<a href="https://colab.research.google.com/github/WaliSiddiqui1/BPM/blob/main/Training_NSVA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
print("Mounting Google Drive...")
from google.colab import drive
drive.mount('/content/drive')

Mounting Google Drive...
Mounted at /content/drive


In [None]:
# NBA Sports Video Analysis - Simplified Training
# Adapted for Google Colab environment

# ===== 1. Setup and Dependencies =====

!pip install -q webvtt-py
!pip install -q rouge
!pip install -q nltk
!pip install -q transformers


import os
import sys
import numpy as np
import pandas as pd
import json
import cv2
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
from datetime import datetime
from transformers import BertTokenizer
import nltk
from nltk.translate.bleu_score import corpus_bleu, sentence_bleu
from rouge import Rouge

# Import nltk packages
nltk.download('punkt')

# ===== 2. Directory Setup =====
# Define paths for data and outputs

BASE_DIR = '/content/drive/MyDrive/NSVA_Results/'
FEATURES_DIR = os.path.join(BASE_DIR, 'features')
ANNOTATIONS_DIR = os.path.join(BASE_DIR, 'annotations')
METADATA_DIR = os.path.join(BASE_DIR, 'metadata')
CHECKPOINTS_DIR = os.path.join(BASE_DIR, 'checkpoints')
RESULTS_DIR = os.path.join(BASE_DIR, 'results')

# Create directories if they don't exist
for directory in [FEATURES_DIR, ANNOTATIONS_DIR, METADATA_DIR, CHECKPOINTS_DIR, RESULTS_DIR]:
    os.makedirs(directory, exist_ok=True)

# ===== 3. Define Simplified Model Architecture =====
# Simplified version of the model described in the paper

class SimplifiedSportsVideoUnderstandingModel(keras.Model):
    """
    A simplified version of the Sports Video Understanding model from the paper.
    This model uses pre-extracted features and focuses on:
    1. Encoding global video features (e.g., from ResNet or pre-extracted TimeSformer)
    2. Incorporating object-level features (ball, player, basket)
    3. Generating captions using a transformer-based decoder
    """
    def __init__(
        self,
        vocab_size,
        max_frames=100,
        max_words=30,
        embed_dim=256,  # Reduced from 768 in original
        num_heads=4,    # Reduced from 12 in original
        decoder_layers=2,  # Reduced from 3 in original
        dropout_rate=0.1,
        **kwargs
    ):
        super(SimplifiedSportsVideoUnderstandingModel, self).__init__(**kwargs)
        self.vocab_size = vocab_size
        self.max_frames = max_frames
        self.max_words = max_words
        self.embed_dim = embed_dim

        # Feature projection layers (project pre-extracted features to common dimension)
        self.video_projection = keras.layers.Dense(embed_dim)
        self.ball_projection = keras.layers.Dense(embed_dim)
        self.player_projection = keras.layers.Dense(embed_dim)
        self.basket_projection = keras.layers.Dense(embed_dim)
        self.court_projection = keras.layers.Dense(embed_dim)

        # Temporal aggregation (simpler than TimeSformer)
        self.temporal_aggregation = keras.Sequential([
            keras.layers.LayerNormalization(epsilon=1e-6),
            keras.layers.Bidirectional(keras.layers.LSTM(embed_dim // 2, return_sequences=True)),
            keras.layers.LayerNormalization(epsilon=1e-6)
        ])

        # Multi-modal fusion
        self.fusion_layer = keras.Sequential([
            keras.layers.LayerNormalization(epsilon=1e-6),
            keras.layers.Dense(embed_dim * 2, activation='relu'),
            keras.layers.Dropout(dropout_rate),
            keras.layers.Dense(embed_dim),
            keras.layers.LayerNormalization(epsilon=1e-6)
        ])

        # Decoder for caption generation
        self.decoder_embedding = keras.layers.Embedding(vocab_size, embed_dim)
        self.positional_encoding = PositionalEncoding(max_words, embed_dim)

        # Multi-head attention layers for decoder
        self.decoder_layers = []
        for i in range(decoder_layers):
            self.decoder_layers.append(DecoderLayer(
                embed_dim, num_heads, embed_dim * 4, dropout_rate
            ))

        self.decoder_layernorm = keras.layers.LayerNormalization(epsilon=1e-6)
        self.output_projection = keras.layers.Dense(vocab_size)

        # Task-specific heads for additional tasks
        self.action_classifier = keras.layers.Dense(14)  # 14 action classes
        self.player_classifier = keras.layers.Dense(184)  # 184 player identities

    def call(self, inputs, training=False):
        """Forward pass through the model"""
        (video_features, video_mask,
         ball_features, player_features, basket_features, court_features,
         input_caption_ids, decoder_mask) = inputs

        # Process global video features
        video_embed = self.video_projection(video_features)

        # Process object features
        ball_embed = self.ball_projection(ball_features)

        # For player features, we need to average across players
        # Reshape from [batch, frames, players, dims] to [batch*frames, players, dims]
        batch_size = tf.shape(player_features)[0]
        frames = tf.shape(player_features)[1]

        player_features_flat = tf.reshape(player_features,
                                         [batch_size * frames, -1, tf.shape(player_features)[-1]])
        player_embed_flat = self.player_projection(player_features_flat)
        # Average over players dimension
        player_embed_flat = tf.reduce_mean(player_embed_flat, axis=1)
        # Reshape back to [batch, frames, dims]
        player_embed = tf.reshape(player_embed_flat, [batch_size, frames, self.embed_dim])

        basket_embed = self.basket_projection(basket_features)
        court_embed = self.court_projection(court_features)

        # Combine features
        combined_features = video_embed + ball_embed + player_embed + basket_embed + court_embed

        # Apply temporal aggregation
        sequence_features = self.temporal_aggregation(combined_features, training=training)

        # Apply fusion layer
        fused_features = self.fusion_layer(sequence_features, training=training)

        # Process through decoder
        decoder_embedding = self.decoder_embedding(input_caption_ids)
        decoder_embedding = self.positional_encoding(decoder_embedding)

        decoder_output = decoder_embedding

        # Process through each decoder layer
        for decoder_layer in self.decoder_layers:
            decoder_output = decoder_layer(
                decoder_output,
                fused_features,
                look_ahead_mask=create_look_ahead_mask(tf.shape(input_caption_ids)[1]),
                padding_mask=create_padding_mask(decoder_mask),
                encoder_padding_mask=create_padding_mask(video_mask),
                training=training
            )

        decoder_output = self.decoder_layernorm(decoder_output)

        # Generate logits
        caption_logits = self.output_projection(decoder_output)

        # Feature for action and player recognition (use first token of encoder output)
        sequence_representation = fused_features[:, 0, :]

        # Generate logits for other tasks
        action_logits = self.action_classifier(sequence_representation)
        player_logits = self.player_classifier(sequence_representation)

        return {
            "caption_logits": caption_logits,
            "action_logits": action_logits,
            "player_logits": player_logits,
            "encoder_output": fused_features
        }

    def generate_caption(self, encoder_output, tokenizer, max_length=30, beam_size=3):
        """Generate captions using beam search decoding"""
        batch_size = tf.shape(encoder_output)[0]

        # Initialize with start token
        start_token = tokenizer.cls_token_id
        end_token = tokenizer.sep_token_id

        # Initial decoder input and state
        decoder_input = tf.expand_dims([start_token] * batch_size, 1)  # [batch_size, 1]
        decoder_mask = tf.ones_like(decoder_input)

        # Storage for beams
        beams = [(decoder_input, 0.0, decoder_mask)]  # (sequence, score, mask)
        finished_beams = []

        # Beam search
        for step in range(max_length - 1):
            candidates = []

            for seq, score, mask in beams:
                # Skip if sequence is finished (has end token)
                if seq[0, -1] == end_token:
                    finished_beams.append((seq, score, mask))
                    continue

                # Predict next token
                inputs = (
                    tf.zeros((batch_size, self.max_frames, 768), dtype=tf.float32),  # Placeholder for video_features - Changed to (batch_size, max_frames, 768)
                    tf.ones((batch_size, self.max_frames), dtype=tf.int32),  # video_mask
                    tf.zeros((batch_size, self.max_frames, 768), dtype=tf.float32),  # Placeholder for ball_features - Changed to (batch_size, max_frames, 768)
                    tf.zeros((batch_size, self.max_frames, 5, 768), dtype=tf.float32),  # Placeholder for player_features - Changed to (batch_size, max_frames, 5, 768)
                    tf.zeros((batch_size, self.max_frames, 768), dtype=tf.float32),  # Placeholder for basket_features - Changed to (batch_size, max_frames, 768)
                    tf.zeros((batch_size, self.max_frames, 768), dtype=tf.float32),  # Placeholder for court_features - Changed to (batch_size, max_frames, 768)
                    seq,
                    mask
                )

                # Get logits from decoder
                outputs = self.call(inputs, training=False)
                logits = outputs["caption_logits"][:, -1, :]  # Last token prediction

                # Get top k tokens
                topk_logits, topk_indices = tf.math.top_k(logits, k=beam_size)
                topk_probs = tf.nn.softmax(topk_logits)

                # Add candidates
                for i in range(beam_size):
                    token = topk_indices[0, i]
                    prob = topk_probs[0, i]

                    new_seq = tf.concat([seq, tf.expand_dims([token], 1)], axis=1)
                    new_mask = tf.concat([mask, tf.ones((tf.shape(mask)[0], 1), dtype=tf.int32)], axis=1)
                    new_score = score - tf.math.log(prob + 1e-10)  # Lower is better (negative log prob)

                    candidates.append((new_seq, new_score, new_mask))

                    # If end token, add to finished beams
                    if token == end_token:
                        finished_beams.append((new_seq, new_score, new_mask))

            # Keep top beam_size candidates
            candidates.sort(key=lambda x: x[1])  # Sort by score (lower is better)
            beams = candidates[:beam_size]

            # Early stopping if all beams are finished
            if all(beam[0][0, -1] == end_token for beam in beams):
                break

        # Add unfinished beams to finished beams
        for beam in beams:
            if beam[0][0, -1] != end_token:
                finished_beams.append(beam)

        # Return the best beam
        if finished_beams:
            finished_beams.sort(key=lambda x: x[1])  # Sort by score
            best_seq = finished_beams[0][0][0].numpy()  # First in batch
            caption = tokenizer.decode(best_seq, skip_special_tokens=True)
        else:
            caption = ""

        return caption

class PositionalEncoding(keras.layers.Layer):
    """Positional encoding layer for transformer decoder"""
    def __init__(self, max_length, d_model):
        super(PositionalEncoding, self).__init__()
        self.max_length = max_length
        self.d_model = d_model
        self.pos_encoding = self.positional_encoding(max_length, d_model)

    def get_angles(self, pos, i, d_model):
        angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
        return pos * angle_rates

    def positional_encoding(self, max_length, d_model):
        angle_rads = self.get_angles(
            np.arange(max_length)[:, np.newaxis],
            np.arange(d_model)[np.newaxis, :],
            d_model
        )

        # Apply sin to even indices in the array; 2i
        angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])

        # Apply cos to odd indices in the array; 2i+1
        angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])

        pos_encoding = angle_rads[np.newaxis, ...]

        return tf.cast(pos_encoding, dtype=tf.float32)

    def call(self, inputs):
        seq_len = tf.shape(inputs)[1]
        return inputs + self.pos_encoding[:, :seq_len, :]

class DecoderLayer(keras.layers.Layer):
    """Decoder layer with self-attention and cross-attention"""
    def __init__(self, d_model, num_heads, dff, dropout_rate=0.1):
        super(DecoderLayer, self).__init__()

        self.self_attention = keras.layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=d_model // num_heads, dropout=dropout_rate
        )
        self.cross_attention = keras.layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=d_model // num_heads, dropout=dropout_rate
        )

        self.ffn = keras.Sequential([
            keras.layers.Dense(dff, activation='relu'),
            keras.layers.Dense(d_model)
        ])

        self.layernorm1 = keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm3 = keras.layers.LayerNormalization(epsilon=1e-6)

        self.dropout1 = keras.layers.Dropout(dropout_rate)
        self.dropout2 = keras.layers.Dropout(dropout_rate)
        self.dropout3 = keras.layers.Dropout(dropout_rate)

    def call(self, inputs, encoder_output, look_ahead_mask=None,
             padding_mask=None, encoder_padding_mask=None, training=True):
        # Self-attention
        attn1 = self.self_attention(
            query=inputs,
            key=inputs,
            value=inputs,
            attention_mask=look_ahead_mask,
            training=training
        )
        attn1 = self.dropout1(attn1, training=training)
        out1 = self.layernorm1(inputs + attn1)

        # Cross-attention
        attn2 = self.cross_attention(
            query=out1,
            key=encoder_output,
            value=encoder_output,
            attention_mask=encoder_padding_mask,
            training=training
        )
        attn2 = self.dropout2(attn2, training=training)
        out2 = self.layernorm2(out1 + attn2)

        # Feed-forward layer
        ffn_output = self.ffn(out2, training=training)
        ffn_output = self.dropout3(ffn_output, training=training)
        out3 = self.layernorm3(out2 + ffn_output)

        return out3

def create_look_ahead_mask(size):
    """Create a look-ahead mask for transformer decoder"""
    mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
    return mask[tf.newaxis, tf.newaxis, :, :]  # (1, 1, size, size)

def create_padding_mask(mask):
    """Convert a 1D mask to a 2D padding mask for attention"""
    return mask[:, tf.newaxis, tf.newaxis, :]  # (batch_size, 1, 1, seq_len)

# ===== 4. Data Processing and Training Functions =====

class NSVADataset:
    """
    Simplified data loader for the NSVA dataset
    Loads pre-extracted features and captions
    """
    def __init__(
        self,
        annotations_file,
        video_features_dir,
        ball_features_dir,
        player_features_dir,
        basket_features_dir,
        court_features_dir,
        tokenizer,
        max_frames=100,
        max_words=30,
        split='train',
        split_file=None
    ):
        self.video_features_dir = video_features_dir
        self.ball_features_dir = ball_features_dir
        self.player_features_dir = player_features_dir
        self.basket_features_dir = basket_features_dir
        self.court_features_dir = court_features_dir
        self.tokenizer = tokenizer
        self.max_frames = max_frames
        self.max_words = max_words

        # Load annotations
        with open(annotations_file, 'r') as f:
            self.annotations = json.load(f)

        # Load split information
        if split_file:
            with open(split_file, 'r') as f:
                splits = json.load(f)
                self.video_ids = splits[split]
        else:
            # Use all videos if no split file is provided
            self.video_ids = list(set([s['video_id'] for s in self.annotations['sentences']]))

        # Filter annotations by video IDs in the split
        self.filtered_annotations = [
            s for s in self.annotations['sentences']
            if s['video_id'] in self.video_ids
        ]

        print(f"Loaded {len(self.filtered_annotations)} annotations for {len(self.video_ids)} videos in {split} split")

        # Check which videos have features extracted
        self.available_videos = set()
        for video_id in self.video_ids:
            if (
                os.path.exists(os.path.join(self.video_features_dir, f"{video_id}.npy")) and
                os.path.exists(os.path.join(self.ball_features_dir, f"{video_id}.npy")) and
                os.path.exists(os.path.join(self.player_features_dir, f"{video_id}.npy")) and
                os.path.exists(os.path.join(self.basket_features_dir, f"{video_id}.npy")) and
                os.path.exists(os.path.join(self.court_features_dir, f"{video_id}.npy"))
            ):
                self.available_videos.add(video_id)

        # Final filter for annotations with available features
        self.filtered_annotations = [
            s for s in self.filtered_annotations
            if s['video_id'] in self.available_videos
        ]

        print(f"Found {len(self.available_videos)} videos with extracted features")
        print(f"Final number of annotations: {len(self.filtered_annotations)}")

    def __len__(self):
        return len(self.filtered_annotations)

    def load_features(self, video_id):
        """Load all features for a video"""
        video_features = np.load(os.path.join(self.video_features_dir, f"{video_id}.npy"))
        ball_features = np.load(os.path.join(self.ball_features_dir, f"{video_id}.npy"))
        player_features = np.load(os.path.join(self.player_features_dir, f"{video_id}.npy"))
        basket_features = np.load(os.path.join(self.basket_features_dir, f"{video_id}.npy"))
        court_features = np.load(os.path.join(self.court_features_dir, f"{video_id}.npy"))

        # Truncate to max_frames
        if video_features.shape[0] > self.max_frames:
            video_features = video_features[:self.max_frames]
            ball_features = ball_features[:self.max_frames]
            player_features = player_features[:self.max_frames]
            basket_features = basket_features[:self.max_frames]
            court_features = court_features[:self.max_frames]

        # Create mask
        frames = video_features.shape[0]
        video_mask = np.ones(frames, dtype=np.int32)

        # Pad if necessary
        if frames < self.max_frames:
            pad_len = self.max_frames - frames

            # Pad with zeros
            video_features = np.pad(video_features, ((0, pad_len), (0, 0)), mode='constant')
            ball_features = np.pad(ball_features, ((0, pad_len), (0, 0)), mode='constant')
            player_features = np.pad(player_features, ((0, pad_len), (0, 0), (0, 0)), mode='constant')
            basket_features = np.pad(basket_features, ((0, pad_len), (0, 0)), mode='constant')
            court_features = np.pad(court_features, ((0, pad_len), (0, 0)), mode='constant')

            # Extend mask
            video_mask = np.pad(video_mask, (0, pad_len), mode='constant')

        return video_features, video_mask, ball_features, player_features, basket_features, court_features

    def tokenize_caption(self, caption):
        """Tokenize caption and create input/output ids for training"""
        # Tokenize
        tokenized = self.tokenizer(
            caption,
            padding="max_length",
            truncation=True,
            max_length=self.max_words,
            return_tensors="tf"
        )

        input_ids = tokenized["input_ids"][0]
        attention_mask = tokenized["attention_mask"][0]

        # For decoder input, shift right and add start token
        decoder_input_ids = tf.concat([
            [self.tokenizer.cls_token_id],
            input_ids[:-1]
        ], axis=0)

        # Decoder mask is same as attention mask
        decoder_mask = attention_mask

        # Target is the original input_ids
        target_ids = input_ids

        return decoder_input_ids, decoder_mask, target_ids

    def __getitem__(self, idx):
        """Get a single sample from the dataset"""
        annotation = self.filtered_annotations[idx]
        video_id = annotation['video_id']
        caption = annotation['caption']

        # Load features
        video_features, video_mask, ball_features, player_features, basket_features, court_features = self.load_features(video_id)

        # Tokenize caption
        decoder_input_ids, decoder_mask, target_ids = self.tokenize_caption(caption)

        return (
            video_features.astype(np.float32),
            video_mask,
            ball_features.astype(np.float32),
            player_features.astype(np.float32),
            basket_features.astype(np.float32),
            court_features.astype(np.float32),
            decoder_input_ids,
            decoder_mask
        ), target_ids

    def create_tf_dataset(self, batch_size=32, shuffle=True, buffer_size=1000):
        """Create a TensorFlow dataset from samples"""
        def generator():
            indices = list(range(len(self)))
            if shuffle:
                import random
                random.shuffle(indices)

            for idx in indices:
                yield self[idx]

        # Define output shapes
        output_shapes = (
            (
                tf.TensorShape([self.max_frames, None]),  # video_features
                tf.TensorShape([self.max_frames]),  # video_mask
                tf.TensorShape([self.max_frames, None]),  # ball_features
                tf.TensorShape([self.max_frames, None, None]),  # player_features
                tf.TensorShape([self.max_frames, None]),  # basket_features
                tf.TensorShape([self.max_frames, None]),  # court_features
                tf.TensorShape([self.max_words]),  # decoder_input_ids
                tf.TensorShape([self.max_words])   # decoder_mask
            ),
            tf.TensorShape([self.max_words])  # target_ids
        )

        # Define output types
        output_types = (
            (
                tf.float32,  # video_features
                tf.int32,    # video_mask
                tf.float32,  # ball_features
                tf.float32,  # player_features
                tf.float32,  # basket_features
                tf.float32,  # court_features
                tf.int32,    # decoder_input_ids
                tf.int32     # decoder_mask
            ),
            tf.int32  # target_ids
        )

        # Create dataset
        dataset = tf.data.Dataset.from_generator(
            generator,
            output_types=output_types,
            output_shapes=output_shapes
        )

        if shuffle:
            dataset = dataset.shuffle(buffer_size)

        # Batch and prefetch
        dataset = dataset.batch(batch_size)
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

        return dataset

class NSVATrainer:
    """Trainer for the NSVA captioning model"""
    def __init__(
        self,
        model,
        train_dataset,
        val_dataset=None,
        learning_rate=1e-4,
        beta_1=0.9,
        beta_2=0.999,
        weight_decay=0.01,
        checkpoint_dir=None,
        log_dir=None
    ):
        self.model = model
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset

        # Setup optimizer
        self.learning_rate = learning_rate
        self.optimizer = keras.optimizers.Adam(
            learning_rate=learning_rate,
            beta_1=beta_1,
            beta_2=beta_2
        )

        # Setup loss function
        self.loss_fn = keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction='none'
        )

        # Setup metrics
        self.train_loss = keras.metrics.Mean(name='train_loss')
        self.train_accuracy = keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

        self.val_loss = keras.metrics.Mean(name='val_loss')
        self.val_accuracy = keras.metrics.SparseCategoricalAccuracy(name='val_accuracy')

        # Setup checkpointing
        self.checkpoint_dir = checkpoint_dir
        if checkpoint_dir:
            os.makedirs(checkpoint_dir, exist_ok=True)
            self.checkpoint = tf.train.Checkpoint(
                optimizer=self.optimizer,
                model=self.model,
                step=tf.Variable(0),
                epoch=tf.Variable(0)
            )
            self.checkpoint_manager = tf.train.CheckpointManager(
                self.checkpoint, checkpoint_dir, max_to_keep=5
            )

            self.restore_checkpoint()

        # Setup tensorboard
        self.log_dir = log_dir
        if log_dir:
            os.makedirs(log_dir, exist_ok=True)
            self.summary_writer = tf.summary.create_file_writer(log_dir)

    def restore_checkpoint(self):
        """Restore from latest checkpoint"""
        if self.checkpoint_manager.latest_checkpoint:
            self.checkpoint.restore(self.checkpoint_manager.latest_checkpoint)
            print(f"Restored from checkpoint: {self.checkpoint_manager.latest_checkpoint}")
            print(f"Starting from epoch {int(self.checkpoint.epoch.numpy())}")
            return True
        else:
            print("Initializing from scratch.")
            return False

    @tf.function
    def train_step(self, inputs, targets):
        """Single training step"""
        with tf.GradientTape() as tape:
            # Forward pass
            predictions = self.model(inputs, training=True)
            caption_logits = predictions["caption_logits"]

            # Calculate loss
            mask = tf.cast(inputs[7], tf.float32)  # decoder_mask
            loss = self.loss_fn(targets, caption_logits)
            loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask)

        # Backward pass
        gradients = tape.gradient(loss, self.model.trainable_variables)

        # Apply gradients
        self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))

        # Update metrics
        self.train_loss.update_state(loss)
        self.train_accuracy.update_state(targets, caption_logits, sample_weight=mask)

        return loss

    @tf.function
    def validate_step(self, inputs, targets):
        """Single validation step"""
        # Forward pass
        predictions = self.model(inputs, training=False)
        caption_logits = predictions["caption_logits"]

        # Calculate loss
        mask = tf.cast(inputs[7], tf.float32)  # decoder_mask
        loss = self.loss_fn(targets, caption_logits)
        loss = tf.reduce_sum(loss * mask) / tf.reduce_sum(mask)

        # Update metrics
        self.val_loss.update_state(loss)
        self.val_accuracy.update_state(targets, caption_logits, sample_weight=mask)

        return loss

    def train(self, epochs, eval_freq=1):
        """Train the model for the specified number of epochs"""
        start_epoch = 0
        if hasattr(self, 'checkpoint'):
            start_epoch = int(self.checkpoint.epoch.numpy())

        for epoch in range(start_epoch, start_epoch + epochs):
            print(f"\nEpoch {epoch + 1}/{start_epoch + epochs}")

            # Reset metrics
            self.train_loss.reset_state()
            self.train_accuracy.reset_state()

            # Training loop
            start_time = time.time()
            for step, (inputs, targets) in enumerate(self.train_dataset):
                loss = self.train_step(inputs, targets)

                # Update checkpoint step
                if hasattr(self, 'checkpoint'):
                    self.checkpoint.step.assign_add(1)

                # Print progress
                if step % 10 == 0:
                    print(f"Step {step}: Loss = {loss:.4f}, " +
                          f"Accuracy = {self.train_accuracy.result():.4f}")

            train_time = time.time() - start_time

            # Print epoch results
            print(f"Training time: {train_time:.2f}s")
            print(f"Train Loss: {self.train_loss.result():.4f}")
            print(f"Train Accuracy: {self.train_accuracy.result():.4f}")

            # Validation
            if self.val_dataset and (epoch + 1) % eval_freq == 0:
                self.validate()

            # Save checkpoint
            if hasattr(self, 'checkpoint'):
                self.checkpoint.epoch.assign_add(1)
                save_path = self.checkpoint_manager.save()
                print(f"Saved checkpoint at: {save_path}")

            # Write to tensorboard
            if hasattr(self, 'summary_writer'):
                with self.summary_writer.as_default():
                    tf.summary.scalar('train_loss', self.train_loss.result(), step=epoch)
                    tf.summary.scalar('train_accuracy', self.train_accuracy.result(), step=epoch)

                    if self.val_dataset and (epoch + 1) % eval_freq == 0:
                        tf.summary.scalar('val_loss', self.val_loss.result(), step=epoch)
                        tf.summary.scalar('val_accuracy', self.val_accuracy.result(), step=epoch)

    def validate(self):
        """Run validation and calculate metrics"""
        print("\nRunning validation...")

        # Reset metrics
        self.val_loss.reset_state()
        self.val_accuracy.reset_state()

        # Validation loop
        start_time = time.time()
        for inputs, targets in self.val_dataset:
            self.validate_step(inputs, targets)

        val_time = time.time() - start_time

        # Print results
        print(f"Validation time: {val_time:.2f}s")
        print(f"Validation Loss: {self.val_loss.result():.4f}")
        print(f"Validation Accuracy: {self.val_accuracy.result():.4f}")

        return self.val_loss.result().numpy()

    def evaluate_captions(self, dataset, tokenizer, num_samples=50):
        """Generate captions and calculate BLEU and ROUGE scores"""
        print(f"\nGenerating captions for {num_samples} samples...")

        references = []
        hypotheses = []

        rouge = Rouge()

        for i, (inputs, targets) in enumerate(dataset):
            if i >= num_samples:
                break

            # Get ground truth caption
            gt_caption = tokenizer.decode(targets[0].numpy(), skip_special_tokens=True)

            # Generate caption
            video_features, video_mask, ball_features, player_features, basket_features, court_features, _, _ = inputs

            # Forward pass to get encoder output
            encoder_inputs = (
                video_features, video_mask,
                ball_features, player_features, basket_features, court_features,
                tf.zeros_like(targets[:1]),
                tf.zeros_like(targets[:1])
            )

            outputs = self.model(encoder_inputs, training=False)
            encoder_output = outputs["encoder_output"]

            # Generate caption using beam search
            pred_caption = self.model.generate_caption(encoder_output, tokenizer)

            references.append([gt_caption.split()])
            hypotheses.append(pred_caption.split())

            # Print examples
            if i < 5:
                print(f"\nExample {i+1}:")
                print(f"Ground truth: {gt_caption}")
                print(f"Prediction: {pred_caption}")

        # Calculate BLEU scores
        bleu1 = corpus_bleu(references, hypotheses, weights=(1, 0, 0, 0))
        bleu2 = corpus_bleu(references, hypotheses, weights=(0.5, 0.5, 0, 0))
        bleu3 = corpus_bleu(references, hypotheses, weights=(0.33, 0.33, 0.33, 0))
        bleu4 = corpus_bleu(references, hypotheses, weights=(0.25, 0.25, 0.25, 0.25))

        # Calculate ROUGE scores
        rouge_scores = {}
        try:
            # Convert tokenized hypotheses back to strings for ROUGE
            hyp_texts = [' '.join(h) for h in hypotheses]
            ref_texts = [' '.join(r[0]) for r in references]

            rouge_scores = rouge.get_scores(hyp_texts, ref_texts, avg=True)
        except Exception as e:
            print(f"Error calculating ROUGE: {e}")

        # Print metrics
        print("\nEvaluation Metrics:")
        print(f"BLEU-1: {bleu1:.4f}")
        print(f"BLEU-2: {bleu2:.4f}")
        print(f"BLEU-3: {bleu3:.4f}")
        print(f"BLEU-4: {bleu4:.4f}")

        if rouge_scores:
            print(f"ROUGE-1: {rouge_scores['rouge-1']['f']:.4f}")
            print(f"ROUGE-2: {rouge_scores['rouge-2']['f']:.4f}")
            print(f"ROUGE-L: {rouge_scores['rouge-l']['f']:.4f}")

        # Return metrics dictionary
        metrics = {
            'bleu1': bleu1,
            'bleu2': bleu2,
            'bleu3': bleu3,
            'bleu4': bleu4
        }

        if rouge_scores:
            metrics['rouge1'] = rouge_scores['rouge-1']['f']
            metrics['rouge2'] = rouge_scores['rouge-2']['f']
            metrics['rougeL'] = rouge_scores['rouge-l']['f']

        return metrics

# ===== 5. Main Training Script =====

def main():
    # Configure paths
    annotations_file = os.path.join(ANNOTATIONS_DIR, 'annotations.json')
    split_file = os.path.join(METADATA_DIR, 'splits.json')

    video_features_dir = os.path.join(FEATURES_DIR, 'timesformer')
    ball_features_dir = os.path.join(FEATURES_DIR, 'ball')
    player_features_dir = os.path.join(FEATURES_DIR, 'player')
    basket_features_dir = os.path.join(FEATURES_DIR, 'basket')
    court_features_dir = os.path.join(FEATURES_DIR, 'court')

    # Configure model parameters
    max_frames = 100
    max_words = 30
    batch_size = 32
    embed_dim = 256
    num_heads = 4
    decoder_layers = 2

    # Initialize tokenizer
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    vocab_size = tokenizer.vocab_size

    # Create datasets
    print("Creating datasets...")
    train_dataset = NSVADataset(
        annotations_file=annotations_file,
        video_features_dir=video_features_dir,
        ball_features_dir=ball_features_dir,
        player_features_dir=player_features_dir,
        basket_features_dir=basket_features_dir,
        court_features_dir=court_features_dir,
        tokenizer=tokenizer,
        max_frames=max_frames,
        max_words=max_words,
        split='train',
        split_file=split_file
    )

    val_dataset = NSVADataset(
        annotations_file=annotations_file,
        video_features_dir=video_features_dir,
        ball_features_dir=ball_features_dir,
        player_features_dir=player_features_dir,
        basket_features_dir=basket_features_dir,
        court_features_dir=court_features_dir,
        tokenizer=tokenizer,
        max_frames=max_frames,
        max_words=max_words,
        split='val',
        split_file=split_file
    )

    # Create test dataset
    test_dataset = NSVADataset(
        annotations_file=annotations_file,
        video_features_dir=video_features_dir,
        ball_features_dir=ball_features_dir,
        player_features_dir=player_features_dir,
        basket_features_dir=basket_features_dir,
        court_features_dir=court_features_dir,
        tokenizer=tokenizer,
        max_frames=max_frames,
        max_words=max_words,
        split='test',
        split_file=split_file
    )

    # Create TF datasets
    train_tf_dataset = train_dataset.create_tf_dataset(batch_size=batch_size)
    val_tf_dataset = val_dataset.create_tf_dataset(batch_size=batch_size, shuffle=False)
    test_tf_dataset = test_dataset.create_tf_dataset(batch_size=batch_size, shuffle=False)

    # Initialize model
    print("Initializing model...")
    model = SimplifiedSportsVideoUnderstandingModel(
        vocab_size=vocab_size,
        max_frames=max_frames,
        max_words=max_words,
        embed_dim=embed_dim,
        num_heads=num_heads,
        decoder_layers=decoder_layers
    )

    # Configure training
    trainer = NSVATrainer(
        model=model,
        train_dataset=train_tf_dataset,
        val_dataset=val_tf_dataset,
        learning_rate=1e-4,
        checkpoint_dir=CHECKPOINTS_DIR,
        log_dir=os.path.join(RESULTS_DIR, 'logs')
    )

    # dummy_input = next(iter(train_tf_dataset))[0]
    # _ = model(dummy_input)

    # model_path = os.path.join(RESULTS_DIR, 'final_model.weights.h5')
    # model.load_weights(model_path)
    # print(f"Loaded model weights from {model_path}")

    # Train model
    epochs = 1
    print(f"Training for {epochs} epochs...")
    history = {
        'train_loss': [],
        'train_accuracy': [],
        'val_loss': [],
        'val_accuracy': []
    }

    for epoch in range(epochs):
        print(f"\nEpoch {epoch + 1}/{epochs}")

        # Train for one epoch
        trainer.train_loss.reset_state()
        trainer.train_accuracy.reset_state()

        for inputs, targets in train_tf_dataset:
            loss = trainer.train_step(inputs, targets)

        # Validate
        val_loss = trainer.validate()

        # Record metrics for plotting
        history['train_loss'].append(trainer.train_loss.result().numpy())
        history['train_accuracy'].append(trainer.train_accuracy.result().numpy())
        history['val_loss'].append(trainer.val_loss.result().numpy())
        history['val_accuracy'].append(trainer.val_accuracy.result().numpy())

        # Save checkpoint
        if hasattr(trainer, 'checkpoint'):
            trainer.checkpoint.epoch.assign_add(1)
            save_path = trainer.checkpoint_manager.save()
            print(f"Saved checkpoint at: {save_path}")

    # Save final model
    model_path = os.path.join(RESULTS_DIR, 'final_model.weights.h5')
    model.save_weights(model_path)
    print(f"Model saved to {model_path}")

    # Evaluate on validation set
    print("\n===== VALIDATION RESULTS =====")
    val_metrics = trainer.evaluate_captions(val_tf_dataset, tokenizer, num_samples=50)

    # Evaluate on test set
    print("\n===== TEST RESULTS =====")
    test_metrics = trainer.evaluate_captions(test_tf_dataset, tokenizer, num_samples=100)

    # Save metrics
    metrics = {
        'validation': val_metrics,
        'test': test_metrics,
        'history': {
            'train_loss': [float(x) for x in history['train_loss']],
            'train_accuracy': [float(x) for x in history['train_accuracy']],
            'val_loss': [float(x) for x in history['val_loss']],
            'val_accuracy': [float(x) for x in history['val_accuracy']]
        }
    }

    metrics_file = os.path.join(RESULTS_DIR, 'metrics.json')
    with open(metrics_file, 'w') as f:
        json.dump(metrics, f, indent=2)

    print(f"Evaluation metrics saved to {metrics_file}")

    # Visualize results
    visualize_results(history, val_metrics, test_metrics)

def visualize_results(history, val_metrics, test_metrics):
    """
    Visualize training history and evaluation metrics
    """
    # Create results directory for plots
    plots_dir = os.path.join(RESULTS_DIR, 'plots')
    os.makedirs(plots_dir, exist_ok=True)

    # Set plot style
    plt.style.use('ggplot')

    # Plot 1: Training and Validation Loss
    plt.figure(figsize=(12, 8))
    plt.plot(history['train_loss'], label='Training Loss', marker='o', linestyle='-', linewidth=2)
    plt.plot(history['val_loss'], label='Validation Loss', marker='s', linestyle='--', linewidth=2)
    plt.title('Training and Validation Loss', fontsize=16)
    plt.xlabel('Epoch', fontsize=14)
    plt.ylabel('Loss', fontsize=14)
    plt.legend(fontsize=12)
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(plots_dir, 'loss_plot.png'))

    # Plot 2: Training and Validation Accuracy
    plt.figure(figsize=(12, 8))
    plt.plot(history['train_accuracy'], label='Training Accuracy', marker='o', linestyle='-', linewidth=2)
    plt.plot(history['val_accuracy'], label='Validation Accuracy', marker='s', linestyle='--', linewidth=2)
    plt.title('Training and Validation Accuracy', fontsize=16)
    plt.xlabel('Epoch', fontsize=14)
    plt.ylabel('Accuracy', fontsize=14)
    plt.legend(fontsize=12)
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(plots_dir, 'accuracy_plot.png'))

    # Plot 3: BLEU Scores for Validation and Test
    plt.figure(figsize=(12, 8))

    x = np.arange(4)
    width = 0.35

    val_bleu = [val_metrics['bleu1'], val_metrics['bleu2'], val_metrics['bleu3'], val_metrics['bleu4']]
    test_bleu = [test_metrics['bleu1'], test_metrics['bleu2'], test_metrics['bleu3'], test_metrics['bleu4']]

    plt.bar(x - width/2, val_bleu, width, label='Validation', color='#5DA5DA', alpha=0.8)
    plt.bar(x + width/2, test_bleu, width, label='Test', color='#F15854', alpha=0.8)

    plt.ylabel('Score', fontsize=14)
    plt.title('BLEU Scores Comparison', fontsize=16)
    plt.xticks(x, ['BLEU-1', 'BLEU-2', 'BLEU-3', 'BLEU-4'], fontsize=12)
    plt.ylim(0, max(max(val_bleu), max(test_bleu)) * 1.2)

    # Add score values on top of bars
    for i, v in enumerate(val_bleu):
        plt.text(i - width/2, v + 0.01, f'{v:.4f}', ha='center', fontsize=10)
    for i, v in enumerate(test_bleu):
        plt.text(i + width/2, v + 0.01, f'{v:.4f}', ha='center', fontsize=10)

    plt.legend(fontsize=12)
    plt.tight_layout()
    plt.savefig(os.path.join(plots_dir, 'bleu_scores.png'))

    # Plot 4: ROUGE Scores for Validation and Test (if available)
    if 'rouge1' in val_metrics and 'rouge1' in test_metrics:
        plt.figure(figsize=(12, 8))

        x = np.arange(3)
        width = 0.35

        val_rouge = [val_metrics['rouge1'], val_metrics['rouge2'], val_metrics['rougeL']]
        test_rouge = [test_metrics['rouge1'], test_metrics['rouge2'], test_metrics['rougeL']]

        plt.bar(x - width/2, val_rouge, width, label='Validation', color='#5DA5DA', alpha=0.8)
        plt.bar(x + width/2, test_rouge, width, label='Test', color='#F15854', alpha=0.8)

        plt.ylabel('Score', fontsize=14)
        plt.title('ROUGE Scores Comparison', fontsize=16)
        plt.xticks(x, ['ROUGE-1', 'ROUGE-2', 'ROUGE-L'], fontsize=12)
        plt.ylim(0, max(max(val_rouge), max(test_rouge)) * 1.2)

        # Add score values on top of bars
        for i, v in enumerate(val_rouge):
            plt.text(i - width/2, v + 0.01, f'{v:.4f}', ha='center', fontsize=10)
        for i, v in enumerate(test_rouge):
            plt.text(i + width/2, v + 0.01, f'{v:.4f}', ha='center', fontsize=10)

        plt.legend(fontsize=12)
        plt.tight_layout()
        plt.savefig(os.path.join(plots_dir, 'rouge_scores.png'))

    print(f"Visualization plots saved to {plots_dir}")

    # Generate a summary report with key metrics
    summary_file = os.path.join(RESULTS_DIR, 'summary_report.md')
    with open(summary_file, 'w') as f:
        f.write("# NBA Sports Video Analysis Results\n\n")
        f.write(f"Report generated on {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")

        f.write("## Training Performance\n\n")
        f.write(f"- Final Training Loss: {history['train_loss'][-1]:.4f}\n")
        f.write(f"- Final Validation Loss: {history['val_loss'][-1]:.4f}\n")
        f.write(f"- Final Training Accuracy: {history['train_accuracy'][-1]:.4f}\n")
        f.write(f"- Final Validation Accuracy: {history['val_accuracy'][-1]:.4f}\n\n")

        f.write("## Caption Generation Evaluation\n\n")
        f.write("### Validation Set Metrics\n\n")
        f.write(f"- BLEU-1: {val_metrics['bleu1']:.4f}\n")
        f.write(f"- BLEU-4: {val_metrics['bleu4']:.4f}\n")
        if 'rouge1' in val_metrics:
            f.write(f"- ROUGE-L: {val_metrics['rougeL']:.4f}\n\n")

        f.write("### Test Set Metrics\n\n")
        f.write(f"- BLEU-1: {test_metrics['bleu1']:.4f}\n")
        f.write(f"- BLEU-4: {test_metrics['bleu4']:.4f}\n")
        if 'rouge1' in test_metrics:
            f.write(f"- ROUGE-L: {test_metrics['rougeL']:.4f}\n\n")

        f.write("## Visualizations\n\n")
        f.write("The following plots have been generated:\n\n")
        f.write("1. Training and Validation Loss\n")
        f.write("2. Training and Validation Accuracy\n")
        f.write("3. BLEU Scores Comparison (Validation vs Test)\n")
        if 'rouge1' in val_metrics and 'rouge1' in test_metrics:
            f.write("4. ROUGE Scores Comparison (Validation vs Test)\n")

    print(f"Summary report saved to {summary_file}")

if __name__ == "__main__":
    main()