# Tennis Shot Prediction - Repository Demo

This notebook demonstrates how to use the modular tennis shot prediction repository to analyze and predict tennis shots. The repository provides a clean, well-structured approach to tennis analytics using transformer-based neural networks.

## üéæ What we'll cover:
1. **Setup and Environment Configuration** - Import modules and configure the environment
2. **Load and Initialize Dataset** - Use the repository's data loading functionality  
3. **Model Architecture** - Leverage pre-built transformer models
4. **Pre-trained Model Loading** - Load and configure saved models
5. **Mid-Rally Prediction Testing** - Test on real tennis sequences
6. **Interactive Prediction** - Make predictions on custom rally strings
7. **Performance Analysis** - Comprehensive model evaluation
8. **Tactical Intelligence** - Analyze tennis-specific AI capabilities

---

In [None]:
import sys
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("GITHUB_TOKEN")
secret_value_1 = user_secrets.get_secret("GITHUB_USER")

!cd "/kaggle/working/"
!rm -r "/kaggle/working/Tennis-Shot-Prediction"

!git clone https://{secret_value_1}:{secret_value_0}@github.com/SoykatAmin/Tennis-Shot-Prediction.git

In [None]:
import sys
sys.path.append("/kaggle/working/Tennis-Shot-Prediction/")

## 1. Setup and Environment Configuration

Let's start by importing the necessary libraries and setting up our environment. We'll use the modular structure from our repository.

In [None]:
# Core libraries
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
import os
import sys
from pathlib import Path

# Add the src directory to the path so we can import our modules
repo_root = Path.cwd().parent  # Go up from notebooks to repo root
sys.path.append(str(repo_root))
print(f"Repository root: {repo_root}")

# Import our custom modules from the repository
from src.data import MCPTennisDataset, create_data_loaders, compute_class_weights
from src.data.utils import (
    calculate_directional_accuracy, 
    calculate_winner_detection_rate,
    calculate_top_k_accuracy
)
from src.models import SymbolicTinyRM_PlayerAware, SymbolicTinyRM_Context, FocalLoss
from src.utils import setup_logging, get_device, print_model_summary

print("‚úÖ All imports successful!")

In [None]:
# Configuration
DEVICE = get_device('auto')  # Use our utility function
SEQ_LEN = 30
BATCH_SIZE = 64
EPOCHS = 10

print(f"Using device: {DEVICE}")
print(f"Sequence length: {SEQ_LEN}")
print(f"Batch size: {BATCH_SIZE}")

# Data paths - Update these to match your system
# For demo purposes, we'll use sample paths
DATA_PATHS = {
    'atp_points': '/kaggle/input/tennis-match-charting-project/charting-m-points.csv',
    'atp_matches': '/kaggle/input/tennis-match-charting-project/charting-m-matches.csv', 
    'atp_players': '/kaggle/input/tennis-players/atp_players.csv',
    'wta_players': '/kaggle/input/tennis-players/wta_players.csv',
    'points_path': '/kaggle/input/tennis-match-charting-project/charting-m-points.csv',
    'matches_path': '/kaggle/input/tennis-match-charting-project/charting-m-matches.csv'
}

# Check if data files exist
for name, path in DATA_PATHS.items():
    full_path = repo_root / path
    exists = full_path.exists()
    print(f"üìÅ {name}: {'‚úÖ Found' if exists else '‚ùå Not found'} at {full_path}")

print("\nüí° Note: Update DATA_PATHS above with your actual file locations if files not found.")

## 2. Load and Initialize Dataset

Now we'll use the repository's `MCPTennisDataset` class to load and process tennis data. This class handles all the complex data preprocessing, including:
- Rally sequence parsing
- Player handedness information
- Data augmentation (left/right mirroring)
- Context encoding (surface, score, etc.)

In [None]:
dataset = MCPTennisDataset(DATA_PATHS['points_path'], DATA_PATHS['matches_path'], DATA_PATHS['atp_players'], DATA_PATHS['wta_players'], max_seq_len=SEQ_LEN)

# Display dataset information
print(f"\nüìä Dataset Information:")
print(f"Total samples: {len(dataset):,}")
print(f"Shot vocabulary size: {len(dataset.shot_vocab)}")
print(f"Zone vocabulary size: {len(dataset.zone_vocab)}")
print(f"Player vocabulary size: {len(dataset.player_vocab)}")
print(f"Surface types: {list(dataset.surface_vocab.keys())}")

## 3. Define Model Architecture

The repository provides two main model architectures:
1. **`SymbolicTinyRM_PlayerAware`** - Includes player embeddings for personalized predictions
2. **`SymbolicTinyRM_Context`** - Context-only model without player-specific information

Let's initialize both models to see their capabilities.

In [None]:
# Model configuration
MODEL_CONFIG = {
    'embed_dim': 64,
    'n_head': 4,
    'n_cycles': 3,
    'seq_len': SEQ_LEN,
    'context_dim': 6,
    'dropout': 0.1
}

print("ü§ñ Initializing Models...")

# 1. Player-aware model (includes player embeddings)
player_aware_model = SymbolicTinyRM_PlayerAware(
    zone_vocab_size=len(dataset.zone_vocab),
    type_vocab_size=len(dataset.shot_vocab),
    num_players=len(dataset.player_vocab),
    **MODEL_CONFIG
).to(DEVICE)

print("‚úÖ Player-aware model created")
print_model_summary(player_aware_model, "Player-Aware Transformer")

# 2. Context-only model (no player embeddings)
context_model = SymbolicTinyRM_Context(
    zone_vocab_size=len(dataset.zone_vocab),
    type_vocab_size=len(dataset.shot_vocab),
    **MODEL_CONFIG
).to(DEVICE)

print("‚úÖ Context-only model created")
print_model_summary(context_model, "Context-Only Transformer")

## 4. Load Pre-trained Model Weights

In a real scenario, you would load pre-trained weights here. For this demo, we'll show how the loading process would work and then train a small model for demonstration.

In [None]:
# Check for pre-trained models in checkpoints directory
checkpoint_dir = repo_root / 'checkpoints'
checkpoint_dir.mkdir(exist_ok=True)

# Look for saved models
model_files = list(checkpoint_dir.glob('*.pth'))
print(f"üîç Found {len(model_files)} model files in {checkpoint_dir}")

# User choice: Load weights or train model
print("\nüéØ MODEL INITIALIZATION OPTIONS:")
print("1. Load pre-trained weights (if available)")
print("2. Train model from scratch")
print("3. Use randomly initialized model (for quick demo)")

# For interactive use, you can change this variable
USER_CHOICE = 2  # Change this to 1, 2, or 3 based on your preference

model_loaded = False

if USER_CHOICE == 1 and model_files:
    print(f"\nüìÇ Available models:")
    for i, model_file in enumerate(model_files):
        print(f"  {i+1}. {model_file.name}")
    
    # Try to load the first model found (you can modify this to select a specific model)
    try:
        model_path = model_files[0]  # Use first model, or change index to select different model
        print(f"\nüîÑ Loading model from {model_path}")
        
        # Load the state dict
        checkpoint = torch.load(model_path, map_location=DEVICE)
        if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
            player_aware_model.load_state_dict(checkpoint['model_state_dict'])
            print(f"‚úÖ Loaded player-aware model state from checkpoint")
            if 'context_model_state_dict' in checkpoint:
                context_model.load_state_dict(checkpoint['context_model_state_dict'])
                print(f"‚úÖ Loaded context-only model state from checkpoint")
        else:
            player_aware_model.load_state_dict(checkpoint)
            print(f"‚úÖ Loaded player-aware model state directly")
        
        model_loaded = True
            
    except Exception as e:
        print(f"‚ö†Ô∏è Could not load model: {e}")
        print("Will train model from scratch instead...")
        USER_CHOICE = 2

elif USER_CHOICE == 2:
    print("\nüéì TRAINING MODELS FROM SCRATCH")
    print("This will train both models for a few epochs...")
    
    # Quick training function
    # Quick training function
    def quick_train_models(player_model, context_model, dataset, epochs=EPOCHS):
        from torch.utils.data import DataLoader, random_split
        from src.models import FocalLoss
        import torch
        
        # Split dataset into train (80%) and validation (20%)
        train_size = int(0.8 * len(dataset))
        val_size = len(dataset) - train_size
        train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
        
        print(f"üìä Dataset split: {train_size} training, {val_size} validation samples")
        
        # Create data loaders
        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
        
        # Loss function and optimizers
        # FIX: Changed alpha=1.0 to alpha=None to avoid 'float' object error
        criterion = FocalLoss(alpha=None, gamma=2.0, reduction='mean')
        
        optimizer_pa = torch.optim.AdamW(player_model.parameters(), lr=1e-3, weight_decay=1e-4)
        optimizer_co = torch.optim.AdamW(context_model.parameters(), lr=1e-3, weight_decay=1e-4)
        
        print(f"üèÉ Training for {epochs} epochs with train/validation split...")
        
        for epoch in range(epochs):
            # --- Training phase ---
            player_model.train()
            context_model.train()
            
            # FIX: Initialize variables with the names used inside the loop
            total_loss_pa = 0
            total_loss_co = 0
            num_batches = 0
            
            print(f"\nüìà Epoch {epoch+1}/{epochs} - Training...")
            for batch_idx, batch in enumerate(train_loader):
                if batch_idx >= 15:  # Limit to 15 batches per epoch for quick training
                    break
                
                try:
                    # Handle context
                    if isinstance(batch['context'], (list, tuple)):
                        x_context = torch.stack([torch.tensor(ctx, dtype=torch.float32) for ctx in batch['context']]).to(DEVICE)
                    elif isinstance(batch['context'], torch.Tensor):
                        x_context = batch['context'].to(DEVICE)
                    else:
                        x_context = torch.tensor(batch['context'], dtype=torch.float32).to(DEVICE)
                        if x_context.dim() == 1:
                            x_context = x_context.unsqueeze(0)
                    
                    x_zone = batch['x_zone'].to(DEVICE)
                    x_type = batch['x_type'].to(DEVICE) 
                    y_target = batch['y_target'].to(DEVICE)
                    
                    batch_size = x_zone.size(0)
                    
                    # --- 1. Train Player-Aware Model ---
                    optimizer_pa.zero_grad()
                    
                    # Handle player IDs
                    if 'x_s_id' in batch and 'x_r_id' in batch:
                        x_s_id = batch['x_s_id'].to(DEVICE)
                        x_r_id = batch['x_r_id'].to(DEVICE)
                    else:
                        # Dummy IDs if missing
                        x_s_id = torch.zeros(batch_size, dtype=torch.long, device=DEVICE)
                        x_r_id = torch.zeros(batch_size, dtype=torch.long, device=DEVICE)
                
                    logits_pa = player_model(x_zone, x_type, x_context, x_s_id, x_r_id)
                    loss_pa = criterion(logits_pa.view(-1, logits_pa.size(-1)), y_target.view(-1))
                    loss_pa.backward()
                    optimizer_pa.step()
                    
                    # --- 2. Train Context-Only Model ---
                    optimizer_co.zero_grad()
                    logits_co = context_model(x_zone, x_type, x_context)
                    loss_co = criterion(logits_co.view(-1, logits_co.size(-1)), y_target.view(-1))
                    loss_co.backward()
                    optimizer_co.step()
                    
                    # Update totals
                    total_loss_pa += loss_pa.item()
                    total_loss_co += loss_co.item()
                    num_batches += 1
                
                except Exception as batch_error:
                    print(f"   Skipping batch {batch_idx} due to error: {batch_error}")
                    continue
            
            # Calculate training averages
            avg_train_loss_pa = total_loss_pa / num_batches if num_batches > 0 else 0
            avg_train_loss_co = total_loss_co / num_batches if num_batches > 0 else 0
            
            # --- Validation phase ---
            print(f"üìä Epoch {epoch+1}/{epochs} - Validation...")
            player_model.eval()
            context_model.eval()
            
            val_loss_pa = 0
            val_loss_co = 0
            val_correct_pa = 0
            val_correct_co = 0
            val_total = 0
            val_batches = 0
            
            with torch.no_grad():
                for batch_idx, batch in enumerate(val_loader):
                    if batch_idx >= 5:  # Limit validation batches
                        break
                    
                    try:
                        # Handle validation batch data
                        if isinstance(batch['context'], (list, tuple)):
                            x_context = torch.stack([torch.tensor(ctx, dtype=torch.float32) for ctx in batch['context']]).to(DEVICE)
                        elif isinstance(batch['context'], torch.Tensor):
                            x_context = batch['context'].to(DEVICE)
                        else:
                            x_context = torch.tensor(batch['context'], dtype=torch.float32).to(DEVICE)
                            if x_context.dim() == 1:
                                x_context = x_context.unsqueeze(0)
                        
                        x_zone = batch['x_zone'].to(DEVICE)
                        x_type = batch['x_type'].to(DEVICE) 
                        y_target = batch['y_target'].to(DEVICE)
                        batch_size = x_zone.size(0)
                        
                        # Handle player IDs
                        if 'x_s_id' in batch and 'x_r_id' in batch:
                            x_s_id = batch['x_s_id'].to(DEVICE)
                            x_r_id = batch['x_r_id'].to(DEVICE)
                        else:
                            x_s_id = torch.zeros(batch_size, dtype=torch.long, device=DEVICE)
                            x_r_id = torch.zeros(batch_size, dtype=torch.long, device=DEVICE)
                        
                        # Forward pass
                        logits_pa = player_model(x_zone, x_type, x_context, x_s_id, x_r_id)
                        logits_co = context_model(x_zone, x_type, x_context)
                        
                        # Loss
                        loss_pa = criterion(logits_pa.view(-1, logits_pa.size(-1)), y_target.view(-1))
                        loss_co = criterion(logits_co.view(-1, logits_co.size(-1)), y_target.view(-1))
                        
                        val_loss_pa += loss_pa.item()
                        val_loss_co += loss_co.item()
                        
                        # Accuracy (mask padding)
                        mask = (y_target != 0)
                        if mask.any():
                            pred_pa = logits_pa.argmax(dim=-1)
                            pred_co = logits_co.argmax(dim=-1)
                            
                            val_correct_pa += (pred_pa[mask] == y_target[mask]).sum().item()
                            val_correct_co += (pred_co[mask] == y_target[mask]).sum().item()
                            val_total += mask.sum().item()
                        
                        val_batches += 1
                        
                    except Exception as val_error:
                        continue
            
            # Calculate validation averages
            avg_val_loss_pa = val_loss_pa / val_batches if val_batches > 0 else 0
            avg_val_loss_co = val_loss_co / val_batches if val_batches > 0 else 0
            val_acc_pa = val_correct_pa / val_total if val_total > 0 else 0
            val_acc_co = val_correct_co / val_total if val_total > 0 else 0
            
            # Print epoch results
            print(f"‚úÖ Epoch {epoch+1}/{epochs} Results:")
            print(f"   Train Loss - PA: {avg_train_loss_pa:.4f}, CO: {avg_train_loss_co:.4f}")
            print(f"   Val Loss   - PA: {avg_val_loss_pa:.4f}, CO: {avg_val_loss_co:.4f}")
            print(f"   Val Acc    - PA: {val_acc_pa*100:.2f}%, CO: {val_acc_co*100:.2f}%")
        
        # Save models
        checkpoint_path = checkpoint_dir / 'quick_trained_model.pth'
        torch.save({
            'model_state_dict': player_model.state_dict(),
            'context_model_state_dict': context_model.state_dict(),
            'epoch': epochs,
            'training_type': 'quick_demo'
        }, checkpoint_path)
        
        print(f"‚úÖ Quick training completed! Models saved to {checkpoint_path}")
        return True
    
    try:
        model_loaded = quick_train_models(player_aware_model, context_model, dataset)
    except Exception as e:
        print(f"‚ö†Ô∏è Training failed: {e}")
        print("Will use randomly initialized models...")
        USER_CHOICE = 3

if not model_loaded or USER_CHOICE == 3:
    print("\nüé≤ Using randomly initialized models for demonstration")
    print("   Note: Performance will be random since models are not trained")
    print("   To get meaningful results:")
    print("   1. Set USER_CHOICE = 2 to train models")
    print("   2. Or provide pre-trained weights and set USER_CHOICE = 1")

# Set models to evaluation mode
player_aware_model.eval()
context_model.eval()
print(f"\n‚úÖ Models ready for inference!")

print(f"\nüìã Current Setup:")
print(f"   Choice: {['', 'Load pre-trained weights', 'Train from scratch', 'Random initialization'][USER_CHOICE]}")
print(f"   Model loaded: {'Yes' if model_loaded else 'No'}")
print(f"   Ready for predictions: Yes")

## 5. Mid-Rally Prediction Testing

Let's test the model's ability to predict shots in the middle of rallies. This demonstrates the core functionality of the tennis shot prediction system.

In [None]:
def test_mid_rally_prediction(model, dataset, num_samples=5, model_type='player_aware'):
    """
    Test the model on partial rally sequences.
    
    Args:
        model: The neural network model
        dataset: Tennis dataset
        num_samples: Number of samples to test
        model_type: 'player_aware' or 'context_only'
    """
    print(f"üéæ Testing {model_type} model on {num_samples} mid-rally predictions...\n")
    
    # Create reverse vocabularies
    idx_to_zone = {v: k for k, v in dataset.zone_vocab.items()}
    idx_to_shot = {v: k for k, v in dataset.shot_vocab.items()}
    
    model.eval()
    correct_predictions = 0
    
    # Randomly select samples for testing
    test_indices = np.random.choice(len(dataset), num_samples, replace=False)
    
    with torch.no_grad():
        for i, idx in enumerate(test_indices):
            sample = dataset[idx]
            
            # Find non-padding positions
            valid_positions = (sample['x_zone'] != 0).nonzero(as_tuple=True)[0]
            
            if len(valid_positions) < 3:  # Need at least 3 shots for meaningful test
                continue
                
            # Test prediction at the last position
            test_pos = valid_positions[-1].item()
            true_target = sample['y_target'][test_pos].item()
            
            if true_target == 0:  # Skip if target is padding
                continue
            
            # Prepare input tensors
            x_zone = sample['x_zone'].unsqueeze(0).to(DEVICE)
            x_type = sample['x_type'].unsqueeze(0).to(DEVICE)
            x_context = sample['context'].unsqueeze(0).to(DEVICE)
            
            # Make prediction based on model type
            if model_type == 'player_aware':
                x_s_id = sample['x_s_id'].unsqueeze(0).to(DEVICE)
                x_r_id = sample['x_r_id'].unsqueeze(0).to(DEVICE)
                logits = model(x_zone, x_type, x_context, x_s_id, x_r_id)
            else:
                logits = model(x_zone, x_type, x_context)
            
            # Get prediction at test position
            pred_logits = logits[0, test_pos]
            pred_zone_idx = pred_logits.argmax().item()
            
            # Get top 3 predictions
            top_probs, top_indices = torch.topk(torch.softmax(pred_logits, dim=0), 3)
            
            # Convert to readable format
            pred_zone = idx_to_zone.get(pred_zone_idx, '?')
            true_zone = idx_to_zone.get(true_target, '?')
            
            # Build rally history for display
            rally_parts = []
            for j in valid_positions[:test_pos]:
                zone = idx_to_zone.get(sample['x_zone'][j].item(), '?')
                shot = idx_to_shot.get(sample['x_type'][j].item(), '?')
                rally_parts.append(f"{zone}{shot}")
            
            rally_str = " ".join(rally_parts)
            
            # Check if prediction is correct
            is_correct = (pred_zone == true_zone)
            if is_correct:
                correct_predictions += 1
            
            # Display results
            status_emoji = "‚úÖ" if is_correct else "‚ùå"
            print(f"{status_emoji} Sample {i+1}:")
            print(f"   Rally: {rally_str} ‚Üí ?")
            print(f"   Predicted: Zone {pred_zone} | Actual: Zone {true_zone}")
            print(f"   Top 3 predictions:")
            
            for k, (prob, zone_idx) in enumerate(zip(top_probs, top_indices)):
                zone = idx_to_zone.get(zone_idx.item(), '?')
                if zone != '<pad>':
                    print(f"      {k+1}. Zone {zone}: {prob.item()*100:.1f}%")
            
            print()
    
    accuracy = correct_predictions / num_samples if num_samples > 0 else 0
    print(f"üìä Accuracy: {correct_predictions}/{num_samples} ({accuracy*100:.1f}%)")
    return accuracy

# Test both models
print("=" * 60)
accuracy_player = test_mid_rally_prediction(player_aware_model, dataset, 5, 'player_aware')
print("\n" + "=" * 60)
accuracy_context = test_mid_rally_prediction(context_model, dataset, 5, 'context_only')

print(f"\nüèÜ Results Summary:")
print(f"   Player-aware model accuracy: {accuracy_player*100:.1f}%")
print(f"   Context-only model accuracy: {accuracy_context*100:.1f}%")

## 6. Interactive Shot Prediction

Let's create an interactive interface to make predictions based on custom rally sequences.

In [None]:
def predict_next_shot(rally_sequence, model, dataset, model_type='player_aware'):
    """
    Predict the next shot given a rally sequence.
    
    Args:
        rally_sequence: List of tuples like [('1', 'b'), ('8', 'f'), ...]
        model: The neural network model
        dataset: Tennis dataset
        model_type: 'player_aware' or 'context_only'
    
    Returns:
        Dictionary with prediction results
    """
    print(f"üîÆ Predicting next shot with {model_type} model...")
    print(f"Rally: {' ‚Üí '.join([f'Zone{z}{t}' for z, t in rally_sequence])}")
    
    # Create reverse vocabularies
    idx_to_zone = {v: k for k, v in dataset.zone_vocab.items()}
    
    # Convert rally to tensor format - use SEQ_LEN to match model expectations
    max_seq_len = SEQ_LEN  # Use the same sequence length as the model was initialized with
    zones = [0] * max_seq_len
    shot_types = [0] * max_seq_len
    
    for i, (zone, shot_type) in enumerate(rally_sequence[:max_seq_len]):
        zones[i] = dataset.zone_vocab.get(zone, 0)
        shot_types[i] = dataset.shot_vocab.get(shot_type, 0)
    
    # Create tensors
    x_zone = torch.tensor(zones, dtype=torch.long).unsqueeze(0).to(DEVICE)
    x_type = torch.tensor(shot_types, dtype=torch.long).unsqueeze(0).to(DEVICE)
    
    # Create dummy context (for mock data) - match the context_dim from model config
    # Context format: [surface, server_score, receiver_score, is_second_serve, server_hand, receiver_hand]
    dummy_context = [1.0, 0.0, 0.0, 0.0, 1.0, 1.0]  # 6 features
    x_context = torch.tensor(dummy_context, dtype=torch.float).unsqueeze(0).to(DEVICE)  # [1, 6]
    
    model.eval()
    with torch.no_grad():
        # Make prediction based on model type
        if model_type == 'player_aware':
            # Use dummy player IDs for mock prediction
            x_s_id = torch.tensor([2], dtype=torch.long).to(DEVICE)  # Single server ID
            x_r_id = torch.tensor([3], dtype=torch.long).to(DEVICE)  # Single receiver ID
            logits = model(x_zone, x_type, x_context, x_s_id, x_r_id)
        else:
            logits = model(x_zone, x_type, x_context)
        
        # Get prediction at the next position
        pred_pos = len(rally_sequence)
        if pred_pos < max_seq_len:
            pred_logits = logits[0, pred_pos]
            probs = torch.softmax(pred_logits, dim=0)
            
            # Get top 5 predictions
            top_probs, top_indices = torch.topk(probs, min(5, len(dataset.zone_vocab)))
            
            predictions = []
            for prob, zone_idx in zip(top_probs, top_indices):
                zone = idx_to_zone.get(zone_idx.item(), '?')
                if zone != '<pad>':
                    predictions.append({
                        'zone': zone,
                        'probability': prob.item(),
                        'confidence': 'High' if prob.item() > 0.3 else 'Medium' if prob.item() > 0.15 else 'Low'
                    })
            
            return predictions
        else:
            print("Rally sequence is too long!")
            return []

# Example rally sequences to test
example_rallies = [
    [('4', 'b'), ('6', 'f')],  # Backhand to zone 4, forehand to zone 6
    [('1', 'b'), ('8', 'f'), ('3', 'b')],  # Longer rally
    [('2', 'v'), ('7', 'f')],  # Volley followed by forehand
    [('5', 'f'), ('5', 'b'), ('6', 'f')]  # Back and forth rally
]

print("üéæ Interactive Shot Prediction Examples\n")
print("=" * 70)

# Test with both models
for i, rally in enumerate(example_rallies[:2]):  # Test first 2 examples
    print(f"\nüìã Example {i+1}:")
    print("-" * 30)
    
    # Player-aware model prediction
    predictions_pa = predict_next_shot(rally, player_aware_model, dataset, 'player_aware')
    print("\nü§ñ Player-aware model predictions:")
    if not predictions_pa:
        print("   No predictions")
    else:
        for j, pred in enumerate(predictions_pa[:3]):
            print(f"   {j+1}. Zone {pred['zone']}: {pred['probability']*100:.1f}% ({pred['confidence']})")
    
    # Context-only model prediction
    predictions_co = predict_next_shot(rally, context_model, dataset, 'context_only')
    print("\nüß† Context-only model predictions:")
    if not predictions_co:
        print("   No predictions")
    else:
        for j, pred in enumerate(predictions_co[:3]):
            print(f"   {j+1}. Zone {pred['zone']}: {pred['probability']*100:.1f}% ({pred['confidence']})")
    
    print("\n" + "=" * 70)

## 7. Performance Analysis

Let's analyze the models' performance and compare their capabilities.

In [None]:
def analyze_model_performance(model, dataset, model_type, num_samples=50):
    """
    Comprehensive performance analysis of a model.
    
    Args:
        model: The neural network model
        dataset: Tennis dataset
        model_type: 'player_aware' or 'context_only'
        num_samples: Number of samples to analyze
    
    Returns:
        Dictionary with performance metrics
    """
    print(f"üìä Analyzing {model_type} model performance on {num_samples} samples...")
    
    model.eval()
    correct_predictions = 0
    total_predictions = 0
    confidence_scores = []
    prediction_probabilities = []
    
    # Create reverse vocabularies
    idx_to_zone = {v: k for k, v in dataset.zone_vocab.items()}
    
    with torch.no_grad():
        test_indices = np.random.choice(len(dataset), num_samples, replace=False)
        
        for idx in test_indices:
            sample = dataset[idx]
            
            # Find valid positions
            valid_positions = (sample['x_zone'] != 0).nonzero(as_tuple=True)[0]
            
            if len(valid_positions) < 2:
                continue
            
            # Test multiple positions in each rally
            for pos in valid_positions[1:]:  # Skip first position
                true_target = sample['y_target'][pos].item()
                if true_target == 0:  # Skip padding
                    continue
                
                # Prepare input
                x_zone = sample['x_zone'].unsqueeze(0).to(DEVICE)
                x_type = sample['x_type'].unsqueeze(0).to(DEVICE)
                x_context = sample['context'].unsqueeze(0).to(DEVICE)
                
                # Make prediction
                if model_type == 'player_aware':
                    x_s_id = sample['x_s_id'].unsqueeze(0).to(DEVICE)
                    x_r_id = sample['x_r_id'].unsqueeze(0).to(DEVICE)
                    logits = model(x_zone, x_type, x_context, x_s_id, x_r_id)
                else:
                    logits = model(x_zone, x_type, x_context)
                
                # Calculate metrics
                pred_logits = logits[0, pos]
                probs = torch.softmax(pred_logits, dim=0)
                pred_idx = pred_logits.argmax().item()
                
                # Store results
                total_predictions += 1
                if pred_idx == true_target:
                    correct_predictions += 1
                
                max_prob = probs.max().item()
                confidence_scores.append(max_prob)
                prediction_probabilities.append(probs.cpu().numpy())
    
    # Calculate metrics
    accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
    avg_confidence = np.mean(confidence_scores) if confidence_scores else 0
    confidence_std = np.std(confidence_scores) if confidence_scores else 0
    
    results = {
        'accuracy': accuracy,
        'total_predictions': total_predictions,
        'correct_predictions': correct_predictions,
        'average_confidence': avg_confidence,
        'confidence_std': confidence_std,
        'model_type': model_type
    }
    
    # Display results
    print(f"\nüìà {model_type.title()} Model Results:")
    print(f"   Accuracy: {accuracy*100:.2f}% ({correct_predictions}/{total_predictions})")
    print(f"   Average Confidence: {avg_confidence*100:.2f}% (¬±{confidence_std*100:.2f}%)\n")

    return results

# Analyze both models
print("=" * 80)
results_player = analyze_model_performance(player_aware_model, dataset, 'player_aware', 30)
print("\n" + "=" * 80)
results_context = analyze_model_performance(context_model, dataset, 'context_only', 30)

# Compare models
print("\n" + "=" * 80)
print("üèÜ MODEL COMPARISON")
print("=" * 80)
print("üìä Accuracy Comparison:")
print(f"   Player-aware: {results_player['accuracy']*100:.2f}%")
print(f"   Context-only: {results_context['accuracy']*100:.2f}%")
print(f"   Difference: {(results_player['accuracy'] - results_context['accuracy'])*100:.2f}%")

print(f"\nüéØ Confidence Comparison:")
print(f"   Player-aware: {results_player['average_confidence']*100:.2f}%")
print(f"   Context-only: {results_context['average_confidence']*100:.2f}%")

# Determine better model
better_model = "Player-aware" if results_player['accuracy'] > results_context['accuracy'] else "Context-only"
print(f"\nü•á Better performing model: {better_model}")

## 8. Tactical Intelligence Analysis

Let's explore what tactical patterns our models have learned.

In [None]:
def analyze_tactical_patterns(model, dataset, model_type):
    """
    Analyze tactical patterns learned by the model.
    """
    print(f"üéØ Analyzing tactical patterns in {model_type} model...")
    
    # Define common tactical scenarios
    tactical_scenarios = {
        "Cross-court rallies": [("1", "b"), ("8", "f")],
        "Down-the-line pressure": [("1", "f"), ("1", "b")],
        "Serve and volley": [("4", "s"), ("2", "v")],
        "Defensive lob": [("8", "f"), ("3", "l")],
        "Approach shot": [("5", "f"), ("2", "v")]
    }
    
    print("\nüß† Tactical Pattern Analysis:")
    print("=" * 50)
    
    for scenario_name, rally in tactical_scenarios.items():
        print(f"\nüìã {scenario_name}:")
        print(f"   Setup: {' ‚Üí '.join([f'Zone{z}{t.upper()}' for z, t in rally])}")
        
        predictions = predict_next_shot(rally, model, dataset, model_type)
        
        if predictions:
            top_pred = predictions[0]
            print(f"   üí° Model suggests: Zone {top_pred['zone']} ({top_pred['probability']*100:.1f}% confidence)")
            
            # Tactical interpretation
            tactical_meaning = interpret_tactical_choice(rally[-1], top_pred['zone'])
            print(f"   üéæ Tactical insight: {tactical_meaning}")
        else:
            print("   ‚ùå No valid predictions")
    
    return None
def interpret_tactical_choice(last_shot, predicted_zone):
    """Provide tactical interpretation of predicted shot."""
    last_zone, last_type = last_shot
    
    # Simple tactical interpretation logic
    if predicted_zone in ['1', '2', '3']:
        if last_zone in ['6', '7', '8']:
            return "Cross-court shot to change direction"
        else:
            return "Down-the-line shot to maintain pressure"
    elif predicted_zone in ['6', '7', '8']:
        if last_zone in ['1', '2', '3']:
            return "Cross-court shot to open the court"
        else:
            return "Same-side shot to maintain rally"
    elif predicted_zone in ['4', '5']:
        return "Central shot to maintain neutral position"
    else:
        return "Defensive positioning"
# Analyze both models' tactical understanding
print("üéæ TACTICAL INTELLIGENCE COMPARISON")
print("=" * 80)

print("\nü§ñ PLAYER-AWARE MODEL:")
analyze_tactical_patterns(player_aware_model, dataset, 'player_aware')

print("\n" + "=" * 80)
print("\nüß† CONTEXT-ONLY MODEL:")
analyze_tactical_patterns(context_model, dataset, 'context_only')