# üèÜ Apex-X Ascension V5 Professional Training Suite

**The production-grade engine for world-class segmentation.**

This notebook provides a complete E2E pipeline for:
1. **Dataset Exploratory Data Analysis (EDA)**: Understand your data bias and statistics.
2. **SOTA Ascension V5 Training**: Stable, high-precision training loop.
3. **Production Export (ONNX)**: Deploy your model with dynamic axes support.


## 1. System Setup & Diagnostics

In [None]:
import os, sys, warnings
warnings.filterwarnings('ignore', category=UserWarning, module='IPython')

# 1. Install critical system dependencies first
!pip install pickleshare structlog -q

if not os.path.exists('Apex-X'):
    !git clone https://github.com/Voskan/Apex-X.git
    print('‚úÖ Repository cloned')
else:
    !cd Apex-X && git pull
    print('‚úÖ Repository updated')

%cd Apex-X
!pip install -e . -q
!pip install pycocotools albumentations matplotlib seaborn tqdm transformers timm peft -q
print('\n‚úÖ Environment Ready')

import torch, time
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
print(f"‚úÖ PyTorch: {torch.__version__}")
if torch.cuda.is_available():
    print(f"   Device: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

try:
    import triton
    print(f"‚úÖ Triton: {triton.__version__}")
except:
    print("‚ö†Ô∏è Triton not found. Using CPU fallback for Geometrical Branch.")

## 2. Dataset Statistical Intelligence (EDA)

In [None]:
from apex_x.data import YOLOSegmentationDataset
from torch.utils.data import DataLoader

DATASET_ROOT = '/workspace/YOLO26_Merged'
IMAGE_SIZE = 1024

ds = YOLOSegmentationDataset(DATASET_ROOT, split='train', image_size=IMAGE_SIZE)
print(f"Analyzing {len(ds)} images...")

class_counts = {}
instance_counts = []
mask_areas = []

# Sample 10% for speed if massive, else full
sample_indices = np.random.choice(len(ds), min(500, len(ds)), replace=False)

for idx in tqdm(sample_indices):
    sample = ds[idx]
    cids = sample.class_ids
    instance_counts.append(len(cids))
    for cid in cids:
        class_counts[cid] = class_counts.get(cid, 0) + 1
    
    if sample.masks is not None:
        # Calculate relative area
        areas = sample.masks.sum(axis=(1, 2)) / (IMAGE_SIZE**2)
        mask_areas.extend(areas.tolist())

plt.figure(figsize=(18, 5))
plt.subplot(1, 3, 1)
plt.bar(class_counts.keys(), class_counts.values())
plt.title("Class Distribution"); plt.xlabel("Class ID"); plt.ylabel("Instances")

plt.subplot(1, 3, 2)
plt.hist(instance_counts, bins=20)
plt.title("Instances per Image"); plt.xlabel("Count")

plt.subplot(1, 3, 3)
plt.hist(mask_areas, bins=50, color='orange')
plt.title("Relative Mask Area"); plt.yscale('log')
plt.show()

## 3. Advanced SOTA Configuration

In [None]:
from apex_x.config import ApexXConfig, ModelConfig, TrainConfig, LossConfig
from apex_x.data import standard_collate_fn
from apex_x.data.transforms import build_robust_transforms
from apex_x.model import TeacherModelV5

# Hyperparameters
config = ApexXConfig(
    model=ModelConfig(input_height=IMAGE_SIZE, input_width=IMAGE_SIZE),
    train=TrainConfig(
        batch_size=16, 
        epochs=200, 
        lr=1e-4, 
        weight_decay=0.05,
        grad_accum=16
    ),
    loss=LossConfig(
        topological_persistence=True,   # Stability Fix Applied
        flow_symmetry=True,             # Physics-informed boundaries
        self_distillation=True          # Recursive refinement
    )
)

train_tf = build_robust_transforms(IMAGE_SIZE, IMAGE_SIZE)
train_loader = DataLoader(ds, batch_size=config.train.batch_size, shuffle=True, 
                          collate_fn=standard_collate_fn, num_workers=4)

model = TeacherModelV5(num_classes=ds.num_classes).to('cuda')
optimizer = torch.optim.AdamW(model.parameters(), lr=config.train.lr, weight_decay=config.train.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.train.epochs)
scaler = torch.amp.GradScaler('cuda')

## 4. Professional Training Loop

In [None]:
from apex_x.train.train_losses_v5 import compute_v5_training_losses
from apex_x.train.validation import validate_epoch

best_val_loss = float('inf')

for epoch in range(config.train.epochs):
    model.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.train.epochs}")
    
    for i, batch in enumerate(pbar):
        imgs = batch['images'].to('cuda')
        targets = {k: v.to('cuda') if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
        
        with torch.amp.autocast('cuda'):
            outputs = model(imgs)
            loss, loss_dict = compute_v5_training_losses(outputs, targets, model, config)
            
        scaler.scale(loss).backward()
        
        if (i+1) % config.train.grad_accum == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            
        pbar.set_postfix({k: f"{v.item():.3f}" for k, v in loss_dict.items() if 'loss' not in k})
    
    scheduler.step()
    
    # Validation
    # Note: val_loader needs standard_collate_fn too!
    with torch.no_grad():
        # Assuming you have a val_loader check step 3
        # val_metrics = validate_epoch(model, val_loader, device='cuda', loss_fn=compute_v5_training_losses, config=config)
        # print(f"\n‚≠ê Epoch {epoch+1} Results | Val Loss: {val_metrics['val_loss']:.4f}")
        pass

## 5. Production Handoff: ONNX Export

In [None]:
from apex_x.export.onnx_export import export_to_onnx, verify_onnx_model

onnx_path = "artifacts/ascension_v5_flagship.onnx"

class ONNXWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, x):
        out = self.model(x)
        # Flatten DICT to tuple for ONNX compatibility
        return out['boxes'], out['scores'], out['masks']

wrapper = ONNXWrapper(model).cpu().eval()

export_to_onnx(
    wrapper,
    onnx_path,
    input_shape=(1, 3, IMAGE_SIZE, IMAGE_SIZE),
    opset_version=17,
    dynamic_axes={
        'input': {0: 'batch_size', 2: 'height', 3: 'width'},
        'output1': {0: 'batch_size'}, # boxes
        'output2': {0: 'batch_size'}, # scores
        'output3': {0: 'batch_size'}, # masks
    }
)

print(f"\nüéÅ Ascension V5 ONNX Suite Ready at {onnx_path}")