In [3]:
import torch
import numpy as np
import cv2
import pandas as pd
import os

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

In [14]:
root_dir = "./colors"
colors_dir = os.listdir(root_dir)
all_black = os.listdir(os.path.join(root_dir, colors_dir[0]))
first_vid = os.listdir(os.path.join(root_dir, colors_dir[0], all_black[0]))
print(len(colors_dir), len(all_black), len(first_vid))

8 40 30


In [15]:
class ColorVideoDataset(Dataset):
    """
    PyTorch Dataset for color video classification.
    Each video is a sequence of frames, labeled by the color folder it belongs to.
    """
    
    def __init__(self, root_dir, transform=None, sequence_length=None):
        """
        Args:
            root_dir (string): Directory with all the color folders.
            transform (callable, optional): Optional transform to be applied on frames.
            sequence_length (int, optional): Fixed length for video sequences. If None, uses all frames.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.sequence_length = sequence_length
        
        # Get all color categories
        self.color_folders = [f for f in os.listdir(root_dir) if f.startswith('colors_')]
        self.color_to_idx = {color.replace('colors_', ''): idx for idx, color in enumerate(self.color_folders)}
        self.idx_to_color = {idx: color.replace('colors_', '') for idx, color in enumerate(self.color_folders)}
        
        # Build dataset index
        self.video_paths = []
        self.labels = []
        
        for color_folder in self.color_folders:
            color_path = os.path.join(root_dir, color_folder)
            color_name = color_folder.replace('colors_', '')
            label = self.color_to_idx[color_name]
            
            # Get all video folders for this color
            video_folders = [v for v in os.listdir(color_path) if v.startswith(color_name + '_Video_')]
            
            for video_folder in video_folders:
                video_path = os.path.join(color_path, video_folder)
                self.video_paths.append(video_path)
                self.labels.append(label)
    
    def __len__(self):
        return len(self.video_paths)
    
    def __getitem__(self, idx):
        """
        Returns:
            video_frames: Tensor of shape (T, C, H, W) where T is sequence length
            label: Integer label for the color category
            video_info: Dictionary with metadata
        """
        video_path = self.video_paths[idx]
        label = self.labels[idx]
        
        # Get all frame files in the video folder
        frame_files = sorted([f for f in os.listdir(video_path) if f.endswith('.jpg')])
        
        frames = []
        for frame_file in frame_files:
            frame_path = os.path.join(video_path, frame_file)
            
            # Load image
            frame = cv2.imread(frame_path)
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB
            
            # Apply transforms if provided
            if self.transform:
                frame = self.transform(frame)
            else:
                # Convert to tensor and normalize to [0, 1]
                frame = torch.from_numpy(frame).permute(2, 0, 1).float() / 255.0
            
            frames.append(frame)
        
        # Convert to tensor
        video_frames = torch.stack(frames)  # Shape: (T, C, H, W)
        
        # Handle sequence length
        if self.sequence_length is not None:
            if len(video_frames) >= self.sequence_length:
                # Randomly sample frames or take first N frames
                video_frames = video_frames[:self.sequence_length]
            else:
                # Pad with last frame if sequence is shorter
                padding_needed = self.sequence_length - len(video_frames)
                last_frame = video_frames[-1].unsqueeze(0)
                padding = last_frame.repeat(padding_needed, 1, 1, 1)
                video_frames = torch.cat([video_frames, padding], dim=0)
        
        # Create metadata
        video_info = {
            'video_path': video_path,
            'color': self.idx_to_color[label],
            'num_frames': len(frame_files),
            'video_folder': os.path.basename(video_path)
        }
        
        return video_frames, label, video_info
    
    def get_class_names(self):
        """Returns list of color class names"""
        return [self.idx_to_color[i] for i in range(len(self.idx_to_color))]
    
    def get_num_classes(self):
        """Returns number of color classes"""
        return len(self.color_to_idx)

In [18]:
# Example usage and testing
if __name__ == "__main__":
    # Create dataset
    dataset = ColorVideoDataset(root_dir="./colors", sequence_length=15)
    
    print(f"Dataset size: {len(dataset)}")
    print(f"Number of classes: {dataset.get_num_classes()}")
    print(f"Class names: {dataset.get_class_names()}")
    
    # Test loading a sample
    video_frames, label, video_info = dataset[0]
    print(f"\nSample video info:")
    print(f"Video path: {video_info['video_path']}")
    print(f"Color: {video_info['color']}")
    print(f"Label: {label}")
    print(f"Video frames shape: {video_frames.shape}")
    print(f"Original frames: {video_info['num_frames']}")
    
    # Create custom collate function for DataLoader to handle metadata
    def custom_collate_fn(batch):
        videos, labels, infos = zip(*batch)
        videos = torch.stack(videos)
        labels = torch.tensor(labels)
        return videos, labels, infos
    
    # Create DataLoader for batching
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True, 
                          collate_fn=custom_collate_fn, num_workers=0)
    
    # Test batch loading
    for batch_idx, (batch_frames, batch_labels, batch_info) in enumerate(dataloader):
        print(f"\nBatch {batch_idx}:")
        print(f"Batch frames shape: {batch_frames.shape}")
        print(f"Batch labels: {batch_labels}")
        print(f"Colors in batch: {[info['color'] for info in batch_info]}")
        print(f"Video folders: {[info['video_folder'] for info in batch_info]}")
        if batch_idx == 0:  # Only show first batch
            break

Dataset size: 320
Number of classes: 8
Class names: ['black', 'blue', 'brown', 'green', 'orange', 'red', 'white', 'yellow']

Sample video info:
Video path: ./colors\colors_black\black_Video_1
Color: black
Label: 0
Video frames shape: torch.Size([15, 3, 480, 640])
Original frames: 30

Batch 0:
Batch frames shape: torch.Size([4, 15, 3, 480, 640])
Batch labels: tensor([4, 0, 4, 4])
Colors in batch: ['orange', 'black', 'orange', 'orange']
Video folders: ['orange_Video_1', 'black_Video_12', 'orange_Video_16', 'orange_Video_18']


In [None]:
# Advanced usage: Train/Validation split and data statistics
from torch.utils.data import random_split
from collections import Counter

# Create dataset
full_dataset = ColorVideoDataset(root_dir="./colors", sequence_length=30)

# Create train/validation split (80/20)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

print(f"Full dataset size: {len(full_dataset)}")
print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")

# Check label distribution in training set
train_labels = [full_dataset.labels[i] for i in train_dataset.indices]
label_counts = Counter(train_labels)
print(f"\nTraining set label distribution:")
for label, count in label_counts.items():
    color_name = full_dataset.idx_to_color[label]
    print(f"  {color_name}: {count} videos")

# Create data loaders with proper collate function
def video_collate_fn(batch):
    videos, labels, infos = zip(*batch)
    videos = torch.stack(videos)
    labels = torch.tensor(labels)
    return videos, labels, infos

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, 
                         collate_fn=video_collate_fn, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, 
                       collate_fn=video_collate_fn, num_workers=0)

print(f"\nTrain batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

# Example: Calculate dataset statistics
print(f"\nDataset statistics:")
print(f"Video resolution: {full_dataset[0][0].shape[2:]} (H x W)")
print(f"Sequence length: {full_dataset[0][0].shape[0]}")
print(f"Color channels: {full_dataset[0][0].shape[1]}")
print(f"Pixel value range: [{full_dataset[0][0].min():.3f}, {full_dataset[0][0].max():.3f}]")

Full dataset size: 320
Train dataset size: 256
Validation dataset size: 64
number of training labels: 256

Training set label distribution:
  red: 32 videos
  white: 33 videos
  orange: 32 videos
  yellow: 33 videos
  green: 27 videos
  brown: 32 videos
  blue: 32 videos
  black: 35 videos

Train batches: 32
Validation batches: 8

Dataset statistics:
Video resolution: torch.Size([480, 640]) (H x W)
Sequence length: 30
Color channels: 3
Sequence length: 30
Color channels: 3
Pixel value range: [0.000, 1.000]
Pixel value range: [0.000, 1.000]


In [None]:
full_dataset = ColorVideoDataset(root_dir="./colors", sequence_length=30)

In [None]:
# Optional: Define transforms for data augmentation and preprocessing
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from PIL import Image

def create_video_transform(image_size=(224, 224)):
    """
    Create a transform pipeline for video frames
    """
    transforms_list = [
        lambda x: Image.fromarray(x),  # Convert numpy array to PIL Image
        Resize(image_size),
        ToTensor(),  # Converts to [0, 1] and changes to (C, H, W)
    ]
    
    return Compose(transforms_list)

transform = create_video_transform(image_size=(128, 128))
dataset_with_transforms = ColorVideoDataset(
    root_dir="./colors", 
    transform=transform,
    sequence_length=10
)