In [None]:
import os
import json
import random
from typing import List, Dict, Tuple, Optional
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
from sklearn.model_selection import train_test_split

class TennisDataset(Dataset):
    """
    A custom dataset class for loading and processing tennis-related data.
    Args:
        data (List[Dict]): A list of dictionaries where each dictionary contains information about an image and its annotations.
        transform (callable, optional): A function/transform to apply to the images.
        sequence_length (Optional[int], optional): If provided, indicates that the data should be treated as sequences of this length.
    Attributes:
        sequence_length (Optional[int]): The length of the sequences if provided.
        transform (callable, optional): A function/transform to apply to the images.
        data (List[Dict]): The dataset containing image paths and annotations.
    Methods:
        __len__() -> int:
            Returns the number of items in the dataset.
        __getitem__(idx: int):
            Retrieves the item at the given index. If sequence_length is provided, retrieves a sequence of items.
            Args:
                idx (int): The index of the item to retrieve.
            Returns:
                tuple: A tuple containing the image(s), bounding box(es), keypoint(s), and label.
    """
    
    def __init__(self, data: List[Dict], transform=None, sequence_length: Optional[int] = None):
        self.sequence_length = sequence_length
        self.transform = transform
        self.data = data

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int):
        if self.sequence_length:
            # If sequence_length is provided, retrieve a sequence of items
            sequence = self.data[idx]
            frames, bboxes, keypoints = [], [], []
            for item in sequence:
                original_width, original_height = item['width'], item['height']

            # Open and transform the image
            image = Image.open(item['image_path']).convert("RGB")
            if self.transform:
                image = self.transform(image)
            
            new_height, new_width = image.shape[1], image.shape[2]

            # Compute scaling factors
            scale_x = new_width / original_width
            scale_y = new_height / original_height

            # Append transformed image, normalized bounding boxes, and keypoints
            frames.append(image)
            bboxes.append(normalize_bbox(item['bbox'], scale_x, scale_y))
            keypoints.append(normalize_keypoints(item['keypoints'], scale_x, scale_y))
            
            # Stack frames, bounding boxes, and keypoints into tensors
            frames = torch.stack(frames)
            bboxes = torch.stack(bboxes)
            keypoints = torch.stack(keypoints)
            
            # Convert label to tensor
            label = torch.tensor(['backhand', 'forehand', 'serve', 'ready_position'].index(sequence[0]['label']))
            return frames, bboxes, keypoints, label
        
        else:
            # If sequence_length is not provided, retrieve a single item
            item = self.data[idx]
            image = Image.open(item['image_path']).convert("RGB")
            original_width, original_height = image.size

            # Transform the image
            if self.transform:
                image = self.transform(image)
            
            new_height, new_width = image.shape[1], image.shape[2]

            # Compute scaling factors
            scale_x = new_width / original_width
            scale_y = new_height / original_height

            # Normalize bounding boxes and keypoints
            bboxes = normalize_bbox(item['bbox'], scale_x, scale_y)
            keypoints = normalize_keypoints(item['keypoints'], scale_x, scale_y)
            
            # Convert label to tensor
            label = torch.tensor(['backhand', 'forehand', 'serve', 'ready_position'].index(item['label']))
        
            return image, bboxes, keypoints, label

def get_train_transform() -> transforms.Compose:
    """
    Returns a composition of image transformations to be applied to the training dataset.
    The transformations include:
    - Resizing the image to 320x320 pixels.
    - Applying random changes in brightness, contrast, saturation, and hue.
    - Applying Gaussian blur with a kernel size of 5x5 and a sigma range of 0.1 to 2.0.
    - Converting the image to a tensor.
    - Normalizing the image tensor with mean and standard deviation values.
    Returns:
        transforms.Compose: A composition of the specified image transformations.
    """
    
    return transforms.Compose([
        transforms.Resize((320, 320)),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 2.0)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
def get_val_transform():
    """
    Returns a composed transform for validation data preprocessing.
    The transform includes the following steps:
    1. Resize the image to 320x320 pixels.
    2. Convert the image to a tensor.
    3. Normalize the image tensor with mean and standard deviation values.
    Returns:
        torchvision.transforms.Compose: A composed transform for validation data.
    """
    
    return transforms.Compose([
        transforms.Resize((320, 320)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

def normalize_bbox(bbox: List[float], scale_x: float, scale_y: float) -> torch.Tensor:
    """
    Normalize a bounding box by scaling its coordinates.
    Args:
        bbox (List[float]): A list of four floats representing the bounding box in the format [x, y, width, height].
        scale_x (float): The scaling factor for the x-axis.
        scale_y (float): The scaling factor for the y-axis.
    Returns:
        torch.Tensor: A tensor containing the normalized bounding box coordinates in the format [xmin, ymin, xmax, ymax].
    """
    
    x, y, width, height = bbox
    xmin = max(0, x * scale_x)
    ymin = max(0, y * scale_y)
    xmax = max(0, (x + width) * scale_x)
    ymax = max(0, (y + height) * scale_y)
    normalized = [xmin, ymin, xmax, ymax]
    return torch.tensor(normalized, dtype=torch.float32)

def normalize_keypoints(keypoints: List[float], scale_x: float, scale_y: float) -> torch.Tensor:
    """
    Normalize keypoints by scaling the x and y coordinates and keeping the visibility value unchanged.
    Args:
        keypoints (List[float]): A list of keypoints where each keypoint is represented by three values [x, y, v].
        scale_x (float): The scaling factor for the x coordinates.
        scale_y (float): The scaling factor for the y coordinates.
    Returns:
        torch.Tensor: A tensor containing the normalized keypoints with the same structure as the input list.
    """
    
    normalized = []
    for i in range(0, len(keypoints), 3):
        x = max(0, keypoints[i] * scale_x)
        y = max(0, keypoints[i + 1] * scale_y)
        v = keypoints[i + 2]
        normalized.extend([x, y, v])
    return torch.tensor(normalized, dtype=torch.float32)

def load_annotations(json_files: List[str], base_path: str) -> List[Dict]:
    """
    Load annotations from a list of JSON files and return a list of dictionaries containing image data and annotations.
    Args:
        json_files (List[str]): A list of JSON file names containing annotation data.
        base_path (str): The base directory path where the JSON files and images are located.
    Returns:
        List[Dict]: A list of dictionaries, each containing the following keys:
            - 'image_path' (str): The path to the image file.
            - 'bbox' (List[float]): The bounding box coordinates of the annotation.
            - 'keypoints' (List[float]): The keypoints of the annotation.
            - 'label' (str): The label of the annotation (shot type).
            - 'id' (int): The ID of the image.
            - 'height' (int): The height of the image.
            - 'width' (int): The width of the image.
    Raises:
        FileNotFoundError: If any of the JSON files are not found in the specified path.
    Example:
        json_files = ['annotation1.json', 'annotation2.json']
        base_path = '/path/to/dataset'
        annotations = load_annotations(json_files, base_path)
    """
    
    all_data = []
    for json_file in json_files:
        json_file_path = os.path.join(base_path, 'annotations', json_file)
        try:
            with open(json_file_path, 'r') as f:
                data = json.load(f)
        except FileNotFoundError:
            print(f"Error: Annotation file {json_file_path} not found.")
            continue

        shot_type = data['categories'][0]['name'].lower()
        for img_info in data['images']:
            img_path = os.path.join(base_path, img_info['path'].lstrip('../'))
            annotation = next((ann for ann in data['annotations'] if ann['image_id'] == img_info['id']), None)
            if annotation:
                all_data.append({
                    'image_path': img_path,
                    'bbox': annotation['bbox'],
                    'keypoints': annotation['keypoints'],
                    'label': shot_type,
                    'id': img_info['id'],
                    'height': img_info['height'],
                    'width': img_info['width']
                })
    return all_data

def sequentialize_data(data: List[Dict], sequence_length: int) -> List[List[Dict]]:
    """
    Converts a list of dictionaries into a list of sequences of dictionaries, 
    where each sequence has a specified length.
    Args:
        data (List[Dict]): A list of dictionaries, each containing data entries.
        sequence_length (int): The length of each sequence to be generated.
    Returns:
        List[List[Dict]]: A list of sequences, where each sequence is a list of dictionaries.
    Example:
        data = [{'id': 1, 'value': 'a'}, {'id': 2, 'value': 'b'}, {'id': 3, 'value': 'c'}]
        sequence_length = 2
        result = sequentialize_data(data, sequence_length)
        # result will be:
        # [
        #     [{'id': 1, 'value': 'a'}, {'id': 2, 'value': 'b'}],
        #     [{'id': 2, 'value': 'b'}, {'id': 3, 'value': 'c'}]
        # ]
    """
    
    sequences = []
    data.sort(key=lambda x: x['id'])
    for i in range(0, len(data), 500):
        batch = data[i:i + 500]
        for j in range(len(batch) - sequence_length + 1):
            sequence = batch[j:j + sequence_length]
            sequences.append(sequence)
    
    return sequences

def split_data(data: List[Dict], sequence_length: int = 1) -> Tuple[List[Dict], List[Dict], List[Dict]]:
    """
    Splits the input data into training, validation, and test sets.
    Parameters:
    -----------
    data : List[Dict]
        A list of dictionaries where each dictionary represents a data point.
    sequence_length : int, optional
        The length of the sequence to be considered for splitting. Default is 1.
        If greater than 1, the data is split sequentially.
    Returns:
    --------
    Tuple[List[Dict], List[Dict], List[Dict]]
        A tuple containing three lists of dictionaries: train_data, val_data, and test_data.
    Raises:
    -------
    ValueError
        If there is any overlap between the train, validation, and test sets.
    Notes:
    ------
    - When sequence_length is greater than 1, the data is split sequentially into categories,
      and a round-robin selection method is used to ensure balanced splits.
    - The function ensures that there are no overlaps between the train, validation, and test sets.
    - If sequence_length is 1, the data is shuffled and split randomly using a 70-15-15 ratio for train, validation, and test sets respectively.
    """
    
    if sequence_length > 1:
        # Split data sequentially
        
        train_ratio = 0.7
        val_ratio = 0.15

        total_frames = len(data)
        total_categories = 4

        # Calculate sizes
        train_size = int(train_ratio * total_frames)
        val_size = int(val_ratio * total_frames)

        # Initialize splits
        train_data = []
        val_data = []
        test_data = []
        
        category_data = [[] for _ in range(total_categories)]
        for i, ls in enumerate(category_data):
            ls.append(data[i*(len(data)//total_categories):(i+1)*(len(data)//total_categories)])
        
        category_data = [item for sublist in category_data for item in sublist]
        
        # Function to round-robin select frames
        def round_robin_select(target_size, start_indices):
            selected_data = []
            indices = start_indices.copy()
            
            while len(selected_data) < target_size:
                for i in range(total_categories):
                    if indices[i] < len(category_data[i]):
                        selected_data.append(category_data[i][indices[i]])
                        indices[i] += 1
                    if len(selected_data) == target_size:
                        break
                        
                if min(indices) >= len(category_data[0]):
                    break
              
            for i in range(len(indices)):
                indices[i] += sequence_length
                indices[i] -= 1
            return selected_data, indices

        # Initialize indices for each category
        start_indices = [0] * total_categories

        # Fill the train, val, and test sets
        train_data, start_indices = round_robin_select(train_size, start_indices)
        val_data, start_indices = round_robin_select(val_size, start_indices)
        remaining_size = total_frames - len(train_data) - len(val_data)
        test_data, start_indices = round_robin_select(remaining_size, start_indices)
 
        # Ensure no overlaps
        train_ids = {item['id'] for sequence in train_data for item in sequence}
        val_ids = {item['id'] for sequence in val_data for item in sequence}
        test_ids = {item['id'] for sequence in test_data for item in sequence}

        train_val_overlap = train_ids.intersection(val_ids)
        train_test_overlap = train_ids.intersection(test_ids)
        val_test_overlap = val_ids.intersection(test_ids)

        if train_val_overlap:
            raise ValueError(f"Overlap between train and val: {train_val_overlap}")
        if train_test_overlap:
            raise ValueError(f"Overlap between train and test: {train_test_overlap}")
        if val_test_overlap:
            raise ValueError(f"Overlap between val and test: {val_test_overlap}")
        else:
            print('No overlaps found in data')
    else:
        random.shuffle(data)
        train_data, val_test_data = train_test_split(data, train_size=0.7, random_state=42)
        val_data, test_data = train_test_split(val_test_data, train_size=0.5, random_state=42)
    
    return train_data, val_data, test_data

def get_datasets(json_files: List[str], base_path: str, sequence_length: Optional[int] = None) -> Tuple[TennisDataset, TennisDataset, TennisDataset]:
    """
    Load datasets from JSON annotation files, apply transformations, and split into training, validation, and test sets.
    Args:
        json_files (List[str]): List of paths to JSON annotation files.
        base_path (str): Base path to the dataset directory.
        sequence_length (Optional[int], optional): Length of sequences for sequential data. Defaults to None.
    Returns:
        Tuple[TennisDataset, TennisDataset, TennisDataset]: A tuple containing the training, validation, and test datasets.
    """
    
    all_data = load_annotations(json_files, base_path)
    transform_train = get_train_transform()
    transform_val = get_val_transform()
    
    if sequence_length:
        # Sequentialize first
        sequences = sequentialize_data(all_data, sequence_length)
        train_data, val_data, test_data = split_data(sequences, sequence_length)
    else:
        # Split first
        train_data, val_data, test_data = split_data(all_data)
    train_dataset = TennisDataset(train_data, transform=transform_train, sequence_length=sequence_length)
    val_dataset = TennisDataset(val_data,transform=transform_val, sequence_length=sequence_length)
    test_dataset = TennisDataset(test_data,transform=transform_val, sequence_length=sequence_length)
    return train_dataset, val_dataset, test_dataset


In [None]:
json_files = ['backhand.json', 'forehand.json', 'serve.json', 'ready_position.json']
base_path = "../og_dataset"
# Non-sequential data
train_dataset, val_dataset, test_dataset = get_datasets(json_files, base_path)
print(f"Non-sequential - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

train_dataset_seq, val_dataset_seq, test_dataset_seq = get_datasets(json_files, base_path, sequence_length=5)
print(f"Sequential - Train: {len(train_dataset_seq)}, Val: {len(val_dataset_seq)}, Test: {len(test_dataset_seq)}")
# Print a sample of the data for each dataset
def print_sample(dataset, name):
    sample = dataset[0]
    if dataset.sequence_length:
        frames, bboxes, keypoints, label = sample
        print(f"{name} - Sequence Sample:")
        print(f"Frames shape: {frames.shape}")
        print(f"BBoxes shape: {bboxes.shape}")
        print(f"Keypoints shape: {keypoints.shape}")
        print(f"Label: {label}")
    else:
        image, bboxes, keypoints, label = sample
        print(f"{name} - Sample:")
        print(f"Image shape: {image.shape}")
        print(f"BBoxes: {bboxes}")
        print(f"Keypoints: {keypoints}")
        print(f"Label: {label}")

print_sample(train_dataset, "Train Dataset")
print_sample(val_dataset, "Validation Dataset")
print_sample(test_dataset, "Test Dataset")
print('\n')
print_sample(train_dataset_seq, "Train Dataset Sequential")
print_sample(val_dataset_seq, "Validation Dataset Sequential")
print_sample(test_dataset_seq, "Test Dataset Sequential")


Non-sequential - Train: 1400, Val: 300, Test: 300
No overlaps found in data
Sequential - Train: 1388, Val: 297, Test: 267
Train Dataset - Sample:
Image shape: torch.Size([3, 320, 320])
BBoxes: tensor([  0.2500, 108.4444,  32.7500, 205.3333])
Keypoints: tensor([ 19.5000, 120.0000,   1.0000,  17.5000, 118.6667,   1.0000,  18.7500,
        118.2222,   1.0000,  14.5000, 120.8889,   1.0000,  17.5000, 121.3333,
          2.0000,  12.7500, 133.7778,   2.0000,  22.2500, 125.7778,   2.0000,
         16.2500, 149.3333,   2.0000,  17.5000, 128.0000,   1.0000,  21.2500,
        134.6667,   1.0000,  11.2500, 131.1111,   2.0000,  19.2500, 159.5556,
          2.0000,  27.5000, 155.1111,   2.0000,  21.0000, 165.7778,   2.0000,
         29.0000, 173.7778,   2.0000,  19.7500, 183.1111,   2.0000,  24.5000,
        188.4444,   2.0000,  16.5000, 124.4444,   2.0000])
Label: 1
Validation Dataset - Sample:
Image shape: torch.Size([3, 320, 320])
BBoxes: tensor([ 85.0000, 141.3333, 119.0000, 262.2222])
Keypoint