In [None]:
import sys
import os

sys.path.append(os.path.abspath('..'))

from src.data import *
from src.models import *
from src.utils import *

# SINGLE HEAD

## BASELINE FFN

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
import pandas as pd
import numpy as np
import seaborn as sns
from sklearn.model_selection import train_test_split, KFold
from sklearn.utils.class_weight import compute_class_weight
import re
import os
import random

# --- CONFIGURAZIONE GLOBALE ---
SEQ_LEN = 30       # Fixed sequence length
BATCH_SIZE = 64
EPOCHS = 10
LR = 1e-3
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Define the path
save_path = 'tennis_shot_forecasting.pth'

# Seed everything to avoid randomness
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    print(f"Random seed set to {seed}")

# Call it immediately
seed_everything(42)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split

def train_singlehead_baseline(dataset, epochs=5, batch_size=64, lr=1e-3, device='cuda'):
    print(f"--- STARTING SINGLE-HEAD BASELINE (Vocab: {len(dataset.unified_vocab)}) ---")
    
    # 1. IDENTIFY SERVE TOKENS ONCE
    serve_token_ids = set()
    for key, idx in dataset.unified_vocab.items():
        if key.lower().startswith('serve') or key.startswith('S_'):
            serve_token_ids.add(idx)
            
    print(f"Training will ignore {len(serve_token_ids)} serve tokens as targets.")

    # --- MODIFICATION: 80 / 15 / 5 SPLIT ---
    total_len = len(dataset)
    train_len = int(0.80 * total_len)
    val_len   = int(0.15 * total_len)
    test_len  = total_len - train_len - val_len # Remaining 5%
    
    print(f"Data Split -> Train: {train_len}, Val: {val_len}, Test: {test_len}")

    # Use a generator for strict reproducibility
    gen = torch.Generator().manual_seed(42)
    train_ds, val_ds, test_ds = torch.utils.data.random_split(
        dataset, [train_len, val_len, test_len], generator=gen
    )
    
    # Train on 80%, Validate (for print stats) on 15%
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
    
    model = SimpleUnifiedBaseline(
        vocab_size=len(dataset.unified_vocab),
        context_dim=10
    ).to(device)
    
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        
        for batch in train_loader:
            full_seq = batch['x_seq'].to(device)
            ctx = batch['context'].to(device)
            
            x_in = full_seq[:, :-1] 
            y_target = full_seq[:, 1:]
            
            # Mask serves
            y_target_masked = y_target.clone()
            for s_id in serve_token_ids:
                y_target_masked[y_target == s_id] = 0

            optimizer.zero_grad()
            logits = model(x_in, ctx)
            
            loss = criterion(logits.reshape(-1, len(dataset.unified_vocab)), 
                             y_target_masked.reshape(-1))
            
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            
        # Validation loop (using the 15% validation set)
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for batch in val_loader:
                full_seq = batch['x_seq'].to(device)
                ctx = batch['context'].to(device)
                x_in = full_seq[:, :-1]
                y_target = full_seq[:, 1:]
                
                logits = model(x_in, ctx)
                preds = logits.argmax(dim=-1)
                
                mask = (y_target != 0)
                if mask.sum() > 0:
                    correct += (preds[mask] == y_target[mask]).sum().item()
                    total += mask.sum().item()
        
        acc = (correct / total * 100) if total > 0 else 0
        print(f"Epoch {epoch+1} | Loss: {total_loss/len(train_loader):.4f} | Val Acc (15% split): {acc:.2f}%")

    # Return model AND the held-out test set
    return model, test_ds

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import pandas as pd


base_path = '/kaggle/input/wta-points/'

# List all point files to merge
point_files = [
    base_path + 'charting-w-points-2020s.csv',
    base_path + 'charting-w-points-2010s.csv',
    base_path + 'charting-w-points-to-2009.csv'
]

# New Matches File
matches_path = '/kaggle/input/wta-matches/charting-w-matches.csv'
'''

base_path = '/kaggle/input/atp-points/'

# List all point files to merge
point_files = [
    base_path + 'charting-m-points-2020s.csv',
    base_path + 'charting-m-points-2010s.csv',
    base_path + 'charting-m-points-to-2009.csv'
]

# New Matches File
matches_path = '/kaggle/input/atp-matches-updated/charting-m-matches-updated.csv'
'''

atp_path = '/kaggle/input/atp-players/atp_players.csv'
wta_path = '/kaggle/input/wta-players/wta_players.csv'

datasetSingle = MCPTennisDataset(point_files, matches_path, atp_path, wta_path, max_seq_len=SEQ_LEN) 


if 'datasetSingle' not in globals():
    print("Please ensure 'dataset' is loaded.")
else:
    baselineSingleHead, test_ds = train_singlehead_baseline(
        datasetSingle, 
        epochs=10, 
        batch_size=512, 
        device=DEVICE
    )

    # 2. PREPARE TEST DATA
    # Create loader specifically for the 5% test set
    test_loader = DataLoader(test_ds, batch_size=64, shuffle=False)
    
    # Extract indices from the subset (needed for Part 2 & 3 of evaluation to look up match IDs)
    test_indices = test_ds.indices

In [None]:
def check_surface_distribution(dataset):
    print("--- SURFACE DISTRIBUTION ANALYSIS ---")
    
    # 1. Identify Unique Matches in the Dataset
    # sample_match_ids contains the match_id for every single rally sequence in the data
    if not hasattr(dataset, 'sample_match_ids'):
        print("Error: Dataset does not track sample_match_ids.")
        return

    unique_match_ids = set(dataset.sample_match_ids)
    print(f"Total Unique Matches with valid rallies: {len(unique_match_ids):,}")

    # 2. Count Surfaces
    # We look up the surface for each unique match using the match_meta dictionary
    surface_counts = Counter()
    
    for m_id in unique_match_ids:
        # Default to 'Hard' if metadata is missing (same logic as in your Dataset class)
        meta = dataset.match_meta.get(m_id, {'surface': 'Hard'})
        surface = meta.get('surface', 'Hard')
        
        # Normalize labels just in case
        if 'Clay' in surface: surface = 'Clay'
        elif 'Grass' in surface: surface = 'Grass'
        else: surface = 'Hard'
        
        surface_counts[surface] += 1

    # 3. Create DataFrame for Display
    df_surf = pd.DataFrame.from_dict(surface_counts, orient='index', columns=['Count'])
    df_surf.index.name = 'Surface'
    df_surf = df_surf.sort_values('Count', ascending=False)
    
    # Add Percentage column
    total_matches = df_surf['Count'].sum()
    df_surf['Percentage'] = (df_surf['Count'] / total_matches * 100).round(2)
    
    print("\n[Match Counts by Surface]")
    print(df_surf)

    # 4. Plot
    plt.figure(figsize=(8, 5))
    # Custom colors for tennis surfaces
    colors = {'Hard': '#1f77b4', 'Clay': '#d62728', 'Grass': '#2ca02c'}
    palette = [colors.get(s, 'grey') for s in df_surf.index]
    
    sns.barplot(x=df_surf.index, y=df_surf['Count'], palette=palette)
    plt.title(f'Distribution of Surfaces (N={total_matches} Matches)')
    plt.ylabel('Number of Matches')
    plt.grid(axis='y', alpha=0.3)
    
    # Add text labels
    for i, v in enumerate(df_surf['Count']):
        plt.text(i, v + (0.01*v), str(v), ha='center', fontweight='bold')
        
    plt.show()

# --- EXECUTE ---
if 'datasetSingle' in globals():
    check_surface_distribution(datasetSingle)
else:
    print("Please ensure 'datasetSingle' is loaded.")

### Evaluation

In [None]:
# 1. Generate the map using the dataset attached to your model
universal_map = get_universal_decoder_map(datasetSingle)

# 2. Initialize the Adapter
adapter = UnifiedAdapter(
    model=baselineSingleHead, 
    device=DEVICE, 
    dataset=datasetSingle, 
    uni_map=universal_map  # <--- Pass the robust map here
)

# 3. Run Evaluation
evaluator = TennisEvaluator(adapter, test_loader, test_indices)
evaluator.run_all()

In [None]:
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from collections import Counter
import seaborn as sns
import pandas as pd
import numpy as np
import random
import torch

# --- 1. SHARED CONFIG & HELPERS ---
EVAL_SHOT_VOCAB = {'<pad>': 0, 'f': 1, 'b': 2, 'r': 3, 'v': 4, 'o': 5, 's': 6, 'u': 7, 'l': 8, 'm': 9, 'z': 10}
EVAL_DIR_VOCAB  = {'<pad>': 0, '0': 0, '1': 1, '2': 2, '3': 3}
EVAL_DEPTH_VOCAB = {'<pad>': 0, '0': 0, '7': 1, '8': 2, '9': 3}
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

def get_fast_decoder_map(dataset):
    """
    Creates a direct lookup table: Unified_ID -> (Type_ID, Dir_ID, Depth_ID)
    Robustly handles Serves, Shots, Lets, and Specials.
    """
    uni_map = {}
    serve_id = EVAL_SHOT_VOCAB.get('s', 0)
    
    for uid, key in dataset.inv_unified_vocab.items():
        # 1. Handle Padding/Unk
        if uid <= 1: 
            uni_map[uid] = (0,0,0)
            continue
            
        parts = key.split('_')
        
        # 2. Handle Serves
        if parts[0] == 'Serve':
            uni_map[uid] = (serve_id, 0, 0)
            
        # 3. Handle Specials/Lets (e.g., "LET_1", "SPECIAL_S")
        # These are not shots, so we map them to (0,0,0) or a specific ID if needed.
        # For this evaluation, we treat them as 'No Shot' (0).
        elif parts[0] in ['LET', 'SPECIAL']:
            uni_map[uid] = (0, 0, 0)
            
        # 4. Handle Standard Shots (e.g., "f_2_8_...")
        else:
            # We expect at least Type, Dir, Depth.
            # If the parser produced a weird key like "f_1" (len 2), handle gracefully.
            if len(parts) < 3:
                t = EVAL_SHOT_VOCAB.get(parts[0], 0)
                d = EVAL_DIR_VOCAB.get(parts[1], 0) if len(parts) > 1 else 0
                uni_map[uid] = (t, d, 0)
            else:
                uni_map[uid] = (
                    EVAL_SHOT_VOCAB.get(parts[0], 0), 
                    EVAL_DIR_VOCAB.get(parts[1], 0), 
                    EVAL_DEPTH_VOCAB.get(parts[2], 0)
                )
    return uni_map

def decode_unified_predictions(preds, dataset):
    # Wrapper to use the fast map for list processing
    uni_map = get_fast_decoder_map(dataset)
    types, dirs, depths = [], [], []
    for p in preds:
        t, d, dep = uni_map.get(p, (0,0,0))
        types.append(t); dirs.append(d); depths.append(dep)
    return types, dirs, depths

# --- 2. THE ADAPTED EVALUATION FUNCTION ---
def run_full_evaluation(model, dataset, loader, test_indices, 
                        live_samples=5000, 
                        length_matches=2000, 
                        freq_matches=2000, 
                        era_matches=50, 
                        speed_samples=10000):
    
    model.eval()
    print(f"Starting Full Evaluation on {len(test_indices)} test samples...")
    uni_map = get_fast_decoder_map(dataset)
    
    # Pre-calculate Match-to-Index Map
    match_map = {}
    for idx in test_indices:
        mid = dataset.sample_match_ids[idx]
        match_map.setdefault(mid, []).append(idx)
    unique_matches = list(match_map.keys())

    # ==============================================================================
    # PART 1: OVERALL TACTICAL METRICS (Updated)
    # ==============================================================================
    print("\n" + "="*40 + "\n PART 1: OVERALL TACTICAL METRICS \n" + "="*40)
    all_preds_unified, all_targets_unified = [], []
    
    print("Running Evaluation on TEST SET (Unified Model)...")
    with torch.no_grad():
        for batch in loader:
            x_seq = batch['x_seq'].to(DEVICE)
            x_c = batch['context'].to(DEVICE)
            y = batch['y_target'].to(DEVICE)
            
            # Forward pass
            logits = model(x_seq, x_c)
            
            # Mask out padding (0) for cleaner metrics
            mask = y.view(-1) != 0
            all_preds_unified.extend(logits.argmax(-1).view(-1)[mask].cpu().numpy())
            all_targets_unified.extend(y.view(-1)[mask].cpu().numpy())

    print("Decoding predictions...")
    pred_t, pred_d, pred_dp = decode_unified_predictions(all_preds_unified, dataset)
    targ_t, targ_d, targ_dp = decode_unified_predictions(all_targets_unified, dataset)

    # 1. Direction Report
    print("\n=== DIRECTION REPORT (1=Right, 2=Center, 3=Left) ===")
    # Filter out Pad and '0' (assuming 0 is undefined/unknown direction)
    d_labels = [k for k,v in EVAL_DIR_VOCAB.items() if v in np.unique(targ_d) and k not in ['<pad>', '0']]
    d_indices = [EVAL_DIR_VOCAB[k] for k in d_labels]
    print(classification_report(targ_d, pred_d, labels=d_indices, target_names=d_labels, zero_division=0))
    
    # 2. Depth Report
    print("\n=== DEPTH REPORT (7=Shallow, 8=Mid, 9=Deep) ===")
    dp_labels = [k for k,v in EVAL_DEPTH_VOCAB.items() if v in np.unique(targ_dp) and k not in ['<pad>', '0']]
    dp_indices = [EVAL_DEPTH_VOCAB[k] for k in dp_labels]
    print(classification_report(targ_dp, pred_dp, labels=dp_indices, target_names=dp_labels, zero_division=0))
    
    # 3. Shot Type Report
    print("\n=== SHOT TYPE REPORT ===")
    tp_labels = [k for k,v in EVAL_SHOT_VOCAB.items() if v in np.unique(targ_t) and k not in ['<pad>']]
    tp_indices = [EVAL_SHOT_VOCAB[k] for k in tp_labels]
    print(classification_report(targ_t, pred_t, labels=tp_indices, target_names=tp_labels, zero_division=0))

    # ==============================================================================
    # PART 2: DETAILED LIVE SAMPLES
    # ==============================================================================
    print("\n" + "="*40 + f"\n PART 2: LIVE SAMPLES (Showing {live_samples} Cases) \n" + "="*40)
    
    # Setup for pretty printing
    inv_dir = {v:k for k,v in EVAL_DIR_VOCAB.items()}
    inv_typ = {v:k for k,v in EVAL_SHOT_VOCAB.items()}
    
    selected_indices = random.sample(test_indices, min(live_samples * 2, len(test_indices))) # Sample extra to account for skips
    results_buffer = {3: [], 2:[], 1:[], 0:[]}
    
    printed_count = 0
    
    with torch.no_grad():
        for idx in selected_indices:
            if printed_count >= live_samples: break
            
            sample = dataset[idx]
            non_zeros = (sample['x_seq'] != 0).nonzero(as_tuple=True)[0]
            if len(non_zeros) < 2: continue
            
            # Predict a random point in the sequence
            valid_indices = non_zeros.tolist()
            t = random.choice(valid_indices)
            
            x_seq = sample['x_seq'].unsqueeze(0).to(DEVICE)
            x_c = sample['context'].unsqueeze(0).to(DEVICE)
            
            # Forward
            logits = model(x_seq, x_c)
            
            # --- Build History String ---
            start_idx = valid_indices[0]
            history_str = ""
            for i in range(start_idx, t + 1):
                uid = sample['x_seq'][i].item()
                typ, d, dep = uni_map.get(uid, (0,0,0))
                z_in = inv_dir.get(d, '?')
                t_in = inv_typ.get(typ, '?')
                
                if i == start_idx:
                    history_str += f"[Serve {z_in}] " if t_in == 's' else f"[{t_in}{z_in}] "
                else:
                    history_str += f"-> {t_in}{z_in} "
            
            # --- Get Prediction ---
            probs = torch.softmax(logits[0, t], dim=0)
            pred_uid = probs.argmax().item()
            conf = probs.max().item() * 100
            
            pred_t, pred_d, pred_dp = uni_map.get(pred_uid, (0,0,0))
            
            true_uid = sample['y_target'][t].item()
            true_t, true_d, true_dp = uni_map.get(true_uid, (0,0,0))
            
            # Skip uninformative cases (pad vs pad)
            if true_t == 0: continue

            s_pred_d = inv_dir.get(pred_d, '?'); s_pred_t = inv_typ.get(pred_t, '?')
            s_true_d = inv_dir.get(true_d, '?'); s_true_t = inv_typ.get(true_t, '?')
            
            check_d = "‚úÖ" if pred_d == true_d else "‚ùå"
            check_t = "‚úÖ" if pred_t == true_t else "‚ùå"
            check_dp = "‚úÖ" if pred_dp == true_dp else "‚ùå"
            
            def d_lbl(x):
                if x == 1: return "Short"
                if x == 2: return "Deep"
                if x == 3: return "V.Deep"
                return "N/A"
            
            score = (1 if pred_d == true_d else 0) + (1 if pred_t == true_t else 0) + (1 if pred_dp == true_dp else 0)
            
            m_id = dataset.sample_match_ids[idx]
            p1 = dataset.match_meta.get(m_id, {}).get('p1_name', 'Unknown')

            out = []
            out.append(f"\nMatch {m_id} ({p1}):")
            out.append(f"  History:    {history_str}")
            out.append(f"  Prediction: {s_pred_t} to {s_pred_d} ({d_lbl(pred_dp)}) | Conf: {conf:.0f}%")
            out.append(f"  ACTUAL:     {s_true_t} to {s_true_d} ({d_lbl(true_dp)}) | {check_t} Type {check_d} Dir {check_dp} Dep")
            
            results_buffer[score].append("\n".join(out))
            printed_count += 1

    print_flag = False
    
    for s in [3,2,1,0]:
        items = results_buffer[s]
        if items:
            print(f"\n{'='*20} {s}/3 CORRECT ({len(items)} cases) {'='*20}")
            if print_flag:
                for item in items: 
                    print(item)
                
    # ==============================================================================
    # PART 3: GRANULAR ACCURACY VS RALLY LENGTH
    # ==============================================================================
    print("\n" + "="*40 + "\n PART 3: GRANULAR ACCURACY VS RALLY LENGTH \n" + "="*40)
    
    # 3.1 Calculate Baselines (Weighted Random Probability)
    print("Calculating dataset baselines...")
    all_d, all_dp, all_tp = [], [], []
    
    for i in test_indices:
        y_seq = dataset[i]['y_target']
        for uid in y_seq:
            if uid.item() <= 1: continue
            t, d, dep = uni_map[uid.item()]
            all_tp.append(t)
            all_d.append(d)
            if dep != 0: all_dp.append(dep)

    def calc_baseline(data_list):
        if not data_list: return 0.33
        counts = Counter(data_list)
        total = sum(counts.values())
        return sum([(c/total)**2 for c in counts.values()])

    base_d = calc_baseline(all_d)
    base_dp = calc_baseline(all_dp)
    base_tp = calc_baseline(all_tp)
    
    base_pair_avg = (base_d*base_dp + base_d*base_tp + base_tp*base_dp) / 3
    base_whole = base_d * base_dp * base_tp
    
    print(f"Baselines -> Dir: {base_d:.2f}, Depth: {base_dp:.2f}, Type: {base_tp:.2f}, Whole: {base_whole:.4f}")
    
    # 3.2 Analysis Loop
    # We sample matches to get coherent rally structures
    test_match_ids = [dataset.sample_match_ids[i] for i in test_indices]
    unique_matches_p3 = sorted(list(set(test_match_ids)))
    selected_matches_p3 = random.sample(unique_matches_p3, min(length_matches, len(unique_matches_p3)))
    selected_indices_p3 = [i for i in test_indices if dataset.sample_match_ids[i] in selected_matches_p3]
    
    print(f"Analyzing {len(selected_indices_p3)} points from {len(selected_matches_p3)} matches...")

    results_p3 = []
    with torch.no_grad():
        for idx in selected_indices_p3:
            sample = dataset[idx]
            x_seq = sample['x_seq'].unsqueeze(0).to(DEVICE)
            x_c = sample['context'].unsqueeze(0).to(DEVICE)
            y = sample['y_target'].to(DEVICE)

            logits = model(x_seq, x_c)
            preds = logits.argmax(dim=-1).squeeze(0)
            
            seq_len = x_seq.shape[1]
            for t in range(seq_len):
                if x_seq[0, t] == 0: continue
                
                # Calculate absolute shot number
                history_so_far = x_seq[0, :t+1]
                true_shot_count = (history_so_far != 0).sum().item()
                shot_num = true_shot_count + 1

                p_uid = preds[t].item(); t_uid = y[t].item()
                if t_uid <= 1: continue 
                
                p_t, p_d, p_dp = uni_map.get(p_uid, (0,0,0))
                t_t, t_d, t_dp = uni_map.get(t_uid, (0,0,0))

                # --- LOGIC SPLIT ---
                # 1. ALWAYS valid tasks
                results_p3.append({'Shot_Number': shot_num, 'Task': 'Direction', 'Type': 'Single', 'Accuracy': 1 if p_d == t_d else 0})
                results_p3.append({'Shot_Number': shot_num, 'Task': 'Type', 'Type': 'Single', 'Accuracy': 1 if p_t == t_t else 0})
                results_p3.append({'Shot_Number': shot_num, 'Task': 'Dir + Type', 'Type': 'Pair', 'Accuracy': 1 if (p_d==t_d and p_t==t_t) else 0})
                
                # 2. DEPTH-DEPENDENT tasks (Only count if target depth is NOT 0)
                if t_dp != 0:
                    results_p3.append({'Shot_Number': shot_num, 'Task': 'Depth', 'Type': 'Single', 'Accuracy': 1 if p_dp == t_dp else 0})
                    results_p3.append({'Shot_Number': shot_num, 'Task': 'Dir + Depth', 'Type': 'Pair', 'Accuracy': 1 if (p_d==t_d and p_dp==t_dp) else 0})
                    results_p3.append({'Shot_Number': shot_num, 'Task': 'Type + Depth', 'Type': 'Pair', 'Accuracy': 1 if (p_t==t_t and p_dp==t_dp) else 0})
                
                # 3. Whole Shot
                results_p3.append({'Shot_Number': shot_num, 'Task': 'Whole Shot', 'Type': 'Whole', 'Accuracy': 1 if p_uid == t_uid else 0})

    if results_p3:
        df = pd.DataFrame(results_p3)
        df = df[(df['Shot_Number'] <= 12) & (df['Shot_Number'] >= 2)]
        
        # --- PLOTTING ---
        palette_single = {'Direction': '#1f77b4', 'Depth': '#d62728', 'Type': '#2ca02c'}
        palette_pair   = {'Dir + Depth': '#9467bd', 'Dir + Type': '#17becf', 'Type + Depth': '#ff7f0e'}
        palette_whole  = {'Whole Shot': '#000000'}

        def setup_plot(title, baseline, base_label):
            plt.figure(figsize=(12, 5))
            plt.title(title, fontsize=14)
            plt.ylabel('Accuracy', fontsize=12)
            plt.xlabel('Shot Number', fontsize=12)
            plt.xticks(np.arange(2, 13, 1))
            
            ax = plt.gca()
            ax.yaxis.set_major_locator(ticker.MultipleLocator(0.1))
            ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.02))
            plt.grid(True, which='major', axis='y', linestyle='-', linewidth=0.75, color='grey', alpha=0.6)
            plt.grid(True, which='minor', axis='y', linestyle='--', linewidth=0.5, color='grey', alpha=0.3)
            plt.ylim(0.0, 1.0)
            
            if baseline:
                plt.axhline(baseline, color='#FF1493', linestyle=':', alpha=0.8, linewidth=2, label=base_label)

        # GRAPH 1: SINGLE
        setup_plot('Single Task Accuracy vs. Rally Length', base_tp, f'Random Type ({base_tp:.2f})')
        sns.lineplot(data=df[df['Type']=='Single'], x='Shot_Number', y='Accuracy', hue='Task', style='Task', 
                     markers=True, dashes=False, palette=palette_single, linewidth=2.5, errorbar=('ci', 68))
        plt.legend(loc='lower right'); plt.show()
        
        # GRAPH 2: PAIRWISE
        setup_plot('Pairwise Accuracy vs. Rally Length', base_pair_avg, f'Random Pair (~{base_pair_avg:.2f})')
        sns.lineplot(data=df[df['Type']=='Pair'], x='Shot_Number', y='Accuracy', hue='Task', style='Task', 
                     markers=True, dashes=False, palette=palette_pair, linewidth=2.5, errorbar=('ci', 68))
        plt.legend(loc='upper right'); plt.show()

        # GRAPH 3: WHOLE SHOT
        setup_plot('Whole Shot Accuracy vs. Rally Length', base_whole, f'Random Whole ({base_whole:.4f})')
        sns.lineplot(data=df[df['Type']=='Whole'], x='Shot_Number', y='Accuracy', hue='Task', style='Task', 
                     markers=True, dashes={'Whole Shot':(2,2)}, palette=palette_whole, linewidth=2.5, errorbar=('ci', 68))
        plt.legend(loc='upper right'); plt.show()

    

    # ==============================================================================
    # PART 4: PLAYER FREQUENCY
    # ==============================================================================
    print("\n" + "="*40 + f"\n PART 4: PLAYER FREQUENCY ({freq_matches} Matches) \n" + "="*40)
    chosen_matches = random.sample(unique_matches, min(freq_matches, len(unique_matches)))
    pf_indices = [idx for mid in chosen_matches for idx in match_map[mid]]
    
    p_counts = Counter()
    p_stats = {}
    
    with torch.no_grad():
        for idx in pf_indices:
            sample = dataset[idx]
            y = sample['y_target']
            # --- ADAPTATION: We still access IDs for Stats, even if model ignores them ---
            s_id, r_id = sample['x_s_id'].item(), sample['x_r_id'].item()
            
            # --- ADAPTATION: Model Prediction ---
            preds = model(sample['x_seq'].unsqueeze(0).to(DEVICE), 
                          sample['context'].unsqueeze(0).to(DEVICE)).argmax(-1).squeeze(0)
            
            for t in range(len(y)):
                if y[t] == 0: continue
                hist_len = (sample['x_seq'][:t+1] != 0).sum().item()
                actor = s_id if (hist_len + 1) % 2 != 0 else r_id
                if actor <= 1: continue
                
                p_counts[actor] += 1
                if actor not in p_stats: p_stats[actor] = {'tot': 0, 'corr': 0}
                p_stats[actor]['tot'] += 1
                if preds[t].item() == y[t].item(): p_stats[actor]['corr'] += 1

    pf_data = [{'Freq': p_counts[a], 'Err': (1 - v['corr']/v['tot'])*100} for a, v in p_stats.items() if p_counts[a] > 10]
    if pf_data:
        df_pf = pd.DataFrame(pf_data)
        plt.figure(figsize=(8, 4))
        sns.regplot(data=df_pf, x='Freq', y='Err', scatter_kws={'alpha':0.5}, line_kws={'color':'red'})
        plt.xscale('log'); plt.title(f"Error vs Frequency (Corr: {df_pf['Freq'].corr(df_pf['Err']):.2f})"); plt.show()

    # ==============================================================================
    # PART 5: ERA STABILITY
    # ==============================================================================
    print("\n" + "="*40 + f"\n PART 5: ERA STABILITY ({era_matches} Matches/Era) \n" + "="*40)
    eras = {'Pre-2010': [], '2010-2019': [], '2020+': []}
    for m_id in unique_matches:
        try: y_year = int(str(m_id)[:4])
        except: continue
        if y_year < 2010: eras['Pre-2010'].append(m_id)
        elif y_year < 2020: eras['2010-2019'].append(m_id)
        else: eras['2020+'].append(m_id)

    era_indices = []
    era_labels_list = []
    for era_name, m_list in eras.items():
        if not m_list: continue
        chosen = random.sample(m_list, min(era_matches, len(m_list)))
        for m in chosen:
            era_indices.extend(match_map[m])
            era_labels_list.extend([era_name]*len(match_map[m]))
            
    era_res = []
    with torch.no_grad():
        for i, idx in enumerate(era_indices):
            sample = dataset[idx]
            y = sample['y_target'].to(DEVICE)
            # --- ADAPTATION ---
            preds = model(sample['x_seq'].unsqueeze(0).to(DEVICE), 
                          sample['context'].unsqueeze(0).to(DEVICE)).argmax(-1).squeeze(0)
            mask = y != 0
            if mask.sum() > 0:
                acc = (preds[mask] == y[mask]).float().mean().item()
                era_res.append({'Era': era_labels_list[i], 'Whole Shot Acc': acc})
    
    if era_res:
        plt.figure(figsize=(6, 4))
        sns.barplot(data=pd.DataFrame(era_res), x='Era', y='Whole Shot Acc', palette='viridis', order=['Pre-2010', '2010-2019', '2020+'])
        plt.title('Accuracy by Era'); plt.ylim(0, 1); plt.show()
    
    # ==============================================================================
    # PART 6: RAW ERROR ANALYSIS BY SURFACE (Unified Model + Masked Depth)
    # ==============================================================================
    print("\n" + "="*50)
    print(" RAW ERROR ANALYSIS BY SURFACE (Unified Model) ")
    print("="*50)

    # 1. Group Test Indices by Surface
    surface_map = {'Clay': [], 'Hard': [], 'Grass': []}
    for idx in test_indices:
        # Robustly fetch surface (default to Hard if missing)
        surf = dataset.match_meta.get(dataset.sample_match_ids[idx], {}).get('surface', 'Hard')
        found = False
        for k in surface_map: 
            if k in surf: 
                surface_map[k].append(idx)
                found = True
                break
        if not found: surface_map['Hard'].append(idx)
            
    # 2. Select Samples Balanced by Surface
    selected_indices, surface_labels = [], []
    per_surf = speed_samples // 3 
    
    for s, inds in surface_map.items():
        if not inds: continue
        chosen = random.sample(inds, min(len(inds), per_surf))
        selected_indices.extend(chosen)
        surface_labels.extend([s]*len(chosen))
        
    # 3. Evaluation Loop
    results = []
    
    # Ensure map is ready
    uni_map = get_fast_decoder_map(dataset)

    with torch.no_grad():
        for i, idx in enumerate(selected_indices):
            sample = dataset[idx]
            surf = surface_labels[i]
            
            # Prepare Inputs (Unified Model)
            x_seq = sample['x_seq'].unsqueeze(0).to(DEVICE)
            x_c   = sample['context'].unsqueeze(0).to(DEVICE)
            y_tgt = sample['y_target'].to(DEVICE)
            
            # Predict (One unified logit tensor)
            logits = model(x_seq, x_c) 
            preds = logits.argmax(dim=-1).squeeze(0) # [SeqLen]
            
            seq_len = x_seq.shape[1]
            
            for t in range(seq_len):
                # Get Unified IDs
                t_uid = y_tgt[t].item()
                p_uid = preds[t].item()

                # Skip Padding or Special/Serve tokens if you want pure rally analysis
                # (Assuming <pad>=0, <unk>=1)
                if t_uid <= 1: continue 
                
                # --- DECODE UNIFIED TOKENS ---
                # Returns (Type, Dir, Depth)
                p_t, p_d, p_dp = uni_map.get(p_uid, (0,0,0))
                t_t, t_d, t_dp = uni_map.get(t_uid, (0,0,0))

                # Skip if target decodes to type 0 (e.g. valid token but mapped to 0 like 'Let')
                if t_t == 0: continue

                # --- CALCULATE COMPONENT ERRORS ---
                
                # Whole Shot Error: Did we predict the EXACT same unified token?
                # (Alternative: match on all 3 components)
                whole_shot_miss = (p_uid != t_uid)

                # Component Errors
                type_err = 1.0 if p_t != t_t else 0.0
                dir_err  = 1.0 if p_d != t_d else 0.0
                
                # --- DEPTH MASKING FIX ---
                # Only penalize depth error if the TARGET actually has a depth (!= 0)
                if t_dp != 0:
                    depth_err = 1.0 if p_dp != t_dp else 0.0
                else:
                    depth_err = None 

                results.append({
                    'Surface': surf,
                    'Type Error': type_err,
                    'Direction Error': dir_err,
                    'Depth Error': depth_err,  # <--- Masked
                    'Whole Shot Error': 1.0 if whole_shot_miss else 0.0
                })
                
    # 4. Statistics & Plotting
    if not results:
        print("No results generated.")
        return

    df = pd.DataFrame(results)
    
    # Print Table (Mean Error %) - Pandas ignores None in mean() automatically
    stats = df.groupby('Surface')[['Type Error', 'Direction Error', 'Depth Error', 'Whole Shot Error']].mean() * 100
    print("\n--- Mean Error Rates (%) [Depth calc only on non-zero targets] ---")
    print(stats.round(2))
    
    # Plotting
    # Melt for seaborn (will drop Nones automatically or handle them)
    df_melt = df.melt(id_vars=['Surface'], 
                      value_vars=['Type Error', 'Direction Error', 'Depth Error', 'Whole Shot Error'], 
                      value_name='Error Rate')
    
    plt.figure(figsize=(12, 6))
    sns.barplot(data=df_melt, x='Surface', y='Error Rate', hue='variable', 
                order=['Clay', 'Hard', 'Grass'], palette='viridis')
    plt.title('Component Error Rates (Unified Model) by Surface')
    plt.ylabel('Error Rate (0.0 - 1.0)')
    plt.legend(title='Metric')
    plt.grid(axis='y', alpha=0.3)
    plt.show()

# --- 3. RUNNER SNIPPET (Add this to your notebook cell) ---
if 'datasetSingle' in globals() and 'baselineSingleHead' in globals():
    print(f"Evaluating on the held-out 5% test set ({len(test_ds)} samples)...")
    
    # Create loader
    test_loader = DataLoader(test_ds, batch_size=64, shuffle=False)

    # Run!
    run_full_evaluation(
        model=baselineSingleHead, 
        dataset=datasetSingle, 
        loader=test_loader, 
        test_indices=test_indices,
        live_samples=5000, 
        length_matches=2000,
        freq_matches=2000
    )
else:
    print("Error: 'datasetSingle' or 'baselineSingleHead' not found.")

## (RICH) LSTM

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split

def train_rich_model(dataset, epochs=10, batch_size=64, lr=1e-3, device='cuda'):
    print("--- STARTING RICH INPUT LSTM TRAINING ---")

    serve_token_ids = set()
    for key, idx in dataset.unified_vocab.items():
        if key.lower().startswith('serve') or key.startswith('S_'):
            serve_token_ids.add(idx)
    print(f"Training will ignore {len(serve_token_ids)} serve tokens in targets.")
    
    total_len = len(dataset)
    train_len = int(0.80 * total_len)
    val_len   = int(0.10 * total_len)
    test_len  = total_len - train_len - val_len

    print(f"Total Samples: {total_len}")
    print(f"Splits -> Train: {train_len}, Val: {val_len}, Test: {test_len}")

    train_ds, val_ds, test_ds = random_split(
        dataset, 
        [train_len, val_len, test_len],
        generator=torch.Generator().manual_seed(42) # Seed for reproducibility
    )

    #  Create DataLoaders
    # num_workers=0 is safer for Windows/Debugging. Increase for speed on Linux.
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    
    # 2. Initialize Model
    model = RichInputLSTM(
        unified_vocab_size=len(dataset.unified_vocab),
        num_players=len(dataset.player_vocab),
        type_vocab_size=len(dataset.type_vocab),
        dir_vocab_size=len(dataset.dir_vocab),
        depth_vocab_size=len(dataset.depth_vocab),
        context_dim=10
    ).to(device)
    
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        
        for batch in train_loader:
            x_t  = batch['x_type'].to(device)
            x_d  = batch['x_dir'].to(device)
            x_dp = batch['x_depth'].to(device)
            s_id = batch['x_s_id'].to(device)
            r_id = batch['x_r_id'].to(device)
            ctx  = batch['context'].to(device)
            y    = batch['y_target'].to(device)
            
            # --- FIX: MASK TARGETS ---
            y_masked = y.clone()
            for s_token in serve_token_ids:
                y_masked[y == s_token] = 0
            # -------------------------
            
            optimizer.zero_grad()
            
            logits = model(x_t, x_d, x_dp, s_id, r_id, ctx)
            
            # Calculate loss against masked targets
            loss = criterion(logits.view(-1, len(dataset.unified_vocab)), y_masked.view(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            
        # Validation
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for batch in val_loader:
                x_t  = batch['x_type'].to(device)
                x_d  = batch['x_dir'].to(device)
                x_dp = batch['x_depth'].to(device)
                s_id = batch['x_s_id'].to(device)
                r_id = batch['x_r_id'].to(device)
                ctx  = batch['context'].to(device)
                y    = batch['y_target'].to(device)
                
                # Mask validation targets too for fair accuracy
                y_val_masked = y.clone()
                for s_token in serve_token_ids:
                    y_val_masked[y == s_token] = 0
                
                logits = model(x_t, x_d, x_dp, s_id, r_id, ctx)
                preds = logits.argmax(dim=-1)
                
                mask = (y_val_masked != 0)
                if mask.sum() > 0:
                    correct += (preds[mask] == y_val_masked[mask]).sum().item()
                    total += mask.sum().item()
        
        acc = (correct / total * 100) if total > 0 else 0
        print(f"Epoch {epoch+1} | Loss: {total_loss/len(train_loader):.4f} | Val Acc (Non-Serve): {acc:.2f}%")

    return model

In [None]:
base_path = '/kaggle/input/atp-points/'

# List all point files to merge
point_files = [
    base_path + 'charting-m-points-2020s.csv',
    base_path + 'charting-m-points-2010s.csv',
    base_path + 'charting-m-points-to-2009.csv'
]

# New Matches File
matches_path = '/kaggle/input/atp-matches-updated/charting-m-matches-updated.csv'

atp_path = '/kaggle/input/atp-players/atp_players.csv'
wta_path = '/kaggle/input/wta-players/wta_players.csv'

datasetEnhanced = EnhancedTennisDataset(
        points_paths_list=point_files,
        matches_path=matches_path,
        atp_path=atp_path,
        wta_path=wta_path,
        max_seq_len=SEQ_LEN  # Length of rally history to look at
    )

if 'datasetEnhanced' not in globals():
    print("Please ensure 'dataset' is loaded.")
else:
    richLSTM = train_rich_model(datasetEnhanced, epochs=15, batch_size=512, device=DEVICE)

In [None]:
CHECKPOINT_PATH = "/kaggle/input/richlstm/transformers/default/1/rich_best.pt"

base_path = '/kaggle/input/atp-points/'

# List all point files to merge
point_files = [
    base_path + 'charting-m-points-2020s.csv',
    base_path + 'charting-m-points-2010s.csv',
    base_path + 'charting-m-points-to-2009.csv'
]

# New Matches File
matches_path = '/kaggle/input/atp-matches-updated/charting-m-matches-updated.csv'

atp_path = '/kaggle/input/atp-players/atp_players.csv'
wta_path = '/kaggle/input/wta-players/wta_players.csv'

def align_and_load_rich_lstm(checkpoint_path, dataset, device="cuda"):
    """
    1. Loads checkpoint.
    2. Maps the dataset's existing tensors (targets) to match the checkpoint's vocabulary.
    3. Loads the model weights.
    """
    print(f"üìÇ Loading checkpoint: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # --- STEP 1: VERIFY VOCABS EXIST ---
    if "unified_vocab" not in checkpoint:
        raise ValueError("Checkpoint must contain 'unified_vocab' to perform alignment!")

    ckpt_vocab = checkpoint["unified_vocab"]
    ds_vocab = dataset.unified_vocab # The vocab currently in the dataset
    
    print(f"üìä Dataset Vocab Size: {len(ds_vocab)} | Checkpoint Vocab Size: {len(ckpt_vocab)}")
    
    # --- STEP 2: BUILD TRANSLATION MAP (Old ID -> New ID) ---
    print("üîÑ Building translation map (Dataset IDs -> Checkpoint IDs)...")
    
    # Create inverse map for checkpoint: Word -> New_ID
    # We default to <unk> (usually 1) if a word is missing in checkpoint
    unk_id = ckpt_vocab.get('<unk>', 1)
    
    # Map: Dataset_ID -> Word -> Checkpoint_ID
    translation_map = {}
    
    # Iterate over the DATASET's current vocab
    for word, ds_id in ds_vocab.items():
        if word in ckpt_vocab:
            translation_map[ds_id] = ckpt_vocab[word]
        else:
            translation_map[ds_id] = unk_id
            
    # Convert map to a tensor for fast lookup [Size: Max_Dataset_ID + 1]
    max_ds_id = max(ds_vocab.values())
    map_tensor = torch.zeros(max_ds_id + 1, dtype=torch.long)
    for old_id, new_id in translation_map.items():
        map_tensor[old_id] = new_id
        
    # --- STEP 3: TRANSLATE TENSORS ---
    print("‚ö° Translating dataset tensors...")
    
    # Translate Targets (Crucial for Error Calculation)
    # .apply_() is slow, using tensor indexing is fast: new_tensor = map_tensor[old_tensor]
    dataset.y_target_tensor = map_tensor[dataset.y_target_tensor]
    
    # Translate Inputs (Unified Sequence) - Just to be safe, though RichLSTM uses decomposed
    dataset.x_seq_tensor = map_tensor[dataset.x_seq_tensor]
    
    print("‚úÖ Tensors aligned.")

    # --- STEP 4: OVERWRITE DATASET VOCABULARIES ---
    # Now that tensors are translated, we must update the dataset's vocab definitions
    # so the evaluation decoder uses the correct strings.
    dataset.unified_vocab = checkpoint["unified_vocab"]
    dataset.inv_unified_vocab = {v: k for k, v in dataset.unified_vocab.items()}
    
    # Also sync player/other vocabs if present
    if "player_vocab" in checkpoint: dataset.player_vocab = checkpoint["player_vocab"]
    if "type_vocab" in checkpoint: dataset.type_vocab = checkpoint["type_vocab"]
    if "dir_vocab" in checkpoint: dataset.dir_vocab = checkpoint["dir_vocab"]
    if "depth_vocab" in checkpoint: dataset.depth_vocab = checkpoint["depth_vocab"]
    
    print("üì¶ Dataset vocabularies updated to match checkpoint.")

    # --- STEP 5: LOAD MODEL ---
    model = RichInputLSTM(
        unified_vocab_size=len(dataset.unified_vocab),
        num_players=len(dataset.player_vocab),
        type_vocab_size=len(dataset.type_vocab),
        dir_vocab_size=len(dataset.dir_vocab),
        depth_vocab_size=len(dataset.depth_vocab),
        context_dim=10,
        hidden_dim=256,
        num_layers=2, 
        dropout=0.2
    ).to(device)
    
    if "model_state_dict" in checkpoint:
        model.load_state_dict(checkpoint["model_state_dict"])
    else:
        model.load_state_dict(checkpoint)
        
    model.eval()
    print("üöÄ Model weights loaded successfully.")
    
    return model

### Evaluation

In [None]:
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from collections import Counter
import seaborn as sns
import pandas as pd
import numpy as np
import random
import torch
from torch.utils.data import DataLoader, random_split

# --- 1. SHARED CONFIG & HELPERS ---
EVAL_SHOT_VOCAB = {'<pad>': 0, 'f': 1, 'b': 2, 'r': 3, 'v': 4, 'o': 5, 's': 6, 'u': 7, 'l': 8, 'm': 9, 'z': 10}
EVAL_DIR_VOCAB  = {'<pad>': 0, '0': 0, '1': 1, '2': 2, '3': 3}
EVAL_DEPTH_VOCAB = {'<pad>': 0, '0': 0, '7': 1, '8': 2, '9': 3}
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

def get_fast_decoder_map(dataset):
    """ Safe decoder that handles variable key lengths to prevent IndexErrors. """
    uni_map = {}
    serve_id = EVAL_SHOT_VOCAB.get('s', 0)
    
    for uid, key in dataset.inv_unified_vocab.items():
        if uid <= 1: 
            uni_map[uid] = (0,0,0)
            continue
            
        parts = key.split('_')
        
        # 1. Handle Serves (Force Direction/Depth to 0)
        if parts[0] == 'Serve':
            uni_map[uid] = (serve_id, 0, 0)
        
        # 2. Handle Specials/Lets
        elif parts[0] in ['LET', 'SPECIAL']:
            uni_map[uid] = (0, 0, 0)
            
        # 3. Handle Standard Shots (Safe Access)
        else:
            t = EVAL_SHOT_VOCAB.get(parts[0], 0)
            d = EVAL_DIR_VOCAB.get(parts[1], 0) if len(parts) > 1 else 0
            dep = EVAL_DEPTH_VOCAB.get(parts[2], 0) if len(parts) > 2 else 0
            uni_map[uid] = (t, d, dep)
            
    return uni_map

def decode_unified_predictions(preds, dataset):
    uni_map = get_fast_decoder_map(dataset)
    types, dirs, depths = [], [], []
    for p in preds:
        t, d, dep = uni_map.get(p, (0,0,0))
        types.append(t); dirs.append(d); depths.append(dep)
    return types, dirs, depths

# --- 2. THE RICH-INPUT EVALUATION FUNCTION ---
def run_full_evaluation(model, dataset, loader, test_indices, 
                        live_samples=5000, 
                        length_matches=2000, 
                        freq_matches=2000, 
                        era_matches=50, 
                        speed_samples=10000):
    
    model.eval()
    print(f"Starting Full Evaluation on {len(test_indices)} test samples...")
    uni_map = get_fast_decoder_map(dataset)
    
    # Pre-calculate Match-to-Index Map
    match_map = {}
    for idx in test_indices:
        mid = dataset.sample_match_ids[idx]
        match_map.setdefault(mid, []).append(idx)
    unique_matches = list(match_map.keys())
    
    # ==============================================================================
    # PART 1: OVERALL TACTICAL METRICS
    # ==============================================================================
    print("\n" + "="*40 + "\n PART 1: OVERALL TACTICAL METRICS \n" + "="*40)
    all_preds_unified, all_targets_unified = [], []
    
    with torch.no_grad():
        for batch in loader:
            # Unpack rich inputs
            x_t  = batch['x_type'].to(DEVICE)
            x_d  = batch['x_dir'].to(DEVICE)
            x_dp = batch['x_depth'].to(DEVICE)
            s_id = batch['x_s_id'].to(DEVICE)
            r_id = batch['x_r_id'].to(DEVICE)
            ctx  = batch['context'].to(DEVICE)
            y    = batch['y_target'].to(DEVICE)
            
            logits = model(x_t, x_d, x_dp, s_id, r_id, ctx)
            
            mask = y.view(-1) != 0
            all_preds_unified.extend(logits.argmax(-1).view(-1)[mask].cpu().numpy())
            all_targets_unified.extend(y.view(-1)[mask].cpu().numpy())

    pred_t, pred_d, pred_dp = decode_unified_predictions(all_preds_unified, dataset)
    targ_t, targ_d, targ_dp = decode_unified_predictions(all_targets_unified, dataset)

    # 1. Direction Report
    print("\n=== DIRECTION REPORT (1=Cross, 2=Center, 3=Line) ===")
    d_labels = [k for k,v in EVAL_DIR_VOCAB.items() if v in np.unique(targ_d) and k not in ['<pad>', '0']]
    d_indices = [EVAL_DIR_VOCAB[k] for k in d_labels]
    if d_indices:
        print(classification_report(targ_d, pred_d, labels=d_indices, target_names=d_labels, zero_division=0))
    else:
        print("No valid directions found in targets.")

    # 2. Depth Report (RESTORED)
    print("\n=== DEPTH REPORT (7=Shallow, 8=Mid, 9=Deep) ===")
    dp_labels = [k for k,v in EVAL_DEPTH_VOCAB.items() if v in np.unique(targ_dp) and k not in ['<pad>', '0']]
    dp_indices = [EVAL_DEPTH_VOCAB[k] for k in dp_labels]
    print(classification_report(targ_dp, pred_dp, labels=dp_indices, target_names=dp_labels, zero_division=0))

    # 3. Shot Type Report
    print("\n=== SHOT TYPE REPORT ===")
    tp_labels = [k for k,v in EVAL_SHOT_VOCAB.items() if v in np.unique(targ_t) and k not in ['<pad>']]
    tp_indices = [EVAL_SHOT_VOCAB[k] for k in tp_labels]
    print(classification_report(targ_t, pred_t, labels=tp_indices, target_names=tp_labels, zero_division=0))

    # ==============================================================================
    # PART 2: LIVE TEST CASES (Detailed History & Confidence)
    # ==============================================================================
    print("\n" + "="*40 + f"\n PART 2: LIVE SAMPLES (Subset of {live_samples}) \n" + "="*40)
    
    inv_dir = {v:k for k,v in EVAL_DIR_VOCAB.items()}
    inv_typ = {v:k for k,v in EVAL_SHOT_VOCAB.items()}
    
    selected_indices = random.sample(test_indices, min(live_samples * 2, len(test_indices)))
    results_buffer = {3: [], 2:[], 1:[], 0:[]}
    printed_count = 0
    
    with torch.no_grad():
        for idx in selected_indices:
            if printed_count >= live_samples: break
            
            sample = dataset[idx]
            non_zeros = (sample['x_seq'] != 0).nonzero(as_tuple=True)[0]
            if len(non_zeros) < 2: continue
            
            # Predict random point
            valid_indices = non_zeros.tolist()
            t = random.choice(valid_indices)
            
            # Unpack & Unsqueeze for single sample
            x_t  = sample['x_type'].unsqueeze(0).to(DEVICE)
            x_d  = sample['x_dir'].unsqueeze(0).to(DEVICE)
            x_dp = sample['x_depth'].unsqueeze(0).to(DEVICE)
            s_id = sample['x_s_id'].unsqueeze(0).to(DEVICE)
            r_id = sample['x_r_id'].unsqueeze(0).to(DEVICE)
            ctx  = sample['context'].unsqueeze(0).to(DEVICE)

            logits = model(x_t, x_d, x_dp, s_id, r_id, ctx)
            
            # --- Build History String ---
            start_idx = valid_indices[0]
            history_str = ""
            for i in range(start_idx, t + 1):
                uid = sample['x_seq'][i].item() # Use x_seq for history logic
                typ, d, dep = uni_map.get(uid, (0,0,0))
                z_in = inv_dir.get(d, '?')
                t_in = inv_typ.get(typ, '?')
                if i == start_idx:
                    history_str += f"[Serve {z_in}] " if t_in == 's' else f"[{t_in}{z_in}] "
                else:
                    history_str += f"-> {t_in}{z_in} "

            probs = torch.softmax(logits[0, t], dim=0)
            pred_uid = probs.argmax().item()
            conf = probs.max().item() * 100
            
            pred_t, pred_d, pred_dp = uni_map.get(pred_uid, (0,0,0))
            true_uid = sample['y_target'][t].item()
            true_t, true_d, true_dp = uni_map.get(true_uid, (0,0,0))
            
            if true_t == 0: continue

            s_pred_d = inv_dir.get(pred_d, '?'); s_pred_t = inv_typ.get(pred_t, '?')
            s_true_d = inv_dir.get(true_d, '?'); s_true_t = inv_typ.get(true_t, '?')
            
            check_d = "‚úÖ" if pred_d == true_d else "‚ùå"
            check_t = "‚úÖ" if pred_t == true_t else "‚ùå"
            check_dp = "‚úÖ" if pred_dp == true_dp else "‚ùå"
            
            def d_lbl(x):
                if x == 1: return "Short"
                if x == 2: return "Deep"
                if x == 3: return "V.Deep"
                return "N/A"
            
            score = (1 if pred_d == true_d else 0) + (1 if pred_t == true_t else 0) + (1 if pred_dp == true_dp else 0)
            m_id = dataset.sample_match_ids[idx]

            out = []
            out.append(f"\nMatch {m_id}:")
            out.append(f"  History:    {history_str}")
            out.append(f"  Prediction: {s_pred_t} to {s_pred_d} ({d_lbl(pred_dp)}) | Conf: {conf:.0f}%")
            out.append(f"  ACTUAL:     {s_true_t} to {s_true_d} ({d_lbl(true_dp)}) | {check_t} Type {check_d} Dir {check_dp} Dep")
            
            results_buffer[score].append("\n".join(out))
            printed_count += 1
            
    print_flag = False
    for s in [3, 2, 1, 0]:
        items = results_buffer[s]
        if items:
            print(f"\n{'='*20} {s}/3 CORRECT ({len(items)} cases) {'='*20}")
            if print_flag:    
                for item in items:
                    print(item)

    # ==============================================================================
    # PART 3: GRANULAR ACCURACY VS RALLY LENGTH (3 Graphs)
    # ==============================================================================
    print("\n" + "="*40 + f"\n PART 3: RALLY LENGTH vs ACCURACY ({length_matches} Matches) \n" + "="*40)
    
    # 3.1 Calculate Baselines
    print("Calculating dataset baselines...")
    all_d, all_dp, all_tp = [], [], []
    for i in test_indices:
        y_seq = dataset[i]['y_target']
        for uid in y_seq:
            if uid.item() <= 1: continue
            t, d, dep = uni_map[uid.item()]
            all_tp.append(t); all_d.append(d)
            if dep != 0: all_dp.append(dep)

    def calc_baseline(data_list):
        if not data_list: return 0.33
        counts = Counter(data_list)
        total = sum(counts.values())
        return sum([(c/total)**2 for c in counts.values()])

    base_d = calc_baseline(all_d)
    base_dp = calc_baseline(all_dp)
    base_tp = calc_baseline(all_tp)
    base_pair_avg = (base_d*base_dp + base_d*base_tp + base_tp*base_dp) / 3
    base_whole = base_d * base_dp * base_tp
    print(f"Baselines -> Dir: {base_d:.2f}, Depth: {base_dp:.2f}, Type: {base_tp:.2f}, Whole: {base_whole:.4f}")

    # 3.2 Analysis Loop
    chosen_matches = random.sample(unique_matches, min(length_matches, len(unique_matches)))
    rl_indices = [idx for mid in chosen_matches for idx in match_map[mid]]
    
    rl_results = []
    with torch.no_grad():
        for idx in rl_indices:
            sample = dataset[idx]
            x_t  = sample['x_type'].unsqueeze(0).to(DEVICE)
            x_d  = sample['x_dir'].unsqueeze(0).to(DEVICE)
            x_dp = sample['x_depth'].unsqueeze(0).to(DEVICE)
            s_id = sample['x_s_id'].unsqueeze(0).to(DEVICE)
            r_id = sample['x_r_id'].unsqueeze(0).to(DEVICE)
            ctx  = sample['context'].unsqueeze(0).to(DEVICE)
            y = sample['y_target'].to(DEVICE)
            
            preds = model(x_t, x_d, x_dp, s_id, r_id, ctx).argmax(-1).squeeze(0)
            
            x_seq_cpu = sample['x_seq']
            
            for t in range(x_seq_cpu.shape[0]):
                if x_seq_cpu[t] == 0 or y[t] <= 1: continue
                
                # Shot count logic
                shot_num = (x_seq_cpu[:t+1] != 0).sum().item() + 1
                
                p_uid = preds[t].item(); t_uid = y[t].item()
                pt, p_d, pdp = uni_map.get(p_uid, (0,0,0))
                tt, td, tdp = uni_map.get(t_uid, (0,0,0))
                
                # Logic Split
                rl_results.append({'Shot_Number': shot_num, 'Task': 'Direction', 'Type': 'Single', 'Acc': 1 if p_d == td else 0})
                rl_results.append({'Shot_Number': shot_num, 'Task': 'Type', 'Type': 'Single', 'Acc': 1 if pt == tt else 0})
                rl_results.append({'Shot_Number': shot_num, 'Task': 'Dir + Type', 'Type': 'Pair', 'Acc': 1 if (pd==td and pt==tt) else 0})
                
                if tdp != 0:
                     rl_results.append({'Shot_Number': shot_num, 'Task': 'Depth', 'Type': 'Single', 'Acc': 1 if pdp == tdp else 0})
                     rl_results.append({'Shot_Number': shot_num, 'Task': 'Dir + Depth', 'Type': 'Pair', 'Acc': 1 if (p_d==td and pdp==tdp) else 0})
                     rl_results.append({'Shot_Number': shot_num, 'Task': 'Type + Depth', 'Type': 'Pair', 'Acc': 1 if (pt==tt and pdp==tdp) else 0})

                rl_results.append({'Shot_Number': shot_num, 'Task': 'Whole Shot', 'Type': 'Whole', 'Acc': 1 if p_uid == t_uid else 0})

    if rl_results:
        df = pd.DataFrame(rl_results)
        df = df[(df['Shot_Number'] >= 2) & (df['Shot_Number'] <= 12)]
        
        palette_single = {'Direction': '#1f77b4', 'Depth': '#d62728', 'Type': '#2ca02c'}
        palette_pair   = {'Dir + Depth': '#9467bd', 'Dir + Type': '#17becf', 'Type + Depth': '#ff7f0e'}
        palette_whole  = {'Whole Shot': '#000000'}

        def setup_plot(title, baseline, base_label):
            plt.figure(figsize=(12, 5))
            plt.title(title, fontsize=14)
            plt.ylabel('Accuracy', fontsize=12)
            plt.xlabel('Shot Number', fontsize=12)
            plt.xticks(np.arange(2, 13, 1))
            
            ax = plt.gca()
            ax.yaxis.set_major_locator(ticker.MultipleLocator(0.1))
            ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.02))
            plt.grid(True, which='major', axis='y', linestyle='-', linewidth=0.75, color='grey', alpha=0.6)
            plt.grid(True, which='minor', axis='y', linestyle='--', linewidth=0.5, color='grey', alpha=0.3)
            plt.ylim(0.0, 1.0)
            
            if baseline:
                plt.axhline(baseline, color='#FF1493', linestyle=':', alpha=0.8, linewidth=2, label=base_label)

        # Graph 1: Single
        setup_plot('Single Task Accuracy vs. Rally Length', base_tp, f'Random Type ({base_tp:.2f})')
        sns.lineplot(data=df[df['Type']=='Single'], x='Shot_Number', y='Acc', hue='Task', style='Task', 
                     markers=True, dashes=False, palette=palette_single, linewidth=2.5, errorbar=('ci', 68))
        plt.legend(loc='lower right'); plt.show()
        
        # Graph 2: Pairwise
        setup_plot('Pairwise Accuracy vs. Rally Length', base_pair_avg, f'Random Pair (~{base_pair_avg:.2f})')
        sns.lineplot(data=df[df['Type']=='Pair'], x='Shot_Number', y='Acc', hue='Task', style='Task', 
                     markers=True, dashes=False, palette=palette_pair, linewidth=2.5, errorbar=('ci', 68))
        plt.legend(loc='upper right'); plt.show()

        # Graph 3: Whole Shot
        setup_plot('Whole Shot Accuracy vs. Rally Length', base_whole, f'Random Whole ({base_whole:.4f})')
        sns.lineplot(data=df[df['Type']=='Whole'], x='Shot_Number', y='Acc', hue='Task', style='Task', 
                     markers=True, dashes={'Whole Shot':(2,2)}, palette=palette_whole, linewidth=2.5, errorbar=('ci', 68))
        plt.legend(loc='upper right'); plt.show()

    # ==============================================================================
    # PART 4: PLAYER FREQUENCY
    # ==============================================================================
    print("\n" + "="*40 + f"\n PART 4: PLAYER FREQUENCY ({freq_matches} Matches) \n" + "="*40)
    chosen_matches = random.sample(unique_matches, min(freq_matches, len(unique_matches)))
    pf_indices = [idx for mid in chosen_matches for idx in match_map[mid]]
    
    p_counts = Counter()
    p_stats = {}
    
    with torch.no_grad():
        for idx in pf_indices:
            sample = dataset[idx]
            x_t  = sample['x_type'].unsqueeze(0).to(DEVICE)
            x_d  = sample['x_dir'].unsqueeze(0).to(DEVICE)
            x_dp = sample['x_depth'].unsqueeze(0).to(DEVICE)
            s_id = sample['x_s_id'].unsqueeze(0).to(DEVICE)
            r_id = sample['x_r_id'].unsqueeze(0).to(DEVICE)
            ctx  = sample['context'].unsqueeze(0).to(DEVICE)
            y = sample['y_target']
            
            preds = model(x_t, x_d, x_dp, s_id, r_id, ctx).argmax(-1).squeeze(0)
            
            s_val, r_val = sample['x_s_id'].item(), sample['x_r_id'].item()
            x_seq_cpu = sample['x_seq']
            
            for t in range(len(y)):
                if y[t] == 0: continue
                hist_len = (x_seq_cpu[:t+1] != 0).sum().item()
                actor = s_val if (hist_len + 1) % 2 != 0 else r_val
                if actor <= 1: continue
                
                p_counts[actor] += 1
                if actor not in p_stats: p_stats[actor] = {'tot': 0, 'corr': 0}
                p_stats[actor]['tot'] += 1
                if preds[t].item() == y[t].item(): p_stats[actor]['corr'] += 1

    pf_data = [{'Freq': p_counts[a], 'Err': (1 - v['corr']/v['tot'])*100} for a, v in p_stats.items() if p_counts[a] > 10]
    if pf_data:
        df_pf = pd.DataFrame(pf_data)
        plt.figure(figsize=(8, 4))
        sns.regplot(data=df_pf, x='Freq', y='Err', scatter_kws={'alpha':0.5}, line_kws={'color':'red'})
        plt.xscale('log'); plt.title(f"Error vs Frequency (Corr: {df_pf['Freq'].corr(df_pf['Err']):.2f})"); plt.show()

    # ==============================================================================
    # PART 5: ERA STABILITY
    # ==============================================================================
    print("\n" + "="*40 + f"\n PART 5: ERA STABILITY ({era_matches} Matches/Era) \n" + "="*40)
    eras = {'Pre-2010': [], '2010-2019': [], '2020+': []}
    for m_id in unique_matches:
        try: y_year = int(str(m_id)[:4])
        except: continue
        if y_year < 2010: eras['Pre-2010'].append(m_id)
        elif y_year < 2020: eras['2010-2019'].append(m_id)
        else: eras['2020+'].append(m_id)

    era_indices = []
    era_labels_list = []
    for era_name, m_list in eras.items():
        if not m_list: continue
        chosen = random.sample(m_list, min(era_matches, len(m_list)))
        for m in chosen:
            era_indices.extend(match_map[m])
            era_labels_list.extend([era_name]*len(match_map[m]))
            
    era_res = []
    with torch.no_grad():
        for i, idx in enumerate(era_indices):
            sample = dataset[idx]
            x_t  = sample['x_type'].unsqueeze(0).to(DEVICE)
            x_d  = sample['x_dir'].unsqueeze(0).to(DEVICE)
            x_dp = sample['x_depth'].unsqueeze(0).to(DEVICE)
            s_id = sample['x_s_id'].unsqueeze(0).to(DEVICE)
            r_id = sample['x_r_id'].unsqueeze(0).to(DEVICE)
            ctx  = sample['context'].unsqueeze(0).to(DEVICE)
            y = sample['y_target'].to(DEVICE)

            preds = model(x_t, x_d, x_dp, s_id, r_id, ctx).argmax(-1).squeeze(0)
            mask = y != 0
            if mask.sum() > 0:
                acc = (preds[mask] == y[mask]).float().mean().item()
                era_res.append({'Era': era_labels_list[i], 'Whole Shot Acc': acc})
    
    if era_res:
        plt.figure(figsize=(6, 4))
        sns.barplot(data=pd.DataFrame(era_res), x='Era', y='Whole Shot Acc', palette='viridis', order=['Pre-2010', '2010-2019', '2020+'])
        plt.title('Accuracy by Era'); plt.ylim(0, 1); plt.show()
    
    # ==============================================================================
    # PART 6: RAW ERROR ANALYSIS BY SURFACE (RichInputLSTM - Unified Output)
    # ==============================================================================
    print("\n" + "="*50)
    print(" RAW ERROR ANALYSIS BY SURFACE (RichInputLSTM) ")
    print("="*50)

    # 1. Group Test Indices by Surface
    surface_map = {'Clay': [], 'Hard': [], 'Grass': []}
    for idx in test_indices:
        surf = dataset.match_meta.get(dataset.sample_match_ids[idx], {}).get('surface', 'Hard')
        found = False
        for k in surface_map: 
            if k in surf: 
                surface_map[k].append(idx)
                found = True
                break
        if not found: surface_map['Hard'].append(idx)
            
    # 2. Select Samples Balanced by Surface
    selected_indices, surface_labels = [], []
    per_surf = speed_samples // 3 
    
    for s, inds in surface_map.items():
        if not inds: continue
        chosen = random.sample(inds, min(len(inds), per_surf))
        selected_indices.extend(chosen)
        surface_labels.extend([s]*len(chosen))
        
    # 3. Evaluation Loop
    # We need the decoder map to break Unified IDs back into components
    uni_map = get_fast_decoder_map(dataset)
    
    results = []
    with torch.no_grad():
        for i, idx in enumerate(selected_indices):
            sample = dataset[idx]
            surf = surface_labels[i]
            
            # --- Inputs (Decomposed) ---
            x_t  = sample['x_type'].unsqueeze(0).to(DEVICE)
            x_d  = sample['x_dir'].unsqueeze(0).to(DEVICE)
            x_dp = sample['x_depth'].unsqueeze(0).to(DEVICE)
            x_s  = sample['x_s_id'].unsqueeze(0).to(DEVICE)
            x_r  = sample['x_r_id'].unsqueeze(0).to(DEVICE)
            x_c  = sample['context'].unsqueeze(0).to(DEVICE)
            
            # --- Target (Unified ID) ---
            y_uid_gt = sample['y_target'].to(DEVICE)
            
            # --- Forward Pass (Unified Output) ---
            logits = model(x_t, x_d, x_dp, x_s, x_r, x_c)
            pred_uid = logits.argmax(dim=-1).squeeze(0)
            
            seq_len = x_t.shape[1]
            for t in range(seq_len):
                if y_uid_gt[t] == 0: continue 
                
                # Get IDs
                p_id = pred_uid[t].item()
                t_id = y_uid_gt[t].item()
                
                # Decode components (Prediction vs Truth)
                p_t, p_d, p_dp = uni_map.get(p_id, (0,0,0))
                t_t, t_d, t_dp = uni_map.get(t_id, (0,0,0))
                
                # Whole Shot Error (Unified Mismatch)
                whole_shot_miss = (p_id != t_id)

                # --- DEPTH MASKING FIX ---
                # Only calculate Depth Error if the target actually HAS depth
                if t_dp != 0:
                    depth_err = 1.0 if p_dp != t_dp else 0.0
                else:
                    depth_err = None 

                results.append({
                    'Surface': surf,
                    'Type Error': 1.0 if p_t != t_t else 0.0,
                    'Direction Error': 1.0 if p_d != t_d else 0.0,
                    'Depth Error': depth_err,  # <--- Masked
                    'Whole Shot Error': 1.0 if whole_shot_miss else 0.0
                })
                
    # 4. Statistics & Plotting
    if not results:
        print("No results generated.")
        return

    df = pd.DataFrame(results)
    
    # Print Table (Mean Error %) - Pandas ignores None/NaN automatically
    stats = df.groupby('Surface')[['Type Error', 'Direction Error', 'Depth Error', 'Whole Shot Error']].mean() * 100
    print("\n--- Mean Error Rates (%) [Depth calculated only on non-zero targets] ---")
    print(stats.round(2))
    
    # Plotting
    df_melt = df.melt(id_vars=['Surface'], 
                      value_vars=['Type Error', 'Direction Error', 'Depth Error', 'Whole Shot Error'], 
                      value_name='Error Rate')
    
    plt.figure(figsize=(12, 6))
    sns.barplot(data=df_melt, x='Surface', y='Error Rate', hue='variable', 
                order=['Clay', 'Hard', 'Grass'], palette='viridis')
    plt.title('Error Rates by Component (Masked Depth) vs. Whole Shot')
    plt.ylabel('Error Rate (0.0 - 1.0)')
    plt.legend(title='Metric')
    plt.grid(axis='y', alpha=0.3)
    plt.show()

# ==============================================================================
# RUN THIS BLOCK BEFORE EVALUATION
# ==============================================================================



# --- 3. RUNNER SNIPPET ---
if 'datasetEnhanced' in globals() and 'richLSTM' in globals():
    # 1. Init Dataset (This creates the "Old" IDs)
    print("Initializing Dataset...")
    datasetEnhanced = EnhancedTennisDataset(
        point_files, 
        matches_path, 
        atp_path, 
        wta_path, 
        max_seq_len=SEQ_LEN
    )
    
    # 2. Align Tensors & Load Model
    richLSTM = align_and_load_rich_lstm(
        checkpoint_path="/kaggle/input/richlstm/transformers/default/1/rich_best.pt",
        dataset=datasetEnhanced, 
        device=DEVICE
    )
    
    print("Recreating validation/test split for evaluation...")
    seed_everything(42) 
    
    total_len = len(datasetEnhanced)
    train_len = int(0.80 * total_len)
    val_len   = int(0.10 * total_len)
    test_len  = total_len - train_len - val_len
    
    gen = torch.Generator().manual_seed(42)
    _, _, test_ds = random_split(datasetEnhanced, [train_len, val_len, test_len], generator=gen)
    
    test_indices = test_ds.indices
    test_loader_eval = DataLoader(test_ds, batch_size=64, shuffle=False)

    run_full_evaluation(
        model=richLSTM, 
        dataset=datasetEnhanced, 
        loader=test_loader_eval, 
        test_indices=test_indices,
        live_samples=5000, 
        length_matches=2000,
        freq_matches=2000
    )
else:
    print("Error: 'datasetEnhanced' or 'richLSTM' not found. Run training first.")

## UNIFIED TRANSFORMER (cristiangpt)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
import pandas as pd
import numpy as np
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
import re
import os
import random

# --- CONFIGURAZIONE GLOBALE ---
SEQ_LEN = 30       # Fixed sequence length
BATCH_SIZE = 64
EPOCHS = 10
LR = 1e-3
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Define the path
save_path = 'tennis_shot_forecasting.pth'

# Seed everything to avoid randomness
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    print(f"Random seed set to {seed}")

# Call it immediately
seed_everything(42)

### DOWNSAMPLED DATASET PER VEDERE COMPARISON FRA UOMINI E DONNE

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
import pandas as pd
import numpy as np
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
import re
import os
import random

# --- CONFIGURAZIONE GLOBALE ---
SEQ_LEN = 30        # Fixed sequence length
BATCH_SIZE = 64
EPOCHS = 10
LR = 1e-3
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Define the path
save_path = 'tennis_shot_forecasting.pth'

# Seed everything to avoid randomness
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    print(f"Random seed set to {seed}")

# Call it immediately
seed_everything(42)


In [None]:
atp_path = '/kaggle/input/atp-players/atp_players.csv'
wta_path = '/kaggle/input/wta-players/wta_players.csv'

from torch.optim.lr_scheduler import OneCycleLR

# List all point files to merge
base_path = '/kaggle/input/atp-points/'
point_files = [
    base_path + 'charting-m-points-2020s.csv',
    base_path + 'charting-m-points-2010s.csv',
    base_path + 'charting-m-points-to-2009.csv'
]

# New Matches File
matches_path = '/kaggle/input/atp-matches-updated/charting-m-matches-updated.csv'

# Standard Focal Loss (Reused for all heads)
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.reduction = reduction
        self.alpha = alpha
        
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none', ignore_index=0)
        pt = torch.exp(-ce_loss) 
        focal_loss = (1 - pt) ** self.gamma * ce_loss
        
        if self.alpha is not None:
            if self.alpha.device != inputs.device:
                self.alpha = self.alpha.to(inputs.device)
            at = self.alpha.gather(0, targets.view(-1))
            focal_loss = focal_loss * at
            
        if self.reduction == 'mean':
            mask = targets != 0
            if mask.sum() > 0:
                return focal_loss[mask].mean()
            else:
                return torch.tensor(0.0, device=inputs.device, requires_grad=True)
        else:
            return focal_loss

# Helper function to calculate weights for a specific target list
def get_balanced_weights(target_tensor, num_classes, power=0.3):
    # 1. Flatten the big tensor directly (Very fast)
    flat = target_tensor.view(-1).numpy()
    
    # 2. Filter padding (0)
    valid = flat[flat != 0]
    
    # 3. Compute Weights
    unique = np.unique(valid)
    raw_weights = compute_class_weight(class_weight='balanced', classes=unique, y=valid)
    smoothed_weights = raw_weights ** power
    
    weights_list = [0.0] # Index 0 is padding
    for i in range(1, num_classes + 1):
        if i in unique:
            idx = np.where(unique == i)[0][0]
            weights_list.append(float(smoothed_weights[idx]))
        else:
            weights_list.append(1.0)
            
    # Cap weights
    MAX_WEIGHT = 15.0
    weights_list = [min(w, MAX_WEIGHT) for w in weights_list]
    return torch.tensor(weights_list, dtype=torch.float32).to(DEVICE)

# Check if files exist
if os.path.exists(matches_path):
    # 1. Initialize Unified Dataset
    # Make sure you have updated the MCPTennisDataset class definition 
    # to the "Corrected" version I gave you previously!
    dataset = DownsampledDataset(
        points_paths_list=point_files, 
        matches_path=matches_path, 
        atp_path=atp_path, 
        wta_path=wta_path, 
        max_seq_len=SEQ_LEN
    )
    
    # --- 1. CALCULATE WEIGHTS FOR THE UNIFIED TOKENS ---
    print("\nCalculating Balanced Weights for Unified Vocab...")
    
    # We now have just ONE target tensor: y_target_tensor
    # It contains IDs like 42 (which represents "Forehand_CrossCourt_Deep")
    w_unified = get_balanced_weights(dataset.y_target_tensor, len(dataset.unified_vocab))
    print(f"Unified Weights Shape: {w_unified.shape}")
        
    # --- SPLIT BY MATCH ---
    print("\nSplitting Data (80/15/5)...")
    all_matches = sorted(list(set(dataset.sample_match_ids)))
    
    train_matches, temp_matches = train_test_split(all_matches, test_size=0.20, random_state=42)
    val_matches, test_matches = train_test_split(temp_matches, test_size=0.25, random_state=42)
    
    train_set = set(train_matches)
    val_set = set(val_matches)
    test_set = set(test_matches)
    
    train_indices = [i for i, m in enumerate(dataset.sample_match_ids) if m in train_set]
    val_indices = [i for i, m in enumerate(dataset.sample_match_ids) if m in val_set]
    test_indices = [i for i, m in enumerate(dataset.sample_match_ids) if m in test_set]
    
    print(f"Train: {len(train_indices)} | Val: {len(val_indices)} | Test: {len(test_indices)}")

    # Workers
    g = torch.Generator()
    g.manual_seed(42)
    
    train_loader = DataLoader(Subset(dataset, train_indices), batch_size=BATCH_SIZE, shuffle=True, num_workers=2, generator=g)
    val_loader = DataLoader(Subset(dataset, val_indices), batch_size=BATCH_SIZE, shuffle=False, num_workers=2, generator=g)
    test_loader = DataLoader(Subset(dataset, test_indices), batch_size=BATCH_SIZE, shuffle=False, num_workers=2, generator=g)

    seed_everything(42)
    
    # 5. Initialize UNIFIED Model
    model = UnifiedCristianGPT(
        unified_vocab_size=len(dataset.unified_vocab),
        num_players=len(dataset.player_vocab),
        context_dim=10, # Remember we added Height/Pressure features
        seq_len=SEQ_LEN,
        embed_dim=128
    ).to(DEVICE)
    
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
    
    # Use aggressive learning rate as discussed
    scheduler = OneCycleLR(optimizer, max_lr=4e-4, steps_per_epoch=len(train_loader), 
                           epochs=EPOCHS, pct_start=0.3, div_factor=25, final_div_factor=1000)

    # 6. Initialize SINGLE Loss Function
    # We only predict one thing now: The Unified Token ID
    criterion = FocalLoss(alpha=w_unified, gamma=2.0)

    best_val_loss = float('inf')
    patience = 5
    trigger_times = 0
        
    # 7. Training Loop (Unified)
    for epoch in range(EPOCHS):
        model.train()
        train_loss = 0
        
        for batch in train_loader:
            # Inputs (Updated for Unified Model signature)
            x_seq = batch['x_seq'].to(DEVICE) # The sequence of unified tokens
            x_c = batch['context'].to(DEVICE)
            x_s = batch['x_s_id'].to(DEVICE)
            x_r = batch['x_r_id'].to(DEVICE)
            
            # Single Target
            y = batch['y_target'].to(DEVICE)
            
            optimizer.zero_grad()
            
            # Forward
            logits = model(x_seq, x_c, x_s, x_r)
            
            # Calculate Loss (Flatten batch and sequence)
            # View(-1, vocab_size) vs View(-1)
            loss = criterion(logits.view(-1, len(dataset.unified_vocab)), y.view(-1))
            
            loss.backward()
            optimizer.step()
            scheduler.step()
            train_loss += loss.item()
            
        # Validation
        model.eval()
        val_loss = 0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for batch in val_loader:
                x_seq = batch['x_seq'].to(DEVICE)
                x_c = batch['context'].to(DEVICE)
                x_s = batch['x_s_id'].to(DEVICE)
                x_r = batch['x_r_id'].to(DEVICE)
                y = batch['y_target'].to(DEVICE)
                
                logits = model(x_seq, x_c, x_s, x_r)
                
                # Val Loss
                l = criterion(logits.view(-1, len(dataset.unified_vocab)), y.view(-1))
                val_loss += l.item()
                
                # Accuracy (Unified)
                # We check if the predicted ID exactly matches the target ID
                mask = y != 0
                preds = logits.argmax(dim=-1)
                correct += (preds[mask] == y[mask]).sum().item()
                total += mask.sum().item()

        avg_train = train_loss / len(train_loader)
        avg_val = val_loss / len(val_loader)
        
        acc = (correct / total * 100) if total > 0 else 0
        
        print(f"Epoch {epoch+1} | Train Loss: {avg_train:.4f} | Val Loss: {avg_val:.4f} | Unified Acc: {acc:.2f}%")
        
        if avg_val < best_val_loss:
            best_val_loss = avg_val
            trigger_times = 0
            torch.save(model.state_dict(), save_path)
            print(f"   --> New Best Model Saved! (Loss: {best_val_loss:.4f})")
        else:
            trigger_times += 1
            if trigger_times >= patience:
                print("   --> Early Stopping.")
                break
else:
    print("Dataset paths not found.")

torch.save(model.state_dict(), save_path)
print(f"Model weights saved to {save_path}")

### Evaluation

In [None]:
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from collections import Counter
import seaborn as sns
import pandas as pd
import numpy as np
import random
import torch
import os

# --- 1. DEFINE STANDARD VOCABS (Quick Fix for Decoding) ---
EVAL_SHOT_VOCAB = {'<pad>': 0, 'f': 1, 'b': 2, 'r': 3, 'v': 4, 'o': 5, 's': 6, 'u': 7, 'l': 8, 'm': 9, 'z': 10}
EVAL_DIR_VOCAB  = {'<pad>': 0, '0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6}
EVAL_DEPTH_VOCAB = {'<pad>': 0, '0': 0, '7': 1, '8': 2, '9': 3}

# --- 2. DECODER HELPER ---
def decode_unified_predictions(preds, dataset):
    """ Converts list of Unified IDs -> (Type, Dir, Depth) """
    types, dirs, depths = [], [], []
    id_map = {}
    serve_type_id = EVAL_SHOT_VOCAB.get('s', 0)
    
    for uid, key in dataset.inv_unified_vocab.items():
        if uid <= 1: 
            id_map[uid] = (0, 0, 0)
            continue
        parts = key.split('_')
        if parts[0] == 'Serve':
            t = serve_type_id
            d = EVAL_DIR_VOCAB.get(parts[1], 0)
            dep = 0
        else:
            t = EVAL_SHOT_VOCAB.get(parts[0], 0)
            d = EVAL_DIR_VOCAB.get(parts[1], 0)
            dep = EVAL_DEPTH_VOCAB.get(parts[2], 0)
        id_map[uid] = (t, d, dep)
        
    for p in preds:
        t, d, dep = id_map.get(p, (0,0,0))
        types.append(t)
        dirs.append(d)
        depths.append(dep)
    return types, dirs, depths

def get_fast_decoder_map(dataset):
    """ Returns a dict mapping UnifiedID -> (Type, Dir, Depth) for fast loops """
    uni_map = {}
    serve_id = EVAL_SHOT_VOCAB.get('s', 0)
    for uid, key in dataset.inv_unified_vocab.items():
        if uid <= 1: 
            uni_map[uid] = (0,0,0)
            continue
        parts = key.split('_')
        if parts[0] == 'Serve':
            uni_map[uid] = (serve_id, EVAL_DIR_VOCAB.get(parts[1], 0), 0)
        else:
            uni_map[uid] = (EVAL_SHOT_VOCAB.get(parts[0], 0), 
                            EVAL_DIR_VOCAB.get(parts[1], 0), 
                            EVAL_DEPTH_VOCAB.get(parts[2], 0))
    return uni_map

def load_weights_into_model(model, path):
    if not os.path.exists(path):
        print(f"‚ö†Ô∏è  Weights file '{path}' not found. Using random init.")
        return model
    print(f"üîç Loading weights from '{path}'...")
    try:
        DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        checkpoint = torch.load(path, map_location=DEVICE)
        state_dict = checkpoint['model_state_dict'] if (isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint) else checkpoint
        model.load_state_dict(state_dict)
        print("‚úÖ Weights loaded successfully!")
    except Exception as e:
        print(f"‚ùå Error loading weights: {e}")
    return model.to(DEVICE)

# ==============================================================================
# 1. MAIN TACTICAL ANALYSIS
# ==============================================================================
def evaluation_analyze_tactics_multitask(model, loader, dataset):
    model.eval()
    all_preds_unified, all_targets_unified = [], []
    
    print("\nRunning Evaluation on TEST SET (Unified Model)...")
    with torch.no_grad():
        for batch in loader:
            x_seq = batch['x_seq'].to(DEVICE)
            x_c = batch['context'].to(DEVICE)
            x_s = batch['x_s_id'].to(DEVICE)
            x_r = batch['x_r_id'].to(DEVICE)
            y = batch['y_target'].to(DEVICE)
            
            logits = model(x_seq, x_c, x_s, x_r)
            
            mask = y.view(-1) != 0
            all_preds_unified.extend(logits.argmax(-1).view(-1)[mask].cpu().numpy())
            all_targets_unified.extend(y.view(-1)[mask].cpu().numpy())

    print("Decoding predictions...")
    pred_t, pred_d, pred_dp = decode_unified_predictions(all_preds_unified, dataset)
    targ_t, targ_d, targ_dp = decode_unified_predictions(all_targets_unified, dataset)

    # 1. Direction Report
    print("\n=== DIRECTION REPORT (1=Right, 2=Center, 3=Left) ===")
    d_labels = [k for k,v in EVAL_DIR_VOCAB.items() if v in np.unique(targ_d) and k not in ['<pad>', '0']]
    d_indices = [EVAL_DIR_VOCAB[k] for k in d_labels]
    print(classification_report(targ_d, pred_d, labels=d_indices, target_names=d_labels, zero_division=0))
    
    # 2. Depth Report
    print("\n=== DEPTH REPORT (7=Shallow, 8=Mid, 9=Deep) ===")
    dp_labels = [k for k,v in EVAL_DEPTH_VOCAB.items() if v in np.unique(targ_dp) and k not in ['<pad>', '0']]
    dp_indices = [EVAL_DEPTH_VOCAB[k] for k in dp_labels]
    print(classification_report(targ_dp, pred_dp, labels=dp_indices, target_names=dp_labels, zero_division=0))
    
    # 3. Shot Type Report
    print("\n=== SHOT TYPE REPORT ===")
    tp_labels = [k for k,v in EVAL_SHOT_VOCAB.items() if v in np.unique(targ_t) and k not in ['<pad>']]
    tp_indices = [EVAL_SHOT_VOCAB[k] for k in tp_labels]
    print(classification_report(targ_t, pred_t, labels=tp_indices, target_names=tp_labels, zero_division=0))

# ==============================================================================
# 2. LIVE TEST CASES
# ==============================================================================
def evaluation_live_test_cases(model, dataset, test_indices, num_samples=10, print_flag=True):     
    model.eval()
    inv_dir = {v:k for k,v in EVAL_DIR_VOCAB.items()}
    # Note: We rely on the indices (1,2,3), not the keys ('7','8','9')
    inv_typ = {v:k for k,v in EVAL_SHOT_VOCAB.items()}
    
    print(f"\n--- LIVE TACTICAL EVALUATION ({num_samples} Random Test Cases) ---")
    # FIX 1: Shuffle ALL indices to ensure we have a large enough pool
    candidates = list(test_indices)
    random.shuffle(candidates)
    
    results_buffer = {3: [], 2:[], 1:[], 0:[]}
    uni_map = get_fast_decoder_map(dataset)
    
    count_processed = 0

    with torch.no_grad():
        for idx in candidates:
            # FIX 2: Stop once we have processed the desired number of VALID samples
            if count_processed >= num_samples:
                break
                
            sample = dataset[idx]
            non_zeros = (sample['x_seq'] != 0).nonzero(as_tuple=True)[0]

            #Skip invalid samples
            if len(non_zeros) < 2: continue
            
            valid_indices = non_zeros.tolist()
            t = random.choice(valid_indices)
            
            x_seq = sample['x_seq'].unsqueeze(0).to(DEVICE)
            x_c = sample['context'].unsqueeze(0).to(DEVICE)
            x_s = sample['x_s_id'].unsqueeze(0).to(DEVICE)
            x_r = sample['x_r_id'].unsqueeze(0).to(DEVICE)
            
            logits = model(x_seq, x_c, x_s, x_r)
            
            start_idx = valid_indices[0]
            history_str = ""
            for i in range(start_idx, t + 1):
                uid = sample['x_seq'][i].item()
                typ, d, dep = uni_map.get(uid, (0,0,0))
                z_in = inv_dir.get(d, '?')
                t_in = inv_typ.get(typ, '?')
                if i == start_idx:
                    history_str += f"[Serve {z_in}] " if t_in == 's' else f"[{t_in}{z_in}] "
                else:
                    history_str += f"-> {t_in}{z_in} "

            probs = torch.softmax(logits[0, t], dim=0)
            pred_uid = probs.argmax().item()
            conf = probs.max().item() * 100
            pred_t, pred_d, pred_dp = uni_map.get(pred_uid, (0,0,0))
            
            true_uid = sample['y_target'][t].item()
            true_t, true_d, true_dp = uni_map.get(true_uid, (0,0,0))
            
            s_pred_d = inv_dir.get(pred_d, '?'); s_pred_t = inv_typ.get(pred_t, '?')
            s_true_d = inv_dir.get(true_d, '?'); s_true_t = inv_typ.get(true_t, '?')
            
            check_d = "‚úÖ" if pred_d == true_d else "‚ùå"
            check_t = "‚úÖ" if pred_t == true_t else "‚ùå"
            check_dp = "‚úÖ" if pred_dp == true_dp else "‚ùå"
            
            # --- FIX: Correctly decode the depth INDEX (not value) ---
            # 0=N/A, 1='7'(Shallow), 2='8'(Deep), 3='9'(Very Deep)
            def d_lbl(x):
                if x == 0: return "N/A"
                if x == 1: return "Short"
                if x == 2: return "Deep"
                if x == 3: return "V.Deep"
                return "?"
            
            score = (1 if pred_d == true_d else 0) + (1 if pred_t == true_t else 0) + (1 if pred_dp == true_dp else 0)
            
            out = []
            out.append(f"\nMatch {dataset.sample_match_ids[idx]}:")
            out.append(f"  History:    {history_str}")
            out.append(f"  Prediction: {s_pred_t} to {s_pred_d} ({d_lbl(pred_dp)}) | Conf: {conf:.0f}%")
            out.append(f"  ACTUAL:     {s_true_t} to {s_true_d} ({d_lbl(true_dp)}) | {check_t} Type {check_d} Dir {check_dp} Dep")
            results_buffer[score].append("\n".join(out))
            
            count_processed += 1
            
    for s in [3,2,1,0]:
        items = results_buffer[s]
        if not items: continue
        print(f"\n{'='*20} {s}/3 CORRECT ({len(items)} cases) {'='*20}")
        if print_flag:
            for item in items: print(item)

# ==============================================================================
# 3. LENGTH VS ERROR
# ==============================================================================
def evaluation_length_vs_errrors(model, dataset, test_indices, num_matches=10):
    model.eval()
    
    # 1. Calculate Baselines (Weighted Random Probability)
    print("Calculating dataset baselines...")
    from collections import Counter
    all_d, all_dp, all_tp = [], [], []
    
    check_indices = test_indices[:5000] if len(test_indices) > 5000 else test_indices
    uni_map = get_fast_decoder_map(dataset)
    
    for i in check_indices:
        y_seq = dataset[i]['y_target']
        for uid in y_seq:
            if uid.item() <= 1: continue
            t, d, dep = uni_map[uid.item()]
            all_tp.append(t)
            all_d.append(d)
            if dep != 0: all_dp.append(dep)

    def calc_baseline(data_list):
        if not data_list: return 0.33
        counts = Counter(data_list)
        total = sum(counts.values())
        return sum([(c/total)**2 for c in counts.values()])

    base_d = calc_baseline(all_d)
    base_dp = calc_baseline(all_dp)
    base_tp = calc_baseline(all_tp)
    
    base_pair_avg = (base_d*base_dp + base_d*base_tp + base_tp*base_dp) / 3
    base_whole = base_d * base_dp * base_tp
    
    print(f"Baselines -> Dir: {base_d:.2f}, Depth: {base_dp:.2f}, Type: {base_tp:.2f}, Whole: {base_whole:.4f}")
    
    # 2. Analysis Loop
    test_match_ids = [dataset.sample_match_ids[i] for i in test_indices]
    unique_matches = sorted(list(set(test_match_ids)))
    selected_matches = random.sample(unique_matches, min(num_matches, len(unique_matches)))
    selected_indices = [i for i in test_indices if dataset.sample_match_ids[i] in selected_matches]
    print(f"Analyzing {len(selected_indices)} points from {len(selected_matches)} matches...")

    results = []
    with torch.no_grad():
        for idx in selected_indices:
            sample = dataset[idx]
            x_seq = sample['x_seq'].unsqueeze(0).to(DEVICE)
            x_c = sample['context'].unsqueeze(0).to(DEVICE)
            x_s = sample['x_s_id'].unsqueeze(0).to(DEVICE)
            x_r = sample['x_r_id'].unsqueeze(0).to(DEVICE)
            y = sample['y_target'].to(DEVICE)

            logits = model(x_seq, x_c, x_s, x_r)
            preds = logits.argmax(dim=-1).squeeze(0)
            
            seq_len = x_seq.shape[1]
            for t in range(seq_len):
                if x_seq[0, t] == 0: continue
                # Calculate absolute shot number in rally
                history_so_far = x_seq[0, :t+1]
                true_shot_count = (history_so_far != 0).sum().item()
                shot_num = true_shot_count + 1

                p_uid = preds[t].item(); t_uid = y[t].item()
                if t_uid <= 1: continue 
                
                p_t, p_d, p_dp = uni_map.get(p_uid, (0,0,0))
                t_t, t_d, t_dp = uni_map.get(t_uid, (0,0,0))

                # --- FIX: LOGIC SPLIT ---
                
                # 1. ALWAYS valid tasks
                results.append({'Shot_Number': shot_num, 'Task': 'Direction', 'Type': 'Single', 'Accuracy': 1 if p_d == t_d else 0})
                results.append({'Shot_Number': shot_num, 'Task': 'Type', 'Type': 'Single', 'Accuracy': 1 if p_t == t_t else 0})
                results.append({'Shot_Number': shot_num, 'Task': 'Dir + Type', 'Type': 'Pair', 'Accuracy': 1 if (p_d==t_d and p_t==t_t) else 0})
                
                # 2. DEPTH-DEPENDENT tasks (Only count if target depth is NOT 0)
                if t_dp != 0:
                    results.append({'Shot_Number': shot_num, 'Task': 'Depth', 'Type': 'Single', 'Accuracy': 1 if p_dp == t_dp else 0})
                    results.append({'Shot_Number': shot_num, 'Task': 'Dir + Depth', 'Type': 'Pair', 'Accuracy': 1 if (p_d==t_d and p_dp==t_dp) else 0})
                    results.append({'Shot_Number': shot_num, 'Task': 'Type + Depth', 'Type': 'Pair', 'Accuracy': 1 if (p_t==t_t and p_dp==t_dp) else 0})
                
                # 3. Whole Shot (We count this even if depth is 0, because predicting "No Depth" correctly is part of the token)
                results.append({'Shot_Number': shot_num, 'Task': 'Whole Shot', 'Type': 'Whole', 'Accuracy': 1 if p_uid == t_uid else 0})

    if not results: return
    df = pd.DataFrame(results)
    df = df[(df['Shot_Number'] <= 12) & (df['Shot_Number'] >= 2)]
    
    # --- PLOTTING ---
    palette_single = {'Direction': '#1f77b4', 'Depth': '#d62728', 'Type': '#2ca02c'}
    palette_pair   = {'Dir + Depth': '#9467bd', 'Dir + Type': '#17becf', 'Type + Depth': '#ff7f0e'}
    palette_whole  = {'Whole Shot': '#000000'}

    def setup_plot(title, baseline, base_label):
        plt.figure(figsize=(14, 6))
        plt.title(title, fontsize=14)
        plt.ylabel('Accuracy', fontsize=12)
        plt.xlabel('Shot Number', fontsize=12)
        plt.xticks(np.arange(2, 13, 1))
        
        ax = plt.gca()
        ax.yaxis.set_major_locator(ticker.MultipleLocator(0.1))
        ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.02))
        plt.grid(True, which='major', axis='y', linestyle='-', linewidth=0.75, color='grey', alpha=0.6)
        plt.grid(True, which='minor', axis='y', linestyle='--', linewidth=0.5, color='grey', alpha=0.3)
        plt.ylim(0.0, 1.0)
        
        if baseline:
            plt.axhline(baseline, color='#FF1493', linestyle=':', alpha=0.6, label=base_label)

    # GRAPH 1: SINGLE
    setup_plot('Single Task Accuracy vs. Rally Length', base_tp, f'Random Type ({base_tp:.2f})')
    sns.lineplot(data=df[df['Type']=='Single'], x='Shot_Number', y='Accuracy', hue='Task', style='Task', 
                 markers=True, dashes=False, palette=palette_single, linewidth=2.5, errorbar=('ci', 68))
    plt.legend(loc='lower right'); plt.show()
    
    # GRAPH 2: PAIRWISE
    setup_plot('Pairwise Accuracy vs. Rally Length', base_pair_avg, f'Random Pair (~{base_pair_avg:.2f})')
    sns.lineplot(data=df[df['Type']=='Pair'], x='Shot_Number', y='Accuracy', hue='Task', style='Task', 
                 markers=True, dashes=False, palette=palette_pair, linewidth=2.5, errorbar=('ci', 68))
    plt.legend(loc='upper right'); plt.show()

    # GRAPH 3: WHOLE SHOT
    setup_plot('Whole Shot Accuracy vs. Rally Length', base_whole, f'Random Whole ({base_whole:.4f})')
    sns.lineplot(data=df[df['Type']=='Whole'], x='Shot_Number', y='Accuracy', hue='Task', style='Task', 
                 markers=True, dashes={'Whole Shot':(2,2)}, palette=palette_whole, linewidth=2.5, errorbar=('ci', 68))
    plt.legend(loc='upper right'); plt.show()

# ==============================================================================
# 4. PLAYER FREQUENCY ANALYSIS
# ==============================================================================
def evaluation_player_frequency_vs_error(model, dataset, test_indices, num_matches=None):
    """
    Analyzes if the model performs better on players it sees more often.
    Plots Error Rate vs. Player Appearance Frequency.
    """
    model.eval()
    
    if num_matches:
        match_map = {}
        for idx in test_indices:
            mid = dataset.sample_match_ids[idx]
            match_map.setdefault(mid, []).append(idx)
        selected_mids = random.sample(list(match_map.keys()), min(num_matches, len(match_map)))
        selected_indices = [idx for mid in selected_mids for idx in match_map[mid]]
    else:
        selected_indices = test_indices

    print(f"\n--- PLAYER FREQUENCY ANALYSIS ({len(selected_indices)} samples) ---")
    
    player_shot_counts = Counter()
    player_stats = {} 
    inv_player_vocab = {v: k for k, v in dataset.player_vocab.items()}
    uni_map = get_fast_decoder_map(dataset)

    print("Gathering player stats and predictions (this might take a moment)...")
    
    with torch.no_grad():
        for i, idx in enumerate(selected_indices):
            sample = dataset[idx]
            
            x_seq = sample['x_seq'].unsqueeze(0).to(DEVICE)
            x_c, x_s, x_r = sample['context'].unsqueeze(0).to(DEVICE), sample['x_s_id'].unsqueeze(0).to(DEVICE), sample['x_r_id'].unsqueeze(0).to(DEVICE)
            y = sample['y_target'].to(DEVICE)
            
            logits = model(x_seq, x_c, x_s, x_r)
            preds = logits.argmax(dim=-1).squeeze(0)
            
            seq_len = x_seq.shape[1]
            s_id = sample['x_s_id'].item(); r_id = sample['x_r_id'].item()
            
            for t in range(seq_len):
                if y[t] == 0: continue
                
                # Identify hitter
                history_len = (x_seq[0, :t+1] != 0).sum().item()
                actor_id = s_id if (history_len + 1) % 2 != 0 else r_id
                if actor_id <= 1: continue

                player_shot_counts[actor_id] += 1
                if actor_id not in player_stats: 
                    player_stats[actor_id] = {'total': 0, 'correct_whole': 0, 'correct_type': 0}
                player_stats[actor_id]['total'] += 1
                
                # Check
                p_uid = preds[t].item(); t_uid = y[t].item()
                p_t = uni_map.get(p_uid, (0,0,0))[0]
                t_t = uni_map.get(t_uid, (0,0,0))[0]
                
                if p_uid == t_uid: player_stats[actor_id]['correct_whole'] += 1
                if p_t == t_t: player_stats[actor_id]['correct_type'] += 1

    # Prepare Data
    plot_data = []
    for pid, counts in player_stats.items():
        if player_shot_counts[pid] < 15: continue
        plot_data.append({
            'Player': inv_player_vocab.get(pid, f"ID_{pid}").split(' ')[-1],
            'Frequency': player_shot_counts[pid],
            'Error_Rate_Whole': (1 - counts['correct_whole']/counts['total']) * 100,
            'Error_Rate_Type': (1 - counts['correct_type']/counts['total']) * 100
        })
        
    if not plot_data: 
        print("Not enough data points per player to plot.")
        return
        
    df = pd.DataFrame(plot_data)
    
    # Plotting
    fig, axes = plt.subplots(1, 2, figsize=(18, 7))
    
    # Plot A: Whole Shot
    sns.scatterplot(data=df, x='Frequency', y='Error_Rate_Whole', ax=axes[0], color='#d62728', alpha=0.6)
    sns.regplot(data=df, x='Frequency', y='Error_Rate_Whole', ax=axes[0], scatter=False, color='black', line_kws={'linestyle':'--'})
    axes[0].set_xscale('log')
    axes[0].set_title('Does Fame Help? (Whole Shot Error vs. Frequency)', fontsize=14)
    axes[0].grid(True, which="both", alpha=0.2)
    
    # Plot B: Type
    sns.scatterplot(data=df, x='Frequency', y='Error_Rate_Type', ax=axes[1], color='#1f77b4', alpha=0.6)
    sns.regplot(data=df, x='Frequency', y='Error_Rate_Type', ax=axes[1], scatter=False, color='black', line_kws={'linestyle':'--'})
    axes[1].set_xscale('log')
    axes[1].set_title('Shot Type Prediction Error vs. Frequency', fontsize=14)
    axes[1].grid(True, which="both", alpha=0.2)
    
    plt.tight_layout()
    plt.show()
    
    # --- ADDED: Correlation Calculation ---
    corr = df['Frequency'].corr(df['Error_Rate_Whole'])
    print(f"Analyzed {len(df)} unique players.")
    print(f"Correlation between Frequency and Error Rate: {corr:.4f}")
    
    if corr < -0.3:
        print(">> Observation: Strong negative correlation. The model is significantly better at predicting famous players.")
    elif corr > 0:
        print(">> Observation: No advantage for famous players. The model generalizes well!")

# ==============================================================================
# 5. ERA COMPARISON
# ==============================================================================
def evaluation_compare_eras(model, dataset, test_indices, matches_per_era=50):
    model.eval()
    print("\n--- ERA STABILITY ANALYSIS ---")
    match_to_indices = {}
    for idx in test_indices:
        match_to_indices.setdefault(dataset.sample_match_ids[idx], []).append(idx)
        
    eras = {'Pre-2010': [], '2010-2019': [], '2020-Present': []}
    for m_id in match_to_indices:
        try: y = int(str(m_id)[:4])
        except: continue
        if y < 2010: eras['Pre-2010'].append(m_id)
        elif y < 2020: eras['2010-2019'].append(m_id)
        else: eras['2020-Present'].append(m_id)
        
    selected_indices = []
    era_labels = []
    uni_map = get_fast_decoder_map(dataset)
    
    for era, m_list in eras.items():
        chosen = random.sample(m_list, min(matches_per_era, len(m_list)))
        for m in chosen:
            selected_indices.extend(match_to_indices[m])
            era_labels.extend([era]*len(match_to_indices[m]))
            
    results = []
    with torch.no_grad():
        for i, idx in enumerate(selected_indices):
            sample = dataset[idx]; era = era_labels[i]
            x_seq = sample['x_seq'].unsqueeze(0).to(DEVICE)
            x_c = sample['context'].unsqueeze(0).to(DEVICE)
            x_s = sample['x_s_id'].unsqueeze(0).to(DEVICE)
            x_r = sample['x_r_id'].unsqueeze(0).to(DEVICE)
            y = sample['y_target'].to(DEVICE)
            
            logits = model(x_seq, x_c, x_s, x_r)
            preds = logits.argmax(dim=-1).squeeze(0)
            
            for t in range(x_seq.shape[1]):
                if y[t] == 0: continue
                p_uid = preds[t].item(); t_uid = y[t].item()
                p_t, p_d, _ = uni_map.get(p_uid, (0,0,0))
                t_t, t_d, _ = uni_map.get(t_uid, (0,0,0))
                
                results.append({
                    'Era': era,
                    'Type Error': 1.0 if p_t != t_t else 0.0,
                    'Direction Error': 1.0 if p_d != t_d else 0.0,
                    'Whole Shot Error': 1.0 if p_uid != t_uid else 0.0
                })
                
    if not results: return
    df = pd.DataFrame(results)
    stats = df.groupby('Era').mean() * 100
    print(stats)
    
    df_melt = df.melt(id_vars=['Era'], value_vars=['Whole Shot Error', 'Type Error', 'Direction Error'], value_name='Error Rate')
    plt.figure(figsize=(12, 6))
    sns.barplot(data=df_melt, x='Era', y='Error Rate', hue='variable', order=['Pre-2010','2010-2019','2020-Present'], errorbar=('ci',95))
    plt.title('Error Rates Across Eras'); plt.show()

# ==============================================================================
# 6. COURT SPEED (Raw)
# ==============================================================================
def evaluation_court_speed_vs_error(model, dataset, test_indices, sample_size=5000):
    model.eval()
    print("\n" + "="*50)
    print(" RAW ERROR ANALYSIS BY SURFACE (Fixed Depth Logic) ")
    print("="*50)

    # 1. Group Test Indices by Surface
    surface_map = {'Clay': [], 'Hard': [], 'Grass': []}
    for idx in test_indices:
        surf = dataset.match_meta.get(dataset.sample_match_ids[idx], {}).get('surface', 'Hard')
        found = False
        for k in surface_map: 
            if k in surf: 
                surface_map[k].append(idx)
                found = True
                break
        if not found: surface_map['Hard'].append(idx)
            
    # 2. Select Samples Balanced by Surface
    selected_indices, surface_labels = [], []
    per_surf = sample_size // 3
    
    # Ensure we have the decoder map
    uni_map = get_fast_decoder_map(dataset)
    
    for s, inds in surface_map.items():
        if not inds: continue
        chosen = random.sample(inds, min(len(inds), per_surf))
        selected_indices.extend(chosen)
        surface_labels.extend([s]*len(chosen))
        
    # 3. Evaluation Loop
    results = []
    with torch.no_grad():
        for i, idx in enumerate(selected_indices):
            sample = dataset[idx]
            surf = surface_labels[i]
            
            # Add Batch Dimension
            x_seq = sample['x_seq'].unsqueeze(0).to(DEVICE)
            x_c = sample['context'].unsqueeze(0).to(DEVICE)
            x_s = sample['x_s_id'].unsqueeze(0).to(DEVICE)
            x_r = sample['x_r_id'].unsqueeze(0).to(DEVICE)
            y = sample['y_target'].to(DEVICE)
            
            # Forward Pass
            logits = model(x_seq, x_c, x_s, x_r)
            preds = logits.argmax(dim=-1).squeeze(0)
            
            # Compare Tokens
            for t in range(x_seq.shape[1]):
                if y[t] == 0: continue # Skip padding
                
                p_uid = preds[t].item()
                t_uid = y[t].item()
                
                # Decode (Type, Dir, Depth)
                # 0=N/A, 1='7'(Shallow), 2='8'(Deep), 3='9'(Very Deep)
                p_t, p_d, p_dp = uni_map.get(p_uid, (0,0,0))
                t_t, t_d, t_dp = uni_map.get(t_uid, (0,0,0))
                
                # --- FIX: MASK DEPTH ERROR ---
                # We only care about Depth Error if the TARGET actually has a depth annotation.
                # If Target Depth is 0 (N/A), we set error to None so Pandas ignores it in mean().
                if t_dp != 0:
                    depth_err = 1.0 if p_dp != t_dp else 0.0
                else:
                    depth_err = None

                results.append({
                    'Surface': surf,
                    'Type Error': 1.0 if p_t != t_t else 0.0,
                    'Direction Error': 1.0 if p_d != t_d else 0.0,
                    'Depth Error': depth_err, # Will be float or None
                    'Whole Shot Error': 1.0 if p_uid != t_uid else 0.0
                })
                
    # 4. Statistics & Plotting
    if not results:
        print("No results generated.")
        return

    df = pd.DataFrame(results)
    
    # Print Table (Mean Error %)
    # Pandas .mean() automatically ignores None/NaN values, giving us the correct accuracy
    # for only the shots where depth was applicable.
    stats = df.groupby('Surface')[['Type Error', 'Direction Error', 'Depth Error', 'Whole Shot Error']].mean() * 100
    print("\n--- Mean Error Rates (%) [Depth calculated only on valid targets] ---")
    print(stats.round(2))
    
    # Plotting
    # Melt handles None values (they become NaNs), seaborn barplot ignores them in calculation
    df_melt = df.melt(id_vars=['Surface'], 
                      value_vars=['Type Error', 'Direction Error', 'Depth Error', 'Whole Shot Error'], 
                      value_name='Error Rate')
    
    plt.figure(figsize=(12, 6))
    sns.barplot(data=df_melt, x='Surface', y='Error Rate', hue='variable', 
                order=['Clay', 'Hard', 'Grass'], palette='viridis')
    plt.title('Error Rates by Component vs. Whole Shot (Depth Masked)')
    plt.ylabel('Error Rate (0.0 - 1.0)')
    plt.legend(title='Metric')
    plt.grid(axis='y', alpha=0.3)
    plt.show()


# --- RUNNER ---
#evaluation_analyze_tactics_multitask(model, test_loader, dataset)
if 'test_indices' in locals():
    #evaluation_live_test_cases(model, dataset, test_indices, num_samples=5000,print_flag=False)
    #evaluation_length_vs_errrors(model, dataset, test_indices, num_matches=2000)
    #evaluation_player_frequency_vs_error(model, dataset, test_indices, num_matches=2000)
    #evaluation_compare_eras(model, dataset, test_indices)
    evaluation_court_speed_vs_error(model, dataset, test_indices, sample_size=10000)
else:
    print("Please run split first.")

# MULTI HEAD

## BASELINE FFN

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import pandas as pd
import numpy as np
import os
import re

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
import pandas as pd
import numpy as np
import seaborn as sns
from sklearn.model_selection import train_test_split, KFold
from sklearn.utils.class_weight import compute_class_weight
import re
import os
import random

# --- CONFIGURAZIONE GLOBALE ---
SEQ_LEN = 30       # Fixed sequence length
BATCH_SIZE = 64
EPOCHS = 10
LR = 1e-3
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Define the path
save_path = 'tennis_shot_forecasting.pth'

# Seed everything to avoid randomness
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    print(f"Random seed set to {seed}")

# Call it immediately
seed_everything(42)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split

def train_baseline_model(dataset, epochs=5, batch_size=64, lr=1e-3, device='cuda'):
    print("--- STARTING BASELINE (FEED-FORWARD) TRAINING ---")
    
    # 1. IDENTIFY SERVE TOKENS (To exclude from Type/Dir evaluation)
    serve_type_ids = set()
    if hasattr(dataset, 'type_vocab'):
        for key, idx in dataset.type_vocab.items():
            if key.lower().startswith('serve') or key.startswith('S_'):
                serve_type_ids.add(idx)
    
    # Split
    # Calculate lengths
    total_len = len(dataset)
    train_len = int(0.80 * total_len)
    val_len   = int(0.15 * total_len)
    test_len  = total_len - train_len - val_len # Remaining to ensure sum is correct
    
    # Perform the Split
    train_subset, val_subset, test_subset = random_split(
        dataset, [train_len, val_len, test_len], generator=torch.Generator().manual_seed(42)
    )

    # Create Loaders
    train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader   = DataLoader(val_subset,   batch_size=BATCH_SIZE, shuffle=False)
    test_loader  = DataLoader(test_subset,  batch_size=1, shuffle=True) # Batch size 1 makes live sampling easier

    model = SimpleMultiHeadBaseline(
        unified_vocab_size=len(dataset.unified_vocab),
        type_vocab_size=len(dataset.type_vocab),
        dir_vocab_size=len(dataset.dir_vocab),
        depth_vocab_size=len(dataset.depth_vocab),
        context_dim=10
    ).to(device)
    
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    
    serve_tensor = torch.tensor(list(serve_type_ids), device=device)

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        
        for batch in train_loader:
            x_seq = batch['x_seq'].to(device)
            ctx   = batch['context'].to(device)
            y_type = batch['y_type'].to(device)
            y_dir  = batch['y_dir'].to(device)
            y_depth= batch['y_depth'].to(device)
            
            # --- Mask Serves for Loss ---
            is_serve = torch.isin(y_type, serve_tensor)
            
            y_type_masked = y_type.clone()
            y_type_masked[is_serve] = 0
            
            y_dir_masked = y_dir.clone()
            y_dir_masked[is_serve] = 0
            
            # For Depth LOSS, we likely want to ignore "No Depth" (1) as well as Pad (0)
            # if we want the model to focus on learning actual depths.
            # But usually we leave Loss as-is (just ignore 0) and fix the Metric.
            # However, if 90% of data is class 1, the loss will be dominated by class 1.
            # Optional: Uncomment below to ignore class 1 in training too
            # y_depth_masked = y_depth.clone()
            # y_depth_masked[is_serve] = 0
            # y_depth_masked[y_depth == 1] = 0 
            
            optimizer.zero_grad()
            l_type, l_dir, l_depth = model(x_seq, ctx)
            
            loss_t = criterion(l_type.view(-1, l_type.size(-1)), y_type_masked.view(-1))
            loss_d = criterion(l_dir.view(-1, l_dir.size(-1)), y_dir_masked.view(-1))
            loss_p = criterion(l_depth.view(-1, l_depth.size(-1)), y_depth.view(-1)) # Keep original (or masked)
            
            loss = loss_t + loss_d + loss_p
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            
        # --- Evaluation ---
        model.eval()
        acc_type, acc_dir, acc_depth = 0, 0, 0
        tokens_type, tokens_dir, tokens_depth = 0, 0, 0
        
        with torch.no_grad():
            for batch in val_loader:
                x_seq = batch['x_seq'].to(device)
                ctx   = batch['context'].to(device)
                y_type = batch['y_type'].to(device)
                y_dir  = batch['y_dir'].to(device)
                y_depth= batch['y_depth'].to(device)
                
                l_type, l_dir, l_depth = model(x_seq, ctx)
                
                # --- MASKS ---
                is_serve = torch.isin(y_type, serve_tensor)
                
                # 1. Type/Dir Mask: Ignore Pad (0) AND Serves
                mask_common = (y_type != 0) & (~is_serve)
                
                # 2. Depth Mask: Ignore Pad (0) AND Index 1 (Likely "N/A" or "Unknown")
                #    We also ignore Serves just in case.
                mask_depth = (y_depth > 1) & (~is_serve) 

                # Count valid tokens
                tokens_type += mask_common.sum().item()
                tokens_depth += mask_depth.sum().item()
                
                # Accumulate Correct Predictions
                if mask_common.sum() > 0:
                    acc_type += (l_type.argmax(-1)[mask_common] == y_type[mask_common]).sum().item()
                    acc_dir  += (l_dir.argmax(-1)[mask_common]  == y_dir[mask_common]).sum().item()
                
                if mask_depth.sum() > 0:
                    acc_depth += (l_depth.argmax(-1)[mask_depth] == y_depth[mask_depth]).sum().item()
        
        # Print
        loss_avg = total_loss/len(train_loader)
        type_pct = (acc_type / tokens_type * 100) if tokens_type > 0 else 0
        dir_pct  = (acc_dir  / tokens_type * 100) if tokens_type > 0 else 0
        depth_pct= (acc_depth/ tokens_depth * 100) if tokens_depth > 0 else 0
        
        print(f"Epoch {epoch+1} | Loss: {loss_avg:.4f}")
        print(f"   Type Acc:  {type_pct:.2f}% (N={tokens_type})")
        print(f"   Dir Acc:   {dir_pct:.2f}% (N={tokens_type})")
        print(f"   Depth Acc: {depth_pct:.2f}% (N={tokens_depth}) <--- Only counting depth > 1")

    return model

In [None]:
base_path = '/kaggle/input/atp-points/'

# List all point files to merge
point_files = [
    base_path + 'charting-m-points-2020s.csv',
    base_path + 'charting-m-points-2010s.csv',
    base_path + 'charting-m-points-to-2009.csv'
]

# New Matches File
matches_path = '/kaggle/input/atp-matches-updated/charting-m-matches-updated.csv'

atp_path = '/kaggle/input/atp-players/atp_players.csv'
wta_path = '/kaggle/input/atp-players/wta_players.csv'
dataset = MCPMultiTaskDataset(point_files, matches_path, atp_path, wta_path, max_seq_len=SEQ_LEN) 

if 'dataset' not in globals():
    print("Please ensure 'dataset' is loaded.")
else:
    baseline = train_baseline_model(dataset, epochs=15, batch_size=512, device=DEVICE)

### Evaluation

In [None]:
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from collections import Counter
import seaborn as sns
import pandas as pd
import numpy as np
import random
import torch
from torch.utils.data import DataLoader, random_split

# --- 1. SHARED CONFIG & HELPERS ---
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

def get_inverse_vocabs(dataset):
    """Creates reverse lookups for the multi-head vocabularies"""
    inv_type  = {v: k for k, v in dataset.type_vocab.items()}
    inv_dir   = {v: k for k, v in dataset.dir_vocab.items()}
    inv_depth = {v: k for k, v in dataset.depth_vocab.items()}
    return inv_type, inv_dir, inv_depth

# --- 2. THE MULTI-TASK EVALUATION FUNCTION ---
def run_full_evaluation(model, dataset, loader, test_indices, 
                        live_samples=5000, 
                        length_matches=2000, 
                        freq_matches=2000, 
                        era_matches=50, 
                        speed_samples=10000):
    
    model.eval()
    print(f"Starting Multi-Task Evaluation on {len(test_indices)} test samples...")
    print(f"Batch Size: {loader.batch_size} (Part 1 will run much faster now)")

    # --- Helpers ---
    inv_type, inv_dir, inv_depth = get_inverse_vocabs(dataset)
    
    # Reverse Unified Vocab
    inv_unified = {}
    if hasattr(dataset, 'unified_vocab'):
        inv_unified = {v: k for k, v in dataset.unified_vocab.items()}

    serve_type_id = dataset.type_vocab.get('serve', dataset.type_vocab.get('s', -1))
    unk_depth_id = dataset.depth_vocab.get('0', -1)
    unk_dir_id = dataset.dir_vocab.get('0', -1)

    # --- Match Map ---
    if not hasattr(dataset, 'sample_match_ids'):
        print("Error: Dataset missing 'sample_match_ids'.")
        return

    match_map = {}
    for idx in test_indices:
        mid = dataset.sample_match_ids[idx]
        match_map.setdefault(mid, []).append(idx)
    unique_matches = list(match_map.keys())

    # ==============================================================================
    # PART 1: OVERALL TACTICAL METRICS (Vectorized & Fast)
    # ==============================================================================
    print("\n" + "="*40 + "\n PART 1: OVERALL TACTICAL METRICS \n" + "="*40)
    
    all_preds_type, all_targs_type = [], []
    all_preds_dir,  all_targs_dir  = [], []
    all_preds_depth, all_targs_depth = [], []
    
    with torch.no_grad():
        for i, batch in enumerate(loader):
            if i % 50 == 0: print(f"Processing batch {i}...", end='\r')
            
            x_seq = batch['x_seq'].to(DEVICE)
            x_c   = batch['context'].to(DEVICE)
            
            # Inference on full batch
            l_type, l_dir, l_depth = model(x_seq, x_c)
            
            # --- Type ---
            y_type = batch['y_type'].to(DEVICE).view(-1)
            mask_t = (y_type != 0)
            if serve_type_id != -1: mask_t &= (y_type != serve_type_id)

            if mask_t.sum() > 0:
                all_preds_type.extend(l_type.argmax(-1).view(-1)[mask_t].cpu().numpy())
                all_targs_type.extend(y_type[mask_t].cpu().numpy())

            # --- Direction ---
            y_dir = batch['y_dir'].to(DEVICE).view(-1)
            mask_d = (y_dir != 0) & (y_dir != unk_dir_id)
            if serve_type_id != -1: mask_d &= (y_type != serve_type_id)

            if mask_d.sum() > 0:
                all_preds_dir.extend(l_dir.argmax(-1).view(-1)[mask_d].cpu().numpy())
                all_targs_dir.extend(y_dir[mask_d].cpu().numpy())

            # --- Depth ---
            y_depth = batch['y_depth'].to(DEVICE).view(-1)
            mask_dp = (y_depth != 0) & (y_depth != unk_depth_id)
            if serve_type_id != -1: mask_dp &= (y_type != serve_type_id)

            if mask_dp.sum() > 0:
                all_preds_depth.extend(l_depth.argmax(-1).view(-1)[mask_dp].cpu().numpy())
                all_targs_depth.extend(y_depth[mask_dp].cpu().numpy())

    # Reports
    print("\n=== SHOT TYPE REPORT (Rally Only) ===")
    labels = [k for k in dataset.type_vocab if k not in ['<pad>', '<unk>', 'serve', 's']]
    indices = [dataset.type_vocab[k] for k in labels]
    present = [i for i in indices if i in np.unique(all_targs_type)]
    present_lbls = [inv_type[i] for i in present]
    if present:
        print(classification_report(all_targs_type, all_preds_type, labels=present, target_names=present_lbls, zero_division=0))

    print("\n=== DIRECTION REPORT ===")
    labels = [k for k in dataset.dir_vocab if k not in ['<pad>', '0']]
    indices = [dataset.dir_vocab[k] for k in labels]
    if indices:
        print(classification_report(all_targs_dir, all_preds_dir, labels=indices, target_names=labels, zero_division=0))

    print("\n=== DEPTH REPORT ===")
    labels = [k for k in dataset.depth_vocab if k not in ['<pad>', '0']]
    indices = [dataset.depth_vocab[k] for k in labels]
    if indices:
        print(classification_report(all_targs_depth, all_preds_depth, labels=indices, target_names=labels, zero_division=0))

    # ==============================================================================
    # PART 2: LIVE SAMPLES (Updated for Large Batch Sizes)
    # ==============================================================================
    print("\n" + "="*40 + f"\n PART 2: LIVE SAMPLES (Showing {live_samples} Cases) \n" + "="*40)
    
    processed_count = 0
    correct_3_of_3 = 0
    results_buffer = {3: [], 2:[], 1:[], 0:[]} 

    with torch.no_grad():
        for i, batch in enumerate(loader):
            if processed_count >= live_samples:
                break
                
            x_seq_batch = batch['x_seq'].to(DEVICE)     # [B, SeqLen]
            ctx_batch   = batch['context'].to(DEVICE)   # [B, ContextDim]
            y_t_batch   = batch['y_type'].to(DEVICE)
            y_d_batch   = batch['y_dir'].to(DEVICE)
            y_dp_batch  = batch['y_depth'].to(DEVICE)

            # Inference on WHOLE batch
            l_t_batch, l_d_batch, l_dp_batch = model(x_seq_batch, ctx_batch) 

            # Iterate through items in the batch
            batch_size = x_seq_batch.size(0)
            
            for k in range(batch_size):
                if processed_count >= live_samples: break
                
                # --- SINGLE ITEM LOGIC START ---
                # Check valid indices for this specific rally 'k'
                seq_len = x_seq_batch.size(1)
                valid_indices = []
                for t in range(seq_len):
                    true_type = y_t_batch[k, t].item()
                    if true_type > 1 and inv_type.get(true_type) not in ['serve', 's']:
                        valid_indices.append(t)
                
                if not valid_indices: 
                    continue

                t = random.choice(valid_indices)
                
                # RECONSTRUCT HISTORY (Using index k)
                history_str = ""
                raw_history = x_seq_batch[k, :t+1].cpu().numpy()
                hist_tokens = []
                for h_idx in raw_history:
                    if h_idx == 0: continue 
                    token_str = inv_unified.get(h_idx, '?')
                    parts = token_str.split('_')
                    if parts[0].lower() in ['serve', 's']:
                        hist_tokens.append(f"[S{parts[1] if len(parts)>1 else ''}]")
                    else:
                        shot = parts[0].upper()
                        dire = parts[1] if len(parts)>1 else ''
                        hist_tokens.append(f"{shot}{dire}")
                
                history_str = " -> ".join(hist_tokens[-6:]) 
                
                # PREDICTION (Using index k)
                pred_t = l_t_batch[k, t].argmax().item()
                pred_d = l_d_batch[k, t].argmax().item()
                pred_dp = l_dp_batch[k, t].argmax().item()
                conf_t = torch.softmax(l_t_batch[k, t], dim=0).max().item() * 100

                true_t = y_t_batch[k, t].item()
                true_d = y_d_batch[k, t].item()
                true_dp = y_dp_batch[k, t].item()

                s_pred_t = inv_type.get(pred_t, '?')
                s_pred_d = inv_dir.get(pred_d, '?')
                s_pred_dp = inv_depth.get(pred_dp, '?')
                
                s_true_t = inv_type.get(true_t, '?')
                s_true_d = inv_dir.get(true_d, '?')
                s_true_dp = inv_depth.get(true_dp, '?')

                ok_t = "‚úÖ" if pred_t == true_t else "‚ùå"
                ok_d = "‚úÖ" if pred_d == true_d else "‚ùå"
                ok_dp = "‚úÖ" if pred_dp == true_dp else "‚ùå"

                score = (1 if pred_t==true_t else 0) + (1 if pred_d==true_d else 0) + (1 if pred_dp==true_dp else 0)
                if score == 3: correct_3_of_3 += 1
                
                res_str = (
                    f"Sample #{processed_count+1}\n"
                    f"History:  ... {history_str}\n"
                    f"Model:    [{s_pred_t}] to Zone [{s_pred_d}] (Depth {s_pred_dp}) | Conf: {conf_t:.0f}%\n"
                    f"Actual:   [{s_true_t}] to Zone [{s_true_d}] (Depth {s_true_dp}) | {ok_t}{ok_d}{ok_dp}\n"
                    f"{'-'*40}"
                )
                results_buffer[score].append(res_str)
                processed_count += 1
                # --- SINGLE ITEM LOGIC END ---

            if processed_count % 500 == 0:
                print(f"Processed {processed_count}/{live_samples} samples...", end='\r')

    # Print Best Matches
    print_flag = False
    print("\n--- SAMPLE PREDICTIONS ---")
    for s in [3, 2, 1, 0]:
        items = results_buffer[s]
        if items:
            print(f"\n{'='*20} {s}/3 CORRECT ({len(items)} cases) {'='*20}")
            if print_flag:    
                for item in items:
                    print(item)

    # ==============================================================================
    # PART 3: GRANULAR ACCURACY VS RALLY LENGTH
    # ==============================================================================
    print("\n" + "="*40 + "\n PART 3: GRANULAR ACCURACY VS RALLY LENGTH \n" + "="*40)
    
    # 3.1 Calculate Baselines
    print("Calculating dataset baselines...")
    all_d, all_dp, all_tp = [], [], []
    for i in test_indices:
        sample = dataset[i]
        yt = sample['y_type']
        yd = sample['y_dir']
        ydp = sample['y_depth']
        
        if torch.is_tensor(yt): yt = yt.cpu().numpy()
        if torch.is_tensor(yd): yd = yd.cpu().numpy()
        if torch.is_tensor(ydp): ydp = ydp.cpu().numpy()

        for j in range(len(yt)):
            if yt[j] == 0: continue
            if yt[j] != serve_type_id:
                all_tp.append(yt[j])
                if yd[j] != unk_dir_id and yd[j] != 0: all_d.append(yd[j])
                if ydp[j] != unk_depth_id and ydp[j] != 0: all_dp.append(ydp[j])

    def calc_baseline(data_list):
        if not data_list: return 0.33
        counts = Counter(data_list)
        total = sum(counts.values())
        return sum([(c/total)**2 for c in counts.values()])

    base_d = calc_baseline(all_d)
    base_dp = calc_baseline(all_dp)
    base_tp = calc_baseline(all_tp)
    base_pair_avg = (base_d*base_dp + base_d*base_tp + base_tp*base_dp) / 3
    base_whole = base_d * base_dp * base_tp
    print(f"Baselines -> Dir: {base_d:.2f}, Depth: {base_dp:.2f}, Type: {base_tp:.2f}, Whole: {base_whole:.4f}")
    
    # 3.2 Analysis Loop
    chosen_matches = random.sample(unique_matches, min(length_matches, len(unique_matches)))
    rl_indices = [idx for mid in chosen_matches for idx in match_map[mid]]
    print(f"Analyzing {len(rl_indices)} points from {len(chosen_matches)} matches...")

    results_p3 = []
    with torch.no_grad():
        for idx in rl_indices:
            sample = dataset[idx]
            x_seq = sample['x_seq'].unsqueeze(0).to(DEVICE)
            x_c   = sample['context'].unsqueeze(0).to(DEVICE)
            
            l_type, l_dir, l_depth = model(x_seq, x_c)
            preds_t = l_type.argmax(dim=-1).squeeze(0)
            preds_d = l_dir.argmax(dim=-1).squeeze(0)
            preds_dp = l_depth.argmax(dim=-1).squeeze(0)
            
            y_t = sample['y_type'].to(DEVICE)
            y_d = sample['y_dir'].to(DEVICE)
            y_dp = sample['y_depth'].to(DEVICE)
            
            x_seq_cpu = sample['x_seq']
            
            limit = min(len(y_t), len(preds_t))

            for t_step in range(limit):
                if y_t[t_step] == 0: continue
                if y_t[t_step] == serve_type_id: continue

                shot_num = (x_seq_cpu[:t_step+1] != 0).sum().item() + 1
                
                pt, p_d, pdp = preds_t[t_step].item(), preds_d[t_step].item(), preds_dp[t_step].item()
                tt, td, tdp = y_t[t_step].item(), y_d[t_step].item(), y_dp[t_step].item()

                ok_t = (pt == tt)
                has_dir = (td != 0 and td != unk_dir_id)
                has_depth = (tdp != 0 and tdp != unk_depth_id)
                ok_d = (p_d == td) if has_dir else False
                ok_dp = (pdp == tdp) if has_depth else False

                # Logic Split - Single
                results_p3.append({'Shot_Number': shot_num, 'Task': 'Type', 'Type': 'Single', 'Accuracy': 1 if ok_t else 0})
                if has_dir:
                    results_p3.append({'Shot_Number': shot_num, 'Task': 'Direction', 'Type': 'Single', 'Accuracy': 1 if ok_d else 0})
                if has_depth:
                    results_p3.append({'Shot_Number': shot_num, 'Task': 'Depth', 'Type': 'Single', 'Accuracy': 1 if ok_dp else 0})
                
                # Logic Split - Pair
                if has_dir:
                    results_p3.append({'Shot_Number': shot_num, 'Task': 'Dir + Type', 'Type': 'Pair', 'Accuracy': 1 if (ok_d and ok_t) else 0})
                if has_depth:
                    results_p3.append({'Shot_Number': shot_num, 'Task': 'Type + Depth', 'Type': 'Pair', 'Accuracy': 1 if (ok_dp and ok_t) else 0})
                if has_dir and has_depth:
                    results_p3.append({'Shot_Number': shot_num, 'Task': 'Dir + Depth', 'Type': 'Pair', 'Accuracy': 1 if (ok_d and ok_dp) else 0})
                
                # Logic Split - Whole
                if has_dir and has_depth:
                    results_p3.append({'Shot_Number': shot_num, 'Task': 'Whole Shot', 'Type': 'Whole', 'Accuracy': 1 if (ok_d and ok_dp and ok_t) else 0})

    if results_p3:
        df = pd.DataFrame(results_p3)
        df = df[(df['Shot_Number'] <= 12) & (df['Shot_Number'] >= 2)]
        
        palette_single = {'Direction': '#1f77b4', 'Depth': '#d62728', 'Type': '#2ca02c'}
        palette_pair   = {'Dir + Depth': '#9467bd', 'Dir + Type': '#17becf', 'Type + Depth': '#ff7f0e'}
        palette_whole  = {'Whole Shot': '#000000'}

        def setup_plot(title, baseline, base_label):
            plt.figure(figsize=(12, 5))
            plt.title(title, fontsize=14)
            plt.ylabel('Accuracy', fontsize=12)
            plt.xlabel('Shot Number', fontsize=12)
            plt.xticks(np.arange(2, 13, 1))
            
            ax = plt.gca()
            ax.yaxis.set_major_locator(ticker.MultipleLocator(0.1))
            ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.02))
            plt.grid(True, which='major', axis='y', linestyle='-', linewidth=0.75, color='grey', alpha=0.6)
            plt.grid(True, which='minor', axis='y', linestyle='--', linewidth=0.5, color='grey', alpha=0.3)
            plt.ylim(0.0, 1.0)
            
            if baseline:
                plt.axhline(baseline, color='#FF1493', linestyle=':', alpha=0.8, linewidth=2, label=base_label)

        setup_plot('Single Task Accuracy vs. Rally Length', base_tp, f'Random Type ({base_tp:.2f})')
        sns.lineplot(data=df[df['Type']=='Single'], x='Shot_Number', y='Accuracy', hue='Task', style='Task', 
                     markers=True, dashes=False, palette=palette_single, linewidth=2.5, errorbar=('ci', 68))
        plt.legend(loc='lower right'); plt.show()
        
        setup_plot('Pairwise Accuracy vs. Rally Length', base_pair_avg, f'Random Pair (~{base_pair_avg:.2f})')
        sns.lineplot(data=df[df['Type']=='Pair'], x='Shot_Number', y='Accuracy', hue='Task', style='Task', 
                     markers=True, dashes=False, palette=palette_pair, linewidth=2.5, errorbar=('ci', 68))
        plt.legend(loc='upper right'); plt.show()

        setup_plot('Whole Shot Accuracy vs. Rally Length', base_whole, f'Random Whole ({base_whole:.4f})')
        sns.lineplot(data=df[df['Type']=='Whole'], x='Shot_Number', y='Accuracy', hue='Task', style='Task', 
                     markers=True, dashes={'Whole Shot':(2,2)}, palette=palette_whole, linewidth=2.5, errorbar=('ci', 68))
        plt.legend(loc='upper right'); plt.show()
    # ==============================================================================
    # PART 4: PLAYER FREQUENCY
    # ==============================================================================
    if 'x_s_id' in dataset[0]:
        print("\n" + "="*40 + f"\n PART 4: PLAYER FREQUENCY ({freq_matches} Matches) \n" + "="*40)
        chosen_matches = random.sample(unique_matches, min(freq_matches, len(unique_matches)))
        pf_indices = [idx for mid in chosen_matches for idx in match_map[mid]]
        
        p_counts = Counter()
        p_stats = {}
        
        with torch.no_grad():
            for idx in pf_indices:
                sample = dataset[idx]
                x_seq = sample['x_seq'].unsqueeze(0).to(DEVICE)
                x_c   = sample['context'].unsqueeze(0).to(DEVICE)
                y_type = sample['y_type'].to(DEVICE)
                
                l_t, _, _ = model(x_seq, x_c)
                preds = l_t.argmax(-1).squeeze(0)
                
                s_val = sample['x_s_id'] if isinstance(sample['x_s_id'], int) else sample['x_s_id'].item()
                r_val = sample['x_r_id'] if isinstance(sample['x_r_id'], int) else sample['x_r_id'].item()
                
                x_seq_cpu = sample['x_seq']
                
                for t in range(len(y_type)):
                    if y_type[t] == 0: continue
                    if y_type[t] == serve_type_id: continue
                    if t >= len(preds): break

                    hist_len = (x_seq_cpu[:t+1] != 0).sum().item()
                    actor = s_val if (hist_len + 1) % 2 != 0 else r_val
                    if actor <= 1: continue 
                    
                    p_counts[actor] += 1
                    if actor not in p_stats: p_stats[actor] = {'tot': 0, 'corr': 0}
                    p_stats[actor]['tot'] += 1
                    if preds[t].item() == y_type[t].item(): p_stats[actor]['corr'] += 1

        pf_data = [{'Freq': p_counts[a], 'Err': (1 - v['corr']/v['tot'])*100} for a, v in p_stats.items() if p_counts[a] > 10]
        if pf_data:
            df_pf = pd.DataFrame(pf_data)
            plt.figure(figsize=(8, 4))
            sns.regplot(data=df_pf, x='Freq', y='Err', scatter_kws={'alpha':0.5}, line_kws={'color':'red'})
            plt.xscale('log'); plt.title(f"Type Error vs Frequency"); plt.show()
    else:
        print("\n(Part 4 Skipped: Player IDs not in dataset)")

    # ==============================================================================
    # PART 5: ERA STABILITY
    # ==============================================================================
    print("\n" + "="*40 + f"\n PART 5: ERA STABILITY ({era_matches} Matches/Era) \n" + "="*40)
    eras = {'Pre-2010': [], '2010-2019': [], '2020+': []}
    for m_id in unique_matches:
        try: y_year = int(str(m_id)[:4])
        except: continue
        if y_year < 2010: eras['Pre-2010'].append(m_id)
        elif y_year < 2020: eras['2010-2019'].append(m_id)
        else: eras['2020+'].append(m_id)

    era_indices = []
    era_labels_list = []
    for era_name, m_list in eras.items():
        if not m_list: continue
        chosen = random.sample(m_list, min(era_matches, len(m_list)))
        for m in chosen:
            era_indices.extend(match_map[m])
            era_labels_list.extend([era_name]*len(match_map[m]))
            
    era_res = []
    with torch.no_grad():
        for i, idx in enumerate(era_indices):
            sample = dataset[idx]
            x_seq = sample['x_seq'].unsqueeze(0).to(DEVICE)
            x_c   = sample['context'].unsqueeze(0).to(DEVICE)
            y_type = sample['y_type'].to(DEVICE)
            
            l_t, _, _ = model(x_seq, x_c)
            p_t = l_t.argmax(-1).squeeze(0)
            
            # Trim to match logic
            limit = min(len(y_type), len(p_t))
            p_t_trim = p_t[:limit]
            y_type_trim = y_type[:limit]

            mask = (y_type_trim > 1) & (y_type_trim != serve_type_id)
            if mask.sum() > 0:
                correct = (p_t_trim[mask] == y_type_trim[mask])
                acc = correct.float().mean().item()
                era_res.append({'Era': era_labels_list[i], 'Type Acc': acc})
    
    if era_res:
        plt.figure(figsize=(6, 4))
        sns.barplot(data=pd.DataFrame(era_res), x='Era', y='Type Acc', palette='viridis', order=['Pre-2010', '2010-2019', '2020+'])
        plt.title('Shot Type Accuracy by Era'); plt.ylim(0, 1); plt.show()

    # ==============================================================================
    # PART 6: SURFACE DIFFICULTY
    # ==============================================================================
    print("\n" + "="*50)
    print(" PART 6: SURFACE ERROR ANALYSIS (Baseline Model) ")
    print("="*50)

    # --- 1. Efficient Surface Grouping ---
    print(f"Grouping {len(test_indices)} test samples by surface...")
    surface_indices = {'Clay': [], 'Hard': [], 'Grass': []}
    
    # Pre-fetch match IDs via dataset metadata
    for idx in test_indices:
        m_id = dataset.sample_match_ids[idx]
        surf = dataset.match_meta.get(m_id, {}).get('surface', 'Hard')
        
        # Normalize strings
        if 'Clay' in surf:   target = 'Clay'
        elif 'Grass' in surf: target = 'Grass'
        else:                 target = 'Hard'
        
        surface_indices[target].append(idx)

    # --- 2. Evaluation Loop ---
    results = []
    inv_depth = {v: k for k, v in dataset.depth_vocab.items()}
    inv_dir   = {v: k for k, v in dataset.dir_vocab.items()}
    
    # Identify valid IDs (Strict Mode)
    # Valid Depth: '7' (Short), '8' (Deep), '9' (Very Deep)
    valid_depth_ids = [v for k, v in dataset.depth_vocab.items() if k in ['7', '8', '9']]
    
    # Valid Direction: '1', '2', '3' (Excluding '0' center/unknown)
    valid_dir_ids = [v for k, v in dataset.dir_vocab.items() if k not in ['<pad>', '0', '<unk>']]

    with torch.no_grad():
        for surf, inds in surface_indices.items():
            if not inds: 
                continue
            
            print(f"Scanning {len(inds)} samples for {surf}...")
            
            # Create a DataLoader for this surface slice
            surf_ds = Subset(dataset, inds)
            loader = DataLoader(surf_ds, batch_size=batch_size, shuffle=False)
            
            for batch in loader:
                # Move Inputs (Baseline uses x_seq + context)
                x_seq = batch['x_seq'].to(DEVICE)
                ctx   = batch['context'].to(DEVICE)
                
                # Move Targets
                y_t  = batch['y_type'].to(DEVICE)
                y_d  = batch['y_dir'].to(DEVICE)
                y_dp = batch['y_depth'].to(DEVICE)
                
                # Forward Pass (Baseline Signature)
                l_t, l_d, l_dp = model(x_seq, ctx)
                
                # Get Predictions
                p_t  = l_t.argmax(dim=-1)
                p_d  = l_d.argmax(dim=-1)
                p_dp = l_dp.argmax(dim=-1)
                
                # --- Vectorized Error Calculation ---
                
                # 1. Base Mask: Must be a real shot (not padding)
                # We also usually exclude serves from analysis to focus on rally dynamics
                # Check if 'serve' exists in vocab to handle it safely
                serve_id = dataset.type_vocab.get('serve', -1)
                
                valid_mask = (y_t != 0) # Ignore Pad
                if serve_id != -1:
                    valid_mask &= (y_t != serve_id) # Ignore Serves
                
                if valid_mask.sum() == 0: continue

                # Filter down to valid shots (Flattening for easy iteration)
                curr_y_t  = y_t[valid_mask].cpu().numpy()
                curr_p_t  = p_t[valid_mask].cpu().numpy()
                
                curr_y_d  = y_d[valid_mask].cpu().numpy()
                curr_p_d  = p_d[valid_mask].cpu().numpy()
                
                curr_y_dp = y_dp[valid_mask].cpu().numpy()
                curr_p_dp = p_dp[valid_mask].cpu().numpy()

                for i in range(len(curr_y_t)):
                    # A. Type Error
                    t_err = 1.0 if curr_y_t[i] != curr_p_t[i] else 0.0
                    
                    # B. Direction Error (Strict Masking)
                    # Only count if ground truth is a specific direction (1,2,3)
                    if curr_y_d[i] in valid_dir_ids:
                        d_err = 1.0 if curr_y_d[i] != curr_p_d[i] else 0.0
                    else:
                        d_err = None # Ignore 'Center' or 'Unknown' shots
                        
                    # C. Depth Error (VERY Strict Masking)
                    # Only count if ground truth is strictly Deep/Short (7,8,9)
                    if curr_y_dp[i] in valid_depth_ids:
                        dp_err = 1.0 if curr_y_dp[i] != curr_p_dp[i] else 0.0
                    else:
                        dp_err = None # Ignore standard depth shots
                        
                    # D. Whole Shot Error
                    # Defined as: Type must be right, AND (Dir right if valid), AND (Depth right if valid)
                    miss = False
                    if t_err == 1.0: miss = True
                    if d_err is not None and d_err == 1.0: miss = True
                    if dp_err is not None and dp_err == 1.0: miss = True
                    
                    ws_err = 1.0 if miss else 0.0
                    
                    results.append({
                        'Surface': surf,
                        'Type Error': t_err,
                        'Direction Error': d_err,
                        'Depth Error': dp_err,
                        'Whole Shot Error': ws_err
                    })

    # --- 3. Output Generation ---
    if not results:
        print("No results generated.")
        return

    df = pd.DataFrame(results)
    
    # Diagnostics
    depth_counts = df.groupby('Surface')['Depth Error'].count()
    print("\n[Diagnostics] Valid Depth Samples found per surface (Baseline):")
    print(depth_counts)
    
    # Stats Table
    stats = df.groupby('Surface')[['Type Error', 'Direction Error', 'Depth Error', 'Whole Shot Error']].mean() * 100
    print("\n--- Mean Error Rates (%) [Lower is Better] ---")
    print(stats.round(2))
    
    # Plot
    df_melt = df.melt(id_vars=['Surface'], 
                      value_vars=['Type Error', 'Direction Error', 'Depth Error', 'Whole Shot Error'], 
                      value_name='Error Rate')
    
    plt.figure(figsize=(12, 6))
    sns.barplot(data=df_melt, x='Surface', y='Error Rate', hue='variable', 
                order=['Clay', 'Hard', 'Grass'], palette='viridis', errorbar=('ci', 95))
    
    plt.title(f'Baseline Model: Error Rates by Component (N={len(df)})')
    plt.ylabel('Error Rate (0.0 - 1.0)')
    plt.legend(title='Metric')
    plt.grid(axis='y', alpha=0.3)
    plt.ylim(0, 1.0)
    plt.show()


# --- 3. RUNNER SNIPPET ---
if 'dataset' in globals() and 'baseline' in globals():
    print("Recreating 80/15/5 split for evaluation...")
    
    seed_everything(42) 
    
    # Calculate Exact Lengths
    total_len = len(dataset)
    train_len = int(0.80 * total_len)
    val_len   = int(0.15 * total_len)
    test_len  = total_len - train_len - val_len
    
    # 3-Way Split
    _, _, test_ds = random_split(
        dataset, 
        [train_len, val_len, test_len], 
        generator=torch.Generator().manual_seed(42)
    )
    
    # Extract Indices and Loader
    test_indices = test_ds.indices
    
    # IMPORTANT: Use a LARGE batch size (e.g., 512) for speed. 
    # The updated function now handles this correctly for all parts.
    test_loader_eval = DataLoader(test_ds, batch_size=512, shuffle=True)

    print(f"Test Set Ready: {len(test_ds)} samples.")

    # Run Full Evaluation
    run_full_evaluation(
        model=baseline, 
        dataset=dataset, 
        loader=test_loader_eval, 
        test_indices=test_indices,
        live_samples=5000,     
        length_matches=2000,   
        freq_matches=2000,     
        era_matches=50,       
        speed_samples=10000  
    )
else:
    print("Error: 'dataset' or 'baseline' not found. Run training first.")

## CHAINED TRANSFORMER (hierarchical)

In [None]:
def train_hierarchical_model(dataset, epochs=10, batch_size=64, lr=1e-3, device='cuda'):
    print(f"--- STARTING HIERARCHICAL TRANSFORMER TRAINING ---")
    
    # 0. IDENTIFY SPECIAL TOKENS
    serve_type_id = dataset.type_vocab.get('serve', dataset.type_vocab.get('s', -1))
    
    # Identify the ID for '0' (Unknown) in depth/dir vocabs
    unk_depth_id = dataset.depth_vocab.get('0', -1)
    unk_dir_id = dataset.dir_vocab.get('0', -1)
    
    print(f"Masking Config:")
    print(f" - Serve Type ID: {serve_type_id} (Will mask Dir/Depth for serves)")
    print(f" - Unknown Depth ID: {unk_depth_id} (Will be ignored in loss)")
    print(f" - Unknown Dir ID: {unk_dir_id} (Will be ignored in loss)")

    # 1. Split Dataset
    total_len = len(dataset)
    train_len = int(0.8 * total_len)
    val_len = int(0.1 * total_len)
    test_len = total_len - train_len - val_len
    
    train_ds, val_ds, test_ds = random_split(
        dataset, [train_len, val_len, test_len], 
        generator=torch.Generator().manual_seed(42)
    )
    
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0)
    
    try:
        real_context_dim = dataset.context_tensor.shape[1]
    except AttributeError:
        real_context_dim = dataset.context.shape[1]

    # 2. Initialize Model
    model = HierarchicalCristianGPT(
        dir_vocab_size=len(dataset.dir_vocab),
        depth_vocab_size=len(dataset.depth_vocab),
        type_vocab_size=len(dataset.type_vocab), 
        num_players=len(dataset.player_vocab),
        context_dim=real_context_dim,   
        embed_dim=64, 
        n_head=4, 
        n_cycles=3
    ).to(device)
    
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2)
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    
    # 3. Training Loop
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        
        for batch in train_loader:
            x_z = batch['x_dir'].to(device)
            x_t = batch['x_type'].to(device)
            x_c = batch['context'].to(device)
            x_s = batch['x_s_id'].to(device)
            x_r = batch['x_r_id'].to(device)
            
            y_d = batch['y_dir'].to(device)
            y_dp = batch['y_depth'].to(device)
            y_t = batch['y_type'].to(device)
            
            # --- CRITICAL FIX: MASKING TARGETS ---
            # We must mask specific conditions to prevent learning useless "Unknown" labels
            
            # 1. Clone targets
            y_d = y_d.clone()
            y_dp = y_dp.clone()
            y_t = y_t.clone()
            
            # 2. Mask "Unknown" Depths and Directions (Set to 0/Padding)
            # This forces the model to only learn when valid Depth/Dir exists
            if unk_depth_id != -1:
                y_dp[y_dp == unk_depth_id] = 0
            if unk_dir_id != -1:
                y_d[y_d == unk_dir_id] = 0

            # 3. Mask Dir/Depth completely if the shot is a Serve
            if serve_type_id != -1:
                is_serve = (y_t == serve_type_id)
                y_d[is_serve] = 0
                y_dp[is_serve] = 0
                # Note: We do NOT mask y_t here, we want it to learn to predict "Serve" type
            # ---------------------------------------
            
            optimizer.zero_grad()
            
            pred_d, pred_dp, pred_t = model(x_z, x_t, x_c, x_s, x_r)
            
            l_d = criterion(pred_d.view(-1, len(dataset.dir_vocab)), y_d.view(-1))
            l_dp = criterion(pred_dp.view(-1, len(dataset.depth_vocab)), y_dp.view(-1))
            l_t = criterion(pred_t.view(-1, len(dataset.type_vocab)), y_t.view(-1))
            
            loss = l_d + l_dp + l_t
            
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            
        avg_train_loss = total_loss / len(train_loader)
        
        # Validation
        val_metrics = evaluate_hierarchical(model, val_loader, dataset, device, serve_type_id)
        
        current_lr = optimizer.param_groups[0]['lr']
        scheduler.step(val_metrics['type'])
        
        print(f"Ep {epoch+1}/{epochs} | Loss: {avg_train_loss:.4f} | LR: {current_lr:.1e} | "
              f"Val Dir: {val_metrics['dir']:.1f}% | "
              f"Val Depth: {val_metrics['depth']:.1f}% | "
              f"Val Type: {val_metrics['type']:.1f}%")

    return model

In [None]:
import torch
import torch.nn as nn


In [None]:
if __name__ == "__main__":
    base_path = '/kaggle/input/atp-points/'

    # List all point files to merge
    point_files = [
        base_path + 'charting-m-points-2020s.csv',
        base_path + 'charting-m-points-2010s.csv',
        base_path + 'charting-m-points-to-2009.csv'
    ]
    
    # New Matches File
    matches_path = '/kaggle/input/atp-matches-updated/charting-m-matches-updated.csv'
    
    atp_path = '/kaggle/input/atp-players/atp_players.csv'
    wta_path = '/kaggle/input/atp-players/wta_players.csv'
    
    
    dataset = HierarchicalTennisDataset(point_files, matches_path, atp_path, wta_path, max_seq_len=SEQ_LEN) 
    hierarchical_model = train_hierarchical_model(dataset, epochs=10, batch_size=512, device=DEVICE)

### Evaluation

In [None]:
def load_standalone_checkpoint(checkpoint_path, dataset, device='cuda'):
    print(f"üìÇ Loading checkpoint from: {checkpoint_path}")
    if device == 'cuda' and not torch.cuda.is_available():
        device = 'cpu'
        
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    cfg = checkpoint.get("config")
    vocabs = checkpoint.get("vocabs")
    
    if cfg is None or vocabs is None:
        raise ValueError("Checkpoint is missing config/vocabs. Ensure it is the new Hierarchical version.")

    # 1. Re-Initialize Model with saved config
    # This ensures the layer dimensions match exactly what was trained
    model = HierarchicalCristianGPT(
        dir_vocab_size=len(vocabs["dir_vocab"]),
        depth_vocab_size=len(vocabs["depth_vocab"]),
        type_vocab_size=len(vocabs["type_vocab"]),
        num_players=len(vocabs["player_vocab"]),
        context_dim=cfg["context_dim"],
        embed_dim=cfg["embed_dim"],
        n_head=cfg["n_head"],
        n_cycles=cfg["n_cycles"]
    ).to(device)

    # 2. Load Weights
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()

    # 3. Sync Dataset Vocabs
    print("üîÑ Syncing dataset vocabularies with checkpoint...")
    dataset.type_vocab   = vocabs["type_vocab"]
    dataset.dir_vocab    = vocabs["dir_vocab"]
    dataset.depth_vocab  = vocabs["depth_vocab"]
    dataset.player_vocab = vocabs["player_vocab"]
    
    # Update inverse lookups (needed for the decomposition step)
    dataset.inv_type_vocab = {v: k for k, v in dataset.type_vocab.items()}
    dataset.inv_dir_vocab  = {v: k for k, v in dataset.dir_vocab.items()}
    dataset.inv_depth_vocab= {v: k for k, v in dataset.depth_vocab.items()}
    
    # 4. CRITICAL: Re-decompose data
    # Since dataset was init'd with default vocabs, we must re-run this 
    # so the x_type/x_dir tensors map to the correct new IDs.
    dataset._decompose_data()
    
    print(f"‚úÖ Model loaded and Dataset tensors refreshed. (Epoch {checkpoint.get('epoch', '?')})")
    return model

In [None]:
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from collections import Counter
import seaborn as sns
import pandas as pd
import numpy as np
import random
import torch
from torch.utils.data import DataLoader, random_split

# --- 1. SHARED CONFIG & HELPERS ---
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

def get_inverse_vocabs(dataset):
    """Creates reverse lookups for the hierarchical vocabs"""
    inv_type  = {v: k for k, v in dataset.type_vocab.items()}
    inv_dir   = {v: k for k, v in dataset.dir_vocab.items()}
    inv_depth = {v: k for k, v in dataset.depth_vocab.items()}
    return inv_type, inv_dir, inv_depth

# --- 2. THE HIERARCHICAL EVALUATION FUNCTION ---
def run_full_evaluation(model, dataset, loader, test_indices, 
                        live_samples=5000, 
                        length_matches=2000, 
                        freq_matches=2000, 
                        era_matches=50, 
                        speed_samples=10000):
    
    model.eval()
    print(f"Starting Full Evaluation on {len(test_indices)} TEST SET samples...")
    
    # Helpers & IDs
    inv_type, inv_dir, inv_depth = get_inverse_vocabs(dataset)
    serve_type_id = dataset.type_vocab.get('serve', dataset.type_vocab.get('s', -1))
    unk_depth_id = dataset.depth_vocab.get('0', -1)
    unk_dir_id = dataset.dir_vocab.get('0', -1)

    # Pre-calculate Match-to-Index Map
    match_map = {}
    for idx in test_indices:
        mid = dataset.sample_match_ids[idx]
        match_map.setdefault(mid, []).append(idx)
    unique_matches = list(match_map.keys())
    
    # ==============================================================================
    # PART 1: OVERALL TACTICAL METRICS
    # ==============================================================================
    print("\n" + "="*40 + "\n PART 1: OVERALL TACTICAL METRICS \n" + "="*40)
    
    all_p_type, all_t_type = [], []
    all_p_dir,  all_t_dir  = [], []
    all_p_depth, all_t_depth = [], []
    
    with torch.no_grad():
        for batch in loader:
            # Inputs
            x_z = batch['x_dir'].to(DEVICE)
            x_t = batch['x_type'].to(DEVICE)
            x_c = batch['context'].to(DEVICE)
            x_s = batch['x_s_id'].to(DEVICE)
            x_r = batch['x_r_id'].to(DEVICE)
            
            # Forward (3 Heads)
            l_dir, l_depth, l_type = model(x_z, x_t, x_c, x_s, x_r)
            
            # Targets
            y_type = batch['y_type'].to(DEVICE).view(-1)
            y_dir  = batch['y_dir'].to(DEVICE).view(-1)
            y_depth = batch['y_depth'].to(DEVICE).view(-1)

            # --- 1. Type (Predicts everything except PAD and SERVES) ---
            mask_t = (y_type != 0)
            if serve_type_id != -1: mask_t &= (y_type != serve_type_id)

            if mask_t.sum() > 0:
                all_p_type.extend(l_type.argmax(-1).view(-1)[mask_t].cpu().numpy())
                all_t_type.extend(y_type[mask_t].cpu().numpy())

            # --- 2. Direction (Mask PAD, UNKNOWN, and SERVES) ---
            mask_d = (y_dir != 0) 
            if unk_dir_id != -1: mask_d &= (y_dir != unk_dir_id)
            if serve_type_id != -1: mask_d &= (y_type != serve_type_id)
            
            if mask_d.sum() > 0:
                all_p_dir.extend(l_dir.argmax(-1).view(-1)[mask_d].cpu().numpy())
                all_t_dir.extend(y_dir[mask_d].cpu().numpy())

            # --- 3. Depth (Mask PAD, UNKNOWN, and SERVES) ---
            mask_dp = (y_depth != 0)
            if unk_depth_id != -1: mask_dp &= (y_depth != unk_depth_id)
            if serve_type_id != -1: mask_dp &= (y_type != serve_type_id)

            if mask_dp.sum() > 0:
                all_p_depth.extend(l_depth.argmax(-1).view(-1)[mask_dp].cpu().numpy())
                all_t_depth.extend(y_depth[mask_dp].cpu().numpy())

    # Reports
    print("\n=== DIRECTION REPORT (Excluding Serves) ===")
    labels = [k for k in dataset.dir_vocab if k not in ['<pad>', '0']]
    indices = [dataset.dir_vocab[k] for k in labels]
    if indices:
        print(classification_report(all_t_dir, all_p_dir, labels=indices, target_names=labels, zero_division=0))

    print("\n=== DEPTH REPORT (Excluding Unknowns) ===")
    labels = [k for k in dataset.depth_vocab if k not in ['<pad>', '0']]
    indices = [dataset.depth_vocab[k] for k in labels]
    print(classification_report(all_t_depth, all_p_depth, labels=indices, target_names=labels, zero_division=0))

    print("\n=== SHOT TYPE REPORT (Rally Only) ===")
    labels = [k for k in dataset.type_vocab if k not in ['<pad>', '<unk>', 'serve', 's']]
    indices = [dataset.type_vocab[k] for k in labels]
    # Filter to present
    present = sorted(list(set(all_t_type)))
    present_names = [inv_type[i] for i in present]
    if present:
        print(classification_report(all_t_type, all_p_type, labels=present, target_names=present_names, zero_division=0))

    # ==============================================================================
    # PART 2: LIVE SAMPLES (Showing {live_samples} Cases)
    # ==============================================================================
    print("\n" + "="*40 + f"\n PART 2: LIVE SAMPLES (Showing {live_samples} Cases) \n" + "="*40)
    
    selected_indices = random.sample(test_indices, min(live_samples * 2, len(test_indices)))
    results_buffer = {3: [], 2:[], 1:[], 0:[]}
    printed_count = 0
    
    with torch.no_grad():
        for idx in selected_indices:
            if printed_count >= live_samples: break
            
            sample = dataset[idx]
            non_zeros = (sample['x_type'] != 0).nonzero(as_tuple=True)[0]
            if len(non_zeros) < 2: continue
            
            # Predict a random point
            valid_indices = non_zeros.tolist()
            t = random.choice(valid_indices)
            
            x_z = sample['x_dir'].unsqueeze(0).to(DEVICE)
            x_t = sample['x_type'].unsqueeze(0).to(DEVICE)
            x_c = sample['context'].unsqueeze(0).to(DEVICE)
            x_s = sample['x_s_id'].unsqueeze(0).to(DEVICE)
            x_r = sample['x_r_id'].unsqueeze(0).to(DEVICE)
            
            l_dir, l_depth, l_type = model(x_z, x_t, x_c, x_s, x_r)
            
            # --- Build History String ---
            start_idx = valid_indices[0]
            history_str = ""
            for i in range(start_idx, t + 1):
                typ_idx = sample['x_type'][i].item()
                dir_idx = sample['x_dir'][i].item()
                t_in = inv_type.get(typ_idx, '?')
                z_in = inv_dir.get(dir_idx, '?')
                
                if i == start_idx:
                    history_str += f"[Serve {z_in}] " if t_in in ['serve', 's'] else f"[{t_in}{z_in}] "
                else:
                    history_str += f"-> {t_in}{z_in} "
            
            # --- Get Prediction ---
            probs = torch.softmax(l_type[0, t], dim=0)
            conf = probs.max().item() * 100
            
            pred_t = l_type[0, t].argmax().item()
            pred_d = l_dir[0, t].argmax().item()
            pred_dp = l_depth[0, t].argmax().item()
            
            true_t = sample['y_type'][t].item()
            true_d = sample['y_dir'][t].item()
            true_dp = sample['y_depth'][t].item()
            
            if true_t == 0: continue

            s_pred_d = inv_dir.get(pred_d, '?'); s_pred_t = inv_type.get(pred_t, '?')
            s_true_d = inv_dir.get(true_d, '?'); s_true_t = inv_type.get(true_t, '?')
            
            check_d = "‚úÖ" if pred_d == true_d else "‚ùå"
            check_t = "‚úÖ" if pred_t == true_t else "‚ùå"
            check_dp = "‚úÖ" if pred_dp == true_dp else "‚ùå"
            
            def d_lbl(x): return inv_depth.get(x, 'N/A')
            
            score = (1 if pred_d == true_d else 0) + (1 if pred_t == true_t else 0) + (1 if pred_dp == true_dp else 0)
            m_id = dataset.sample_match_ids[idx]

            out = []
            out.append(f"\nMatch {m_id}:")
            out.append(f"  History:    {history_str}")
            out.append(f"  Prediction: {s_pred_t} to {s_pred_d} (Depth {d_lbl(pred_dp)}) | Conf: {conf:.0f}%")
            out.append(f"  ACTUAL:     {s_true_t} to {s_true_d} (Depth {d_lbl(true_dp)}) | {check_t} Type {check_d} Dir {check_dp} Dep")
            
            results_buffer[score].append("\n".join(out))
            printed_count += 1

    print_flag = False
    
    for s in [3,2,1,0]:
        items = results_buffer[s]
        if items:
            print(f"\n{'='*20} {s}/3 CORRECT ({len(items)} cases) {'='*20}")
            if print_flag:
                for item in items: 
                    print(item)

    # ==============================================================================
    # PART 3: GRANULAR ACCURACY VS RALLY LENGTH (Restored 3 Graphs)
    # ==============================================================================
    print("\n" + "="*40 + "\n PART 3: GRANULAR ACCURACY VS RALLY LENGTH \n" + "="*40)
    
    # 3.1 Calculate Baselines
    print("Calculating dataset baselines...")
    all_d, all_dp, all_tp = [], [], []
    for i in test_indices:
        yt = dataset[i]['y_type']
        yd = dataset[i]['y_dir']
        ydp = dataset[i]['y_depth']
        for j in range(len(yt)):
            if yt[j] == 0: continue
            is_srv = (yt[j] == serve_type_id)
            if not is_srv:
                all_tp.append(yt[j].item())
                if yd[j] != unk_dir_id and yd[j] != 0: all_d.append(yd[j].item())
                if ydp[j] != unk_depth_id and ydp[j] != 0: all_dp.append(ydp[j].item())

    def calc_baseline(data_list):
        if not data_list: return 0.33
        counts = Counter(data_list)
        total = sum(counts.values())
        return sum([(c/total)**2 for c in counts.values()])

    base_d = calc_baseline(all_d)
    base_dp = calc_baseline(all_dp)
    base_tp = calc_baseline(all_tp)
    base_pair_avg = (base_d*base_dp + base_d*base_tp + base_tp*base_dp) / 3
    base_whole = base_d * base_dp * base_tp
    print(f"Baselines -> Dir: {base_d:.2f}, Depth: {base_dp:.2f}, Type: {base_tp:.2f}, Whole: {base_whole:.4f}")
    
    # 3.2 Analysis Loop
    chosen_matches = random.sample(unique_matches, min(length_matches, len(unique_matches)))
    rl_indices = [idx for mid in chosen_matches for idx in match_map[mid]]
    print(f"Analyzing {len(rl_indices)} points from {len(chosen_matches)} matches...")

    results_p3 = []
    with torch.no_grad():
        for idx in rl_indices:
            sample = dataset[idx]
            x_z = sample['x_dir'].unsqueeze(0).to(DEVICE)
            x_t = sample['x_type'].unsqueeze(0).to(DEVICE)
            x_c = sample['context'].unsqueeze(0).to(DEVICE)
            x_s = sample['x_s_id'].unsqueeze(0).to(DEVICE)
            x_r = sample['x_r_id'].unsqueeze(0).to(DEVICE)
            y_t = sample['y_type']
            y_d = sample['y_dir']
            y_dp = sample['y_depth']

            l_dir, l_depth, l_type = model(x_z, x_t, x_c, x_s, x_r)
            preds_t = l_type.argmax(dim=-1).squeeze(0)
            preds_d = l_dir.argmax(dim=-1).squeeze(0)
            preds_dp = l_depth.argmax(dim=-1).squeeze(0)
            
            x_seq_cpu = sample['x_type']
            
            for t_step in range(len(y_t)):
                if y_t[t_step] == 0: continue
                # Skip serves for rally analysis
                if y_t[t_step] == serve_type_id: continue

                shot_num = (x_seq_cpu[:t_step+1] != 0).sum().item() + 1
                
                pt, p_d, pdp = preds_t[t_step].item(), preds_d[t_step].item(), preds_dp[t_step].item()
                tt, td, tdp = y_t[t_step].item(), y_d[t_step].item(), y_dp[t_step].item()

                ok_t = (pt == tt)
                has_dir = (td != 0 and td != unk_dir_id)
                has_depth = (tdp != 0 and tdp != unk_depth_id)
                ok_d = (p_d == td) if has_dir else False
                ok_dp = (pdp == tdp) if has_depth else False

                # Logic Split - Single
                results_p3.append({'Shot_Number': shot_num, 'Task': 'Type', 'Type': 'Single', 'Accuracy': 1 if ok_t else 0})
                if has_dir:
                    results_p3.append({'Shot_Number': shot_num, 'Task': 'Direction', 'Type': 'Single', 'Accuracy': 1 if ok_d else 0})
                if has_depth:
                    results_p3.append({'Shot_Number': shot_num, 'Task': 'Depth', 'Type': 'Single', 'Accuracy': 1 if ok_dp else 0})
                
                # Logic Split - Pair
                if has_dir:
                    results_p3.append({'Shot_Number': shot_num, 'Task': 'Dir + Type', 'Type': 'Pair', 'Accuracy': 1 if (ok_d and ok_t) else 0})
                if has_depth:
                    results_p3.append({'Shot_Number': shot_num, 'Task': 'Type + Depth', 'Type': 'Pair', 'Accuracy': 1 if (ok_dp and ok_t) else 0})
                if has_dir and has_depth:
                    results_p3.append({'Shot_Number': shot_num, 'Task': 'Dir + Depth', 'Type': 'Pair', 'Accuracy': 1 if (ok_d and ok_dp) else 0})
                
                # Logic Split - Whole
                if has_dir and has_depth:
                    results_p3.append({'Shot_Number': shot_num, 'Task': 'Whole Shot', 'Type': 'Whole', 'Accuracy': 1 if (ok_d and ok_dp and ok_t) else 0})

    if results_p3:
        df = pd.DataFrame(results_p3)
        df = df[(df['Shot_Number'] <= 12) & (df['Shot_Number'] >= 2)]
        
        palette_single = {'Direction': '#1f77b4', 'Depth': '#d62728', 'Type': '#2ca02c'}
        palette_pair   = {'Dir + Depth': '#9467bd', 'Dir + Type': '#17becf', 'Type + Depth': '#ff7f0e'}
        palette_whole  = {'Whole Shot': '#000000'}

        def setup_plot(title, baseline, base_label):
            plt.figure(figsize=(12, 5))
            plt.title(title, fontsize=14)
            plt.ylabel('Accuracy', fontsize=12)
            plt.xlabel('Shot Number', fontsize=12)
            plt.xticks(np.arange(2, 13, 1))
            
            ax = plt.gca()
            ax.yaxis.set_major_locator(ticker.MultipleLocator(0.1))
            ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.02))
            plt.grid(True, which='major', axis='y', linestyle='-', linewidth=0.75, color='grey', alpha=0.6)
            plt.grid(True, which='minor', axis='y', linestyle='--', linewidth=0.5, color='grey', alpha=0.3)
            plt.ylim(0.0, 1.0)
            
            if baseline:
                plt.axhline(baseline, color='#FF1493', linestyle=':', alpha=0.8, linewidth=2, label=base_label)

        # Graph 1: Single
        setup_plot('Single Task Accuracy vs. Rally Length', base_tp, f'Random Type ({base_tp:.2f})')
        sns.lineplot(data=df[df['Type']=='Single'], x='Shot_Number', y='Accuracy', hue='Task', style='Task', 
                     markers=True, dashes=False, palette=palette_single, linewidth=2.5, errorbar=('ci', 68))
        plt.legend(loc='lower right'); plt.show()
        
        # Graph 2: Pairwise
        setup_plot('Pairwise Accuracy vs. Rally Length', base_pair_avg, f'Random Pair (~{base_pair_avg:.2f})')
        sns.lineplot(data=df[df['Type']=='Pair'], x='Shot_Number', y='Accuracy', hue='Task', style='Task', 
                     markers=True, dashes=False, palette=palette_pair, linewidth=2.5, errorbar=('ci', 68))
        plt.legend(loc='upper right'); plt.show()

        # Graph 3: Whole Shot
        setup_plot('Whole Shot Accuracy vs. Rally Length', base_whole, f'Random Whole ({base_whole:.4f})')
        sns.lineplot(data=df[df['Type']=='Whole'], x='Shot_Number', y='Accuracy', hue='Task', style='Task', 
                     markers=True, dashes={'Whole Shot':(2,2)}, palette=palette_whole, linewidth=2.5, errorbar=('ci', 68))
        plt.legend(loc='upper right'); plt.show()

    # ==============================================================================
    # PART 4: PLAYER FREQUENCY
    # ==============================================================================
    print("\n" + "="*40 + f"\n PART 4: PLAYER FREQUENCY ({freq_matches} Matches) \n" + "="*40)
    chosen_matches = random.sample(unique_matches, min(freq_matches, len(unique_matches)))
    pf_indices = [idx for mid in chosen_matches for idx in match_map[mid]]
    
    p_counts = Counter()
    p_stats = {}
    
    with torch.no_grad():
        for idx in pf_indices:
            sample = dataset[idx]
            x_z = sample['x_dir'].unsqueeze(0).to(DEVICE)
            x_t = sample['x_type'].unsqueeze(0).to(DEVICE)
            x_c = sample['context'].unsqueeze(0).to(DEVICE)
            x_s = sample['x_s_id'].unsqueeze(0).to(DEVICE)
            x_r = sample['x_r_id'].unsqueeze(0).to(DEVICE)
            y_t = sample['y_type']
            
            _, _, l_t = model(x_z, x_t, x_c, x_s, x_r)
            preds = l_t.argmax(-1).squeeze(0)
            
            s_val, r_val = sample['x_s_id'].item(), sample['x_r_id'].item()
            x_seq_cpu = sample['x_type']
            
            for t in range(len(y_t)):
                if y_t[t] == 0: continue
                if y_t[t] == serve_type_id: continue # Skip serves

                hist_len = (x_seq_cpu[:t+1] != 0).sum().item()
                actor = s_val if (hist_len + 1) % 2 != 0 else r_val
                if actor <= 1: continue
                
                p_counts[actor] += 1
                if actor not in p_stats: p_stats[actor] = {'tot': 0, 'corr': 0}
                p_stats[actor]['tot'] += 1
                if preds[t].item() == y_t[t].item(): p_stats[actor]['corr'] += 1

    pf_data = [{'Freq': p_counts[a], 'Err': (1 - v['corr']/v['tot'])*100} for a, v in p_stats.items() if p_counts[a] > 10]
    if pf_data:
        df_pf = pd.DataFrame(pf_data)
        plt.figure(figsize=(8, 4))
        sns.regplot(data=df_pf, x='Freq', y='Err', scatter_kws={'alpha':0.5}, line_kws={'color':'red'})
        plt.xscale('log'); plt.title(f"Type Error vs Frequency (Corr: {df_pf['Freq'].corr(df_pf['Err']):.2f})"); plt.show()

    # ==============================================================================
    # PART 5: ERA STABILITY
    # ==============================================================================
    print("\n" + "="*40 + f"\n PART 5: ERA STABILITY ({era_matches} Matches/Era) \n" + "="*40)
    eras = {'Pre-2010': [], '2010-2019': [], '2020+': []}
    for m_id in unique_matches:
        try: y_year = int(str(m_id)[:4])
        except: continue
        if y_year < 2010: eras['Pre-2010'].append(m_id)
        elif y_year < 2020: eras['2010-2019'].append(m_id)
        else: eras['2020+'].append(m_id)

    era_indices = []
    era_labels_list = []
    for era_name, m_list in eras.items():
        if not m_list: continue
        chosen = random.sample(m_list, min(era_matches, len(m_list)))
        for m in chosen:
            era_indices.extend(match_map[m])
            era_labels_list.extend([era_name]*len(match_map[m]))
            
    era_res = []
    with torch.no_grad():
        for i, idx in enumerate(era_indices):
            sample = dataset[idx]
            x_z = sample['x_dir'].unsqueeze(0).to(DEVICE)
            x_t = sample['x_type'].unsqueeze(0).to(DEVICE)
            x_c = sample['context'].unsqueeze(0).to(DEVICE)
            x_s = sample['x_s_id'].unsqueeze(0).to(DEVICE)
            x_r = sample['x_r_id'].unsqueeze(0).to(DEVICE)
            y_t = sample['y_type'].to(DEVICE)
            y_d = sample['y_dir'].to(DEVICE)
            y_dp = sample['y_depth'].to(DEVICE)
            
            l_d, l_dp, l_t = model(x_z, x_t, x_c, x_s, x_r)
            p_d = l_d.argmax(-1).squeeze(0)
            p_t = l_t.argmax(-1).squeeze(0)
            p_dp = l_dp.argmax(-1).squeeze(0)
            
            mask = (y_t > 1) & (y_t != serve_type_id) & (y_d != unk_dir_id) & (y_dp != unk_depth_id)
            if mask.sum() > 0:
                correct = (p_t[mask] == y_t[mask]) & (p_d[mask] == y_d[mask]) & (p_dp[mask] == y_dp[mask])
                acc = correct.float().mean().item()
                era_res.append({'Era': era_labels_list[i], 'Whole Shot Acc': acc})
    
    if era_res:
        plt.figure(figsize=(6, 4))
        sns.barplot(data=pd.DataFrame(era_res), x='Era', y='Whole Shot Acc', palette='viridis', order=['Pre-2010', '2010-2019', '2020+'])
        plt.title('Whole Shot Accuracy by Era'); plt.ylim(0, 1); plt.show()

    # ==============================================================================
    # PART 6: RAW ERROR ANALYSIS BY SURFACE (Masked Depth)
    # ==============================================================================
    print("\n" + "="*50)
    print(" RAW ERROR ANALYSIS BY SURFACE (HierarchicalCristianGPT) ")
    print("="*50)

    # 1. Group Test Indices by Surface
    surface_map = {'Clay': [], 'Hard': [], 'Grass': []}
    for idx in test_indices:
        surf = dataset.match_meta.get(dataset.sample_match_ids[idx], {}).get('surface', 'Hard')
        found = False
        for k in surface_map: 
            if k in surf: 
                surface_map[k].append(idx)
                found = True
                break
        if not found: surface_map['Hard'].append(idx)
            
    # 2. Select Samples Balanced by Surface
    selected_indices, surface_labels = [], []
    per_surf = speed_samples // 3 
    
    for s, inds in surface_map.items():
        if not inds: continue
        chosen = random.sample(inds, min(len(inds), per_surf))
        selected_indices.extend(chosen)
        surface_labels.extend([s]*len(chosen))
        
    # 3. Evaluation Loop
    results = []
    with torch.no_grad():
        for i, idx in enumerate(selected_indices):
            sample = dataset[idx]
            surf = surface_labels[i]
            
            x_dir  = sample['x_dir'].unsqueeze(0).to(DEVICE)
            x_type = sample['x_type'].unsqueeze(0).to(DEVICE)
            x_c    = sample['context'].unsqueeze(0).to(DEVICE)
            x_s    = sample['x_s_id'].unsqueeze(0).to(DEVICE)
            x_r    = sample['x_r_id'].unsqueeze(0).to(DEVICE)
            
            y_t_gt = sample['y_type'].to(DEVICE)
            y_d_gt = sample['y_dir'].to(DEVICE)
            y_dp_gt = sample['y_depth'].to(DEVICE)
            
            logits_dir, logits_depth, logits_type_out = model(x_dir, x_type, x_c, x_s, x_r)
            
            pred_t = logits_type_out.argmax(dim=-1).squeeze(0)
            pred_d = logits_dir.argmax(dim=-1).squeeze(0)
            pred_dp = logits_depth.argmax(dim=-1).squeeze(0)
            
            seq_len = x_dir.shape[1]
            for t in range(seq_len):
                if y_t_gt[t] == 0: continue 
                
                p_t, p_d, p_dp = pred_t[t].item(), pred_d[t].item(), pred_dp[t].item()
                t_t, t_d, t_dp = y_t_gt[t].item(), y_d_gt[t].item(), y_dp_gt[t].item()
                
                whole_shot_miss = (p_t != t_t) or (p_d != t_d) or (p_dp != t_dp)

                # --- DEPTH MASKING FIX ---
                # Only calculate Depth Error if the target actually HAS depth (is not 0)
                if t_dp != 0:
                    depth_err = 1.0 if p_dp != t_dp else 0.0
                else:
                    depth_err = None 

                results.append({
                    'Surface': surf,
                    'Type Error': 1.0 if p_t != t_t else 0.0,
                    'Direction Error': 1.0 if p_d != t_d else 0.0,
                    'Depth Error': depth_err,  # <--- Masked
                    'Whole Shot Error': 1.0 if whole_shot_miss else 0.0
                })
                
    # 4. Statistics & Plotting
    if not results:
        print("No results generated.")
        return

    df = pd.DataFrame(results)
    
    # Print Table (Mean Error %) - Pandas ignores None in mean()
    stats = df.groupby('Surface')[['Type Error', 'Direction Error', 'Depth Error', 'Whole Shot Error']].mean() * 100
    print("\n--- Mean Error Rates (%) [Depth calc only on non-zero targets] ---")
    print(stats.round(2))
    
    # Plotting
    df_melt = df.melt(id_vars=['Surface'], 
                      value_vars=['Type Error', 'Direction Error', 'Depth Error', 'Whole Shot Error'], 
                      value_name='Error Rate')
    
    plt.figure(figsize=(12, 6))
    sns.barplot(data=df_melt, x='Surface', y='Error Rate', hue='variable', 
                order=['Clay', 'Hard', 'Grass'], palette='viridis')
    plt.title('Error Rates by Component (Masked Depth) vs. Whole Shot')
    plt.ylabel('Error Rate (0.0 - 1.0)')
    plt.legend(title='Metric')
    plt.grid(axis='y', alpha=0.3)
    plt.show()
    
# --- 3. RUNNER SNIPPET ---
print("Recreating validation/test split for evaluation...")
checkpoint_path="/kaggle/input/hierarchical-best/other/default/1/hierarchical_best.pth"

dataset = HierarchicalTennisDataset(point_files, matches_path, atp_path, wta_path, max_seq_len=SEQ_LEN)

hierarchical_model = load_standalone_checkpoint(checkpoint_path, dataset, device=DEVICE)

seed_everything(42)

total_len = len(dataset)
train_len = int(0.8 * total_len)
val_len = int(0.1 * total_len)
test_len = total_len - train_len - val_len

gen = torch.Generator().manual_seed(42)
_, _, test_ds = random_split(dataset, [train_len, val_len, test_len], generator=gen)

test_indices = test_ds.indices
test_loader_eval = DataLoader(test_ds, batch_size=64, shuffle=False)

run_full_evaluation(
    model=hierarchical_model, 
    dataset=dataset, 
    loader=test_loader_eval, 
    test_indices=test_indices,
    live_samples=5000, 
    length_matches=2000,
    freq_matches=2000
)

# HYBRID

## MULTI-HEAD RICH LSTM

### DOWNSAMPLED DATASET PER COMPARARE UOMINI E DONNE

### MODEL

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import copy

# --- 2. Helper to Calculate Accuracy ---
def calculate_accuracy(l_type, l_dir, l_depth, y_type, y_dir, y_depth):
    """
    Computes accuracy for all three heads combined (Average of the 3 accuracies).
    Ignores padding (0).
    """
    # Get Predictions
    p_t = l_type.argmax(dim=-1)
    p_d = l_dir.argmax(dim=-1)
    p_dp = l_depth.argmax(dim=-1)

    # Masks (Ignore 0/Padding)
    mask_t = y_type != 0
    mask_d = y_dir != 0
    mask_dp = y_depth != 0

    # Calculate Corrects
    acc_t = (p_t[mask_t] == y_type[mask_t]).float().mean().item() if mask_t.sum() > 0 else 0
    acc_d = (p_d[mask_d] == y_dir[mask_d]).float().mean().item() if mask_d.sum() > 0 else 0
    acc_dp = (p_dp[mask_dp] == y_depth[mask_dp]).float().mean().item() if mask_dp.sum() > 0 else 0

    return acc_t, acc_d, acc_dp

def train_hybrid_model(dataset, epochs=10, batch_size=64, lr=1e-3, device='cuda'):
    print("--- STARTING HYBRID (PARALLEL HEADS) TRAINING ---")
    
    # Setup DataLoaders
    # 80/15/5
    train_len = int(0.8 * len(dataset))
    val_len = int(0.15 * len(dataset))
    test_len = len(dataset) - train_len - val_len
    
    train_ds, val_ds, _ = torch.utils.data.random_split(
        dataset, [train_len, val_len, test_len],
        generator=torch.Generator().manual_seed(42)
    )
    
    train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_ds, batch_size=batch_size, shuffle=False)
    
    # Initialize Model
    model = HybridRichLSTM(
        num_players=len(dataset.player_vocab),
        type_vocab_size=len(dataset.type_vocab),
        dir_vocab_size=len(dataset.dir_vocab),
        depth_vocab_size=len(dataset.depth_vocab),
        context_dim=10
    ).to(device)
    
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    
    # Loss Function (ignores padding 0)
    criterion = nn.CrossEntropyLoss(ignore_index=0)

    # Early Stopping & Checkpointing Vars
    best_val_loss = float('inf')
    patience_counter = 0
    best_model_wts = copy.deepcopy(model.state_dict())
    
    for epoch in range(epochs):
        # --- TRAINING PHASE ---
        model.train()
        train_loss = 0
        t_acc_t, t_acc_d, t_acc_dp = 0, 0, 0
        
        for batch in train_loader:
            # Move to device
            x_t, x_d, x_dp = batch['x_type'].to(device), batch['x_dir'].to(device), batch['x_depth'].to(device)
            s_id, r_id, ctx = batch['x_s_id'].to(device), batch['x_r_id'].to(device), batch['context'].to(device)
            y_t, y_d, y_dp = batch['y_type'].to(device), batch['y_dir'].to(device), batch['y_depth'].to(device)
            
            optimizer.zero_grad()
            l_t, l_d, l_dp = model(x_t, x_d, x_dp, s_id, r_id, ctx)
            
            # Loss
            loss_t = criterion(l_t.view(-1, len(dataset.type_vocab)), y_t.view(-1))
            loss_d = criterion(l_d.view(-1, len(dataset.dir_vocab)), y_d.view(-1))
            loss_dp= criterion(l_dp.view(-1, len(dataset.depth_vocab)), y_dp.view(-1))
            loss = loss_t + loss_d + loss_dp
            
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

            # Train Acc (Batch)
            at, ad, adp = calculate_accuracy(l_t, l_d, l_dp, y_t, y_d, y_dp)
            t_acc_t += at; t_acc_d += ad; t_acc_dp += adp
            
        avg_train_loss = train_loss / len(train_loader)
        avg_t_acc = (t_acc_t + t_acc_d + t_acc_dp) / (3 * len(train_loader)) # Average of 3 tasks

        # --- VALIDATION PHASE ---
        model.eval()
        val_loss = 0
        v_acc_t, v_acc_d, v_acc_dp = 0, 0, 0
        
        with torch.no_grad():
            for batch in val_loader:
                x_t, x_d, x_dp = batch['x_type'].to(device), batch['x_dir'].to(device), batch['x_depth'].to(device)
                s_id, r_id, ctx = batch['x_s_id'].to(device), batch['x_r_id'].to(device), batch['context'].to(device)
                y_t, y_d, y_dp = batch['y_type'].to(device), batch['y_dir'].to(device), batch['y_depth'].to(device)
                
                l_t, l_d, l_dp = model(x_t, x_d, x_dp, s_id, r_id, ctx)
                
                loss_t = criterion(l_t.view(-1, len(dataset.type_vocab)), y_t.view(-1))
                loss_d = criterion(l_d.view(-1, len(dataset.dir_vocab)), y_d.view(-1))
                loss_dp= criterion(l_dp.view(-1, len(dataset.depth_vocab)), y_dp.view(-1))
                val_loss += (loss_t + loss_d + loss_dp).item()

                at, ad, adp = calculate_accuracy(l_t, l_d, l_dp, y_t, y_d, y_dp)
                v_acc_t += at; v_acc_d += ad; v_acc_dp += adp
        
        avg_val_loss = val_loss / len(val_loader)
        avg_v_acc = (v_acc_t + v_acc_d + v_acc_dp) / (3 * len(val_loader))

        # --- REPORTING ---
        print(f"Epoch {epoch+1:02d} | "
              f"Train Loss: {avg_train_loss:.4f} Acc: {avg_t_acc:.2%} | "
              f"Val Loss: {avg_val_loss:.4f} Acc: {avg_v_acc:.2%}")

        # --- EARLY STOPPING & CHECKPOINT ---
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_wts = copy.deepcopy(model.state_dict())
            patience_counter = 0
            # Save checkpoint (overwrite each time performance improves)
            torch.save(best_model_wts, 'hybrid_rich_lstm.pth')
        else:
            patience_counter += 1
            if patience_counter >= 5:
                print(f"Early stopping triggered after {epoch+1} epochs.")
                break
                
    # Load best weights
    print("Loading best model weights...")
    model.load_state_dict(best_model_wts)
    return model

In [None]:
atp_path = '/kaggle/input/atp-players/atp_players.csv'
wta_path = '/kaggle/input/wta-players/wta_players.csv'


base_path = '/kaggle/input/atp-points/'

# List all point files to merge
point_files = [
    base_path + 'charting-m-points-2020s.csv',
    base_path + 'charting-m-points-2010s.csv',
    base_path + 'charting-m-points-to-2009.csv'
]

# New Matches File
matches_path = '/kaggle/input/atp-matches-updated/charting-m-matches-updated.csv'

'''
base_path = '/kaggle/input/wta-points/'

# List all point files to merge
point_files = [
    base_path + 'charting-w-points-2020s.csv',
    base_path + 'charting-w-points-2010s.csv',
    base_path + 'charting-w-points-to-2009.csv'
]

# New Matches File
matches_path = '/kaggle/input/wta-matches/charting-w-matches.csv'
'''

dataset = DownsampledHierarchical(point_files, matches_path, atp_path, wta_path, max_seq_len=SEQ_LEN) 

rich_hybrid_LSTM = train_hybrid_model(dataset, epochs=15, batch_size=64, lr=1e-3, device='cuda')

### TEST SET EVALUATION

In [None]:
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from collections import Counter
import seaborn as sns
import pandas as pd
import numpy as np
import random
import torch
from torch.utils.data import DataLoader, random_split, Subset

# --- 1. SHARED CONFIG & HELPERS ---
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

def get_inverse_vocabs(dataset):
    """Creates reverse lookups for the hierarchical vocabs"""
    inv_type  = {v: k for k, v in dataset.type_vocab.items()}
    inv_dir   = {v: k for k, v in dataset.dir_vocab.items()}
    inv_depth = {v: k for k, v in dataset.depth_vocab.items()}
    return inv_type, inv_dir, inv_depth

# Helper for seed
def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

# --- 2. THE HIERARCHICAL EVALUATION FUNCTION (ADAPTED) ---
def run_full_evaluation(model, dataset, loader, test_indices, 
                        live_samples=5000, 
                        length_matches=2000, 
                        freq_matches=2000, 
                        era_matches=50, 
                        speed_samples=10000):
    
    model.eval()
    print(f"Starting Full Evaluation on {len(test_indices)} TEST SET samples...")
    
    # Helpers & IDs
    inv_type, inv_dir, inv_depth = get_inverse_vocabs(dataset)
    serve_type_id = dataset.type_vocab.get('serve', dataset.type_vocab.get('s', -1))
    unk_depth_id = dataset.depth_vocab.get('0', -1)
    unk_dir_id = dataset.dir_vocab.get('0', -1)

    # Pre-calculate Match-to-Index Map
    match_map = {}
    for idx in test_indices:
        mid = dataset.sample_match_ids[idx]
        match_map.setdefault(mid, []).append(idx)
    unique_matches = list(match_map.keys())
    
    # ==============================================================================
    # PART 1: OVERALL TACTICAL METRICS
    # ==============================================================================
    print("\n" + "="*40 + "\n PART 1: OVERALL TACTICAL METRICS \n" + "="*40)
    
    all_p_type, all_t_type = [], []
    all_p_dir,  all_t_dir  = [], []
    all_p_depth, all_t_depth = [], []
    
    with torch.no_grad():
        for batch in loader:
            # Inputs (ADAPTED to include depth and match new model input names/order)
            x_d = batch['x_dir'].to(DEVICE) # Direction Input
            x_t = batch['x_type'].to(DEVICE) # Type Input
            x_dp = batch['x_depth'].to(DEVICE) # DEPTH Input (New)
            x_c = batch['context'].to(DEVICE)
            x_s = batch['x_s_id'].to(DEVICE)
            x_r = batch['x_r_id'].to(DEVICE)
            
            # Forward (3 Heads) - Model returns (Type, Dir, Depth)
            l_type_pred, l_dir_pred, l_depth_pred = model(x_t, x_d, x_dp, x_s, x_r, x_c)
            
            # Re-assign to match original variable names in evaluation logic
            l_type = l_type_pred
            l_dir = l_dir_pred
            l_depth = l_depth_pred
            
            # Targets
            y_type = batch['y_type'].to(DEVICE).view(-1)
            y_dir  = batch['y_dir'].to(DEVICE).view(-1)
            y_depth = batch['y_depth'].to(DEVICE).view(-1)

            # --- 1. Type (Predicts everything except PAD and SERVES) ---
            mask_t = (y_type != 0)
            if serve_type_id != -1: mask_t &= (y_type != serve_type_id)

            if mask_t.sum() > 0:
                all_p_type.extend(l_type.argmax(-1).view(-1)[mask_t].cpu().numpy())
                all_t_type.extend(y_type[mask_t].cpu().numpy())

            # --- 2. Direction (Mask PAD, UNKNOWN, and SERVES) ---
            mask_d = (y_dir != 0) 
            if unk_dir_id != -1: mask_d &= (y_dir != unk_dir_id)
            if serve_type_id != -1: mask_d &= (y_type != serve_type_id)
            
            if mask_d.sum() > 0:
                all_p_dir.extend(l_dir.argmax(-1).view(-1)[mask_d].cpu().numpy())
                all_t_dir.extend(y_dir[mask_d].cpu().numpy())

            # --- 3. Depth (Mask PAD, UNKNOWN, and SERVES) ---
            mask_dp = (y_depth != 0)
            if unk_depth_id != -1: mask_dp &= (y_depth != unk_depth_id)
            if serve_type_id != -1: mask_dp &= (y_type != serve_type_id)

            if mask_dp.sum() > 0:
                all_p_depth.extend(l_depth.argmax(-1).view(-1)[mask_dp].cpu().numpy())
                all_t_depth.extend(y_depth[mask_dp].cpu().numpy())

    # Reports
    print("\n=== DIRECTION REPORT (Excluding Serves) ===")
    labels = [k for k in dataset.dir_vocab if k not in ['<pad>', '0']]
    indices = [dataset.dir_vocab[k] for k in labels]
    if indices:
        print(classification_report(all_t_dir, all_p_dir, labels=indices, target_names=labels, zero_division=0))

    print("\n=== DEPTH REPORT (Excluding Unknowns) ===")
    labels = [k for k in dataset.depth_vocab if k not in ['<pad>', '0']]
    indices = [dataset.depth_vocab[k] for k in labels]
    print(classification_report(all_t_depth, all_p_depth, labels=indices, target_names=labels, zero_division=0))

    print("\n=== SHOT TYPE REPORT (Rally Only) ===")
    labels = [k for k in dataset.type_vocab if k not in ['<pad>', '<unk>', 'serve', 's']]
    indices = [dataset.type_vocab[k] for k in labels]
    # Filter to present
    present = sorted(list(set(all_t_type)))
    present_names = [inv_type[i] for i in present]
    if present:
        print(classification_report(all_t_type, all_p_type, labels=present, target_names=present_names, zero_division=0))

    # ==============================================================================
    # PART 2: LIVE SAMPLES (Showing {live_samples} Cases)
    # ==============================================================================
    print("\n" + "="*40 + f"\n PART 2: LIVE SAMPLES (Showing {live_samples} Cases) \n" + "="*40)
    
    selected_indices = random.sample(test_indices, min(live_samples * 2, len(test_indices)))
    results_buffer = {3: [], 2:[], 1:[], 0:[]}
    printed_count = 0
    
    with torch.no_grad():
        for idx in selected_indices:
            if printed_count >= live_samples: break
            
            sample = dataset[idx]
            non_zeros = (sample['x_type'] != 0).nonzero(as_tuple=True)[0]
            if len(non_zeros) < 2: continue
            
            # Predict a random point
            valid_indices = non_zeros.tolist()
            t = random.choice(valid_indices)
            
            # Inputs (ADAPTED)
            x_d = sample['x_dir'].unsqueeze(0).to(DEVICE) # Direction Input (was x_z)
            x_t = sample['x_type'].unsqueeze(0).to(DEVICE)
            x_dp = sample['x_depth'].unsqueeze(0).to(DEVICE) # DEPTH Input (New)
            x_c = sample['context'].unsqueeze(0).to(DEVICE)
            x_s = sample['x_s_id'].unsqueeze(0).to(DEVICE)
            x_r = sample['x_r_id'].unsqueeze(0).to(DEVICE)
            
            # Forward (3 Heads) - Model returns (Type, Dir, Depth)
            l_type_pred, l_dir_pred, l_depth_pred = model(x_t, x_d, x_dp, x_s, x_r, x_c)
            
            # Re-assign to match original variable names
            l_type = l_type_pred
            l_dir = l_dir_pred
            l_depth = l_depth_pred
            
            # ... [Rest of Part 2 logic remains the same]
            # --- Build History String ---
            start_idx = valid_indices[0]
            history_str = ""
            for i in range(start_idx, t + 1):
                typ_idx = sample['x_type'][i].item()
                dir_idx = sample['x_dir'][i].item()
                t_in = inv_type.get(typ_idx, '?')
                z_in = inv_dir.get(dir_idx, '?')
                
                if i == start_idx:
                    history_str += f"[Serve {z_in}] " if t_in in ['serve', 's'] else f"[{t_in}{z_in}] "
                else:
                    history_str += f"-> {t_in}{z_in} "
            
            # --- Get Prediction ---
            probs = torch.softmax(l_type[0, t], dim=0)
            conf = probs.max().item() * 100
            
            pred_t = l_type[0, t].argmax().item()
            pred_d = l_dir[0, t].argmax().item()
            pred_dp = l_depth[0, t].argmax().item()
            
            true_t = sample['y_type'][t].item()
            true_d = sample['y_dir'][t].item()
            true_dp = sample['y_depth'][t].item()
            
            if true_t == 0: continue

            s_pred_d = inv_dir.get(pred_d, '?'); s_pred_t = inv_type.get(pred_t, '?')
            s_true_d = inv_dir.get(true_d, '?'); s_true_t = inv_type.get(true_t, '?')
            
            check_d = "‚úÖ" if pred_d == true_d else "‚ùå"
            check_t = "‚úÖ" if pred_t == true_t else "‚ùå"
            check_dp = "‚úÖ" if pred_dp == true_dp else "‚ùå"
            
            def d_lbl(x): return inv_depth.get(x, 'N/A')
            
            score = (1 if pred_d == true_d else 0) + (1 if pred_t == true_t else 0) + (1 if pred_dp == true_dp else 0)
            m_id = dataset.sample_match_ids[idx]

            out = []
            out.append(f"\nMatch {m_id}:")
            out.append(f"  History:    {history_str}")
            out.append(f"  Prediction: {s_pred_t} to {s_pred_d} (Depth {d_lbl(pred_dp)}) | Conf: {conf:.0f}%")
            out.append(f"  ACTUAL:     {s_true_t} to {s_true_d} (Depth {d_lbl(true_dp)}) | {check_t} Type {check_d} Dir {check_dp} Dep")
            
            results_buffer[score].append("\n".join(out))
            printed_count += 1

    print_flag = False
    
    for s in [3,2,1,0]:
        items = results_buffer[s]
        if items:
            print(f"\n{'='*20} {s}/3 CORRECT ({len(items)} cases) {'='*20}")
            if print_flag:
                for item in items: 
                    print(item)

    # ==============================================================================
    # PART 3: GRANULAR ACCURACY VS RALLY LENGTH (Restored 3 Graphs)
    # ==============================================================================
    print("\n" + "="*40 + "\n PART 3: GRANULAR ACCURACY VS RALLY LENGTH \n" + "="*40)
    
    # 3.1 Calculate Baselines
    print("Calculating dataset baselines...")
    all_d, all_dp, all_tp = [], [], []
    for i in test_indices:
        yt = dataset[i]['y_type']
        yd = dataset[i]['y_dir']
        ydp = dataset[i]['y_depth']
        for j in range(len(yt)):
            if yt[j] == 0: continue
            is_srv = (yt[j] == serve_type_id)
            if not is_srv:
                all_tp.append(yt[j].item())
                if yd[j] != unk_dir_id and yd[j] != 0: all_d.append(yd[j].item())
                if ydp[j] != unk_depth_id and ydp[j] != 0: all_dp.append(ydp[j].item())

    def calc_baseline(data_list):
        if not data_list: return 0.33
        counts = Counter(data_list)
        total = sum(counts.values())
        return sum([(c/total)**2 for c in counts.values()])

    base_d = calc_baseline(all_d)
    base_dp = calc_baseline(all_dp)
    base_tp = calc_baseline(all_tp)
    base_pair_avg = (base_d*base_dp + base_d*base_tp + base_tp*base_dp) / 3
    base_whole = base_d * base_dp * base_tp
    print(f"Baselines -> Dir: {base_d:.2f}, Depth: {base_dp:.2f}, Type: {base_tp:.2f}, Whole: {base_whole:.4f}")
    
    # 3.2 Analysis Loop
    chosen_matches = random.sample(unique_matches, min(length_matches, len(unique_matches)))
    rl_indices = [idx for mid in chosen_matches for idx in match_map[mid]]
    print(f"Analyzing {len(rl_indices)} points from {len(chosen_matches)} matches...")

    results_p3 = []
    with torch.no_grad():
        for idx in rl_indices:
            sample = dataset[idx]
            
            # Inputs (ADAPTED)
            x_d = sample['x_dir'].unsqueeze(0).to(DEVICE) # Direction Input (was x_z)
            x_t = sample['x_type'].unsqueeze(0).to(DEVICE)
            x_dp = sample['x_depth'].unsqueeze(0).to(DEVICE) # DEPTH Input (New)
            x_c = sample['context'].unsqueeze(0).to(DEVICE)
            x_s = sample['x_s_id'].unsqueeze(0).to(DEVICE)
            x_r = sample['x_r_id'].unsqueeze(0).to(DEVICE)
            
            y_t = sample['y_type']
            y_d = sample['y_dir']
            y_dp = sample['y_depth']

            # Forward (3 Heads) - Model returns (Type, Dir, Depth)
            l_type_pred, l_dir_pred, l_depth_pred = model(x_t, x_d, x_dp, x_s, x_r, x_c)
            
            # Prediction assignment (ADAPTED)
            preds_t = l_type_pred.argmax(dim=-1).squeeze(0)
            preds_d = l_dir_pred.argmax(dim=-1).squeeze(0)
            preds_dp = l_depth_pred.argmax(dim=-1).squeeze(0)
            
            x_seq_cpu = sample['x_type']
            
            for t_step in range(len(y_t)):
                if y_t[t_step] == 0: continue
                # Skip serves for rally analysis
                if y_t[t_step] == serve_type_id: continue

                shot_num = (x_seq_cpu[:t_step+1] != 0).sum().item() + 1
                
                pt, p_d, pdp = preds_t[t_step].item(), preds_d[t_step].item(), preds_dp[t_step].item()
                tt, td, tdp = y_t[t_step].item(), y_d[t_step].item(), y_dp[t_step].item()

                ok_t = (pt == tt)
                has_dir = (td != 0 and td != unk_dir_id)
                has_depth = (tdp != 0 and tdp != unk_depth_id)
                ok_d = (p_d == td) if has_dir else False
                ok_dp = (pdp == tdp) if has_depth else False

                # Logic Split - Single
                results_p3.append({'Shot_Number': shot_num, 'Task': 'Type', 'Type': 'Single', 'Accuracy': 1 if ok_t else 0})
                if has_dir:
                    results_p3.append({'Shot_Number': shot_num, 'Task': 'Direction', 'Type': 'Single', 'Accuracy': 1 if ok_d else 0})
                if has_depth:
                    results_p3.append({'Shot_Number': shot_num, 'Task': 'Depth', 'Type': 'Single', 'Accuracy': 1 if ok_dp else 0})
                
                # Logic Split - Pair
                if has_dir:
                    results_p3.append({'Shot_Number': shot_num, 'Task': 'Dir + Type', 'Type': 'Pair', 'Accuracy': 1 if (ok_d and ok_t) else 0})
                if has_depth:
                    results_p3.append({'Shot_Number': shot_num, 'Task': 'Type + Depth', 'Type': 'Pair', 'Accuracy': 1 if (ok_dp and ok_t) else 0})
                if has_dir and has_depth:
                    results_p3.append({'Shot_Number': shot_num, 'Task': 'Dir + Depth', 'Type': 'Pair', 'Accuracy': 1 if (ok_d and ok_dp) else 0})
                
                # Logic Split - Whole
                if has_dir and has_depth:
                    results_p3.append({'Shot_Number': shot_num, 'Task': 'Whole Shot', 'Type': 'Whole', 'Accuracy': 1 if (ok_d and ok_dp and ok_t) else 0})

    if results_p3:
        df = pd.DataFrame(results_p3)
        df = df[(df['Shot_Number'] <= 12) & (df['Shot_Number'] >= 2)]
        
        palette_single = {'Direction': '#1f77b4', 'Depth': '#d62728', 'Type': '#2ca02c'}
        palette_pair   = {'Dir + Depth': '#9467bd', 'Dir + Type': '#17becf', 'Type + Depth': '#ff7f0e'}
        palette_whole  = {'Whole Shot': '#000000'}

        def setup_plot(title, baseline, base_label):
            plt.figure(figsize=(12, 5))
            plt.title(title, fontsize=14)
            plt.ylabel('Accuracy', fontsize=12)
            plt.xlabel('Shot Number', fontsize=12)
            plt.xticks(np.arange(2, 13, 1))
            
            ax = plt.gca()
            ax.yaxis.set_major_locator(ticker.MultipleLocator(0.1))
            ax.yaxis.set_minor_locator(ticker.MultipleLocator(0.02))
            plt.grid(True, which='major', axis='y', linestyle='-', linewidth=0.75, color='grey', alpha=0.6)
            plt.grid(True, which='minor', axis='y', linestyle='--', linewidth=0.5, color='grey', alpha=0.3)
            plt.ylim(0.0, 1.0)
            
            if baseline:
                plt.axhline(baseline, color='#FF1493', linestyle=':', alpha=0.8, linewidth=2, label=base_label)

        # Graph 1: Single
        setup_plot('Single Task Accuracy vs. Rally Length', base_tp, f'Random Type ({base_tp:.2f})')
        sns.lineplot(data=df[df['Type']=='Single'], x='Shot_Number', y='Accuracy', hue='Task', style='Task', 
                     markers=True, dashes=False, palette=palette_single, linewidth=2.5, errorbar=('ci', 68))
        plt.legend(loc='lower right'); plt.show()
        
        # Graph 2: Pairwise
        setup_plot('Pairwise Accuracy vs. Rally Length', base_pair_avg, f'Random Pair (~{base_pair_avg:.2f})')
        sns.lineplot(data=df[df['Type']=='Pair'], x='Shot_Number', y='Accuracy', hue='Task', style='Task', 
                     markers=True, dashes=False, palette=palette_pair, linewidth=2.5, errorbar=('ci', 68))
        plt.legend(loc='upper right'); plt.show()

        # Graph 3: Whole Shot
        setup_plot('Whole Shot Accuracy vs. Rally Length', base_whole, f'Random Whole ({base_whole:.4f})')
        sns.lineplot(data=df[df['Type']=='Whole'], x='Shot_Number', y='Accuracy', hue='Task', style='Task', 
                     markers=True, dashes={'Whole Shot':(2,2)}, palette=palette_whole, linewidth=2.5, errorbar=('ci', 68))
        plt.legend(loc='upper right'); plt.show()

    # ==============================================================================
    # PART 4: PLAYER FREQUENCY
    # ==============================================================================
    print("\n" + "="*40 + f"\n PART 4: PLAYER FREQUENCY ({freq_matches} Matches) \n" + "="*40)
    chosen_matches = random.sample(unique_matches, min(freq_matches, len(unique_matches)))
    pf_indices = [idx for mid in chosen_matches for idx in match_map[mid]]
    
    p_counts = Counter()
    p_stats = {}
    
    with torch.no_grad():
        for idx in pf_indices:
            sample = dataset[idx]
            
            # Inputs (ADAPTED)
            x_d = sample['x_dir'].unsqueeze(0).to(DEVICE) # Direction Input (was x_z)
            x_t = sample['x_type'].unsqueeze(0).to(DEVICE)
            x_dp = sample['x_depth'].unsqueeze(0).to(DEVICE) # DEPTH Input (New)
            x_c = sample['context'].unsqueeze(0).to(DEVICE)
            x_s = sample['x_s_id'].unsqueeze(0).to(DEVICE)
            x_r = sample['x_r_id'].unsqueeze(0).to(DEVICE)
            y_t = sample['y_type']
            
            # Forward (3 Heads) - Model returns (Type, Dir, Depth)
            l_type_pred, _, _ = model(x_t, x_d, x_dp, x_s, x_r, x_c)
            
            # Re-assign to match original variable names
            l_t = l_type_pred
            
            preds = l_t.argmax(-1).squeeze(0)
            
            s_val, r_val = sample['x_s_id'].item(), sample['x_r_id'].item()
            x_seq_cpu = sample['x_type']
            
            for t in range(len(y_t)):
                if y_t[t] == 0: continue
                if y_t[t] == serve_type_id: continue # Skip serves

                hist_len = (x_seq_cpu[:t+1] != 0).sum().item()
                actor = s_val if (hist_len + 1) % 2 != 0 else r_val
                if actor <= 1: continue
                
                p_counts[actor] += 1
                if actor not in p_stats: p_stats[actor] = {'tot': 0, 'corr': 0}
                p_stats[actor]['tot'] += 1
                if preds[t].item() == y_t[t].item(): p_stats[actor]['corr'] += 1

    pf_data = [{'Freq': p_counts[a], 'Err': (1 - v['corr']/v['tot'])*100} for a, v in p_stats.items() if p_counts[a] > 10]
    if pf_data:
        df_pf = pd.DataFrame(pf_data)
        plt.figure(figsize=(8, 4))
        sns.regplot(data=df_pf, x='Freq', y='Err', scatter_kws={'alpha':0.5}, line_kws={'color':'red'})
        plt.xscale('log'); plt.title(f"Type Error vs Frequency (Corr: {df_pf['Freq'].corr(df_pf['Err']):.2f})"); plt.show()

    # ==============================================================================
    # PART 5: ERA STABILITY
    # ==============================================================================
    print("\n" + "="*40 + f"\n PART 5: ERA STABILITY ({era_matches} Matches/Era) \n" + "="*40)
    eras = {'Pre-2010': [], '2010-2019': [], '2020+': []}
    for m_id in unique_matches:
        try: y_year = int(str(m_id)[:4])
        except: continue
        if y_year < 2010: eras['Pre-2010'].append(m_id)
        elif y_year < 2020: eras['2010-2019'].append(m_id)
        else: eras['2020+'].append(m_id)

    era_indices = []
    era_labels_list = []
    for era_name, m_list in eras.items():
        if not m_list: continue
        chosen = random.sample(m_list, min(era_matches, len(m_list)))
        for m in chosen:
            era_indices.extend(match_map[m])
            era_labels_list.extend([era_name]*len(match_map[m]))
            
    era_res = []
    with torch.no_grad():
        for i, idx in enumerate(era_indices):
            sample = dataset[idx]
            
            # Inputs (ADAPTED)
            x_d = sample['x_dir'].unsqueeze(0).to(DEVICE) # Direction Input (was x_z)
            x_t = sample['x_type'].unsqueeze(0).to(DEVICE)
            x_dp = sample['x_depth'].unsqueeze(0).to(DEVICE) # DEPTH Input (New)
            x_c = sample['context'].unsqueeze(0).to(DEVICE)
            x_s = sample['x_s_id'].unsqueeze(0).to(DEVICE)
            x_r = sample['x_r_id'].unsqueeze(0).to(DEVICE)
            
            y_t = sample['y_type'].to(DEVICE)
            y_d = sample['y_dir'].to(DEVICE)
            y_dp = sample['y_depth'].to(DEVICE)
            
            # Forward (3 Heads) - Model returns (Type, Dir, Depth)
            l_t, l_d, l_dp = model(x_t, x_d, x_dp, x_s, x_r, x_c)
            
            p_d = l_d.argmax(-1).squeeze(0)
            p_t = l_t.argmax(-1).squeeze(0)
            p_dp = l_dp.argmax(-1).squeeze(0)
            
            mask = (y_t > 1) & (y_t != serve_type_id) & (y_d != unk_dir_id) & (y_dp != unk_depth_id)
            if mask.sum() > 0:
                correct = (p_t[mask] == y_t[mask]) & (p_d[mask] == y_d[mask]) & (p_dp[mask] == y_dp[mask])
                acc = correct.float().mean().item()
                era_res.append({'Era': era_labels_list[i], 'Whole Shot Acc': acc})
    
    if era_res:
        plt.figure(figsize=(6, 4))
        sns.barplot(data=pd.DataFrame(era_res), x='Era', y='Whole Shot Acc', palette='viridis', order=['Pre-2010', '2010-2019', '2020+'])
        plt.title('Whole Shot Accuracy by Era'); plt.ylim(0, 1); plt.show()
    
    # ==============================================================================
    # PART 6: SURFACES (Depth Masked - Full Scan)
    # ==============================================================================
    batch_size=64
    print("\n" + "="*50)
    print(" PART 6: SURFACE ERROR ANALYSIS (Full Test Set Scan) ")
    print("="*50)

    # --- 1. Efficient Surface Grouping ---
    print(f"Grouping {len(test_indices)} test samples by surface...")
    surface_indices = {'Clay': [], 'Hard': [], 'Grass': []}
    
    for idx in test_indices:
        m_id = dataset.sample_match_ids[idx]
        surf = dataset.match_meta.get(m_id, {}).get('surface', 'Hard')
        
        if 'Clay' in surf:   target = 'Clay'
        elif 'Grass' in surf: target = 'Grass'
        else:                 target = 'Hard'
        
        surface_indices[target].append(idx)

    # --- 2. Evaluation Loop ---
    results = []
    inv_depth = {v: k for k, v in dataset.depth_vocab.items()}
    inv_dir = {v: k for k, v in dataset.dir_vocab.items()}
    
    # Identify IDs to strictly target or ignore (7, 8, 9)
    valid_depth_ids = [v for k, v in dataset.depth_vocab.items() if k in ['7', '8', '9']]

    with torch.no_grad():
        for surf, inds in surface_indices.items():
            if not inds: 
                continue
            
            print(f"Scanning {len(inds)} samples for {surf}...")
            
            # Create a DataLoader for this surface slice
            surf_ds = Subset(dataset, inds)
            loader = DataLoader(surf_ds, batch_size=batch_size, shuffle=False) # batch_size is now defined
            
            for batch in loader:
                # Move Inputs
                x_t = batch['x_type'].to(DEVICE)
                x_d = batch['x_dir'].to(DEVICE)
                x_dp = batch['x_depth'].to(DEVICE)
                x_s = batch['x_s_id'].to(DEVICE)
                x_r = batch['x_r_id'].to(DEVICE)
                x_c = batch['context'].to(DEVICE)
                
                # Move Targets
                y_t = batch['y_type'].to(DEVICE)
                y_d = batch['y_dir'].to(DEVICE)
                y_dp = batch['y_depth'].to(DEVICE)
                
                l_t, l_d, l_dp = model(x_t, x_d, x_dp, x_s, x_r, x_c)
                
                p_t = l_t.argmax(dim=-1)
                p_d = l_d.argmax(dim=-1)
                p_dp = l_dp.argmax(dim=-1)
                
                # --- Vectorized Error Calculation ---
                valid_mask = (y_t != 0)
                if valid_mask.sum() == 0: continue

                curr_y_t = y_t[valid_mask].cpu().numpy()
                curr_p_t = p_t[valid_mask].cpu().numpy()
                
                curr_y_d = y_d[valid_mask].cpu().numpy()
                curr_p_d = p_d[valid_mask].cpu().numpy()
                
                curr_y_dp = y_dp[valid_mask].cpu().numpy()
                curr_p_dp = p_dp[valid_mask].cpu().numpy()

                for i in range(len(curr_y_t)):
                    # A. Type Error
                    t_err = 1.0 if curr_y_t[i] != curr_p_t[i] else 0.0
                    
                    # B. Direction Error (Strict Masking)
                    d_target_str = inv_dir.get(curr_y_d[i], '0')
                    if curr_y_d[i] != 0 and d_target_str not in ['<pad>', '<unk>', '0']:
                        d_err = 1.0 if curr_y_d[i] != curr_p_d[i] else 0.0
                    else:
                        d_err = None 
                        
                    # C. Depth Error (VERY Strict Masking: Only 7, 8, 9)
                    if curr_y_dp[i] in valid_depth_ids:
                        dp_err = 1.0 if curr_y_dp[i] != curr_p_dp[i] else 0.0
                    else:
                        dp_err = None 
                        
                    # D. Whole Shot
                    miss = False
                    if t_err == 1.0: miss = True
                    if d_err is not None and d_err == 1.0: miss = True
                    if dp_err is not None and dp_err == 1.0: miss = True
                    
                    ws_err = 1.0 if miss else 0.0
                    
                    results.append({
                        'Surface': surf,
                        'Type Error': t_err,
                        'Direction Error': d_err,
                        'Depth Error': dp_err,
                        'Whole Shot Error': ws_err
                    })

    # --- 3. Output Generation ---
    if not results:
        print("No results generated.")
        return

    df = pd.DataFrame(results)
    
    depth_counts = df.groupby('Surface')['Depth Error'].count()
    print("\n[Diagnostics] Valid Depth Samples found per surface:")
    print(depth_counts)
    
    stats = df.groupby('Surface')[['Type Error', 'Direction Error', 'Depth Error', 'Whole Shot Error']].mean() * 100
    print("\n--- Mean Error Rates (%) [Lower is Better] ---")
    print(stats.round(2))
    
    df_melt = df.melt(id_vars=['Surface'], 
                      value_vars=['Type Error', 'Direction Error', 'Depth Error', 'Whole Shot Error'], 
                      value_name='Error Rate')
    
    plt.figure(figsize=(12, 6))
    sns.barplot(data=df_melt, x='Surface', y='Error Rate', hue='variable', 
                order=['Clay', 'Hard', 'Grass'], palette='viridis', errorbar=('ci', 95))
    
    plt.title(f'Error Rates by Component (Full Test Scan, N={len(df)})')
    plt.ylabel('Error Rate (0.0 - 1.0)')
    plt.legend(title='Metric')
    plt.grid(axis='y', alpha=0.3)
    plt.ylim(0, 1.0)
    plt.show()

# --- 3. RUNNER SNIPPET (ADAPTED for HybridRichLSTM) ---
if 'dataset' in globals() and 'rich_hybrid_LSTM' in globals():
    print("Recreating validation/test split for evaluation...")
    seed_everything(42)
    
    total_len = len(dataset)
    train_len = int(0.8 * total_len)
    val_len = int(0.15 * total_len)
    test_len = total_len - train_len - val_len
    
    gen = torch.Generator().manual_seed(42)
    _, _, test_ds = random_split(dataset, [train_len, val_len, test_len], generator=gen)
    
    test_indices = test_ds.indices
    test_loader_eval = DataLoader(test_ds, batch_size=64, shuffle=False)

    run_full_evaluation(
        model=rich_hybrid_LSTM, 
        dataset=dataset, 
        loader=test_loader_eval, 
        test_indices=test_indices,
        live_samples=5000, 
        length_matches=2000,
        freq_matches=2000
    )
else:
    print("Error: 'dataset' or 'hierarchical_model' not found. Run training and ensure the trained model is named 'hierarchical_model' or update the runner snippet.")