# TranSTR CausalVid - Training Notebook

In [None]:
import os
REPO_URL = "https://github.com/DanielQH07/tranSTR_Casual.git"
REPO_NAME = "tranSTR_Casual"
if not os.path.exists(REPO_NAME):
    !git clone {REPO_URL}
os.chdir(os.path.join(REPO_NAME, "causalvid") if os.path.exists(os.path.join(REPO_NAME, "causalvid")) else REPO_NAME)
print(f"CWD: {os.getcwd()}")

In [None]:
# === PATCH DataLoader.py to fix frame/object dimension mismatch ===
DATALOADER_CODE = '''
import torch
import os
import re
import json
import pandas as pd
import pickle as pkl
import os.path as osp
import numpy as np
from torch.utils.data import Dataset
from utils.util import transform_bb

class VideoQADataset(Dataset):
    def __init__(self, split, n_query=5, obj_num=10, sample_list_path="/anno",
         video_feature_path="/vit", object_feature_path="/obj", split_dir=None, topK_frame=16):
        super().__init__()
        self.split = split
        self.mc = n_query
        self.obj_num = obj_num
        self.video_feature_path = video_feature_path
        self.object_feature_path = object_feature_path
        self.topK_frame = topK_frame

        valid_vids = set()
        if split_dir:
            txt_name = "valid" if split == "val" else split
            txt_path = osp.join(split_dir, f"{txt_name}.txt")
            if osp.exists(txt_path):
                with open(txt_path) as f:
                    valid_vids = {l.strip() for l in f if l.strip()}

        if not valid_vids and os.path.isdir(sample_list_path):
            valid_vids = {d for d in os.listdir(sample_list_path) if os.path.isdir(osp.join(sample_list_path, d))}

        data_rows = []
        for vid in valid_vids:
            vid_path = osp.join(sample_list_path, vid)
            t_json, a_json = osp.join(vid_path, "text.json"), osp.join(vid_path, "answer.json")
            if not (osp.exists(t_json) and osp.exists(a_json)): continue
            try:
                with open(t_json, encoding="utf-8") as f: t_data = json.load(f)
                with open(a_json, encoding="utf-8") as f: a_data = json.load(f)
                for key in ["descriptive", "explanatory", "predictive", "counterfactual"]:
                    if key in t_data and key in a_data:
                        q, a = t_data[key], a_data[key]
                        if "question" in q and "answer" in q and "answer" in a:
                            row = {"video_id": vid, "question": q["question"], "answer": a["answer"], "type": key}
                            for i, c in enumerate(q["answer"]): row[f"a{i}"] = c
                            data_rows.append(row)
                        if key in ["predictive", "counterfactual"] and "reason" in q and "reason" in a:
                            row = {"video_id": vid, "question": "Why?", "answer": a["reason"], "type": f"{key}_reason"}
                            for i, c in enumerate(q["reason"]): row[f"a{i}"] = c
                            data_rows.append(row)
            except: pass

        self.sample_list = pd.DataFrame(data_rows)
        print(f"Loaded {len(self.sample_list)} QA pairs.")
        
        if len(self.sample_list) > 0:
            existing = {v for v in self.sample_list["video_id"].unique()
                        if osp.exists(osp.join(self.video_feature_path, self.split, f"{v}.pt"))}
            self.sample_list = self.sample_list[self.sample_list["video_id"].isin(existing)]
            print(f"Final: {len(self.sample_list)}")

    def __getitem__(self, idx):
        cur = self.sample_list.iloc[idx]
        vid = str(cur["video_id"])
        qns = str(cur["question"])
        ans_id = int(cur["answer"])
        ans_word = [f"[CLS] {qns} [SEP] {cur[f\"a{i}\"]}" for i in range(self.mc)]

        # 1. Frame features
        frame_feat = torch.load(osp.join(self.video_feature_path, self.split, f"{vid}.pt"))
        if isinstance(frame_feat, np.ndarray): frame_feat = torch.from_numpy(frame_feat)
        frame_feat = frame_feat.float()
        
        # SAMPLE/PAD to topK_frame
        nf = frame_feat.shape[0]
        if nf > self.topK_frame:
            idx = np.linspace(0, nf-1, self.topK_frame).astype(int)
            frame_feat = frame_feat[idx]
        elif nf < self.topK_frame:
            pad = torch.zeros((self.topK_frame - nf, frame_feat.shape[1]))
            frame_feat = torch.cat([frame_feat, pad], 0)

        # 2. Object features
        obj_dir = osp.join(self.object_feature_path, vid)
        obj_feats = []
        def num(f): m = re.findall(r"\\d+", f); return int(m[-1]) if m else -1
        if osp.isdir(obj_dir):
            pkls = sorted([f for f in os.listdir(obj_dir) if f.endswith(".pkl") and not f.startswith("._")], key=num)
            npkl = len(pkls)
            idxs = np.linspace(0, npkl-1, self.topK_frame).astype(int) if npkl > 0 else []
            for i in idxs:
                try:
                    with open(osp.join(obj_dir, pkls[i]), "rb") as fp: c = pkl.load(fp)
                    feat = c.get("feat", c.get("features")) if isinstance(c, dict) else c[0]
                    bbox = c.get("bbox", c.get("boxes", c.get("box"))) if isinstance(c, dict) else c[1]
                    w = c.get("img_w", 640) if isinstance(c, dict) else 640
                    h = c.get("img_h", 480) if isinstance(c, dict) else 480
                    if isinstance(feat, np.ndarray): feat = torch.from_numpy(feat)
                    if isinstance(bbox, np.ndarray): bbox = torch.from_numpy(bbox)
                    if feat.shape[0] > self.obj_num: feat, bbox = feat[:self.obj_num], bbox[:self.obj_num]
                    elif feat.shape[0] < self.obj_num:
                        p = self.obj_num - feat.shape[0]
                        feat = torch.cat([feat, torch.zeros(p, feat.shape[1])], 0)
                        bbox = torch.cat([bbox, torch.zeros(p, bbox.shape[1])], 0)
                    bb = torch.from_numpy(transform_bb(bbox.numpy(), w, h)).float()
                    obj_feats.append(torch.cat([feat.float(), bb], -1))
                except:
                    obj_feats.append(torch.zeros(self.obj_num, 2053))
        while len(obj_feats) < self.topK_frame:
            obj_feats.append(torch.zeros(self.obj_num, 2053))
        obj_feat = torch.stack(obj_feats)  # [topK, obj, 2053]

        qns_key = f"{vid}_{cur[\"type\"]}"
        return frame_feat, obj_feat, qns, ans_word, ans_id, qns_key

    def __len__(self): return len(self.sample_list)
'''
with open('DataLoader.py', 'w') as f:
    f.write(DATALOADER_CODE)
print('DataLoader.py patched!')

In [None]:
!pip install -q huggingface_hub
from huggingface_hub import notebook_login, HfApi, hf_hub_download, list_repo_tree
notebook_login()

In [None]:
import os, torch, numpy as np, pandas as pd
from torch.utils.data import DataLoader
from utils.util import set_seed, set_gpu_devices
from networks.model import VideoQAmodel
from DataLoader import VideoQADataset
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
from huggingface_hub import hf_hub_download, list_repo_tree, HfApi
print('Imports OK')

In [None]:
def train(model, optimizer, loader, xe, device, use_amp=True, scaler=None):
    model.train()
    total_loss, preds, gts = 0, [], []
    for batch in loader:
        f, o, q, a, ans_id, _ = batch
        f, o, tgt = f.to(device), o.to(device), ans_id.to(device)
        with torch.amp.autocast('cuda', enabled=use_amp):
            out = model(f, o, q, a)
            loss = xe(out, tgt)
        optimizer.zero_grad()
        if scaler:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        total_loss += loss.item()
        preds.append(out.argmax(-1))
        gts.append(ans_id)
    preds = torch.cat(preds).cpu()
    gts = torch.cat(gts)
    return total_loss / len(loader), (preds == gts).float().mean().item() * 100

def eval_model(model, loader, device):
    model.eval()
    preds, gts = [], []
    with torch.no_grad():
        for batch in loader:
            f, o, q, a, ans_id, _ = batch
            out = model(f.to(device), o.to(device), q, a)
            preds.append(out.argmax(-1))
            gts.append(ans_id)
    preds = torch.cat(preds).cpu()
    gts = torch.cat(gts)
    return (preds == gts).float().mean().item() * 100

print('Functions defined')

In [None]:
RUN_TRAINING = True
HF_REPO_ID = "DanielQ07/transtr-causalvid-weights"
HF_MODEL_FILENAME = "best_model.ckpt"
HF_DATASET_ID = "DanielQ07/kltn"
HF_SHARD_FOLDER = "shards"

import tarfile, shutil
BASE = "/kaggle/working" if os.path.exists("/kaggle/working") else os.getcwd()
OBJ_DIR = os.path.join(BASE, "features", "objects")
MODEL_DIR = os.path.join(BASE, "models")
os.makedirs(OBJ_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)

if not any(os.path.isdir(os.path.join(OBJ_DIR, d)) for d in os.listdir(OBJ_DIR)):
    print('Downloading data...')
    try:
        files = list_repo_tree(HF_DATASET_ID, repo_type='dataset', path_in_repo=HF_SHARD_FOLDER)
        tars = [f.path for f in files if f.path.endswith('.tar.gz')]
        for t in tars:
            print(f'  {os.path.basename(t)}')
            p = hf_hub_download(HF_DATASET_ID, t, repo_type='dataset', local_dir=BASE)
            with tarfile.open(p, 'r:gz') as tf: tf.extractall(OBJ_DIR)
            os.remove(p)
        for d in os.listdir(OBJ_DIR):
            dp = os.path.join(OBJ_DIR, d)
            if os.path.isdir(dp) and ('shard' in d or 'train' in d or 'val' in d):
                for s in os.listdir(dp): shutil.move(os.path.join(dp, s), os.path.join(OBJ_DIR, s))
                os.rmdir(dp)
    except Exception as e: print(f'Error: {e}')
print('Data ready')

In [None]:
class Config:
    # PATHS - Update these!
    video_feature_root = "/kaggle/input/your-vit-features"  # << SET THIS
    object_feature_path = OBJ_DIR
    sample_list_path = os.path.join(os.getcwd(), '..', 'data', 'vqa', 'causal', 'anno')
    split_dir_txt = os.path.join(os.getcwd(), '..', 'data', 'splits')

    # === CRITICAL: Match your actual data! ===
    topK_frame = 16      # Number of frames in your ViT features
    frame_feat_dim = 1024  # Dimension of your ViT features
    obj_feat_dim = 2053  # 2048 + 5 (bbox)
    objs = 10            # Objects per frame
    # =========================================
    
    bs = 8
    lr = 1e-4
    epoch = 20
    gpu = 0
    dropout = 0.3
    encoder_dropout = 0.3
    patience = 5
    gamma = 0.1
    decay = 1e-4
    n_query = 5
    d_model = 512
    word_dim = 768
    topK_obj = 10
    num_encoder_layers = 2
    num_decoder_layers = 2
    nheads = 8
    normalize_before = True
    activation = 'gelu'
    text_encoder_lr = 1e-5
    freeze_text_encoder = False
    text_encoder_type = 'bert-base-uncased'
    text_pool_mode = 1
    hard_eval = False
    pos_ratio = 1.0
    neg_ratio = 1.0
    a = 1.0
    use_amp = True
    num_workers = 2

args = Config()
set_gpu_devices(args.gpu)
set_seed(999)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}, topK_frame={args.topK_frame}, frame_feat_dim={args.frame_feat_dim}')

In [None]:
print('Creating datasets...')
train_ds = VideoQADataset('train', args.n_query, args.objs, args.sample_list_path, args.video_feature_root, args.object_feature_path, args.split_dir_txt, args.topK_frame)
val_ds = VideoQADataset('val', args.n_query, args.objs, args.sample_list_path, args.video_feature_root, args.object_feature_path, args.split_dir_txt, args.topK_frame)
test_ds = VideoQADataset('test', args.n_query, args.objs, args.sample_list_path, args.video_feature_root, args.object_feature_path, args.split_dir_txt, args.topK_frame)
train_loader = DataLoader(train_ds, args.bs, shuffle=True, num_workers=args.num_workers, pin_memory=True)
val_loader = DataLoader(val_ds, args.bs, shuffle=False, num_workers=args.num_workers, pin_memory=True)
test_loader = DataLoader(test_ds, args.bs, shuffle=False, num_workers=args.num_workers, pin_memory=True)
print(f'Train: {len(train_ds)}, Val: {len(val_ds)}, Test: {len(test_ds)}')

In [None]:
cfg = {k: v for k, v in Config.__dict__.items() if not k.startswith('_')}
cfg['device'] = device
model = VideoQAmodel(**cfg)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.decay)
scheduler = ReduceLROnPlateau(optimizer, 'max', factor=args.gamma, patience=args.patience)
model.to(device)
xe = nn.CrossEntropyLoss()
scaler = torch.amp.GradScaler('cuda', enabled=args.use_amp)
print('Model ready')

In [None]:
save_path = os.path.join(MODEL_DIR, HF_MODEL_FILENAME)
best_acc = 0
if RUN_TRAINING:
    print('Training...')
    for ep in range(1, args.epoch + 1):
        loss, acc = train(model, optimizer, train_loader, xe, device, args.use_amp, scaler)
        val_acc = eval_model(model, val_loader, device)
        scheduler.step(val_acc)
        print(f'Ep {ep}: Loss={loss:.4f}, Train={acc:.2f}%, Val={val_acc:.2f}%')
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), save_path)
    print(f'Best Val: {best_acc:.2f}%')
    try:
        api = HfApi()
        api.create_repo(HF_REPO_ID, repo_type='model', exist_ok=True)
        api.upload_file(save_path, HF_MODEL_FILENAME, HF_REPO_ID, repo_type='model')
        print('Uploaded!')
    except Exception as e: print(f'Upload failed: {e}')
else:
    print('Skip training')

In [None]:
import matplotlib.pyplot as plt, json
if os.path.exists(save_path): model.load_state_dict(torch.load(save_path))
model.eval()
results = {}
type_map = {'descriptive':'d','explanatory':'e','predictive':'p','predictive_reason':'pr','counterfactual':'c','counterfactual_reason':'cr'}
with torch.no_grad():
    for batch in test_loader:
        f,o,q,a,ans_id,keys = batch
        out = model(f.to(device), o.to(device), q, a)
        preds = out.argmax(-1).cpu().numpy()
        for i, k in enumerate(keys):
            for ts, tsh in type_map.items():
                if k.endswith('_'+ts):
                    vid = k[:-(len(ts)+1)]
                    if vid not in results: results[vid] = {}
                    results[vid][tsh] = preds[i] == ans_id[i].item()
                    break
stats = {k: [0,0] for k in ['d','e','p','pr','c','cr','par','car']}
for vid, r in results.items():
    for t in ['d','e','p','pr','c','cr']:
        if t in r: stats[t][1] += 1; stats[t][0] += int(r[t])
    if 'p' in r and 'pr' in r: stats['par'][1] += 1; stats['par'][0] += int(r['p'] and r['pr'])
    if 'c' in r and 'cr' in r: stats['car'][1] += 1; stats['car'][0] += int(r['c'] and r['cr'])
print('Type   Acc%    Cor/Tot')
for k in ['d','e','p','pr','par','c','cr','car']:
    c, t = stats[k]
    print(f'{k.upper():<6} {c/t*100 if t else 0:<8.2f} {c}/{t}')
labels = [k.upper() for k in ['d','e','p','pr','par','c','cr','car'] if stats[k][1]]
accs = [stats[k.lower()][0]/stats[k.lower()][1]*100 for k in labels]
plt.figure(figsize=(10,4)); plt.bar(labels, accs); plt.ylim(0,100); plt.ylabel('Accuracy %'); plt.show()