# TranSTR CausalVid - Final Notebook

Notebook ho√†n ch·ªânh cho **training v√† evaluation** model TranSTR tr√™n CausalVidQA.

**T√≠nh nƒÉng:**
- DeBERTa encode text **real-time** (kh√¥ng c·∫ßn pre-extraction)
- Download model t·ª´ HuggingFace
- Detailed evaluation theo paper metrics (D, E, PAR, CAR, Acc(ALL))

In [None]:
import os
# --- Git Clone & Setup ---
REPO_URL = "https://github.com/DanielQH07/tranSTR_Casual.git" 
REPO_NAME = "tranSTR_Casual"
BRANCH = "origin" 

if not os.path.exists(REPO_NAME):
    print(f"Cloning {REPO_URL}...")
    !git clone {REPO_URL} -b {BRANCH}
else:
    print("Repo already exists.")

# Change Directory to the repo root 
if os.path.basename(os.getcwd()) != REPO_NAME:
    try:
        target_dir = os.path.join(os.getcwd(), REPO_NAME, "causalvid")
        if os.path.exists(target_dir):
             os.chdir(target_dir)
        elif os.path.exists(REPO_NAME):
             os.chdir(REPO_NAME)
        
        print(f"Changed directory to: {os.getcwd()}")
    except Exception as e:
             print(f"Could not set working directory: {e}")

In [None]:
# CELL 2: HuggingFace
print('=== CELL 2 ===')
!pip install -q huggingface_hub
from huggingface_hub import login, HfApi, hf_hub_download, list_repo_tree
# notebook_login()
login(token='') # Replace with your actual token

In [None]:
# CELL 3: Imports
print('=== CELL 3: Imports ===')
import os, torch, numpy as np, pandas as pd, tarfile, shutil, json
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from utils.util import set_seed, set_gpu_devices
from DataLoader import VideoQADataset
from networks.model import VideoQAmodel
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
print('Imports OK')

In [None]:
# CELL 4: Train/Eval functions (No AMP - Safe for DeBERTa)
print('=== CELL 4 ===')

def train_epoch(model, optimizer, loader, xe, device, scaler):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for batch in loader:
        ff, of, q, a, ans_id, _ = batch
        ff, of, tgt = ff.to(device), of.to(device), ans_id.to(device)
        
        # No autocast - DeBERTa has issues with fp16
        out = model(ff, of, q, a)
        loss = xe(out, tgt)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        correct += (out.argmax(-1) == tgt).sum().item()
        total += tgt.size(0)
    return total_loss / len(loader), correct / total * 100

def eval_epoch(model, loader, device):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for batch in loader:
            ff, of, q, a, ans_id, _ = batch
            out = model(ff.to(device), of.to(device), q, a)
            correct += (out.argmax(-1) == ans_id.to(device)).sum().item()
            total += ans_id.size(0)
    return correct / total * 100

print('Functions defined (No AMP - DeBERTa safe)')


In [None]:
# CELL 5 + 6: Setup Paths & Config
print('=== CELL 5+6: Paths & Config ===')

# ============================================
# KAGGLE INPUT PATHS - UPDATE THESE!
# ============================================
# ViT video features (folder contains video_id.pt files directly)
VIT_FEATURE_PATH = '/kaggle/input/vit-features-full-merged'  # Contains: video_id.pt files

# Object detection features (direct read from Kaggle)
OBJ_FEATURE_PATH = '/kaggle/input/object-detection-causal-full'  # Contains: features_node_X/video.pkl

# Annotations (folder contains video_id subfolders with text.json, answer.json)
ANNOTATION_PATH = '/kaggle/input/text-annotation/QA'  # Contains: video_id/text.json, answer.json

# Split files (train.pkl, valid.pkl, test.pkl)
SPLIT_DIR = '/kaggle/input/casual-vid-data-split/split'  # Contains: train.pkl, valid.pkl, test.pkl

# ============================================
# WORKING DIRECTORIES
# ============================================
BASE = '/kaggle/working'
MODEL_DIR = os.path.join(BASE, 'models')
os.makedirs(MODEL_DIR, exist_ok=True)

# ============================================
# VERIFY PATHS
# ============================================
print('\n--- Path Verification ---')

def verify_path(name, path, expected_sample=None):
    if os.path.exists(path):
        items = os.listdir(path)[:5]
        print(f'‚úÖ {name}')
        print(f'   Path: {path}')
        print(f'   Sample: {items}')
        return True
    else:
        print(f'‚ùå {name}: NOT FOUND')
        print(f'   Path: {path}')
        return False

all_ok = True
all_ok &= verify_path('ViT Features', VIT_FEATURE_PATH)
all_ok &= verify_path('Object Features', OBJ_FEATURE_PATH)
all_ok &= verify_path('Annotations', ANNOTATION_PATH)
all_ok &= verify_path('Splits', SPLIT_DIR)

if not all_ok:
    print('\n‚ö†Ô∏è  Please update paths above!')

# ============================================
# CONFIG
# ============================================
RUN_TRAINING = True
HF_REPO_ID = 'DanielQ07/transtr-causalvid-weights'
HF_MODEL_FILENAME = 'best_model.ckpt'

class Config:
    # Paths from Kaggle input
    video_feature_root = VIT_FEATURE_PATH   # video_id.pt files directly
    object_feature_path = OBJ_FEATURE_PATH  # features_node_X/video.pkl
    sample_list_path = ANNOTATION_PATH      # video_id/text.json, answer.json
    split_dir_txt = SPLIT_DIR               # train.pkl, valid.pkl, test.pkl
    
    # Model architecture (paper config)
    topK_frame = 16
    objs = 20
    frames = 16
    select_frames = 5
    topK_obj = 12
    frame_feat_dim = 1024
    obj_feat_dim = 2053
    d_model = 768
    word_dim = 768
    nheads = 8
    num_encoder_layers = 2
    num_decoder_layers = 2
    normalize_before = True
    activation = 'gelu'
    dropout = 0.3
    encoder_dropout = 0.3
    
    # Text encoder
    text_encoder_type = 'microsoft/deberta-base'
    freeze_text_encoder = False
    text_encoder_lr = 1e-5
    text_pool_mode = 1
    
    # Training
    bs = 8
    lr = 1e-5
    epoch = 20
    gpu = 0
    patience = 5
    gamma = 0.1
    decay = 1e-4
    n_query = 5
    
    # Other
    hard_eval = False
    pos_ratio = 1.0
    neg_ratio = 1.0
    a = 1.0
    use_amp = True
    num_workers = 4

args = Config()
set_gpu_devices(args.gpu)
set_seed(999)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'\nDevice: {device}')
print('Config loaded!')


In [None]:
# CELL 7: Create Datasets with Verification
print('=== CELL 7: Datasets ===')

# Configuration for limiting train samples (set to None for no limit)
MAX_TRAIN_SAMPLES = 2000  # Change this to limit training videos, or None for all

# Create datasets with detailed logging
print('\n--- Creating TRAIN dataset ---')
train_ds = VideoQADataset(
    split='train', 
    n_query=args.n_query, 
    obj_num=args.objs, 
    sample_list_path=args.sample_list_path, 
    video_feature_path=args.video_feature_root, 
    object_feature_path=args.object_feature_path, 
    split_dir=args.split_dir_txt, 
    topK_frame=args.topK_frame,
    max_samples=MAX_TRAIN_SAMPLES,
    verbose=True
)

print('\n--- Creating VAL dataset ---')
val_ds = VideoQADataset(
    split='val', 
    n_query=args.n_query, 
    obj_num=args.objs, 
    sample_list_path=args.sample_list_path, 
    video_feature_path=args.video_feature_root, 
    object_feature_path=args.object_feature_path, 
    split_dir=args.split_dir_txt, 
    topK_frame=args.topK_frame,
    max_samples=None,  # Don't limit val/test
    verbose=True
)

print('\n--- Creating TEST dataset ---')
test_ds = VideoQADataset(
    split='test', 
    n_query=args.n_query, 
    obj_num=args.objs, 
    sample_list_path=args.sample_list_path, 
    video_feature_path=args.video_feature_root, 
    object_feature_path=args.object_feature_path, 
    split_dir=args.split_dir_txt, 
    topK_frame=args.topK_frame,
    max_samples=None,
    verbose=True
)

# Create DataLoaders
train_loader = DataLoader(train_ds, args.bs, shuffle=True, num_workers=args.num_workers, pin_memory=True)
val_loader = DataLoader(val_ds, args.bs, shuffle=False, num_workers=args.num_workers, pin_memory=True)
test_loader = DataLoader(test_ds, args.bs, shuffle=False, num_workers=args.num_workers, pin_memory=True)

# Summary
print('\n' + '='*60)
print('DATASET SUMMARY')
print('='*60)
print(f'Train: {len(train_ds)} samples -> {len(train_loader)} batches')
print(f'Val:   {len(val_ds)} samples -> {len(val_loader)} batches')
print(f'Test:  {len(test_ds)} samples -> {len(test_loader)} batches')
print('='*60)

# Quick sanity check - load one batch
if len(train_ds) > 0:
    print('\nSanity check: Loading first batch...')
    try:
        ff, of, qns, ans, ans_id, keys = next(iter(train_loader))
        print(f'  ViT features: {ff.shape}')  # Expected: [batch, topK_frame, feat_dim]
        print(f'  Object features: {of.shape}')  # Expected: [batch, topK_frame, obj_num, 2053]
        print(f'  Answer IDs: {ans_id}')
        print('Sanity check PASSED!')
    except Exception as e:
        print(f'  ERROR: {e}')
        import traceback
        traceback.print_exc()


In [None]:
# CELL 8: Model
print('=== CELL 8: Model ===')
cfg = {k: v for k, v in Config.__dict__.items() if not k.startswith('_')}
cfg['device'] = device
cfg['topK_frame'] = args.select_frames
model = VideoQAmodel(**cfg)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
scheduler = ReduceLROnPlateau(optimizer, 'max', factor=args.gamma, patience=args.patience)
model.to(device)
xe = nn.CrossEntropyLoss()
scaler = torch.amp.GradScaler('cuda', enabled=True)
save_path = os.path.join(MODEL_DIR, HF_MODEL_FILENAME)
print(f'Model: {sum(p.numel() for p in model.parameters())/1e6:.1f}M params')

In [None]:
# CELL 9: Training
print('=== CELL 9: Training ===')
best_acc = 0

if RUN_TRAINING:
    for ep in range(1, args.epoch + 1):
        loss, acc = train_epoch(model, optimizer, train_loader, xe, device, scaler)
        val_acc = eval_epoch(model, val_loader, device)
        scheduler.step(val_acc)
        print(f'Ep {ep}: Loss={loss:.4f}, Train={acc:.1f}%, Val={val_acc:.1f}%')
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), save_path)
            print(f'  Saved!')
    print(f'\nBest Val: {best_acc:.1f}%')
    try:
        api = HfApi()
        api.create_repo(repo_id=HF_REPO_ID, repo_type='model', exist_ok=True)
        api.upload_file(path_or_fileobj=save_path, path_in_repo=HF_MODEL_FILENAME, repo_id=HF_REPO_ID, repo_type='model')
        print('Uploaded!')
    except Exception as e:
        print(f'Upload failed: {e}')
else:
    print('Skipping training (RUN_TRAINING=False)')

In [None]:
# CELL 10: Detailed Test Evaluation (Download from HuggingFace)
print('=== CELL 10: TEST Set Evaluation ===')
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
from tqdm.auto import tqdm
from huggingface_hub import hf_hub_download
from networks.model import VideoQAmodel

# --- CONFIG ---
HF_REPO_ID = 'DanielQ07/transtr-causalvid-weights'
HF_FILENAME = 'best_model.ckpt'
LOCAL_MODEL_PATH = os.path.join(MODEL_DIR, HF_FILENAME)

# 1. Download Model from HuggingFace if not exists locally
if not os.path.exists(LOCAL_MODEL_PATH):
    print(f"\nüì• Downloading {HF_FILENAME} from HuggingFace ({HF_REPO_ID})...")
    try:
        model_path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_FILENAME, local_dir=MODEL_DIR)
        print(f"‚úÖ Downloaded to: {model_path}")
        LOCAL_MODEL_PATH = model_path
    except Exception as e:
        print(f"‚ùå Failed to download model: {e}")
        print("üí° Make sure Internet is ON and Repo ID is correct.")
else:
    print(f"üìÇ Found local model at: {LOCAL_MODEL_PATH}")

# 2. Load Model weights
if os.path.exists(LOCAL_MODEL_PATH):
    print(f"üîß Loading weights...")
    state = torch.load(LOCAL_MODEL_PATH, map_location=device, weights_only=True)
    msg = model.load_state_dict(state, strict=False)
    print(f"Load status: {msg}")
else:
    print("‚ö†Ô∏è CRITICAL: No model weights found! Using Random Weights.")

# DEBUG: Check dataset sample_list to see why answer = -1
print("\nüîç DEBUG - Checking Dataset sample_list:")
if 'test_ds' in globals():
    sample = test_ds.sample_list.iloc[0]
    print(f"  First sample in sample_list:")
    print(f"    video_id: {sample['video_id']}")
    print(f"    type: {sample['type']}")
    print(f"    answer: {sample['answer']} (Type: {type(sample['answer'])})")
    print(f"    question: {sample['question'][:50]}...")
    
    # Check answer distribution
    print(f"\n  Answer value distribution:")
    print(test_ds.sample_list['answer'].value_counts().head(10))
    
    # Try to read raw answer.json to see actual format
    import json
    sample_vid = sample['video_id']
    answer_json_path = f"{args.sample_list_path}/{sample_vid}/answer.json"
    if os.path.exists(answer_json_path):
        with open(answer_json_path, 'r') as f:
            raw_answer = json.load(f)
        print(f"\n  Raw answer.json for {sample_vid}:")
        print(f"    {json.dumps(raw_answer, indent=2)[:500]}...")
    else:
        print(f"\n  ‚ùå answer.json not found at: {answer_json_path}")

# 3. FIXED Evaluation Function - t√≠nh metric TR·ª∞C TI·∫æP t·ª´ batch
def evaluate_detailed_v2(model, loader, device):
    """
    T√≠nh metric tr·ª±c ti·∫øp t·ª´ batch, kh√¥ng c·∫ßn join v·ªõi dataframe g·ªëc.
    """
    model.eval()
    
    # Collect all results by type
    type_results = {}  # {qtype: [(pred, target, video_id), ...]}
    
    print("\nüìä Running Detailed Evaluation...")
    with torch.no_grad():
        for batch in tqdm(loader):
            # Unpack: ff, of, qns, ans_word, ans_id, qns_keys
            ff, of, qns, ans_word, ans_id, qns_keys = batch
            ff = ff.to(device)
            of = of.to(device)
            
            # Forward
            out = model(ff, of, qns, ans_word)
            preds = out.argmax(dim=-1).cpu().numpy()
            targets = ans_id.numpy()
            
            # Parse qns_keys to get video_id and type
            for key, pred, target in zip(qns_keys, preds, targets):
                # key format: "video_id_type" e.g., "abc123_descriptive" or "abc123_predictive_reason"
                # Known types: descriptive, explanatory, predictive, predictive_reason, 
                #              counterfactual, counterfactual_reason
                
                # Check for _reason suffix first (2-part type)
                if key.endswith('_reason'):
                    # Try to find known prefix: predictive_reason or counterfactual_reason
                    if '_predictive_reason' in key:
                        idx = key.rfind('_predictive_reason')
                        video_id = key[:idx]
                        qtype = 'predictive_reason'
                    elif '_counterfactual_reason' in key:
                        idx = key.rfind('_counterfactual_reason')
                        video_id = key[:idx]
                        qtype = 'counterfactual_reason'
                    else:
                        # Fallback
                        parts = key.rsplit('_', 2)
                        video_id = parts[0] if len(parts) > 2 else key
                        qtype = '_'.join(parts[1:]) if len(parts) > 1 else 'unknown'
                else:
                    # Single-part type: descriptive, explanatory, predictive, counterfactual
                    parts = key.rsplit('_', 1)
                    if len(parts) == 2:
                        video_id, qtype = parts
                    else:
                        video_id, qtype = key, 'unknown'
                
                if qtype not in type_results:
                    type_results[qtype] = []
                type_results[qtype].append({
                    'video_id': video_id,
                    'pred': int(pred),
                    'target': int(target),
                    'correct': int(pred) == int(target)
                })
    
    # DEBUG: Show first few predictions per type
    print("\nüîç DEBUG - Sample Predictions vs Targets:")
    for qtype, results in type_results.items():
        if len(results) > 0:
            sample = results[0]
            correct_count = sum(1 for r in results if r['correct'])
            print(f"  [{qtype}] Count: {len(results)}, Correct: {correct_count}")
            print(f"    First: pred={sample['pred']}, target={sample['target']}, match={sample['correct']}")
            if len(results) > 1:
                sample2 = results[1]
                print(f"    Second: pred={sample2['pred']}, target={sample2['target']}, match={sample2['correct']}")
    
    # Calculate metrics
    metrics = {}
    
    metrics_map = {
        'Description': 'descriptive',
        'Explanation': 'explanatory',
        'Predictive-Answer': 'predictive',
        'Predictive-Reason': 'predictive_reason',
        'Counterfactual-Answer': 'counterfactual',
        'Counterfactual-Reason': 'counterfactual_reason'
    }
    
    print("\n" + "="*60)
    print("EVALUATION RESULTS - TEST SET")
    print("="*60)
    
    # Standard metrics
    for name, qtype in metrics_map.items():
        if qtype in type_results:
            results = type_results[qtype]
            correct = sum(1 for r in results if r['correct'])
            total = len(results)
            acc = correct / total * 100 if total > 0 else 0
        else:
            correct, total, acc = 0, 0, 0
        metrics[name] = acc
        print(f"{name:<25} ==>   {acc:.2f}%  ({correct}/{total})")

    # Hard Metrics (AND logic) - c·∫£ answer v√† reason ƒë·ªÅu ƒë√∫ng
    print("-" * 60)
    
    def calc_hard_metric(type_ans, type_reason, name):
        if type_ans not in type_results or type_reason not in type_results:
            metrics[name] = 0
            print(f"{name:<25} ==>   0.00%  (0/0 paired)")
            return
        
        # Build lookup by video_id
        ans_by_vid = {r['video_id']: r['correct'] for r in type_results[type_ans]}
        reason_by_vid = {r['video_id']: r['correct'] for r in type_results[type_reason]}
        
        # Find common video_ids
        common_vids = set(ans_by_vid.keys()) & set(reason_by_vid.keys())
        
        both_correct = sum(1 for vid in common_vids if ans_by_vid[vid] and reason_by_vid[vid])
        total = len(common_vids)
        acc = both_correct / total * 100 if total > 0 else 0
        metrics[name] = acc
        print(f"{name:<25} ==>   {acc:.2f}%  ({both_correct}/{total} paired)")

    calc_hard_metric('predictive', 'predictive_reason', 'PAR')
    calc_hard_metric('counterfactual', 'counterfactual_reason', 'CAR')
    
    print("-" * 60)
    
    # Acc (ALL) = (D + E + PAR + CAR) / 4 (paper definition)
    d_acc = metrics.get('Description', 0)
    e_acc = metrics.get('Explanation', 0)
    par_acc = metrics.get('PAR', 0)
    car_acc = metrics.get('CAR', 0)
    
    acc_all = (d_acc + e_acc + par_acc + car_acc) / 4
    metrics['Acc (ALL)'] = acc_all
    print(f"{'Acc (ALL)':<25} ==>   {acc_all:.2f}%  ((D+E+PAR+CAR)/4)")
    print("="*60)
    
    # Plot
    plot_metrics(metrics)
    return metrics, type_results

def plot_metrics(metrics):
    keys = ['Description', 'Explanation', 'PAR', 'CAR', 'Acc (ALL)']
    values = [metrics.get(k, 0) for k in keys]
    
    plt.figure(figsize=(10, 6))
    bars = plt.bar(keys, values, color=sns.color_palette("viridis", len(keys)))
    plt.ylim(0, 100)
    plt.ylabel('Accuracy (%)')
    plt.title('VideoQA Performance on Test Set')
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    for bar in bars:
        plt.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 1,
                f'{bar.get_height():.1f}%', ha='center', va='bottom', fontweight='bold')
    
    plt.tight_layout()
    plt.savefig('test_results.png')
    plt.show()

# --- EXECUTION ---
# ‚ö†Ô∏è TEST SET kh√¥ng c√≥ ground truth labels (to√†n -1)
# ‚Üí D√πng VALIDATION SET ƒë·ªÉ evaluate thay th·∫ø

RUN_SMALL_TEST = False  # üî¥ Set True to test with 5 batches, False for full run
USE_VAL_SET = False     # üî¥ Set True to use VAL set (has labels), False for TEST set

if USE_VAL_SET:
    print("\nüìå Using VALIDATION SET (has ground truth labels)")
    eval_loader = val_loader
else:
    print("\nüìå Using TEST SET (‚ö†Ô∏è may have -1 labels if held out)")
    eval_loader = test_loader

if 'val_loader' in globals():
    from itertools import islice
    
    loader_to_run = eval_loader
    if RUN_SMALL_TEST:
        print("‚ö†Ô∏è RUNNING SMALL TEST MODE (5 batches only)")
        print("To run full evaluation, set RUN_SMALL_TEST = False")
        loader_to_run = list(islice(eval_loader, 5))

    metrics, raw_results = evaluate_detailed_v2(model, loader_to_run, device)
else:
    print("‚ö†Ô∏è 'val_loader' not defined. Run previous cells to load data first.")
