In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%%writefile /content/drive/MyDrive/NSVA_Results/nsva_dataset.py

import tensorflow as tf
import numpy as np
import json
import os
import sys
sys.path.insert(0, '/content/drive/MyDrive/NSVA_Results')

from nsva_model import NSVAModel, PositionalEncoding, DecoderLayer

class NSVADataset:

    def __init__(self, annotations_file, feature_paths, tokenizer, max_seq_length=30):
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length

        try:
            with open(annotations_file, 'r') as f:
                self.annotations = json.load(f)
            print(f"Loaded annotations from {annotations_file}")
        except Exception as e:
            print(f"Error loading annotations: {e}")
            self.annotations = {'sentences': []}

        self.feature_paths = feature_paths

        self.video_id_to_caption = {}
        for ann in self.annotations.get('sentences', []):
            video_id = ann.get('video_id')
            caption = ann.get('caption')
            if video_id and caption:
                self.video_id_to_caption[video_id] = caption

        self.available_videos = []
        for video_id in self.video_id_to_caption.keys():
            if self.check_features_exist(video_id):
                self.available_videos.append(video_id)

    def check_features_exist(self, video_id):
        for feature_type in self.feature_paths.keys():
            feature_path = os.path.join(self.feature_paths[feature_type], f"{video_id}.npy")
            if not os.path.exists(feature_path):
                return False
        return True

    def __len__(self):
        return len(self.available_videos)

    def __getitem__(self, idx):
        if idx >= len(self):
            raise IndexError(f"Index {idx} out of range for dataset with {len(self)} items")

        video_id = self.available_videos[idx]
        caption = self.video_id_to_caption[video_id]

        
        features = {}
        masks = {}

        for feature_type in self.feature_paths.keys():
            feature_path = os.path.join(self.feature_paths[feature_type], f"{video_id}.npy")
            try:
                feature = np.load(feature_path)

                
                mask = np.ones(feature.shape[0], dtype=np.int32)

                features[feature_type] = feature
                masks[feature_type] = mask
            except Exception as e:
                print(f"Error loading {feature_type} feature for {video_id}: {e}")
                
                features[feature_type] = np.zeros((1, 768), dtype=np.float32)  
                masks[feature_type] = np.zeros(1, dtype=np.int32)

        
        tokenized = self.tokenizer(
            caption,
            padding="max_length",
            truncation=True,
            max_length=self.max_seq_length,
            return_tensors="tf"
        )

        input_ids = tokenized["input_ids"][0]
        attention_mask = tokenized["attention_mask"][0]

        
        decoder_input = tf.concat([[self.tokenizer.cls_token_id], input_ids[:-1]], axis=0)
        target = input_ids

        return {
            'timesformer': (features['timesformer'], masks['timesformer']),
            'ball': (features['ball'], masks['ball']),
            'player': (features['player'], masks['player']),
            'basket': (features['basket'], masks['basket']),
            'court': (features['court'], masks['court']),
            'target_ids': decoder_input,
            'target_mask': attention_mask
        }, target

    def create_tf_dataset(self, batch_size=64, shuffle=True):
        def generator():
            import random
            indices = list(range(len(self)))
            if shuffle:
                random.shuffle(indices)

            for idx in indices:
                try:
                    yield self[idx]
                except Exception as e:
                    print(f"Error generating item {idx}: {e}")
                    continue

        
        output_shapes = (
            {
                'timesformer': (tf.TensorShape([None, None]), tf.TensorShape([None])),
                'ball': (tf.TensorShape([None, None]), tf.TensorShape([None])),
                'player': (tf.TensorShape([None, None, None]), tf.TensorShape([None])),
                'basket': (tf.TensorShape([None, None]), tf.TensorShape([None])),
                'court': (tf.TensorShape([None, None]), tf.TensorShape([None])),
                'target_ids': tf.TensorShape([self.max_seq_length]),
                'target_mask': tf.TensorShape([self.max_seq_length])
            },
            tf.TensorShape([self.max_seq_length])
        )

        
        output_types = (
            {
                'timesformer': (tf.float32, tf.int32),
                'ball': (tf.float32, tf.int32),
                'player': (tf.float32, tf.int32),
                'basket': (tf.float32, tf.int32),
                'court': (tf.float32, tf.int32),
                'target_ids': tf.int32,
                'target_mask': tf.int32
            },
            tf.int32
        )

        
        try:
            dataset = tf.data.Dataset.from_generator(
                generator,
                output_types=output_types,
                output_shapes=output_shapes
            )

            
            dataset = dataset.batch(batch_size)
            dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

            return dataset
        except Exception as e:
            print(f"Error creating TensorFlow dataset: {e}")
            raise


if __name__ == "__main__":
    print("Dataset module loaded")