In [None]:
!git clone https://github.com/DanielQH07/tranSTR_Casual.git

In [None]:
%cd /kaggle/working/tranSTR_Casual

In [None]:
# ============================================================
# VERIFY PATCHED FILES (No need to re-write - already patched)
# ============================================================
import os

print("=" * 70)
print("üîç VERIFYING PATCHED FILES")
print("=" * 70)

checks = []

# Check 1: attention.py patch
attention_file = 'networks/attention.py'
if os.path.exists(attention_file):
    with open(attention_file, 'r') as f:
        content = f.read()
    has_fix = 'DataParallel-safe' in content or 'new_mask' in content
    status = "‚úÖ PATCHED" if has_fix else "‚ùå NOT PATCHED"
    checks.append(has_fix)
    print(f"  {status}: {attention_file}")
else:
    print(f"  ‚ùå MISSING: {attention_file}")
    checks.append(False)

# Check 2: model.py patch
model_file = 'networks/model.py'
if os.path.exists(model_file):
    with open(model_file, 'r') as f:
        content = f.read()
    has_fix = 'repeat_interleave' in content
    status = "‚úÖ PATCHED" if has_fix else "‚ùå NOT PATCHED"
    checks.append(has_fix)
    print(f"  {status}: {model_file}")
else:
    print(f"  ‚ùå MISSING: {model_file}")
    checks.append(False)

# Check 3: DataLoader.py patch
dataloader_file = 'DataLoader.py'
if os.path.exists(dataloader_file):
    with open(dataloader_file, 'r') as f:
        content = f.read()
    has_fix = 'dummy_bbox' in content or 'Handle different feature shapes' in content
    status = "‚úÖ PATCHED" if has_fix else "‚ùå NOT PATCHED"
    checks.append(has_fix)
    print(f"  {status}: {dataloader_file}")
else:
    print(f"  ‚ùå MISSING: {dataloader_file}")
    checks.append(False)

print("=" * 70)
if all(checks):
    print("‚úÖ ALL FILES VERIFIED! Ready for training.")
else:
    print("‚ùå SOME FILES ARE MISSING OR NOT PATCHED!")
print("=" * 70)

In [None]:
import os
import pickle
import h5py
import json

# ============================================================
# DATA PATHS (Kaggle Input)
# ============================================================
text_feature_path = '/kaggle/input/text-feature'
visual_feature_path = '/kaggle/input/visual-feature'
split_path = '/kaggle/input/casual-vid-data-split/split'
text_annotation_path = '/kaggle/input/text-annotation'

print("=" * 70)
print("üìÇ DATA PATHS")
print("=" * 70)
for name, path in [("Visual features", visual_feature_path), 
                   ("Split files", split_path), 
                   ("Text annotations", text_annotation_path)]:
    status = "‚úì" if os.path.exists(path) else "‚úó"
    print(f"  {status} {name}: {path}")

# ============================================================
# DATA STATISTICS
# ============================================================
print("\n" + "=" * 70)
print("üìä DATASET STATISTICS")
print("=" * 70)

# 1. Split files
print("\nüìÅ Split Files:")
split_stats = {}
for split_name in ['train', 'valid', 'test']:
    split_file = f'{split_path}/{split_name}.pkl'
    if os.path.exists(split_file):
        with open(split_file, 'rb') as f:
            vids = pickle.load(f)
        split_stats[split_name] = len(vids)
        samples = len(vids) * 6  # 6 question types per video
        print(f"  {split_name:>6}: {len(vids):>6} videos ‚Üí {samples:>6} samples")

# 2. Visual features
print("\nüé¨ Visual Features:")
idx2vid_file = f'{visual_feature_path}/idx2vid.pkl'
if os.path.exists(idx2vid_file):
    with open(idx2vid_file, 'rb') as f:
        idx2vid = pickle.load(f)
    print(f"  Indexed videos: {len(idx2vid)}")

for feat_name in ['appearance_feat.h5', 'motion_feat.h5']:
    feat_file = f'{visual_feature_path}/{feat_name}'
    if os.path.exists(feat_file):
        with h5py.File(feat_file, 'r') as f:
            shape = f['resnet_features'].shape
        print(f"  {feat_name}: {shape}")

# 3. Question types
print("\n‚ùì Question Types (qtype):")
qtype_info = [
    ("0", "Descriptive", "What is happening?"),
    ("1", "Explanatory", "Why did it happen?"),
    ("2", "Predictive-Ans", "What will happen?"),
    ("3", "Predictive-Reason", "Why will it happen?"),
    ("4", "Counterfactual-Ans", "What if X didn't happen?"),
    ("5", "Counterfactual-Reason", "Why would that result?"),
]
for qt, name, desc in qtype_info:
    print(f"  {qt}: {name:<20} - {desc}")

print("\n" + "=" * 70)

In [None]:
!pip install -q transformers einops h5py wandb

# Login to W&B (uncomment v√† th√™m API key c·ªßa b·∫°n)
my_key = "80b5a02ccaed80f35a2e893aed6446d4467c0c45"
import wandb
wandb.login(key=my_key, relogin=True)

In [None]:
import os
import sys
import torch
import numpy as np
import wandb

# ============================================================
# CONFIGURATION
# ============================================================

class Config:
    """Training configuration for CausalVidQA"""
    
    # Experiment
    project_name = "CausalVidQA-TranSTR"
    run_name = "causalvid_1gpu"
    
    # Data paths
    sample_list_path = split_path
    video_feature_path = visual_feature_path
    text_annotation_path = text_annotation_path
    
    # Training
    bs = 2                     # Batch size per step
    lr = 1e-4                  # Learning rate
    text_encoder_lr = 1e-5     # Text encoder LR (lower)
    epoch = 20
    warmup_epochs = 2          # Warmup epochs
    
    # üî• Gradient Accumulation (simulate 2 GPUs on 1 GPU)
    ACCUM_STEPS = 2            # Effective batch size = bs * ACCUM_STEPS
    
    # Dataset
    dataset = 'causal-vid'
    qtype = 3                  # -1 = all question types
    max_samples = 2000         # None = use all data
    
    # Model architecture
    d_model = 768
    word_dim = 768
    nheads = 8
    num_encoder_layers = 1
    num_decoder_layers = 1
    dropout = 0.1
    encoder_dropout = 0.1
    activation = 'relu'
    normalize_before = False
    
    # Video features
    objs = 20                  # Objects per frame
    topK_frame = 8             # Top-K frames to select
    topK_obj = 5               # Top-K objects to select
    frame_feat_dim = 4096      # app(2048) + mot(2048)
    obj_feat_dim = 2053        # feat(2048) + bbox(5)
    n_query = 5                # 5-way multiple choice
    
    # Text encoder
    text_encoder_type = "microsoft/deberta-base"
    freeze_text_encoder = False
    text_pool_mode = 0
    hard_eval = False
    
    # Optimizer
    decay = 0.001              # Weight decay
    patience = 3               # LR scheduler patience
    gamma = 0.5                # LR decay factor
    
    # Early stopping
    early_stopping_patience = 5  # Stop after 5 epochs without improvement
    
    # Contrastive learning
    pos_ratio = 0.7
    neg_ratio = 0.3
    a = 1
    
    # Single GPU mode (no DataParallel)
    use_multi_gpu = False      # Disabled - using gradient accumulation instead
    num_workers = 0            # DataLoader workers
    
    # Logging
    log_interval = 50          # Log every N batches
    save_every = 5             # Save checkpoint every N epochs

args = Config()

# ============================================================
# GPU SETUP (Single GPU)
# ============================================================
print("=" * 70)
print("üñ•Ô∏è GPU CONFIGURATION")
print("=" * 70)

n_gpus = torch.cuda.device_count()
print(f"  Available GPUs: {n_gpus}")
for i in range(n_gpus):
    print(f"    GPU {i}: {torch.cuda.get_device_name(i)}")
    mem = torch.cuda.get_device_properties(i).total_memory / 1e9
    print(f"           Memory: {mem:.1f} GB")

print(f"\n  ‚úì Single GPU mode with Gradient Accumulation")
print(f"  ‚úì Accumulation steps: {args.ACCUM_STEPS}")
print(f"  ‚úì Effective batch size: {args.bs * args.ACCUM_STEPS}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"  Primary device: {device}")

# ============================================================
# PRINT CONFIG
# ============================================================
print("\n" + "=" * 70)
print("‚öôÔ∏è TRAINING CONFIG")
print("=" * 70)
config_items = [
    ("Batch size", args.bs),
    ("Accumulation steps", args.ACCUM_STEPS),
    ("Effective batch size", args.bs * args.ACCUM_STEPS),
    ("Learning rate", args.lr),
    ("Text encoder LR", args.text_encoder_lr),
    ("Epochs", args.epoch),
    ("Early stopping", f"{args.early_stopping_patience} epochs"),
    ("d_model", args.d_model),
    ("TopK frames", args.topK_frame),
    ("TopK objects", args.topK_obj),
    ("Objects/frame", args.objs),
    ("Text encoder", args.text_encoder_type),
]

for name, val in config_items:
    print(f"  {name:<20}: {val}")
print("=" * 70)

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingWarmRestarts
from collections import defaultdict
import time
import json

# Local imports
from DataLoader import VideoQADataset
from networks.model import VideoQAmodel
import eval_mc

# ============================================================
# REPRODUCIBILITY
# ============================================================
def set_seed(seed=999):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(999)
print("‚úÖ Modules imported, seed set to 999")

In [None]:
print("Creating datasets...")

# ============================================================
# CREATE DATASETS
# ============================================================
dataset_kwargs = dict(
    n_query=args.n_query,
    obj_num=args.objs,
    sample_list_path=args.sample_list_path,
    video_feature_path=args.video_feature_path,
    text_annotation_path=args.text_annotation_path,
    qtype=args.qtype,
    max_samples=args.max_samples
)

train_dataset = VideoQADataset(split='train', **dataset_kwargs)
val_dataset = VideoQADataset(split='val', **dataset_kwargs)

# Test set LU√îN d√πng to√†n b·ªô data (kh√¥ng gi·ªõi h·∫°n max_samples)
test_kwargs = dataset_kwargs.copy()
test_kwargs['max_samples'] = None  # Force full test set
test_dataset = VideoQADataset(split='test', **test_kwargs)

# ============================================================
# CREATE DATALOADERS (optimized for multi-GPU)
# ============================================================
loader_kwargs = dict(
    batch_size=args.bs,
    num_workers=args.num_workers if args.use_multi_gpu else 0,
    pin_memory=True,
    prefetch_factor=2 if args.num_workers > 0 else None,
)

train_loader = DataLoader(train_dataset, shuffle=True,drop_last=True, **loader_kwargs)
val_loader = DataLoader(val_dataset, shuffle=False, **loader_kwargs)
test_loader = DataLoader(test_dataset, shuffle=False, **loader_kwargs)

# ============================================================
# DATASET SUMMARY
# ============================================================
print("\n" + "=" * 70)
print("üìä DATALOADER SUMMARY")
print("=" * 70)
print(f"  {'Split':<10} {'Videos':>10} {'Samples':>10} {'Batches':>10}")
print(f"  {'-'*10} {'-'*10} {'-'*10} {'-'*10}")
for name, dataset, loader in [
    ("Train", train_dataset, train_loader),
    ("Val", val_dataset, val_loader),
    ("Test (FULL)", test_dataset, test_loader)
]:
    n_vids = len(dataset.vids) if hasattr(dataset, 'vids') else "?"
    print(f"  {name:<10} {n_vids:>10} {len(dataset):>10} {len(loader):>10}")
print("=" * 70)
print("  ‚ÑπÔ∏è  Test set always uses ALL data regardless of max_samples")

In [None]:
# ============================================================
# VERIFY DATA SAMPLE
# ============================================================
print("üîç Verifying data sample...")

for batch in train_loader:
    vid_frame_feat, vid_obj_feat, qns_word, ans_word, ans_id, qns_key = batch
    
    print(f"\n  Frame features:  {vid_frame_feat.shape}")
    print(f"  Object features: {vid_obj_feat.shape}")
    print(f"  Batch size:      {len(qns_word)}")
    print(f"\n  Sample question: {qns_word[0][:80]}...")
    print(f"  Sample answer:   {ans_word[0][0][:60]}...")
    print(f"  Ground truth:    {ans_id[0].item()}")
    print(f"  Question key:    {qns_key[0]}")
    break

print("\n‚úÖ Data verification complete!")

In [None]:
# ============================================================
# CREATE MODEL (Single GPU - No DataParallel)
# ============================================================
print("üèóÔ∏è Creating model...")

model_config = {
    'd_model': args.d_model,
    'word_dim': args.word_dim,
    'encoder_dropout': args.encoder_dropout,
    'dropout': args.dropout,
    'num_encoder_layers': args.num_encoder_layers,
    'num_decoder_layers': args.num_decoder_layers,
    'nheads': args.nheads,
    'normalize_before': args.normalize_before,
    'activation': args.activation,
    'text_encoder_type': args.text_encoder_type,
    'freeze_text_encoder': args.freeze_text_encoder,
    'text_pool_mode': args.text_pool_mode,
    'n_query': args.n_query,
    'objs': args.objs,
    'topK_frame': args.topK_frame,
    'topK_obj': args.topK_obj,
    'hard_eval': args.hard_eval,
    'frame_feat_dim': args.frame_feat_dim,
    'obj_feat_dim': args.obj_feat_dim,
    'device': device,
}

model = VideoQAmodel(**model_config)

# Single GPU - no DataParallel wrapper needed
model.to(device)

# ============================================================
# MODEL SUMMARY
# ============================================================
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("\n" + "=" * 70)
print("üß† MODEL SUMMARY")
print("=" * 70)
print(f"  Total parameters:     {total_params / 1e6:.2f}M")
print(f"  Trainable parameters: {trainable_params / 1e6:.2f}M")
print(f"  Device:               {device}")
print(f"  Gradient Accumulation: {args.ACCUM_STEPS} steps")
print("=" * 70)

In [None]:
# ============================================================
# TRAINING FUNCTIONS WITH GRADIENT ACCUMULATION
# ============================================================

def train_epoch(model, optimizer, train_loader, criterion, device, epoch,
                wandb_run=None, accum_steps=2):
    """Train for one epoch with gradient accumulation and wrong sample tracking"""
    model.train()
    optimizer.zero_grad()
    
    total_loss = 0.0
    predictions = []
    answers = []
    batch_times = []
    
    # üî¥ Track wrong predictions
    wrong_samples = []
    
    # Per question type tracking
    qtype_correct = defaultdict(int)
    qtype_total = defaultdict(int)
    
    start_time = time.time()
    
    for batch_idx, inputs in enumerate(train_loader):
        batch_start = time.time()
        
        vid_frame_feat, vid_obj_feat, qns_w, ans_w, ans_id, qns_keys = inputs
        vid_frame_feat = vid_frame_feat.to(device)
        vid_obj_feat = vid_obj_feat.to(device)
        ans_targets = ans_id.to(device)
        
        # Forward pass
        out = model(vid_frame_feat, vid_obj_feat, qns_w, ans_w)
        loss = criterion(out, ans_targets)
        
        # üî• Gradient Accumulation: scale loss
        loss = loss / accum_steps
        loss.backward()
        
        # Clip gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        # Step optimizer every accum_steps batches
        if (batch_idx + 1) % accum_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        
        # Track metrics (use unscaled loss for logging)
        total_loss += loss.item() * accum_steps
        
        # Predictions
        pred = torch.argmax(out, dim=1).detach().cpu()
        ans = ans_id.detach().cpu()
        
        predictions.append(pred)
        answers.append(ans)
        
        # Track per question type accuracy
        for qkey, p, a in zip(qns_keys, pred.numpy(), ans.numpy()):
            qtype = int(qkey.split('_')[-1])
            qtype_total[qtype] += 1
            if p == a:
                qtype_correct[qtype] += 1
        
        # üî¥ Track wrong predictions
        for i in range(len(pred)):
            if pred[i].item() != ans[i].item():
                wrong_samples.append({
                    "qid": qns_keys[i],
                    "prediction": int(pred[i]),
                    "answer": int(ans[i]),
                    "question": qns_w[i] if isinstance(qns_w, list) else qns_w,
                    "epoch": epoch
                })
        
        batch_times.append(time.time() - batch_start)
        
        # Logging
        if (batch_idx + 1) % args.log_interval == 0:
            avg_loss = total_loss / (batch_idx + 1)
            avg_time = np.mean(batch_times[-args.log_interval:])
            print(f"    Batch [{batch_idx+1:>4}/{len(train_loader)}] "
                  f"Loss: {loss.item() * accum_steps:.4f} (avg: {avg_loss:.4f}) "
                  f"Time: {avg_time:.3f}s/batch")
            
            if wandb_run:
                wandb_run.log({
                    "train/batch_loss": loss.item() * accum_steps,
                    "train/avg_loss": avg_loss,
                    "train/batch_time": avg_time,
                }, step=epoch * len(train_loader) + batch_idx)
    
    # Handle remaining gradients if not divisible by accum_steps
    if (batch_idx + 1) % accum_steps != 0:
        optimizer.step()
        optimizer.zero_grad()
    
    # Compute epoch metrics
    all_preds = torch.cat(predictions, dim=0).long()
    all_ans = torch.cat(answers, dim=0).long()
    epoch_acc = (all_preds == all_ans).sum().item() * 100.0 / len(all_ans)
    epoch_loss = total_loss / len(train_loader)
    epoch_time = time.time() - start_time
    
    # Per question type accuracy
    qtype_acc = {}
    qtype_names = ['Des', 'Exp', 'Pred-A', 'Pred-R', 'CF-A', 'CF-R']
    for qt in range(6):
        if qtype_total[qt] > 0:
            qtype_acc[qtype_names[qt]] = qtype_correct[qt] * 100.0 / qtype_total[qt]
    
    return {
        'loss': epoch_loss,
        'acc': epoch_acc,
        'time': epoch_time,
        'qtype_acc': qtype_acc,
        'wrong_samples': wrong_samples  # üî¥ Return wrong samples
    }


def evaluate(model, data_loader, device, split_name='val'):
    """Evaluate with detailed per-type accuracy and wrong sample tracking"""
    model.eval()
    
    predictions = []
    answers = []
    qtype_correct = defaultdict(int)
    qtype_total = defaultdict(int)
    
    # üî¥ Track wrong predictions
    wrong_samples = []
    all_questions = []
    all_qns_keys = []
    
    with torch.no_grad():
        for inputs in data_loader:
            vid_frame_feat, vid_obj_feat, qns_w, ans_w, ans_id, qns_keys = inputs
            vid_frame_feat = vid_frame_feat.to(device)
            vid_obj_feat = vid_obj_feat.to(device)
            
            out = model(vid_frame_feat, vid_obj_feat, qns_w, ans_w)
            pred = out.max(-1)[1].cpu()
            
            predictions.append(pred)
            answers.append(ans_id)
            
            for qkey, p, a, q in zip(qns_keys, pred.numpy(), ans_id.numpy(), qns_w):
                qtype = int(qkey.split('_')[-1])
                qtype_total[qtype] += 1
                if p == a:
                    qtype_correct[qtype] += 1
                else:
                    # üî¥ Track wrong predictions
                    wrong_samples.append({
                        "qid": qkey,
                        "video_id": qkey.rsplit('_', 1)[0],
                        "qtype": qtype,
                        "question": q,
                        "prediction": int(p),
                        "answer": int(a),
                    })
    
    all_preds = torch.cat(predictions, dim=0).long()
    all_ans = torch.cat(answers, dim=0).long()
    overall_acc = (all_preds == all_ans).sum().item() * 100.0 / len(all_ans)
    
    # Per question type accuracy
    qtype_names = ['Des', 'Exp', 'Pred-A', 'Pred-R', 'CF-A', 'CF-R']
    qtype_acc = {}
    for qt in range(6):
        if qtype_total[qt] > 0:
            qtype_acc[qtype_names[qt]] = qtype_correct[qt] * 100.0 / qtype_total[qt]
    
    return {
        'acc': overall_acc,
        'qtype_acc': qtype_acc,
        'n_samples': len(all_ans),
        'wrong_samples': wrong_samples  # üî¥ Return wrong samples
    }


def predict_and_save(model, data_loader, device, save_path):
    """Generate predictions and save to JSON with wrong sample tracking"""
    model.eval()
    results = {}
    wrong_samples = []
    
    with torch.no_grad():
        for inputs in data_loader:
            vid_frame_feat, vid_obj_feat, qns_w, ans_w, ans_id, qns_keys = inputs
            vid_frame_feat = vid_frame_feat.to(device)
            vid_obj_feat = vid_obj_feat.to(device)
            
            out = model(vid_frame_feat, vid_obj_feat, qns_w, ans_w)
            pred = out.max(-1)[1].cpu()
            
            for qid, p, a, q in zip(qns_keys, pred.numpy(), ans_id.numpy(), qns_w):
                results[qid] = {'prediction': int(p), 'answer': int(a)}
                
                if int(p) != int(a):
                    wrong_samples.append({
                        "qid": qid,
                        "video_id": qid.rsplit('_', 1)[0],
                        "qtype": int(qid.split('_')[-1]),
                        "question": q,
                        "prediction": int(p),
                        "answer": int(a),
                    })
    
    with open(save_path, 'w') as f:
        json.dump(results, f, indent=2)
    
    # Compute accuracy
    correct = sum(1 for v in results.values() if v['prediction'] == v['answer'])
    acc = correct * 100.0 / len(results)
    
    return results, acc, wrong_samples


print("‚úÖ Training functions defined with gradient accumulation and wrong sample tracking")

In [None]:
# ============================================================
# SETUP OPTIMIZER, SCHEDULER, CRITERION
# ============================================================
os.makedirs('./models', exist_ok=True)
os.makedirs('./prediction', exist_ok=True)

# Optimizer with different LR for text encoder
param_groups = [
    {
        "params": [p for n, p in model.named_parameters() 
                   if "text_encoder" not in n and p.requires_grad],
        "lr": args.lr
    },
    {
        "params": [p for n, p in model.named_parameters() 
                   if "text_encoder" in n and p.requires_grad],
        "lr": args.text_encoder_lr
    }
]

optimizer = torch.optim.AdamW(param_groups, weight_decay=args.decay)
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=args.gamma, 
                               patience=args.patience, verbose=True)
criterion = nn.CrossEntropyLoss()

print("‚úÖ Optimizer and scheduler created")
print(f"   Main LR: {args.lr}")
print(f"   Text encoder LR: {args.text_encoder_lr}")
print(f"   Gradient Accumulation: {args.ACCUM_STEPS} steps")

In [None]:
# ============================================================
# INITIALIZE WANDB
# ============================================================
wandb_config = {
    "architecture": "TranSTR",
    "dataset": "CausalVidQA",
    "epochs": args.epoch,
    "batch_size": args.bs,
    "learning_rate": args.lr,
    "text_encoder_lr": args.text_encoder_lr,
    "text_encoder": args.text_encoder_type,
    "d_model": args.d_model,
    "topK_frame": args.topK_frame,
    "topK_obj": args.topK_obj,
    "n_objects": args.objs,
    "num_encoder_layers": args.num_encoder_layers,
    "num_decoder_layers": args.num_decoder_layers,
    "multi_gpu": args.use_multi_gpu,
    "n_gpus": torch.cuda.device_count(),
    "train_samples": len(train_dataset),
    "val_samples": len(val_dataset),
    "test_samples": len(test_dataset),
}

run = wandb.init(
    project=args.project_name,
    name=args.run_name,
    config=wandb_config,
    tags=["causalvid", "multi-gpu" if args.use_multi_gpu else "single-gpu"]
)

# Log dataset info
wandb.log({
    "data/train_videos": len(train_dataset.vids) if hasattr(train_dataset, 'vids') else 0,
    "data/val_videos": len(val_dataset.vids) if hasattr(val_dataset, 'vids') else 0,
    "data/test_videos": len(test_dataset.vids) if hasattr(test_dataset, 'vids') else 0,
})

print(f"‚úÖ W&B initialized: {run.url}")

In [None]:
# ============================================================
# TRAINING LOOP WITH GRADIENT ACCUMULATION & EARLY STOPPING
# ============================================================
best_val_acc = 0.0
best_epoch = 1
best_model_path = f'./models/best_model-{args.run_name}.ckpt'
history = {'train': [], 'val': [], 'test': []}
epochs_without_improvement = 0

# üî¥ Track all wrong samples across epochs
all_wrong_samples = {
    'train': [],
    'val': [],
    'test': []
}

print("=" * 70)
print(f"üöÄ STARTING TRAINING: {args.run_name}")
print(f"   Epochs: {args.epoch} | Batch size: {args.bs} | Accum steps: {args.ACCUM_STEPS}")
print(f"   Effective batch size: {args.bs * args.ACCUM_STEPS}")
print(f"   Early stopping: {args.early_stopping_patience} epochs without improvement")
print("=" * 70)

for epoch in range(1, args.epoch + 1):
    print(f"\n{'='*70}")
    print(f"üìö EPOCH [{epoch}/{args.epoch}]")
    print(f"{'='*70}")
    
    # ============ TRAIN with gradient accumulation ============
    train_metrics = train_epoch(
        model, optimizer, train_loader, criterion, device, epoch, 
        wandb_run=run, accum_steps=args.ACCUM_STEPS
    )
    
    # ============ EVALUATE ============
    val_metrics = evaluate(model, val_loader, device, 'val')
    test_metrics = evaluate(model, test_loader, device, 'test')
    
    # ============ UPDATE SCHEDULER ============
    scheduler.step(val_metrics['acc'])
    current_lr = optimizer.param_groups[0]['lr']
    
    # ============ SAVE BEST MODEL & EARLY STOPPING ============
    is_best = val_metrics['acc'] > best_val_acc
    if is_best:
        best_val_acc = val_metrics['acc']
        best_epoch = epoch
        epochs_without_improvement = 0
        torch.save(model.state_dict(), best_model_path)
        
        # üî¥ Save wrong samples from best epoch
        all_wrong_samples['train'] = train_metrics['wrong_samples']
        all_wrong_samples['val'] = val_metrics['wrong_samples']
        all_wrong_samples['test'] = test_metrics['wrong_samples']
    else:
        epochs_without_improvement += 1
    
    # ============ LOGGING ============
    print(f"\n  üìä Results:")
    print(f"     {'Metric':<15} {'Train':>10} {'Val':>10} {'Test':>10}")
    print(f"     {'-'*15} {'-'*10} {'-'*10} {'-'*10}")
    print(f"     {'Loss':<15} {train_metrics['loss']:>10.4f} {'-':>10} {'-':>10}")
    print(f"     {'Accuracy':<15} {train_metrics['acc']:>9.2f}% {val_metrics['acc']:>9.2f}% {test_metrics['acc']:>9.2f}%")
    print(f"     {'Wrong samples':<15} {len(train_metrics['wrong_samples']):>10} {len(val_metrics['wrong_samples']):>10} {len(test_metrics['wrong_samples']):>10}")
    
    # Per question type accuracy
    print(f"\n  üìà Per Question Type Accuracy (Val):")
    qtype_order = ['Des', 'Exp', 'Pred-A', 'Pred-R', 'CF-A', 'CF-R']
    for qt in qtype_order:
        if qt in val_metrics['qtype_acc']:
            print(f"     {qt:<10}: {val_metrics['qtype_acc'][qt]:>6.2f}%")
    
    print(f"\n  ‚è±Ô∏è  Time: {train_metrics['time']:.1f}s | LR: {current_lr:.2e}")
    print(f"  üìâ No improvement: {epochs_without_improvement}/{args.early_stopping_patience} epochs")
    if is_best:
        print(f"  üíæ Saved best model! (Val acc: {best_val_acc:.2f}%)")
    
    # ============ WANDB LOGGING ============
    wandb_log = {
        "epoch": epoch,
        "train/loss": train_metrics['loss'],
        "train/acc": train_metrics['acc'],
        "train/wrong_count": len(train_metrics['wrong_samples']),
        "val/acc": val_metrics['acc'],
        "val/wrong_count": len(val_metrics['wrong_samples']),
        "test/acc": test_metrics['acc'],
        "test/wrong_count": len(test_metrics['wrong_samples']),
        "lr": current_lr,
        "epoch_time": train_metrics['time'],
        "best_val_acc": best_val_acc,
        "epochs_without_improvement": epochs_without_improvement,
    }
    
    # Log per question type accuracy
    for qt, acc in train_metrics['qtype_acc'].items():
        wandb_log[f"train/acc_{qt}"] = acc
    for qt, acc in val_metrics['qtype_acc'].items():
        wandb_log[f"val/acc_{qt}"] = acc
    for qt, acc in test_metrics['qtype_acc'].items():
        wandb_log[f"test/acc_{qt}"] = acc
    
    wandb.log(wandb_log)
    
    # Save checkpoint every N epochs
    if epoch % args.save_every == 0:
        ckpt_path = f'./models/checkpoint-{args.run_name}-ep{epoch}.ckpt'
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_metrics['acc'],
        }, ckpt_path)
        print(f"  üìÅ Checkpoint saved: {ckpt_path}")
    
    # ============ EARLY STOPPING CHECK ============
    if epochs_without_improvement >= args.early_stopping_patience:
        print(f"\n  ‚ö†Ô∏è EARLY STOPPING: No improvement for {args.early_stopping_patience} epochs")
        print(f"     Best val acc: {best_val_acc:.2f}% at epoch {best_epoch}")
        wandb.log({"early_stopped": True, "stopped_at_epoch": epoch})
        break

# ============================================================
# TRAINING COMPLETE
# ============================================================
print("\n" + "=" * 70)
print("‚úÖ TRAINING COMPLETED!")
print("=" * 70)
print(f"   Best epoch: {best_epoch}")
print(f"   Best val accuracy: {best_val_acc:.2f}%")
print(f"   Model saved: {best_model_path}")
print(f"   Wrong samples tracked: Train={len(all_wrong_samples['train'])}, Val={len(all_wrong_samples['val'])}, Test={len(all_wrong_samples['test'])}")
if epochs_without_improvement >= args.early_stopping_patience:
    print(f"   Stopped early at epoch {epoch}")
print("=" * 70)

In [None]:
# ============================================================
# FINAL EVALUATION WITH BEST MODEL
# ============================================================
print("=" * 70)
print("üìä FINAL EVALUATION")
print("=" * 70)

# Load best model
print("\n  Loading best model...")
model.load_state_dict(torch.load(best_model_path))

# Predict on test set with wrong sample tracking
result_path = f'./prediction/{args.run_name}-ep{best_epoch}-val{best_val_acc:.2f}.json'
results, test_acc, test_wrong_samples = predict_and_save(model, test_loader, device, result_path)

print(f"\n  Test accuracy: {test_acc:.2f}%")
print(f"  Results saved: {result_path}")
print(f"  Wrong predictions: {len(test_wrong_samples)}")

# Save wrong samples to JSON
wrong_samples_path = f'./prediction/{args.run_name}-wrong_samples.json'
with open(wrong_samples_path, 'w') as f:
    json.dump({
        'test': test_wrong_samples,
        'val': all_wrong_samples['val'],
        'metadata': {
            'best_epoch': best_epoch,
            'best_val_acc': best_val_acc,
            'test_acc': test_acc,
            'run_name': args.run_name,
        }
    }, f, indent=2)
print(f"  Wrong samples saved: {wrong_samples_path}")

# Detailed evaluation by question type
print("\n" + "-" * 70)
print("  Detailed Results by Question Type:")
print("-" * 70)
eval_mc.accuracy_metric_cvid(result_path)

# Log final results to wandb
wandb.log({
    "final/test_acc": test_acc,
    "final/best_epoch": best_epoch,
    "final/best_val_acc": best_val_acc,
    "final/wrong_test_count": len(test_wrong_samples),
})

# Save results artifact
artifact = wandb.Artifact(f'predictions-{args.run_name}', type='predictions')
artifact.add_file(result_path)
artifact.add_file(wrong_samples_path)
wandb.log_artifact(artifact)

In [None]:
# ============================================================
# SAVE TO KAGGLE OUTPUT & FINISH WANDB
# ============================================================
import shutil

output_dir = '/kaggle/working'
if os.path.exists(output_dir):
    # Copy best model
    shutil.copy(best_model_path, os.path.join(output_dir, f'best_model-{args.run_name}.ckpt'))
    # Copy predictions
    shutil.copy(result_path, output_dir)
    # Copy wrong samples
    shutil.copy(wrong_samples_path, output_dir)
    print(f"‚úÖ Files saved to {output_dir}")
else:
    print("  Not running on Kaggle, files saved locally")

# Save model artifact to wandb
model_artifact = wandb.Artifact(f'model-{args.run_name}', type='model')
model_artifact.add_file(best_model_path)
wandb.log_artifact(model_artifact)

# Finish wandb run
wandb.finish()
print("‚úÖ W&B run finished")

In [None]:
# ============================================================
# UPLOAD WRONG PREDICTIONS TO HUGGINGFACE DATASET
# ============================================================
# Install huggingface_hub if not already installed
!pip install -q huggingface_hub datasets

from huggingface_hub import HfApi, login
from datasets import Dataset, DatasetDict
import pandas as pd

# ============================================================
# CONFIGURATION
# ============================================================
HF_TOKEN = "YOUR_HF_TOKEN_HERE"  # Replace with your HuggingFace token
HF_REPO_NAME = "your-username/causalvid-wrong-predictions"  # Replace with your repo name

# Login to HuggingFace
login(token=HF_TOKEN)

# ============================================================
# PREPARE WRONG SAMPLES DATA
# ============================================================
print("=" * 70)
print("üì§ PREPARING WRONG PREDICTIONS FOR HUGGINGFACE")
print("=" * 70)

# Load wrong samples
with open(wrong_samples_path, 'r') as f:
    wrong_data = json.load(f)

# Question type names
qtype_names = {
    0: 'descriptive',
    1: 'explanatory', 
    2: 'predictive_answer',
    3: 'predictive_reason',
    4: 'counterfactual_answer',
    5: 'counterfactual_reason'
}

# Prepare test wrong samples as DataFrame
test_wrong = wrong_data.get('test', [])
if test_wrong:
    df_test = pd.DataFrame(test_wrong)
    df_test['qtype_name'] = df_test['qtype'].map(qtype_names)
    df_test['split'] = 'test'
    print(f"  Test wrong samples: {len(df_test)}")
else:
    df_test = pd.DataFrame()
    print("  No test wrong samples")

# Prepare val wrong samples as DataFrame
val_wrong = wrong_data.get('val', [])
if val_wrong:
    df_val = pd.DataFrame(val_wrong)
    df_val['qtype_name'] = df_val['qtype'].map(qtype_names)
    df_val['split'] = 'val'
    print(f"  Val wrong samples: {len(df_val)}")
else:
    df_val = pd.DataFrame()
    print("  No val wrong samples")

# Combine all wrong samples
df_all = pd.concat([df_test, df_val], ignore_index=True) if not df_test.empty or not df_val.empty else pd.DataFrame()

if df_all.empty:
    print("‚ö†Ô∏è No wrong samples to upload!")
else:
    # Add metadata columns
    df_all['model_name'] = args.run_name
    df_all['best_epoch'] = best_epoch
    df_all['best_val_acc'] = best_val_acc
    
    print(f"\n  Total wrong samples: {len(df_all)}")
    print(f"  Columns: {list(df_all.columns)}")
    
    # Show sample data
    print("\n  Sample wrong predictions:")
    print(df_all.head(3).to_string(index=False))

# ============================================================
# CREATE HUGGINGFACE DATASET
# ============================================================
if not df_all.empty:
    print("\n" + "=" * 70)
    print("üì§ UPLOADING TO HUGGINGFACE")
    print("=" * 70)
    
    # Create HuggingFace Dataset
    hf_dataset = Dataset.from_pandas(df_all)
    
    # Create DatasetDict with splits
    dataset_dict = DatasetDict({
        'test': Dataset.from_pandas(df_test) if not df_test.empty else None,
        'val': Dataset.from_pandas(df_val) if not df_val.empty else None,
        'all': hf_dataset
    })
    
    # Remove None splits
    dataset_dict = DatasetDict({k: v for k, v in dataset_dict.items() if v is not None})
    
    print(f"  Dataset splits: {list(dataset_dict.keys())}")
    print(f"  Total samples: {sum(len(v) for v in dataset_dict.values())}")
    
    # Push to HuggingFace Hub
    try:
        dataset_dict.push_to_hub(
            HF_REPO_NAME,
            private=True,  # Set to False if you want public dataset
            token=HF_TOKEN
        )
        print(f"\n‚úÖ Successfully uploaded to: https://huggingface.co/datasets/{HF_REPO_NAME}")
    except Exception as e:
        print(f"\n‚ùå Error uploading to HuggingFace: {e}")
        print("  Make sure to set HF_TOKEN and HF_REPO_NAME correctly")
        
        # Save locally as backup
        backup_path = f'./prediction/{args.run_name}-wrong_samples_hf.parquet'
        df_all.to_parquet(backup_path)
        print(f"  Saved backup to: {backup_path}")

print("=" * 70)

In [None]:
# ============================================================
# ANALYZE WRONG PREDICTIONS
# ============================================================
print("=" * 70)
print("üìä WRONG PREDICTIONS ANALYSIS")
print("=" * 70)

# Load wrong samples
with open(wrong_samples_path, 'r') as f:
    wrong_data = json.load(f)

test_wrong = wrong_data.get('test', [])

if test_wrong:
    import pandas as pd
    df = pd.DataFrame(test_wrong)
    
    # Question type distribution
    print("\nüîç Wrong predictions by Question Type:")
    print("-" * 40)
    qtype_names = {
        0: 'Descriptive',
        1: 'Explanatory', 
        2: 'Predictive-Ans',
        3: 'Predictive-Reason',
        4: 'Counterfactual-Ans',
        5: 'Counterfactual-Reason'
    }
    
    qtype_counts = df['qtype'].value_counts().sort_index()
    for qt, count in qtype_counts.items():
        print(f"  {qtype_names.get(qt, qt):<20}: {count:>5} ({count/len(df)*100:.1f}%)")
    
    # Most common video IDs with wrong predictions
    print("\nüé¨ Videos with most wrong predictions:")
    print("-" * 40)
    video_counts = df['video_id'].value_counts().head(10)
    for vid, count in video_counts.items():
        print(f"  {vid}: {count} wrong")
    
    # Prediction vs Answer distribution
    print("\nüìà Prediction distribution (wrong samples):")
    print("-" * 40)
    pred_dist = df['prediction'].value_counts().sort_index()
    for pred, count in pred_dist.items():
        print(f"  Choice {pred}: {count} times predicted")
    
    print("\nüìâ Answer distribution (ground truth for wrong samples):")
    print("-" * 40)
    ans_dist = df['answer'].value_counts().sort_index()
    for ans, count in ans_dist.items():
        print(f"  Choice {ans}: {count} times correct")
        
else:
    print("  No wrong samples to analyze")

print("=" * 70)

### Evaluating pretrained model B2A

### Training