# TranSTR + Token Mark - Inference & Visualization

**This notebook:**
1. Loads best model checkpoint from W&B
2. Runs inference on test samples with/without Token Mark
3. Visualizes:
   - 16 sampled frames from raw video
   - Selected frames after TopK filtering
   - Object bounding boxes
   - Token Mark entity masks (if available)
4. Shows detailed Q&A predictions

---

## üî¥ REQUIREMENTS
- Raw video path
- Best model checkpoint on W&B

In [None]:
# ==============================================================================
# CELL 1: Setup & Clone
# ==============================================================================
import os
import sys

REPO_URL = "https://github.com/DanielQH07/tranSTR_Casual.git" 
REPO_NAME = "tranSTR_Casual"
BRANCH = "daniel_setmark"

if not os.path.exists(REPO_NAME):
    print(f"Cloning {REPO_URL}...")
    !git clone {REPO_URL} -b {BRANCH}
else:
    print("Repo already exists.")

# Change Directory
if os.path.basename(os.getcwd()) != "causalvid":
    target_dir = os.path.join(os.getcwd(), REPO_NAME, "causalvid")
    if os.path.exists(target_dir):
        os.chdir(target_dir)
    elif os.path.exists(REPO_NAME):
        os.chdir(REPO_NAME)
print(f"Working directory: {os.getcwd()}")

In [None]:
# ==============================================================================
# CELL 2: Install & W&B Login
# ==============================================================================
!pip install -q wandb decord opencv-python matplotlib seaborn
import wandb

# ============================================
# üî¥ W&B CONFIG
# ============================================
WANDB_API_KEY = 'YOUR_WANDB_API_KEY_HERE'  # üî¥ UPDATE
WANDB_PROJECT = 'transtr-causalvid'
WANDB_ENTITY = None

# Model artifact to load
ARTIFACT_NAME = 'best-model-som:latest'  # üî¥ UPDATE if needed

wandb.login(key=WANDB_API_KEY, relogin=True)
print('‚úÖ W&B logged in!')

In [None]:
# ==============================================================================
# CELL 3: Imports
# ==============================================================================
import torch
import numpy as np
import pandas as pd
import json
import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.gridspec import GridSpec
import seaborn as sns
from PIL import Image
from tqdm.auto import tqdm
from einops import rearrange

try:
    from decord import VideoReader, cpu
    USE_DECORD = True
except ImportError:
    USE_DECORD = False
    print('‚ö†Ô∏è decord not available, using OpenCV')

from torch.utils.data import DataLoader
from utils.util import set_seed, set_gpu_devices
from DataLoader import VideoQADataset
from networks.model import VideoQAmodel

print('‚úÖ Imports OK')

In [None]:
# ==============================================================================
# CELL 4: Paths Configuration
# ==============================================================================
print('=== CELL 4: Paths ===')

# ============================================
# üî¥ UPDATE THESE PATHS
# ============================================
VIT_FEATURE_PATH = '/kaggle/input/vit-features-full-merged'
OBJ_FEATURE_PATH = '/kaggle/input/object-detection-causal-full'
ANNOTATION_PATH = '/kaggle/input/text-annotation/QA'
SPLIT_DIR = '/kaggle/input/casual-vid-data-split/split'
SOM_FEATURE_PATH = '/kaggle/input/causal-vqa-object-masks-full/obj_mask_causal_full'

# üî¥ RAW VIDEO PATH - for visualization
RAW_VIDEO_PATH = '/kaggle/input/causal-vid-qa-raw-videos/videos'  # üî¥ UPDATE

# Verify paths
def verify_path(name, path):
    if os.path.exists(path):
        items = os.listdir(path)[:3]
        print(f'‚úÖ {name}: {len(os.listdir(path))} items')
        return True
    else:
        print(f'‚ùå {name}: NOT FOUND - {path}')
        return False

verify_path('ViT Features', VIT_FEATURE_PATH)
verify_path('Object Features', OBJ_FEATURE_PATH)
verify_path('Annotations', ANNOTATION_PATH)
verify_path('SoM Masks', SOM_FEATURE_PATH)
video_ok = verify_path('Raw Videos', RAW_VIDEO_PATH)

if not video_ok:
    print('\n‚ö†Ô∏è Raw videos not found! Frame visualization will be limited.')

In [None]:
# ==============================================================================
# CELL 5: Config & Device
# ==============================================================================
print('=== CELL 5: Config ===')

class Config:
    # Paths
    video_feature_root = VIT_FEATURE_PATH
    object_feature_path = OBJ_FEATURE_PATH
    sample_list_path = ANNOTATION_PATH
    split_dir_txt = SPLIT_DIR
    som_feature_path = SOM_FEATURE_PATH
    raw_video_path = RAW_VIDEO_PATH
    
    # Model architecture
    topK_frame = 16
    objs = 20
    frames = 16
    select_frames = 5
    topK_obj = 12
    frame_feat_dim = 1024
    obj_feat_dim = 2053
    d_model = 768
    word_dim = 768
    nheads = 8
    num_encoder_layers = 2
    num_decoder_layers = 2
    normalize_before = True
    activation = 'gelu'
    dropout = 0.3
    encoder_dropout = 0.3
    
    # Token Mark (SoM)
    use_som = True
    num_marks = 16
    
    # Text encoder
    text_encoder_type = 'microsoft/deberta-base'
    freeze_text_encoder = False
    text_encoder_lr = 1e-5
    text_pool_mode = 1
    
    # Eval
    bs = 1  # Single sample for visualization
    n_query = 5
    gpu = 0
    hard_eval = True  # Use hard topK for clear visualization
    
    pos_ratio = 1.0
    neg_ratio = 1.0
    a = 1.0

args = Config()
set_gpu_devices(args.gpu)
set_seed(999)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

In [None]:
# ==============================================================================
# CELL 6: Download Model from W&B
# ==============================================================================
print('=== CELL 6: Download Model ===')

# Initialize temp run to download
temp_run = wandb.init(
    project=WANDB_PROJECT,
    entity=WANDB_ENTITY,
    job_type='inference',
    name='inference-visualization',
    reinit=True
)

# Download artifact
print(f'Downloading artifact: {ARTIFACT_NAME}')
try:
    artifact = temp_run.use_artifact(ARTIFACT_NAME, type='model')
    artifact_dir = artifact.download()
    
    ckpt_files = [f for f in os.listdir(artifact_dir) if f.endswith('.ckpt') or f.endswith('.pt')]
    if ckpt_files:
        CHECKPOINT_PATH = os.path.join(artifact_dir, ckpt_files[0])
        print(f'‚úÖ Checkpoint: {CHECKPOINT_PATH}')
        
        # Get metadata
        if artifact.metadata:
            print(f"   Epoch: {artifact.metadata.get('epoch', 'N/A')}")
            print(f"   Val Acc: {artifact.metadata.get('val_acc', 'N/A'):.2f}%")
    else:
        raise FileNotFoundError("No checkpoint file found")
        
except Exception as e:
    print(f'‚ùå Error: {e}')
    CHECKPOINT_PATH = None

In [None]:
# ==============================================================================
# CELL 7: Create Test Dataset
# ==============================================================================
print('=== CELL 7: Test Dataset ===')

def collate_fn_som(batch):
    ff = torch.stack([item[0] for item in batch])
    of = torch.stack([item[1] for item in batch])
    qns = [item[2] for item in batch]
    ans = [item[3] for item in batch]
    ans_id = torch.tensor([item[4] for item in batch])
    qns_key = [item[5] for item in batch]
    som_data = [item[6] for item in batch]
    return ff, of, qns, ans, ans_id, qns_key, som_data

test_ds = VideoQADataset(
    split='test', n_query=args.n_query, obj_num=args.objs,
    sample_list_path=args.sample_list_path,
    video_feature_path=args.video_feature_root,
    object_feature_path=args.object_feature_path,
    split_dir=args.split_dir_txt, topK_frame=args.topK_frame,
    max_samples=100, verbose=True,  # Limit for faster loading
    som_feature_path=args.som_feature_path
)

test_loader = DataLoader(
    test_ds, batch_size=1, shuffle=False,
    num_workers=0, collate_fn=collate_fn_som
)

print(f'Test samples: {len(test_ds)}')

In [None]:
# ==============================================================================
# CELL 8: Load Models (with and without SoM)
# ==============================================================================
print('=== CELL 8: Load Models ===')

# Create model WITH SoM
cfg_som = {k: v for k, v in Config.__dict__.items() if not k.startswith('_')}
cfg_som['device'] = device
cfg_som['topK_frame'] = args.select_frames
cfg_som['use_som'] = True
cfg_som['num_marks'] = args.num_marks
cfg_som['hard_eval'] = True  # For clear visualization

model_som = VideoQAmodel(**cfg_som)
model_som.to(device)

# Create model WITHOUT SoM
cfg_no_som = cfg_som.copy()
cfg_no_som['use_som'] = False

model_no_som = VideoQAmodel(**cfg_no_som)
model_no_som.to(device)

# Load weights
if CHECKPOINT_PATH and os.path.exists(CHECKPOINT_PATH):
    state_dict = torch.load(CHECKPOINT_PATH, map_location=device)
    
    # Load into SoM model
    model_som.load_state_dict(state_dict)
    print('‚úÖ Model WITH SoM loaded')
    
    # Load into non-SoM model (ignore som_injector keys)
    filtered_state = {k: v for k, v in state_dict.items() if 'som_injector' not in k}
    model_no_som.load_state_dict(filtered_state, strict=False)
    print('‚úÖ Model WITHOUT SoM loaded (som_injector ignored)')
    
model_som.eval()
model_no_som.eval()

print(f'\nTotal params: {sum(p.numel() for p in model_som.parameters())/1e6:.1f}M')

In [None]:
# ==============================================================================
# CELL 9: Video Frame Extraction Utilities
# ==============================================================================
print('=== CELL 9: Video Utils ===')

def load_video_frames(video_path, num_frames=16):
    """Load uniformly sampled frames from video."""
    if not os.path.exists(video_path):
        return None
    
    if USE_DECORD:
        vr = VideoReader(video_path, ctx=cpu(0))
        total_frames = len(vr)
        indices = np.linspace(0, total_frames - 1, num_frames).astype(int)
        frames = vr.get_batch(indices).asnumpy()  # [N, H, W, C]
        return frames
    else:
        cap = cv2.VideoCapture(video_path)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        indices = np.linspace(0, total_frames - 1, num_frames).astype(int)
        
        frames = []
        for idx in indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
            ret, frame = cap.read()
            if ret:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frames.append(frame)
        cap.release()
        return np.array(frames) if frames else None

def find_video_file(video_id, video_dir):
    """Find video file with various extensions."""
    if not os.path.exists(video_dir):
        return None
    
    extensions = ['.mp4', '.avi', '.mkv', '.webm', '.mov']
    for ext in extensions:
        path = os.path.join(video_dir, f"{video_id}{ext}")
        if os.path.exists(path):
            return path
    
    # Try without extension match
    for f in os.listdir(video_dir):
        if f.startswith(video_id):
            return os.path.join(video_dir, f)
    
    return None

def load_object_boxes(video_id, obj_feature_path, frame_idx=0):
    """Load object bounding boxes from feature files."""
    boxes_list = []
    
    for fidx in range(16):
        feat_path = os.path.join(obj_feature_path, video_id, f'frame{fidx}.npy')
        if os.path.exists(feat_path):
            feat = np.load(feat_path)  # [N, 2053] where last 4 are [x1,y1,x2,y2] or similar
            # Usually object features contain bbox info
            # Format may vary - adjust based on actual data
            boxes_list.append(feat)
        else:
            boxes_list.append(None)
    
    return boxes_list

print('‚úÖ Video utilities defined')

In [None]:
# ==============================================================================
# CELL 10: Inference with Attention Extraction
# ==============================================================================
print('=== CELL 10: Inference Functions ===')

def inference_with_attention(model, ff, of, qns, ans, som_data, device, use_som=True):
    """
    Run inference and extract frame/object selection indices.
    Returns prediction and selection info.
    """
    model.eval()
    B, F, O = of.size()[:3]
    
    with torch.no_grad():
        ff = ff.to(device)
        of = of.to(device)
        
        # Manual forward to extract indices
        frame_feat = model.frame_resize(ff)
        q_local, q_mask = model.forward_text(list(qns), device)
        
        frame_mask = torch.ones(B, F).bool().to(device)
        frame_local, frame_att = model.frame_decoder(
            frame_feat, q_local,
            memory_key_padding_mask=q_mask,
            query_pos=model.pos_encoder_1d(frame_mask, model.d_model),
            output_attentions=True
        )
        
        # Get frame selection indices
        from networks.topk import HardtopK
        idx_frame = rearrange(
            HardtopK(frame_att.flatten(1,2), model.frame_topK), 
            'b (f q) k -> b f q k', f=F
        ).sum(-2)  # [B, F, frame_topK]
        
        # Selected frame indices (which original frames were chosen)
        selected_frame_indices = idx_frame[0].argmax(dim=0).cpu().numpy()  # [frame_topK]
        frame_weights = idx_frame[0].sum(dim=1).cpu().numpy()  # [F] total weight per frame
        
        frame_local = (frame_local.transpose(1,2) @ idx_frame).transpose(1,2)
        
        # Object processing
        obj_feat = (of.flatten(-2,-1).transpose(1,2) @ idx_frame).transpose(1,2)
        obj_feat = obj_feat.view(B, model.frame_topK, O, -1)
        obj_local = model.obj_resize(obj_feat)
        
        # Apply SoM if enabled
        if use_som and hasattr(model, 'som_injector') and som_data[0] is not None:
            frame_local, obj_local = model.som_injector(
                frame_local, obj_local, som_data, idx_frame=idx_frame
            )
        
        # Object selection
        q_local_rep = q_local.repeat_interleave(model.frame_topK, dim=0)
        q_mask_rep = q_mask.repeat_interleave(model.frame_topK, dim=0) if q_mask is not None else None
        
        obj_local_flat, obj_att = model.obj_decoder(
            obj_local.flatten(0,1), q_local_rep,
            memory_key_padding_mask=q_mask_rep,
            output_attentions=True
        )
        
        idx_obj = rearrange(
            HardtopK(obj_att.flatten(1,2), model.obj_topK),
            'b (o q) k -> b o q k', o=O
        ).sum(-2)
        
        # Selected object indices per frame
        selected_obj_indices = []
        for f_idx in range(model.frame_topK):
            obj_w = idx_obj[f_idx].sum(dim=1).cpu().numpy()  # [O]
            top_objs = np.argsort(obj_w)[-model.obj_topK:][::-1]
            selected_obj_indices.append(top_objs)
        
        # Full forward for prediction
        if use_som and som_data[0] is not None:
            out = model(ff, of, qns, ans, som_data=som_data)
        else:
            # Need to handle differently for no-som model
            out = model(ff, of, qns, ans)
        
        pred = out.argmax(-1).item()
        probs = torch.softmax(out, dim=-1)[0].cpu().numpy()
        
    return {
        'pred': pred,
        'probs': probs,
        'selected_frames': selected_frame_indices,
        'frame_weights': frame_weights,
        'selected_objects': selected_obj_indices,
    }

print('‚úÖ Inference function defined')

In [None]:
# ==============================================================================
# CELL 11: Visualization Functions
# ==============================================================================
print('=== CELL 11: Visualization ===')

def visualize_sample(sample_data, result_som, result_no_som, video_frames=None, som_masks=None):
    """
    Create comprehensive visualization for a sample.
    """
    qns_key = sample_data['qns_key']
    question = sample_data['question']
    answers = sample_data['answers']
    correct_ans = sample_data['correct_ans']
    
    # Create figure
    fig = plt.figure(figsize=(20, 16))
    gs = GridSpec(4, 4, figure=fig, hspace=0.3, wspace=0.2)
    
    # Title
    fig.suptitle(f"Sample: {qns_key}", fontsize=16, fontweight='bold')
    
    # ============================================
    # Row 1: 16 Sampled Frames
    # ============================================
    if video_frames is not None:
        for i in range(min(16, len(video_frames))):
            row = i // 8
            col = i % 8
            ax = fig.add_subplot(gs[row, col // 2] if col < 4 else gs[row, col // 2])
            
            if row == 0:
                ax = fig.add_subplot(4, 8, i + 1)
            else:
                ax = fig.add_subplot(4, 8, i + 1)
            
            ax.imshow(video_frames[i])
            
            # Highlight selected frames
            if i in result_som['selected_frames']:
                ax.patch.set_edgecolor('lime')
                ax.patch.set_linewidth(4)
                ax.set_title(f'F{i}‚úì', fontsize=8, color='lime')
            else:
                ax.set_title(f'F{i}', fontsize=8)
            
            ax.axis('off')
    else:
        ax = fig.add_subplot(gs[0:2, :])
        ax.text(0.5, 0.5, 'Video frames not available', 
                ha='center', va='center', fontsize=14)
        ax.axis('off')
    
    # ============================================
    # Row 3: Q&A Info and Comparison
    # ============================================
    ax_qa = fig.add_subplot(gs[2, :2])
    ax_qa.axis('off')
    
    qa_text = f"""QUESTION:
{question}

ANSWERS:
"""
    for i, ans in enumerate(answers):
        marker = ''
        if i == correct_ans:
            marker = ' ‚úì (correct)'
        if i == result_som['pred']:
            marker += ' ‚Üê SoM pred'
        if i == result_no_som['pred']:
            marker += ' ‚Üê No-SoM pred'
        qa_text += f"  [{i}] {ans}{marker}\n"
    
    ax_qa.text(0.02, 0.98, qa_text, transform=ax_qa.transAxes,
               fontsize=10, verticalalignment='top', fontfamily='monospace',
               bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    # ============================================
    # Row 3: Prediction Comparison
    # ============================================
    ax_comp = fig.add_subplot(gs[2, 2:])
    
    x = np.arange(5)
    width = 0.35
    
    bars1 = ax_comp.bar(x - width/2, result_som['probs'], width, label='With SoM', color='green', alpha=0.7)
    bars2 = ax_comp.bar(x + width/2, result_no_som['probs'], width, label='Without SoM', color='red', alpha=0.7)
    
    ax_comp.set_ylabel('Probability')
    ax_comp.set_xlabel('Answer Option')
    ax_comp.set_title('Prediction Probabilities')
    ax_comp.set_xticks(x)
    ax_comp.set_xticklabels(['A0', 'A1', 'A2', 'A3', 'A4'])
    ax_comp.legend()
    ax_comp.axhline(y=0.2, color='gray', linestyle='--', alpha=0.5)
    
    # Highlight correct answer
    ax_comp.get_xticklabels()[correct_ans].set_color('blue')
    ax_comp.get_xticklabels()[correct_ans].set_fontweight('bold')
    
    # ============================================
    # Row 4: Selected Frames Detail
    # ============================================
    ax_sel = fig.add_subplot(gs[3, :2])
    
    # Frame attention weights
    frame_w = result_som['frame_weights']
    colors = ['lime' if i in result_som['selected_frames'] else 'gray' for i in range(16)]
    ax_sel.bar(range(16), frame_w, color=colors)
    ax_sel.set_xlabel('Frame Index')
    ax_sel.set_ylabel('Attention Weight')
    ax_sel.set_title(f'Frame Selection (Top {len(result_som["selected_frames"])} selected in green)')
    ax_sel.set_xticks(range(16))
    
    # ============================================
    # Row 4: Summary
    # ============================================
    ax_sum = fig.add_subplot(gs[3, 2:])
    ax_sum.axis('off')
    
    som_correct = result_som['pred'] == correct_ans
    no_som_correct = result_no_som['pred'] == correct_ans
    
    summary = f"""RESULTS SUMMARY
‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ
With Token Mark (SoM):
  Prediction: {result_som['pred']} {'‚úÖ CORRECT' if som_correct else '‚ùå WRONG'}
  Confidence: {result_som['probs'][result_som['pred']]*100:.1f}%
  Selected Frames: {list(result_som['selected_frames'])}

Without Token Mark:
  Prediction: {result_no_som['pred']} {'‚úÖ CORRECT' if no_som_correct else '‚ùå WRONG'}
  Confidence: {result_no_som['probs'][result_no_som['pred']]*100:.1f}%
  Selected Frames: {list(result_no_som['selected_frames'])}
‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ
Correct Answer: {correct_ans}
"""
    
    ax_sum.text(0.02, 0.98, summary, transform=ax_sum.transAxes,
               fontsize=10, verticalalignment='top', fontfamily='monospace',
               bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.5))
    
    plt.tight_layout()
    return fig

print('‚úÖ Visualization function defined')

In [None]:
# ==============================================================================
# CELL 12: Run Inference on Selected Samples
# ==============================================================================
print('=== CELL 12: Run Inference ===')

# Get annotation data for question/answer text
annotation_files = {
    'descriptive': os.path.join(ANNOTATION_PATH, 'descriptive.csv'),
    'explanatory': os.path.join(ANNOTATION_PATH, 'explanatory.csv'),
    'predictive': os.path.join(ANNOTATION_PATH, 'predictive.csv'),
    'counterfactual': os.path.join(ANNOTATION_PATH, 'counterfactual.csv'),
}

# Load annotations
annotations = {}
for qtype, path in annotation_files.items():
    if os.path.exists(path):
        annotations[qtype] = pd.read_csv(path)
        print(f'Loaded {qtype}: {len(annotations[qtype])} samples')

def get_sample_info(video_id, qtype):
    """Get question and answers from annotation."""
    qtype_map = {
        'descriptive': 'descriptive',
        'explanatory': 'explanatory', 
        'predictive': 'predictive',
        'predictive_reason': 'predictive',
        'counterfactual': 'counterfactual',
        'counterfactual_reason': 'counterfactual'
    }
    
    df_key = qtype_map.get(qtype, qtype)
    if df_key in annotations:
        df = annotations[df_key]
        row = df[df['video_id'] == int(video_id)]
        if len(row) > 0:
            row = row.iloc[0]
            return {
                'question': row.get('question', 'N/A'),
                'answers': [row.get(f'a{i}', f'Option {i}') for i in range(5)],
                'correct': row.get('answer', 0)
            }
    return None

print('\n‚úÖ Annotation data loaded')

In [None]:
# ==============================================================================
# CELL 13: Process Multiple Samples
# ==============================================================================
print('=== CELL 13: Process Samples ===')

NUM_SAMPLES = 5  # Number of samples to visualize
results_list = []

for idx, batch in enumerate(tqdm(test_loader, total=NUM_SAMPLES)):
    if idx >= NUM_SAMPLES:
        break
    
    ff, of, qns, ans, ans_id, qns_key, som_data = batch
    qns_key = qns_key[0]
    
    # Parse video_id and question type
    parts = qns_key.rsplit('_', 1)
    video_id = parts[0] if len(parts) > 1 else qns_key
    qtype = parts[1] if len(parts) > 1 else 'unknown'
    
    print(f'\n--- Sample {idx+1}: {qns_key} ---')
    print(f'Video ID: {video_id}, Type: {qtype}')
    
    # Get Q&A info
    qa_info = get_sample_info(video_id, qtype)
    if qa_info:
        question = qa_info['question']
        answers = qa_info['answers']
    else:
        question = qns[0]
        answers = ans[0] if isinstance(ans[0], list) else ['N/A'] * 5
    
    # Run inference with SoM
    result_som = inference_with_attention(
        model_som, ff, of, qns, ans, som_data, device, use_som=True
    )
    
    # Run inference without SoM
    result_no_som = inference_with_attention(
        model_no_som, ff, of, qns, ans, [None], device, use_som=False
    )
    
    # Load video frames
    video_path = find_video_file(video_id, RAW_VIDEO_PATH)
    video_frames = load_video_frames(video_path) if video_path else None
    
    # Get SoM masks if available
    som_masks = som_data[0].get('frame_masks', {}) if som_data[0] else None
    entity_names = som_data[0].get('entity_names', {}) if som_data[0] else None
    
    # Store results
    sample_data = {
        'idx': idx,
        'qns_key': qns_key,
        'video_id': video_id,
        'qtype': qtype,
        'question': question,
        'answers': answers,
        'correct_ans': ans_id[0].item(),
        'entity_names': entity_names,
    }
    
    results_list.append({
        'sample_data': sample_data,
        'result_som': result_som,
        'result_no_som': result_no_som,
        'video_frames': video_frames,
        'som_masks': som_masks
    })
    
    # Print quick summary
    print(f'Question: {question[:80]}...' if len(question) > 80 else f'Question: {question}')
    print(f'Correct: {ans_id[0].item()}, SoM pred: {result_som["pred"]}, No-SoM pred: {result_no_som["pred"]}')
    print(f'SoM correct: {result_som["pred"] == ans_id[0].item()}, No-SoM correct: {result_no_som["pred"] == ans_id[0].item()}')
    if entity_names:
        print(f'Entities: {entity_names}')

print(f'\n‚úÖ Processed {len(results_list)} samples')

In [None]:
# ==============================================================================
# CELL 14: Visualize All Samples
# ==============================================================================
print('=== CELL 14: Visualizations ===')

for i, res in enumerate(results_list):
    print(f'\nüìä Visualizing sample {i+1}/{len(results_list)}: {res["sample_data"]["qns_key"]}')
    
    fig = visualize_sample(
        res['sample_data'],
        res['result_som'],
        res['result_no_som'],
        res['video_frames'],
        res['som_masks']
    )
    
    # Save figure
    fig.savefig(f'inference_sample_{i+1}.png', dpi=150, bbox_inches='tight')
    plt.show()
    plt.close(fig)

print('\n‚úÖ All visualizations complete!')

In [None]:
# ==============================================================================
# CELL 15: Summary Statistics
# ==============================================================================
print('=== CELL 15: Summary ===')

som_correct = sum(1 for r in results_list if r['result_som']['pred'] == r['sample_data']['correct_ans'])
no_som_correct = sum(1 for r in results_list if r['result_no_som']['pred'] == r['sample_data']['correct_ans'])
total = len(results_list)

print('\n' + '='*60)
print('INFERENCE SUMMARY')
print('='*60)
print(f'Total samples: {total}')
print(f'\nWith Token Mark (SoM):')
print(f'  Correct: {som_correct}/{total} ({som_correct/total*100:.1f}%)')
print(f'\nWithout Token Mark:')
print(f'  Correct: {no_som_correct}/{total} ({no_som_correct/total*100:.1f}%)')
print('='*60)

# Detailed breakdown
print('\nPer-sample breakdown:')
print('-'*60)
for i, res in enumerate(results_list):
    som_ok = '‚úÖ' if res['result_som']['pred'] == res['sample_data']['correct_ans'] else '‚ùå'
    no_som_ok = '‚úÖ' if res['result_no_som']['pred'] == res['sample_data']['correct_ans'] else '‚ùå'
    print(f"{i+1}. {res['sample_data']['qns_key'][:30]:<30} SoM:{som_ok} No-SoM:{no_som_ok}")

# Log to W&B
wandb.log({
    'inference/som_accuracy': som_correct/total*100,
    'inference/no_som_accuracy': no_som_correct/total*100,
    'inference/samples': total
})

wandb.finish()
print('\n‚úÖ Done!')

In [None]:
# ==============================================================================
# CELL 16: Display 16 Frames with Entity Masks (if available)
# ==============================================================================
print('=== CELL 16: Entity Mask Visualization ===')

# Select a sample with SoM data
sample_with_som = None
for res in results_list:
    if res['som_masks'] and res['video_frames'] is not None:
        sample_with_som = res
        break

if sample_with_som:
    print(f"Visualizing entity masks for: {sample_with_som['sample_data']['qns_key']}")
    
    frames = sample_with_som['video_frames']
    masks = sample_with_som['som_masks']
    entities = sample_with_som['sample_data'].get('entity_names', {})
    
    fig, axes = plt.subplots(4, 8, figsize=(24, 12))
    axes = axes.flatten()
    
    # Define colors for entities
    cmap = plt.cm.get_cmap('tab10')
    
    for i in range(16):
        ax = axes[i]
        
        if i < len(frames):
            ax.imshow(frames[i])
            
            # Overlay mask if available
            if i in masks:
                mask = masks[i].numpy()
                # Create colored overlay
                overlay = np.zeros((*mask.shape, 4))
                for entity_id in np.unique(mask):
                    if entity_id > 0:  # Skip background
                        color = cmap(entity_id % 10)
                        entity_mask = mask == entity_id
                        overlay[entity_mask] = [*color[:3], 0.4]  # RGBA with alpha
                
                ax.imshow(overlay)
        
        # Highlight selected frames
        selected_frames = sample_with_som['result_som']['selected_frames']
        if i in selected_frames:
            for spine in ax.spines.values():
                spine.set_edgecolor('lime')
                spine.set_linewidth(4)
            ax.set_title(f'F{i} ‚úì', fontsize=10, color='lime', fontweight='bold')
        else:
            ax.set_title(f'F{i}', fontsize=10)
        
        ax.axis('off')
    
    # Legend
    if entities:
        legend_text = 'Entities: ' + ', '.join([f'{k}:{v}' for k, v in entities.items()])
        fig.text(0.5, 0.02, legend_text, ha='center', fontsize=12, 
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    plt.suptitle(f"16 Frames with Entity Masks - {sample_with_som['sample_data']['qns_key']}", 
                 fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig('frames_with_masks.png', dpi=150, bbox_inches='tight')
    plt.show()
else:
    print('No sample with both video frames and SoM masks found.')

print('\n‚úÖ Entity mask visualization complete!')