# LSTM Ball Trajectory Prediction Video Overlay
- **LSTM predicted ball trajectory** (3 seconds into future)
- **Actual ball position** for comparison
- **Prediction accuracy visualization**
- **Real-time error metrics**

In [None]:
import sys
import os
import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

##  Load LSTM Model and Data

In [None]:
# Import LSTM model architecture
from tennis_lstm import TennisLSTMModel, TennisLSTMAligned
from tennis_utils import read_video, MiniCourt
from create_tennis_video_overlay import TennisVideoOverlayCreator


In [None]:
# Load the trained LSTM model
print("📂 Loading LSTM model...")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Initialize LSTM model with paper specifications
model = TennisLSTMModel(
    input_size=69,      # 69 features per frame
    lstm1_units=128,    # First LSTM layer
    lstm2_units=64,     # Second LSTM layer
    dropout_rate=0.2
).to(device)

# Try to load the trained model weights
model_files = [
    'tennis_lstm_final_model.pth',
    'best_lstm_model.pth'
]

model_loaded = False
for model_file in model_files:
    if os.path.exists(model_file):
        try:
            model.load_state_dict(torch.load(model_file, map_location=device))
            model.eval()
            model_loaded = True
            print(f"LSTM model loaded from: {model_file}")
            break
        except Exception as e:
            print(f"⚠Error loading {model_file}: {e}")
            continue

if not model_loaded:
    print("No trained LSTM model found. Creating demo model with random weights.")


print(f" LSTM layers: {model.lstm1.hidden_size} → {model.lstm2.hidden_size} units")

## Load Tennis Data

In [None]:


# Load ball tracking data
ball_data_file = "ball_tracking.csv"
if os.path.exists(ball_data_file):
    df_ball = pd.read_csv(ball_data_file)
    print(f" Ball data loaded: {len(df_ball)} records")
else:
    print(f" Ball tracking file not found: {ball_data_file}")
    df_ball = pd.DataFrame()

player_data_file = "player_tracking.csv"
if os.path.exists(player_data_file):
    df_players = pd.read_csv(player_data_file)
    print(f" Player data loaded: {len(df_players)} records")
else:
    print(f" Player tracking file not found: {player_data_file}")
    df_players = pd.DataFrame()

# Check data availability
if len(df_ball) > 0:
    print(f"\n Ball Data Summary:")
    print(f"Videos: {df_ball['video_name'].nunique()}")
    print(f" Total frames: {len(df_ball)}")
    print(f"Valid detections: {(~df_ball['center_x'].isna()).sum()}")
    print(f"mDetection rate: {(~df_ball['center_x'].isna()).mean():.1%}")
    
    # Show sample videos
    print(f"\nAvailable videos:")
    for i, video in enumerate(df_ball['video_name'].unique()[:5]):
        video_ball_count = len(df_ball[df_ball['video_name'] == video])
        video_detections = (~df_ball[df_ball['video_name'] == video]['center_x'].isna()).sum()
        print(f"   {i+1}. {video}: {video_detections}/{video_ball_count} ball detections")
    
    if len(df_ball['video_name'].unique()) > 5:
        print(f" {len(df_ball['video_name'].unique()) - 5} more videos")
else:
    print("No ball data available for prediction")

## Create LSTM Prediction Pipeline

In [None]:
class TennisLSTMPredictor:
    """
    Real-time tennis ball trajectory predictor using LSTM
    """
    
    def __init__(self, model, device='cpu'):
        self.model = model
        self.device = device
        self.scaler = StandardScaler()
        self.sequence_length = 12  # Input frames
        self.prediction_frames = 5  # Output frames
        self.features_per_frame = 69
        self.is_fitted = False
        
    def create_features(self, df_ball, df_players=None):
        """

        """

        
        df_features = df_ball.copy()
        
        # Ensure required columns exist
        if 'court_width_pixels' not in df_features.columns:
            df_features['court_width_pixels'] = 1344.0
        if 'court_height_pixels' not in df_features.columns:
            df_features['court_height_pixels'] = 756.0
        
        # Add player positions if available
        if df_players is not None and len(df_players) > 0:
            # Get player 1 positions
            player1 = df_players[df_players['player_id'] == 1].groupby(['video_name', 'frame_number']).first()
            player1 = player1.reset_index()[['video_name', 'frame_number', 'center_x', 'center_y']]
            player1.columns = ['video_name', 'frame_number', 'player_1_center_x', 'player_1_center_y']
            
            df_features = df_features.merge(player1, on=['video_name', 'frame_number'], how='left')

        if 'player_1_center_x' not in df_features.columns:
            df_features['player_1_center_x'] = 500.0
        if 'player_1_center_y' not in df_features.columns:
            df_features['player_1_center_y'] = 400.0
        
        df_features['player_1_center_x'] = df_features['player_1_center_x'].fillna(500.0)
        df_features['player_1_center_y'] = df_features['player_1_center_y'].fillna(400.0)
        
        # Create time differential
        df_features = df_features.sort_values(['video_name', 'frame_number'])
        df_features['time_seconds'] = df_features['frame_number'] / 30.0
        
        # Calculate features by video
        enhanced_dfs = []
        
        for video_name in df_features['video_name'].unique():
            video_df = df_features[df_features['video_name'] == video_name].copy()
            
            if len(video_df) < self.sequence_length + self.prediction_frames:
                continue  # Skip videos that are too short
            
            # Time differential
            time_diff = video_df['time_seconds'].diff().replace(0, 0.0333).fillna(0.0333)
            
            # Court-relative positions
            video_df['rel_x'] = video_df['center_x'] / video_df['court_width_pixels']
            video_df['rel_y'] = video_df['center_y'] / video_df['court_height_pixels']
            
            # Velocity components
            video_df['vx'] = video_df['center_x'].diff() / time_diff
            video_df['vy'] = video_df['center_y'].diff() / time_diff
            
            # Convert to m/s
            pixels_per_meter = 100
            video_df['vx_ms'] = video_df['vx'] / pixels_per_meter
            video_df['vy_ms'] = video_df['vy'] / pixels_per_meter
            
            # Speed and acceleration
            video_df['speed'] = np.sqrt(video_df['vx']**2 + video_df['vy']**2)
            video_df['speed_ms'] = np.sqrt(video_df['vx_ms']**2 + video_df['vy_ms']**2)
            video_df['acceleration'] = np.sqrt(video_df['vx'].diff()**2 + video_df['vy'].diff()**2) / time_diff
            
            # Direction change
            video_df['direction_change'] = np.arctan2(video_df['vy'], video_df['vx']).diff()
            
            # Player distance
            video_df['player_dist'] = np.sqrt(
                (video_df['player_1_center_x'] - video_df['center_x'])**2 + 
                (video_df['player_1_center_y'] - video_df['center_y'])**2
            )
            
            # Base variables for additional features
            base_vars = ['rel_x', 'rel_y', 'vx', 'vy', 'speed', 'acceleration', 'direction_change', 'player_dist']
            
            # Moving averages (16 features: 8 vars × 2 windows)
            for var in base_vars:
                video_df[f'{var}_ma3'] = video_df[var].rolling(3, min_periods=1).mean()
                video_df[f'{var}_ma5'] = video_df[var].rolling(5, min_periods=1).mean()
            

            for var in base_vars:
                video_df[f'{var}_std5'] = video_df[var].rolling(5, min_periods=1).std().fillna(0)
            
            # Court zone features (2 features)
            video_df['near_net'] = ((video_df['rel_x'] > 0.4) & (video_df['rel_x'] < 0.6)).astype(int)
            video_df['in_corner'] = ((video_df['rel_y'] < 0.2) | (video_df['rel_y'] > 0.8)).astype(int)
            
            # Additional features to reach 69 total
            video_df['serve_rally_state'] = ((video_df['speed_ms'] > 15) & (video_df['rel_y'] < 0.5)).astype(int)
            video_df['trajectory_curvature'] = np.abs(video_df['direction_change']).fillna(0)
            video_df['vx_vy_ratio'] = np.abs(video_df['vx'] / (video_df['vy'] + 1e-6))
            video_df['speed_change'] = video_df['speed'].diff().fillna(0)
            video_df['distance_to_net'] = np.abs(video_df['rel_x'] - 0.5)
            video_df['distance_to_baseline'] = np.minimum(video_df['rel_y'], 1 - video_df['rel_y'])
            video_df['distance_to_sideline'] = np.minimum(video_df['rel_x'], 1 - video_df['rel_x'])
            video_df['frame_number_norm'] = (video_df['frame_number'] - video_df['frame_number'].min()) / len(video_df)
            video_df['time_in_rally'] = video_df['time_seconds'] - video_df['time_seconds'].iloc[0]
            
            # Smoothed features
            video_df['rel_x_smooth'] = video_df['rel_x'].rolling(7, min_periods=1).mean()
            video_df['rel_y_smooth'] = video_df['rel_y'].rolling(7, min_periods=1).mean()
            video_df['vx_smooth'] = video_df['vx'].rolling(5, min_periods=1).mean()
            video_df['vy_smooth'] = video_df['vy'].rolling(5, min_periods=1).mean()
            video_df['speed_smooth'] = video_df['speed'].rolling(5, min_periods=1).mean()
            video_df['kinetic_energy'] = 0.5 * video_df['speed_ms']**2
            video_df['speed_max5'] = video_df['speed'].rolling(5, min_periods=1).max()
            video_df['speed_min5'] = video_df['speed'].rolling(5, min_periods=1).min()
            
            enhanced_dfs.append(video_df)
        
        # Combine all videos
        if enhanced_dfs:
            df_enhanced = pd.concat(enhanced_dfs, ignore_index=True)
        else:
            print("No videos suitable for feature creation")
            return pd.DataFrame(), []
        
        # Select exactly 69 features
        feature_list = [
            'rel_x', 'rel_y', 'vx', 'vy', 'vx_ms', 'vy_ms',
            'speed', 'speed_ms', 'acceleration', 'direction_change', 'speed_change',
            'player_dist', 'trajectory_curvature', 'vx_vy_ratio',
            'distance_to_net', 'distance_to_baseline', 'distance_to_sideline',
            'frame_number_norm', 'time_in_rally',
            'rel_x_smooth', 'rel_y_smooth', 'vx_smooth', 'vy_smooth', 'speed_smooth',
            'kinetic_energy', 'speed_max5', 'speed_min5',
            'near_net', 'in_corner', 'serve_rally_state',
            'court_width_pixels', 'court_height_pixels'
        ]
        
        # Add moving averages and std
        for var in base_vars:
            feature_list.extend([f'{var}_ma3', f'{var}_ma5', f'{var}_std5'])
        
        # Ensure exactly 69 features
        feature_list = feature_list[:69]
        
        print(f"Created {len(feature_list)} features")
        print(f" Enhanced dataset: {len(df_enhanced)} frames")
        
        return df_enhanced, feature_list
    
    def prepare_sequences(self, df_enhanced, feature_cols):
        """
        Prepare sequences for LSTM prediction
        """
        print(f" Preparing prediction sequences...")
        
        sequences = []
        targets = []
        metadata = []
        
        for video_name in df_enhanced['video_name'].unique():
            video_data = df_enhanced[df_enhanced['video_name'] == video_name].copy()
            
            if len(video_data) < self.sequence_length + self.prediction_frames:
                continue
            
            for i in range(self.sequence_length, len(video_data) - self.prediction_frames + 1):
                # Input sequence: 12 frames × 69 features
                x_seq = video_data.iloc[i-self.sequence_length:i][feature_cols].values
                
                # Target sequence: 5 frames × 2 coordinates
                y_seq = video_data.iloc[i:i+self.prediction_frames][['center_x', 'center_y']].values
                
                # Skip if any NaN values
                if not np.isnan(x_seq).any() and not np.isnan(y_seq).any():
                    sequences.append(x_seq)
                    targets.append(y_seq)
                    
                    # Store metadata for this sequence
                    metadata.append({
                        'video_name': video_name,
                        'start_frame': video_data.iloc[i-self.sequence_length]['frame_number'],
                        'prediction_frame': video_data.iloc[i]['frame_number'],
                        'end_frame': video_data.iloc[i+self.prediction_frames-1]['frame_number']
                    })
        
        if sequences:
            X = np.array(sequences)
            y = np.array(targets)
            
            print(f" Created {len(sequences)} sequences")
            print(f" Input shape: {X.shape} (batch, seq_len, features)")
            print(f" Target shape: {y.shape} (batch, pred_frames, coordinates)")
            
            return X, y, metadata
        else:
            print(f"No valid sequences created")
            return np.array([]), np.array([]), []
    
    def fit_scaler(self, X):
        """
        Fit the feature scaler
        """
        if len(X) > 0:
            # Reshape for normalization: (batch*seq_len, features)
            X_reshaped = X.reshape(-1, X.shape[-1])
            self.scaler.fit(X_reshaped)
            self.is_fitted = True
            print(f"Scaler fitted on {X_reshaped.shape[0]} samples")
        else:
            print(f"No data ")
    
    def predict_trajectory(self, X_sequence):
        """
        Predict ball trajectory for a single sequence
        """
        if not self.is_fitted:
            print("Scaler not fitted. Using unscaled data.")
            X_scaled = X_sequence
        else:
            # Scale the input
            X_scaled = self.scaler.transform(X_sequence)
        
        # Add batch dimension and convert to tensor
        X_tensor = torch.FloatTensor(X_scaled).unsqueeze(0).to(self.device)  # (1, seq_len, features)
        
        # Predict
        with torch.no_grad():
            prediction = self.model(X_tensor)  # (1, pred_frames, 2)
            prediction = prediction.cpu().numpy()[0]  # (pred_frames, 2)
        
        return prediction


predictor = TennisLSTMPredictor(model, device)
print("LSTM Predictor initialized")

## Prepare Data for Prediction

In [None]:
# Prepare data for LSTM prediction
if len(df_ball) > 0:
    ball_valid = df_ball.dropna(subset=['center_x', 'center_y']).copy()
    print(f"Valid ball positions: {len(ball_valid)} / {len(df_ball)}")
    
    # Create features
    df_enhanced, feature_cols = predictor.create_features(ball_valid, df_players)
    
    if len(df_enhanced) > 0:
        # Prepare sequences
        X, y, metadata = predictor.prepare_sequences(df_enhanced, feature_cols)
        
        if len(X) > 0:
            # Fit scaler
            predictor.fit_scaler(X)

            print(f"Sequences: {len(X)}")
            print(f"Videos: {len(set(m['video_name'] for m in metadata))}")
            print(f"Features per frame: {len(feature_cols)}")

            demo_sequences = []
            for i, meta in enumerate(metadata):

                frame_span = meta['end_frame'] - meta['start_frame']
                if frame_span >= 15:  # At least 15 frames span
                    demo_sequences.append((i, meta))
            
            if demo_sequences:
                # Select the first good sequence
                demo_idx, demo_meta = demo_sequences[0]

                print(f" Video: {demo_meta['video_name']}")
                print(f"Frames: {demo_meta['start_frame']} → {demo_meta['end_frame']}")
                print(f"Prediction starts at frame: {demo_meta['prediction_frame']}")
            else:
                print(f"\n No suitable sequences found for demonstration")
                demo_idx, demo_meta = 0, metadata[0]
        else:
            print(f"No valid sequences created")
    else:
        print(f"Feature creation failed")
else:
    print(f"No ball data available for prediction")
    X, y, metadata = np.array([]), np.array([]), []

## Create Prediction Video Overlay

In [None]:
class LSTMPredictionVideoOverlay(TennisVideoOverlayCreator):
    """
    Extended video overlay creator with LSTM predictions
    """
    
    def __init__(self, predictor, court_measurements=None):
        super().__init__(court_measurements)
        self.predictor = predictor
        
        # Prediction visualization colors
        self.pred_colors = {
            'predicted_path': (255, 165, 0),    # Orange
            'actual_path': (0, 255, 0),         # Green
            'prediction_point': (255, 0, 255),  # Magenta
            'error_line': (255, 255, 0),        # Yellow
            'confidence_zone': (128, 128, 255)  # Light blue
        }
    
    def draw_prediction_overlay(self, frame, current_frame, ball_history, prediction, actual_future=None):
        """
        Draw LSTM prediction overlay on frame
        """
        overlay = frame.copy()
        
        #  ball history trail
        if len(ball_history) > 1:
            history_points = [(int(x), int(y)) for x, y in ball_history if not (np.isnan(x) or np.isnan(y))]
            
            for i in range(1, len(history_points)):
                # Fade older points
                alpha = 0.3 + 0.7 * (i / len(history_points))
                color = tuple(int(c * alpha) for c in self.colors['ball'])
                cv2.line(overlay, history_points[i-1], history_points[i], color, 2)
        
        #  current ball position
        if len(ball_history) > 0:
            current_pos = ball_history[-1]
            if not (np.isnan(current_pos[0]) or np.isnan(current_pos[1])):
                cv2.circle(overlay, (int(current_pos[0]), int(current_pos[1])), 8, self.colors['ball'], -1)
                cv2.circle(overlay, (int(current_pos[0]), int(current_pos[1])), 10, (255, 255, 255), 2)
        
        # predicted trajectory
        if len(prediction) > 0:
            pred_points = [(int(x), int(y)) for x, y in prediction if not (np.isnan(x) or np.isnan(y))]
            
            # Connect current position to first prediction
            if len(ball_history) > 0 and len(pred_points) > 0:
                current_pos = ball_history[-1]
                if not (np.isnan(current_pos[0]) or np.isnan(current_pos[1])):
                    cv2.line(overlay, (int(current_pos[0]), int(current_pos[1])), 
                            pred_points[0], self.pred_colors['predicted_path'], 3)
            
            # prediction path
            for i in range(1, len(pred_points)):
                cv2.line(overlay, pred_points[i-1], pred_points[i], self.pred_colors['predicted_path'], 3)
            
            # prediction points
            for i, point in enumerate(pred_points):
                # Size decreases with distance in future
                radius = max(4, 8 - i)
                cv2.circle(overlay, point, radius, self.pred_colors['prediction_point'], -1)
                cv2.circle(overlay, point, radius + 2, (255, 255, 255), 1)
                
                # Add frame number
                frame_text = f"+{i+1}"
                cv2.putText(overlay, frame_text, (point[0] + 12, point[1] - 12),
                           cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
        
        #= future positions if available
        if actual_future is not None and len(actual_future) > 0:
            actual_points = [(int(x), int(y)) for x, y in actual_future if not (np.isnan(x) or np.isnan(y))]
            
            # path
            for i in range(1, len(actual_points)):
                cv2.line(overlay, actual_points[i-1], actual_points[i], self.pred_colors['actual_path'], 2)
            
            # points
            for point in actual_points:
                cv2.circle(overlay, point, 4, self.pred_colors['actual_path'], -1)
            
            #  error lines between predicted and actual
            min_len = min(len(pred_points), len(actual_points))
            for i in range(min_len):
                cv2.line(overlay, pred_points[i], actual_points[i], self.pred_colors['error_line'], 1)
        

        info_y = 30
        cv2.putText(overlay, f"LSTM Ball Trajectory Prediction", (10, info_y),
                   cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
        
        info_y += 25
        cv2.putText(overlay, f"Current Frame: {current_frame}", (10, info_y),
                   cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
        
        if len(prediction) > 0:
            info_y += 20
            cv2.putText(overlay, f"Predicting {len(prediction)} frames ahead (~{len(prediction)/30:.1f}s)", (10, info_y),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
        
        return overlay
    
    def add_prediction_info_panel(self, combined_frame, prediction, actual_future, frame_number, errors=None):
        """
        Add prediction analysis panel
        """
        panel_height = 200
        panel_width = combined_frame.shape[1]
        
        # Create info panel
        info_panel = np.zeros((panel_height, panel_width, 3), dtype=np.uint8)
        
        # Title
        cv2.putText(info_panel, f"Frame {frame_number} - LSTM Prediction Analysis", 
                   (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, self.colors['text'], 2)
        
        # Prediction information
        y_offset = 60
        if len(prediction) > 0:
            cv2.putText(info_panel, "TRAJECTORY PREDICTION:", (10, y_offset), 
                       cv2.FONT_HERSHEY_SIMPLEX, 0.7, self.pred_colors['predicted_path'], 2)
            y_offset += 25
            
            # Show predicted positions
            for i, (x, y) in enumerate(prediction[:3]):
                cv2.putText(info_panel, f"Frame +{i+1}: ({x:.0f}, {y:.0f})", 
                           (20, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.5, self.colors['text'], 1)
                y_offset += 18
            
            if len(prediction) > 3:
                cv2.putText(info_panel, f"... +{len(prediction)} total predictions", 
                           (20, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.5, self.colors['text'], 1)
        
        # Actual vs Predicted comparison
        y_offset = 60
        x_offset = 400
        if actual_future is not None and len(actual_future) > 0:
            cv2.putText(info_panel, "PREDICTION ACCURACY:", (x_offset, y_offset), 
                       cv2.FONT_HERSHEY_SIMPLEX, 0.7, self.pred_colors['actual_path'], 2)
            y_offset += 25
            
            # Calculate and show errors
            if errors is not None and len(errors) > 0:
                avg_error = np.mean(errors)
                max_error = np.max(errors)
                
                cv2.putText(info_panel, f"Average Error: {avg_error:.1f} pixels", 
                           (x_offset + 10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.5, self.colors['text'], 1)
                y_offset += 18
                
                cv2.putText(info_panel, f"Max Error: {max_error:.1f} pixels", 
                           (x_offset + 10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.5, self.colors['text'], 1)
                y_offset += 18
                
                # Error in meters (approximate)
                pixels_per_meter = 100  # Approximate
                cv2.putText(info_panel, f"Avg Error: {avg_error/pixels_per_meter:.2f} meters", 
                           (x_offset + 10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.5, self.colors['text'], 1)
        
        # Legend
        legend_y = 140
        cv2.putText(info_panel, "LEGEND:", (10, legend_y), 
                   cv2.FONT_HERSHEY_SIMPLEX, 0.6, self.colors['text'], 2)
        legend_y += 20
        
        # Color legend
        legends = [
            ("Orange: LSTM Prediction", self.pred_colors['predicted_path']),
            ("Green: Actual Position", self.pred_colors['actual_path']),
            ("Yellow: Error", self.pred_colors['error_line'])
        ]
        
        for i, (text, color) in enumerate(legends):
            x_pos = 20 + (i * 200)
            cv2.rectangle(info_panel, (x_pos, legend_y-8), (x_pos+15, legend_y+2), color, -1)
            cv2.putText(info_panel, text, (x_pos+20, legend_y), 
                       cv2.FONT_HERSHEY_SIMPLEX, 0.4, self.colors['text'], 1)
        
        return info_panel

print("LSTM Prediction Video Overlay class created")

## Create Prediction Video

In [None]:
# Create LSTM prediction video overlay
if len(X) > 0 and len(metadata) > 0:
    demo_meta = metadata[demo_idx]
    demo_video = demo_meta['video_name']
    demo_video_path = f"input_videos/{demo_video}"
    
    print(f"Demo video: {demo_video}")
    print(f"Prediction sequence: frames {demo_meta['start_frame']} to {demo_meta['end_frame']}")
    
    # Check if video exists
    if os.path.exists(demo_video_path):
        # Get ball data for this video
        video_ball_data = df_ball[df_ball['video_name'] == demo_video].copy()
        video_player_data = df_players[df_players['video_name'] == demo_video].copy() if len(df_players) > 0 else pd.DataFrame()
        
        print(f"   Ball data: {len(video_ball_data)} frames")
        print(f"   Player data: {len(video_player_data)} detections")

        overlay_creator = LSTMPredictionVideoOverlay(
            predictor=predictor,
            court_measurements={
                'single_line_width': 8.23,
                'double_line_width': 10.97,
                'half_court_height': 11.88,
                'service_line_width': 6.4,
                'double_alley_difference': 1.37,
                'no_mans_land_height': 5.48
            }
        )
        
        # Create output directory
        os.makedirs('prediction_videos', exist_ok=True)
        
        # Generate output filename
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_filename = f"prediction_videos/lstm_prediction_{demo_video.replace('.mp4', '')}_{timestamp}.mp4"
        
        try:

            print(f"Loading video frames")
            video_frames = read_video(demo_video_path)
            
            if len(video_frames) == 0:
                raise ValueError("No frames loaded from video")
            
            print(f"   Loaded {len(video_frames)} frames")
            

            fps = 30
            frame_width = video_frames[0].shape[1]
            frame_height = video_frames[0].shape[0]

            mini_court_width = 300
            mini_court_height = 150
            main_width = frame_width + mini_court_width
            main_height = max(frame_height, mini_court_height)
            info_panel_height = 200
            output_width = main_width
            output_height = main_height + info_panel_height
            
            print(f"Output dimensions: {output_width}x{output_height}")

            fourcc = cv2.VideoWriter_fourcc(*'mp4v')
            out = cv2.VideoWriter(output_filename, fourcc, fps, (output_width, output_height))
            
            # Process frames for prediction
            start_frame = max(0, demo_meta['prediction_frame'] - 30)  # Start 1 second before prediction
            end_frame = min(len(video_frames), demo_meta['prediction_frame'] + 90)  # Show 3 seconds of prediction
            
            print(f"Processing frames {start_frame} to {end_frame}")
            
            # Get prediction sequence
            prediction_sequence = X[demo_idx]  # (12, 69)
            actual_targets = y[demo_idx]  # (5, 2)
            
            # Make prediction
            predicted_positions = predictor.predict_trajectory(prediction_sequence)
            print(f"Generated prediction: {predicted_positions.shape}")
            
            # Create mini court instance
            mini_court = MiniCourt(video_frames[start_frame])
            
            processed_frames = 0
            
            for frame_idx in range(start_frame, end_frame):
                if frame_idx >= len(video_frames):
                    break
                
                current_frame = video_frames[frame_idx].copy()
                
                # Create output frame
                output_frame = np.zeros((output_height, output_width, 3), dtype=np.uint8)
                
                # Get ball history up to current frame
                history_frames = video_ball_data[
                    (video_ball_data['frame_number'] >= frame_idx - 12) & 
                    (video_ball_data['frame_number'] <= frame_idx)
                ]
                
                ball_history = []
                for _, row in history_frames.iterrows():
                    if not pd.isna(row['center_x']):
                        ball_history.append((row['center_x'], row['center_y']))

                actual_future = None
                if frame_idx >= demo_meta['prediction_frame']:

                    future_offset = frame_idx - demo_meta['prediction_frame']
                    if future_offset < len(actual_targets):
                        actual_future = actual_targets[:future_offset+1]
                

                prediction_to_show = []
                if frame_idx >= demo_meta['prediction_frame']:
                    prediction_to_show = predicted_positions
                
                # Apply prediction overlay
                frame_with_prediction = overlay_creator.draw_prediction_overlay(
                    current_frame, frame_idx, ball_history, prediction_to_show, actual_future
                )
                
                # Add to output frame
                output_frame[0:frame_height, 0:frame_width] = frame_with_prediction
                
                # Add mini court
                mini_court_template = overlay_creator.create_mini_court(mini_court_width, mini_court_height)
                
                # Draw positions on mini court
                if len(ball_history) > 0:
                    current_pos = ball_history[-1]
                    try:
                        mini_pos = mini_court.convert_position_to_mini_court(current_pos)
                        mini_x = int(mini_pos[0] * mini_court_width / mini_court.court_drawing_width)
                        mini_y = int(mini_pos[1] * mini_court_height / (mini_court.court_drawing_width * 0.5))
                        cv2.circle(mini_court_template, (mini_x, mini_y), 4, overlay_creator.colors['ball'], -1)
                    except:
                        pass
                
                # Add mini court to output
                mini_start_x = frame_width
                mini_start_y = (main_height - mini_court_height) // 2
                output_frame[mini_start_y:mini_start_y + mini_court_height,
                           mini_start_x:mini_start_x + mini_court_width] = mini_court_template
                
                # Add mini court title
                cv2.putText(output_frame, "Mini Court View",
                           (mini_start_x + 10, mini_start_y - 10),
                           cv2.FONT_HERSHEY_SIMPLEX, 0.7, overlay_creator.colors['text'], 2)
                
                # Calculate prediction errors
                errors = None
                if actual_future is not None and len(prediction_to_show) > 0:
                    min_len = min(len(actual_future), len(prediction_to_show))
                    errors = []
                    for i in range(min_len):
                        error = np.sqrt((actual_future[i][0] - prediction_to_show[i][0])**2 +
                                      (actual_future[i][1] - prediction_to_show[i][1])**2)
                        errors.append(error)
                
                # Add prediction info panel
                info_panel = overlay_creator.add_prediction_info_panel(
                    output_frame, prediction_to_show, actual_future, frame_idx, errors
                )
                output_frame[main_height:main_height + info_panel_height, :] = info_panel
                
                # Write frame
                out.write(output_frame)
                processed_frames += 1
                
                # Progress update
                if processed_frames % 10 == 0:
                    progress = (processed_frames / (end_frame - start_frame)) * 100
                    print(f"Progress: {progress:.1f}% ({processed_frames}/{end_frame - start_frame} frames)")
            
            # Cleanup
            out.release()

            print(f"Output: {output_filename}")
            print(f"Processed: {processed_frames} frames")

            
            # Try to display a sample frame
            try:
                cap = cv2.VideoCapture(output_filename)
                if cap.isOpened():
                    cap.set(cv2.CAP_PROP_POS_FRAMES, processed_frames // 2)
                    ret, sample_frame = cap.read()
                    
                    if ret:
                        sample_frame_rgb = cv2.cvtColor(sample_frame, cv2.COLOR_BGR2RGB)
                        
                        plt.figure(figsize=(16, 10))
                        plt.imshow(sample_frame_rgb)
                        plt.title('LSTM Ball Trajectory Prediction - Sample Frame\nOrange: Predicted | Green: Actual | Yellow: Error', 
                                 fontsize=14, fontweight='bold')
                        plt.axis('off')
                        plt.tight_layout()
                        plt.show()
                        
                        print(f"\nSample frame shows the LSTM prediction overlay in action")
                    
                    cap.release()
            except Exception as e:
                print(f"Could not display sample frame: {e}")
        
        except Exception as e:
            print(f"Error creating prediction video: {e}")
            import traceback
            traceback.print_exc()
    
    else:
        print(f" Demo video not found: {demo_video_path}")

else:
    print(f" No prediction data available for video creation")


## Analysis and Evaluation

In [None]:

if len(X) > 0 and len(y) > 0:
    
    # Test predictions on a sample of sequences
    sample_size = min(10, len(X))
    sample_indices = np.random.choice(len(X), sample_size, replace=False)
    
    all_errors = []
    prediction_analysis = []
    
    print(f"Testing LSTM on {sample_size} sequences")
    
    for i, idx in enumerate(sample_indices):
        # Get sequence and target
        sequence = X[idx]
        target = y[idx]
        meta = metadata[idx]
        
        # Make prediction
        prediction = predictor.predict_trajectory(sequence)
        
        # Calculate errors
        errors = []
        for pred_frame in range(len(prediction)):
            error = np.sqrt((prediction[pred_frame][0] - target[pred_frame][0])**2 + 
                           (prediction[pred_frame][1] - target[pred_frame][1])**2)
            errors.append(error)
        
        all_errors.extend(errors)
        
        # Store analysis
        analysis = {
            'video': meta['video_name'],
            'start_frame': meta['start_frame'],
            'prediction_frame': meta['prediction_frame'],
            'avg_error_pixels': np.mean(errors),
            'max_error_pixels': np.max(errors),
            'errors_by_frame': errors
        }
        prediction_analysis.append(analysis)
        
        if i < 3:  # Show details for first 3
            print(f"\n📍 Sequence {i+1} ({meta['video_name']})")
            print(f"   Frames: {meta['start_frame']} → {meta['prediction_frame']} → {meta['end_frame']}")
            print(f"   Avg Error: {np.mean(errors):.1f} pixels ({np.mean(errors)/100:.2f}m)")
            print(f"   Max Error: {np.max(errors):.1f} pixels ({np.max(errors)/100:.2f}m)")

    print(f"\nPREDICTION PERFORMANCE:")
    print(f"   Sequences tested: {sample_size}")
    print(f"   Average error: {np.mean(all_errors):.1f} pixels ({np.mean(all_errors)/100:.2f} meters)")
    print(f"   Median error: {np.median(all_errors):.1f} pixels ({np.median(all_errors)/100:.2f} meters)")
    print(f"   Max error: {np.max(all_errors):.1f} pixels ({np.max(all_errors)/100:.2f} meters)")
    print(f"   Error std: {np.std(all_errors):.1f} pixels")

    print(f"\nError by Prediction Horizon:")
    for frame_ahead in range(5):
        frame_errors = []
        for analysis in prediction_analysis:
            if frame_ahead < len(analysis['errors_by_frame']):
                frame_errors.append(analysis['errors_by_frame'][frame_ahead])
        
        if frame_errors:
            avg_error = np.mean(frame_errors)
            time_ahead = (frame_ahead + 1) / 30.0  # Convert to seconds
            print(f"   +{frame_ahead+1} frame ({time_ahead:.2f}s): {avg_error:.1f} pixels ({avg_error/100:.2f}m)")
    
    # Create error distribution plot
    if len(all_errors) > 0:
        print(f"\n Error distribution plot")
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
        
        # Error histogram
        ax1.hist(all_errors, bins=20, edgecolor='black', alpha=0.7)
        ax1.set_title('LSTM Prediction Error Distribution', fontweight='bold')
        ax1.set_xlabel('Error (pixels)')
        ax1.set_ylabel('Frequency')
        ax1.axvline(np.mean(all_errors), color='red', linestyle='--', 
                   label=f'Mean: {np.mean(all_errors):.1f}px')
        ax1.axvline(np.median(all_errors), color='green', linestyle='--', 
                   label=f'Median: {np.median(all_errors):.1f}px')
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # Error by horizon
        horizon_errors = [[] for _ in range(5)]
        for analysis in prediction_analysis:
            for frame_idx, error in enumerate(analysis['errors_by_frame']):
                if frame_idx < 5:
                    horizon_errors[frame_idx].append(error)
        
        horizon_means = [np.mean(errors) if errors else 0 for errors in horizon_errors]
        horizon_stds = [np.std(errors) if errors else 0 for errors in horizon_errors]
        
        x_pos = np.arange(1, 6)
        ax2.bar(x_pos, horizon_means, yerr=horizon_stds, capsize=5, 
               edgecolor='black', alpha=0.7)
        ax2.set_title('Error by Prediction Horizon', fontweight='bold')
        ax2.set_xlabel('Frames Ahead')
        ax2.set_ylabel('Average Error (pixels)')
        ax2.set_xticks(x_pos)
        ax2.grid(True, alpha=0.3)
        
        # Add time labels
        ax2_top = ax2.twiny()
        ax2_top.set_xlim(ax2.get_xlim())
        ax2_top.set_xticks(x_pos)
        ax2_top.set_xticklabels([f'{i/30:.2f}s' for i in x_pos])
        ax2_top.set_xlabel('Time Ahead (seconds)')
        
        plt.tight_layout()
        plt.show()
        
        print(f"Error analysis complete!")

else:
    print(f"No data available")