### data.py

In [23]:
from abc import ABC, abstractmethod
import os
import cv2
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import torch
from tqdm import tqdm
import random
#from src.sampling import Sampler, FixedStepSampler
from src.augmentations import default_transforms, train_augmentations

CATEGORY_INDEX = {
    "brush_hair": 0,
    "cartwheel": 1,
    "catch": 2,
    "chew": 3,
    "climb": 4,
    "climb_stairs": 5,
    "draw_sword": 6,
    "eat": 7,
    "fencing": 8,
    "flic_flac": 9,
    "golf": 10,
    "handstand": 11,
    "kiss": 12,
    "pick": 13,
    "pour": 14,
    "pullup": 15,
    "pushup": 16,
    "ride_bike": 17,
    "shoot_bow": 18,
    "shoot_gun": 19,
    "situp": 20,
    "smile": 21,
    "smoke": 22,
    "throw": 23,
    "wave": 24,
}


class VideoDataCollator:
    """
    Custom data collator for TimeSFormer.
    Converts (clip, label) tuples into a dictionary format.
    """

    def __call__(self, features):
        clips, labels = zip(*features)  # Unpack (clip, label)
        batch = {
            "pixel_values": torch.stack(clips),  # Stack clips into batch
            "labels": torch.tensor(
                labels, dtype=torch.long
            ),  # Convert labels to tensor
        }
        return batch


def split_sources(dataset_path, train_ratio=0.8):
    """
    Splits source folders into train and val sets before processing clips.
    Ensures that all clips from a source video stay in the same set.
    """
    train_sources = {}
    val_sources = {}

    for category in os.listdir(dataset_path):  # Iterate over action categories
        category_path = os.path.join(dataset_path, category)
        if not os.path.isdir(category_path):
            continue

        instances = os.listdir(category_path)  # List all source folders (video IDs)
        random.shuffle(instances)  # Shuffle instances before splitting

        split_idx = int(len(instances) * train_ratio)
        train_sources[category] = instances[:split_idx]  # First 80% for training
        val_sources[category] = instances[split_idx:]  # Last 20% for validation

    return train_sources, val_sources


def create_clips(frames, clip_size=8, frame_paths=None):
    """
    Given a list of sampled frames, create multiple [clip_size]-frame clips.
    Each clip is returned as a tensor.
    """
    clips = []
    path_clips = []
    
    for i in range(0, len(frames) - clip_size + 1, clip_size):
        clip = frames[i : i + clip_size]
        if len(clip) == clip_size:
            clips.append(torch.stack(clip))  # Convert the clip to a tensor
            if frame_paths is not None:
                path_clips.append(frame_paths[i : i + clip_size])

    return clips, path_clips
            
            
def process_dataset(
    dataset_path,
    sources_dict,
    augmentation_transform=None,
    sampler: Sampler = FixedStepSampler(),
):
    """
    Processes dataset based on a predefined list of sources.
    """
    if augmentation_transform is None:
        augmentation_transform = lambda image: {"image": image}

    dataset = []
    dataset_w_paths = []

    for category, instances in tqdm(sources_dict.items()):
        category_path = os.path.join(dataset_path, category)

        for instance in instances:
            instance_path = os.path.join(category_path, instance)
            if not os.path.isdir(instance_path):
                # print(f"Skipping non-directory file: {instance_path}")
                continue
            
            if isinstance(sampler, FrameSampler): 
                # Load sampled frames
                frame_paths = sampler.sample(instance_path)
                frames = []
                for path in frame_paths:
                    try:
                        frames.append(
                            default_transforms(
                                image=augmentation_transform(
                                    image=cv2.cvtColor(
                                        cv2.imread(path), cv2.COLOR_BGR2RGB
                                    )
                                )["image"]
                            )["image"]
                        )
                    except Exception as e:
                        print(f"Error processing frame {path}: {e}")
                        frames.append(None)

            if isinstance(sampler, ClipSampler):
                frames_created, frame_paths = sampler.sample(instance_path)
                frame_paths = [f"{category}_{instance}_{round(pos, 3)}" for pos in frame_paths]
                frames = []
                for frame in frames_created:
                    try:
                        frames.append(
                            default_transforms(
                                image=augmentation_transform(
                                    image=frame
                                )["image"]
                            )["image"]
                        )
                    except Exception as e:
                        print(f"Error processing frame {path}: {e}")
                        frames.append(None) 
                

            # Create 8-frame clips
            clips, clips_path = create_clips(frames, 8,frame_paths)

            for idx, clip in enumerate(clips):
                dataset.append((clip, CATEGORY_INDEX[category]))
                if clips_path is not None:
                    dataset_w_paths.append((clip, clips_path[idx], CATEGORY_INDEX[category]))

    return dataset, dataset_w_paths  # List of (clip, label)


if __name__ == "__main__":
    DATASET_PATH = "../data/HMDB_simp_clean/"
    # print(len(FixedStepSampler.sample(frame_dir="../HMDB_simp/")))
    train_sources, val_sources = split_sources(DATASET_PATH)

    # Process train and val sets separately
    train_dataset,train_dataset_paths = process_dataset(
        DATASET_PATH, train_sources, augmentation_transform=train_augmentations
    )
    a = train_dataset[0]


100%|███████████████████████████████████████████| 25/25 [00:13<00:00,  1.87it/s]


In [24]:
train_dataset,train_dataset_paths = process_dataset(
        DATASET_PATH, train_sources, augmentation_transform=train_augmentations, sampler = InterpolationSampler()
    )

100%|███████████████████████████████████████████| 25/25 [00:45<00:00,  1.80s/it]


In [26]:
len(train_dataset_paths), train_dataset_paths[1][1]

(980,
 ['ride_bike_C0BC2937_0.0',
  'ride_bike_C0BC2937_11.286',
  'ride_bike_C0BC2937_22.571',
  'ride_bike_C0BC2937_33.857',
  'ride_bike_C0BC2937_45.143',
  'ride_bike_C0BC2937_56.429',
  'ride_bike_C0BC2937_67.714',
  'ride_bike_C0BC2937_79.0'])

In [None]:
dataset_path = "../data/HMDB_simp_clean/"
sources_dict=train_sources
augmentation_transform=train_augmentations
sampler: Sampler = FixedStepSampler()

In [8]:
dataset = []
dataset_w_paths = []

for category, instances in tqdm(sources_dict.items()):
    category_path = os.path.join(dataset_path, category)
    
    category_path = '../data/HMDB_simp_clean/pour'
    instances=['DC13AD0D']

    for instance in instances:
        instance_path = os.path.join(category_path, instance)
        if not os.path.isdir(instance_path):
            # print(f"Skipping non-directory file: {instance_path}")
            continue

        # Load sampled frames
        frame_paths = sampler.sample(instance_path)

        frames = []

        for path in frame_paths:
            try:
                frames.append(
                    default_transforms(
                        image=augmentation_transform(
                            image=cv2.cvtColor(
                                cv2.imread(path), cv2.COLOR_BGR2RGB
                            )
                        )["image"]
                    )["image"]
                )
            except Exception as e:
                print(f"Error processing frame {path}: {e}")
                frames.append(None)

        # Create 8-frame clips
        clips, clips_path = create_clips(frames, 8,frame_paths)

        
        for idx, clip in enumerate(clips):
            dataset.append((clip, CATEGORY_INDEX[category]))
            if clips_path is not None:
                dataset_w_paths.append((clip, clips_path[idx], CATEGORY_INDEX[category]))


100%|███████████████████████████████████████████| 25/25 [00:00<00:00, 29.77it/s]


In [None]:
#def create_clips(frames, clip_size=8, frame_paths=None):
"""
Given a list of sampled frames, create multiple [clip_size]-frame clips.
Each clip is returned as a tensor.
"""
clips = []
path_clips = []
clip_size=8

for i in range(0, len(frames) - clip_size + 1, clip_size):
    clip = frames[i : i + clip_size]
    if len(clip) == clip_size:
        clips.append(torch.stack(clip))  # Convert the clip to a tensor
        if frame_paths is not None:
            #path_clips.append(frame_paths[i : i + clip_size])
            #path_clips.extend(frame_paths[i:i + clip_size])
            path_clips.append(frame_paths[i : i + clip_size])
            #path_clips.append([path for path in frame_paths[i:i + clip_size]])



In [None]:
train_dataset_paths[876][1]

In [127]:
train_dataset[0][0].size()

torch.Size([8, 3, 224, 224])

In [29]:
def find_clips_by_substring(dataset_w_paths, substring):
    """
    Prints clip info for any clip in dataset_w_paths where the path contains the given substring.
    
    Args:
        dataset_w_paths (list): List of (clip, clip_paths, label) tuples.
        substring (str): Substring to search for in the clip paths.
    """
    for idx, (_, clip_paths, label) in enumerate(dataset_w_paths):
        if any(substring in path for path in clip_paths):
            print(f"Index: {idx}")
            print(f"Clip paths: {clip_paths}")
            print(f"Label: {label}")
            print(f"Clip length: {len(clip_paths)}")

find_clips_by_substring(train_dataset_paths, "1C53D816")


Index: 191
Clip paths: ['brush_hair_1C53D816_0.0', 'brush_hair_1C53D816_44.571', 'brush_hair_1C53D816_89.143', 'brush_hair_1C53D816_133.714', 'brush_hair_1C53D816_178.286', 'brush_hair_1C53D816_222.857', 'brush_hair_1C53D816_267.429', 'brush_hair_1C53D816_312.0']
Label: 0
Clip length: 8


### sampling.py

In [2]:
from abc import ABC, abstractmethod
import os
import cv2
import numpy as np


class Sampler(ABC):
    @abstractmethod
    def sample(self, frame_dir=None, *args, **kwargs):
        pass

    @staticmethod
    def list_frames(frame_dir):
        return [
            os.path.join(frame_dir, file)
            for file in sorted(os.listdir(frame_dir))
            if file.endswith((".jpg", ".png", ".jpeg"))
        ]

class FrameSampler(Sampler):
    #sub-parent class for selecting frame locations
    @abstractmethod
    def sample(self, frame_dir):
        pass

class ClipSampler(Sampler):
    #sub-parent class for creating frames
    @abstractmethod
    def sample(self, frame_dir):
        pass


class FixedStepSampler(FrameSampler):
    def __init__(self, step=8):
        self.step = step
        
    def sample(self, frame_dir):
        """
        Load every [step]-th frame from a directory.
        """
        frame_files = self.list_frames(frame_dir)
        return frame_files[::self.step]


class EquidistantSampler(FrameSampler):
    def __init__(self, initial_offset=5, min_frames=8):
        self.initial_offset = initial_offset
        self.min_frames = min_frames
        
    def sample(self, frame_dir):
        frame_files = self.list_frames(frame_dir)
        total_frames = len(frame_files)
        
        if total_frames <= self.initial_offset:
            return frame_files  # Not enough frames, return all

        step = max(1, int((total_frames - self.initial_offset) / self.min_frames))
        print(total_frames)
        print(step)
        return frame_files[self.initial_offset::step]


class InterpolationSampler(ClipSampler):
    """
    Sample frames from a video by interpolating between key frames.
    Outputs interpolated frames as a numpy array (and frame positions for checking purposes).
    """
    def __init__(self, min_frames=8):
        self.min_frames = min_frames
        self.transform = transforms.ToTensor()
        
    def sample(self, frame_dir):
        frame_files = self.list_frames(frame_dir)
        total_frames = len(frame_files)

        if total_frames == 0:
            raise ValueError("Video is empty")
        
        else:
            video = []
            for f in frame_files:
                frame = cv2.cvtColor(cv2.imread(f), cv2.COLOR_BGR2RGB)
                video.append(frame)
                
            positions = np.linspace(0, total_frames - 1, self.min_frames)
            clip = []
            for pos in positions:
                low_idx = int(np.floor(pos))
                high_idx = min(low_idx + 1, total_frames - 1)
                alpha = pos - low_idx

                frame_low = video[low_idx]
                frame_high = video[high_idx]
                
                interp_frame = (1 - alpha) * frame_low + alpha * frame_high
                interp_frame = np.clip(interp_frame, 0, 255).astype(np.uint8)       
                clip.append(interp_frame)
                
        clip_frames = np.stack(clip, axis=0)
        positions_frames = list(positions)
        
        return clip_frames, positions_frames 
        
class AugmentationSampler(ClipSampler):
    """
    Sample frames from a video by adding new augmented frames.
    """
    def __init__(self, min_frames=8):
        self.min_frames = min_frames
        
    def sample(self, frame_dir):
        # TODO: Elisa-tea
        frame_files = self.list_frames(frame_dir)
        return frame_files  # Placeholder implementation


In [None]:
dataset_path = "../data/HMDB_simp_clean/"

In [None]:
final_output = process_dataset(dataset_path, train_sources, augmentation_transform=None, sampler=FixedStepSampler())

In [None]:
final_output[0]

### process_dataset function

In [None]:
def process_dataset(
    dataset_path,
    sources_dict,
    augmentation_transform=None,
    sampler: Sampler = FixedStepSampler(),
):

In [157]:
dataset_path = "../data/HMDB_simp_clean/"
sources_dict=train_sources
augmentation_transform=train_augmentations
sampler: Sampler = FixedStepSampler()

In [None]:
os.getcwd()

In [158]:
if augmentation_transform is None:
    augmentation_transform = lambda image: {"image": image}

    dataset = []

    for category, instances in tqdm(sources_dict.items()):
        category_path = os.path.join(dataset_path, category)

        for instance in instances:
            instance_path = os.path.join(category_path, instance)
            if not os.path.isdir(instance_path):
                # print(f"Skipping non-directory file: {instance_path}")
                continue

            # Load sampled frames
            frame_paths = sampler.sample(instance_path)

            frames = []

            for path in frame_paths:
                try:
                    frames.append(
                        default_transforms(
                            image=augmentation_transform(
                                image=cv2.cvtColor(
                                    cv2.imread(path), cv2.COLOR_BGR2RGB
                                )
                            )["image"]
                        )["image"]
                    )
                except Exception as e:
                    print(f"Error processing frame {path}: {e}")
                    frames.append(None)

            # Create 8-frame clips
            clips = create_clips(frames, 8)
            for clip in clips:
                dataset.append((clip, CATEGORY_INDEX[category]))

In [None]:
#frames is a list of tensors of size X
#clips creates n 8-frame clips
print(len(frames)), print(frames[0].size()),len(clips)

In [None]:
instances[-1]

In [None]:
category_path

In [159]:
frame_paths

### InterpolationSampler

In [4]:
dataset_path = "../data/HMDB_simp_clean/"
sources_dict=train_sources
augmentation_transform=train_augmentations
sampler: Sampler = FixedStepSampler()

In [6]:
dataset=[]
dataset_w_paths=[]

for category, instances in tqdm(sources_dict.items()):
        category_path = os.path.join(dataset_path, category)
        
        category_path = '../data/HMDB_simp_clean/pour'
        instances=['DC13AD0D']

        for instance in instances:
            instance_path = os.path.join(category_path, instance)
            if not os.path.isdir(instance_path):
                # print(f"Skipping non-directory file: {instance_path}")
                continue

            # Load sampled frames
            sampled = sampler.sample(instance_path)
            frames = []

            if isinstance(sampler, FrameSampler): 

                frame_paths = sampled
                
                for path in frame_paths:
                    try:
                        frames.append(
                            default_transforms(
                                image=augmentation_transform(
                                    image=cv2.cvtColor(
                                        cv2.imread(path), cv2.COLOR_BGR2RGB
                                    )
                                )["image"]
                            )["image"]
                        )
                    except Exception as e:
                        print(f"Error processing frame {path}: {e}")
                        frames.append(None)
    

            if isinstance(sampler, ClipSampler):

                clip_created = sampled

                for frame in clip_created:
                    try:
                        frames.append(
                            default_transforms(
                                image=augmentation_transform(
                                    frame
                                )["image"]
                            )["image"]
                        )
                    except Exception as e:
                        print(f"Error processing frame {path}: {e}")
                        frames.append(None)    
                

            # Create 8-frame clips
            clips, clips_path = create_clips(frames, 8,frame_paths)

            for idx, clip in enumerate(clips):
                dataset.append((clip, CATEGORY_INDEX[category]))
                if clips_path is not None:
                    dataset_w_paths.append((clip, clips_path[idx], CATEGORY_INDEX[category]))

            
            

100%|███████████████████████████████████████████| 25/25 [00:00<00:00, 28.06it/s]


In [99]:
len(dataset),dataset[0][0].size()

(70, torch.Size([8, 3, 224, 224]))

In [90]:
dataset_w_paths[0][0].size()

torch.Size([8, 3, 224, 224])

In [9]:
print(f"Sampler type: {type(sampler)}, ClipSampler? {isinstance(sampler, ClipSampler)}, FrameSampler? {isinstance(sampler, FrameSampler)}")


Sampler type: <class '__main__.FixedStepSampler'>, ClipSampler? False, FrameSampler? True


In [11]:
def list_frames(frame_dir):
        return [
            os.path.join(frame_dir, file)
            for file in sorted(os.listdir(frame_dir))
            if file.endswith((".jpg", ".png", ".jpeg"))
        ]

In [13]:
import numpy as np

In [192]:
for category, instances in tqdm(sources_dict.items()):
        category_path = os.path.join(dataset_path, category)
        
        category_path = '../data/HMDB_simp_clean/pour'
        #instances=['DC13AD0D']

        for instance in instances:
            instance_path = os.path.join(category_path, instance)
            if not os.path.isdir(instance_path):
                # print(f"Skipping non-directory file: {instance_path}")
                continue


100%|█████████████████████████████████████████| 25/25 [00:00<00:00, 2899.50it/s]


In [15]:
instance

'DC13AD0D'

In [14]:
frame_files = list_frames(instance_path)
total_frames = len(frame_files)

if total_frames == 0:
    raise ValueError("Video is empty")

else:
    video = []
    for f in frame_files:
        frame = cv2.cvtColor(cv2.imread(f), cv2.COLOR_BGR2RGB)
        video.append(frame)

    positions = np.linspace(0, total_frames - 1, 8)
    clip = []
    for pos in positions:
        low_idx = int(np.floor(pos))
        high_idx = min(low_idx + 1, total_frames - 1)
        alpha = pos - low_idx

        frame_low = video[low_idx]
        frame_high = video[high_idx]

        interp_frame = (1 - alpha) * frame_low + alpha * frame_high
        interp_frame = np.clip(interp_frame, 0, 255).astype(np.uint8)
        clip.append(interp_frame)

clip_final = np.stack(clip, axis=0)
position_list=list(positions)

In [70]:
positions, (total_frames-1)/7

(array([  0.,  41.,  82., 123., 164., 205., 246., 287.]), 41.0)

In [162]:
list(positions)

[0.0, 41.0, 82.0, 123.0, 164.0, 205.0, 246.0, 287.0]

In [71]:
low_idx,high_idx

(287, 287)

In [152]:
type(clip_final),type(frame_paths)

(numpy.ndarray, NoneType)

In [77]:
for frame in clip_final:
    image = augmentation_transform(image=frame)["i

In [170]:
clip_created = clip_final
transformed_clip = []

for frame in clip_created:
    try:
        transformed_clip.append(
            default_transforms(
                image=augmentation_transform(
                    image=frame
                )["image"]
            )["image"]
        )
    except Exception as e:
        print(f"Error processing frame {path}: {e}")
        transformed_clip.append(None)    

clips,clips_path = create_clips(transformed_clip, 8,position_list)

In [None]:
position_list

In [184]:
dataset=[]
dataset_w_paths=[]

for idx, clip in enumerate(clips):    
    dataset.append((clip, CATEGORY_INDEX[category]))
    if clips_path is not None:
        dataset_w_paths.append((clip, clips_path[idx], CATEGORY_INDEX[category]))


In [191]:
len(dataset_w_paths),dataset_w_paths[0][2]

(1, 21)

In [138]:
len(transformed_clip)

8

In [139]:
len(transformed_clip),transformed_clip[0].size()

(8, torch.Size([3, 224, 224]))

In [131]:
len(clips)

4

In [142]:
len(clips),clips[0].size()

(1, torch.Size([8, 3, 224, 224]))

In [135]:
def create_clips(frames, clip_size=8, frame_paths=None):
    """
    Given a list of sampled frames, create multiple [clip_size]-frame clips.
    Each clip is returned as a tensor.
    """
    clips = []
    path_clips = []
    
    for i in range(0, len(frames) - clip_size + 1, clip_size):
        clip = frames[i : i + clip_size]
        if len(clip) == clip_size:
            clips.append(torch.stack(clip))  # Convert the clip to a tensor
            if frame_paths is not None:
                path_clips.append(frame_paths[i : i + clip_size])

    if frame_paths is not None:
        return clips, path_clips
    else:
        return clips

array([[[[ 59,   9,   2],
         [ 82,  29,  21],
         [119,  65,  53],
         ...,
         [118,  21,   4],
         [121,  19,   4],
         [121,  19,   4]],

        [[ 82,  29,  23],
         [ 64,  11,   3],
         [ 90,  36,  24],
         ...,
         [122,  22,   6],
         [123,  21,   6],
         [123,  21,   6]],

        [[104,  49,  42],
         [ 68,  14,   4],
         [ 71,  17,   5],
         ...,
         [124,  25,   6],
         [126,  25,   7],
         [126,  25,   7]],

        ...,

        [[143,  99,  86],
         [142,  98,  85],
         [143, 100,  84],
         ...,
         [ 10,   4,   4],
         [  9,   5,   2],
         [  7,   6,   1]],

        [[160, 114, 101],
         [157, 113, 100],
         [153, 110,  94],
         ...,
         [  9,   4,   1],
         [  8,   5,   0],
         [  8,   5,   0]],

        [[160, 114, 101],
         [156, 110,  97],
         [141,  98,  82],
         ...,
         [  9,   4,   1],
        

In [29]:
frame.shape

(240, 320, 3)

In [None]:
import os
import cv2
import torch

# 1. List frames manually
frame_files = sorted([
    os.path.join(instance_path, f)
    for f in os.listdir(instance_path)
    if f.lower().endswith(('.jpg', '.jpeg', '.png'))
])

print(f"Found {len(frame_files)} frames")

total_frames = len(frame_files)

if total_frames == 0:
    raise ValueError("No frames found in the directory!")

# 2. Load frames manually with cv2
video = []
for f in frame_files:
    try:
        frame = cv2.imread(f)
        if frame is None:
            raise ValueError(f"Frame {f} could not be read (maybe corrupted?)")
        
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frame_tensor = torch.from_numpy(frame).permute(2, 0, 1).float() / 255.0
        video.append(frame_tensor)

    except Exception as e:
        print(f"Error reading frame {f}: {e}")

print(f"Successfully loaded {len(video)} frames")

# 3. Interpolation positions
min_frames = 8  # or whatever you want
positions = torch.linspace(0, total_frames - 1, steps=min_frames)
print(f"Interpolation positions: {positions}")

# 4. Interpolate
clip = []
for pos in positions:
    low_idx = int(torch.floor(pos).item())
    high_idx = min(low_idx + 1, total_frames - 1)
    alpha = pos - low_idx

    frame_low = video[low_idx]
    frame_high = video[high_idx]

    interp_frame = (1 - alpha) * frame_low + alpha * frame_high
    clip.append(interp_frame)

clip_final = torch.stack(clip)
print(f"Final clip tensor shape: {clip_final.shape}")


In [None]:
clip_final[0].size()

In [None]:
sample(sources_dict)

In [None]:
video_path = "../data/HMDB_simp_clean/brush_hair/020E3BBA"

In [None]:
os.getcwd()