# IMPORTS

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models import EfficientNet_B0_Weights
import os
import cv2
import numpy as np
from torch.utils.data import Dataset
from torchvision import transforms
import glob
from PIL import Image
import time
import torch.optim as optim
from tqdm import tqdm


# TRANSFORMER BLOCK

In [2]:
class TransformerEncoderBlock(nn.Module):
    """
    Transformer Encoder Block.

    Implements a single Transformer Encoder Block as proposed in "Attention Is All You Need" (Vaswani et al.).

    This block consists of:
    - Multi-Head Self-Attention
    - Feedforward Neural Network (FFN)
    - Residual connections and Layer Normalization

    Args:
        input_dim (int): Dimension of the input embeddings/features.
        num_heads (int): Number of attention heads. Default is 8.
        ff_hidden_dim (int): Dimension of the hidden layer in the feedforward network. Default is 2048.
        dropout (float): Dropout rate applied after attention and feedforward layers. Default is 0.1.

    Example:
        >>> encoder_block = TransformerEncoderBlock(input_dim=512, num_heads=8)
        >>> x = torch.randn(32, 10, 512)  # (batch_size, sequence_length, input_dim)
        >>> output = encoder_block(x)

    Shape:
        - Input: (B, L, C) where B = batch size, L = sequence length, and C = input_dim.
        - Output: (B, L, C)
    """

    def __init__(self, input_dim, num_heads=8, ff_hidden_dim=2048, dropout=0.1):
        super(TransformerEncoderBlock, self).__init__()

        # Multi-Head Self-Attention
        self.self_attention = nn.MultiheadAttention(embed_dim=input_dim, num_heads=num_heads, dropout=dropout, batch_first=True)

        # Layer Normalization for attention output
        self.norm1 = nn.LayerNorm(input_dim)

        # Feedforward Network (FFN)
        self.ffn = nn.Sequential(
            nn.Linear(input_dim, ff_hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(ff_hidden_dim, input_dim),
        )

        # Layer Normalization for FFN output
        self.norm2 = nn.LayerNorm(input_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        """
        Forward pass of the Transformer Encoder Block.

        Args:
            x (torch.Tensor): Input tensor of shape (B, L, C).
            mask (torch.Tensor, optional): Attention mask of shape (B, L) or (L, L).

        Returns:
            torch.Tensor: Output tensor of shape (B, L, C).
        """
        # Multi-Head Self-Attention with residual connection
        attn_output, _ = self.self_attention(x, x, x, attn_mask=mask)
        x = self.norm1(x + self.dropout(attn_output))

        # Feedforward Network with residual connection
        ffn_output = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_output))

        return x

# VIDEO CLASSIFY MODEL

In [3]:
class VideoClassify(nn.Module):
    def __init__(self, num_classes=10, frames=16):
        super(VideoClassify, self).__init__()
        
        # Use EfficientNet B0
        self.backbone = models.efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT)
        
        # Remove the classifier from the backbone
        self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])
        
        # EfficientNet B0's last feature map has 1280 output channels
        
        # Direct transformation from backbone output to 128 dimensions with stronger downsampling
        self.feature_reducer = nn.Sequential(
            nn.AdaptiveAvgPool2d((2, 2)),  # Aggressive spatial downsampling to 2x2
            nn.Conv2d(1280, 128, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Flatten(),  # Flatten from [B, 128, 2, 2] to [B, 512]
            nn.Linear(512, 128)  # Add linear layer to get to 128 dimensions
        )

        # Keep the rest of the architecture similar, but with small dimensions
        self.pos_embedding = LearnablePositionalEmbedding(128, max_len=frames)
        
        # Smaller transformer to save memory
        self.transformer_block = TransformerEncoderBlock(input_dim=128, num_heads=4, ff_hidden_dim=128, dropout=0.1)
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, num_classes)
        )
        
        self.frames = frames

    def forward(self, x):
        # input is in the shape of (B, T, 3, 224, 224) where B is batch size and T is number of frames
        batch_size, seq_len = x.shape[0], x.shape[1]
        
        # Reshape to process all frames as a batch
        x = x.reshape(-1, 3, 224, 224)  # (B*T, 3, 224, 224)
        
        # Extract features from backbone
        features = self.backbone(x)  # Get the feature map
        
        # Reduce dimensions
        spatial_features = self.feature_reducer(features)  # (B*T, 128)
        
        # Reshape to separate batch and time dimensions
        spatial_features = spatial_features.reshape(batch_size, seq_len, 128)
        
        # Add learnable positional embeddings for temporal information
        spatial_features = self.pos_embedding(spatial_features)
        
        # Apply transformer for temporal modeling
        temporal_features = self.transformer_block(spatial_features)  # (B, T, 128)
        
        # Global average pooling across time dimension
        pooled_features = torch.mean(temporal_features, dim=1)  # (B, 128)
   
        # Classification
        output = self.classifier(pooled_features)  # (B, num_classes)
        
        return output


class LearnablePositionalEmbedding(nn.Module):
    """
    Learnable positional embeddings for each position in the sequence.
    Each frame position gets its own learnable embedding vector.
    """
    def __init__(self, d_model, dropout=0.1, max_len=100):
        super(LearnablePositionalEmbedding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # Create learnable parameter for position embeddings
        # Shape: (1, max_len, d_model) - one embedding vector per position
        self.position_embeddings = nn.Parameter(torch.zeros(1, max_len, d_model))
        
        # Initialize the position embeddings
        nn.init.trunc_normal_(self.position_embeddings, std=0.02)

    def forward(self, x):
        """
        Args:
            x: Tensor, shape [batch_size, seq_len, embedding_dim]        """
        # Add position embeddings to input features
        # Each frame at position i gets the same learnable embedding vector
        x = x + self.position_embeddings[:, :x.size(1), :]
        return self.dropout(x)

# DATASET/DATALOADER

In [None]:
class VideoDataset(Dataset):
    """
    Dataset for loading videos from a directory structure.
    
    The dataset expects videos to be organized in folders where each folder represents
    a class. It samples frames at even intervals to match the required number of frames
    for the VideoClassify model.
    
    Args:
        root_dir (str): Root directory containing video folders.
        num_frames (int): Number of frames to sample from each video.
        transform (callable, optional): Optional transform to be applied on sampled frames.
        extensions (list): List of valid video file extensions. Default: ['.mp4', '.avi', '.mov']
        class_map (dict, optional): Dictionary mapping folder names to class indices.
    """
    def __init__(self, root_dir, num_frames=16, transform=None, 
                 extensions=['.mp4', '.avi', '.mov', '.MOV'], class_map=None):
        self.root_dir = root_dir
        self.num_frames = num_frames
        self.transform = transform
        self.extensions = extensions
        
        # Default transform if none provided
        if self.transform is None:
            self.transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
        
        # Find all video files and their corresponding classes
        self.video_paths = []
        self.video_labels = []
        
        # Get class folders
        class_folders = [d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
        
        # Create class map if not provided
        if class_map is None:
            self.class_map = {folder: idx for idx, folder in enumerate(sorted(class_folders))}
        else:
            self.class_map = class_map
        
        # Collect videos and their labels
        for class_folder in class_folders:
            if class_folder in self.class_map:
                class_path = os.path.join(root_dir, class_folder)
                class_label = self.class_map[class_folder]
                
                # Find video files with specified extensions
                for ext in self.extensions:
                    videos = glob.glob(os.path.join(class_path, f'*{ext}'))
                    for video_path in videos:
                        self.video_paths.append(video_path)
                        self.video_labels.append(class_label)
        
        print(f"Found {len(self.video_paths)} videos across {len(class_folders)} classes")
    
    def __len__(self):
        return len(self.video_paths)
    
    def sample_frames(self, video_path):
        """
        Sample frames at even intervals from a video.
        
        Args:
            video_path (str): Path to the video file.
            
        Returns:
            list: List of sampled frames as PIL Images.
        """
        cap = cv2.VideoCapture(video_path)
        
        if not cap.isOpened():
            raise ValueError(f"Could not open video: {video_path}")
        
        # Get video properties
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        if total_frames <= 0:
            raise ValueError(f"Video has no frames: {video_path}")
        
        # Calculate frame indices to sample
        if total_frames <= self.num_frames:
            # If video has fewer frames than needed, duplicate frames
            indices = np.linspace(0, total_frames - 1, self.num_frames, dtype=int)
        else:
            # Sample frames at even intervals
            indices = np.linspace(0, total_frames - 1, self.num_frames, dtype=int)
        
        frames = []
        for idx in indices:
            # Set frame position
            cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
            ret, frame = cap.read()
            
            if not ret:
                # If frame read failed, create a black frame
                height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
                width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
                frame = np.zeros((height, width, 3), dtype=np.uint8)
            
            # Convert BGR (OpenCV) to RGB (PIL)
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = Image.fromarray(frame)
            frames.append(frame)
        
        cap.release()
        return frames
    
    def __getitem__(self, idx):
        """
        Get video frames and label for a given index.
        
        Args:
            idx (int): Index of the video to get.
            
        Returns:
            tuple: (frames, label) where frames is a tensor of shape (num_frames, 3, H, W)
            and label is the class index.
        """
        video_path = self.video_paths[idx]
        label = self.video_labels[idx]
        
        # Sample frames from the video
        frames = self.sample_frames(video_path)
        
        # Apply transforms to each frame
        if self.transform:
            frames = [self.transform(frame) for frame in frames]
        
        # Stack frames to create a tensor of shape (T, C, H, W)
        frames_tensor = torch.stack(frames)
        
        return frames_tensor, label

# Example usage to create data loaders
def create_data_loaders(root_dir, batch_size=8, num_frames=16, num_workers=4, 
                       train_transform=None, val_transform=None, train_ratio=0.8):
    """
    Create train and validation data loaders for video classification.
    
    Args:
        root_dir (str): Root directory containing video folders
        batch_size (int): Batch size for data loaders
        num_frames (int): Number of frames to sample per video
        num_workers (int): Number of workers for data loading
        train_transform (callable, optional): Transform for training data
        val_transform (callable, optional): Transform for validation data
        train_ratio (float): Ratio of training data (0.0 to 1.0)
        
    Returns:
        tuple: (train_loader, val_loader)
    """
    from torch.utils.data import DataLoader, random_split
    
    # Default transforms
    if train_transform is None:
        train_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    if val_transform is None:
        val_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    # Create full dataset
    full_dataset = VideoDataset(root_dir=root_dir, num_frames=num_frames, transform=None)
    
    # Split dataset into train and validation
    train_size = int(train_ratio * len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
    
    # Create dataset objects with appropriate transforms
    train_dataset = VideoDataset(
        root_dir=root_dir,
        num_frames=num_frames,
        transform=train_transform,
        class_map=full_dataset.class_map,
    )
    
    val_dataset = VideoDataset(
        root_dir=root_dir,
        num_frames=num_frames,
        transform=val_transform,
        class_map=full_dataset.class_map,
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    return train_loader, val_loader


# TRAIN/VALIDATE FUNCTIONS AND CONFIGS

In [5]:
# Training configuration
CONFIG = {
    'data_dir': 'dataset',           # Path to dataset directory
    'output_dir': './checkpoints',     # Path to save checkpoints and logs
    'frames': 75,                      # Number of frames to sample per video
    'batch_size': 1,                   # Training batch size
    'val_batch_size': 1,              # Validation batch size
    'epochs': 50,                      # Number of training epochs
    'lr': 0.001,                       # Initial learning rate
    'weight_decay': 1e-4,              # Weight decay
    'num_workers': 0,                  # Number of data loading workers
    'save_freq': 5                     # Save checkpoint every N epochs
}

def train_model():
    """Main training function."""
    # Create output directory
    os.makedirs(CONFIG['output_dir'], exist_ok=True)
    
    # Set up device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Create data loaders
    train_loader, val_loader = create_data_loaders(
        root_dir=CONFIG['data_dir'],
        batch_size=CONFIG['batch_size'],
        num_frames=CONFIG['frames'],
        num_workers=CONFIG['num_workers']
    )
    
    # Get number of classes from dataset
    num_classes = len(train_loader.dataset.class_map)
    print(f"Training on {num_classes} classes")
    
    # Create model
    model = VideoClassify(num_classes=num_classes, frames=CONFIG['frames'])
    model = model.to(device)
    
    
    # Set up loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=CONFIG['lr'], weight_decay=CONFIG['weight_decay'])
    
    # Learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
    
    # Training state
    start_epoch = 0
    best_val_acc = 0.0
    
    # Training loop
    for epoch in range(start_epoch, CONFIG['epochs']):
        print(f"\nEpoch {epoch+1}/{CONFIG['epochs']}")
        
        # Train for one epoch
        train_loss, train_acc = train_epoch(
            model=model,
            data_loader=train_loader,
            criterion=criterion,
            optimizer=optimizer,
            device=device,
            epoch=epoch
        )
        
        # Validate
        val_loss, val_acc = validate(
            model=model,
            data_loader=val_loader,
            criterion=criterion,
            device=device
        )
        
        # Update learning rate
        scheduler.step(val_loss)
        current_lr = optimizer.param_groups[0]['lr']
        
        # Print epoch summary
        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
        print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
        print(f"Learning Rate: {current_lr:.6f}")
        
        # Save checkpoint
        is_best = val_acc > best_val_acc
        best_val_acc = max(val_acc, best_val_acc)
        
        if is_best or (epoch + 1) % CONFIG['save_freq'] == 0:
            save_checkpoint({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_acc': best_val_acc,
                'val_acc': val_acc,
                'val_loss': val_loss,
                'train_acc': train_acc,
                'train_loss': train_loss,
            }, is_best)
    
    print(f"Training completed. Best validation accuracy: {best_val_acc:.4f}")

def train_epoch(model, data_loader, criterion, optimizer, device, epoch):
    """Train the model for one epoch."""
    model.train()
    running_loss = 0.0
    running_corrects = 0
    processed_size = 0
    
    start_time = time.time()
    
    pbar = tqdm(data_loader, desc=f"Epoch {epoch+1}")
    for inputs, labels in pbar:
        # Move inputs and labels to device
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        running_corrects += torch.sum(preds == labels.data).item()
        processed_size += inputs.size(0)
        
        # Update progress bar
        pbar.set_postfix({
            'loss': running_loss / processed_size,
            'acc': running_corrects / processed_size,
            'time': f"{time.time() - start_time:.2f}s"
        })
    
    # Calculate metrics
    epoch_loss = running_loss / len(data_loader.dataset)
    epoch_acc = running_corrects / len(data_loader.dataset)
    
    return epoch_loss, epoch_acc

def validate(model, data_loader, criterion, device):
    """Evaluate the model on the validation set."""
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    
    with torch.no_grad():
        for inputs, labels in tqdm(data_loader, desc="Validating"):
            # Move inputs and labels to device
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            # Statistics
            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            running_corrects += torch.sum(preds == labels.data).item()
    
    # Calculate metrics
    epoch_loss = running_loss / len(data_loader.dataset)
    epoch_acc = running_corrects / len(data_loader.dataset)
    
    return epoch_loss, epoch_acc

def save_checkpoint(state, is_best, filename='checkpoint.pth'):
    """Save training checkpoint."""
    checkpoint_path = os.path.join(CONFIG['output_dir'], filename)
    torch.save(state, checkpoint_path)
    print(f"Checkpoint saved to {checkpoint_path}")
    
    if is_best:
        best_path = os.path.join(CONFIG['output_dir'], 'best_model.pth')
        torch.save(state, best_path)
        print(f"Best model saved to {best_path}")

# ENTRY POINTS

In [6]:
train_model()

Using device: cuda
Found 57 videos across 5 classes
Found 57 videos across 5 classes
Found 57 videos across 5 classes
Training on 5 classes

Epoch 1/50


Epoch 1:  74%|███████▎  | 42/57 [15:18<05:28, 21.87s/it, loss=1.48, acc=0.381, time=905.64s]


KeyboardInterrupt: 