In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%%writefile /content/drive/MyDrive/NSVA_Results/nsva_model.py

import tensorflow as tf
import numpy as np

class PositionalEncoding(tf.keras.layers.Layer):

    def __init__(self, d_model, max_position):
        super(PositionalEncoding, self).__init__()
        self.pos_encoding = self.positional_encoding(max_position, d_model)

    def get_angles(self, pos, i, d_model):
        angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
        return pos * angle_rates

    def positional_encoding(self, position, d_model):
        angle_rads = self.get_angles(
            np.arange(position)[:, np.newaxis],
            np.arange(d_model)[np.newaxis, :],
            d_model
        )

        
        angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
        angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])

        pos_encoding = angle_rads[np.newaxis, ...]
        return tf.cast(pos_encoding, dtype=tf.float32)

    def call(self, inputs):
        return inputs + self.pos_encoding[:, :tf.shape(inputs)[1], :]

class DecoderLayer(tf.keras.layers.Layer):

    def __init__(self, embed_dim, num_heads, dff, rate=0.1):
        super(DecoderLayer, self).__init__()

        
        self.mha1 = tf.keras.layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim//num_heads
        )
        self.mha2 = tf.keras.layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=embed_dim//num_heads
        )

        
        self.ffn = tf.keras.Sequential([
            tf.keras.layers.Dense(dff, activation='relu'),
            tf.keras.layers.Dense(embed_dim)
        ])

        
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

        
        self.dropout1 = tf.keras.layers.Dropout(rate)
        self.dropout2 = tf.keras.layers.Dropout(rate)
        self.dropout3 = tf.keras.layers.Dropout(rate)

    def call(self, x, enc_output, look_ahead_mask, padding_mask=None, enc_padding_mask=None, training=False):
        
        attn1 = self.mha1(
            query=x, key=x, value=x,
            attention_mask=look_ahead_mask,
            training=training
        )
        attn1 = self.dropout1(attn1, training=training)
        out1 = self.layernorm1(x + attn1)

        
        attn2 = self.mha2(
            query=out1, key=enc_output, value=enc_output,
            attention_mask=enc_padding_mask,
            training=training
        )
        attn2 = self.dropout2(attn2, training=training)
        out2 = self.layernorm2(out1 + attn2)

        
        ffn_output = self.ffn(out2)
        ffn_output = self.dropout3(ffn_output, training=training)
        out3 = self.layernorm3(out2 + ffn_output)

        return out3

class NSVAModel(tf.keras.Model):

    def __init__(self, vocab_size, max_caption_length=30, embed_dim=256, num_heads=4, dff=512):
        super(NSVAModel, self).__init__()

        self.embed_dim = embed_dim
        self.vocab_size = vocab_size
        self.max_caption_length = max_caption_length
        self.num_heads = num_heads
        self.dff = dff

        
        self.timesformer_projection = tf.keras.layers.Dense(embed_dim)
        self.ball_projection = tf.keras.layers.Dense(embed_dim)
        self.player_projection = tf.keras.layers.Dense(embed_dim)
        self.basket_projection = tf.keras.layers.Dense(embed_dim)
        self.court_projection = tf.keras.layers.Dense(embed_dim)

        
        self.decoder_embedding = tf.keras.layers.Embedding(vocab_size, embed_dim)

        
        self.decoder_pos_encoding = PositionalEncoding(embed_dim, max_caption_length)

        
        self.temporal_encoder = tf.keras.Sequential([
            tf.keras.layers.LayerNormalization(epsilon=1e-6),
            tf.keras.layers.Bidirectional(
                tf.keras.layers.LSTM(embed_dim//2, return_sequences=True)
            ),
            tf.keras.layers.LayerNormalization(epsilon=1e-6)
        ])

        
        self.cross_encoder_layers = 3
        self.cross_attention_layers = []
        self.cross_ffn_layers = []
        self.cross_norm1_layers = []
        self.cross_norm2_layers = []

        for _ in range(self.cross_encoder_layers):
            self.cross_attention_layers.append(
                tf.keras.layers.MultiHeadAttention(
                    num_heads=num_heads, key_dim=embed_dim//num_heads
                )
            )
            self.cross_ffn_layers.append(tf.keras.Sequential([
                tf.keras.layers.Dense(embed_dim*4, activation='relu'),
                tf.keras.layers.Dense(embed_dim)
            ]))
            self.cross_norm1_layers.append(tf.keras.layers.LayerNormalization(epsilon=1e-6))
            self.cross_norm2_layers.append(tf.keras.layers.LayerNormalization(epsilon=1e-6))

        
        self.decoder_layers = []
        for _ in range(3):  
            self.decoder_layers.append(DecoderLayer(embed_dim, num_heads, embed_dim*4))

        
        self.final_layer = tf.keras.layers.Dense(vocab_size)

    def process_features(self, inputs, training=False):
        
        timesformer_features = inputs['timesformer'][0]
        ball_features = inputs['ball'][0]
        player_features = inputs['player'][0]
        basket_features = inputs['basket'][0]
        court_features = inputs['court'][0]

        
        
        if len(timesformer_features.shape) == 2:
            timesformer_features = tf.expand_dims(timesformer_features, axis=1)

        
        if len(ball_features.shape) == 2:
            ball_features = tf.expand_dims(ball_features, axis=1)

        
        if len(player_features.shape) == 3:
            
            pass
        elif len(player_features.shape) == 4:
            
            
            player_features = tf.reduce_mean(player_features, axis=2)
        else:
            
            player_features = tf.expand_dims(player_features, axis=1)

        
        if len(basket_features.shape) == 2:
            basket_features = tf.expand_dims(basket_features, axis=1)

        
        if len(court_features.shape) == 2:
            court_features = tf.expand_dims(court_features, axis=1)

        
        timesformer_embed = self.timesformer_projection(timesformer_features)
        ball_embed = self.ball_projection(ball_features)
        player_embed = self.player_projection(player_features)
        basket_embed = self.basket_projection(basket_features)
        court_embed = self.court_projection(court_features)

        
        
        object_features = ball_embed + player_embed + basket_embed

        
        timesformer_encoded = self.temporal_encoder(timesformer_embed)
        object_encoded = self.temporal_encoder(object_features)
        court_encoded = self.temporal_encoder(court_embed)

        
        fine_grained_encoded = tf.concat([object_encoded, court_encoded], axis=1)
        encoder_output = tf.concat([timesformer_encoded, fine_grained_encoded], axis=1)

        
        for i in range(self.cross_encoder_layers):
            
            attn_output = self.cross_attention_layers[i](
                query=encoder_output,
                key=encoder_output,
                value=encoder_output,
                training=training
            )

            
            encoder_output = self.cross_norm1_layers[i](encoder_output + attn_output)

            
            ffn_output = self.cross_ffn_layers[i](encoder_output)

            
            encoder_output = self.cross_norm2_layers[i](encoder_output + ffn_output)

        return encoder_output

    def call(self, inputs, training=False):
        
        encoder_output = self.process_features(inputs, training)

        
        if 'target_ids' not in inputs or inputs['target_ids'] is None:
            return encoder_output

        
        target_ids = inputs['target_ids']

        
        look_ahead_mask = self.create_look_ahead_mask(tf.shape(target_ids)[1])

        
        decoder_input = self.decoder_embedding(target_ids)
        decoder_input = self.decoder_pos_encoding(decoder_input)

        
        decoder_output = decoder_input
        for layer in self.decoder_layers:
            decoder_output = layer(
                decoder_output,
                encoder_output,
                look_ahead_mask,
                training=training
            )

        
        logits = self.final_layer(decoder_output)
        return logits

    def create_look_ahead_mask(self, size):
        
        mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
        return mask[tf.newaxis, tf.newaxis, :, :]

    def generate_caption(self, features, tokenizer, max_length=30):
        try:
            
            encoder_output = self.process_features(features, training=False)
            print(f"Encoder output shape: {encoder_output.shape}")

            
            start_token = tokenizer.cls_token_id
            end_token = tokenizer.sep_token_id
            print(f"Start token: {start_token}, End token: {end_token}")

            
            current_tokens = tf.constant([[start_token]], dtype=tf.int32)

            
            for step in range(max_length):
                
                decoder_input = self.decoder_embedding(current_tokens)
                decoder_input = self.decoder_pos_encoding(decoder_input)

                
                seq_len = tf.shape(current_tokens)[1]
                look_ahead_mask = self.create_look_ahead_mask(seq_len)

                
                decoder_output = decoder_input
                for layer in self.decoder_layers:
                    decoder_output = layer(
                        decoder_output,
                        encoder_output,
                        look_ahead_mask,
                        training=False
                    )

                
                logits = self.final_layer(decoder_output)
                next_token_logits = logits[:, -1, :]

                
                top_values, top_indices = tf.math.top_k(next_token_logits[0], k=5)
                print(f"Step {step}, Top 5 tokens: {top_indices.numpy()}, Probs: {tf.nn.softmax(top_values).numpy()}")

                
                next_token = tf.argmax(next_token_logits, axis=-1, output_type=tf.int32)
                next_token = tf.reshape(next_token, [1, 1])

                
                current_tokens = tf.concat([current_tokens, next_token], axis=1)

                
                next_token_val = int(next_token.numpy()[0][0])
                print(f"Generated token: {next_token_val}, Text so far: {tokenizer.decode(current_tokens[0].numpy())}")

                
                if next_token_val == end_token:
                    print("Generated end token, stopping.")
                    break

            
            with_special = tokenizer.decode(current_tokens[0].numpy())
            without_special = tokenizer.decode(current_tokens[0].numpy(), skip_special_tokens=True)
            print(f"Final with special tokens: {with_special}")
            print(f"Final without special tokens: {without_special}")

            return current_tokens

        except Exception as e:
            import traceback
            print(f"Error in generate_caption_simple: {e}")
            print(traceback.format_exc())
            return tf.constant([[start_token, end_token]], dtype=tf.int32)


if __name__ == "__main__":
    print("Model module loaded")