# TranSTR CausalVid - Paper Configuration

In [None]:
# CELL 1: Clone
import os
print('=== CELL 1 ===')
if not os.path.exists('tranSTR_Casual'):
    !git clone https://github.com/DanielQH07/tranSTR_Casual.git
os.chdir('tranSTR_Casual/causalvid' if os.path.exists('tranSTR_Casual/causalvid') else 'tranSTR_Casual')
print(f'CWD: {os.getcwd()}')

In [None]:
# CELL 2: HuggingFace
print('=== CELL 2 ===')
!pip install -q huggingface_hub
from huggingface_hub import login, HfApi, hf_hub_download, list_repo_tree
# notebook_login()
login(token='YOUR_HF_TOKEN') # Replace with your actual token

In [None]:
# CELL 3: Imports
print('=== CELL 3: Imports ===')
import os, torch, numpy as np, pandas as pd, tarfile, shutil, json
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from utils.util import set_seed, set_gpu_devices
from DataLoader import VideoQADataset
from networks.model import VideoQAmodel
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
print('Imports OK')

In [None]:
# CELL 4: Train/Eval functions
print('=== CELL 4 ===')

def train_epoch(model, optimizer, loader, xe, device, scaler):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for batch in loader:
        ff, of, q, a, ans_id, _ = batch
        ff, of, tgt = ff.to(device), of.to(device), ans_id.to(device)
        with torch.amp.autocast('cuda', enabled=True):
            out = model(ff, of, q, a)
            loss = xe(out, tgt)
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_loss += loss.item()
        correct += (out.argmax(-1) == tgt).sum().item()
        total += tgt.size(0)
    return total_loss / len(loader), correct / total * 100

def eval_epoch(model, loader, device):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for batch in loader:
            ff, of, q, a, ans_id, _ = batch
            out = model(ff.to(device), of.to(device), q, a)
            correct += (out.argmax(-1) == ans_id.to(device)).sum().item()
            total += ans_id.size(0)
    return correct / total * 100

print('Functions defined')

In [None]:
# CELL 5: Organize Kaggle Object Features Dataset
# This cell reads pre-extracted object features from Kaggle input and organizes them
# into the directory structure expected by DataLoader.py
print('=== CELL 5: Organize Object Features ===')
import pickle as pkl
from tqdm.auto import tqdm

# ============================================
# CONFIGURATION - UPDATE THESE PATHS
# ============================================
KAGGLE_INPUT_PATH = '/kaggle/input/YOUR_DATASET_NAME'  # Change this!
BASE = '/kaggle/working' if os.path.exists('/kaggle/working') else os.getcwd()
OBJ_DIR = os.path.join(BASE, 'features', 'objects')
MODEL_DIR = os.path.join(BASE, 'models')

os.makedirs(OBJ_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)

# ============================================
# CHECK IF ALREADY ORGANIZED
# ============================================
def count_organized_videos():
    """Count videos that have been organized (have subdirectories with pkl files)"""
    if not os.path.exists(OBJ_DIR):
        return 0
    count = 0
    for d in os.listdir(OBJ_DIR):
        dp = os.path.join(OBJ_DIR, d)
        if os.path.isdir(dp) and not d.startswith('.'):
            # Check if it has pkl files inside
            if any(f.endswith('.pkl') for f in os.listdir(dp)):
                count += 1
    return count

organized_count = count_organized_videos()
already_organized = organized_count > 0

if already_organized:
    print(f'Data already organized: {organized_count} videos in {OBJ_DIR}')
    
    # Sample check: show first video's structure
    sample_dirs = [d for d in os.listdir(OBJ_DIR) 
                   if os.path.isdir(os.path.join(OBJ_DIR, d)) and not d.startswith('.')][:3]
    for sample_video in sample_dirs:
        sample_path = os.path.join(OBJ_DIR, sample_video)
        pkl_files = sorted([f for f in os.listdir(sample_path) if f.endswith('.pkl')])
        print(f'  Sample: "{sample_video}" has {len(pkl_files)} frame files')
        
        # Verify structure of first pkl
        if pkl_files:
            with open(os.path.join(sample_path, pkl_files[0]), 'rb') as f:
                data = pkl.load(f)
            if isinstance(data, dict):
                print(f'    Keys: {list(data.keys())}')
                if 'feat' in data:
                    print(f'    feat shape: {np.array(data["feat"]).shape}')
                if 'bbox' in data:
                    print(f'    bbox shape: {np.array(data["bbox"]).shape}')

else:
    print('Organizing object features from Kaggle dataset...')
    
    if not os.path.exists(KAGGLE_INPUT_PATH):
        print(f'ERROR: Kaggle input path not found: {KAGGLE_INPUT_PATH}')
        print('Please update KAGGLE_INPUT_PATH to match your dataset location')
    else:
        # Find all subdirectories containing pkl files
        subdirs = []
        for d in os.listdir(KAGGLE_INPUT_PATH):
            dp = os.path.join(KAGGLE_INPUT_PATH, d)
            if os.path.isdir(dp):
                subdirs.append(d)
        
        print(f'Found {len(subdirs)} subdirectories: {subdirs}')
        
        # Count total pkl files first
        all_pkl_files = []
        for subdir in subdirs:
            subdir_path = os.path.join(KAGGLE_INPUT_PATH, subdir)
            pkl_files = [f for f in os.listdir(subdir_path) 
                        if f.endswith('.pkl') and not f.startswith('._')]
            for pf in pkl_files:
                all_pkl_files.append((subdir, pf))
        
        print(f'Total pkl files to process: {len(all_pkl_files)}')
        
        video_count = 0
        frame_count = 0
        errors = []
        
        # Process with progress bar
        for subdir, pkl_file in tqdm(all_pkl_files, desc='Organizing'):
            video_id = pkl_file[:-4]  # Remove '.pkl' extension
            video_dir = os.path.join(OBJ_DIR, video_id)
            src_file = os.path.join(KAGGLE_INPUT_PATH, subdir, pkl_file)
            
            try:
                # Skip if already processed
                if os.path.exists(video_dir) and os.listdir(video_dir):
                    video_count += 1
                    continue
                
                os.makedirs(video_dir, exist_ok=True)
                
                # Load the pkl file
                with open(src_file, 'rb') as f:
                    data = pkl.load(f)
                
                # Determine structure and extract features
                if isinstance(data, dict):
                    feats = data.get('feat', data.get('features'))
                    bboxes = data.get('bbox', data.get('boxes', data.get('box')))
                    img_w = data.get('img_w', 640)
                    img_h = data.get('img_h', 480)
                elif isinstance(data, (tuple, list)) and len(data) >= 2:
                    feats, bboxes = data[0], data[1]
                    img_w, img_h = 640, 480
                else:
                    errors.append(f'{video_id}: Unknown pkl structure type={type(data)}')
                    continue
                
                # Validate data
                if feats is None or bboxes is None:
                    errors.append(f'{video_id}: feats or bboxes is None')
                    continue
                
                # Convert to numpy
                if not isinstance(feats, np.ndarray):
                    feats = np.array(feats)
                if not isinstance(bboxes, np.ndarray):
                    bboxes = np.array(bboxes)
                
                # Determine number of frames based on shape
                # Expected: [num_frames, num_objects, feat_dim] or [num_objects, feat_dim]
                if len(feats.shape) == 3:
                    num_frames = feats.shape[0]
                elif len(feats.shape) == 2:
                    # Single frame case
                    feats = feats[np.newaxis, ...]
                    bboxes = bboxes[np.newaxis, ...]
                    num_frames = 1
                else:
                    errors.append(f'{video_id}: Unexpected feat shape {feats.shape}')
                    continue
                
                # Validate bbox shape matches
                if bboxes.shape[0] != num_frames:
                    errors.append(f'{video_id}: bbox frames {bboxes.shape[0]} != feat frames {num_frames}')
                    continue
                
                # Split into individual frame pkl files
                for frame_idx in range(num_frames):
                    frame_data = {
                        'feat': feats[frame_idx].astype(np.float32),
                        'bbox': bboxes[frame_idx].astype(np.float32),
                        'img_w': int(img_w),
                        'img_h': int(img_h)
                    }
                    
                    frame_pkl_path = os.path.join(video_dir, f'{frame_idx}.pkl')
                    with open(frame_pkl_path, 'wb') as f:
                        pkl.dump(frame_data, f, protocol=pkl.HIGHEST_PROTOCOL)
                    
                    frame_count += 1
                
                video_count += 1
                    
            except Exception as e:
                errors.append(f'{video_id}: {str(e)}')
                # Clean up failed directory
                if os.path.exists(video_dir):
                    import shutil
                    shutil.rmtree(video_dir, ignore_errors=True)
        
        # Summary
        print(f'\n{"="*50}')
        print(f'ORGANIZATION COMPLETE')
        print(f'{"="*50}')
        print(f'  Videos processed: {video_count}')
        print(f'  Total frames: {frame_count}')
        print(f'  Output directory: {OBJ_DIR}')
        
        if errors:
            print(f'\n  Errors encountered: {len(errors)}')
            for err in errors[:10]:
                print(f'    - {err}')
            if len(errors) > 10:
                print(f'    ... and {len(errors) - 10} more errors')
        
        # Verify a sample
        print(f'\nVerification sample:')
        sample_dirs = [d for d in os.listdir(OBJ_DIR) 
                       if os.path.isdir(os.path.join(OBJ_DIR, d))][:2]
        for sd in sample_dirs:
            sp = os.path.join(OBJ_DIR, sd)
            files = os.listdir(sp)
            print(f'  {sd}: {len(files)} files')


In [None]:
# CELL 6: Config
print('=== CELL 6: Config ===')

RUN_TRAINING = True
HF_REPO_ID = 'DanielQ07/transtr-causalvid-weights'
HF_MODEL_FILENAME = 'best_model.ckpt'

class Config:
    video_feature_root = '/kaggle/input/YOUR_VIT_DATASET'
    object_feature_path = OBJ_DIR
    sample_list_path = os.path.join(os.getcwd(), '..', 'data', 'vqa', 'causal', 'anno')
    split_dir_txt = os.path.join(os.getcwd(), '..', 'data', 'splits')
    
    topK_frame = 16; objs = 20; frames = 16
    select_frames = 5; topK_obj = 12
    frame_feat_dim = 1024; obj_feat_dim = 2053
    d_model = 768; word_dim = 768; nheads = 8
    num_encoder_layers = 2; num_decoder_layers = 2
    normalize_before = True; activation = 'gelu'
    dropout = 0.3; encoder_dropout = 0.3
    text_encoder_type = 'microsoft/deberta-base'
    freeze_text_encoder = False; text_encoder_lr = 1e-5; text_pool_mode = 1
    bs = 8; lr = 1e-5; epoch = 20; gpu = 0
    patience = 5; gamma = 0.1; decay = 1e-4; n_query = 5
    hard_eval = False; pos_ratio = 1.0; neg_ratio = 1.0; a = 1.0
    use_amp = True; num_workers = 4

args = Config()
set_gpu_devices(args.gpu)
set_seed(999)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

In [None]:
# CELL 7: Create Datasets with Verification
print('=== CELL 7: Datasets ===')

# Configuration for limiting train samples (set to None for no limit)
MAX_TRAIN_SAMPLES = 1000  # Change this to limit training videos, or None for all

# Create datasets with detailed logging
print('\n--- Creating TRAIN dataset ---')
train_ds = VideoQADataset(
    split='train', 
    n_query=args.n_query, 
    obj_num=args.objs, 
    sample_list_path=args.sample_list_path, 
    video_feature_path=args.video_feature_root, 
    object_feature_path=args.object_feature_path, 
    split_dir=args.split_dir_txt, 
    topK_frame=args.topK_frame,
    max_samples=MAX_TRAIN_SAMPLES,
    verbose=True
)

print('\n--- Creating VAL dataset ---')
val_ds = VideoQADataset(
    split='val', 
    n_query=args.n_query, 
    obj_num=args.objs, 
    sample_list_path=args.sample_list_path, 
    video_feature_path=args.video_feature_root, 
    object_feature_path=args.object_feature_path, 
    split_dir=args.split_dir_txt, 
    topK_frame=args.topK_frame,
    max_samples=None,  # Don't limit val/test
    verbose=True
)

print('\n--- Creating TEST dataset ---')
test_ds = VideoQADataset(
    split='test', 
    n_query=args.n_query, 
    obj_num=args.objs, 
    sample_list_path=args.sample_list_path, 
    video_feature_path=args.video_feature_root, 
    object_feature_path=args.object_feature_path, 
    split_dir=args.split_dir_txt, 
    topK_frame=args.topK_frame,
    max_samples=None,
    verbose=True
)

# Create DataLoaders
train_loader = DataLoader(train_ds, args.bs, shuffle=True, num_workers=args.num_workers, pin_memory=True)
val_loader = DataLoader(val_ds, args.bs, shuffle=False, num_workers=args.num_workers, pin_memory=True)
test_loader = DataLoader(test_ds, args.bs, shuffle=False, num_workers=args.num_workers, pin_memory=True)

# Summary
print('\n' + '='*60)
print('DATASET SUMMARY')
print('='*60)
print(f'Train: {len(train_ds)} samples -> {len(train_loader)} batches')
print(f'Val:   {len(val_ds)} samples -> {len(val_loader)} batches')
print(f'Test:  {len(test_ds)} samples -> {len(test_loader)} batches')
print('='*60)

# Quick sanity check - load one batch
if len(train_ds) > 0:
    print('\nSanity check: Loading first batch...')
    try:
        ff, of, qns, ans, ans_id, keys = next(iter(train_loader))
        print(f'  ViT features: {ff.shape}')  # Expected: [batch, topK_frame, feat_dim]
        print(f'  Object features: {of.shape}')  # Expected: [batch, topK_frame, obj_num, 2053]
        print(f'  Answer IDs: {ans_id}')
        print('Sanity check PASSED!')
    except Exception as e:
        print(f'  ERROR: {e}')
        import traceback
        traceback.print_exc()


In [None]:
# CELL 8: Model
print('=== CELL 8: Model ===')
cfg = {k: v for k, v in Config.__dict__.items() if not k.startswith('_')}
cfg['device'] = device
cfg['topK_frame'] = args.select_frames
model = VideoQAmodel(**cfg)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
scheduler = ReduceLROnPlateau(optimizer, 'max', factor=args.gamma, patience=args.patience)
model.to(device)
xe = nn.CrossEntropyLoss()
scaler = torch.amp.GradScaler('cuda', enabled=True)
save_path = os.path.join(MODEL_DIR, HF_MODEL_FILENAME)
print(f'Model: {sum(p.numel() for p in model.parameters())/1e6:.1f}M params')

In [None]:
# CELL 9: Training
print('=== CELL 9: Training ===')
best_acc = 0

if RUN_TRAINING:
    for ep in range(1, args.epoch + 1):
        loss, acc = train_epoch(model, optimizer, train_loader, xe, device, scaler)
        val_acc = eval_epoch(model, val_loader, device)
        scheduler.step(val_acc)
        print(f'Ep {ep}: Loss={loss:.4f}, Train={acc:.1f}%, Val={val_acc:.1f}%')
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), save_path)
            print(f'  Saved!')
    print(f'\nBest Val: {best_acc:.1f}%')
    try:
        api = HfApi()
        api.create_repo(repo_id=HF_REPO_ID, repo_type='model', exist_ok=True)
        api.upload_file(path_or_fileobj=save_path, path_in_repo=HF_MODEL_FILENAME, repo_id=HF_REPO_ID, repo_type='model')
        print('Uploaded!')
    except Exception as e:
        print(f'Upload failed: {e}')
else:
    print('Skipping training (RUN_TRAINING=False)')

In [None]:
# CELL 10: Detailed Evaluation Function (with ALL score)
print('=== CELL 10: Evaluation Functions ===')

def run_detailed_evaluation(model, loader, device, split_name='test', save_file='failure_cases.json'):
    # Load weights
    loaded = False
    if RUN_TRAINING:
        if os.path.exists(save_path):
            model.load_state_dict(torch.load(save_path))
            print(f'Loaded Locally Trained Model: {save_path}')
            loaded = True
    else:
        try:
            print(f'Downloading {HF_MODEL_FILENAME} from {HF_REPO_ID}...')
            local_model_path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_MODEL_FILENAME, local_dir=MODEL_DIR)
            model.load_state_dict(torch.load(local_model_path))
            print(f'Loaded HF Model: {local_model_path}')
            loaded = True
        except Exception as e:
            print(f'Could not download model: {e}')

    model.eval()
    vid_results = {}
    failures = []
    
    type_map = {
        'descriptive': 'd', 'explanatory': 'e',
        'predictive': 'p', 'predictive_reason': 'pr',
        'counterfactual': 'c', 'counterfactual_reason': 'cr'
    }
    
    print(f'\nRunning Detailed Evaluation on {split_name.upper()} Set...')
    with torch.no_grad():
        for batch in loader:
            vid_frame_feat, vid_obj_feat, qns_word, ans_word, ans_id, qns_keys = batch
            vid_frame_feat = vid_frame_feat.to(device)
            vid_obj_feat = vid_obj_feat.to(device)
            out = model(vid_frame_feat, vid_obj_feat, qns_word, ans_word)
            preds = out.argmax(dim=-1).cpu().numpy()
            targets = ans_id.numpy()
            
            for i, qkey in enumerate(qns_keys):
                found_type = None
                vid_id = None
                for t_str, t_short in type_map.items():
                    if qkey.endswith('_' + t_str):
                        found_type = t_short
                        vid_id = qkey[:-(len(t_str)+1)]
                        break
                if not found_type:
                    continue
                if vid_id not in vid_results:
                    vid_results[vid_id] = {}
                
                is_correct = (preds[i] == targets[i])
                vid_results[vid_id][found_type] = {'correct': is_correct}
                
                if not is_correct:
                    failures.append({
                        'video_id': vid_id,
                        'type': found_type,
                        'question': qns_word[i],
                        'pred': int(preds[i]),
                        'ground_truth': int(targets[i])
                    })

    # Calculate stats (including ALL)
    stats = {k: {'correct': 0, 'total': 0} for k in ['d', 'e', 'p', 'pr', 'c', 'cr', 'par', 'car', 'all']}
    for vid, res in vid_results.items():
        for t in ['d', 'e', 'p', 'pr', 'c', 'cr']:
            if t in res:
                stats[t]['total'] += 1
                stats['all']['total'] += 1
                if res[t]['correct']:
                    stats[t]['correct'] += 1
                    stats['all']['correct'] += 1
        # Combined PAR
        if 'p' in res and 'pr' in res:
            stats['par']['total'] += 1
            if res['p']['correct'] and res['pr']['correct']:
                stats['par']['correct'] += 1
        # Combined CAR
        if 'c' in res and 'cr' in res:
            stats['car']['total'] += 1
            if res['c']['correct'] and res['cr']['correct']:
                stats['car']['correct'] += 1

    # Print results with ALL
    labels, accs = [], []
    print(f"\n{'Type':<6} {'Acc %':<10} {'Cor':<6} {'Tot':<6}")
    print('-' * 35)
    for k in ['d', 'e', 'p', 'pr', 'par', 'c', 'cr', 'car']:
        s = stats[k]
        acc = s['correct'] / s['total'] * 100 if s['total'] > 0 else 0
        print(f"{k.upper():<6} {acc:<10.2f} {s['correct']:<6} {s['total']:<6}")
        if s['total'] > 0:
            labels.append(k.upper())
            accs.append(acc)
    print('-' * 35)
    # Print ALL score
    all_s = stats['all']
    all_acc = all_s['correct'] / all_s['total'] * 100 if all_s['total'] > 0 else 0
    print(f"{'ALL':<6} {all_acc:<10.2f} {all_s['correct']:<6} {all_s['total']:<6}")
    print('=' * 35)
    labels.append('ALL')
    accs.append(all_acc)

    # Plot
    plt.figure(figsize=(12, 5))
    colors = ['steelblue'] * (len(labels) - 1) + ['darkgreen']
    bars = plt.bar(labels, accs, color=colors)
    plt.ylim(0, 105)
    plt.ylabel('Accuracy (%)')
    plt.title(f'Performance by Question Type ({split_name.upper()} Set)')
    for bar in bars:
        y = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2, y + 1, f'{y:.1f}', ha='center', va='bottom')
    plt.axhline(y=20, color='gray', linestyle='--', alpha=0.5, label='Random (20%)')
    plt.legend()
    plt.savefig(f'{split_name}_results.png', dpi=150)
    print(f'\nSaved: {split_name}_results.png')
    plt.show()

    # Save failures
    failure_file = f'{split_name}_{save_file}'
    with open(failure_file, 'w') as f:
        json.dump(failures, f, indent=4)
    print(f'Saved {len(failures)} failure cases to {failure_file}')
    
    return stats, failures

print('Evaluation function defined (with ALL score)')

In [None]:
# CELL 11: Evaluate on TEST set
print('=== CELL 11: TEST Set Evaluation ===')
test_stats, test_failures = run_detailed_evaluation(model, test_loader, device, split_name='test')

In [None]:
# CELL 12: Evaluate on VALIDATION set
print('=== CELL 12: VALIDATION Set Evaluation ===')
val_stats, val_failures = run_detailed_evaluation(model, val_loader, device, split_name='val')

In [None]:
# CELL 13: Summary CSV
print('=== CELL 13: Summary ===')

type_keys = ['d', 'e', 'p', 'pr', 'par', 'c', 'cr', 'car', 'all']
type_names = ['D', 'E', 'P', 'PR', 'PAR', 'C', 'CR', 'CAR', 'ALL']

summary_data = []
for k, name in zip(type_keys, type_names):
    v = val_stats[k]
    t = test_stats[k]
    summary_data.append({
        'Type': name,
        'Val_Correct': v['correct'],
        'Val_Total': v['total'],
        'Val_Acc%': round(v['correct']/v['total']*100, 2) if v['total'] > 0 else 0,
        'Test_Correct': t['correct'],
        'Test_Total': t['total'],
        'Test_Acc%': round(t['correct']/t['total']*100, 2) if t['total'] > 0 else 0
    })

df = pd.DataFrame(summary_data)
df.to_csv('evaluation_summary.csv', index=False)
print('Saved: evaluation_summary.csv\n')
print(df.to_string(index=False))