# AIxSuture TimeSformer Model - OSS Challenge 2025
## Task 1: Global Rating Score (GRS) Classification

**Model Architecture:** TimeSformer for video understanding
**Dataset:** AIxSuture (314 videos, 157 students, 3 raters)
**Task:** 4-class GRS classification (Novice, Intermediate, Proficient, Expert)

**Training Strategy:**
- Load preprocessed AIxSuture data
- TimeSformer with divided space-time attention
- Multi-rater aggregated ground truth
- Session-aware evaluation (PRE vs POST)
- Macro-F1 and Expected Cost metrics

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, WeightedRandomSampler
import torchvision.transforms as transforms
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import pickle
import json
from collections import Counter, defaultdict
from sklearn.metrics import classification_report, confusion_matrix, f1_score
from sklearn.utils.class_weight import compute_class_weight
import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

torch.manual_seed(2025)
np.random.seed(2025)
if torch.cuda.is_available():
    torch.cuda.manual_seed(2025)

Using device: cuda


## 1. Load Preprocessed AIxSuture Data

In [None]:
def load_aixsuture_data():
    data_path = '../processed_data/aixsuture_processed_data.pkl'
    metadata_path = '../processed_data/aixsuture_metadata.json'
    
    if not os.path.exists(data_path):
        raise FileNotFoundError(f"Preprocessed data not found at {data_path}. Please run preprocessing first.")
    
    with open(data_path, 'rb') as f:
        data = pickle.load(f)
    
    with open(metadata_path, 'r') as f:
        metadata = json.load(f)
    
    return data, metadata

data, metadata = load_aixsuture_data()

CONFIG = data['config']
GRS_MAPPING = data['grs_mapping']
train_data = data['train_data']
val_data = data['val_data']
train_students = data['train_students']
val_students = data['val_students']

print(f"Dataset Loaded:")
print(f"  Total videos: {metadata['dataset_info']['total_videos']}")
print(f"  Students: {metadata['dataset_info']['total_students']}")
print(f"  Sessions: {metadata['dataset_info']['sessions']}")
print(f"  Investigators: {metadata['dataset_info']['investigators']}")
print(f"  Train sequences: {len(train_data)}")
print(f"  Val sequences: {len(val_data)}")
print(f"  Train students: {len(train_students)}")
print(f"  Val students: {len(val_students)}")

print(f"\nClass distribution:")
for class_id, label in enumerate(GRS_MAPPING['labels']):
    train_count = metadata['class_distribution']['train'].get(str(class_id), 0)
    val_count = metadata['class_distribution']['val'].get(str(class_id), 0)
    print(f"  {label}: Train={train_count}, Val={val_count}")

print(f"\nPreprocessing config:")
print(f"  Sequence length: {CONFIG['sequence_length']}")
print(f"  Frame size: {CONFIG['frame_size']}")
print(f"  FPS: {CONFIG['fps']}")
print(f"  Aggregation: {CONFIG['aggregation_strategy']}")

Dataset Loaded:
  Total videos: 290
  Students: 157
  Sessions: ['POST', 'PRE']
  Investigators: ['A', 'B', 'C']
  Train sequences: 4078
  Val sequences: 2291
  Train students: 18
  Val students: 10

Class distribution:
  Novice: Train=2715, Val=0
  Intermediate: Train=909, Val=239
  Proficient: Train=227, Val=1599
  Expert: Train=227, Val=453

Preprocessing config:
  Sequence length: 16
  Frame size: (224, 224)
  FPS: 5
  Aggregation: average


## 2. TimeSformer Architecture

In [31]:
class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x):
        B, T, C, H, W = x.shape
        x = x.reshape(B * T, C, H, W)
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        x = x.reshape(B, T, self.num_patches, -1)
        return x

class TimeSformerBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.1, attention_type='divided_space_time'):
        super().__init__()
        self.attention_type = attention_type
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        
        if attention_type == 'divided_space_time':
            self.temporal_attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
            self.spatial_attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
            self.norm_temporal = nn.LayerNorm(dim)
        else:
            self.attn = nn.MultiheadAttention(dim, num_heads, dropout=dropout, batch_first=True)
        
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, x, cls_token):
        B, T, N, D = x.shape
        
        if self.attention_type == 'divided_space_time':
            cls_token = cls_token.unsqueeze(1).repeat(1, T, 1, 1)
            x_with_cls = torch.cat([cls_token, x], dim=2)
            
            x_temporal = x_with_cls.permute(0, 2, 1, 3).reshape(B * (N + 1), T, D)
            x_temporal_out, _ = self.temporal_attn(x_temporal, x_temporal, x_temporal)
            x_temporal_out = x_temporal_out.reshape(B, N + 1, T, D).permute(0, 2, 1, 3)
            
            x_temporal_out = x_with_cls + x_temporal_out
            x_temporal_out = self.norm_temporal(x_temporal_out)
            
            x_spatial = x_temporal_out.reshape(B * T, N + 1, D)
            x_spatial_out, _ = self.spatial_attn(x_spatial, x_spatial, x_spatial)
            x_spatial_out = x_spatial_out.reshape(B, T, N + 1, D)
            
            x_out = x_temporal_out + x_spatial_out
            x_out = self.norm1(x_out)
            
            cls_token_out = x_out[:, :, 0, :].mean(dim=1)
            x_out = x_out[:, :, 1:, :]
        else:
            print("x shape:", x.shape)
            print("cls_token shape before repeat:", cls_token.shape)
            print("cls_token shape after repeat:", cls_token.unsqueeze(1).repeat(1, T, 1, 1).shape)
            cls_token_expanded = cls_token.unsqueeze(1).unsqueeze(1).repeat(1, T, 1, 1)
            x_with_cls = torch.cat([cls_token_expanded, x], dim=2)
            x_flat = x_with_cls.reshape(B, T * (N + 1), D)
            
            x_attn, _ = self.attn(x_flat, x_flat, x_flat)
            x_attn = x_flat + x_attn
            x_attn = self.norm1(x_attn)
            
            cls_token_out = x_attn[:, 0, :]
            x_out = x_attn[:, 1:, :].reshape(B, T, N, D)
        
        x_out_flat = x_out.reshape(B * T, N, D)
        mlp_out = self.mlp(x_out_flat)
        x_out = x_out_flat + mlp_out
        x_out = self.norm2(x_out)
        x_out = x_out.reshape(B, T, N, D)
        
        return x_out, cls_token_out

class TimeSformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, num_classes=4, num_frames=16,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, dropout=0.1,
                 attention_type='divided_space_time'):
        super().__init__()
        self.num_classes = num_classes
        self.num_frames = num_frames
        self.embed_dim = embed_dim
        
        self.patch_embed = PatchEmbed(img_size, patch_size, 3, embed_dim)
        num_patches = self.patch_embed.num_patches
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
        self.temporal_embed = nn.Parameter(torch.zeros(1, num_frames, embed_dim))
        
        self.blocks = nn.ModuleList([
            TimeSformerBlock(embed_dim, num_heads, mlp_ratio, dropout, attention_type)
            for _ in range(depth)
        ])
        
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
        
        self.dropout = nn.Dropout(dropout)
        
        self._init_weights()
    
    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.temporal_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)
    
    def forward(self, x):
        B, T, C, H, W = x.shape
        
        x = self.patch_embed(x)
        B, T, N, D = x.shape
        
        x = x + self.pos_embed.unsqueeze(1)
        x = x + self.temporal_embed.unsqueeze(2)
        
        x = self.dropout(x)
        
        cls_token = self.cls_token.expand(B, -1, -1)
        
        for block in self.blocks:
            x, cls_token = block(x, cls_token)
        
        cls_token = self.norm(cls_token)
        logits = self.head(cls_token)
        
        return logits

def create_timesformer_model(num_classes=4, num_frames=16, model_size='base'):
    if model_size == 'small':
        config = {
            'embed_dim': 384,
            'depth': 8,
            'num_heads': 6,
            'mlp_ratio': 4.0
        }
    elif model_size == 'base':
        config = {
            'embed_dim': 768,
            'depth': 12,
            'num_heads': 12,
            'mlp_ratio': 4.0
        }
    else:
        config = {
            'embed_dim': 1024,
            'depth': 16,
            'num_heads': 16,
            'mlp_ratio': 4.0
        }
    
    model = TimeSformer(
        img_size=224,
        patch_size=16,
        num_classes=num_classes,
        num_frames=num_frames,
        **config
    )
    
    return model

model = create_timesformer_model(
    num_classes=len(GRS_MAPPING['classes']),
    num_frames=CONFIG['sequence_length'],
    model_size='base'
).to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nTimeSformer Model Created:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Model size: base")
print(f"  Input shape: (batch, {CONFIG['sequence_length']}, 3, 224, 224)")
print(f"  Output classes: {len(GRS_MAPPING['classes'])}")


TimeSformer Model Created:
  Total parameters: 114,180,100
  Trainable parameters: 114,180,100
  Model size: base
  Input shape: (batch, 16, 3, 224, 224)
  Output classes: 4


## 3. AIxSuture Dataset Class

In [32]:
class AIxSutureDataset(torch.utils.data.Dataset):
    def __init__(self, training_data, transform=None, mode='train', augment=False):
        self.training_data = training_data
        self.transform = transform
        self.mode = mode
        self.augment = augment
        
        if augment and mode == 'train':
            self.spatial_transform = transforms.Compose([
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomRotation(degrees=5),
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)
            ])
        else:
            self.spatial_transform = None
    
    def __len__(self):
        return len(self.training_data)
    
    def __getitem__(self, idx):
        item = self.training_data[idx]
        
        if item.get('sequence_type') == 'simulated':
            sequence_tensor = torch.randn(
                CONFIG['sequence_length'], 3,
                CONFIG['frame_size'][0], CONFIG['frame_size'][1]
            )
        else:
            frames = []
            for frame_info in item['sequence_frames']:
                if frame_info.get('simulated', False):
                    frame_tensor = torch.randn(3, CONFIG['frame_size'][0], CONFIG['frame_size'][1])
                else:
                    try:
                        import cv2
                        frame = cv2.imread(frame_info['path'])
                        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                        frame = cv2.resize(frame, CONFIG['frame_size'])
                        frame_tensor = torch.from_numpy(frame).float() / 255.0
                        frame_tensor = frame_tensor.permute(2, 0, 1)
                    except:
                        frame_tensor = torch.randn(3, CONFIG['frame_size'][0], CONFIG['frame_size'][1])
                
                if self.spatial_transform is not None:
                    frame_tensor = self.spatial_transform(frame_tensor)
                
                frames.append(frame_tensor)
            
            sequence_tensor = torch.stack(frames)
        
        if self.transform:
            sequence_tensor = self.transform(sequence_tensor)
        
        label = torch.tensor(item['grs_class'], dtype=torch.long)
        
        return {
            'sequence': sequence_tensor,
            'label': label,
            'video_name': item['video_name'],
            'student_id': item['student_id'],
            'session': item['session'],
            'grs_total': item['grs_total'],
            'sequence_idx': item['sequence_idx']
        }

def normalize_sequence(sequence):
    mean = torch.tensor(CONFIG['imagenet_mean']).view(1, 3, 1, 1)
    std = torch.tensor(CONFIG['imagenet_std']).view(1, 3, 1, 1)
    return (sequence - mean) / std

train_dataset = AIxSutureDataset(
    train_data,
    transform=normalize_sequence,
    mode='train',
    augment=True
)

val_dataset = AIxSutureDataset(
    val_data,
    transform=normalize_sequence,
    mode='val',
    augment=False
)

print(f"\nDatasets created:")
print(f"  Train dataset: {len(train_dataset)} sequences")
print(f"  Val dataset: {len(val_dataset)} sequences")

if len(train_dataset) > 0:
    sample = train_dataset[0]
    print(f"\nSample data:")
    print(f"  Sequence shape: {sample['sequence'].shape}")
    print(f"  Label: {sample['label']} ({GRS_MAPPING['labels'][sample['label']]})")
    print(f"  Student: {sample['student_id']}")
    print(f"  Session: {sample['session']}")
    print(f"  GRS Total: {sample['grs_total']}")


Datasets created:
  Train dataset: 4078 sequences
  Val dataset: 2291 sequences

Sample data:
  Sequence shape: torch.Size([16, 3, 224, 224])
  Label: 1 (Intermediate)
  Student: BOG917
  Session: PRE
  GRS Total: 17.666666666666668


## 4. Training Configuration and Data Loaders

In [33]:
TRAINING_CONFIG = {
    'batch_size': 4,
    'num_epochs': 50,
    'learning_rate': 1e-4,
    'weight_decay': 1e-4,
    'warmup_epochs': 5,
    'patience': 10,
    'use_class_weights': True,
    'gradient_clip': 1.0,
    'save_best_model': True
}

def compute_class_weights(train_labels):
    unique_classes = np.unique(train_labels)
    class_weights = compute_class_weight(
        'balanced',
        classes=unique_classes,
        y=train_labels
    )
    return torch.FloatTensor(class_weights).to(device)

train_labels = [item['grs_class'] for item in train_data]
class_weights = compute_class_weights(train_labels) if TRAINING_CONFIG['use_class_weights'] else None

if class_weights is not None:
    print(f"\nClass weights:")
    for i, (label, weight) in enumerate(zip(GRS_MAPPING['labels'], class_weights)):
        print(f"  {label}: {weight:.3f}")

def create_weighted_sampler(dataset):
    labels = [item['grs_class'] for item in dataset.training_data]
    class_counts = Counter(labels)
    
    weights = []
    for label in labels:
        weights.append(1.0 / class_counts[label])
    
    return WeightedRandomSampler(weights, len(weights), replacement=True)

train_sampler = create_weighted_sampler(train_dataset) if TRAINING_CONFIG['use_class_weights'] else None

train_loader = DataLoader(
    train_dataset,
    batch_size=TRAINING_CONFIG['batch_size'],
    sampler=train_sampler,
    shuffle=(train_sampler is None),
    num_workers=2,
    pin_memory=True,
    drop_last=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=TRAINING_CONFIG['batch_size'],
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f"\nData loaders created:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Batch size: {TRAINING_CONFIG['batch_size']}")
print(f"  Using weighted sampling: {train_sampler is not None}")


Class weights:
  Novice: 0.376
  Intermediate: 1.122
  Proficient: 4.491
  Expert: 4.491

Data loaders created:
  Train batches: 1019
  Val batches: 573
  Batch size: 4
  Using weighted sampling: True


## 5. Loss Function and Optimizer

In [34]:
def expected_cost_loss(predictions, targets, cost_matrix=None):
    if cost_matrix is None:
        num_classes = len(GRS_MAPPING['classes'])
        cost_matrix = torch.zeros(num_classes, num_classes).to(device)
        for i in range(num_classes):
            for j in range(num_classes):
                cost_matrix[i, j] = abs(i - j)
    
    probs = F.softmax(predictions, dim=1)
    costs = torch.sum(probs * cost_matrix[targets], dim=1)
    return costs.mean()

class CombinedLoss(nn.Module):
    def __init__(self, class_weights=None, alpha=0.7, beta=0.3):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.ce_loss = nn.CrossEntropyLoss(weight=class_weights)
    
    def forward(self, predictions, targets):
        ce_loss = self.ce_loss(predictions, targets)
        ec_loss = expected_cost_loss(predictions, targets)
        return self.alpha * ce_loss + self.beta * ec_loss

criterion = CombinedLoss(class_weights=class_weights)

optimizer = optim.AdamW(
    model.parameters(),
    lr=TRAINING_CONFIG['learning_rate'],
    weight_decay=TRAINING_CONFIG['weight_decay']
)

def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + np.cos(np.pi * progress)))
    
    return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

num_training_steps = len(train_loader) * TRAINING_CONFIG['num_epochs']
num_warmup_steps = len(train_loader) * TRAINING_CONFIG['warmup_epochs']

scheduler = get_cosine_schedule_with_warmup(
    optimizer, num_warmup_steps, num_training_steps
)

print(f"\nTraining setup:")
print(f"  Optimizer: AdamW")
print(f"  Learning rate: {TRAINING_CONFIG['learning_rate']}")
print(f"  Weight decay: {TRAINING_CONFIG['weight_decay']}")
print(f"  Loss function: Combined (CE + Expected Cost)")
print(f"  Scheduler: Cosine with warmup")
print(f"  Warmup steps: {num_warmup_steps}")
print(f"  Total steps: {num_training_steps}")


Training setup:
  Optimizer: AdamW
  Learning rate: 0.0001
  Weight decay: 0.0001
  Loss function: Combined (CE + Expected Cost)
  Scheduler: Cosine with warmup
  Warmup steps: 5095
  Total steps: 50950


## 6. Training Loop

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, scheduler, epoch):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1} [Train]')
    
    for batch_idx, batch in enumerate(pbar):
        sequences = batch['sequence'].to(device)
        labels = batch['label'].to(device)
        
        optimizer.zero_grad()
        
        outputs = model(sequences)
        loss = criterion(outputs, labels)
        
        loss.backward()
        
        if TRAINING_CONFIG['gradient_clip'] > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), TRAINING_CONFIG['gradient_clip'])
        
        optimizer.step()
        scheduler.step()
        
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        current_lr = scheduler.get_last_lr()[0]
        pbar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'Acc': f'{100.*correct/total:.2f}%',
            'LR': f'{current_lr:.2e}'
        })
    
    avg_loss = total_loss / len(train_loader)
    accuracy = 100. * correct / total
    
    return avg_loss, accuracy

def validate_epoch(model, val_loader, criterion, epoch):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    all_predictions = []
    all_labels = []
    all_video_names = []
    all_students = []
    all_sessions = []
    
    with torch.no_grad():
        pbar = tqdm(val_loader, desc=f'Epoch {epoch+1} [Val]')
        
        for batch in pbar:
            sequences = batch['sequence'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(sequences)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_video_names.extend(batch['video_name'])
            all_students.extend(batch['student_id'])
            all_sessions.extend(batch['session'])
            
            pbar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'Acc': f'{100.*correct/total:.2f}%'
            })
    
    avg_loss = total_loss / len(val_loader)
    accuracy = 100. * correct / total
    
    macro_f1 = f1_score(all_labels, all_predictions, average='macro')
    
    return avg_loss, accuracy, macro_f1, all_predictions, all_labels, all_video_names, all_students, all_sessions

def train_model():
    best_macro_f1 = 0
    patience_counter = 0
    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []
    macro_f1_scores = []
    
    print(f"\nStarting training for {TRAINING_CONFIG['num_epochs']} epochs...")
    
    for epoch in range(TRAINING_CONFIG['num_epochs']):
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, scheduler, epoch)
        val_loss, val_acc, macro_f1, val_preds, val_labels, val_videos, val_students, val_sessions = validate_epoch(
            model, val_loader, criterion, epoch
        )
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accuracies.append(train_acc)
        val_accuracies.append(val_acc)
        macro_f1_scores.append(macro_f1)
        
        print(f"\nEpoch {epoch+1}/{TRAINING_CONFIG['num_epochs']}:")
        print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        print(f"  Macro F1: {macro_f1:.4f}")
        
        if macro_f1 > best_macro_f1:
            best_macro_f1 = macro_f1
            patience_counter = 0
            
            if TRAINING_CONFIG['save_best_model']:
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'best_macro_f1': best_macro_f1,
                    'config': TRAINING_CONFIG,
                    'grs_mapping': GRS_MAPPING
                }, 'best_timesformer_model.pth')
                
                print(f"  ✅ New best model saved (Macro F1: {best_macro_f1:.4f})")
        else:
            patience_counter += 1
            print(f"  Patience: {patience_counter}/{TRAINING_CONFIG['patience']}")
        
        if patience_counter >= TRAINING_CONFIG['patience']:
            print(f"\nEarly stopping triggered after {epoch+1} epochs")
            break
    
    return {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_accuracies': train_accuracies,
        'val_accuracies': val_accuracies,
        'macro_f1_scores': macro_f1_scores,
        'best_macro_f1': best_macro_f1,
        'final_predictions': val_preds,
        'final_labels': val_labels,
        'final_videos': val_videos,
        'final_students': val_students,
        'final_sessions': val_sessions
    }

training_history = train_model()


Starting training for 50 epochs...


Epoch 1 [Train]:   0%|          | 0/1019 [00:01<?, ?it/s]


RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 1 but got size 4 for tensor number 1 in the list.

: 

## 7. Evaluation and Analysis

In [None]:
def plot_training_history(history):
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    epochs = range(1, len(history['train_losses']) + 1)
    
    axes[0, 0].plot(epochs, history['train_losses'], 'b-', label='Train Loss')
    axes[0, 0].plot(epochs, history['val_losses'], 'r-', label='Val Loss')
    axes[0, 0].set_title('Training and Validation Loss')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True)
    
    axes[0, 1].plot(epochs, history['train_accuracies'], 'b-', label='Train Acc')
    axes[0, 1].plot(epochs, history['val_accuracies'], 'r-', label='Val Acc')
    axes[0, 1].set_title('Training and Validation Accuracy')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy (%)')
    axes[0, 1].legend()
    axes[0, 1].grid(True)
    
    axes[1, 0].plot(epochs, history['macro_f1_scores'], 'g-', label='Macro F1')
    axes[1, 0].set_title('Macro F1 Score')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Macro F1')
    axes[1, 0].legend()
    axes[1, 0].grid(True)
    
    cm = confusion_matrix(history['final_labels'], history['final_predictions'])
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[1, 1],
                xticklabels=GRS_MAPPING['labels'],
                yticklabels=GRS_MAPPING['labels'])
    axes[1, 1].set_title('Confusion Matrix')
    axes[1, 1].set_xlabel('Predicted')
    axes[1, 1].set_ylabel('Actual')
    
    plt.tight_layout()
    plt.savefig('training_results.png', dpi=300, bbox_inches='tight')
    plt.show()

def analyze_session_performance(predictions, labels, sessions, students):
    results_df = pd.DataFrame({
        'prediction': predictions,
        'label': labels,
        'session': sessions,
        'student': students
    })
    
    print(f"\n=== SESSION-BASED ANALYSIS ===")
    
    session_metrics = {}
    for session in results_df['session'].unique():
        session_data = results_df[results_df['session'] == session]
        
        accuracy = (session_data['prediction'] == session_data['label']).mean()
        macro_f1 = f1_score(session_data['label'], session_data['prediction'], average='macro')
        
        session_metrics[session] = {
            'accuracy': accuracy,
            'macro_f1': macro_f1,
            'count': len(session_data)
        }
        
        print(f"\n{session} Session:")
        print(f"  Sequences: {len(session_data)}")
        print(f"  Accuracy: {accuracy:.3f}")
        print(f"  Macro F1: {macro_f1:.3f}")
        
        session_class_report = classification_report(
            session_data['label'], session_data['prediction'],
            target_names=GRS_MAPPING['labels'], output_dict=True
        )
        
        print(f"  Class-wise F1 scores:")
        for i, label in enumerate(GRS_MAPPING['labels']):
            if str(i) in session_class_report:
                f1 = session_class_report[str(i)]['f1-score']
                print(f"    {label}: {f1:.3f}")
    
    return session_metrics

def compute_expected_cost(predictions, labels):
    num_classes = len(GRS_MAPPING['classes'])
    cost_matrix = np.zeros((num_classes, num_classes))
    
    for i in range(num_classes):
        for j in range(num_classes):
            cost_matrix[i, j] = abs(i - j)
    
    total_cost = 0
    for true_label, pred_label in zip(labels, predictions):
        total_cost += cost_matrix[true_label, pred_label]
    
    return total_cost / len(labels)

plot_training_history(training_history)

print(f"\n{'='*60}")
print(f"FINAL EVALUATION RESULTS")
print(f"{'='*60}")

final_accuracy = (np.array(training_history['final_predictions']) == np.array(training_history['final_labels'])).mean()
final_macro_f1 = f1_score(training_history['final_labels'], training_history['final_predictions'], average='macro')
expected_cost = compute_expected_cost(training_history['final_predictions'], training_history['final_labels'])

print(f"\nOverall Performance:")
print(f"  Best Macro F1: {training_history['best_macro_f1']:.4f}")
print(f"  Final Accuracy: {final_accuracy:.4f}")
print(f"  Final Macro F1: {final_macro_f1:.4f}")
print(f"  Expected Cost: {expected_cost:.4f}")

print(f"\nDetailed Classification Report:")
print(classification_report(
    training_history['final_labels'],
    training_history['final_predictions'],
    target_names=GRS_MAPPING['labels']
))

session_metrics = analyze_session_performance(
    training_history['final_predictions'],
    training_history['final_labels'],
    training_history['final_sessions'],
    training_history['final_students']
)

## 8. Student-Level Analysis

In [None]:
def analyze_student_performance():
    print(f"\n=== STUDENT-LEVEL ANALYSIS ===")
    
    results_df = pd.DataFrame({
        'prediction': training_history['final_predictions'],
        'label': training_history['final_labels'],
        'session': training_history['final_sessions'],
        'student': training_history['final_students'],
        'video': training_history['final_videos']
    })
    
    student_aggregated = results_df.groupby(['student', 'session']).agg({
        'prediction': lambda x: x.mode().iloc[0] if len(x.mode()) > 0 else x.iloc[0],
        'label': lambda x: x.mode().iloc[0] if len(x.mode()) > 0 else x.iloc[0],
        'video': 'count'
    }).rename(columns={'video': 'sequence_count'})
    
    student_aggregated['correct'] = (student_aggregated['prediction'] == student_aggregated['label'])
    
    print(f"\nStudent-level accuracy (aggregated by majority vote):")
    overall_student_accuracy = student_aggregated['correct'].mean()
    print(f"  Overall: {overall_student_accuracy:.3f}")
    
    for session in student_aggregated.index.get_level_values('session').unique():
        session_data = student_aggregated[student_aggregated.index.get_level_values('session') == session]
        session_accuracy = session_data['correct'].mean()
        print(f"  {session}: {session_accuracy:.3f} ({len(session_data)} students)")
    
    if len(student_aggregated.index.get_level_values('session').unique()) == 2:
        print(f"\nPRE vs POST comparison:")
        
        pre_data = student_aggregated[student_aggregated.index.get_level_values('session') == 'PRE']
        post_data = student_aggregated[student_aggregated.index.get_level_values('session') == 'POST']
        
        pre_skills = pre_data['label'].value_counts().sort_index()
        post_skills = post_data['label'].value_counts().sort_index()
        
        print(f"\nSkill distribution:")
        print(f"  PRE:  {dict(pre_skills)}")
        print(f"  POST: {dict(post_skills)}")
        
        improvement_analysis = []
        for student in pre_data.index.get_level_values('student'):
            if student in post_data.index.get_level_values('student'):
                pre_skill = pre_data.loc[student]['label']
                post_skill = post_data.loc[student]['label']
                improvement = post_skill - pre_skill
                improvement_analysis.append(improvement)
        
        if improvement_analysis:
            avg_improvement = np.mean(improvement_analysis)
            improved_students = sum(1 for x in improvement_analysis if x > 0)
            total_paired = len(improvement_analysis)
            
            print(f"\nSkill improvement analysis:")
            print(f"  Students with both PRE/POST: {total_paired}")
            print(f"  Students who improved: {improved_students} ({improved_students/total_paired*100:.1f}%)")
            print(f"  Average skill change: {avg_improvement:.2f} levels")
    
    return student_aggregated

def create_submission_file():
    print(f"\n=== CREATING SUBMISSION FILE ===")
    
    results_df = pd.DataFrame({
        'video_name': training_history['final_videos'],
        'prediction': training_history['final_predictions'],
        'label': training_history['final_labels'],
        'session': training_history['final_sessions'],
        'student': training_history['final_students']
    })
    
    video_predictions = results_df.groupby('video_name').agg({
        'prediction': lambda x: x.mode().iloc[0] if len(x.mode()) > 0 else x.iloc[0],
        'label': lambda x: x.mode().iloc[0] if len(x.mode()) > 0 else x.iloc[0],
        'session': 'first',
        'student': 'first'
    })
    
    submission_df = pd.DataFrame({
        'video_name': video_predictions.index,
        'predicted_grs_class': video_predictions['prediction'],
        'predicted_grs_label': [GRS_MAPPING['labels'][pred] for pred in video_predictions['prediction']]
    })
    
    submission_df.to_csv('task1_submission.csv', index=False)
    
    print(f"Submission file created: task1_submission.csv")
    print(f"  Videos: {len(submission_df)}")
    print(f"  Predictions distribution:")
    pred_dist = submission_df['predicted_grs_label'].value_counts()
    for label, count in pred_dist.items():
        print(f"    {label}: {count}")
    
    return submission_df

student_analysis = analyze_student_performance()
submission_df = create_submission_file()

## 9. Save Results and Model

In [None]:
final_results = {
    'model_config': {
        'architecture': 'TimeSformer',
        'model_size': 'base',
        'num_classes': len(GRS_MAPPING['classes']),
        'sequence_length': CONFIG['sequence_length'],
        'input_size': CONFIG['frame_size']
    },
    'training_config': TRAINING_CONFIG,
    'dataset_info': metadata['dataset_info'],
    'performance_metrics': {
        'best_macro_f1': training_history['best_macro_f1'],
        'final_accuracy': final_accuracy,
        'final_macro_f1': final_macro_f1,
        'expected_cost': expected_cost
    },
    'training_history': training_history,
    'session_metrics': session_metrics,
    'grs_mapping': GRS_MAPPING
}

with open('task1_results.json', 'w') as f:
    json.dump(final_results, f, indent=2, default=str)

torch.save({
    'model_state_dict': model.state_dict(),
    'config': CONFIG,
    'training_config': TRAINING_CONFIG,
    'grs_mapping': GRS_MAPPING,
    'performance': final_results['performance_metrics']
}, 'final_timesformer_model.pth')

print(f"\n✅ Training completed successfully!")
print(f"\nFiles saved:")
print(f"  - best_timesformer_model.pth (best model checkpoint)")
print(f"  - final_timesformer_model.pth (final model)")
print(f"  - task1_results.json (complete results)")
print(f"  - task1_submission.csv (submission file)")
print(f"  - training_results.png (training plots)")

print(f"\n{'='*70}")
print(f"AIXSUTURE TIMESFORMER TRAINING SUMMARY")
print(f"{'='*70}")
print(f"🎯 Task: Global Rating Score Classification")
print(f"🏗️ Model: TimeSformer (base, {total_params:,} parameters)")
print(f"📊 Dataset: {metadata['dataset_info']['total_videos']} videos, {metadata['dataset_info']['total_students']} students")
print(f"🔄 Sessions: {metadata['dataset_info']['sessions']}")
print(f"👥 Investigators: {len(metadata['dataset_info']['investigators'])} raters")
print(f"🎥 Sequences: {len(train_data)} train, {len(val_data)} val")
print(f"📈 Best Macro F1: {training_history['best_macro_f1']:.4f}")
print(f"🎯 Final Accuracy: {final_accuracy:.4f}")
print(f"💰 Expected Cost: {expected_cost:.4f}")
print(f"⚡ Training: {len(training_history['train_losses'])} epochs")
print(f"{'='*70}")

print(f"\nNext steps:")
print(f"1. Analyze results by session (PRE vs POST training)")
print(f"2. Compare with inter-rater reliability metrics")
print(f"3. Submit task1_submission.csv to OSS Challenge")
print(f"4. Consider ensemble methods for improved performance")
print(f"5. Extend to Task 2 (OSATS criteria prediction)")