In [17]:

# Warnings ignoring
import warnings
warnings.filterwarnings("ignore")

# OS tools
import os
import typing
from pathlib import Path
from dataclasses import dataclass
from collections import Counter

# Tables, arrays, and plotters 
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Torch
import torch
from torch.utils.data import Dataset

# Video Processing
from torchvision.io import read_video
from torchvision.transforms import v2
import torchvision.transforms as tt

In [3]:
@dataclass
class FilePaths:
    base = Path("../data/UCF101")
    train = Path("../data/UCF101/train.csv")
    test = Path("../data/UCF101/test.csv")
    val = Path("../data/UCF101/val.csv")

In [12]:
os.path.join(FilePaths.base, "train")

'../data/UCF101/train'

In [35]:
class VideoDataset(Dataset):
    
    def __init__(self, dir: Path, meta: Path, clip_len: int, transform: v2.Transform = None, output_format: str = "TCHW") -> None:
        """ Dataset class to load UCF101
        
        Args:
            dir (Path): Path to the directory with video files.
            meta (Path): Path to file with information of video [clip_name, clip_path, label] in csv format
            clip_len (int): The number of frames per video
            transform (Transform, optional): Optional transform to be applied on a sample
            output_format (str, optional): The format of the output video tensors. Can be either "TCHW" (default) or differ combination.
        """
        
        self.dir = dir
        self.clip_len = clip_len
        self.transform = transform
        self.output_format = output_format
        
        df = pd.read_csv(meta)
        
        labels = sorted(df["label"].unique())
        
        self._map_label2idx = {l:i for i, l in enumerate(labels)}
        self._map_idx2label = {i:l for i, l in enumerate(labels)}
        
        self.labels = df["label"].to_numpy()
        self.paths = df["clip_path"].to_numpy()
    
    def __len__(self) -> int:
        return len(self.labels)
    
    def _clip_sampler(self, frames: torch.Tensor) -> torch.Tensor:
        if frames.shape[0] < self.clip_len:
            padding_size = self.clip_len - frames.shape[0]
            last_frame = frames[-1].unsqueeze(0)
            padded_video = torch.cat([frames, last_frame.repeat(padding_size, 1, 1, 1)], dim=0)
            return padded_video
        else:
            padding_size = frames.shape[0] - self.clip_len 
            start_idx = np.random.randint(0, padding_size)
            return frames[start_idx:start_idx + self.clip_len]
    
    def _clip_format(self, frames: torch.Tensor) -> torch.Tensor:
        f_idx = {"T": 0, "C": 1, "H": 2, "W": 3}
        transpose_idx = [f_idx[i] for i in self.output_format]
        return frames.permute(*transpose_idx)
    
    def __getitem__(self, idx) -> typing.Tuple[torch.Tensor, int, int]:
        label = self.labels[idx]
        path = self.paths[idx][1:]
        
        frames, *_ = read_video(os.path.join(self.dir, path), output_format="TCHW")
        frames = frames.float() / 255
        frames = self._clip_sampler(frames)
        
        if self.transform is not None:
            frames = self.transform(frames)
        
        frames = self._clip_format(frames)
        
        return frames, self._map_label2idx[label], idx        

In [36]:
composer = v2.Compose([
    v2.ToDtype(torch.float32),
    v2.Resize(size=(224, 224)),
    v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = VideoDataset(FilePaths.base, FilePaths.train, 32, output_format="CTHW", transform=composer)

In [37]:
video, *_ = dataset[0]
video.shape

torch.Size([3, 32, 224, 224])

In [38]:
video

tensor([[[[ 2.2083,  1.9385,  0.7396,  ..., -1.4091, -1.1884, -1.1180],
          [ 2.2489,  1.8531,  0.3783,  ..., -1.2135, -0.8656, -0.7162],
          [ 2.1391,  1.6261,  0.0411,  ..., -1.0173, -0.5777, -0.3917],
          ...,
          [-0.6452, -0.6452, -0.6452,  ...,  0.6416, -0.3849, -0.5451],
          [-0.6452, -0.6452, -0.6452,  ...,  0.6672, -0.5537, -0.7578],
          [-0.6452, -0.6452, -0.6452,  ...,  0.3302, -0.7919, -0.7410]],

         [[ 0.8575,  0.8070,  0.3087,  ..., -1.3741, -1.2236, -0.8756],
          [ 0.8217,  0.7328,  0.2323,  ..., -1.2865, -1.1027, -0.8096],
          [ 0.8193,  0.7068,  0.2181,  ..., -1.2906, -1.1316, -0.9407],
          ...,
          [-0.6452, -0.6452, -0.6452,  ...,  1.6185,  0.5595, -0.3713],
          [-0.6452, -0.6452, -0.6452,  ...,  0.8799,  0.0061, -0.4373],
          [-0.6452, -0.6452, -0.6452,  ...,  0.0861, -0.5060, -0.4789]],

         [[ 0.8575,  0.8070,  0.3087,  ..., -1.3340, -1.1066, -0.8094],
          [ 0.8217,  0.7328,  