# Environment Setup (Kaggle)
This cell clones the repository and sets the working directory.

In [None]:
import os
# --- Git Clone & Setup ---
REPO_URL = "https://github.com/DanielQH07/tranSTR_Casual.git" 
REPO_NAME = "tranSTR_Casual"
BRANCH = "main" 

if not os.path.exists(REPO_NAME):
    print(f"Cloning {REPO_URL}...")
    !git clone {REPO_URL}
else:
    print("Repo already exists.")

# Change Directory to the repo root 
if os.path.basename(os.getcwd()) != REPO_NAME:
    try:
        target_dir = os.path.join(os.getcwd(), REPO_NAME, "causalvid")
        if os.path.exists(target_dir):
             os.chdir(target_dir)
        elif os.path.exists(REPO_NAME):
             os.chdir(REPO_NAME)
        
        print(f"Changed directory to: {os.getcwd()}")
    except Exception as e:
             print(f"Could not set working directory: {e}")

In [None]:
# --- Install & Login Hugging Face ---
# Requires for uploading/downloading weights
!pip install huggingface_hub
from huggingface_hub import notebook_login, HfApi, hf_hub_download, list_repo_tree

print("Please login to Hugging Face to enable model upload/download.")
notebook_login()

In [None]:
import os
import torch
import numpy as np
import pandas as pd
import argparse
from torch.utils.data import DataLoader
from utils.util import set_seed, set_gpu_devices, save_file
from networks.model import VideoQAmodel
from DataLoader import VideoQADataset
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from utils.logger import logger

from train import train, eval, predict

print("Libraries imported successfully.")

## 0. Data & Pipeline Configuration
Configure whether to Train or just Test, and where to store Model Weights on HF.

In [None]:
# --- WORKFLOW CONTROL ---
RUN_TRAINING = True   # Set to False to skip training and load weights from HF
HF_REPO_ID = "DanielQ07/transtr-causalvid-weights" # Your Model Repo ID
HF_MODEL_FILENAME = "best_model_causalvid.ckpt" # Name of the file on HF

# --- Data Configuration ---
HF_DATASET_ID = "DanielQ07/kltn"
HF_SHARD_FOLDER = "shards" # Folder inside the dataset containing tar.gz files

import os
import tarfile
import shutil
import urllib.request

BASE_WORK_DIR = "/kaggle/working" if os.path.exists("/kaggle/working") else os.getcwd()
OBJ_DIR = os.path.join(BASE_WORK_DIR, "features", "objects")
ANNO_DIR = os.path.join(BASE_WORK_DIR, "data", "annotations") 
MODEL_DIR = os.path.join(BASE_WORK_DIR, "models")

for d in [OBJ_DIR, ANNO_DIR, MODEL_DIR]:
    if not os.path.exists(d): os.makedirs(d)

print(f"Object Features will be in: {OBJ_DIR}")

# --- 1. Download & Extract Shards ---
# We only proceed if OBJ_DIR seems empty (no video folders)
# Heuristic: check if any subdirectory exists
has_subdirs = any(os.path.isdir(os.path.join(OBJ_DIR, i)) for i in os.listdir(OBJ_DIR))

if not has_subdirs:
    print(f"Fetching shards from {HF_DATASET_ID}/{HF_SHARD_FOLDER}...")
    try:
        # List files in the shards folder
        files = list_repo_tree(repo_id=HF_DATASET_ID, repo_type="dataset", path_in_repo=HF_SHARD_FOLDER)
        tar_files = [f.path for f in files if f.path.endswith('.tar.gz')]
        print(f"Found {len(tar_files)} shards to process.")
        
        for remote_path in tar_files:
            filename = os.path.basename(remote_path)
            print(f"Downloading & Extracting {filename}...")
            
            # Download (caches implicitly, but we copy/link/use path)
            local_tar = hf_hub_download(repo_id=HF_DATASET_ID, filename=remote_path, repo_type="dataset", local_dir=BASE_WORK_DIR)
            
            # Extract to OBJ_DIR
            with tarfile.open(local_tar, "r:gz") as tar:
                tar.extractall(path=OBJ_DIR)
            
            # Cleanup Tar
            os.remove(local_tar)

        print("Download & Extraction Complete. Checking Structure...")
        # --- Flatten Directory Logic ---
        # Structure might be OBJ_DIR/shard_xxxxx/video_id
        # We want OBJ_DIR/video_id
        moved_count = 0
        for item in os.listdir(OBJ_DIR):
            item_path = os.path.join(OBJ_DIR, item)
            if os.path.isdir(item_path) and (item.startswith("shard_") or item.startswith("train_") or item.startswith("val_")):
                # Valid shard folder, move contents up
                for sub_item in os.listdir(item_path):
                    src_path = os.path.join(item_path, sub_item)
                    dst_path = os.path.join(OBJ_DIR, sub_item)
                    if not os.path.exists(dst_path):
                        shutil.move(src_path, dst_path)
                        moved_count += 1
                # Remove empty shard folder
                os.rmdir(item_path)
        
        print(f"Flattened {moved_count} video folders to root of {OBJ_DIR}")
        
    except Exception as e:
        print(f"Error processing data: {e}")
else:
    print("Features folder not empty, skipping download.")

In [None]:
class Config:
    # --- PATHS ---
    # 1. Video Features (ViT features)
    video_feature_root = r"D:\KLTN\TranSTR\causalvid\features\vit"  
    if not os.path.exists(video_feature_root) and os.path.exists("/kaggle/input"):
        # Example Kaggle Input path - UPDATE THIS
        video_feature_root = "/kaggle/input/your-dataset-name/features/vit"

    # 2. Object Features 
    object_feature_path = OBJ_DIR

    # 3. Annotations 
    sample_list_path = r"D:\KLTN\TranSTR\causalvid\data\vqa\causal\anno"
    if not os.path.exists(sample_list_path):
        repo_anno = os.path.join(os.getcwd(), "..", "data", "vqa", "causal", "anno")
        if os.path.exists(repo_anno):
            sample_list_path = repo_anno
        else:
            sample_list_path = ANNO_DIR

    # 4. Splits 
    split_dir_txt = r"D:\KLTN\TranSTR\causalvid\data\splits"
    if not os.path.exists(split_dir_txt):
         repo_split = os.path.join(os.getcwd(), "..", "data", "splits")
         if os.path.exists(repo_split):
             split_dir_txt = repo_split
    
    print(f"Configured Paths:\n - Video Feat: {video_feature_root}\n - Object Feat: {object_feature_path}\n - Anno Path: {sample_list_path}\n - Split Dir: {split_dir_txt}")

    # Model Params
    v = "v1"
    bs = 8 
    lr = 1e-4
    epoch = 20
    gpu = 0
    dropout = 0.3
    encoder_dropout = 0.3
    patience = 5
    gamma = 0.1
    decay = 1e-4
    
    # Dataset Params
    dataset = 'causal-vid'
    objs = 10 
    n_query = 5
    
    # Transformer
    d_model = 512
    word_dim = 768
    topK_frame = 20
    topK_obj = 10
    num_encoder_layers = 2
    num_decoder_layers = 2
    nheads = 8
    normalize_before = True
    activation = 'gelu'
    
    # Text Encoder
    text_encoder_lr = 1e-5
    freeze_text_encoder = False
    text_encoder_type = "bert-base-uncased"
    text_pool_mode = 1
    
    # Misc
    hard_eval = False
    pos_ratio = 1.0
    neg_ratio = 1.0
    a = 1.0
    use_amp = True
    num_workers = 2 
    frame_feat_dim = 768 
    obj_feat_dim = 2048 + 5 

args = Config()
set_gpu_devices(args.gpu)
set_seed(999)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

## 1. Initialize DataLoaders


In [None]:
print("Initializing Datasets...")

try:
    # TRAIN
    train_dataset = VideoQADataset(
        split='train', 
        n_query=args.n_query, 
        obj_num=args.objs,
        sample_list_path=args.sample_list_path,
        video_feature_path=args.video_feature_root,
        object_feature_path=args.object_feature_path,
        split_dir=args.split_dir_txt,
        topK_frame=args.topK_frame
    )

    # VAL
    val_dataset = VideoQADataset(
        split='val', 
        n_query=args.n_query, 
        obj_num=args.objs,
        sample_list_path=args.sample_list_path,
        video_feature_path=args.video_feature_root,
        object_feature_path=args.object_feature_path,
        split_dir=args.split_dir_txt,
        topK_frame=args.topK_frame
    )

    # TEST
    test_dataset = VideoQADataset(
        split='test', 
        n_query=args.n_query, 
        obj_num=args.objs,
        sample_list_path=args.sample_list_path,
        video_feature_path=args.video_feature_root,
        object_feature_path=args.object_feature_path,
        split_dir=args.split_dir_txt,
        topK_frame=args.topK_frame
    )

    train_loader = DataLoader(dataset=train_dataset, batch_size=args.bs, shuffle=True, 
                              num_workers=args.num_workers, pin_memory=True)
    val_loader = DataLoader(dataset=val_dataset, batch_size=args.bs, shuffle=False, 
                            num_workers=args.num_workers, pin_memory=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=args.bs, shuffle=False, 
                             num_workers=args.num_workers, pin_memory=True)

    print("\nDatasets initialized successfully.")
except Exception as e:
    print(f"\nError initializing datasets: {e}")
    import traceback
    traceback.print_exc()

## 3. Model Setup
Initializing the model logic.

In [None]:
config_dict = {k: v for k, v in Config.__dict__.items() if not k.startswith('__')}
config_dict['device'] = device
model = VideoQAmodel(**config_dict)

param_dicts = [
    {"params": [p for n, p in model.named_parameters() if "text_encoder" not in n and p.requires_grad]},
    {"params": [p for n, p in model.named_parameters() if "text_encoder" in n and p.requires_grad], "lr": args.text_encoder_lr}
]
optimizer = torch.optim.AdamW(params=param_dicts, lr=args.lr, weight_decay=args.decay)
scheduler = ReduceLROnPlateau(optimizer, 'max', factor=args.gamma, patience=args.patience, verbose=True)
model.to(device)
xe = nn.CrossEntropyLoss().to(device)
print("Model and Optimizer created.")

## 4. Train Loop with HF Upload
Runs only if `RUN_TRAINING = True`. Saves best model and Uploads to Hugging Face.

In [None]:
best_eval_score = 0.0
best_epoch = 1
save_path = os.path.join(MODEL_DIR, HF_MODEL_FILENAME)

# --- Create Scaler ---
try:
    scaler = torch.amp.GradScaler('cuda', enabled=args.use_amp)
    print("GradScaler initialized.")
except Exception as e:
    # Fallback for older torch versions
    scaler = torch.cuda.amp.GradScaler(enabled=args.use_amp)
    print("GradScaler initialized (fallback).")

if RUN_TRAINING:
    print("Starting Training Loop...")
    if len(train_loader) > 0:
        for epoch in range(1, args.epoch + 1):
            print(f"Epoch {epoch}/{args.epoch}")
            
            # Pass scaler to train function
            train_loss, train_acc = train(model, optimizer, train_loader, xe, device, use_amp=args.use_amp, scaler=scaler)
            
            eval_score = eval(model, val_loader, device)
            scheduler.step(eval_score)
            
            print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f} | Val Acc: {eval_score:.2f}")
            
            if eval_score > best_eval_score:
                best_eval_score = eval_score
                best_epoch = epoch
                torch.save(model.state_dict(), save_path)
                print(f"  Saved new best model to {save_path}")
    
        print(f"Training Complete. Best Val Acc: {best_eval_score:.2f}")
        
        # --- UPLOAD TO HUGGING FACE ---
        try:
            api = HfApi()
            print(f"Creating/Accessing Repo {HF_REPO_ID}...")
            api.create_repo(repo_id=HF_REPO_ID, repo_type="model", exist_ok=True)
            
            print(f"Uploading {save_path} to Hugging Face...")
            api.upload_file(
                path_or_fileobj=save_path,
                path_in_repo=HF_MODEL_FILENAME,
                repo_id=HF_REPO_ID,
                repo_type="model"
            )
            print("Upload SUCCESS!")
        except Exception as e:
            print(f"Upload FAILED: {e}")
            
    else:
        print("Train Loader empty. Skipping training.")
else:
    print("RUN_TRAINING = False. Skipping Training.")

## 5. Detailed Evaluation (Test Mode)
If not training, it downloads weights. Then it runs the detailed analysis.

In [None]:
import matplotlib.pyplot as plt
import json

def run_detailed_evaluation(model, loader, device, save_file='failure_cases.json'):
    # --- LOAD WEIGHTS LOGIC ---
    loaded = False
    if RUN_TRAINING:
        # Use local file if we just trained
        if os.path.exists(save_path):
            model.load_state_dict(torch.load(save_path))
            print(f"Loaded Locally Trained Model: {save_path}")
            loaded = True
    else:
        # Download from HF
        try:
            print(f"Downloading {HF_MODEL_FILENAME} from {HF_REPO_ID}...")
            local_model_path = hf_hub_download(repo_id=HF_REPO_ID, filename=HF_MODEL_FILENAME, local_dir=MODEL_DIR)
            model.load_state_dict(torch.load(local_model_path))
            print(f"Loaded HF Model: {local_model_path}")
            loaded = True
        except Exception as e:
            print(f"Could not download model from HF: {e}\nRunning with random weights (Warning!)")

    
    model.eval()
    vid_results = {} 
    failures = []
    
    type_map = {
        'descriptive': 'd',
        'explanatory': 'e',
        'predictive': 'p',
        'predictive_reason': 'pr',
        'counterfactual': 'c',
        'counterfactual_reason': 'cr'
    }
    
    print("Running Detailed Evaluation on Test Set...")
    with torch.no_grad():
        for batch in loader:
            vid_frame_feat, vid_obj_feat, qns_word, ans_word, ans_id, qns_keys = batch
            vid_frame_feat, vid_obj_feat = vid_frame_feat.to(device), vid_obj_feat.to(device)
            out = model(vid_frame_feat, vid_obj_feat, qns_word, ans_word)
            preds = out.argmax(dim=-1).cpu().numpy()
            targets = ans_id.numpy()
            
            for i, qkey in enumerate(qns_keys):
                found_type = None
                vid_id = None
                for t_str, t_short in type_map.items():
                    if qkey.endswith('_' + t_str):
                        found_type = t_short
                        vid_id = qkey[:-(len(t_str)+1)]
                        break
                if not found_type: continue
                if vid_id not in vid_results: vid_results[vid_id] = {}
                
                is_correct = (preds[i] == targets[i])
                vid_results[vid_id][found_type] = {'correct': is_correct}
                
                if not is_correct:
                    failures.append({
                        'video_id': vid_id,
                        'type': found_type,
                        'question': qns_word[i],
                        'pred': int(preds[i]),
                        'ground_truth': int(targets[i])
                    })

    stats = {k: {'correct':0, 'total':0} for k in ['d','e','p','pr','c','cr','par','car']}
    for vid, res in vid_results.items():
        for t in ['d','e','p','pr','c','cr']:
            if t in res:
                stats[t]['total'] += 1
                if res[t]['correct']: stats[t]['correct'] += 1
        # Combined PAR
        if 'p' in res and 'pr' in res:
            stats['par']['total'] += 1
            if res['p']['correct'] and res['pr']['correct']: stats['par']['correct'] += 1
        # Combined CAR
        if 'c' in res and 'cr' in res:
            stats['car']['total'] += 1
            if res['c']['correct'] and res['cr']['correct']: stats['car']['correct'] += 1

    # Print
    labels, accs = [], []
    print(f"{'Type':<6} {'Acc %':<10} {'Cor':<6} {'Tot':<6}")
    print("-"*35)
    for k in ['d','e','p','pr','par','c','cr','car']:
        s = stats[k]
        acc = s['correct']/s['total']*100 if s['total'] > 0 else 0
        print(f"{k.upper():<6} {acc:<10.2f} {s['correct']:<6} {s['total']:<6}")
        if s['total'] > 0: 
            labels.append(k.upper())
            accs.append(acc)

    # Plot
    plt.figure(figsize=(10, 5))
    bars = plt.bar(labels, accs, color='steelblue')
    plt.ylim(0, 105)
    plt.ylabel('Accuracy (%)')
    plt.title('Performance by Question Type')
    for bar in bars:
        y = bar.get_height()
        plt.text(bar.get_x()+bar.get_width()/2, y+1, f"{y:.1f}", ha='center', va='bottom')
    plt.show()

    # Save Failures
    with open(save_file, 'w') as f:
        json.dump(failures, f, indent=4)
    print(f"\nSaved {len(failures)} failure cases to {save_file}")
    return stats, failures

stats, failures = run_detailed_evaluation(model, test_loader, device)

In [None]:
import pandas as pd
import seaborn as sns
from IPython.display import display

print("\n=== Detailed Accuracy Metrics ===")

# Convert stats to DataFrame
data = []
# Focus on the 6 main types as requested (D, E, P, PR, C, CR)
target_types = ['d', 'e', 'p', 'pr', 'c', 'cr']
for k in target_types:
    if k in stats:
        s = stats[k]
        acc = s['correct'] / s['total'] * 100 if s['total'] > 0 else 0
        data.append({
            'Type': k.upper(),
            'Total': s['total'],
            'Correct': s['correct'],
            'Wrong': s['total'] - s['correct'],
            'Accuracy (%)': f"{acc:.2f}"
        })

df_results = pd.DataFrame(data)
display(df_results) # Pretty print

# Visualization of Wrong Answers
print("\n=== Failure Distribution (Count of Wrong Answers) ===")
if not df_results.empty:
    plt.figure(figsize=(10, 5))
    # Need to convert back to numeric for plotting
    df_results['Wrong'] = pd.to_numeric(df_results['Wrong'])
    sns.barplot(x='Type', y='Wrong', data=df_results, palette='Reds')
    plt.title("Number of Incorrect Answers by Question Type")
    plt.ylabel("Count of Failures")
    plt.show()