In [17]:
import os
import shutil
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, models, transforms
from torchvision.utils import save_image
from torch.utils.data import Dataset, DataLoader

print(torch.cuda.is_available())

True


### Loading Datasets

loading from `/home/nfs/inf6/data/datasets/kth_actions`

In [18]:
import os
from glob import glob

# Define the split
train_persons = {'person11', 'person12', 'person13', 'person14', 'person15', 'person16', 'person17', 'person02', 'person03', 'person05', 'person06', 'person07', 'person08', 'person09', 'person10', 'person18'}
val_persons = {'person19', 'person20', 'person21', 'person23', 'person24', 'person25', 'person01', 'person04'}

def get_sequences_with_labels(base_dir, persons):
    sequences = []
    labels = []
    classes = os.listdir(base_dir)
    class_to_idx = {cls_name: idx for idx, cls_name in enumerate(classes)}

    for cls in classes:
        for person in persons:
            person_sequences = glob(os.path.join(base_dir, cls, f'{person}*'))
            sequences.extend(person_sequences)
            labels.extend([class_to_idx[cls]] * len(person_sequences))
    
    return sequences, labels

base_dir = '/home/nfs/inf6/data/datasets/kth_actions/processed'
train_sequences, train_labels = get_sequences_with_labels(base_dir, train_persons)
val_sequences, val_labels = get_sequences_with_labels(base_dir, val_persons)


In [19]:
import cv2
import numpy as np

def load_frames(sequence_path):
    frame_files = sorted(glob(os.path.join(sequence_path, '*.png')))
    frames = [cv2.imread(frame_file) for frame_file in frame_files]
    frames = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in frames]  # Convert to RGB
    frames = [frame / 255.0 for frame in frames]  # Normalize
    return frames

def create_subsequences(frames, subsequence_length=13):
    subsequences = []
    for i in range(len(frames) - subsequence_length + 1):
        subsequences.append(frames[i:i + subsequence_length])
    return subsequences

sequence_path = train_sequences[0]
frames = load_frames(sequence_path)
subsequences = create_subsequences(frames)


In [20]:
class VideoDataset(Dataset):
    def __init__(self, sequences, labels, transform=None, subsequence_length=13):
        self.sequences = sequences
        self.labels = labels
        self.transform = transform
        self.subsequence_length = subsequence_length
        self.data = self.load_data()

    def load_data(self):
        data = []
        for sequence, label in zip(self.sequences, self.labels):
            frames = load_frames(sequence)
            subsequences = create_subsequences(frames, self.subsequence_length)
            for subsequence in subsequences:
                data.append((subsequence, label))
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        subsequence, label = self.data[idx]
        if self.transform:
            subsequence = [self.transform(frame) for frame in subsequence]
        subsequence = torch.stack([torch.tensor(frame).permute(2, 0, 1) for frame in subsequence])  # Convert to CxHxW
        label = torch.tensor(label)
        return subsequence, label

# Define transforms if needed
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# Create dataset and dataloader
train_dataset = VideoDataset(train_sequences, train_labels, transform=data_transforms['train'])
val_dataset = VideoDataset(val_sequences, val_labels, transform=data_transforms['val'])

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers = 4)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers = 4)