# Imports

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from tqdm import tqdm

# Config

In [None]:
PROCESSED_PATH = "/lustre/proyectos/p037/datasets/processed/plays_processed.parquet"

# Par√°metros generales
BATCH_SIZE = 8
NUM_WORKERS = 4

# Dataset

In [None]:
class NFLPlayDataset(Dataset):
    """
    Builds frame-by-frame sequences (pairs X_t, X_{t+1}) per each play.
    For SSL pretraining.
    """

    def __init__(self, parquet_path):
        self.df = pd.read_parquet(parquet_path)
        print(f"Loaded Data: {self.df.shape[0]:,} filas.")
        
        # Group plays (per game_id, play_id)
        self.groups = list(self.df.groupby(["game_id", "play_id"]))
        print(f"Found plays: {len(self.groups):,}")

    def __len__(self):
        return len(self.groups)

    def __getitem__(self, idx):
        _, play_df = self.groups[idx]

        # Order by frame & nfl_id
        play_df = play_df.sort_values(["frame_id", "nfl_id"]).reset_index(drop=True)
        frames = play_df["frame_id"].unique()
        nfl_ids = play_df["nfl_id"].unique()

        # If play has <2 frames, not apt for SSL
        if len(frames) < 2:
            return None

        # Extract relevant features (position, speed, etc.)
        feat_cols = ["x", "y", "s", "a", "o", "dir"]
        X = play_df[feat_cols].values.reshape(len(frames), len(nfl_ids), -1)

        # Generate pairs (X_t, X_{t+1})
        pairs = []
        for i in range(len(frames) - 1):
            X_t = torch.tensor(X[i], dtype=torch.float32)
            X_tp1 = torch.tensor(X[i + 1], dtype=torch.float32)
            pairs.append((X_t, X_tp1))

        return pairs

# Collate data

In [None]:
def collate_fn(batch):
    # Filter out empty plays (None)
    batch = [b for b in batch if b is not None]
    if not batch:
        return None, None

    # Flatten pairs [(X_t, X_tp1), ...]
    X_t_list, X_tp1_list = [], []
    for play_pairs in batch:
        for X_t, X_tp1 in play_pairs:
            X_t_list.append(X_t)
            X_tp1_list.append(X_tp1)

    return torch.stack(X_t_list), torch.stack(X_tp1_list)

# Dataloader

In [None]:
dataset = NFLPlayDataset(PROCESSED_PATH)
dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    collate_fn=collate_fn,
)

# Test

In [None]:
for batch in dataloader:
    if batch[0] is not None:
        X_t, X_tp1 = batch
        print(f"Batch X_t: {X_t.shape}, Batch X_tp1: {X_tp1.shape}")
        break