# Siamese TCN Training - Full 22-Player Context

This notebook trains a Siamese TCN to learn embeddings for play similarity using **all 22 players**.

**Updates for full 22-player context:**
- Input: (22, 2, T) → flattened to (44, T) for 22 players × 2 coords
- Model: input_channels=44 instead of 22
- Distance normalization: Auto-computed from 95th percentile

## Setup

In [None]:
# Upload training files to Colab:
# - aligned_scenes.pkl
# - training_pairs.pkl

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pickle
import os
from tqdm import tqdm

In [None]:
# Configuration
DATA_PATH = 'aligned_scenes.pkl'
LABELS_PATH = 'training_pairs.pkl'
BATCH_SIZE = 32
LEARNING_RATE = 1e-3
EPOCHS = 20
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Using device: {DEVICE}")

## 1. Dataset

In [None]:
class SoccerSceneDataset(Dataset):
    def __init__(self, scenes_path, labels_path):
        print(f"Loading scenes from {scenes_path}...")
        with open(scenes_path, 'rb') as f:
            self.scenes = pickle.load(f)
            
        print(f"Loading labels from {labels_path}...")
        with open(labels_path, 'rb') as f:
            self.pairs = pickle.load(f)
        
        # Compute distance statistics for normalization
        distances = [p['distance'] for p in self.pairs]
        self.dist_norm_factor = np.percentile(distances, 95)
        
        print(f"Loaded {len(self.scenes)} scenes and {len(self.pairs)} training pairs")
        print(f"Distance stats: min={np.min(distances):.1f}, max={np.max(distances):.1f}, "
              f"mean={np.mean(distances):.1f}, 95th percentile={self.dist_norm_factor:.1f}")
            
    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        pair = self.pairs[idx]
        
        idx_a = pair['index_a']
        idx_b = pair['index_b']
        dist_target = pair['distance']
        
        # Get tensors (22, 2, T) - all 22 players (11 attacking + 11 defending)
        tensor_a = self.scenes[idx_a]['scene_tensor']  # (22, 2, 150)
        tensor_b = self.scenes[idx_b]['scene_tensor']  # (22, 2, 150)
        
        # Fill NaNs with 0
        tensor_a = np.nan_to_num(tensor_a, nan=0.0)
        tensor_b = np.nan_to_num(tensor_b, nan=0.0)
        
        # Reshape to (44, T) for model input
        # (22, 2, T) -> (44, T) by flattening player dimension
        feat_a = tensor_a.reshape(-1, tensor_a.shape[2])  # (44, 150)
        feat_b = tensor_b.reshape(-1, tensor_b.shape[2])  # (44, 150)
        
        # Normalize distance using 95th percentile
        dist_target = dist_target / self.dist_norm_factor
        
        return (
            torch.tensor(feat_a, dtype=torch.float32),
            torch.tensor(feat_b, dtype=torch.float32),
            torch.tensor(dist_target, dtype=torch.float32)
        )

## 2. Model Architecture

In [None]:
class GatedTCNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation, dropout=0.2):
        super().__init__()
        self.conv_f = nn.Conv1d(in_channels, out_channels, kernel_size, 
                                padding=(kernel_size-1)*dilation//2, 
                                dilation=dilation)
        self.conv_g = nn.Conv1d(in_channels, out_channels, kernel_size, 
                                padding=(kernel_size-1)*dilation//2, 
                                dilation=dilation)
        self.dropout = nn.Dropout(dropout)
        
        # Residual connection
        self.downsample = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else None

    def forward(self, x):
        f = torch.tanh(self.conv_f(x))
        g = torch.sigmoid(self.conv_g(x))
        out = f * g
        out = self.dropout(out)
        
        res = x if self.downsample is None else self.downsample(x)
        return out + res

class SiameseTCN(nn.Module):
    def __init__(self, input_channels=44, hidden_channels=32, embedding_dim=64):
        super().__init__()
        
        # TCN Backbone
        self.tcn = nn.Sequential(
            GatedTCNBlock(input_channels, hidden_channels, kernel_size=3, dilation=1),
            GatedTCNBlock(hidden_channels, hidden_channels, kernel_size=3, dilation=2),
            GatedTCNBlock(hidden_channels, hidden_channels, kernel_size=3, dilation=4),
            GatedTCNBlock(hidden_channels, hidden_channels, kernel_size=3, dilation=8)
        )
        
        # Embedding Head
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Sequential(
            nn.Linear(hidden_channels, embedding_dim),
            nn.LayerNorm(embedding_dim)
        )
        
    def forward_one(self, x):
        # x: (B, 44, T) - batch of flattened scene tensors (22 players × 2 coords)
        x = self.tcn(x)
        x = self.pool(x).squeeze(-1)  # (B, hidden)
        x = self.fc(x)
        return x
    
    def forward(self, x1, x2):
        emb1 = self.forward_one(x1)
        emb2 = self.forward_one(x2)
        return emb1, emb2

## 3. Training Loop

In [None]:
# Create dataset and loader
dataset = SoccerSceneDataset(DATA_PATH, LABELS_PATH)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

# Initialize model
model = SiameseTCN(input_channels=44, hidden_channels=32, embedding_dim=64).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.MSELoss()

print(f"Model has {sum(p.numel() for p in model.parameters())} parameters")

In [None]:
# Training loop
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    
    pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for batch_idx, (feat_a, feat_b, dist_target) in enumerate(pbar):
        feat_a = feat_a.to(DEVICE)
        feat_b = feat_b.to(DEVICE)
        dist_target = dist_target.to(DEVICE)
        
        # Forward pass
        emb_a, emb_b = model(feat_a, feat_b)
        
        # Compute Euclidean distance
        pred_dist = torch.norm(emb_a - emb_b, p=2, dim=1)
        
        # Loss
        loss = criterion(pred_dist, dist_target)
        
        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': loss.item()})
    
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{EPOCHS}, Average Loss: {avg_loss:.4f}")

print("Training complete!")

## 4. Save Model

In [None]:
# Save model weights
torch.save(model.state_dict(), 'siamese_tcn_full.pth')
print("Model saved to siamese_tcn_full.pth")

# Download this file and use it with the index pipeline