# TranSTR + Token Mark (SoM) - Resume Training Support

**Features:**
- Token Mark (SoM) Injection for entity grounding
- DeBERTa text encoder
- W&B logging & checkpoint management
- **Resume training from W&B checkpoint**
- Fix for multiprocessing DataLoader issues

---

## üî¥ RESUME CONFIGURATION

Set `RESUME_FROM_WANDB = True` to resume training from a previous checkpoint.

In [None]:
# ==============================================================================
# CELL 1: Git Clone & Setup
# ==============================================================================
import os
import multiprocessing

# Fix multiprocessing for Kaggle/Colab
try:
    multiprocessing.set_start_method('spawn', force=True)
except RuntimeError:
    pass  # Already set

REPO_URL = "https://github.com/DanielQH07/tranSTR_Casual.git" 
REPO_NAME = "tranSTR_Casual"
BRANCH = "daniel_setmark"

if not os.path.exists(REPO_NAME):
    print(f"Cloning {REPO_URL}...")
    !git clone {REPO_URL} -b {BRANCH}
else:
    print("Repo already exists.")

# Change Directory
if os.path.basename(os.getcwd()) != "causalvid":
    target_dir = os.path.join(os.getcwd(), REPO_NAME, "causalvid")
    if os.path.exists(target_dir):
        os.chdir(target_dir)
    elif os.path.exists(REPO_NAME):
        os.chdir(REPO_NAME)
print(f"Working directory: {os.getcwd()}")

In [None]:
# ==============================================================================
# CELL 2: W&B Setup
# ==============================================================================
print('=== CELL 2: W&B Setup ===')
!pip install -q wandb --upgrade
import wandb

# ============================================
# üî¥ W&B CONFIG - UPDATE THESE!
# ============================================
WANDB_API_KEY = 'YOUR_WANDB_API_KEY_HERE'  # üî¥ REQUIRED
WANDB_PROJECT = 'transtr-causalvid'
WANDB_ENTITY = None  # Your username or None for default

# ============================================
# üî¥ RESUME SETTINGS
# ============================================
RESUME_FROM_WANDB = False  # Set True to resume from checkpoint
RESUME_ARTIFACT_NAME = 'best-model-som:latest'  # W&B artifact name

# Login
wandb.login(key=WANDB_API_KEY, relogin=True)
print('‚úÖ W&B logged in!')

In [None]:
# ==============================================================================
# CELL 3: Imports
# ==============================================================================
print('=== CELL 3: Imports ===')
import os, torch, numpy as np, pandas as pd, json
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from utils.util import set_seed, set_gpu_devices
from DataLoader import VideoQADataset
from networks.model import VideoQAmodel
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.auto import tqdm

print('‚úÖ Imports OK')

In [None]:
# ==============================================================================
# CELL 4: Train/Eval Functions with SoM
# ==============================================================================
print('=== CELL 4: Functions ===')

def train_epoch(model, optimizer, loader, xe, device, epoch, use_som=False):
    """Training with optional SoM injection."""
    model.train()
    total_loss, correct, total = 0, 0, 0
    
    pbar = tqdm(loader, desc=f'Epoch {epoch}', leave=False)
    for batch_idx, batch in enumerate(pbar):
        ff, of, q, a, ans_id, _, som_data = batch
        ff, of, tgt = ff.to(device), of.to(device), ans_id.to(device)
        
        if use_som and som_data is not None:
            out = model(ff, of, q, a, som_data=som_data)
        else:
            out = model(ff, of, q, a)
        
        loss = xe(out, tgt)
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        total_loss += loss.item()
        correct += (out.argmax(-1) == tgt).sum().item()
        total += tgt.size(0)
        
        pbar.set_postfix({'loss': total_loss/(batch_idx+1), 'acc': correct/total*100})
        
        # Log to W&B every 100 batches
        if batch_idx % 100 == 0:
            wandb.log({
                'batch_loss': loss.item(),
                'batch_acc': (out.argmax(-1) == tgt).float().mean().item() * 100,
                'global_step': epoch * len(loader) + batch_idx
            })
        
        # Clear cache periodically to prevent OOM
        if batch_idx % 200 == 0:
            torch.cuda.empty_cache()
    
    return total_loss / len(loader), correct / total * 100

def eval_epoch(model, loader, device, use_som=False):
    """Evaluation with optional SoM injection."""
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for batch in tqdm(loader, desc='Eval', leave=False):
            ff, of, q, a, ans_id, _, som_data = batch
            if use_som and som_data is not None:
                out = model(ff.to(device), of.to(device), q, a, som_data=som_data)
            else:
                out = model(ff.to(device), of.to(device), q, a)
            correct += (out.argmax(-1) == ans_id.to(device)).sum().item()
            total += ans_id.size(0)
    return correct / total * 100

print('‚úÖ Functions defined!')

In [None]:
# ==============================================================================
# CELL 5: Paths & Config
# ==============================================================================
print('=== CELL 5: Paths & Config ===')

# ============================================
# KAGGLE INPUT PATHS
# ============================================
VIT_FEATURE_PATH = '/kaggle/input/vit-features-full-merged'
OBJ_FEATURE_PATH = '/kaggle/input/object-detection-causal-full'
ANNOTATION_PATH = '/kaggle/input/text-annotation/QA'
SPLIT_DIR = '/kaggle/input/casual-vid-data-split/split'
SOM_FEATURE_PATH = '/kaggle/input/causal-vqa-object-masks-full/obj_mask_causal_full'

# Working directories
BASE = '/kaggle/working'
MODEL_DIR = os.path.join(BASE, 'models')
os.makedirs(MODEL_DIR, exist_ok=True)

# Verify paths
print('\n--- Path Verification ---')
def verify_path(name, path):
    if os.path.exists(path):
        items = os.listdir(path)[:3]
        print(f'‚úÖ {name}: {items}')
        return True
    else:
        print(f'‚ùå {name}: NOT FOUND - {path}')
        return False

all_ok = True
all_ok &= verify_path('ViT Features', VIT_FEATURE_PATH)
all_ok &= verify_path('Object Features', OBJ_FEATURE_PATH)
all_ok &= verify_path('Annotations', ANNOTATION_PATH)
all_ok &= verify_path('Splits', SPLIT_DIR)
som_ok = verify_path('SoM Masks', SOM_FEATURE_PATH)

if not all_ok:
    print('\n‚ö†Ô∏è Please update paths above!')

# ============================================
# CONFIG
# ============================================
RUN_TRAINING = True
MAX_TRAIN_SAMPLES = None  # None = all, or set number for testing
MODEL_FILENAME = 'best_model_som.ckpt'
CHECKPOINT_FILENAME = 'training_checkpoint.pt'

class Config:
    # Paths
    video_feature_root = VIT_FEATURE_PATH
    object_feature_path = OBJ_FEATURE_PATH
    sample_list_path = ANNOTATION_PATH
    split_dir_txt = SPLIT_DIR
    som_feature_path = SOM_FEATURE_PATH if som_ok else None
    
    # Model architecture
    topK_frame = 16
    objs = 20
    frames = 16
    select_frames = 5
    topK_obj = 12
    frame_feat_dim = 1024
    obj_feat_dim = 2053
    d_model = 768
    word_dim = 768
    nheads = 8
    num_encoder_layers = 2
    num_decoder_layers = 2
    normalize_before = True
    activation = 'gelu'
    dropout = 0.3
    encoder_dropout = 0.3
    
    # Token Mark (SoM) Settings
    use_som = som_ok
    num_marks = 16
    
    # Text encoder
    text_encoder_type = 'microsoft/deberta-base'
    freeze_text_encoder = False
    text_encoder_lr = 1e-5
    text_pool_mode = 1
    
    # Training
    bs = 8
    lr = 1e-5
    epoch = 20
    gpu = 0
    patience = 5
    gamma = 0.1
    decay = 1e-4
    n_query = 5
    num_workers = 2  # üî¥ Using 2 workers
    
    # Other
    hard_eval = False
    pos_ratio = 1.0
    neg_ratio = 1.0
    a = 1.0

args = Config()
set_gpu_devices(args.gpu)
set_seed(999)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f'\nDevice: {device}')
print(f'Token Mark (SoM): {"ENABLED" if args.use_som else "DISABLED"}')
print(f'Total Epochs: {args.epoch}')
print(f'Num Workers: {args.num_workers}')
print('‚úÖ Config loaded!')

In [None]:
# ==============================================================================
# CELL 6.5: DIAGNOSTIC - Check TokenMark Issues
# ==============================================================================
print('=== CELL 6.5: Diagnostic Check ===')

# 1. Check SoM data availability
print('\n1. Checking SoM data availability...')
som_available = 0
som_missing = 0
sample_som_data = None

for i in range(min(100, len(train_ds))):
    sample = train_ds[i]
    som_data = sample[6]  # Last item is som_data
    if som_data is not None:
        som_available += 1
        if sample_som_data is None:
            sample_som_data = som_data
    else:
        som_missing += 1

print(f'   SoM available: {som_available}/{som_available+som_missing} ({som_available/(som_available+som_missing)*100:.1f}%)')
print(f'   SoM missing:   {som_missing}/{som_available+som_missing} ({som_missing/(som_available+som_missing)*100:.1f}%)')

if sample_som_data:
    print(f'\n   Sample SoM data structure:')
    print(f'     Keys: {list(sample_som_data.keys())}')
    if 'frame_masks' in sample_som_data:
        frame_keys = list(sample_som_data['frame_masks'].keys())
        print(f'     Frame masks: {frame_keys[:5]}... (total: {len(frame_keys)})')
        if frame_keys:
            mask_shape = sample_som_data['frame_masks'][frame_keys[0]].shape
            print(f'     Mask shape: {mask_shape}')
    if 'entity_names' in sample_som_data:
        print(f'     Entity names: {sample_som_data["entity_names"]}')

# 2. Check config consistency
print('\n2. Checking config consistency...')
print(f'   Model use_som: {args.use_som}')
print(f'   Model frame_feat_dim: {args.frame_feat_dim} (should be 1024 for ViT)')
print(f'   Model obj_feat_dim: {args.obj_feat_dim}')
print(f'   Model d_model: {args.d_model}')
print(f'   Model topK_frame (in model): {args.select_frames} (used for selection)')
print(f'   DataLoader topK_frame: {args.topK_frame} (used for loading)')
print(f'   Model topK_obj: {args.topK_obj}')

if args.frame_feat_dim != 1024:
    print(f'   ‚ö†Ô∏è WARNING: frame_feat_dim={args.frame_feat_dim}, expected 1024 for ViT!')

# 3. Check model forward pass
print('\n3. Testing model forward pass...')
model.eval()
try:
    batch = next(iter(train_loader))
    ff, of, q, a, ans_id, _, som_data_batch = batch
    ff, of = ff.to(device), of.to(device)
    
    print(f'   Input shapes:')
    print(f'     ff: {ff.shape} (expected [B, 16, 1024])')
    print(f'     of: {of.shape} (expected [B, 16, 20, 2053])')
    print(f'   SoM data in batch: {[s is not None for s in som_data_batch]}')
    print(f'     Available: {sum(1 for s in som_data_batch if s is not None)}/{len(som_data_batch)}')
    
    with torch.no_grad():
        if args.use_som:
            out = model(ff, of, q, a, som_data=som_data_batch)
        else:
            out = model(ff, of, q, a)
        
        print(f'   Output shape: {out.shape} (expected [B, 5])')
        print(f'   Output range: [{out.min():.2f}, {out.max():.2f}]')
        print(f'   Output mean: {out.mean():.2f}')
        print(f'   Output std: {out.std():.2f}')
        
        # Check if output is reasonable
        if out.std() < 0.1:
            print(f'   ‚ö†Ô∏è WARNING: Output std is very small ({out.std():.2f}), model may not be learning!')
        if torch.isnan(out).any():
            print(f'   ‚ùå ERROR: NaN in output!')
        if torch.isinf(out).any():
            print(f'   ‚ùå ERROR: Inf in output!')
            
except Exception as e:
    print(f'   ‚ùå ERROR in forward pass: {e}')
    import traceback
    traceback.print_exc()

# 4. Check SoM injection parameters
if args.use_som and hasattr(model, 'som_injector'):
    print('\n4. Checking SoM injection parameters...')
    som = model.som_injector
    print(f'   gamma_frame: {som.gamma_frame.item():.4f}')
    print(f'   gamma_obj: {som.gamma_obj.item():.4f}')
    print(f'   num_marks: {som.num_marks}')
    
    if abs(som.gamma_frame.item()) < 0.01:
        print(f'   ‚ö†Ô∏è WARNING: gamma_frame is very small, injection may be negligible!')
    if abs(som.gamma_obj.item()) < 0.01:
        print(f'   ‚ö†Ô∏è WARNING: gamma_obj is very small, injection may be negligible!')

# 5. Check training loop logic
print('\n5. Checking training loop logic...')
print(f'   Training use_som flag: {args.use_som}')
print(f'   Model use_som flag: {model.use_som}')
if args.use_som != model.use_som:
    print(f'   ‚ö†Ô∏è WARNING: Mismatch between training and model use_som flags!')

print('\n‚úÖ Diagnostic complete!')
print('\nüîç KEY ISSUES TO CHECK:')
print('   1. If SoM data is mostly missing, injection won\'t help')
print('   2. If frame_feat_dim != 1024, model resize layer is wrong size')
print('   3. If output std is very small, model may not be learning')
print('   4. If gamma values are too small, SoM injection has no effect')
print('   5. Check if use_som flags are consistent everywhere')


In [None]:
# ==============================================================================
# CELL 6: Create Datasets with SoM
# ==============================================================================
print('=== CELL 6: Datasets ===')

def collate_fn_som(batch):
    """Custom collate that keeps som_data as list of dicts."""
    ff = torch.stack([item[0] for item in batch])
    of = torch.stack([item[1] for item in batch])
    qns = [item[2] for item in batch]
    ans = [item[3] for item in batch]
    ans_id = torch.tensor([item[4] for item in batch])
    qns_key = [item[5] for item in batch]
    som_data = [item[6] for item in batch]
    return ff, of, qns, ans, ans_id, qns_key, som_data

print('Creating TRAIN dataset...')
train_ds = VideoQADataset(
    split='train', n_query=args.n_query, obj_num=args.objs,
    sample_list_path=args.sample_list_path,
    video_feature_path=args.video_feature_root,
    object_feature_path=args.object_feature_path,
    split_dir=args.split_dir_txt, topK_frame=args.topK_frame,
    max_samples=MAX_TRAIN_SAMPLES, verbose=True,
    som_feature_path=args.som_feature_path
)

print('\nCreating VAL dataset...')
val_ds = VideoQADataset(
    split='val', n_query=args.n_query, obj_num=args.objs,
    sample_list_path=args.sample_list_path,
    video_feature_path=args.video_feature_root,
    object_feature_path=args.object_feature_path,
    split_dir=args.split_dir_txt, topK_frame=args.topK_frame,
    max_samples=None, verbose=True,
    som_feature_path=args.som_feature_path
)

print('\nCreating TEST dataset...')
test_ds = VideoQADataset(
    split='test', n_query=args.n_query, obj_num=args.objs,
    sample_list_path=args.sample_list_path,
    video_feature_path=args.video_feature_root,
    object_feature_path=args.object_feature_path,
    split_dir=args.split_dir_txt, topK_frame=args.topK_frame,
    max_samples=None, verbose=True,
    som_feature_path=args.som_feature_path
)

# Create DataLoaders with persistent_workers to avoid recreation issues
train_loader = DataLoader(
    train_ds, args.bs, shuffle=True, 
    num_workers=args.num_workers, 
    pin_memory=True,
    collate_fn=collate_fn_som,
    persistent_workers=True if args.num_workers > 0 else False,
    prefetch_factor=2 if args.num_workers > 0 else None
)
val_loader = DataLoader(
    val_ds, args.bs, shuffle=False, 
    num_workers=args.num_workers, 
    pin_memory=True,
    collate_fn=collate_fn_som,
    persistent_workers=True if args.num_workers > 0 else False,
    prefetch_factor=2 if args.num_workers > 0 else None
)
test_loader = DataLoader(
    test_ds, args.bs, shuffle=False, 
    num_workers=args.num_workers, 
    pin_memory=True,
    collate_fn=collate_fn_som,
    persistent_workers=True if args.num_workers > 0 else False,
    prefetch_factor=2 if args.num_workers > 0 else None
)

print('\n' + '='*60)
print('DATASET SUMMARY')
print('='*60)
print(f'Train: {len(train_ds)} samples -> {len(train_loader)} batches')
print(f'Val:   {len(val_ds)} samples -> {len(val_loader)} batches')
print(f'Test:  {len(test_ds)} samples -> {len(test_loader)} batches')
print(f'SoM:   {"ENABLED" if args.som_feature_path else "DISABLED"}')
print('='*60)

In [None]:
# ==============================================================================
# CELL 7: Create Model & Optimizer
# ==============================================================================
print('=== CELL 7: Model ===')

cfg = {k: v for k, v in Config.__dict__.items() if not k.startswith('_')}
cfg['device'] = device
cfg['topK_frame'] = args.select_frames
cfg['use_som'] = args.use_som
cfg['num_marks'] = args.num_marks

model = VideoQAmodel(**cfg)
model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.decay)
scheduler = ReduceLROnPlateau(optimizer, 'max', factor=args.gamma, patience=args.patience)
xe = nn.CrossEntropyLoss()

save_path = os.path.join(MODEL_DIR, MODEL_FILENAME)
checkpoint_path = os.path.join(MODEL_DIR, CHECKPOINT_FILENAME)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'Total params:     {total_params/1e6:.1f}M')
print(f'Trainable params: {trainable_params/1e6:.1f}M')
print(f'SoM Enabled:      {args.use_som}')

In [None]:
# ==============================================================================
# CELL 8: Resume from Checkpoint (if enabled)
# ==============================================================================
print('=== CELL 8: Resume Logic ===')

start_epoch = 1
best_acc = 0
history = {'train_loss': [], 'train_acc': [], 'val_acc': []}

# ============================================
# RESUME FROM W&B CHECKPOINT
# ============================================
if RESUME_FROM_WANDB:
    print(f'\nüîÑ Resuming from W&B artifact: {RESUME_ARTIFACT_NAME}')
    
    try:
        # Temp run to download artifact
        temp_run = wandb.init(
            project=WANDB_PROJECT,
            entity=WANDB_ENTITY,
            job_type='download',
            reinit=True
        )
        
        artifact = temp_run.use_artifact(RESUME_ARTIFACT_NAME, type='model')
        artifact_dir = artifact.download()
        
        # Find checkpoint file
        ckpt_files = [f for f in os.listdir(artifact_dir) if f.endswith('.ckpt') or f.endswith('.pt')]
        if ckpt_files:
            ckpt_path = os.path.join(artifact_dir, ckpt_files[0])
            
            # Load model weights
            state_dict = torch.load(ckpt_path, map_location=device)
            model.load_state_dict(state_dict)
            print(f'‚úÖ Model weights loaded from: {ckpt_path}')
            
            # Get epoch from artifact metadata
            if artifact.metadata:
                start_epoch = artifact.metadata.get('epoch', 0) + 1
                best_acc = artifact.metadata.get('val_acc', 0)
                print(f'‚úÖ Resume from epoch {start_epoch}, best_acc={best_acc:.2f}%')
        
        temp_run.finish()
        
    except Exception as e:
        print(f'‚ö†Ô∏è Could not resume from W&B: {e}')
        print('Starting fresh training...')
        start_epoch = 1
        best_acc = 0

# ============================================
# CHECK LOCAL CHECKPOINT (fallback)
# ============================================
if not RESUME_FROM_WANDB and os.path.exists(checkpoint_path):
    print(f'\nüîÑ Found local checkpoint: {checkpoint_path}')
    user_input = input('Load local checkpoint? (y/n): ').strip().lower()
    
    if user_input == 'y':
        try:
            checkpoint = torch.load(checkpoint_path, map_location=device)
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            start_epoch = checkpoint['epoch'] + 1
            best_acc = checkpoint['best_acc']
            history = checkpoint.get('history', history)
            print(f'‚úÖ Resumed from local checkpoint at epoch {start_epoch}')
        except Exception as e:
            print(f'‚ö†Ô∏è Could not load local checkpoint: {e}')

print(f'\nüìå Training will start from epoch {start_epoch}')
print(f'üìå Best accuracy so far: {best_acc:.2f}%')

In [None]:
# ==============================================================================
# CELL 9: Initialize W&B Training Run
# ==============================================================================
print('=== CELL 9: Init W&B Run ===')

wandb_config = {
    'model': 'TranSTR-SoM',
    'text_encoder': args.text_encoder_type,
    'batch_size': args.bs,
    'learning_rate': args.lr,
    'total_epochs': args.epoch,
    'start_epoch': start_epoch,
    'max_train_samples': MAX_TRAIN_SAMPLES,
    'd_model': args.d_model,
    'nheads': args.nheads,
    'num_encoder_layers': args.num_encoder_layers,
    'num_decoder_layers': args.num_decoder_layers,
    'dropout': args.dropout,
    'topK_frame': args.select_frames,
    'topK_obj': args.topK_obj,
    'train_samples': len(train_ds),
    'val_samples': len(val_ds),
    'test_samples': len(test_ds),
    'use_som': args.use_som,
    'num_marks': args.num_marks,
    'num_workers': args.num_workers,
    'resumed': start_epoch > 1,
}

run = wandb.init(
    project=WANDB_PROJECT,
    entity=WANDB_ENTITY,
    config=wandb_config,
    name=f'transtr-som-ep{start_epoch}-{args.epoch}',
    resume='allow',
    reinit=True
)

wandb.watch(model, log='gradients', log_freq=200)
print(f'‚úÖ W&B Run: {run.url}')

In [None]:
# ==============================================================================
# CELL 10: Training Loop with Resume Support
# ==============================================================================
print('=== CELL 10: Training ===')

if RUN_TRAINING:
    print('\n' + '='*60)
    print('üöÄ STARTING TRAINING')
    print('='*60)
    print(f'Epochs: {start_epoch} to {args.epoch}')
    print(f'Best Val Acc: {best_acc:.2f}%')
    print('='*60)
    
    for ep in range(start_epoch, args.epoch + 1):
        print(f'\nEpoch {ep}/{args.epoch}')
        
        # Train
        loss, train_acc = train_epoch(
            model, optimizer, train_loader, xe, device, ep,
            use_som=args.use_som
        )
        
        # Validate
        val_acc = eval_epoch(model, val_loader, device, use_som=args.use_som)
        scheduler.step(val_acc)
        
        # Update history
        history['train_loss'].append(loss)
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)
        
        # Log to W&B
        wandb.log({
            'epoch': ep,
            'train_loss': loss,
            'train_acc': train_acc,
            'val_acc': val_acc,
            'learning_rate': optimizer.param_groups[0]['lr'],
            'best_val_acc': max(best_acc, val_acc)
        })
        
        print(f'Loss: {loss:.4f} | Train: {train_acc:.1f}% | Val: {val_acc:.1f}%')
        
        # Save best model
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), save_path)
            print(f'‚úÖ New best! Saved to {save_path}')
            
            # Upload to W&B
            artifact = wandb.Artifact(
                name='best-model-som',
                type='model',
                description=f'Best TranSTR-SoM at epoch {ep}',
                metadata={
                    'epoch': ep,
                    'val_acc': val_acc,
                    'train_acc': train_acc,
                    'train_loss': loss,
                    'use_som': args.use_som
                }
            )
            artifact.add_file(save_path)
            wandb.log_artifact(artifact)
            print('üì§ Uploaded to W&B!')
        
        # Save full checkpoint EVERY epoch for resume
        torch.save({
            'epoch': ep,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_acc': best_acc,
            'history': history,
        }, checkpoint_path)
        print(f'üíæ Checkpoint saved')
        
        # Upload checkpoint artifact every 5 epochs
        if ep % 5 == 0 or ep == args.epoch:
            ckpt_artifact = wandb.Artifact(
                name=f'checkpoint-ep{ep}',
                type='checkpoint',
                description=f'Training checkpoint at epoch {ep}',
                metadata={'epoch': ep, 'val_acc': val_acc, 'best_acc': best_acc}
            )
            ckpt_artifact.add_file(checkpoint_path)
            wandb.log_artifact(ckpt_artifact)
            print(f'üì§ Checkpoint artifact uploaded!')
        
        # Clear cache
        torch.cuda.empty_cache()
    
    print(f'\nüèÜ Training Complete! Best Val Accuracy: {best_acc:.1f}%')
    
    # Final summary
    wandb.run.summary['best_val_acc'] = best_acc
    wandb.run.summary['final_epoch'] = ep
    wandb.run.summary['total_epochs_trained'] = ep - start_epoch + 1
    wandb.run.summary['use_som'] = args.use_som

else:
    print('Skipping training (RUN_TRAINING=False)')

In [None]:
# ==============================================================================
# CELL 11: Detailed Evaluation
# ==============================================================================
print('=== CELL 11: Detailed Evaluation ===')
import seaborn as sns

def evaluate_detailed(model, loader, device, use_som=False, split_name='val', log_to_wandb=True):
    """Detailed evaluation with per-type accuracy."""
    model.eval()
    type_results = {}
    
    print(f"\nüìä Running {split_name.upper()} Evaluation...")
    torch.cuda.empty_cache()
    
    with torch.no_grad():
        for batch in tqdm(loader, desc=f'Eval {split_name}'):
            ff, of, qns, ans_word, ans_id, qns_keys, som_data = batch
            ff, of = ff.to(device), of.to(device)
            
            if use_som and som_data is not None:
                out = model(ff, of, qns, ans_word, som_data=som_data)
            else:
                out = model(ff, of, qns, ans_word)
            
            preds = out.argmax(dim=-1).cpu().numpy()
            targets = ans_id.numpy()
            
            del out, ff, of
            torch.cuda.empty_cache()
            
            for key, pred, target in zip(qns_keys, preds, targets):
                # Parse question type
                if key.endswith('_reason'):
                    if '_predictive_reason' in key:
                        qtype = 'predictive_reason'
                    elif '_counterfactual_reason' in key:
                        qtype = 'counterfactual_reason'
                    else:
                        parts = key.rsplit('_', 2)
                        qtype = '_'.join(parts[1:]) if len(parts) > 1 else 'unknown'
                else:
                    parts = key.rsplit('_', 1)
                    qtype = parts[1] if len(parts) == 2 else 'unknown'
                
                video_id = key.split('_')[0] if '_' in key else key
                
                if qtype not in type_results:
                    type_results[qtype] = []
                type_results[qtype].append({
                    'video_id': video_id,
                    'pred': int(pred),
                    'target': int(target),
                    'correct': int(pred) == int(target)
                })
    
    # Calculate metrics
    metrics = {}
    metrics_map = {
        'Description': 'descriptive',
        'Explanation': 'explanatory',
        'Predictive-Answer': 'predictive',
        'Predictive-Reason': 'predictive_reason',
        'Counterfactual-Answer': 'counterfactual',
        'Counterfactual-Reason': 'counterfactual_reason'
    }
    
    print("\n" + "="*60)
    print(f"EVALUATION RESULTS - {split_name.upper()}")
    print("="*60)
    
    for name, qtype in metrics_map.items():
        if qtype in type_results:
            results = type_results[qtype]
            correct = sum(1 for r in results if r['correct'])
            total = len(results)
            acc = correct / total * 100 if total > 0 else 0
        else:
            correct, total, acc = 0, 0, 0
        metrics[name] = acc
        print(f"{name:<25} ==>   {acc:.2f}%  ({correct}/{total})")
    
    print("-" * 60)
    
    # Hard metrics
    def calc_hard_metric(type_ans, type_reason, name):
        if type_ans not in type_results or type_reason not in type_results:
            metrics[name] = 0
            print(f"{name:<25} ==>   0.00%  (0/0 paired)")
            return
        
        ans_by_vid = {r['video_id']: r['correct'] for r in type_results[type_ans]}
        reason_by_vid = {r['video_id']: r['correct'] for r in type_results[type_reason]}
        common_vids = set(ans_by_vid.keys()) & set(reason_by_vid.keys())
        
        both_correct = sum(1 for vid in common_vids if ans_by_vid[vid] and reason_by_vid[vid])
        total = len(common_vids)
        acc = both_correct / total * 100 if total > 0 else 0
        metrics[name] = acc
        print(f"{name:<25} ==>   {acc:.2f}%  ({both_correct}/{total} paired)")
    
    calc_hard_metric('predictive', 'predictive_reason', 'PAR')
    calc_hard_metric('counterfactual', 'counterfactual_reason', 'CAR')
    
    print("-" * 60)
    
    # Acc (ALL)
    d_acc = metrics.get('Description', 0)
    e_acc = metrics.get('Explanation', 0)
    par_acc = metrics.get('PAR', 0)
    car_acc = metrics.get('CAR', 0)
    acc_all = (d_acc + e_acc + par_acc + car_acc) / 4
    metrics['Acc_ALL'] = acc_all
    print(f"{'Acc (ALL)':<25} ==>   {acc_all:.2f}%  ((D+E+PAR+CAR)/4)")
    print("="*60)
    
    # Log to W&B
    if log_to_wandb:
        wandb.log({
            f'{split_name}/Description': metrics['Description'],
            f'{split_name}/Explanation': metrics['Explanation'],
            f'{split_name}/PAR': metrics['PAR'],
            f'{split_name}/CAR': metrics['CAR'],
            f'{split_name}/Acc_ALL': acc_all
        })
        print('üì§ Metrics logged to W&B!')
    
    return metrics, type_results

# Load best model
if os.path.exists(save_path):
    model.load_state_dict(torch.load(save_path, map_location=device))
    print(f'Loaded best model from {save_path}')

# Evaluate on VAL
print("\nüìå VALIDATION SET")
val_metrics, val_raw = evaluate_detailed(
    model, val_loader, device, 
    use_som=args.use_som, split_name='val'
)

# Evaluate on TEST
print("\nüìå TEST SET")
test_metrics, test_raw = evaluate_detailed(
    model, test_loader, device, 
    use_som=args.use_som, split_name='test'
)

In [None]:
# ==============================================================================
# CELL 12: Save Results & Finish
# ==============================================================================
print('=== CELL 12: Save & Finish ===')

# Compare VAL vs TEST
print("\n" + "="*70)
print("üìä VALIDATION vs TEST COMPARISON")
print("="*70)
print(f"{'Metric':<25} {'Val':>10} {'Test':>10} {'Diff':>10}")
print("-"*70)
for key in ['Description', 'Explanation', 'PAR', 'CAR', 'Acc_ALL']:
    val_v = val_metrics.get(key, 0)
    test_v = test_metrics.get(key, 0)
    diff = test_v - val_v
    symbol = '‚Üë' if diff > 0 else ('‚Üì' if diff < 0 else '=')
    print(f"{key:<25} {val_v:>9.2f}% {test_v:>9.2f}% {diff:>+9.2f}% {symbol}")
print("="*70)

# Save results
results = {
    'best_val_acc': best_acc,
    'validation': val_metrics,
    'test': test_metrics,
    'use_som': args.use_som,
    'num_marks': args.num_marks,
    'epochs_trained': len(history['train_loss']),
    'history': history
}

with open('final_results.json', 'w') as f:
    json.dump(results, f, indent=2)
print('\nüìÅ Saved: final_results.json')

# Log final artifact
final_artifact = wandb.Artifact(
    name='final-results',
    type='results',
    description='Final evaluation results'
)
final_artifact.add_file('final_results.json')
if os.path.exists(save_path):
    final_artifact.add_file(save_path)
wandb.log_artifact(final_artifact)

# Update summary
wandb.run.summary['test_Acc_ALL'] = test_metrics['Acc_ALL']
wandb.run.summary['test_PAR'] = test_metrics['PAR']
wandb.run.summary['test_CAR'] = test_metrics['CAR']

# Finish W&B
wandb.finish()
print('\n‚úÖ Done! Check W&B for full results.')
print(f'View at: https://wandb.ai/{WANDB_ENTITY or "your-username"}/{WANDB_PROJECT}')