# TranSTR CausalVid - Paper Configuration

In [None]:
# CELL 1: Clone
import os
print('=== CELL 1 ===')
if not os.path.exists('tranSTR_Casual'):
    !git clone https://github.com/DanielQH07/tranSTR_Casual.git -b origin
os.chdir('tranSTR_Casual/causalvid' if os.path.exists('tranSTR_Casual/causalvid') else 'tranSTR_Casual')
print(f'CWD: {os.getcwd()}')

In [None]:
# CELL 2: HuggingFace
print('=== CELL 2 ===')
!pip install -q huggingface_hub
from huggingface_hub import login, HfApi, hf_hub_download, list_repo_tree
# notebook_login()
login(token='YOUR_HF_TOKEN') # Replace with your actual token

In [None]:
# CELL 3: Imports
print('=== CELL 3: Imports ===')
import os, torch, numpy as np, pandas as pd, tarfile, shutil, json
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from utils.util import set_seed, set_gpu_devices
from DataLoader import VideoQADataset
from networks.model import VideoQAmodel
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
print('Imports OK')

In [None]:
# CELL 3.5: Pre-extract DeBERTa Text Features (Run ONCE before training)
print('=== Pre-extracting DeBERTa Text Features ===')
print('This will cache all text embeddings - makes training 5-10x faster!')

import pickle as pkl
from transformers import AutoTokenizer, AutoModel
from tqdm.auto import tqdm

# ============================================
# PATHS - UPDATE THESE
# ============================================
ANNOTATION_PATH = '/kaggle/input/YOUR_ANNOTATION_DATASET'  # Contains video_id/text.json, answer.json
SPLIT_DIR = '/kaggle/input/YOUR_SPLITS_DATASET'           # Contains train.pkl, valid.pkl, test.pkl
TEXT_FEATURE_PATH = '/kaggle/working/text_features'       # Output directory
MODEL_NAME = 'microsoft/deberta-base'

os.makedirs(TEXT_FEATURE_PATH, exist_ok=True)

# ============================================
# LOAD MODEL
# ============================================
print(f'\nLoading {MODEL_NAME}...')
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
text_model = AutoModel.from_pretrained(MODEL_NAME)
text_model.to(device)
text_model.eval()
print('Model loaded!')

# ============================================
# HELPER FUNCTIONS
# ============================================
def load_split_videos(split_name):
    pkl_name = 'valid' if split_name == 'val' else split_name
    pkl_path = os.path.join(SPLIT_DIR, f"{pkl_name}.pkl")
    if os.path.exists(pkl_path):
        with open(pkl_path, 'rb') as f:
            data = pkl.load(f)
        return set(data) if isinstance(data, (list, set)) else set(data.keys())
    return set()

def extract_for_split(split_name, max_videos=None):
    print(f'\n--- {split_name.upper()} ---')
    
    # Load video IDs
    video_ids = load_split_videos(split_name)
    if max_videos:
        video_ids = set(list(video_ids)[:max_videos])
    print(f'Videos: {len(video_ids)}')
    
    features = {}
    
    for vid in tqdm(video_ids, desc=f'Extracting {split_name}'):
        vp = os.path.join(ANNOTATION_PATH, vid)
        tj, aj = os.path.join(vp, "text.json"), os.path.join(vp, "answer.json")
        if not (os.path.exists(tj) and os.path.exists(aj)):
            continue
        
        try:
            with open(tj, encoding="utf-8") as f:
                td = json.load(f)
            
            for k in ["descriptive", "explanatory", "predictive", "counterfactual"]:
                if k in td:
                    q = td[k]
                    if "question" in q and "answer" in q:
                        qns = q["question"]
                        choices = q["answer"]
                        texts = [f"[CLS] {qns} [SEP] {c}" for c in choices]
                        
                        tokenized = tokenizer(texts, padding=True, truncation=True, 
                                             max_length=256, return_tensors='pt').to(device)
                        with torch.no_grad():
                            out = text_model(**tokenized).last_hidden_state[:, 0, :]  # [5, 768]
                        
                        features[f"{vid}_{k}"] = {
                            'cls': out.cpu().numpy().astype(np.float32),
                            'question': qns
                        }
                    
                    if k in ["predictive", "counterfactual"] and "reason" in q:
                        qns = "Why?"
                        choices = q["reason"]
                        texts = [f"[CLS] {qns} [SEP] {c}" for c in choices]
                        
                        tokenized = tokenizer(texts, padding=True, truncation=True,
                                             max_length=256, return_tensors='pt').to(device)
                        with torch.no_grad():
                            out = text_model(**tokenized).last_hidden_state[:, 0, :]
                        
                        features[f"{vid}_{k}_reason"] = {
                            'cls': out.cpu().numpy().astype(np.float32),
                            'question': qns
                        }
        except Exception as e:
            pass
        
        if len(features) % 500 == 0:
            torch.cuda.empty_cache()
    
    # Save
    output_file = os.path.join(TEXT_FEATURE_PATH, f"{split_name}_text_features.pkl")
    with open(output_file, 'wb') as f:
        pkl.dump(features, f, protocol=pkl.HIGHEST_PROTOCOL)
    print(f'Saved: {output_file} ({len(features)} entries)')
    
    return features

# ============================================
# EXTRACT FOR ALL SPLITS
# ============================================
# Set max_videos=None for full dataset, or a number for testing
train_text = extract_for_split('train', max_videos=None)
val_text = extract_for_split('val', max_videos=None)
test_text = extract_for_split('test', max_videos=None)

# Cleanup
del text_model, tokenizer
torch.cuda.empty_cache()

print('\n' + '='*60)
print('TEXT FEATURE EXTRACTION COMPLETE!')
print(f'Saved to: {TEXT_FEATURE_PATH}')
print('='*60)


In [None]:
# CELL 4: Train/Eval functions (Optimized for cached text features)
print('=== CELL 4 ===')

def train_epoch(model, optimizer, loader, xe, device, use_cached_text=True):
    model.train()
    total_loss, correct, total = 0, 0, 0
    
    for batch in loader:
        ff, of, q, text_feat, ans_id, _ = batch
        ff, of, tgt = ff.to(device), of.to(device), ans_id.to(device)
        
        if use_cached_text and isinstance(text_feat, torch.Tensor):
            # Cached text features: [batch, 5, 768]
            text_feat = text_feat.to(device)
            out = model.forward_cached(ff, of, text_feat)
        else:
            # Raw text: use DeBERTa (slow)
            out = model(ff, of, q, text_feat)
        
        loss = xe(out, tgt)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        correct += (out.argmax(-1) == tgt).sum().item()
        total += tgt.size(0)
    
    return total_loss / len(loader), correct / total * 100

def eval_epoch(model, loader, device, use_cached_text=True):
    model.eval()
    correct, total = 0, 0
    
    with torch.no_grad():
        for batch in loader:
            ff, of, q, text_feat, ans_id, _ = batch
            ff, of = ff.to(device), of.to(device)
            
            if use_cached_text and isinstance(text_feat, torch.Tensor):
                text_feat = text_feat.to(device)
                out = model.forward_cached(ff, of, text_feat)
            else:
                out = model(ff, of, q, text_feat)
            
            correct += (out.argmax(-1) == ans_id.to(device)).sum().item()
            total += ans_id.size(0)
    
    return correct / total * 100

print('Functions defined (cached text support)')


In [None]:
# CELL 5 + 6: Setup Paths & Config
print('=== CELL 5+6: Paths & Config ===')

# ============================================
# KAGGLE INPUT PATHS - UPDATE THESE!
# ============================================
# ViT video features (folder contains video_id.pt files directly)
VIT_FEATURE_PATH = '/kaggle/input/YOUR_VIT_DATASET'  # Contains: video_id.pt files

# Object detection features (direct read from Kaggle)
OBJ_FEATURE_PATH = '/kaggle/input/object-detection-causal-full'  # Contains: features_node_X/video.pkl

# Annotations (folder contains video_id subfolders with text.json, answer.json)
ANNOTATION_PATH = '/kaggle/input/YOUR_ANNOTATION_DATASET'  # Contains: video_id/text.json, answer.json

# Split files (train.pkl, valid.pkl, test.pkl)
SPLIT_DIR = '/kaggle/input/YOUR_SPLITS_DATASET'  # Contains: train.pkl, valid.pkl, test.pkl

# ============================================
# WORKING DIRECTORIES
# ============================================
BASE = '/kaggle/working'
MODEL_DIR = os.path.join(BASE, 'models')
os.makedirs(MODEL_DIR, exist_ok=True)

# ============================================
# VERIFY PATHS
# ============================================
print('\n--- Path Verification ---')

def verify_path(name, path, expected_sample=None):
    if os.path.exists(path):
        items = os.listdir(path)[:5]
        print(f'✅ {name}')
        print(f'   Path: {path}')
        print(f'   Sample: {items}')
        return True
    else:
        print(f'❌ {name}: NOT FOUND')
        print(f'   Path: {path}')
        return False

all_ok = True
all_ok &= verify_path('ViT Features', VIT_FEATURE_PATH)
all_ok &= verify_path('Object Features', OBJ_FEATURE_PATH)
all_ok &= verify_path('Annotations', ANNOTATION_PATH)
all_ok &= verify_path('Splits', SPLIT_DIR)

if not all_ok:
    print('\n⚠️  Please update paths above!')

# ============================================
# CONFIG
# ============================================
RUN_TRAINING = True
HF_REPO_ID = 'DanielQ07/transtr-causalvid-weights'
HF_MODEL_FILENAME = 'best_model.ckpt'

class Config:
    # Paths from Kaggle input
    video_feature_root = VIT_FEATURE_PATH   # video_id.pt files directly
    object_feature_path = OBJ_FEATURE_PATH  # features_node_X/video.pkl
    sample_list_path = ANNOTATION_PATH      # video_id/text.json, answer.json
    split_dir_txt = SPLIT_DIR               # train.pkl, valid.pkl, test.pkl
    
    # Model architecture (paper config)
    topK_frame = 16
    objs = 20
    frames = 16
    select_frames = 5
    topK_obj = 12
    frame_feat_dim = 1024
    obj_feat_dim = 2053
    d_model = 768
    word_dim = 768
    nheads = 8
    num_encoder_layers = 2
    num_decoder_layers = 2
    normalize_before = True
    activation = 'gelu'
    dropout = 0.3
    encoder_dropout = 0.3
    
    # Text encoder
    text_encoder_type = 'microsoft/deberta-base'
    freeze_text_encoder = False
    text_encoder_lr = 1e-5
    text_pool_mode = 1
    
    # Training
    bs = 8
    lr = 1e-5
    epoch = 20
    gpu = 0
    patience = 5
    gamma = 0.1
    decay = 1e-4
    n_query = 5
    
    # Other
    hard_eval = False
    pos_ratio = 1.0
    neg_ratio = 1.0
    a = 1.0
    use_amp = True
    num_workers = 4

args = Config()
set_gpu_devices(args.gpu)
set_seed(999)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'\nDevice: {device}')
print('Config loaded!')


In [None]:
# CELL 7: Create Datasets with Cached Text Features
print('=== CELL 7: Datasets ===')

# Text features path (set to None if not pre-extracted)
TEXT_FEATURE_PATH = '/kaggle/working/text_features'  # Output from extract_text_features

# Configuration
MAX_TRAIN_SAMPLES = 2000  # Set to None for all

# Create datasets
print('\n--- Creating TRAIN dataset ---')
train_ds = VideoQADataset(
    split='train', 
    n_query=args.n_query, 
    obj_num=args.objs, 
    sample_list_path=args.sample_list_path, 
    video_feature_path=args.video_feature_root, 
    object_feature_path=args.object_feature_path, 
    split_dir=args.split_dir_txt, 
    topK_frame=args.topK_frame,
    max_samples=MAX_TRAIN_SAMPLES,
    verbose=True,
    text_feature_path=TEXT_FEATURE_PATH  # NEW: cached text features
)

print('\n--- Creating VAL dataset ---')
val_ds = VideoQADataset(
    split='val', 
    n_query=args.n_query, 
    obj_num=args.objs, 
    sample_list_path=args.sample_list_path, 
    video_feature_path=args.video_feature_root, 
    object_feature_path=args.object_feature_path, 
    split_dir=args.split_dir_txt, 
    topK_frame=args.topK_frame,
    max_samples=None,
    verbose=True,
    text_feature_path=TEXT_FEATURE_PATH
)

print('\n--- Creating TEST dataset ---')
test_ds = VideoQADataset(
    split='test', 
    n_query=args.n_query, 
    obj_num=args.objs, 
    sample_list_path=args.sample_list_path, 
    video_feature_path=args.video_feature_root, 
    object_feature_path=args.object_feature_path, 
    split_dir=args.split_dir_txt, 
    topK_frame=args.topK_frame,
    max_samples=None,
    verbose=True,
    text_feature_path=TEXT_FEATURE_PATH
)

# DataLoaders
train_loader = DataLoader(train_ds, args.bs, shuffle=True, num_workers=args.num_workers, pin_memory=True)
val_loader = DataLoader(val_ds, args.bs, shuffle=False, num_workers=args.num_workers, pin_memory=True)
test_loader = DataLoader(test_ds, args.bs, shuffle=False, num_workers=args.num_workers, pin_memory=True)

# Summary
print('\n' + '='*60)
print('DATASET SUMMARY')
print('='*60)
print(f'Train: {len(train_ds)} samples -> {len(train_loader)} batches')
print(f'Val:   {len(val_ds)} samples -> {len(val_loader)} batches')
print(f'Test:  {len(test_ds)} samples -> {len(test_loader)} batches')
print(f'Text features: {"CACHED ✓" if train_ds.text_features else "REAL-TIME (slow)"}')
print('='*60)

# Sanity check
if len(train_ds) > 0:
    print('\nSanity check...')
    try:
        ff, of, qns, text_feat, ans_id, keys = next(iter(train_loader))
        print(f'  ViT: {ff.shape}')
        print(f'  Obj: {of.shape}')
        if isinstance(text_feat, torch.Tensor):
            print(f'  Text (cached): {text_feat.shape}')  # [batch, 5, 768]
        else:
            print(f'  Text (raw): list of {len(text_feat)} strings')
        print('Sanity check PASSED!')
    except Exception as e:
        print(f'  ERROR: {e}')
        import traceback
        traceback.print_exc()


In [None]:
# CELL 8: Model
print('=== CELL 8: Model ===')
cfg = {k: v for k, v in Config.__dict__.items() if not k.startswith('_')}
cfg['device'] = device
cfg['topK_frame'] = args.select_frames
model = VideoQAmodel(**cfg)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
scheduler = ReduceLROnPlateau(optimizer, 'max', factor=args.gamma, patience=args.patience)
model.to(device)
xe = nn.CrossEntropyLoss()
#scaler = torch.amp.GradScaler('cuda', enabled=True)
save_path = os.path.join(MODEL_DIR, HF_MODEL_FILENAME)
print(f'Model: {sum(p.numel() for p in model.parameters())/1e6:.1f}M params')

In [None]:
# CELL 9: Training
print('=== CELL 9: Training ===')
best_acc = 0

if RUN_TRAINING:
    for ep in range(1, args.epoch + 1):
        loss, acc = train_epoch(model, optimizer, train_loader, xe, device)
        val_acc = eval_epoch(model, val_loader, device)
        scheduler.step(val_acc)
        print(f'Ep {ep}: Loss={loss:.4f}, Train={acc:.1f}%, Val={val_acc:.1f}%')
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), save_path)
            print(f'  Saved!')
    print(f'\nBest Val: {best_acc:.1f}%')
    try:
        api = HfApi()
        api.create_repo(repo_id=HF_REPO_ID, repo_type='model', exist_ok=True)
        api.upload_file(path_or_fileobj=save_path, path_in_repo=HF_MODEL_FILENAME, repo_id=HF_REPO_ID, repo_type='model')
        print('Uploaded!')
    except Exception as e:
        print(f'Upload failed: {e}')
else:
    print('Skipping training (RUN_TRAINING=False)')

In [None]:
# CELL 10: Detailed Evaluation Function (with ALL score)
print('=== CELL 10: Evaluation Functions ===')

def run_detailed_evaluation(model, loader, device, split_name='test', save_file='failure_cases.json'):
    # Load weights
    loaded = False
    if RUN_TRAINING:
        if os.path.exists(save_path):
            model.load_state_dict(torch.load(save_path))
            print(f'Loaded Locally Trained Model: {save_path}')
            loaded = True
    else:
        try:
            print(f'Downloading {HF_MODEL_FILENAME} from {HF_REPO_ID}...')
            local_model_path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_MODEL_FILENAME, local_dir=MODEL_DIR)
            model.load_state_dict(torch.load(local_model_path))
            print(f'Loaded HF Model: {local_model_path}')
            loaded = True
        except Exception as e:
            print(f'Could not download model: {e}')

    model.eval()
    vid_results = {}
    failures = []
    
    type_map = {
        'descriptive': 'd', 'explanatory': 'e',
        'predictive': 'p', 'predictive_reason': 'pr',
        'counterfactual': 'c', 'counterfactual_reason': 'cr'
    }
    
    print(f'\nRunning Detailed Evaluation on {split_name.upper()} Set...')
    with torch.no_grad():
        for batch in loader:
            vid_frame_feat, vid_obj_feat, qns_word, ans_word, ans_id, qns_keys = batch
            vid_frame_feat = vid_frame_feat.to(device)
            vid_obj_feat = vid_obj_feat.to(device)
            out = model(vid_frame_feat, vid_obj_feat, qns_word, ans_word)
            preds = out.argmax(dim=-1).cpu().numpy()
            targets = ans_id.numpy()
            
            for i, qkey in enumerate(qns_keys):
                found_type = None
                vid_id = None
                for t_str, t_short in type_map.items():
                    if qkey.endswith('_' + t_str):
                        found_type = t_short
                        vid_id = qkey[:-(len(t_str)+1)]
                        break
                if not found_type:
                    continue
                if vid_id not in vid_results:
                    vid_results[vid_id] = {}
                
                is_correct = (preds[i] == targets[i])
                vid_results[vid_id][found_type] = {'correct': is_correct}
                
                if not is_correct:
                    failures.append({
                        'video_id': vid_id,
                        'type': found_type,
                        'question': qns_word[i],
                        'pred': int(preds[i]),
                        'ground_truth': int(targets[i])
                    })

    # Calculate stats (including ALL)
    stats = {k: {'correct': 0, 'total': 0} for k in ['d', 'e', 'p', 'pr', 'c', 'cr', 'par', 'car', 'all']}
    for vid, res in vid_results.items():
        for t in ['d', 'e', 'p', 'pr', 'c', 'cr']:
            if t in res:
                stats[t]['total'] += 1
                stats['all']['total'] += 1
                if res[t]['correct']:
                    stats[t]['correct'] += 1
                    stats['all']['correct'] += 1
        # Combined PAR
        if 'p' in res and 'pr' in res:
            stats['par']['total'] += 1
            if res['p']['correct'] and res['pr']['correct']:
                stats['par']['correct'] += 1
        # Combined CAR
        if 'c' in res and 'cr' in res:
            stats['car']['total'] += 1
            if res['c']['correct'] and res['cr']['correct']:
                stats['car']['correct'] += 1

    # Print results with ALL
    labels, accs = [], []
    print(f"\n{'Type':<6} {'Acc %':<10} {'Cor':<6} {'Tot':<6}")
    print('-' * 35)
    for k in ['d', 'e', 'p', 'pr', 'par', 'c', 'cr', 'car']:
        s = stats[k]
        acc = s['correct'] / s['total'] * 100 if s['total'] > 0 else 0
        print(f"{k.upper():<6} {acc:<10.2f} {s['correct']:<6} {s['total']:<6}")
        if s['total'] > 0:
            labels.append(k.upper())
            accs.append(acc)
    print('-' * 35)
    # Print ALL score
    all_s = stats['all']
    all_acc = all_s['correct'] / all_s['total'] * 100 if all_s['total'] > 0 else 0
    print(f"{'ALL':<6} {all_acc:<10.2f} {all_s['correct']:<6} {all_s['total']:<6}")
    print('=' * 35)
    labels.append('ALL')
    accs.append(all_acc)

    # Plot
    plt.figure(figsize=(12, 5))
    colors = ['steelblue'] * (len(labels) - 1) + ['darkgreen']
    bars = plt.bar(labels, accs, color=colors)
    plt.ylim(0, 105)
    plt.ylabel('Accuracy (%)')
    plt.title(f'Performance by Question Type ({split_name.upper()} Set)')
    for bar in bars:
        y = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2, y + 1, f'{y:.1f}', ha='center', va='bottom')
    plt.axhline(y=20, color='gray', linestyle='--', alpha=0.5, label='Random (20%)')
    plt.legend()
    plt.savefig(f'{split_name}_results.png', dpi=150)
    print(f'\nSaved: {split_name}_results.png')
    plt.show()

    # Save failures
    failure_file = f'{split_name}_{save_file}'
    with open(failure_file, 'w') as f:
        json.dump(failures, f, indent=4)
    print(f'Saved {len(failures)} failure cases to {failure_file}')
    
    return stats, failures

print('Evaluation function defined (with ALL score)')

In [None]:
# CELL 11: Evaluate on TEST set
print('=== CELL 11: TEST Set Evaluation ===')
test_stats, test_failures = run_detailed_evaluation(model, test_loader, device, split_name='test')

In [None]:
# CELL 12: Evaluate on VALIDATION set
print('=== CELL 12: VALIDATION Set Evaluation ===')
val_stats, val_failures = run_detailed_evaluation(model, val_loader, device, split_name='val')

In [None]:
# CELL 13: Summary CSV
print('=== CELL 13: Summary ===')

type_keys = ['d', 'e', 'p', 'pr', 'par', 'c', 'cr', 'car', 'all']
type_names = ['D', 'E', 'P', 'PR', 'PAR', 'C', 'CR', 'CAR', 'ALL']

summary_data = []
for k, name in zip(type_keys, type_names):
    v = val_stats[k]
    t = test_stats[k]
    summary_data.append({
        'Type': name,
        'Val_Correct': v['correct'],
        'Val_Total': v['total'],
        'Val_Acc%': round(v['correct']/v['total']*100, 2) if v['total'] > 0 else 0,
        'Test_Correct': t['correct'],
        'Test_Total': t['total'],
        'Test_Acc%': round(t['correct']/t['total']*100, 2) if t['total'] > 0 else 0
    })

df = pd.DataFrame(summary_data)
df.to_csv('evaluation_summary.csv', index=False)
print('Saved: evaluation_summary.csv\n')
print(df.to_string(index=False))