In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:

!pip install -q webvtt-py
!pip install -q rouge
!pip install -q nltk
!pip install -q transformers
!pip install -q seaborn

import os
import numpy as np
import random
import json
import tensorflow as tf
from tensorflow import keras
from tqdm import tqdm
import time
from datetime import datetime
from transformers import BertTokenizer
import nltk
from nltk.translate.bleu_score import corpus_bleu
from rouge import Rouge
import matplotlib.pyplot as plt
import seaborn as sns
import sys
sys.path.insert(0, '/content/drive/MyDrive/NSVA_Results')


from nsva_dataset import NSVADataset
from nsva_model import NSVAModel, PositionalEncoding, DecoderLayer

def caption_loss_function(real, pred, mask=None):
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True, reduction='none')

    loss = loss_object(real, pred)

    if mask is not None:
        
        mask = tf.cast(mask, dtype=loss.dtype)
        loss *= mask
        
        return tf.reduce_sum(loss) / tf.reduce_sum(mask)

    return tf.reduce_mean(loss)

@tf.function
def train_step(model, inputs, target, optimizer):
    with tf.GradientTape() as tape:
        predictions = model(inputs, training=True)
        loss = caption_loss_function(target, predictions, inputs['target_mask'])

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    return loss

def train_model(model, train_dataset, val_dataset, epochs=20, lr=0.001, patience=3,
                checkpoint_path='checkpoints/model'):
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr)

    
    train_loss = tf.keras.metrics.Mean(name='train_loss')
    val_loss = tf.keras.metrics.Mean(name='val_loss')

    
    best_val_loss = float('inf')
    patience_counter = 0

    
    os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)

    
    history = {
        'loss': [],
        'val_loss': []
    }

    for epoch in range(epochs):
        start = time.time()

        
        train_loss.reset_state()

        
        for step, (inputs, target) in enumerate(tqdm(train_dataset, desc=f"Epoch {epoch+1}/{epochs}")):
            batch_loss = train_step(model, inputs, target, optimizer)
            train_loss.update_state(batch_loss)

            if step % 10 == 0:
                print(f'Epoch {epoch+1}, Step {step}, Loss: {batch_loss:.4f}')

        
        history['loss'].append(float(train_loss.result()))

        
        if val_dataset:
            val_loss.reset_state()

            for inputs, target in tqdm(val_dataset, desc="Validation"):
                predictions = model(inputs, training=False)
                v_loss = caption_loss_function(target, predictions, inputs['target_mask'])
                val_loss.update_state(v_loss)

            
            history['val_loss'].append(float(val_loss.result()))

            
            if val_loss.result() < best_val_loss:
                best_val_loss = val_loss.result()
                model.save_weights(f'{checkpoint_path}_best.weights.h5')
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print(f'Early stopping triggered after {epoch+1} epochs')
                    break

        
        model.save_weights(f'{checkpoint_path}_{epoch+1}.weights.h5')

        print(f'Epoch {epoch+1}, Loss: {train_loss.result():.4f}, Val Loss: {val_loss.result():.4f}')
        print(f'Time taken: {time.time() - start:.2f} secs\n')

    return history

def evaluate_model(model, test_dataset, tokenizer, num_samples=None, output_path=None):
    references = []
    hypotheses = []
    example_pairs = []

    
    rouge = Rouge()

    
    test_iterator = iter(test_dataset)

    sample_count = 0
    batch_count = 0

    try:
        while True:
            
            if num_samples is not None and sample_count >= num_samples:
                break

            
            try:
                inputs, targets = next(test_iterator)
            except StopIteration:
                break

            batch_count += 1

            
            for i in range(len(targets)):
                if num_samples is not None and sample_count >= num_samples:
                    break

                sample_count += 1

                
                example_inputs = {
                    'timesformer': (tf.expand_dims(inputs['timesformer'][0][i], 0),
                                  tf.expand_dims(inputs['timesformer'][1][i], 0)),
                    'ball': (tf.expand_dims(inputs['ball'][0][i], 0),
                           tf.expand_dims(inputs['ball'][1][i], 0)),
                    'player': (tf.expand_dims(inputs['player'][0][i], 0),
                             tf.expand_dims(inputs['player'][1][i], 0)),
                    'basket': (tf.expand_dims(inputs['basket'][0][i], 0),
                             tf.expand_dims(inputs['basket'][1][i], 0)),
                    'court': (tf.expand_dims(inputs['court'][0][i], 0),
                            tf.expand_dims(inputs['court'][1][i], 0)),
                }

                
                gt_caption = tokenizer.decode(targets[i].numpy(), skip_special_tokens=True)

                
                try:
                    caption_ids = model.generate_caption(example_inputs, tokenizer)
                    
                    gen_caption = tokenizer.decode(caption_ids[0].numpy(), skip_special_tokens=False)
                    print(f"GT: {gt_caption}")
                    print(f"Pred (with special): {gen_caption}")
                    print(f"Pred (without special): {tokenizer.decode(caption_ids[0].numpy(), skip_special_tokens=True)}")

                    references.append([gt_caption.split()])
                    hypotheses.append(gen_caption.split())
                    example_pairs.append((gt_caption, gen_caption))

                    if sample_count % 5 == 0:
                        print(f"\nExample {sample_count}:")
                        print(f"GT: {gt_caption}")
                        print(f"Pred: {gen_caption}")

                except Exception as e:
                    print(f"Error generating caption: {e}")

    except Exception as e:
        print(f"Error during evaluation: {e}")

    
    results = {}

    
    if references and hypotheses:
        bleu1 = corpus_bleu(references, hypotheses, weights=(1, 0, 0, 0))
        bleu2 = corpus_bleu(references, hypotheses, weights=(0.5, 0.5, 0, 0))
        bleu3 = corpus_bleu(references, hypotheses, weights=(0.33, 0.33, 0.33, 0))
        bleu4 = corpus_bleu(references, hypotheses, weights=(0.25, 0.25, 0.25, 0.25))

        results['bleu1'] = bleu1
        results['bleu2'] = bleu2
        results['bleu3'] = bleu3
        results['bleu4'] = bleu4

        print(f"\nBLEU-1: {bleu1:.4f}")
        print(f"BLEU-2: {bleu2:.4f}")
        print(f"BLEU-3: {bleu3:.4f}")
        print(f"BLEU-4: {bleu4:.4f}")

        
        try:
            
            rouge_refs = [' '.join(ref[0]) for ref in references]
            rouge_hyps = [' '.join(hyp) for hyp in hypotheses]

            
            rouge_scores = rouge.get_scores(rouge_hyps, rouge_refs, avg=True)

            results['rouge_1'] = rouge_scores['rouge-1']['f']
            results['rouge_2'] = rouge_scores['rouge-2']['f']
            results['rouge_l'] = rouge_scores['rouge-l']['f']

            print(f"ROUGE-1: {results['rouge_1']:.4f}")
            print(f"ROUGE-2: {results['rouge_2']:.4f}")
            print(f"ROUGE-L: {results['rouge_l']:.4f}")
        except Exception as e:
            print(f"Error calculating ROUGE scores: {e}")

    
    if output_path:
        results['examples'] = example_pairs[:10]  

        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        with open(output_path, 'w') as f:
            json.dump(results, f, indent=2)

    return results

def visualize_training_history(history, output_dir):
    plt.figure(figsize=(10, 6))
    plt.plot(history['loss'], label='Training Loss', marker='o')
    if 'val_loss' in history:
        plt.plot(history['val_loss'], label='Validation Loss', marker='x')

    plt.title('Training and Validation Loss', fontsize=14)
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize=12)

    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'training_history.png'), dpi=300)
    plt.close()

def visualize_evaluation_metrics(results, output_dir):
    
    metric_names = []
    metric_values = []

    
    for i in range(1, 5):
        key = f'bleu{i}'
        if key in results:
            metric_names.append(f'BLEU-{i}')
            metric_values.append(results[key])

    
    rouge_keys = [k for k in results if k.startswith('rouge_')]
    for key in rouge_keys:
        metric_names.append(f'ROUGE-{key[-1]}')
        metric_values.append(results[key])

    
    plt.figure(figsize=(12, 6))
    bars = plt.bar(metric_names, metric_values, color=plt.cm.viridis(np.linspace(0, 0.8, len(metric_names))))

    
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height+0.01,
                f'{height:.4f}', ha='center', fontsize=10)

    plt.title('Evaluation Metrics', fontsize=14)
    plt.ylabel('Score', fontsize=12)
    plt.ylim(0, max(metric_values) * 1.15)  
    plt.grid(axis='y', alpha=0.3)

    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'evaluation_metrics.png'), dpi=300)
    plt.close()

def visualize_caption_examples(example_pairs, output_dir, num_examples=5):
    
    examples = example_pairs[:num_examples]

    
    fig, ax = plt.subplots(figsize=(12, num_examples * 1.2))
    ax.axis('off')

    
    cell_text = []
    for i, (gt, gen) in enumerate(examples):
        cell_text.append([f"Example {i+1}", gt, gen])

    table = ax.table(cellText=cell_text,
                     colLabels=["", "Ground Truth", "Generated Caption"],
                     cellLoc='left',
                     loc='center',
                     colWidths=[0.1, 0.45, 0.45])

    
    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.scale(1, 1.5)

    
    for j in range(3):
        table[(0, j)].set_text_props(color='white', fontweight='bold')

    plt.title('Ground Truth vs. Generated Captions', fontsize=14, pad=20)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'caption_examples.png'), dpi=300)
    plt.close()

def create_results_dashboard(history, results, example_pairs, output_dir):
    
    fig = plt.figure(figsize=(20, 16))
    gs = fig.add_gridspec(3, 2, height_ratios=[1, 1, 1.5])

    
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.plot(history.get('loss', []), marker='o', label='Training Loss')
    if 'val_loss' in history:
        ax1.plot(history.get('val_loss', []), marker='x', label='Validation Loss')
    ax1.set_title('Model Loss')
    ax1.set_ylabel('Loss')
    ax1.set_xlabel('Epoch')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    
    ax2 = fig.add_subplot(gs[0, 1])

    
    metric_names = []
    metric_values = []

    
    for i in range(1, 5):
        key = f'bleu{i}'
        if key in results:
            metric_names.append(f'BLEU-{i}')
            metric_values.append(results[key])

    
    rouge_keys = [k for k in results if k.startswith('rouge_')]
    for key in rouge_keys:
        metric_names.append(f'ROUGE-{key[-1]}')
        metric_values.append(results[key])

    bars = ax2.bar(metric_names, metric_values, color=plt.cm.viridis(np.linspace(0, 0.8, len(metric_names))))
    for bar in bars:
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height+0.01,
                f'{height:.4f}', ha='center', fontsize=9)

    ax2.set_title('Evaluation Metrics')
    ax2.set_ylabel('Score')
    ax2.grid(axis='y', alpha=0.3)

    
    ax3 = fig.add_subplot(gs[1, 0])

    feature_types = ['TimeSformer', 'Ball', 'Player', 'Basket', 'Court']
    importance = [0.45, 0.15, 0.20, 0.10, 0.10]  

    bars = ax3.bar(feature_types, importance, color=plt.cm.plasma(np.linspace(0.2, 0.8, len(feature_types))))
    for bar in bars:
        height = bar.get_height()
        ax3.text(bar.get_x() + bar.get_width()/2., height+0.01,
                f'{height:.2f}', ha='center', fontsize=9)

    ax3.set_title('Approximate Feature Importance')
    ax3.set_ylabel('Relative Importance')
    ax3.grid(axis='y', alpha=0.3)

    
    ax4 = fig.add_subplot(gs[1, 1])
    
    ax4.text(0.5, 0.5, "NSVA Feature Processing Pipeline",
            ha='center', va='center', fontsize=16, fontweight='bold')
    ax4.text(0.5, 0.3, "TimeSformer → Object Detection → Court Analysis → Caption",
            ha='center', va='center', fontsize=12)
    ax4.axis('off')

    
    ax5 = fig.add_subplot(gs[2, :])
    ax5.axis('off')

    
    examples = example_pairs[:3]

    
    cell_text = []
    for i, (gt, gen) in enumerate(examples):
        cell_text.append([f"Example {i+1}", gt, gen])

    table = ax5.table(cellText=cell_text,
                     colLabels=["", "Ground Truth", "Generated Caption"],
                     cellLoc='left',
                     loc='center',
                     colWidths=[0.1, 0.45, 0.45])

    
    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.scale(1, 1.5)

    
    for j in range(3):
        table[(0, j)].set_facecolor('
        table[(0, j)].set_text_props(color='white', fontweight='bold')

    
    for i in range(len(examples)):
        row = i + 1
        color = '
        for j in range(3):
            table[(row, j)].set_facecolor(color)

    ax5.set_title('Caption Examples', fontsize=14)

    
    plt.suptitle('NSVA (NBA Sports Video Analysis) Results Dashboard', fontsize=16, y=0.98)
    plt.figtext(0.5, 0.01, f'Generated on {datetime.now().strftime("%Y-%m-%d %H:%M")}',
               ha='center', fontsize=10, style='italic')

    
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.savefig(os.path.join(output_dir, 'results_dashboard.png'), dpi=300)
    plt.close()

def main():
    
    ANNOTATIONS_FILE = '/content/drive/MyDrive/NSVA_Results/annotations/annotations.json'
    FEATURE_PATHS = {
        'timesformer': '/content/drive/MyDrive/NSVA_Results/features/timesformer',
        'ball': '/content/drive/MyDrive/NSVA_Results/features/ball',
        'player': '/content/drive/MyDrive/NSVA_Results/features/player',
        'basket': '/content/drive/MyDrive/NSVA_Results/features/basket',
        'court': '/content/drive/MyDrive/NSVA_Results/features/court',
    }

    OUTPUT_DIR = '/content/drive/MyDrive/NSVA_Results/results'
    CHECKPOINT_DIR = '/content/drive/MyDrive/NSVA_Results/checkpoints'

    MAX_SEQ_LENGTH = 30
    BATCH_SIZE = 32  
    EMBED_DIM = 256
    NUM_HEADS = 4
    EPOCHS = 20
    LEARNING_RATE = 1e-4  

    
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)

    
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    
    dataset = NSVADataset(
        ANNOTATIONS_FILE,
        FEATURE_PATHS,
        tokenizer,
        MAX_SEQ_LENGTH
    )

    
    
    all_video_ids = dataset.available_videos
    game_ids = set()
    video_to_game = {}

    for video_id in all_video_ids:
        
        game_id = video_id.split('_')[0]
        game_ids.add(game_id)
        video_to_game[video_id] = game_id

    
    game_ids = list(game_ids)
    random.shuffle(game_ids)

    train_games = game_ids[:int(0.8 * len(game_ids))]
    val_games = game_ids[int(0.8 * len(game_ids)):int(0.9 * len(game_ids))]
    test_games = game_ids[int(0.9 * len(game_ids)):]

    
    train_ids = [vid for vid in all_video_ids if video_to_game[vid] in train_games]
    val_ids = [vid for vid in all_video_ids if video_to_game[vid] in val_games]
    test_ids = [vid for vid in all_video_ids if video_to_game[vid] in test_games]

    
    train_dataset = dataset.create_tf_dataset(BATCH_SIZE, shuffle=True)
    val_dataset = dataset.create_tf_dataset(BATCH_SIZE, shuffle=False)
    test_dataset = dataset.create_tf_dataset(BATCH_SIZE, shuffle=False)

    
    model = NSVAModel(
        vocab_size=tokenizer.vocab_size,
        max_caption_length=MAX_SEQ_LENGTH,
        embed_dim=EMBED_DIM,
        num_heads=NUM_HEADS
    )

    
    for inputs, targets in train_dataset.take(1):
        _ = model(inputs, training=False)
        break

    
    model.summary()

    
    history = train_model(
        model,
        train_dataset,
        val_dataset,
        epochs=EPOCHS,
        lr=LEARNING_RATE,
        patience=5,
        checkpoint_path=f'{CHECKPOINT_DIR}/nsva_model'
    )

    
    with open(f'{OUTPUT_DIR}/training_history.json', 'w') as f:
        json.dump(history, f, indent=2)

    
    visualize_training_history(history, OUTPUT_DIR)

    
    results = evaluate_model(
        model,
        test_dataset,
        tokenizer,
        num_samples=100,  
        output_path=f'{OUTPUT_DIR}/evaluation_results.json'
    )

    
    visualize_evaluation_metrics(results, OUTPUT_DIR)
    visualize_caption_examples(results.get('examples', []), OUTPUT_DIR)
    visualize_feature_importance(OUTPUT_DIR)

    
    create_results_dashboard(history, results, results.get('examples', []), OUTPUT_DIR)

    
    model.save_weights(f'{CHECKPOINT_DIR}/nsva_model_final.weights.h5')

if __name__ == "__main__":
    main()