In [5]:
!git clone https://github.com/DanielQH07/tranSTR_Casual.git

Cloning into 'tranSTR_Casual'...
remote: Enumerating objects: 43, done.[K
remote: Counting objects: 100% (43/43), done.[K
remote: Compressing objects: 100% (34/34), done.[K
remote: Total 43 (delta 11), reused 41 (delta 9), pack-reused 0 (from 0)[K
Receiving objects: 100% (43/43), 267.42 KiB | 3.38 MiB/s, done.
Resolving deltas: 100% (11/11), done.


# üöÄ CausalVidQA Training on Kaggle (T4 x2 GPU)

This notebook trains the TranSTR model on CausalVidQA dataset with:
- ‚úÖ **Multi-GPU support** (DataParallel for 2x T4)
- ‚úÖ **W&B logging** with per-question-type metrics
- ‚úÖ **Early stopping** after 5 epochs without improvement
- ‚úÖ **Full test evaluation** regardless of training sample limit

## üìã Cell Execution Order:
1. Clone repo from GitHub
2. Navigate to repo directory
3. **Patch attention.py** (DataParallel fix)
4. **Patch model.py** (repeat_interleave fix)
5. Patch DataLoader.py (dimension handling)
6. Continue with training setup...

**Important:** After cloning, patches are applied to fix DataParallel compatibility issues.

In [6]:
%cd /kaggle/working/tranSTR_Casual

/kaggle/working/tranSTR_Casual


In [None]:
# ============================================================
# PATCH ATTENTION.PY - Fix DataParallel mask handling
# ============================================================

attention_fix = '''import math
import torch
from torch import nn

class MultiheadAttention(nn.Module):
    def __init__(self, dim, n_heads, dropout=0.1 ):
        super().__init__()

        self.n_heads = n_heads
        self.dim = dim
        self.dropout = nn.Dropout(p=dropout)

        assert self.dim % self.n_heads == 0

        self.q_lin = nn.Linear(in_features=self.dim, out_features=self.dim)
        self.k_lin = nn.Linear(in_features=self.dim, out_features=self.dim)
        self.v_lin = nn.Linear(in_features=self.dim, out_features=self.dim)
        self.out_lin = nn.Linear(in_features=self.dim, out_features=self.dim)

    def forward(self, query, key, value, key_padding_mask, attn_mask=None, output_attentions=False):
        """
        Parameters
        ----------
        query: torch.tensor(bs, seq_length, dim)
        key: torch.tensor(bs, seq_length, dim)
        value: torch.tensor(bs, seq_length, dim)
        key_padding_mask: torch.tensor(bs, seq_length)
        Outputs
        -------
        weights: torch.tensor(bs, n_heads, seq_length, seq_length)
            Attention weights
        context: torch.tensor(bs, seq_length, dim)
            Contextualized layer. Optional: only if `output_attentions=True`
        """
        bs, q_length, dim = query.size()
        k_length = key.size(1)

        dim_per_head = self.dim // self.n_heads

        def shape(x):
            """ separate heads """
            return x.view(bs, -1, self.n_heads, dim_per_head).transpose(1, 2)

        def unshape(x):
            """ group heads """
            return (
                x.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * dim_per_head)
            )

        q = shape(self.q_lin(query))  # (bs, n_heads, q_length, dim_per_head)
        k = shape(self.k_lin(key))  # (bs, n_heads, k_length, dim_per_head)
        v = shape(self.v_lin(value))  # (bs, n_heads, k_length, dim_per_head)

        q = q / math.sqrt(dim_per_head)  # (bs, n_heads, q_length, dim_per_head)
        scores = torch.matmul(q, k.transpose(2, 3))  # (bs, n_heads, q_length, k_length)
        scores_ = scores.mean(1)
        
        if key_padding_mask is not None:
            # DataParallel fix: get actual batch size from mask
            actual_bs = key_padding_mask.size(0)
            mask_len = key_padding_mask.size(1)
            
            # Trim or pad mask to match k_length
            if mask_len > k_length:
                key_padding_mask = key_padding_mask[:, :k_length]
            elif mask_len < k_length:
                padding = torch.ones(actual_bs, k_length - mask_len, 
                                   dtype=key_padding_mask.dtype, 
                                   device=key_padding_mask.device)
                key_padding_mask = torch.cat([key_padding_mask, padding], dim=1)
            
            # Reshape mask - use actual_bs not bs
            mask_reshp = (actual_bs, 1, 1, k_length)
            padding_mask = (~key_padding_mask).view(mask_reshp).expand_as(scores)
            scores = scores.masked_fill(padding_mask, -float("inf"))

        weights = nn.Softmax(dim=-1)(scores)  # (bs, n_heads, q_length, k_length)
        weights = self.dropout(weights)  # (bs, n_heads, q_length, k_length)

        # Mask heads if we want to
        if attn_mask is not None:
            weights = weights * attn_mask

        context = torch.matmul(weights, v)  # (bs, n_heads, q_length, dim_per_head)
        context = unshape(context)  # (bs, q_length, dim)
        context = self.out_lin(context)  # (bs, q_length, dim)

        if output_attentions:
            return context, weights.mean(1)
        else:
            return context
'''

# Write patched attention.py
import os
os.makedirs('networks', exist_ok=True)
with open('networks/attention.py', 'w') as f:
    f.write(attention_fix)

print("‚úÖ attention.py patched for DataParallel compatibility!")
print(f"   File written to: {os.path.abspath('networks/attention.py')}")

In [None]:
# ============================================================
# PATCH MODEL.PY - Fix repeat_interleave for DataParallel
# ============================================================

# Read current model.py
with open('networks/model.py', 'r') as f:
    model_code = f.read()

# Fix the obj_decoder call to handle repeat_interleave properly
old_pattern = '''        obj_local, obj_att = self.obj_decoder(obj_local.flatten(0,1),
                                            q_local.repeat_interleave(self.frame_topK, dim=0), 
                                            memory_key_padding_mask=q_mask.repeat_interleave(self.frame_topK, dim=0),
                                            output_attentions=True
                                            )  # b*16,5,d        #.view(B, F, O, -1) # b,16,5,d'''

new_pattern = '''        # Repeat q_local and q_mask for each frame (handle potential batch size mismatch)
        q_local_repeated = q_local.repeat_interleave(self.frame_topK, dim=0)
        q_mask_repeated = q_mask.repeat_interleave(self.frame_topK, dim=0) if q_mask is not None else None
        
        obj_local, obj_att = self.obj_decoder(obj_local.flatten(0,1),
                                            q_local_repeated, 
                                            memory_key_padding_mask=q_mask_repeated,
                                            output_attentions=True
                                            )  # b*16,5,d        #.view(B, F, O, -1) # b,16,5,d'''

if old_pattern in model_code:
    model_code = model_code.replace(old_pattern, new_pattern)
    with open('networks/model.py', 'w') as f:
        f.write(model_code)
    print("‚úÖ model.py patched for DataParallel compatibility!")
else:
    print("‚ö†Ô∏è  Pattern not found - model.py may already be patched or different")

In [7]:
# Patch DataLoader.py ƒë·ªÉ x·ª≠ l√Ω dimension mismatch
# Ch·∫°y cell n√†y tr∆∞·ªõc khi import DataLoader

patch_code = '''
import torch
import os
import h5py
import os.path as osp
import numpy as np
import json
import pickle as pkl
from torch.utils import data
from utils.util import load_file, pause, transform_bb, pkload
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaTokenizerFast


class VideoQADataset(Dataset):
    """
    DataLoader cho CausalVidQA v·ªõi output format t∆∞∆°ng th√≠ch NextQA
    """
    
    def __init__(self, split, n_query=5, obj_num=1, 
                 sample_list_path=None,
                 video_feature_path=None,
                 text_annotation_path=None,
                 qtype=-1,
                 max_samples=None):
        super(VideoQADataset, self).__init__()
        
        self.split = split
        self.mc = n_query
        self.obj_num = obj_num
        self.qtype = qtype
        self.video_feature_path = video_feature_path
        self.text_annotation_path = text_annotation_path
        self.max_samples = max_samples
        
        # Load video ids for this split
        split_name = split
        if split == 'val':
            split_file = osp.join(sample_list_path, 'val.pkl')
            if not osp.exists(split_file):
                split_file = osp.join(sample_list_path, 'valid.pkl')
        else:
            split_file = osp.join(sample_list_path, f'{split}.pkl')
        
        if not osp.exists(split_file):
            raise FileNotFoundError(f"Split file not found: {split_file}")
        
        self.vids = pkload(split_file)
        
        if self.vids is None:
            raise ValueError(f"Failed to load split file: {split_file}")
        
        if max_samples is not None and max_samples > 0:
            self.vids = self.vids[:max_samples]
            print(f"Limited to {len(self.vids)} videos (max_samples={max_samples})")
        else:
            print(f"Loaded {len(self.vids)} videos from {split_file}")
        
        # Load video feature index mapping
        idx2vid_file = osp.join(video_feature_path, 'idx2vid.pkl')
        vf_info = pkload(idx2vid_file)
        self.vf_info = dict()
        for idx, vid in enumerate(vf_info):
            if vid in self.vids:
                self.vf_info[vid] = idx
        
        # Load appearance features
        app_file = osp.join(video_feature_path, 'appearance_feat.h5')
        print(f'Loading {app_file}...')
        self.app_feats = dict()
        with h5py.File(app_file, 'r') as fp:
            feats = fp['resnet_features']
            for vid, idx in self.vf_info.items():
                self.app_feats[vid] = feats[idx][...]
        
        # Load motion features
        mot_file = osp.join(video_feature_path, 'motion_feat.h5')
        print(f'Loading {mot_file}...')
        self.mot_feats = dict()
        with h5py.File(mot_file, 'r') as fp:
            feats = fp['resnet_features']
            for vid, idx in self.vf_info.items():
                self.mot_feats[vid] = feats[idx][...]
        
        self._build_sample_list()

    def _build_sample_list(self):
        self.samples = []
        
        if self.qtype == -1:
            for vid in self.vids:
                for qt in range(6):
                    self.samples.append((vid, qt))
        elif self.qtype == 0 or self.qtype == 1:
            for vid in self.vids:
                self.samples.append((vid, self.qtype))
        elif self.qtype == 2:
            for vid in self.vids:
                self.samples.append((vid, 2))
                self.samples.append((vid, 3))
        elif self.qtype == 3:
            for vid in self.vids:
                self.samples.append((vid, 4))
                self.samples.append((vid, 5))
        else:
            for vid in self.vids:
                self.samples.append((vid, self.qtype))
        
        print(f"Total samples: {len(self.samples)}")

    def _load_text(self, vid, qtype):
        text_file = osp.join(self.text_annotation_path, vid, 'text.json')
        answer_file = osp.join(self.text_annotation_path, vid, 'answer.json')
        
        if not osp.exists(text_file):
            text_file = osp.join(self.text_annotation_path, 'QA', vid, 'text.json')
            answer_file = osp.join(self.text_annotation_path, 'QA', vid, 'answer.json')
        
        if not osp.exists(text_file):
            raise FileNotFoundError(f"Text annotation not found for video: {vid}")
        
        with open(text_file, 'r') as f:
            text = json.load(f)
        with open(answer_file, 'r') as f:
            answer = json.load(f)
        
        if qtype == 0:
            qns = text['descriptive']['question']
            cand_ans = text['descriptive']['answer']
            ans_id = answer['descriptive']['answer']
        elif qtype == 1:
            qns = text['explanatory']['question']
            cand_ans = text['explanatory']['answer']
            ans_id = answer['explanatory']['answer']
        elif qtype == 2:
            qns = text['predictive']['question']
            cand_ans = text['predictive']['answer']
            ans_id = answer['predictive']['answer']
        elif qtype == 3:
            qns = text['predictive']['question']
            cand_ans = text['predictive']['reason']
            ans_id = answer['predictive']['reason']
        elif qtype == 4:
            qns = text['counterfactual']['question']
            cand_ans = text['counterfactual']['answer']
            ans_id = answer['counterfactual']['answer']
        elif qtype == 5:
            qns = text['counterfactual']['question']
            cand_ans = text['counterfactual']['reason']
            ans_id = answer['counterfactual']['reason']
        else:
            raise ValueError(f"Invalid qtype: {qtype}")
        
        return qns, cand_ans, ans_id


    def __getitem__(self, idx):
        vid, qtype = self.samples[idx]
        
        qns_word, cand_ans, ans_id = self._load_text(vid, qtype)
        ans_word = ['[CLS] ' + qns_word + ' [SEP] ' + str(cand_ans[i]) for i in range(self.mc)]
        
        # Load video features
        app_feat = self.app_feats[vid]
        mot_feat = self.mot_feats[vid]
        
        # === FIX: Handle different feature shapes ===
        # Squeeze or reshape if needed to get (T, D)
        if app_feat.ndim == 3:
            app_feat = app_feat.mean(axis=1) if app_feat.shape[1] > 1 else app_feat.squeeze(1)
        if mot_feat.ndim == 3:
            mot_feat = mot_feat.mean(axis=1) if mot_feat.shape[1] > 1 else mot_feat.squeeze(1)
        
        if app_feat.ndim == 1:
            app_feat = app_feat[np.newaxis, :]
        if mot_feat.ndim == 1:
            mot_feat = mot_feat[np.newaxis, :]
        # === END FIX ===
        
        # Frame feature: concatenate app + mot
        frame_feat = np.concatenate([app_feat, mot_feat], axis=-1)
        vid_frame_feat = torch.from_numpy(frame_feat).type(torch.float32)
        
        # Object features
        T = app_feat.shape[0]
        D_obj = app_feat.shape[-1]
        
        obj_feat = np.tile(app_feat[:, np.newaxis, :], (1, self.obj_num, 1))
        dummy_bbox = np.zeros((T, self.obj_num, 5), dtype=np.float32)
        dummy_bbox[:, :, :4] = np.array([0.0, 0.0, 1.0, 1.0])
        dummy_bbox[:, :, 4] = 1.0
        
        obj_feat = np.concatenate([obj_feat, dummy_bbox], axis=-1)
        vid_obj_feat = torch.from_numpy(obj_feat).type(torch.float32)
        
        qns_key = vid + '_' + str(qtype)
        
        return vid_frame_feat, vid_obj_feat, qns_word, ans_word, ans_id, qns_key


    def __len__(self):
        return len(self.samples)
'''

# Write patched DataLoader.py
with open('DataLoader.py', 'w') as f:
    f.write(patch_code)

print("‚úÖ DataLoader.py patched with dimension fix!")

‚úÖ DataLoader.py patched with dimension fix!


In [8]:
import os
import pickle
import h5py
import json

# ============================================================
# DATA PATHS (Kaggle Input)
# ============================================================
text_feature_path = '/kaggle/input/text-feature'
visual_feature_path = '/kaggle/input/visual-feature'
split_path = '/kaggle/input/casual-vid-data-split/split'
text_annotation_path = '/kaggle/input/text-annotation'

print("=" * 70)
print("üìÇ DATA PATHS")
print("=" * 70)
for name, path in [("Visual features", visual_feature_path), 
                   ("Split files", split_path), 
                   ("Text annotations", text_annotation_path)]:
    status = "‚úì" if os.path.exists(path) else "‚úó"
    print(f"  {status} {name}: {path}")

# ============================================================
# DATA STATISTICS
# ============================================================
print("\n" + "=" * 70)
print("üìä DATASET STATISTICS")
print("=" * 70)

# 1. Split files
print("\nüìÅ Split Files:")
split_stats = {}
for split_name in ['train', 'valid', 'test']:
    split_file = f'{split_path}/{split_name}.pkl'
    if os.path.exists(split_file):
        with open(split_file, 'rb') as f:
            vids = pickle.load(f)
        split_stats[split_name] = len(vids)
        samples = len(vids) * 6  # 6 question types per video
        print(f"  {split_name:>6}: {len(vids):>6} videos ‚Üí {samples:>6} samples")

# 2. Visual features
print("\nüé¨ Visual Features:")
idx2vid_file = f'{visual_feature_path}/idx2vid.pkl'
if os.path.exists(idx2vid_file):
    with open(idx2vid_file, 'rb') as f:
        idx2vid = pickle.load(f)
    print(f"  Indexed videos: {len(idx2vid)}")

for feat_name in ['appearance_feat.h5', 'motion_feat.h5']:
    feat_file = f'{visual_feature_path}/{feat_name}'
    if os.path.exists(feat_file):
        with h5py.File(feat_file, 'r') as f:
            shape = f['resnet_features'].shape
        print(f"  {feat_name}: {shape}")

# 3. Question types
print("\n‚ùì Question Types (qtype):")
qtype_info = [
    ("0", "Descriptive", "What is happening?"),
    ("1", "Explanatory", "Why did it happen?"),
    ("2", "Predictive-Ans", "What will happen?"),
    ("3", "Predictive-Reason", "Why will it happen?"),
    ("4", "Counterfactual-Ans", "What if X didn't happen?"),
    ("5", "Counterfactual-Reason", "Why would that result?"),
]
for qt, name, desc in qtype_info:
    print(f"  {qt}: {name:<20} - {desc}")

print("\n" + "=" * 70)

üìÇ DATA PATHS
  ‚úì Visual features: /kaggle/input/visual-feature
  ‚úì Split files: /kaggle/input/casual-vid-data-split/split
  ‚úì Text annotations: /kaggle/input/text-annotation

üìä DATASET STATISTICS

üìÅ Split Files:
   train:  18776 videos ‚Üí 112656 samples
   valid:   2695 videos ‚Üí  16170 samples
    test:   5429 videos ‚Üí  32574 samples

üé¨ Visual Features:
  Indexed videos: 26900
  appearance_feat.h5: (26900, 8, 16, 2048)
  motion_feat.h5: (26900, 8, 2048)

‚ùì Question Types (qtype):
  0: Descriptive          - What is happening?
  1: Explanatory          - Why did it happen?
  2: Predictive-Ans       - What will happen?
  3: Predictive-Reason    - Why will it happen?
  4: Counterfactual-Ans   - What if X didn't happen?
  5: Counterfactual-Reason - Why would that result?



In [9]:
!pip install -q transformers einops h5py wandb

# Login to W&B (uncomment v√† th√™m API key c·ªßa b·∫°n)
my_key = "80b5a02ccaed80f35a2e893aed6446d4467c0c45"
import wandb
wandb.login(key=my_key, relogin=True)

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mhaidang262004[0m ([33mintroSE[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [10]:
import os
import sys
import torch
import numpy as np
import wandb

# ============================================================
# CONFIGURATION
# ============================================================

class Config:
    """Training configuration for CausalVidQA"""
    
    # Experiment
    project_name = "CausalVidQA-TranSTR"
    run_name = "causalvid_2gpu"
    
    # Data paths
    sample_list_path = split_path
    video_feature_path = visual_feature_path
    text_annotation_path = text_annotation_path
    
    # Training
    bs = 2                    # Batch size (s·∫Ω chia ƒë·ªÅu cho 2 GPU)
    lr = 1e-4                  # Learning rate
    text_encoder_lr = 1e-5     # Text encoder LR (lower)
    epoch = 20
    warmup_epochs = 2          # Warmup epochs
    
    # Dataset
    dataset = 'causal-vid'
    qtype = 4                 # -1 = all question types
    max_samples = 500         # None = use all data
    
    # Model architecture
    d_model = 768
    word_dim = 768
    nheads = 8
    num_encoder_layers = 1
    num_decoder_layers = 1
    dropout = 0.1
    encoder_dropout = 0.1
    activation = 'relu'
    normalize_before = False
    
    # Video features
    objs = 20                  # Objects per frame
    topK_frame = 8             # Top-K frames to select
    topK_obj = 5               # Top-K objects to select
    frame_feat_dim = 4096      # app(2048) + mot(2048)
    obj_feat_dim = 2053        # feat(2048) + bbox(5)
    n_query = 5                # 5-way multiple choice
    
    # Text encoder
    text_encoder_type = "microsoft/deberta-base"
    freeze_text_encoder = False
    text_pool_mode = 0
    hard_eval = False
    
    # Optimizer
    decay = 0.001              # Weight decay
    patience = 3               # LR scheduler patience
    gamma = 0.5                # LR decay factor
    
    # Early stopping
    early_stopping_patience = 5  # Stop after 5 epochs without improvement
    
    # Contrastive learning
    pos_ratio = 0.7
    neg_ratio = 0.3
    a = 1
    
    # Multi-GPU
    use_multi_gpu = True       # Enable DataParallel
    num_workers = 2            # DataLoader workers
    
    # Logging
    log_interval = 50          # Log every N batches
    save_every = 5             # Save checkpoint every N epochs

args = Config()

# ============================================================
# GPU SETUP
# ============================================================
print("=" * 70)
print("üñ•Ô∏è GPU CONFIGURATION")
print("=" * 70)

n_gpus = torch.cuda.device_count()
print(f"  Available GPUs: {n_gpus}")
for i in range(n_gpus):
    print(f"    GPU {i}: {torch.cuda.get_device_name(i)}")
    mem = torch.cuda.get_device_properties(i).total_memory / 1e9
    print(f"           Memory: {mem:.1f} GB")

if n_gpus >= 2 and args.use_multi_gpu:
    print(f"\n  ‚úì Multi-GPU mode: DataParallel on {n_gpus} GPUs")
    print(f"  ‚úì Effective batch size: {args.bs} (total)")
else:
    print(f"\n  ‚Üí Single GPU mode")
    args.use_multi_gpu = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"  Primary device: {device}")

# ============================================================
# PRINT CONFIG
# ============================================================
print("\n" + "=" * 70)
print("‚öôÔ∏è TRAINING CONFIG")
print("=" * 70)
config_items = [
    ("Batch size", args.bs),
    ("Learning rate", args.lr),
    ("Text encoder LR", args.text_encoder_lr),
    ("Epochs", args.epoch),
    ("Early stopping", f"{args.early_stopping_patience} epochs"),
    ("d_model", args.d_model),
    ("TopK frames", args.topK_frame),
    ("TopK objects", args.topK_obj),
    ("Objects/frame", args.objs),
    ("Text encoder", args.text_encoder_type),
]

for name, val in config_items:
    print(f"  {name:<20}: {val}")
print("=" * 70)

üñ•Ô∏è GPU CONFIGURATION
  Available GPUs: 2
    GPU 0: Tesla T4
           Memory: 15.8 GB
    GPU 1: Tesla T4
           Memory: 15.8 GB

  ‚úì Multi-GPU mode: DataParallel on 2 GPUs
  ‚úì Effective batch size: 2 (total)
  Primary device: cuda

‚öôÔ∏è TRAINING CONFIG
  Batch size          : 2
  Learning rate       : 0.0001
  Text encoder LR     : 1e-05
  Epochs              : 20
  Early stopping      : 5 epochs
  d_model             : 768
  TopK frames         : 8
  TopK objects        : 5
  Objects/frame       : 20
  Text encoder        : microsoft/deberta-base


In [11]:
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingWarmRestarts
from collections import defaultdict
import time
import json

# Local imports
from DataLoader import VideoQADataset
from networks.model import VideoQAmodel
import eval_mc

# ============================================================
# REPRODUCIBILITY
# ============================================================
def set_seed(seed=999):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(999)
print("‚úÖ Modules imported, seed set to 999")

‚úÖ Modules imported, seed set to 999


In [12]:
print("Creating datasets...")

# ============================================================
# CREATE DATASETS
# ============================================================
dataset_kwargs = dict(
    n_query=args.n_query,
    obj_num=args.objs,
    sample_list_path=args.sample_list_path,
    video_feature_path=args.video_feature_path,
    text_annotation_path=args.text_annotation_path,
    qtype=args.qtype,
    max_samples=args.max_samples
)

train_dataset = VideoQADataset(split='train', **dataset_kwargs)
val_dataset = VideoQADataset(split='val', **dataset_kwargs)

# Test set LU√îN d√πng to√†n b·ªô data (kh√¥ng gi·ªõi h·∫°n max_samples)
test_kwargs = dataset_kwargs.copy()
test_kwargs['max_samples'] = None  # Force full test set
test_dataset = VideoQADataset(split='test', **test_kwargs)

# ============================================================
# CREATE DATALOADERS (optimized for multi-GPU)
# ============================================================
loader_kwargs = dict(
    batch_size=args.bs,
    num_workers=args.num_workers if args.use_multi_gpu else 0,
    pin_memory=True,
    prefetch_factor=2 if args.num_workers > 0 else None,
)

train_loader = DataLoader(train_dataset, shuffle=True,drop_last=True, **loader_kwargs)
val_loader = DataLoader(val_dataset, shuffle=False, **loader_kwargs)
test_loader = DataLoader(test_dataset, shuffle=False, **loader_kwargs)

# ============================================================
# DATASET SUMMARY
# ============================================================
print("\n" + "=" * 70)
print("üìä DATALOADER SUMMARY")
print("=" * 70)
print(f"  {'Split':<10} {'Videos':>10} {'Samples':>10} {'Batches':>10}")
print(f"  {'-'*10} {'-'*10} {'-'*10} {'-'*10}")
for name, dataset, loader in [
    ("Train", train_dataset, train_loader),
    ("Val", val_dataset, val_loader),
    ("Test (FULL)", test_dataset, test_loader)
]:
    n_vids = len(dataset.vids) if hasattr(dataset, 'vids') else "?"
    print(f"  {name:<10} {n_vids:>10} {len(dataset):>10} {len(loader):>10}")
print("=" * 70)
print("  ‚ÑπÔ∏è  Test set always uses ALL data regardless of max_samples")

Creating datasets...
Limited to 500 videos (max_samples=500)
Loading /kaggle/input/visual-feature/appearance_feat.h5...
Loading /kaggle/input/visual-feature/motion_feat.h5...
Total samples: 500
Limited to 500 videos (max_samples=500)
Loading /kaggle/input/visual-feature/appearance_feat.h5...
Loading /kaggle/input/visual-feature/motion_feat.h5...
Total samples: 500
Loaded 5429 videos from /kaggle/input/casual-vid-data-split/split/test.pkl
Loading /kaggle/input/visual-feature/appearance_feat.h5...
Loading /kaggle/input/visual-feature/motion_feat.h5...
Total samples: 5429

üìä DATALOADER SUMMARY
  Split          Videos    Samples    Batches
  ---------- ---------- ---------- ----------
  Train             500        500        250
  Val               500        500        250
  Test (FULL)       5429       5429       2715
  ‚ÑπÔ∏è  Test set always uses ALL data regardless of max_samples


In [13]:
# ============================================================
# VERIFY DATA SAMPLE
# ============================================================
print("üîç Verifying data sample...")

for batch in train_loader:
    vid_frame_feat, vid_obj_feat, qns_word, ans_word, ans_id, qns_key = batch
    
    print(f"\n  Frame features:  {vid_frame_feat.shape}")
    print(f"  Object features: {vid_obj_feat.shape}")
    print(f"  Batch size:      {len(qns_word)}")
    print(f"\n  Sample question: {qns_word[0][:80]}...")
    print(f"  Sample answer:   {ans_word[0][0][:60]}...")
    print(f"  Ground truth:    {ans_id[0].item()}")
    print(f"  Question key:    {qns_key[0]}")
    break

print("\n‚úÖ Data verification complete!")

üîç Verifying data sample...

  Frame features:  torch.Size([2, 8, 4096])
  Object features: torch.Size([2, 8, 20, 2053])
  Batch size:      2

  Sample question: What will happen if [person_1] cries?...
  Sample answer:   [CLS] What will happen if [person_1] cries? [SEP] [person_3]...
  Ground truth:    0
  Question key:    tEdsMfaFCQM_000134_000144_4

‚úÖ Data verification complete!


In [14]:
# ============================================================
# CREATE MODEL
# ============================================================
print("üèóÔ∏è Creating model...")

model_config = {
    'd_model': args.d_model,
    'word_dim': args.word_dim,
    'encoder_dropout': args.encoder_dropout,
    'dropout': args.dropout,
    'num_encoder_layers': args.num_encoder_layers,
    'num_decoder_layers': args.num_decoder_layers,
    'nheads': args.nheads,
    'normalize_before': args.normalize_before,
    'activation': args.activation,
    'text_encoder_type': args.text_encoder_type,
    'freeze_text_encoder': args.freeze_text_encoder,
    'text_pool_mode': args.text_pool_mode,
    'n_query': args.n_query,
    'objs': args.objs,
    'topK_frame': args.topK_frame,
    'topK_obj': args.topK_obj,
    'hard_eval': args.hard_eval,
    'frame_feat_dim': args.frame_feat_dim,
    'obj_feat_dim': args.obj_feat_dim,
    'device': device,
}

model = VideoQAmodel(**model_config)

# ============================================================
# MULTI-GPU SETUP (DataParallel)
# ============================================================
if args.use_multi_gpu and torch.cuda.device_count() > 1:
    print(f"  ‚Üí Wrapping model with DataParallel ({torch.cuda.device_count()} GPUs)")
    model = nn.DataParallel(model)

model.to(device)

# ============================================================
# MODEL SUMMARY
# ============================================================
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("\n" + "=" * 70)
print("üß† MODEL SUMMARY")
print("=" * 70)
print(f"  Total parameters:     {total_params / 1e6:.2f}M")
print(f"  Trainable parameters: {trainable_params / 1e6:.2f}M")
print(f"  Multi-GPU:            {args.use_multi_gpu and torch.cuda.device_count() > 1}")
print("=" * 70)

üèóÔ∏è Creating model...


config.json:   0%|          | 0.00/474 [00:00<?, ?B/s]

2025-12-08 15:59:50.250529: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1765209590.467672      38 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1765209590.540981      38 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


pytorch_model.bin:   0%|          | 0.00/559M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/52.0 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/559M [00:00<?, ?B/s]

  ‚Üí Wrapping model with DataParallel (2 GPUs)

üß† MODEL SUMMARY
  Total parameters:     180.36M
  Trainable parameters: 180.36M
  Multi-GPU:            True


In [15]:
# ============================================================
# TRAINING FUNCTIONS
# ============================================================

def train_epoch(model, optimizer, train_loader, criterion, device, epoch, wandb_run=None):
    """Train for one epoch with detailed logging"""
    model.train()
    
    total_loss = 0.0
    predictions = []
    answers = []
    batch_times = []
    
    # Per question type tracking
    qtype_correct = defaultdict(int)
    qtype_total = defaultdict(int)
    
    start_time = time.time()
    
    for batch_idx, inputs in enumerate(train_loader):
        batch_start = time.time()
        
        vid_frame_feat, vid_obj_feat, qns_w, ans_w, ans_id, qns_keys = inputs
        vid_frame_feat = vid_frame_feat.to(device)
        vid_obj_feat = vid_obj_feat.to(device)
        ans_targets = ans_id.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        out = model(vid_frame_feat, vid_obj_feat, qns_w, ans_w)
        loss = criterion(out, ans_targets)
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        # Track metrics
        total_loss += loss.item()
        pred = out.max(-1)[1].cpu()
        predictions.append(pred)
        answers.append(ans_id)
        
        # Track per question type accuracy
        for qkey, p, a in zip(qns_keys, pred.numpy(), ans_id.numpy()):
            qtype = int(qkey.split('_')[-1])
            qtype_total[qtype] += 1
            if p == a:
                qtype_correct[qtype] += 1
        
        batch_times.append(time.time() - batch_start)
        
        # Logging
        if (batch_idx + 1) % args.log_interval == 0:
            avg_loss = total_loss / (batch_idx + 1)
            avg_time = np.mean(batch_times[-args.log_interval:])
            print(f"    Batch [{batch_idx+1:>4}/{len(train_loader)}] "
                  f"Loss: {loss.item():.4f} (avg: {avg_loss:.4f}) "
                  f"Time: {avg_time:.3f}s/batch")
            
            if wandb_run:
                wandb_run.log({
                    "train/batch_loss": loss.item(),
                    "train/avg_loss": avg_loss,
                    "train/batch_time": avg_time,
                }, step=epoch * len(train_loader) + batch_idx)
    
    # Compute epoch metrics
    all_preds = torch.cat(predictions, dim=0).long()
    all_ans = torch.cat(answers, dim=0).long()
    epoch_acc = (all_preds == all_ans).sum().item() * 100.0 / len(all_ans)
    epoch_loss = total_loss / len(train_loader)
    epoch_time = time.time() - start_time
    
    # Per question type accuracy
    qtype_acc = {}
    qtype_names = ['Des', 'Exp', 'Pred-A', 'Pred-R', 'CF-A', 'CF-R']
    for qt in range(6):
        if qtype_total[qt] > 0:
            qtype_acc[qtype_names[qt]] = qtype_correct[qt] * 100.0 / qtype_total[qt]
    
    return {
        'loss': epoch_loss,
        'acc': epoch_acc,
        'time': epoch_time,
        'qtype_acc': qtype_acc
    }


def evaluate(model, data_loader, device, split_name='val'):
    """Evaluate with detailed per-type accuracy"""
    model.eval()
    
    predictions = []
    answers = []
    qtype_correct = defaultdict(int)
    qtype_total = defaultdict(int)
    
    with torch.no_grad():
        for inputs in data_loader:
            vid_frame_feat, vid_obj_feat, qns_w, ans_w, ans_id, qns_keys = inputs
            vid_frame_feat = vid_frame_feat.to(device)
            vid_obj_feat = vid_obj_feat.to(device)
            
            out = model(vid_frame_feat, vid_obj_feat, qns_w, ans_w)
            pred = out.max(-1)[1].cpu()
            
            predictions.append(pred)
            answers.append(ans_id)
            
            for qkey, p, a in zip(qns_keys, pred.numpy(), ans_id.numpy()):
                qtype = int(qkey.split('_')[-1])
                qtype_total[qtype] += 1
                if p == a:
                    qtype_correct[qtype] += 1
    
    all_preds = torch.cat(predictions, dim=0).long()
    all_ans = torch.cat(answers, dim=0).long()
    overall_acc = (all_preds == all_ans).sum().item() * 100.0 / len(all_ans)
    
    # Per question type accuracy
    qtype_names = ['Des', 'Exp', 'Pred-A', 'Pred-R', 'CF-A', 'CF-R']
    qtype_acc = {}
    for qt in range(6):
        if qtype_total[qt] > 0:
            qtype_acc[qtype_names[qt]] = qtype_correct[qt] * 100.0 / qtype_total[qt]
    
    # Combined metrics (Pred = both Pred-A and Pred-R correct for same video)
    # This is computed at video level, need results dict for that
    
    return {
        'acc': overall_acc,
        'qtype_acc': qtype_acc,
        'n_samples': len(all_ans)
    }


def predict_and_save(model, data_loader, device, save_path):
    """Generate predictions and save to JSON"""
    model.eval()
    results = {}
    
    with torch.no_grad():
        for inputs in data_loader:
            vid_frame_feat, vid_obj_feat, qns_w, ans_w, ans_id, qns_keys = inputs
            vid_frame_feat = vid_frame_feat.to(device)
            vid_obj_feat = vid_obj_feat.to(device)
            
            out = model(vid_frame_feat, vid_obj_feat, qns_w, ans_w)
            pred = out.max(-1)[1].cpu()
            
            for qid, p, a in zip(qns_keys, pred.numpy(), ans_id.numpy()):
                results[qid] = {'prediction': int(p), 'answer': int(a)}
    
    with open(save_path, 'w') as f:
        json.dump(results, f, indent=2)
    
    # Compute accuracy
    correct = sum(1 for v in results.values() if v['prediction'] == v['answer'])
    acc = correct * 100.0 / len(results)
    
    return results, acc


print("‚úÖ Training functions defined")

‚úÖ Training functions defined


In [16]:
# ============================================================
# SETUP OPTIMIZER, SCHEDULER, CRITERION
# ============================================================
os.makedirs('./models', exist_ok=True)
os.makedirs('./prediction', exist_ok=True)

# Get base model for parameter groups (handle DataParallel)
base_model = model.module if hasattr(model, 'module') else model

# Optimizer with different LR for text encoder
param_groups = [
    {
        "params": [p for n, p in base_model.named_parameters() 
                   if "text_encoder" not in n and p.requires_grad],
        "lr": args.lr
    },
    {
        "params": [p for n, p in base_model.named_parameters() 
                   if "text_encoder" in n and p.requires_grad],
        "lr": args.text_encoder_lr
    }
]

optimizer = torch.optim.AdamW(param_groups, weight_decay=args.decay)
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=args.gamma, 
                               patience=args.patience, verbose=True)
criterion = nn.CrossEntropyLoss()

print("‚úÖ Optimizer and scheduler created")
print(f"   Main LR: {args.lr}")
print(f"   Text encoder LR: {args.text_encoder_lr}")

‚úÖ Optimizer and scheduler created
   Main LR: 0.0001
   Text encoder LR: 1e-05


In [17]:
# ============================================================
# INITIALIZE WANDB
# ============================================================
wandb_config = {
    "architecture": "TranSTR",
    "dataset": "CausalVidQA",
    "epochs": args.epoch,
    "batch_size": args.bs,
    "learning_rate": args.lr,
    "text_encoder_lr": args.text_encoder_lr,
    "text_encoder": args.text_encoder_type,
    "d_model": args.d_model,
    "topK_frame": args.topK_frame,
    "topK_obj": args.topK_obj,
    "n_objects": args.objs,
    "num_encoder_layers": args.num_encoder_layers,
    "num_decoder_layers": args.num_decoder_layers,
    "multi_gpu": args.use_multi_gpu,
    "n_gpus": torch.cuda.device_count(),
    "train_samples": len(train_dataset),
    "val_samples": len(val_dataset),
    "test_samples": len(test_dataset),
}

run = wandb.init(
    project=args.project_name,
    name=args.run_name,
    config=wandb_config,
    tags=["causalvid", "multi-gpu" if args.use_multi_gpu else "single-gpu"]
)

# Log dataset info
wandb.log({
    "data/train_videos": len(train_dataset.vids) if hasattr(train_dataset, 'vids') else 0,
    "data/val_videos": len(val_dataset.vids) if hasattr(val_dataset, 'vids') else 0,
    "data/test_videos": len(test_dataset.vids) if hasattr(test_dataset, 'vids') else 0,
})

print(f"‚úÖ W&B initialized: {run.url}")

‚úÖ W&B initialized: https://wandb.ai/introSE/CausalVidQA-TranSTR/runs/okts7dm0


In [18]:
# ============================================================
# TRAINING LOOP WITH EARLY STOPPING
# ============================================================
best_val_acc = 0.0
best_epoch = 1
best_model_path = f'./models/best_model-{args.run_name}.ckpt'
history = {'train': [], 'val': [], 'test': []}
epochs_without_improvement = 0  # Early stopping counter

print("=" * 70)
print(f"üöÄ STARTING TRAINING: {args.run_name}")
print(f"   Epochs: {args.epoch} | Batch size: {args.bs} | GPUs: {torch.cuda.device_count()}")
print(f"   Early stopping: {args.early_stopping_patience} epochs without improvement")
print("=" * 70)

for epoch in range(1, args.epoch + 1):
    print(f"\n{'='*70}")
    print(f"üìö EPOCH [{epoch}/{args.epoch}]")
    print(f"{'='*70}")
    
    # ============ TRAIN ============
    train_metrics = train_epoch(model, optimizer, train_loader, criterion, device, epoch, run)
    
    # ============ EVALUATE ============
    val_metrics = evaluate(model, val_loader, device, 'val')
    test_metrics = evaluate(model, test_loader, device, 'test')
    
    # ============ UPDATE SCHEDULER ============
    scheduler.step(val_metrics['acc'])
    current_lr = optimizer.param_groups[0]['lr']
    
    # ============ SAVE BEST MODEL & EARLY STOPPING ============
    is_best = val_metrics['acc'] > best_val_acc
    if is_best:
        best_val_acc = val_metrics['acc']
        best_epoch = epoch
        epochs_without_improvement = 0  # Reset counter
        # Save model (handle DataParallel)
        state_dict = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
        torch.save(state_dict, best_model_path)
    else:
        epochs_without_improvement += 1
    
    # ============ LOGGING ============
    print(f"\n  üìä Results:")
    print(f"     {'Metric':<15} {'Train':>10} {'Val':>10} {'Test':>10}")
    print(f"     {'-'*15} {'-'*10} {'-'*10} {'-'*10}")
    print(f"     {'Loss':<15} {train_metrics['loss']:>10.4f} {'-':>10} {'-':>10}")
    print(f"     {'Accuracy':<15} {train_metrics['acc']:>9.2f}% {val_metrics['acc']:>9.2f}% {test_metrics['acc']:>9.2f}%")
    
    # Per question type accuracy
    print(f"\n  üìà Per Question Type Accuracy (Val):")
    qtype_order = ['Des', 'Exp', 'Pred-A', 'Pred-R', 'CF-A', 'CF-R']
    for qt in qtype_order:
        if qt in val_metrics['qtype_acc']:
            print(f"     {qt:<10}: {val_metrics['qtype_acc'][qt]:>6.2f}%")
    
    print(f"\n  ‚è±Ô∏è  Time: {train_metrics['time']:.1f}s | LR: {current_lr:.2e}")
    print(f"  üìâ No improvement: {epochs_without_improvement}/{args.early_stopping_patience} epochs")
    if is_best:
        print(f"  üíæ Saved best model! (Val acc: {best_val_acc:.2f}%)")
    
    # ============ WANDB LOGGING ============
    wandb_log = {
        "epoch": epoch,
        "train/loss": train_metrics['loss'],
        "train/acc": train_metrics['acc'],
        "val/acc": val_metrics['acc'],
        "test/acc": test_metrics['acc'],
        "lr": current_lr,
        "epoch_time": train_metrics['time'],
        "best_val_acc": best_val_acc,
        "epochs_without_improvement": epochs_without_improvement,
    }
    
    # Log per question type accuracy
    for qt, acc in train_metrics['qtype_acc'].items():
        wandb_log[f"train/acc_{qt}"] = acc
    for qt, acc in val_metrics['qtype_acc'].items():
        wandb_log[f"val/acc_{qt}"] = acc
    for qt, acc in test_metrics['qtype_acc'].items():
        wandb_log[f"test/acc_{qt}"] = acc
    
    wandb.log(wandb_log)
    
    # Save checkpoint every N epochs
    if epoch % args.save_every == 0:
        ckpt_path = f'./models/checkpoint-{args.run_name}-ep{epoch}.ckpt'
        state_dict = model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
        torch.save({
            'epoch': epoch,
            'model_state_dict': state_dict,
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_metrics['acc'],
        }, ckpt_path)
        print(f"  üìÅ Checkpoint saved: {ckpt_path}")
    
    # ============ EARLY STOPPING CHECK ============
    if epochs_without_improvement >= args.early_stopping_patience:
        print(f"\n  ‚ö†Ô∏è EARLY STOPPING: No improvement for {args.early_stopping_patience} epochs")
        print(f"     Best val acc: {best_val_acc:.2f}% at epoch {best_epoch}")
        wandb.log({"early_stopped": True, "stopped_at_epoch": epoch})
        break

# ============================================================
# TRAINING COMPLETE
# ============================================================
print("\n" + "=" * 70)
print("‚úÖ TRAINING COMPLETED!")
print("=" * 70)
print(f"   Best epoch: {best_epoch}")
print(f"   Best val accuracy: {best_val_acc:.2f}%")
print(f"   Model saved: {best_model_path}")
if epochs_without_improvement >= args.early_stopping_patience:
    print(f"   Stopped early at epoch {epoch}")
print("=" * 70)

üöÄ STARTING TRAINING: causalvid_2gpu
   Epochs: 20 | Batch size: 2 | GPUs: 2
   Early stopping: 5 epochs without improvement

üìö EPOCH [1/20]


RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/parallel/parallel_apply.py", line 96, in _worker
    output = module(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/kaggle/working/tranSTR_Casual/networks/model.py", line 102, in forward
    frame_local, frame_att = self.frame_decoder(frame_feat,
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/kaggle/working/tranSTR_Casual/networks/multimodal_transformer.py", line 59, in forward
    output, c_att = layer(output, memory, tgt_mask=tgt_mask,
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/kaggle/working/tranSTR_Casual/networks/multimodal_transformer.py", line 158, in forward
    tgt2, c_att = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/kaggle/working/tranSTR_Casual/networks/attention.py", line 66, in forward
    key_padding_mask = key_padding_mask.reshape(bs, k_length)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: shape '[1, 18]' is invalid for input of size 36


In [None]:
# ============================================================
# FINAL EVALUATION WITH BEST MODEL
# ============================================================
print("=" * 70)
print("üìä FINAL EVALUATION")
print("=" * 70)

# Load best model
print("\n  Loading best model...")
base_model = model.module if hasattr(model, 'module') else model
base_model.load_state_dict(torch.load(best_model_path))

# Predict on test set
result_path = f'./prediction/{args.run_name}-ep{best_epoch}-val{best_val_acc:.2f}.json'
results, test_acc = predict_and_save(model, test_loader, device, result_path)

print(f"\n  Test accuracy: {test_acc:.2f}%")
print(f"  Results saved: {result_path}")

# Detailed evaluation by question type
print("\n" + "-" * 70)
print("  Detailed Results by Question Type:")
print("-" * 70)
eval_mc.accuracy_metric_cvid(result_path)

# Log final results to wandb
wandb.log({
    "final/test_acc": test_acc,
    "final/best_epoch": best_epoch,
    "final/best_val_acc": best_val_acc,
})

# Save results artifact
artifact = wandb.Artifact(f'predictions-{args.run_name}', type='predictions')
artifact.add_file(result_path)
wandb.log_artifact(artifact)

In [None]:
# ============================================================
# SAVE TO KAGGLE OUTPUT & FINISH WANDB
# ============================================================
import shutil

output_dir = '/kaggle/working'
if os.path.exists(output_dir):
    # Copy best model
    shutil.copy(best_model_path, os.path.join(output_dir, f'best_model-{args.run_name}.ckpt'))
    # Copy predictions
    shutil.copy(result_path, output_dir)
    print(f"‚úÖ Files saved to {output_dir}")
else:
    print("  Not running on Kaggle, files saved locally")

# Save model artifact to wandb
model_artifact = wandb.Artifact(f'model-{args.run_name}', type='model')
model_artifact.add_file(best_model_path)
wandb.log_artifact(model_artifact)

# Finish wandb run
wandb.finish()
print("‚úÖ W&B run finished")

In [None]:
# Load best model
print("Loading best model...")
model.load_state_dict(torch.load(best_model_path))

# Predict on test set
result_path = f'./prediction/{args.v}-{best_epoch}-{best_val_acc:.2f}.json'
results, test_acc = predict_and_save(model, test_loader, device, result_path)

print(f"\nüìä Final Test Results:")
print(f"   Overall Accuracy: {test_acc:.2f}%")
print(f"   Results saved to: {result_path}")

In [None]:
# # Quick test v·ªõi 10 videos
# args.max_samples = 10
# args.bs = 4
# args.epoch = 2
# 
# # Re-create datasets v·ªõi max_samples
# train_dataset = VideoQADataset(
#     split='train', n_query=args.n_query, obj_num=args.objs,
#     sample_list_path=args.sample_list_path,
#     video_feature_path=args.video_feature_path,
#     text_annotation_path=args.text_annotation_path,
#     qtype=args.qtype, max_samples=args.max_samples
# )
# print(f"Quick test dataset: {len(train_dataset)} samples")

### Evaluating pretrained model B2A

### Training