<a href="https://colab.research.google.com/github/Mechanics-Mechatronics-and-Robotics/CV-2025/blob/main/Week_11/MovingMNIST_comparison_of_models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q pytorch-lightning clearml

In [None]:
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import torchmetrics
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import matplotlib.pyplot as plt
import re
from clearml import Task
import random

In [None]:
#Enter your code here to implement Step 2 of the logging instruction as it is shown below
%env CLEARML_WEB_HOST=https://app.clear.ml/
%env CLEARML_API_HOST=https://api.clear.ml
%env CLEARML_FILES_HOST=https://files.clear.ml
%env CLEARML_API_ACCESS_KEY=ZP02U03C6V5ER4K9VWRNZT7EWA5ZTV
%env CLEARML_API_SECRET_KEY=BtA5GXZufr6QGpaqhX1GSKPTvaCt56OLqaNqUGLNoxx2Ye8Ctwbui0Ln5OXVnzUgH4I

In [None]:


# Configuration
config = {
    # Data Config
    "img_size": 64,
    "total_frames": 20,        # Total frames in each video
    "sample_frames": {         # Frames sampled per model type
        "2D": 1,
        "3D": 5,
        "2D+1D": 5
    },
    "batch_size": 32,

    # Architecture
    "conv2d_channels": [1, 32, 64],
    "conv3d_channels": [1, 32, 64],
    "kernel_size": (3,3,3),
    "pool_size": 2,
    "linear_hidden": 128,
    "lstm_hidden": 128,

    # Training
    "max_epochs": 2,
    "learning_rate": 1e-3,
    "early_stop_patience": 5,

    # System
    "accelerator": "auto",
    "log_dir": "./logs",
    "checkpoint_dir": "./checkpoints"
}


def collate_fn(batch):
    videos, targets = zip(*batch)

    # Handle mixed 3D/4D inputs
    if isinstance(videos[0], torch.Tensor):
        if videos[0].dim() == 3:  # 2D case [C,H,W]
            return torch.stack(videos), torch.stack(targets)
        else:  # 3D case [T,C,H,W]
            return torch.stack(videos), torch.stack(targets)
    raise ValueError("Unexpected input dimensions")

# Enhanced MovingMNIST Dataset with frame sampling
class MovingMNIST(Dataset):
    def __init__(self, train=True, img_size=64, total_frames=20):
        self.mnist = MNIST(root='./data', train=train, download=True, transform=ToTensor())
        self.img_size = img_size
        self.total_frames = total_frames

    def __len__(self):
        return len(self.mnist) // 2

    def __getitem__(self, idx):
        digit1, label1 = self.mnist[2 * idx]
        digit2, label2 = self.mnist[2 * idx + 1]

        frames = []
        pos1 = np.random.randint(0, self.img_size - 28 + 1, size=2)
        pos2 = np.random.randint(0, self.img_size - 28 + 1, size=2)
        vel1 = np.random.randint(-3, 4, size=2)
        vel2 = np.random.randint(-3, 4, size=2)

        for _ in range(self.total_frames):
            frame = torch.zeros((1, self.img_size, self.img_size))

            def update_pos(pos, vel):
                new_pos = pos + vel
                for i in range(2):
                    if new_pos[i] < 0:
                        vel[i] *= -1
                        new_pos[i] = 0
                    elif new_pos[i] > (self.img_size - 28):
                        vel[i] *= -1
                        new_pos[i] = 2*(self.img_size - 28) - new_pos[i]
                return new_pos

            pos1 = update_pos(pos1, vel1)
            pos2 = update_pos(pos2, vel2)

            x1, y1 = int(pos1[0]), int(pos1[1])
            frame[:, y1:y1+28, x1:x1+28] += digit1.squeeze()
            x2, y2 = int(pos2[0]), int(pos2[1])
            frame[:, y2:y2+28, x2:x2+28] += digit2.squeeze()

            frames.append(frame.clamp(0, 1))

        video = torch.stack(frames)
        target = label1 + label2
        return video, target

# Frame sampling data loader
class FrameSampler:
    def __init__(self, dataset, model_type, sample_frames):
        self.dataset = dataset
        self.model_type = model_type
        self.sample_frames = sample_frames

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        video, target = self.dataset[idx]

        if self.model_type == "2D":
            # For 2D CNN: return single frame without temporal dimension
            frame_idx = random.randint(0, len(video)-1)
            return video[frame_idx], target  # Shape [C,H,W]
        else:
            # For 3D/Hybrid: keep temporal dimension
            max_start = len(video) - self.sample_frames
            start = random.randint(0, max(0, max_start))
            return video[start:start+self.sample_frames], target  # Shape [T,C,H,W]

# Model Architectures
class CNN2D(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.save_hyperparameters()

        self.backbone = nn.Sequential(
            nn.Conv2d(1, config["conv2d_channels"][1], 3),
            nn.ReLU(),
            nn.MaxPool2d(config["pool_size"]),
            nn.Conv2d(config["conv2d_channels"][1], config["conv2d_channels"][2], 3),
            nn.ReLU(),
            nn.MaxPool2d(config["pool_size"]),
            nn.Flatten()
        )

        with torch.no_grad():
            dummy = torch.rand(1, 1, config["img_size"], config["img_size"])
            features = self.backbone(dummy).shape[1]

        self.classifier = nn.Sequential(
            nn.Linear(features, config["linear_hidden"]),
            nn.ReLU(),
            nn.Linear(config["linear_hidden"], 19)
        )

        self.train_acc = torchmetrics.Accuracy(task='multiclass', num_classes=19)
        self.val_acc = torchmetrics.Accuracy(task='multiclass', num_classes=19)

    def forward(self, x):
        # x: (B,1,H,W) for 2D
        if x.dim() == 5:
          x = x.squeeze(1)  # Remove temporal dim if present
        return self.classifier(self.backbone(x))

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.train_acc(y_hat, y)
        self.log_dict({
            "train_loss": loss,
            "train_acc": self.train_acc
        }, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.val_acc(y_hat, y)
        self.log_dict({
            "val_loss": loss,
            "val_acc": self.val_acc
        }, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=self.config["learning_rate"])

class CNN3D(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.save_hyperparameters()

        # Modified architecture with proper dimension handling
        self.conv3d = nn.Sequential(
            # First conv with padding to maintain size
            nn.Conv3d(1, config["conv3d_channels"][1],
                     kernel_size=(1, 3, 3),  # Only spatial conv first
                     padding=(0, 1, 1)),  # Pad spatial dimensions only
            nn.ReLU(),

            # Temporal convolution after spatial features
            nn.Conv3d(config["conv3d_channels"][1], config["conv3d_channels"][1],
                     kernel_size=(3, 1, 1),  # Only temporal conv
                     padding=(1, 0, 0)),  # Pad temporal dimension only
            nn.ReLU(),

            # Pool only spatial dimensions
            nn.MaxPool3d((1, 2, 2)),  # Only reduce H,W

            # Second conv block
            nn.Conv3d(config["conv3d_channels"][1], config["conv3d_channels"][2],
                     kernel_size=(1, 3, 3), padding=(0, 1, 1)),
            nn.ReLU(),
            nn.MaxPool3d((1, 2, 2)),  # Only reduce H,W

            # Final adaptive pooling
            nn.AdaptiveAvgPool3d((None, 1, 1))  # Reduce spatial to 1x1
        )

        # Calculate features dynamically
        with torch.no_grad():
            dummy = torch.rand(1, 1, config["sample_frames"]["3D"],
                             config["img_size"], config["img_size"])
            features = self.conv3d(dummy).numel()

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(features, config["linear_hidden"]),
            nn.ReLU(),
            nn.Linear(config["linear_hidden"], 19)
        )

        # Metrics
        self.train_acc = torchmetrics.Accuracy(task='multiclass', num_classes=19)
        self.val_acc = torchmetrics.Accuracy(task='multiclass', num_classes=19)

    def forward(self, x):
        # x: (B,T,C,H,W)
        x = x.permute(0, 2, 1, 3, 4)  # (B,C,T,H,W)
        return self.classifier(self.conv3d(x))

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.train_acc(y_hat, y)
        self.log_dict({
            "train_loss": loss,
            "train_acc": self.train_acc
        }, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.val_acc(y_hat, y)
        self.log_dict({
            "val_loss": loss,
            "val_acc": self.val_acc
        }, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=self.config["learning_rate"])

class HybridCNN(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.save_hyperparameters()

        # Spatial processing
        self.spatial = nn.Sequential(
            nn.Conv2d(1, config["conv2d_channels"][1], 3),
            nn.ReLU(),
            nn.MaxPool2d(config["pool_size"]),
            nn.Conv2d(config["conv2d_channels"][1], config["conv2d_channels"][2], 3),
            nn.ReLU(),
            nn.MaxPool2d(config["pool_size"]),
            nn.Flatten()
        )

        # Temporal processing
        with torch.no_grad():
            dummy = torch.rand(1, 1, config["img_size"], config["img_size"])
            spatial_features = self.spatial(dummy).shape[1]

        self.lstm = nn.LSTM(
            input_size=spatial_features,
            hidden_size=config["lstm_hidden"],
            batch_first=True
        )

        self.classifier = nn.Linear(config["lstm_hidden"], 19)

        # Metrics
        self.train_acc = torchmetrics.Accuracy(task='multiclass', num_classes=19)
        self.val_acc = torchmetrics.Accuracy(task='multiclass', num_classes=19)

    def forward(self, x):
        # x: (B,T,C,H,W)
        batch_size, timesteps = x.shape[:2]
        x = x.view(batch_size * timesteps, *x.shape[2:])  # (B*T,C,H,W)
        spatial = self.spatial(x)  # (B*T, D)
        spatial = spatial.view(batch_size, timesteps, -1)  # (B,T,D)

        lstm_out, _ = self.lstm(spatial)  # (B,T,H)
        return self.classifier(lstm_out[:,-1,:])  # Last timestep

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.train_acc(y_hat, y)
        self.log_dict({
            "train_loss": loss,
            "train_acc": self.train_acc
        }, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.val_acc(y_hat, y)
        self.log_dict({
            "val_loss": loss,
            "val_acc": self.val_acc
        }, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=self.config["learning_rate"])

# Comparison Function
def compare_models():
    task = Task.init(project_name="MovingMNIST-Comparison",
                    task_name="Architecture Comparison")
    task.connect(config)

    results = {}
    architectures = {
        "2D": CNN2D,
        "3D": CNN3D,
        "2D+1D": HybridCNN
    }

    # Create base dataset
    train_dataset = MovingMNIST(train=True,
                              img_size=config["img_size"],
                              total_frames=config["total_frames"])
    val_dataset = MovingMNIST(train=False,
                            img_size=config["img_size"],
                            total_frames=config["total_frames"])

    for model_name, model_class in architectures.items():
        print(f"\n=== Training {model_name} Model ===")

        # Create frame-sampled datasets
        train_sampler = FrameSampler(train_dataset, model_name,
                                   config["sample_frames"][model_name])
        val_sampler = FrameSampler(val_dataset, model_name,
                                 config["sample_frames"][model_name])

        train_loader = DataLoader(train_sampler,
                                 batch_size=config["batch_size"],
                                 shuffle=True)
        val_loader = DataLoader(val_sampler,
                               batch_size=config["batch_size"])

        # Initialize model
        model = model_class(config)

        # Callbacks
        checkpoint = ModelCheckpoint(
            dirpath=config["checkpoint_dir"],
            filename=f"{model_name}-best",
            monitor="val_loss",
            mode="min"
        )

        trainer = pl.Trainer(
            max_epochs=config["max_epochs"],
            accelerator=config["accelerator"],
            logger=CSVLogger(save_dir=f"{config['log_dir']}/{model_name}"),
            callbacks=[checkpoint, EarlyStopping(monitor="val_loss",
                                               patience=config["early_stop_patience"])]
        )

        # Train
        trainer.fit(model, train_loader, val_loader)

        # Store results
        best_path = checkpoint.best_model_path
        if best_path:
            model = model_class.load_from_checkpoint(best_path)
            val_results = trainer.validate(model, val_loader)[0]
            results[model_name] = {
                "val_loss": val_results["val_loss"],
                "val_acc": val_results["val_acc"],
                "params": sum(p.numel() for p in model.parameters())
            }

            # Visualize predictions
            sample_batch = next(iter(val_loader))
            visualize_predictions(sample_batch, model, model_name)

    # Print comparison table
    print("\n=== Final Comparison ===")
    print(f"{'Model':<10} {'Val Loss':<10} {'Val Acc':<10} {'Params':<10}")
    for name, res in results.items():
        print(f"{name:<10} {res['val_loss']:.4f}     {res['val_acc']:.4f}     {res['params']:,}")

    task.close()
    return results

# Visualization function
def visualize_predictions(batch, model, model_name):
    x, y = batch
    model.eval()

    with torch.no_grad():
        preds = model(x).argmax(dim=1)

    plt.figure(figsize=(10, 4))
    plt.suptitle(f"{model_name} Predictions", fontsize=14)

    for i in range(3):  # Show first 3 samples
        plt.subplot(1, 3, i+1)

        if len(x.shape) == 5:  # 3D or Hybrid - show middle frame
            frame_idx = x.shape[1] // 2
            plt.imshow(x[i, frame_idx, 0].cpu(), cmap='gray')
        else:  # 2D - show single frame
            plt.imshow(x[i, 0].cpu(), cmap='gray')

        plt.title(f"True: {y[i].item()}\nPred: {preds[i].item()}")
        plt.axis('off')

    plt.tight_layout()
    plt.show()

# Run comparison
if __name__ == "__main__":
    compare_models()