In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision.transforms import Compose, Lambda
from torchvision.transforms._transforms_video import (
    CenterCropVideo,
    NormalizeVideo,
)
from pytorchvideo.transforms import (
    ApplyTransformToKey,
    ShortSideScale,
    UniformTemporalSubsample
)
from pytorchvideo.data.encoded_video import EncodedVideo
import numpy as np
from pathlib import Path
from sklearn.model_selection import KFold
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
import os

class VideoDataset(Dataset):
    """Custom Dataset for loading video clips from directory structure"""
    def __init__(self, root_dir, clip_duration=10, transform=None):
        self.root_dir = Path(root_dir)
        self.clip_duration = clip_duration
        self.transform = transform
        
        # Get all video paths and their labels
        self.samples = []
        self.class_to_idx = {}
        
        for idx, class_dir in enumerate(sorted(self.root_dir.glob('*'))):
            if class_dir.is_dir():
                class_name = class_dir.name
                self.class_to_idx[class_name] = idx
                
                for video_path in class_dir.glob('*.mp4'):
                    self.samples.append((str(video_path), idx))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        video_path, label = self.samples[idx]
        
        # Load video
        video = EncodedVideo.from_path(video_path)
        
        # Extract clip
        video_data = video.get_clip(start_sec=0, end_sec=self.clip_duration)
        
        # Apply transform if specified
        if self.transform:
            video_data = self.transform(video_data)
        
        return video_data["video"], label

def create_video_transform(side_size=256, crop_size=256, num_frames=8):
    """Creates video transform pipeline"""
    return ApplyTransformToKey(
        key="video",
        transform=Compose([
            UniformTemporalSubsample(num_frames),
            Lambda(lambda x: x/255.0),
            NormalizeVideo([0.45, 0.45, 0.45], [0.225, 0.225, 0.225]),
            ShortSideScale(size=side_size),
            CenterCropVideo(crop_size=(crop_size, crop_size))
        ])
    )

def modify_model_head(model, num_classes=3):
    """Modifies the classification head of the model"""
    if hasattr(model, 'blocks') and hasattr(model.blocks[-1], 'proj'):
        in_features = model.blocks[-1].proj.in_features
        model.blocks[-1].proj = nn.Linear(in_features, num_classes)
    else:
        # Generic approach for other model architectures
        for name, module in model.named_children():
            if isinstance(module, nn.Linear):
                in_features = module.in_features
                setattr(model, name, nn.Linear(in_features, num_classes))
    return model

def train_epoch(model, dataloader, criterion, optimizer, device):
    """Trains the model for one epoch"""
    model.train()
    running_loss = 0.0
    predictions = []
    true_labels = []
    
    for inputs, labels in tqdm(dataloader, desc="Training"):
        inputs = inputs.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        predictions.extend(outputs.argmax(1).cpu().numpy())
        true_labels.extend(labels.cpu().numpy())
    
    return running_loss / len(dataloader), predictions, true_labels

def validate(model, dataloader, criterion, device):
    """Validates the model"""
    model.eval()
    running_loss = 0.0
    predictions = []
    true_labels = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Validating"):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            predictions.extend(outputs.argmax(1).cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
    
    return running_loss / len(dataloader), predictions, true_labels

def plot_confusion_matrix(cm, class_names, fold):
    """Plots and saves confusion matrix"""
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names,
                yticklabels=class_names)
    plt.title(f'Confusion Matrix - Fold {fold}')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.tight_layout()
    plt.savefig(f'confusion_matrix_fold_{fold}.png')
    plt.close()

def train_model(model_class, dataset_path, num_epochs=30, batch_size=8, num_folds=5,
                learning_rate=0.001, device="cuda"):
    """
    Main training function with cross-validation
    
    Args:
        model_class: Class of the model to be trained (e.g., SlowR50)
        dataset_path: Path to dataset directory
        num_epochs: Number of training epochs
        batch_size: Batch size for training
        num_folds: Number of folds for cross-validation
        learning_rate: Learning rate for optimizer
        device: Device to train on ('cuda' or 'cpu')
    """
    # Setup
    device = torch.device(device if torch.cuda.is_available() else "cpu")
    transform = create_video_transform()
    dataset = VideoDataset(dataset_path, transform=transform)
    
    # Cross-validation setup
    kfold = KFold(n_splits=num_folds, shuffle=True)
    fold_results = []
    
    # Training loop for each fold
    for fold, (train_ids, val_ids) in enumerate(kfold.split(dataset)):
        print(f"\nTraining Fold {fold+1}/{num_folds}")
        
        # Create data loaders
        train_loader = DataLoader(dataset, batch_size=batch_size,
                                sampler=SubsetRandomSampler(train_ids))
        val_loader = DataLoader(dataset, batch_size=batch_size,
                              sampler=SubsetRandomSampler(val_ids))
        
        # Initialize model
        model = model_class()
        model = modify_model_head(model, num_classes=3)
        model = model.to(device)
        
        # Setup training
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3)
        
        best_val_loss = float('inf')
        fold_best_metrics = {}
        
        # Training loop
        for epoch in range(num_epochs):
            # Train
            train_loss, train_preds, train_labels = train_epoch(
                model, train_loader, criterion, optimizer, device
            )
            
            # Validate
            val_loss, val_preds, val_labels = validate(
                model, val_loader, criterion, device
            )
            
            # Print progress
            print(f"Epoch {epoch+1}/{num_epochs}")
            print(f"Train Loss: {train_loss:.4f}")
            print(f"Val Loss: {val_loss:.4f}")
            
            # Update learning rate
            scheduler.step(val_loss)
            
            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(model.state_dict(), f'best_model_fold_{fold}.pth')
                
                # Calculate metrics
                val_report = classification_report(val_labels, val_preds, 
                                                output_dict=True)
                fold_best_metrics = {
                    'fold': fold,
                    'val_loss': val_loss,
                    'val_report': val_report,
                    'confusion_matrix': confusion_matrix(val_labels, val_preds)
                }
        
        # Save fold results
        fold_results.append(fold_best_metrics)
        
        # Plot confusion matrix
        plot_confusion_matrix(
            fold_best_metrics['confusion_matrix'],
            list(dataset.class_to_idx.keys()),
            fold
        )
    
    # Calculate and print final results
    print("\nFinal Cross-Validation Results:")
    class_metrics = {cls: {'precision': [], 'recall': [], 'f1-score': []}
                    for cls in dataset.class_to_idx.keys()}
    
    for fold_metric in fold_results:
        for cls in dataset.class_to_idx.keys():
            metrics = fold_metric['val_report'][str(dataset.class_to_idx[cls])]
            class_metrics[cls]['precision'].append(metrics['precision'])
            class_metrics[cls]['recall'].append(metrics['recall'])
            class_metrics[cls]['f1-score'].append(metrics['f1-score'])
    
    # Create final report
    final_report = {}
    for cls in class_metrics:
        final_report[cls] = {
            'precision': f"{np.mean(class_metrics[cls]['precision']):.3f} ± {np.std(class_metrics[cls]['precision']):.3f}",
            'recall': f"{np.mean(class_metrics[cls]['recall']):.3f} ± {np.std(class_metrics[cls]['recall']):.3f}",
            'f1-score': f"{np.mean(class_metrics[cls]['f1-score']):.3f} ± {np.std(class_metrics[cls]['f1-score']):.3f}"
        }
    
    # Save final report
    pd.DataFrame(final_report).transpose().to_csv('classification_report.csv')
    print("\nClass-wise Performance:")
    print(pd.DataFrame(final_report).transpose())

def main():
    # Configuration
    config = {
        'model_class': lambda: torch.hub.load('facebookresearch/pytorchvideo', 'slowfast_r50', pretrained = True),
        'dataset_path': "vrwalking",
        'num_epochs': 5,
        'batch_size': 8,
        'num_folds': 2,
        'learning_rate': 0.001,
        'device': "cuda"
    }
    
    # Train model
    train_model(**config)

if __name__ == "__main__":
    main()