## 1. üì¶ Install Dependencies

In [None]:
!pip install -q transformers einops h5py wandb

# Login to W&B (uncomment v√† th√™m API key c·ªßa b·∫°n)
import wandb
wandb.login()  # S·∫Ω prompt nh·∫≠p API key n·∫øu ch∆∞a login

## 1.5 üìù Patch DataLoader

In [None]:
# Patch DataLoader.py ƒë·ªÉ x·ª≠ l√Ω dimension mismatch
# Ch·∫°y cell n√†y tr∆∞·ªõc khi import DataLoader

patch_code = '''
import torch
import os
import h5py
import os.path as osp
import numpy as np
import json
import pickle as pkl
from torch.utils import data
from utils.util import load_file, pause, transform_bb, pkload
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaTokenizerFast


class VideoQADataset(Dataset):
    """
    DataLoader cho CausalVidQA v·ªõi output format t∆∞∆°ng th√≠ch NextQA
    """
    
    def __init__(self, split, n_query=5, obj_num=1, 
                 sample_list_path=None,
                 video_feature_path=None,
                 text_annotation_path=None,
                 qtype=-1,
                 max_samples=None):
        super(VideoQADataset, self).__init__()
        
        self.split = split
        self.mc = n_query
        self.obj_num = obj_num
        self.qtype = qtype
        self.video_feature_path = video_feature_path
        self.text_annotation_path = text_annotation_path
        self.max_samples = max_samples
        
        # Load video ids for this split
        split_name = split
        if split == 'val':
            split_file = osp.join(sample_list_path, 'val.pkl')
            if not osp.exists(split_file):
                split_file = osp.join(sample_list_path, 'valid.pkl')
        else:
            split_file = osp.join(sample_list_path, f'{split}.pkl')
        
        if not osp.exists(split_file):
            raise FileNotFoundError(f"Split file not found: {split_file}")
        
        self.vids = pkload(split_file)
        
        if self.vids is None:
            raise ValueError(f"Failed to load split file: {split_file}")
        
        if max_samples is not None and max_samples > 0:
            self.vids = self.vids[:max_samples]
            print(f"Limited to {len(self.vids)} videos (max_samples={max_samples})")
        else:
            print(f"Loaded {len(self.vids)} videos from {split_file}")
        
        # Load video feature index mapping
        idx2vid_file = osp.join(video_feature_path, 'idx2vid.pkl')
        vf_info = pkload(idx2vid_file)
        self.vf_info = dict()
        for idx, vid in enumerate(vf_info):
            if vid in self.vids:
                self.vf_info[vid] = idx
        
        # Load appearance features
        app_file = osp.join(video_feature_path, 'appearance_feat.h5')
        print(f'Loading {app_file}...')
        self.app_feats = dict()
        with h5py.File(app_file, 'r') as fp:
            feats = fp['resnet_features']
            for vid, idx in self.vf_info.items():
                self.app_feats[vid] = feats[idx][...]
        
        # Load motion features
        mot_file = osp.join(video_feature_path, 'motion_feat.h5')
        print(f'Loading {mot_file}...')
        self.mot_feats = dict()
        with h5py.File(mot_file, 'r') as fp:
            feats = fp['resnet_features']
            for vid, idx in self.vf_info.items():
                self.mot_feats[vid] = feats[idx][...]
        
        self._build_sample_list()

    def _build_sample_list(self):
        self.samples = []
        
        if self.qtype == -1:
            for vid in self.vids:
                for qt in range(6):
                    self.samples.append((vid, qt))
        elif self.qtype == 0 or self.qtype == 1:
            for vid in self.vids:
                self.samples.append((vid, self.qtype))
        elif self.qtype == 2:
            for vid in self.vids:
                self.samples.append((vid, 2))
                self.samples.append((vid, 3))
        elif self.qtype == 3:
            for vid in self.vids:
                self.samples.append((vid, 4))
                self.samples.append((vid, 5))
        else:
            for vid in self.vids:
                self.samples.append((vid, self.qtype))
        
        print(f"Total samples: {len(self.samples)}")

    def _load_text(self, vid, qtype):
        text_file = osp.join(self.text_annotation_path, vid, 'text.json')
        answer_file = osp.join(self.text_annotation_path, vid, 'answer.json')
        
        if not osp.exists(text_file):
            text_file = osp.join(self.text_annotation_path, 'QA', vid, 'text.json')
            answer_file = osp.join(self.text_annotation_path, 'QA', vid, 'answer.json')
        
        if not osp.exists(text_file):
            raise FileNotFoundError(f"Text annotation not found for video: {vid}")
        
        with open(text_file, 'r') as f:
            text = json.load(f)
        with open(answer_file, 'r') as f:
            answer = json.load(f)
        
        if qtype == 0:
            qns = text['descriptive']['question']
            cand_ans = text['descriptive']['answer']
            ans_id = answer['descriptive']['answer']
        elif qtype == 1:
            qns = text['explanatory']['question']
            cand_ans = text['explanatory']['answer']
            ans_id = answer['explanatory']['answer']
        elif qtype == 2:
            qns = text['predictive']['question']
            cand_ans = text['predictive']['answer']
            ans_id = answer['predictive']['answer']
        elif qtype == 3:
            qns = text['predictive']['question']
            cand_ans = text['predictive']['reason']
            ans_id = answer['predictive']['reason']
        elif qtype == 4:
            qns = text['counterfactual']['question']
            cand_ans = text['counterfactual']['answer']
            ans_id = answer['counterfactual']['answer']
        elif qtype == 5:
            qns = text['counterfactual']['question']
            cand_ans = text['counterfactual']['reason']
            ans_id = answer['counterfactual']['reason']
        else:
            raise ValueError(f"Invalid qtype: {qtype}")
        
        return qns, cand_ans, ans_id


    def __getitem__(self, idx):
        vid, qtype = self.samples[idx]
        
        qns_word, cand_ans, ans_id = self._load_text(vid, qtype)
        ans_word = ['[CLS] ' + qns_word + ' [SEP] ' + str(cand_ans[i]) for i in range(self.mc)]
        
        # Load video features
        app_feat = self.app_feats[vid]
        mot_feat = self.mot_feats[vid]
        
        # === FIX: Handle different feature shapes ===
        # Squeeze or reshape if needed to get (T, D)
        if app_feat.ndim == 3:
            app_feat = app_feat.mean(axis=1) if app_feat.shape[1] > 1 else app_feat.squeeze(1)
        if mot_feat.ndim == 3:
            mot_feat = mot_feat.mean(axis=1) if mot_feat.shape[1] > 1 else mot_feat.squeeze(1)
        
        if app_feat.ndim == 1:
            app_feat = app_feat[np.newaxis, :]
        if mot_feat.ndim == 1:
            mot_feat = mot_feat[np.newaxis, :]
        # === END FIX ===
        
        # Frame feature: concatenate app + mot
        frame_feat = np.concatenate([app_feat, mot_feat], axis=-1)
        vid_frame_feat = torch.from_numpy(frame_feat).type(torch.float32)
        
        # Object features
        T = app_feat.shape[0]
        D_obj = app_feat.shape[-1]
        
        obj_feat = np.tile(app_feat[:, np.newaxis, :], (1, self.obj_num, 1))
        dummy_bbox = np.zeros((T, self.obj_num, 5), dtype=np.float32)
        dummy_bbox[:, :, :4] = np.array([0.0, 0.0, 1.0, 1.0])
        dummy_bbox[:, :, 4] = 1.0
        
        obj_feat = np.concatenate([obj_feat, dummy_bbox], axis=-1)
        vid_obj_feat = torch.from_numpy(obj_feat).type(torch.float32)
        
        qns_key = vid + '_' + str(qtype)
        
        return vid_frame_feat, vid_obj_feat, qns_word, ans_word, ans_id, qns_key


    def __len__(self):
        return len(self.samples)
'''

# Write patched DataLoader.py
with open('DataLoader.py', 'w') as f:
    f.write(patch_code)

print("‚úÖ DataLoader.py patched with dimension fix!")

## 2. üìÇ Data Paths & Statistics

In [None]:
import os
import pickle
import h5py
import json

# ============================================================
# DATA PATHS (Kaggle Input)
# ============================================================
text_feature_path = '/kaggle/input/text-feature'
visual_feature_path = '/kaggle/input/visual-feature'
split_path = '/kaggle/input/dataset-split-1'
text_annotation_path = '/kaggle/input/text-annotation'

print("=" * 70)
print("üìÇ DATA PATHS")
print("=" * 70)
for name, path in [("Visual features", visual_feature_path), 
                   ("Split files", split_path), 
                   ("Text annotations", text_annotation_path)]:
    status = "‚úì" if os.path.exists(path) else "‚úó"
    print(f"  {status} {name}: {path}")

# ============================================================
# DATA STATISTICS
# ============================================================
print("\n" + "=" * 70)
print("üìä DATASET STATISTICS")
print("=" * 70)

# 1. Split files
print("\nüìÅ Split Files:")
split_stats = {}
for split_name in ['train', 'valid', 'test']:
    split_file = f'{split_path}/{split_name}.pkl'
    if os.path.exists(split_file):
        with open(split_file, 'rb') as f:
            vids = pickle.load(f)
        split_stats[split_name] = len(vids)
        samples = len(vids) * 6  # 6 question types per video
        print(f"  {split_name:>6}: {len(vids):>6} videos ‚Üí {samples:>6} samples")

# 2. Visual features
print("\nüé¨ Visual Features:")
idx2vid_file = f'{visual_feature_path}/idx2vid.pkl'
if os.path.exists(idx2vid_file):
    with open(idx2vid_file, 'rb') as f:
        idx2vid = pickle.load(f)
    print(f"  Indexed videos: {len(idx2vid)}")

for feat_name in ['appearance_feat.h5', 'motion_feat.h5']:
    feat_file = f'{visual_feature_path}/{feat_name}'
    if os.path.exists(feat_file):
        with h5py.File(feat_file, 'r') as f:
            shape = f['resnet_features'].shape
        print(f"  {feat_name}: {shape}")

# 3. Question types
print("\n‚ùì Question Types (qtype):")
qtype_info = [
    ("0", "Descriptive", "What is happening?"),
    ("1", "Explanatory", "Why did it happen?"),
    ("2", "Predictive-Ans", "What will happen?"),
    ("3", "Predictive-Reason", "Why will it happen?"),
    ("4", "Counterfactual-Ans", "What if X didn't happen?"),
    ("5", "Counterfactual-Reason", "Why would that result?"),
]
for qt, name, desc in qtype_info:
    print(f"  {qt}: {name:<20} - {desc}")

print("\n" + "=" * 70)

## 3. ‚öôÔ∏è Configuration

In [None]:
import os
import sys
import torch
import numpy as np
import wandb

# ============================================================
# CONFIGURATION
# ============================================================

class Config:
    """Training configuration for CausalVidQA"""
    
    # Experiment
    project_name = "CausalVidQA-TranSTR"
    run_name = "causalvid_2gpu"
    
    # Data paths
    sample_list_path = split_path
    video_feature_path = visual_feature_path
    text_annotation_path = text_annotation_path
    
    # Training
    bs = 16                    # Batch size (s·∫Ω chia ƒë·ªÅu cho 2 GPU)
    lr = 1e-4                  # Learning rate
    text_encoder_lr = 1e-5     # Text encoder LR (lower)
    epoch = 20
    warmup_epochs = 2          # Warmup epochs
    
    # Dataset
    dataset = 'causal-vid'
    qtype = -1                 # -1 = all question types
    max_samples = None         # None = use all data
    
    # Model architecture
    d_model = 768
    word_dim = 768
    nheads = 8
    num_encoder_layers = 1
    num_decoder_layers = 1
    dropout = 0.1
    encoder_dropout = 0.1
    activation = 'relu'
    normalize_before = False
    
    # Video features
    objs = 20                  # Objects per frame
    topK_frame = 8             # Top-K frames to select
    topK_obj = 5               # Top-K objects to select
    frame_feat_dim = 4096      # app(2048) + mot(2048)
    obj_feat_dim = 2053        # feat(2048) + bbox(5)
    n_query = 5                # 5-way multiple choice
    
    # Text encoder
    text_encoder_type = "microsoft/deberta-base"
    freeze_text_encoder = False
    text_pool_mode = 0
    hard_eval = False
    
    # Optimizer
    decay = 0.001              # Weight decay
    patience = 3               # LR scheduler patience
    gamma = 0.5                # LR decay factor
    
    # Early stopping
    early_stopping_patience = 5  # Stop after 5 epochs without improvement
    
    # Contrastive learning
    pos_ratio = 0.7
    neg_ratio = 0.3
    a = 1
    
    # Multi-GPU
    use_multi_gpu = True       # Enable DataParallel
    num_workers = 4            # DataLoader workers
    
    # Logging
    log_interval = 50          # Log every N batches
    save_every = 5             # Save checkpoint every N epochs

args = Config()

# ============================================================
# GPU SETUP
# ============================================================
print("=" * 70)
print("üñ•Ô∏è GPU CONFIGURATION")
print("=" * 70)

n_gpus = torch.cuda.device_count()
print(f"  Available GPUs: {n_gpus}")
for i in range(n_gpus):
    print(f"    GPU {i}: {torch.cuda.get_device_name(i)}")
    mem = torch.cuda.get_device_properties(i).total_memory / 1e9
    print(f"           Memory: {mem:.1f} GB")

if n_gpus >= 2 and args.use_multi_gpu:
    print(f"\n  ‚úì Multi-GPU mode: DataParallel on {n_gpus} GPUs")
    print(f"  ‚úì Effective batch size: {args.bs} (total)")
else:
    print(f"\n  ‚Üí Single GPU mode")
    args.use_multi_gpu = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"  Primary device: {device}")

# ============================================================
# PRINT CONFIG
# ============================================================
print("\n" + "=" * 70)
print("‚öôÔ∏è TRAINING CONFIG")
print("=" * 70)
config_items = [
    ("Batch size", args.bs),
    ("Learning rate", args.lr),
    ("Text encoder LR", args.text_encoder_lr),
    ("Epochs", args.epoch),
    ("Early stopping", f"{args.early_stopping_patience} epochs"),
    ("d_model", args.d_model),
    ("TopK frames", args.topK_frame),
    ("TopK objects", args.topK_obj),
    ("Objects/frame", args.objs),
    ("Text encoder", args.text_encoder_type),
]

for name, val in config_items:
    print(f"  {name:<20}: {val}")
print("=" * 70)

## 4. üìö Import Modules

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingWarmRestarts
from collections import defaultdict
import time
import json

# Local imports
from DataLoader import VideoQADataset
from networks.model import VideoQAmodel
import eval_mc

# ============================================================
# REPRODUCIBILITY
# ============================================================
def set_seed(seed=999):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(999)
print("‚úÖ Modules imported, seed set to 999")

## 5. üìä Create DataLoaders

In [None]:
print("Creating datasets...")

# ============================================================
# CREATE DATASETS
# ============================================================
dataset_kwargs = dict(
    n_query=args.n_query,
    obj_num=args.objs,
    sample_list_path=args.sample_list_path,
    video_feature_path=args.video_feature_path,
    text_annotation_path=args.text_annotation_path,
    qtype=args.qtype,
    max_samples=args.max_samples
)

train_dataset = VideoQADataset(split='train', **dataset_kwargs)
val_dataset = VideoQADataset(split='val', **dataset_kwargs)

# Test set LU√îN d√πng to√†n b·ªô data (kh√¥ng gi·ªõi h·∫°n max_samples)
test_kwargs = dataset_kwargs.copy()
test_kwargs['max_samples'] = None  # Force full test set
test_dataset = VideoQADataset(split='test', **test_kwargs)

# ============================================================
# CREATE DATALOADERS (optimized for multi-GPU)
# ============================================================
loader_kwargs = dict(
    batch_size=args.bs,
    num_workers=args.num_workers if args.use_multi_gpu else 0,
    pin_memory=True,
    prefetch_factor=2 if args.num_workers > 0 else None,
)

train_loader = DataLoader(train_dataset, shuffle=True, **loader_kwargs)
val_loader = DataLoader(val_dataset, shuffle=False, **loader_kwargs)
test_loader = DataLoader(test_dataset, shuffle=False, **loader_kwargs)

# ============================================================
# DATASET SUMMARY
# ============================================================
print("\n" + "=" * 70)
print("üìä DATALOADER SUMMARY")
print("=" * 70)
print(f"  {'Split':<10} {'Videos':>10} {'Samples':>10} {'Batches':>10}")
print(f"  {'-'*10} {'-'*10} {'-'*10} {'-'*10}")
for name, dataset, loader in [
    ("Train", train_dataset, train_loader),
    ("Val", val_dataset, val_loader),
    ("Test (FULL)", test_dataset, test_loader)
]:
    n_vids = len(dataset.vids) if hasattr(dataset, 'vids') else "?"
    print(f"  {name:<10} {n_vids:>10} {len(dataset):>10} {len(loader):>10}")
print("=" * 70)
print("  ‚ÑπÔ∏è  Test set always uses ALL data regardless of max_samples")

## 6. üîç Check Data Sample

In [None]:
# ============================================================
# VERIFY DATA SAMPLE
# ============================================================
print("üîç Verifying data sample...")

for batch in train_loader:
    vid_frame_feat, vid_obj_feat, qns_word, ans_word, ans_id, qns_key = batch
    
    print(f"\n  Frame features:  {vid_frame_feat.shape}")
    print(f"  Object features: {vid_obj_feat.shape}")
    print(f"  Batch size:      {len(qns_word)}")
    print(f"\n  Sample question: {qns_word[0][:80]}...")
    print(f"  Sample answer:   {ans_word[0][0][:60]}...")
    print(f"  Ground truth:    {ans_id[0].item()}")
    print(f"  Question key:    {qns_key[0]}")
    break

print("\n‚úÖ Data verification complete!")

## 7. üèóÔ∏è Create Model

In [None]:
# ============================================================
# CREATE MODEL
# ============================================================
print("üèóÔ∏è Creating model...")

model_config = {
    'd_model': args.d_model,
    'word_dim': args.word_dim,
    'encoder_dropout': args.encoder_dropout,
    'dropout': args.dropout,
    'num_encoder_layers': args.num_encoder_layers,
    'num_decoder_layers': args.num_decoder_layers,
    'nheads': args.nheads,
    'normalize_before': args.normalize_before,
    'activation': args.activation,
    'text_encoder_type': args.text_encoder_type,
    'freeze_text_encoder': args.freeze_text_encoder,
    'text_pool_mode': args.text_pool_mode,
    'n_query': args.n_query,
    'objs': args.objs,
    'topK_frame': args.topK_frame,
    'topK_obj': args.topK_obj,
    'hard_eval': args.hard_eval,
    'frame_feat_dim': args.frame_feat_dim,
    'obj_feat_dim': args.obj_feat_dim,
    'device': device,
}

model = VideoQAmodel(**model_config)

# ============================================================
# MULTI-GPU SETUP (DataParallel)
# ============================================================
if args.use_multi_gpu and torch.cuda.device_count() > 1:
    print(f"  ‚Üí Wrapping model with DataParallel ({torch.cuda.device_count()} GPUs)")
    model = nn.DataParallel(model)

model.to(device)

# ============================================================
# MODEL SUMMARY
# ============================================================
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("\n" + "=" * 70)
print("üß† MODEL SUMMARY")
print("=" * 70)
print(f"  Total parameters:     {total_params / 1e6:.2f}M")
print(f"  Trainable parameters: {trainable_params / 1e6:.2f}M")
print(f"  Multi-GPU:            {args.use_multi_gpu and torch.cuda.device_count() > 1}")
print("=" * 70)

## 8. üéØ Training Functions

In [None]:
# ============================================================
# TRAINING FUNCTIONS
# ============================================================

def train_epoch(model, optimizer, train_loader, criterion, device, epoch, wandb_run=None):
    """Train for one epoch with detailed logging"""
    model.train()
    
    total_loss = 0.0
    predictions = []
    answers = []
    batch_times = []
    
    # Per question type tracking
    qtype_correct = defaultdict(int)
    qtype_total = defaultdict(int)
    
    start_time = time.time()
    
    for batch_idx, inputs in enumerate(train_loader):
        batch_start = time.time()
        
        vid_frame_feat, vid_obj_feat, qns_w, ans_w, ans_id, qns_keys = inputs
        vid_frame_feat = vid_frame_feat.to(device)
        vid_obj_feat = vid_obj_feat.to(device)
        ans_targets = ans_id.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        out = model(vid_frame_feat, vid_obj_feat, qns_w, ans_w)
        loss = criterion(out, ans_targets)
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        # Track metrics
        total_loss += loss.item()
        pred = out.max(-1)[1].cpu()
        predictions.append(pred)
        answers.append(ans_id)
        
        # Track per question type accuracy
        for qkey, p, a in zip(qns_keys, pred.numpy(), ans_id.numpy()):
            qtype = int(qkey.split('_')[-1])
            qtype_total[qtype] += 1
            if p == a:
                qtype_correct[qtype] += 1
        
        batch_times.append(time.time() - batch_start)
        
        # Logging
        if (batch_idx + 1) % args.log_interval == 0:
            avg_loss = total_loss / (batch_idx + 1)
            avg_time = np.mean(batch_times[-args.log_interval:])
            print(f"    Batch [{batch_idx+1:>4}/{len(train_loader)}] "
                  f"Loss: {loss.item():.4f} (avg: {avg_loss:.4f}) "
                  f"Time: {avg_time:.3f}s/batch")
            
            if wandb_run:
                wandb_run.log({
                    "train/batch_loss": loss.item(),
                    "train/avg_loss": avg_loss,
                    "train/batch_time": avg_time,
                }, step=epoch * len(train_loader) + batch_idx)
    
    # Compute epoch metrics
    all_preds = torch.cat(predictions, dim=0).long()
    all_ans = torch.cat(answers, dim=0).long()
    epoch_acc = (all_preds == all_ans).sum().item() * 100.0 / len(all_ans)
    epoch_loss = total_loss / len(train_loader)
    epoch_time = time.time() - start_time
    
    # Per question type accuracy
    qtype_acc = {}
    qtype_names = ['Des', 'Exp', 'Pred-A', 'Pred-R', 'CF-A', 'CF-R']
    for qt in range(6):
        if qtype_total[qt] > 0:
            qtype_acc[qtype_names[qt]] = qtype_correct[qt] * 100.0 / qtype_total[qt]
    
    return {
        'loss': epoch_loss,
        'acc': epoch_acc,
        'time': epoch_time,
        'qtype_acc': qtype_acc
    }


def evaluate(model, data_loader, device, split_name='val'):
    """Evaluate with detailed per-type accuracy"""
    model.eval()
    
    predictions = []
    answers = []
    qtype_correct = defaultdict(int)
    qtype_total = defaultdict(int)
    
    with torch.no_grad():
        for inputs in data_loader:
            vid_frame_feat, vid_obj_feat, qns_w, ans_w, ans_id, qns_keys = inputs
            vid_frame_feat = vid_frame_feat.to(device)
            vid_obj_feat = vid_obj_feat.to(device)
            
            out = model(vid_frame_feat, vid_obj_feat, qns_w, ans_w)
            pred = out.max(-1)[1].cpu()
            
            predictions.append(pred)
            answers.append(ans_id)
            
            for qkey, p, a in zip(qns_keys, pred.numpy(), ans_id.numpy()):
                qtype = int(qkey.split('_')[-1])
                qtype_total[qtype] += 1
                if p == a:
                    qtype_correct[qtype] += 1
    
    all_preds = torch.cat(predictions, dim=0).long()
    all_ans = torch.cat(answers, dim=0).long()
    overall_acc = (all_preds == all_ans).sum().item() * 100.0 / len(all_ans)
    
    # Per question type accuracy
    qtype_names = ['Des', 'Exp', 'Pred-A', 'Pred-R', 'CF-A', 'CF-R']
    qtype_acc = {}
    for qt in range(6):
        if qtype_total[qt] > 0:
            qtype_acc[qtype_names[qt]] = qtype_correct[qt] * 100.0 / qtype_total[qt]
    
    # Combined metrics (Pred = both Pred-A and Pred-R correct for same video)
    # This is computed at video level, need results dict for that
    
    return {
        'acc': overall_acc,
        'qtype_acc': qtype_acc,
        'n_samples': len(all_ans)
    }


def predict_and_save(model, data_loader, device, save_path):
    """Generate predictions and save to JSON"""
    model.eval()
    results = {}
    
    with torch.no_grad():
        for inputs in data_loader:
            vid_frame_feat, vid_obj_feat, qns_w, ans_w, ans_id, qns_keys = inputs
            vid_frame_feat = vid_frame_feat.to(device)
            vid_obj_feat = vid_obj_feat.to(device)
            
            out = model(vid_frame_feat, vid_obj_feat, qns_w, ans_w)
            pred = out.max(-1)[1].cpu()
            
            for qid, p, a in zip(qns_keys, pred.numpy(), ans_id.numpy()):
                results[qid] = {'prediction': int(p), 'answer': int(a)}
    
    with open(save_path, 'w') as f:
        json.dump(results, f, indent=2)
    
    # Compute accuracy
    correct = sum(1 for v in results.values() if v['prediction'] == v['answer'])
    acc = correct * 100.0 / len(results)
    
    return results, acc


print("‚úÖ Training functions defined")

## 9. üöÄ Training Loop

In [None]:
# ============================================================
# SETUP OPTIMIZER, SCHEDULER, CRITERION
# ============================================================
os.makedirs('./models', exist_ok=True)
os.makedirs('./prediction', exist_ok=True)

# Get base model for parameter groups (handle DataParallel)
base_model = model.module if hasattr(model, 'module') else model

# Optimizer with different LR for text encoder
param_groups = [
    {
        "params": [p for n, p in base_model.named_parameters() 
                   if "text_encoder" not in n and p.requires_grad],
        "lr": args.lr
    },
    {
        "params": [p for n, p in base_model.named_parameters() 
                   if "text_encoder" in n and p.requires_grad],
        "lr": args.text_encoder_lr
    }
]

optimizer = torch.optim.AdamW(param_groups, weight_decay=args.decay)
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=args.gamma, 
                               patience=args.patience, verbose=True)
criterion = nn.CrossEntropyLoss()

print("‚úÖ Optimizer and scheduler created")
print(f"   Main LR: {args.lr}")
print(f"   Text encoder LR: {args.text_encoder_lr}")

In [None]:
# ============================================================
# INITIALIZE WANDB
# ============================================================
wandb_config = {
    "architecture": "TranSTR",
    "dataset": "CausalVidQA",
    "epochs": args.epoch,
    "batch_size": args.bs,
    "learning_rate": args.lr,
    "text_encoder_lr": args.text_encoder_lr,
    "text_encoder": args.text_encoder_type,
    "d_model": args.d_model,
    "topK_frame": args.topK_frame,
    "topK_obj": args.topK_obj,
    "n_objects": args.objs,
    "num_encoder_layers": args.num_encoder_layers,
    "num_decoder_layers": args.num_decoder_layers,
    "multi_gpu": args.use_multi_gpu,
    "n_gpus": torch.cuda.device_count(),
    "train_samples": len(train_dataset),
    "val_samples": len(val_dataset),
    "test_samples": len(test_dataset),
}

run = wandb.init(
    project=args.project_name,
    name=args.run_name,
    config=wandb_config,
    tags=["causalvid", "multi-gpu" if args.use_multi_gpu else "single-gpu"]
)

# Log dataset info
wandb.log({
    "data/train_videos": len(train_dataset.vids) if hasattr(train_dataset, 'vids') else 0,
    "data/val_videos": len(val_dataset.vids) if hasattr(val_dataset, 'vids') else 0,
    "data/test_videos": len(test_dataset.vids) if hasattr(test_dataset, 'vids') else 0,
})

print(f"‚úÖ W&B initialized: {run.url}")

## 10. üöÄ Training Loop

In [None]:
# ============================================================
# TRAINING LOOP WITH EARLY STOPPING
# ============================================================
best_val_acc = 0.0
best_epoch = 1
best_model_path = f'./models/best_model-{args.run_name}.ckpt'
history = {'train': [], 'val': [], 'test': []}
epochs_without_improvement = 0  # Early stopping counter

print("=" * 70)
print(f"üöÄ STARTING TRAINING: {args.run_name}")
print(f"   Epochs: {args.epoch} | Batch size: {args.bs} | GPUs: {torch.cuda.device_count()}")
print(f"   Early stopping: {args.early_stopping_patience} epochs without improvement")
print("=" * 70)

for epoch in range(1, args.epoch + 1):
    print(f"\n{'='*70}")
    print(f"üìö EPOCH [{epoch}/{args.epoch}]")
    print(f"{'='*70}")
    
    # ============ TRAIN ============
    train_metrics = train_epoch(model, optimizer, train_loader, criterion, device, epoch, run)
    
    # ============ EVALUATE ============
    val_metrics = evaluate(model, val_loader, device, 'val')
    test_metrics = evaluate(model, test_loader, device, 'test')
    
    # ============ UPDATE SCHEDULER ============
    scheduler.step(val_metrics['acc'])
    current_lr = optimizer.param_groups[0]['lr']
    
    # ============ SAVE BEST MODEL & EARLY STOPPING ============
    is_best = val_metrics['acc'] > best_val_acc
    if is_best:
        best_val_acc = val_metrics['acc']
        best_epoch = epoch
        epochs_without_improvement = 0  # Reset counter
        # Save model (handle DataParallel)
        state_dict = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
        torch.save(state_dict, best_model_path)
    else:
        epochs_without_improvement += 1
    
    # ============ LOGGING ============
    print(f"\n  üìä Results:")
    print(f"     {'Metric':<15} {'Train':>10} {'Val':>10} {'Test':>10}")
    print(f"     {'-'*15} {'-'*10} {'-'*10} {'-'*10}")
    print(f"     {'Loss':<15} {train_metrics['loss']:>10.4f} {'-':>10} {'-':>10}")
    print(f"     {'Accuracy':<15} {train_metrics['acc']:>9.2f}% {val_metrics['acc']:>9.2f}% {test_metrics['acc']:>9.2f}%")
    
    # Per question type accuracy
    print(f"\n  üìà Per Question Type Accuracy (Val):")
    qtype_order = ['Des', 'Exp', 'Pred-A', 'Pred-R', 'CF-A', 'CF-R']
    for qt in qtype_order:
        if qt in val_metrics['qtype_acc']:
            print(f"     {qt:<10}: {val_metrics['qtype_acc'][qt]:>6.2f}%")
    
    print(f"\n  ‚è±Ô∏è  Time: {train_metrics['time']:.1f}s | LR: {current_lr:.2e}")
    print(f"  üìâ No improvement: {epochs_without_improvement}/{args.early_stopping_patience} epochs")
    if is_best:
        print(f"  üíæ Saved best model! (Val acc: {best_val_acc:.2f}%)")
    
    # ============ WANDB LOGGING ============
    wandb_log = {
        "epoch": epoch,
        "train/loss": train_metrics['loss'],
        "train/acc": train_metrics['acc'],
        "val/acc": val_metrics['acc'],
        "test/acc": test_metrics['acc'],
        "lr": current_lr,
        "epoch_time": train_metrics['time'],
        "best_val_acc": best_val_acc,
        "epochs_without_improvement": epochs_without_improvement,
    }
    
    # Log per question type accuracy
    for qt, acc in train_metrics['qtype_acc'].items():
        wandb_log[f"train/acc_{qt}"] = acc
    for qt, acc in val_metrics['qtype_acc'].items():
        wandb_log[f"val/acc_{qt}"] = acc
    for qt, acc in test_metrics['qtype_acc'].items():
        wandb_log[f"test/acc_{qt}"] = acc
    
    wandb.log(wandb_log)
    
    # Save checkpoint every N epochs
    if epoch % args.save_every == 0:
        ckpt_path = f'./models/checkpoint-{args.run_name}-ep{epoch}.ckpt'
        state_dict = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
        torch.save({
            'epoch': epoch,
            'model_state_dict': state_dict,
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_metrics['acc'],
        }, ckpt_path)
        print(f"  üìÅ Checkpoint saved: {ckpt_path}")
    
    # ============ EARLY STOPPING CHECK ============
    if epochs_without_improvement >= args.early_stopping_patience:
        print(f"\n  ‚ö†Ô∏è EARLY STOPPING: No improvement for {args.early_stopping_patience} epochs")
        print(f"     Best val acc: {best_val_acc:.2f}% at epoch {best_epoch}")
        wandb.log({"early_stopped": True, "stopped_at_epoch": epoch})
        break

# ============================================================
# TRAINING COMPLETE
# ============================================================
print("\n" + "=" * 70)
print("‚úÖ TRAINING COMPLETED!")
print("=" * 70)
print(f"   Best epoch: {best_epoch}")
print(f"   Best val accuracy: {best_val_acc:.2f}%")
print(f"   Model saved: {best_model_path}")
if epochs_without_improvement >= args.early_stopping_patience:
    print(f"   Stopped early at epoch {epoch}")
print("=" * 70)

In [None]:
# ============================================================
# FINAL EVALUATION WITH BEST MODEL
# ============================================================
print("=" * 70)
print("üìä FINAL EVALUATION")
print("=" * 70)

# Load best model
print("\n  Loading best model...")
base_model = model.module if hasattr(model, 'module') else model
base_model.load_state_dict(torch.load(best_model_path))

# Predict on test set
result_path = f'./prediction/{args.run_name}-ep{best_epoch}-val{best_val_acc:.2f}.json'
results, test_acc = predict_and_save(model, test_loader, device, result_path)

print(f"\n  Test accuracy: {test_acc:.2f}%")
print(f"  Results saved: {result_path}")

# Detailed evaluation by question type
print("\n" + "-" * 70)
print("  Detailed Results by Question Type:")
print("-" * 70)
eval_mc.accuracy_metric_cvid(result_path)

# Log final results to wandb
wandb.log({
    "final/test_acc": test_acc,
    "final/best_epoch": best_epoch,
    "final/best_val_acc": best_val_acc,
})

# Save results artifact
artifact = wandb.Artifact(f'predictions-{args.run_name}', type='predictions')
artifact.add_file(result_path)
wandb.log_artifact(artifact)

## 12. üíæ Save & Cleanup

In [None]:
# ============================================================
# SAVE TO KAGGLE OUTPUT & FINISH WANDB
# ============================================================
import shutil

output_dir = '/kaggle/working'
if os.path.exists(output_dir):
    # Copy best model
    shutil.copy(best_model_path, os.path.join(output_dir, f'best_model-{args.run_name}.ckpt'))
    # Copy predictions
    shutil.copy(result_path, output_dir)
    print(f"‚úÖ Files saved to {output_dir}")
else:
    print("  Not running on Kaggle, files saved locally")

# Save model artifact to wandb
model_artifact = wandb.Artifact(f'model-{args.run_name}', type='model')
model_artifact.add_file(best_model_path)
wandb.log_artifact(model_artifact)

# Finish wandb run
wandb.finish()
print("‚úÖ W&B run finished")

---

## üìù Notes

### Multi-GPU Usage
- Code s·ª≠ d·ª•ng `DataParallel` ƒë·ªÉ chia batch ƒë·ªÅu cho 2 GPU
- Batch size 16 s·∫Ω ƒë∆∞·ª£c chia th√†nh 8 samples/GPU
- ƒê·ªÉ tƒÉng throughput, c√≥ th·ªÉ tƒÉng batch size l√™n 32 ho·∫∑c 64

### W&B Metrics Logged
- `train/loss`, `train/acc` - Training metrics
- `val/acc`, `test/acc` - Evaluation accuracy  
- `train/acc_*`, `val/acc_*` - Per question type accuracy
- `lr` - Current learning rate
- `epoch_time` - Time per epoch

### Question Types
- **Des**: Descriptive (What is happening?)
- **Exp**: Explanatory (Why did it happen?)
- **Pred-A/R**: Predictive Answer/Reason
- **CF-A/R**: Counterfactual Answer/Reason