# Dataset

In [None]:
import gdown

url = "https://drive.google.com/uc?id=1rWs5F3UO9h1BDnFSFDhLJmDC3F0BvhSM"
output = "dataset.zip"  # change name if needed

gdown.download(url, output, quiet=False)

!unzip dataset.zip

Downloading...
From (original): https://drive.google.com/uc?id=1rWs5F3UO9h1BDnFSFDhLJmDC3F0BvhSM
From (redirected): https://drive.google.com/uc?id=1rWs5F3UO9h1BDnFSFDhLJmDC3F0BvhSM&confirm=t&uuid=b24f266f-02d1-4122-b83e-83b794c7cc0b
To: /content/dataset.zip
100%|██████████| 1.81G/1.81G [00:28<00:00, 63.2MB/s]


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: Dataset/Validation/Real/real_5499.jpg  
  inflating: Dataset/Validation/Real/real_55.jpg  
  inflating: Dataset/Validation/Real/real_550.jpg  
  inflating: Dataset/Validation/Real/real_5500.jpg  
  inflating: Dataset/Validation/Real/real_5501.jpg  
  inflating: Dataset/Validation/Real/real_5502.jpg  
  inflating: Dataset/Validation/Real/real_5503.jpg  
  inflating: Dataset/Validation/Real/real_5504.jpg  
  inflating: Dataset/Validation/Real/real_5505.jpg  
  inflating: Dataset/Validation/Real/real_5506.jpg  
  inflating: Dataset/Validation/Real/real_5507.jpg  
  inflating: Dataset/Validation/Real/real_5508.jpg  
  inflating: Dataset/Validation/Real/real_5509.jpg  
  inflating: Dataset/Validation/Real/real_551.jpg  
  inflating: Dataset/Validation/Real/real_5510.jpg  
  inflating: Dataset/Validation/Real/real_5511.jpg  
  inflating: Dataset/Validation/Real/real_5512.jpg  
  inflating: Dataset/Validation/Real/r

In [None]:
from pathlib import Path

import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Path to your dataset root
DATASET_ROOT = Path("Dataset")

BATCH_SIZE = 32
NUM_WORKERS = 0
PIN_MEMORY = False
IMAGE_SIZE = 256

In [None]:
# Data augmentation + normalization for training
train_transform = transforms.Compose([
    # transforms.RandomHorizontalFlip(),
    # transforms.RandomRotation(5),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])

# Only normalization for validation & test
eval_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])

In [None]:
train_dataset = datasets.ImageFolder(
    root=DATASET_ROOT / "Train",
    transform=train_transform,
)

val_dataset = datasets.ImageFolder(
    root=DATASET_ROOT / "Validation",
    transform=eval_transform,
)

test_dataset = datasets.ImageFolder(
    root=DATASET_ROOT / "Test",
    transform=eval_transform,
)

# Class names and label mapping
class_names = train_dataset.classes
class_to_idx = train_dataset.class_to_idx

print("Classes:", class_names)
print("Class to index:", class_to_idx)

print("Train size:", len(train_dataset))
print("Val size:", len(val_dataset))
print("Test size:", len(test_dataset))


Classes: ['Fake', 'Real']
Class to index: {'Fake': 0, 'Real': 1}
Train size: 140002
Val size: 39428
Test size: 10905


In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
)

images, labels = next(iter(train_loader))
print("Batch image tensor shape:", images.shape)
print("Batch label tensor shape:", labels.shape)

Batch image tensor shape: torch.Size([32, 3, 256, 256])
Batch label tensor shape: torch.Size([32])


# Training

In [None]:
from __future__ import annotations

from typing import Any, Dict, Iterable, List, Optional, Type, Union

import torch
from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm  # <-- added


def train_model(
    model: nn.Module,
    train_dataset: Dataset,
    eval_dataset: Dataset,
    *,
    num_epochs: int,
    base_lr: float,
    optimizer_cls: Type[Optimizer] = torch.optim.Adam,
    optimizer_kwargs: Optional[Dict[str, Any]] = None,
    optimizer_params: Optional[Iterable[Union[nn.Parameter, Dict[str, Any]]]] = None,
    lr_scheduler_cls: Optional[Type[_LRScheduler]] = None,
    lr_scheduler_kwargs: Optional[Dict[str, Any]] = None,
    scheduler_step_per_batch: bool = False,
    batch_size: int = 32,
    num_workers: int = 0,
    pin_memory: bool = False,
    device: Optional[Union[str, torch.device]] = None,
    criterion: Optional[nn.Module] = None,
    use_amp: bool = True,
    grad_clip_norm: Optional[float] = None,
    non_blocking: bool = True,
    verbose: bool = True,
) -> Dict[str, List[float]]:
    """
    Train and evaluate a model on a classification task.

    Args:
        model: The nn.Module to train.
        train_dataset: Dataset used for training.
        eval_dataset: Dataset used for evaluation (validation OR test).
        num_epochs: Number of full passes over the training dataset.
        base_lr: Base learning rate. Used as optimizer `lr` unless overridden
                 via `optimizer_kwargs`.
        optimizer_cls: Optimizer class (e.g. torch.optim.Adam, SGD, AdamW).
        optimizer_kwargs: Extra kwargs passed to the optimizer constructor.
        optimizer_params: Iterable of parameters or param groups. If None,
                          `model.parameters()` is used.
        lr_scheduler_cls: LR scheduler class (e.g. StepLR, CosineAnnealingLR).
        lr_scheduler_kwargs: Extra kwargs for the scheduler.
        scheduler_step_per_batch: If True, call `scheduler.step()` every batch.
                                  If False, call it once per epoch.
        batch_size: Batch size for both train and eval loaders.
        num_workers: DataLoader workers.
        pin_memory: DataLoader pin_memory flag.
        device: Device string or torch.device. If None, auto-selects CUDA if available.
        criterion: Loss function. Defaults to nn.CrossEntropyLoss().
        use_amp: If True and CUDA is available, use mixed-precision training.
        grad_clip_norm: If not None, apply gradient clipping with this max norm.
        non_blocking: If True, use non_blocking transfers to device.
        verbose: If True, print epoch-level metrics.

    Returns:
        A dict with per-epoch metrics:
            {
                "train_loss": [...],
                "train_acc":  [...],
                "eval_loss":  [...],
                "eval_acc":   [...],
            }
        where "eval_*" corresponds to whatever you passed as `eval_dataset`
        (validation, test, etc.).
    """
    # ---- device & loss ----
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device(device)

    model.to(device)

    if criterion is None:
        criterion = nn.CrossEntropyLoss()

    # ---- dataloaders ----
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )

    eval_loader = DataLoader(
        eval_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory,
    )

    # ---- optimizer ----
    if optimizer_kwargs is None:
        optimizer_kwargs = {}

    # Respect explicit lr if user put it in optimizer_kwargs; otherwise use base_lr
    optimizer_kwargs.setdefault("lr", base_lr)

    if optimizer_params is None:
        optimizer_params = model.parameters()

    optimizer = optimizer_cls(optimizer_params, **optimizer_kwargs)

    # ---- scheduler (optional) ----
    scheduler: Optional[_LRScheduler] = None
    if lr_scheduler_cls is not None:
        if lr_scheduler_kwargs is None:
            lr_scheduler_kwargs = {}
        scheduler = lr_scheduler_cls(optimizer, **lr_scheduler_kwargs)

    # ---- AMP scaler ----
    use_amp = bool(use_amp and device.type == "cuda")
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

    history: Dict[str, List[float]] = {
        "train_loss": [],
        "train_acc": [],
        "eval_loss": [],
        "eval_acc": [],
    }

    # outer progress bar over epochs
    epoch_iter = range(1, num_epochs + 1)
    if verbose:
        epoch_iter = tqdm(epoch_iter, desc="Epochs")

    for epoch in epoch_iter:
        # ========================= TRAIN =========================
        model.train()
        train_loss_sum = 0.0
        train_correct = 0
        train_total = 0

        # inner progress bar over training batches
        train_batch_iter = train_loader
        if verbose:
            train_batch_iter = tqdm(
                train_loader,
                desc=f"Train {epoch}/{num_epochs}",
                leave=False,
            )

        for inputs, targets in train_batch_iter:
            inputs = inputs.to(device, non_blocking=non_blocking)
            targets = targets.to(device, non_blocking=non_blocking)

            optimizer.zero_grad(set_to_none=True)

            if use_amp:
                with torch.cuda.amp.autocast():
                    outputs = model(inputs)
                    loss = criterion(outputs, targets)
                scaler.scale(loss).backward()

                if grad_clip_norm is not None:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm)

                scaler.step(optimizer)
                scaler.update()
            else:
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                loss.backward()

                if grad_clip_norm is not None:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm)

                optimizer.step()

            batch_size_curr = inputs.size(0)
            train_loss_sum += loss.item() * batch_size_curr
            preds = outputs.argmax(dim=1)
            train_correct += (preds == targets).sum().item()
            train_total += batch_size_curr

            if scheduler is not None and scheduler_step_per_batch:
                scheduler.step()

        epoch_train_loss = train_loss_sum / train_total
        epoch_train_acc = train_correct / train_total

        # ========================= EVAL =========================
        model.eval()
        eval_loss_sum = 0.0
        eval_correct = 0
        eval_total = 0

        # inner progress bar over eval batches
        eval_batch_iter = eval_loader
        if verbose:
            eval_batch_iter = tqdm(
                eval_loader,
                desc=f"Eval {epoch}/{num_epochs}",
                leave=False,
            )

        with torch.no_grad():
            for inputs, targets in eval_batch_iter:
                inputs = inputs.to(device, non_blocking=non_blocking)
                targets = targets.to(device, non_blocking=non_blocking)

                if use_amp:
                    with torch.cuda.amp.autocast():
                        outputs = model(inputs)
                        loss = criterion(outputs, targets)
                else:
                    outputs = model(inputs)
                    loss = criterion(outputs, targets)

                batch_size_curr = inputs.size(0)
                eval_loss_sum += loss.item() * batch_size_curr
                preds = outputs.argmax(dim=1)
                eval_correct += (preds == targets).sum().item()
                eval_total += batch_size_curr

        epoch_eval_loss = eval_loss_sum / eval_total
        epoch_eval_acc = eval_correct / eval_total

        if scheduler is not None and not scheduler_step_per_batch:
            scheduler.step()

        history["train_loss"].append(epoch_train_loss)
        history["train_acc"].append(epoch_train_acc)
        history["eval_loss"].append(epoch_eval_loss)
        history["eval_acc"].append(epoch_eval_acc)

        if verbose:
            print(
                f"Epoch {epoch:03d}/{num_epochs:03d} "
                f"- train_loss: {epoch_train_loss:.4f}, train_acc: {epoch_train_acc:.4f} "
                f"- eval_loss: {epoch_eval_loss:.4f}, eval_acc: {epoch_eval_acc:.4f}"
            )

    return history


In [None]:
from __future__ import annotations

from typing import Any, Callable, Dict, Iterable, List

from itertools import product

import torch
from torch import nn
from torch.utils.data import Dataset


def grid_combinations(
    static: Dict[str, Any],
    grid: Dict[str, Iterable[Any]],
) -> List[Dict[str, Any]]:
    """
    Create list of config dicts by combining static values with a simple
    Cartesian product over the grid dict.

    - `static`: fixed parameters (same for all runs).
    - `grid`:   each key has a list/iterable of values to sweep over.

    Grid values override same-named static values.
    """
    if not grid:
        return [dict(static)]

    keys = list(grid.keys())
    values_lists = [list(grid[k]) for k in keys]

    combos: List[Dict[str, Any]] = []
    for values in product(*values_lists):
        cfg = dict(static)
        cfg.update(zip(keys, values))  # grid overrides static
        combos.append(cfg)
    return combos


def run_ablation(
    build_model: Callable[..., nn.Module],
    train_model_fn: Callable[..., Dict[str, List[float]]],
    train_dataset: Dataset,
    eval_dataset: Dataset,
    *,
    model_static: Dict[str, Any],
    model_grid: Dict[str, Iterable[Any]],
    train_static: Dict[str, Any],
    train_grid: Dict[str, Iterable[Any]],
    num_epochs: int = 20,
    verbose: bool = True,
) -> Dict[str, Any]:
    """
    Exhaustive ablation over model + training hyperparameter grids.

    Args:
        build_model:
            Callable that instantiates the model.
            It will be called as: `build_model(**model_kwargs)`.

        train_model_fn:
            Your training loop function (the one we defined earlier),
            called as:
                train_model_fn(
                    model=model,
                    train_dataset=train_dataset,
                    eval_dataset=eval_dataset,
                    num_epochs=epochs,
                    **train_kwargs,
                )
            and must return a history dict with at least:
                history["eval_acc"], history["eval_loss"]

        train_dataset, eval_dataset:
            As before (e.g. Train / Validation).

        model_static:
            Dict of static kwargs for `build_model`.

        model_grid:
            Dict of hyperparameter -> iterable of values for the model.
            These are combined with `model_static`.

        train_static:
            Dict of static kwargs for `train_model_fn`
            (e.g. optimizer_cls, optimizer_kwargs, base_lr, etc.).

        train_grid:
            Dict of hyperparameter -> iterable of values for the training
            loop (e.g. base_lr, batch_size, etc.).

        num_epochs:
            Default number of epochs per run.
            If "num_epochs" is present in `train_grid`, it overrides this
            per run.

        verbose:
            Whether to print progress and best run info.

    Returns:
        results dict:

        {
            "runs": [
                {
                    "run_id": int,
                    "model_kwargs": {...},
                    "train_kwargs": {...},      # without 'num_epochs'
                    "num_epochs": int,
                    "history": {
                        "train_loss": [...],
                        "train_acc": [...],
                        "eval_loss": [...],
                        "eval_acc": [...],
                    },
                    "best_epoch": int,          # 1-based
                    "best_eval_acc": float,
                    "best_eval_loss": float,
                },
                ...
            ],
            "best": {
                ... same structure as one entry of "runs" ...
            },
        }
    """
    # Build all model and training configurations
    model_cfgs = grid_combinations(model_static, model_grid)
    train_cfgs = grid_combinations(train_static, train_grid)

    total_runs = len(model_cfgs) * len(train_cfgs)
    if verbose:
        print(
            f"Ablation: {len(model_cfgs)} model configs × "
            f"{len(train_cfgs)} train configs = {total_runs} runs"
        )

    results: Dict[str, Any] = {"runs": [], "best": None}
    global_best_acc = float("-inf")
    global_best_loss = float("inf")

    run_id = 0
    for m_cfg in model_cfgs:
        for t_cfg in train_cfgs:
            run_id += 1

            # Copy so originals stay unchanged
            model_kwargs = dict(m_cfg)
            train_kwargs = dict(t_cfg)

            # Allow num_epochs to be part of the train grid
            epochs = int(train_kwargs.pop("num_epochs", num_epochs))

            if verbose:
                print(f"\n=== Run {run_id}/{total_runs} ===")
                print("Model kwargs:", model_kwargs)
                print("Train kwargs:", train_kwargs, f"(epochs={epochs})")

            # 1) Instantiate model
            model = build_model(**model_kwargs)

            # 2) Train + evaluate
            history = train_model_fn(
                model=model,
                train_dataset=train_dataset,
                eval_dataset=eval_dataset,
                num_epochs=epochs,
                **train_kwargs,
            )

            eval_acc_list = history["eval_acc"]
            eval_loss_list = history["eval_loss"]

            # Best epoch by eval accuracy (tie-breaker: lower loss)
            best_epoch_idx = max(
                range(len(eval_acc_list)),
                key=lambda i: (eval_acc_list[i], -eval_loss_list[i]),
            )
            best_eval_acc = float(eval_acc_list[best_epoch_idx])
            best_eval_loss = float(eval_loss_list[best_epoch_idx])

            if verbose:
                print(
                    f"Best eval acc this run: {best_eval_acc:.4f} "
                    f"at epoch {best_epoch_idx + 1}"
                )

            run_result = {
                "run_id": run_id,
                "model_kwargs": model_kwargs,
                "train_kwargs": train_kwargs,
                "num_epochs": epochs,
                "history": history,
                "best_epoch": best_epoch_idx + 1,  # 1-based
                "best_eval_acc": best_eval_acc,
                "best_eval_loss": best_eval_loss,
            }
            results["runs"].append(run_result)

            # Global best by eval acc, then loss
            if (best_eval_acc > global_best_acc) or (
                best_eval_acc == global_best_acc
                and best_eval_loss < global_best_loss
            ):
                global_best_acc = best_eval_acc
                global_best_loss = best_eval_loss
                results["best"] = run_result

    # Final summary
    if verbose and results["best"] is not None:
        b = results["best"]
        print("\n=== Best overall run ===")
        print(f"Run id:        {b['run_id']}")
        print(f"Best eval acc:  {b['best_eval_acc']:.4f}")
        print(f"Best eval loss: {b['best_eval_loss']:.4f}")
        print(f"Best epoch:     {b['best_epoch']}")
        print("Best model kwargs:", b["model_kwargs"])
        print("Best train kwargs:", b["train_kwargs"])

    return results


## CNN

In [None]:
import torch
from torch import nn


class CNNBlock(nn.Module):
    """
    Simple ResNet-style basic block:
    Conv3x3 -> BN -> ReLU -> Conv3x3 -> BN + skip connection -> ReLU
    """
    def __init__(self, in_channels: int, out_channels: int, stride: int = 1):
        super().__init__()

        self.conv1 = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False,
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(
            out_channels,
            out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            bias=False,
        )
        self.bn2 = nn.BatchNorm2d(out_channels)

        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(
                    in_channels,
                    out_channels,
                    kernel_size=1,
                    stride=stride,
                    bias=False,
                ),
                nn.BatchNorm2d(out_channels),
            )
        else:
            self.downsample = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        return out


class SimpleCNN(nn.Module):
    """
    Small ResNet-like CNN for 256x256 images, binary classification (Real/Fake).

    Stages:
        stem:    3x3 conv -> BN -> ReLU
        stage1:  BasicBlock x2,  base_channels     -> 2*base_channels
        stage2:  BasicBlock x2,  2*base_channels   -> 4*base_channels
        stage3:  BasicBlock x2,  4*base_channels   -> 8*base_channels
        head:    global avg pool -> Dropout -> Linear -> num_classes
    """
    def __init__(
        self,
        in_channels: int = 3,
        num_classes: int = 2,
        base_channels: int = 32,
        num_blocks_per_stage: tuple[int, int, int] = (2, 2, 2),
        dropout: float = 0.3,
    ):
        super().__init__()

        self.stem = nn.Sequential(
            nn.Conv2d(
                in_channels,
                base_channels,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(base_channels),
            nn.ReLU(inplace=True),
        )

        c1 = base_channels
        c2 = base_channels * 2
        c3 = base_channels * 4
        c4 = base_channels * 8

        self.stage1 = self._make_stage(
            in_channels=c1,
            out_channels=c2,
            num_blocks=num_blocks_per_stage[0],
            stride_first=2,  # 256 -> 128
        )
        self.stage2 = self._make_stage(
            in_channels=c2,
            out_channels=c3,
            num_blocks=num_blocks_per_stage[1],
            stride_first=2,  # 128 -> 64
        )
        self.stage3 = self._make_stage(
            in_channels=c3,
            out_channels=c4,
            num_blocks=num_blocks_per_stage[2],
            stride_first=2,  # 64 -> 32
        )

        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
        self.fc = nn.Linear(c4, num_classes)

    def _make_stage(
        self,
        in_channels: int,
        out_channels: int,
        num_blocks: int,
        stride_first: int,
    ) -> nn.Sequential:
        blocks = []
        # First block can downsample
        blocks.append(CNNBlock(in_channels, out_channels, stride=stride_first))
        # Remaining blocks keep same channels & resolution
        for _ in range(num_blocks - 1):
            blocks.append(CNNBlock(out_channels, out_channels, stride=1))
        return nn.Sequential(*blocks)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.stem(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)

        x = self.global_pool(x)      # [B, C, 1, 1]
        x = x.flatten(1)             # [B, C]
        x = self.dropout(x)
        x = self.fc(x)               # [B, num_classes]
        return x


In [None]:
def build_model(
    arch: str = "simple_resnet_cnn",
    in_channels: int = 3,
    num_classes: int = 2,
    base_channels: int = 32,
    num_blocks_per_stage: tuple[int, int, int] = (2, 2, 2),
    dropout: float = 0.3,
    **kwargs,
) -> nn.Module:
    """
    Wrapper to fit run_ablation's expected signature.
    `arch` is here in case you later add more model types; for now it's fixed.
    """
    if arch != "simple_resnet_cnn":
        raise ValueError(f"Unknown architecture: {arch}")

    return SimpleCNN(
        in_channels=in_channels,
        num_classes=num_classes,
        base_channels=base_channels,
        num_blocks_per_stage=num_blocks_per_stage,
        dropout=dropout,
    )

from torch.optim import AdamW

# Model Hyperparameters

model_static = {
    "arch": "simple_resnet_cnn",
    "in_channels": 3,
    "num_classes": 2,
    "num_blocks_per_stage": (2, 2, 2),
}

model_grid = {
    "base_channels": [32, 48],
    "dropout": [0.0, 0.3],
}

# Train Hyperparameters

train_static = {
    "optimizer_cls": AdamW,
    "use_amp": True,        # T4 supports mixed precision well
    "grad_clip_norm": 1.0,  # mild gradient clipping
    "batch_size": 32,
}

train_grid = {
    "base_lr": [1e-4, 3e-4],
    # Sweep weight decay by passing different optimizer_kwargs dicts
    "optimizer_kwargs": [
        {"weight_decay": 0.0},
        {"weight_decay": 1e-4},
    ],
    # You can add "batch_size": [32, 64] here if you want to include that too
    # "batch_size": [32, 64],
}

NUM_EPOCHS = 5

from torch.utils.data import Subset
n_total = len(train_dataset)
n_small = n_total // 10
indices = torch.randperm(n_total)[:n_small]
train_dataset_subset = Subset(train_dataset, indices)

results = run_ablation(
    build_model=build_model,
    train_model_fn=train_model,
    train_dataset=train_dataset_subset,
    eval_dataset=val_dataset,
    model_static=model_static,
    model_grid=model_grid,
    train_static=train_static,
    train_grid=train_grid,
    num_epochs=NUM_EPOCHS,
    verbose=True,
)

Ablation: 4 model configs × 4 train configs = 16 runs

=== Run 1/16 ===
Model kwargs: {'arch': 'simple_resnet_cnn', 'in_channels': 3, 'num_classes': 2, 'num_blocks_per_stage': (2, 2, 2), 'base_channels': 32, 'dropout': 0.0}
Train kwargs: {'optimizer_cls': <class 'torch.optim.adamw.AdamW'>, 'use_amp': True, 'grad_clip_norm': 1.0, 'batch_size': 32, 'base_lr': 0.0001, 'optimizer_kwargs': {'weight_decay': 0.0}} (epochs=5)


  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)


Epochs:   0%|          | 0/5 [00:00<?, ?it/s]

Train 1/5:   0%|          | 0/438 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Eval 1/5:   0%|          | 0/1233 [00:00<?, ?it/s]

  with torch.cuda.amp.autocast():


Epoch 001/005 - train_loss: 0.6695, train_acc: 0.6006 - eval_loss: 0.7092, eval_acc: 0.5573


Train 2/5:   0%|          | 0/438 [00:00<?, ?it/s]

Eval 2/5:   0%|          | 0/1233 [00:00<?, ?it/s]

Epoch 002/005 - train_loss: 0.5845, train_acc: 0.6994 - eval_loss: 0.6709, eval_acc: 0.6489


Train 3/5:   0%|          | 0/438 [00:00<?, ?it/s]

Eval 3/5:   0%|          | 0/1233 [00:00<?, ?it/s]

Epoch 003/005 - train_loss: 0.4439, train_acc: 0.7981 - eval_loss: 0.5252, eval_acc: 0.7535


Train 4/5:   0%|          | 0/438 [00:00<?, ?it/s]

Eval 4/5:   0%|          | 0/1233 [00:00<?, ?it/s]

Epoch 004/005 - train_loss: 0.3236, train_acc: 0.8654 - eval_loss: 1.0843, eval_acc: 0.6121


Train 5/5:   0%|          | 0/438 [00:00<?, ?it/s]

Eval 5/5:   0%|          | 0/1233 [00:00<?, ?it/s]

Epoch 005/005 - train_loss: 0.2657, train_acc: 0.8919 - eval_loss: 0.3112, eval_acc: 0.8578
Best eval acc this run: 0.8578 at epoch 5

=== Run 2/16 ===
Model kwargs: {'arch': 'simple_resnet_cnn', 'in_channels': 3, 'num_classes': 2, 'num_blocks_per_stage': (2, 2, 2), 'base_channels': 32, 'dropout': 0.0}
Train kwargs: {'optimizer_cls': <class 'torch.optim.adamw.AdamW'>, 'use_amp': True, 'grad_clip_norm': 1.0, 'batch_size': 32, 'base_lr': 0.0001, 'optimizer_kwargs': {'weight_decay': 0.0001}} (epochs=5)


Epochs:   0%|          | 0/5 [00:00<?, ?it/s]

Train 1/5:   0%|          | 0/438 [00:00<?, ?it/s]

Eval 1/5:   0%|          | 0/1233 [00:00<?, ?it/s]

Epoch 001/005 - train_loss: 0.6661, train_acc: 0.5960 - eval_loss: 0.7017, eval_acc: 0.5929


Train 2/5:   0%|          | 0/438 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3adc1dc2c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3adc1dc2c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Eval 2/5:   0%|          | 0/1233 [00:00<?, ?it/s]

Epoch 002/005 - train_loss: 0.5796, train_acc: 0.7066 - eval_loss: 0.7479, eval_acc: 0.6750


Train 3/5:   0%|          | 0/438 [00:00<?, ?it/s]

Eval 3/5:   0%|          | 0/1233 [00:00<?, ?it/s]

Epoch 003/005 - train_loss: 0.4544, train_acc: 0.7929 - eval_loss: 0.4199, eval_acc: 0.8013


Train 4/5:   0%|          | 0/438 [00:00<?, ?it/s]

Eval 4/5:   0%|          | 0/1233 [00:00<?, ?it/s]

Epoch 004/005 - train_loss: 0.3505, train_acc: 0.8480 - eval_loss: 0.6309, eval_acc: 0.7407


Train 5/5:   0%|          | 0/438 [00:00<?, ?it/s]

Eval 5/5:   0%|          | 0/1233 [00:00<?, ?it/s]

Epoch 005/005 - train_loss: 0.2680, train_acc: 0.8918 - eval_loss: 0.3337, eval_acc: 0.8430
Best eval acc this run: 0.8430 at epoch 5

=== Run 3/16 ===
Model kwargs: {'arch': 'simple_resnet_cnn', 'in_channels': 3, 'num_classes': 2, 'num_blocks_per_stage': (2, 2, 2), 'base_channels': 32, 'dropout': 0.0}
Train kwargs: {'optimizer_cls': <class 'torch.optim.adamw.AdamW'>, 'use_amp': True, 'grad_clip_norm': 1.0, 'batch_size': 32, 'base_lr': 0.0003, 'optimizer_kwargs': {'weight_decay': 0.0, 'lr': 0.0001}} (epochs=5)


Epochs:   0%|          | 0/5 [00:00<?, ?it/s]

Train 1/5:   0%|          | 0/438 [02:20<?, ?it/s]

Eval 1/5:   0%|          | 0/1233 [00:00<?, ?it/s]

Epoch 001/005 - train_loss: 0.6720, train_acc: 0.5980 - eval_loss: 0.6969, eval_acc: 0.5665


Train 2/5:   0%|          | 0/438 [00:00<?, ?it/s]

Eval 2/5:   0%|          | 0/1233 [00:00<?, ?it/s]

Epoch 002/005 - train_loss: 0.6024, train_acc: 0.6826 - eval_loss: 0.5392, eval_acc: 0.7185


Train 3/5:   0%|          | 0/438 [00:00<?, ?it/s]

Eval 3/5:   0%|          | 0/1233 [00:00<?, ?it/s]

Epoch 003/005 - train_loss: 0.4793, train_acc: 0.7771 - eval_loss: 0.6188, eval_acc: 0.6990


Train 4/5:   0%|          | 0/438 [00:00<?, ?it/s]

Eval 4/5:   0%|          | 0/1233 [00:00<?, ?it/s]

Epoch 004/005 - train_loss: 0.3554, train_acc: 0.8470 - eval_loss: 0.4099, eval_acc: 0.8092


Train 5/5:   0%|          | 0/438 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3adc1dc2c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3adc1dc2c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Eval 5/5:   0%|          | 0/1233 [00:00<?, ?it/s]

Epoch 005/005 - train_loss: 0.2762, train_acc: 0.8896 - eval_loss: 0.8931, eval_acc: 0.6260
Best eval acc this run: 0.8092 at epoch 4

=== Run 4/16 ===
Model kwargs: {'arch': 'simple_resnet_cnn', 'in_channels': 3, 'num_classes': 2, 'num_blocks_per_stage': (2, 2, 2), 'base_channels': 32, 'dropout': 0.0}
Train kwargs: {'optimizer_cls': <class 'torch.optim.adamw.AdamW'>, 'use_amp': True, 'grad_clip_norm': 1.0, 'batch_size': 32, 'base_lr': 0.0003, 'optimizer_kwargs': {'weight_decay': 0.0001, 'lr': 0.0001}} (epochs=5)


Epochs:   0%|          | 0/5 [00:00<?, ?it/s]

Train 1/5:   0%|          | 0/438 [00:00<?, ?it/s]

Eval 1/5:   0%|          | 0/1233 [00:00<?, ?it/s]

Epoch 001/005 - train_loss: 0.6622, train_acc: 0.6157 - eval_loss: 0.6647, eval_acc: 0.6111


Train 2/5:   0%|          | 0/438 [00:00<?, ?it/s]

Eval 2/5:   0%|          | 0/1233 [00:00<?, ?it/s]

Epoch 002/005 - train_loss: 0.5766, train_acc: 0.7118 - eval_loss: 0.6251, eval_acc: 0.6854


Train 3/5:   0%|          | 0/438 [00:00<?, ?it/s]

Eval 3/5:   0%|          | 0/1233 [00:00<?, ?it/s]

Epoch 003/005 - train_loss: 0.4835, train_acc: 0.7774 - eval_loss: 0.7391, eval_acc: 0.6105


Train 4/5:   0%|          | 0/438 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3adc1dc2c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f3adc1dc2c0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Eval 4/5:   0%|          | 0/1233 [00:00<?, ?it/s]

Epoch 004/005 - train_loss: 0.3684, train_acc: 0.8399 - eval_loss: 0.5015, eval_acc: 0.7657


Train 5/5:   0%|          | 0/438 [00:00<?, ?it/s]

Eval 5/5:   0%|          | 0/1233 [00:00<?, ?it/s]

Epoch 005/005 - train_loss: 0.2870, train_acc: 0.8834 - eval_loss: 0.4451, eval_acc: 0.8170
Best eval acc this run: 0.8170 at epoch 5

=== Run 5/16 ===
Model kwargs: {'arch': 'simple_resnet_cnn', 'in_channels': 3, 'num_classes': 2, 'num_blocks_per_stage': (2, 2, 2), 'base_channels': 32, 'dropout': 0.3}
Train kwargs: {'optimizer_cls': <class 'torch.optim.adamw.AdamW'>, 'use_amp': True, 'grad_clip_norm': 1.0, 'batch_size': 32, 'base_lr': 0.0001, 'optimizer_kwargs': {'weight_decay': 0.0, 'lr': 0.0001}} (epochs=5)


Epochs:   0%|          | 0/5 [00:00<?, ?it/s]

Train 1/5:   0%|          | 0/438 [00:00<?, ?it/s]

Eval 1/5:   0%|          | 0/1233 [00:00<?, ?it/s]

Epoch 001/005 - train_loss: 0.6874, train_acc: 0.5839 - eval_loss: 0.7242, eval_acc: 0.5393


Train 2/5:   0%|          | 0/438 [00:00<?, ?it/s]

## ViT

In [None]:
import torch
from torch import nn
from torchvision.models.vision_transformer import VisionTransformer


class ViTClassifier(nn.Module):
    """
    Small Vision Transformer for 256x256 images, binary classification (Real/Fake).

    Uses torchvision.models.vision_transformer.VisionTransformer under the hood.
    """

    def __init__(
        self,
        image_size: int = 256,
        patch_size: int = 16,
        num_layers: int = 6,
        hidden_dim: int = 256,
        mlp_ratio: float = 4.0,
        dropout: float = 0.1,
        attention_dropout: float = 0.1,
        num_classes: int = 2,
    ):
        super().__init__()

        assert image_size % patch_size == 0, "image_size must be divisible by patch_size"
        self.image_size = image_size
        self.patch_size = patch_size

        # standard ViT practice: num_heads such that hidden_dim / num_heads ~ 64
        num_heads = max(1, hidden_dim // 64)
        mlp_dim = int(hidden_dim * mlp_ratio)

        self.vit = VisionTransformer(
            image_size=image_size,
            patch_size=patch_size,
            num_layers=num_layers,
            num_heads=num_heads,
            hidden_dim=hidden_dim,
            mlp_dim=mlp_dim,
            dropout=dropout,
            attention_dropout=attention_dropout,
            num_classes=num_classes,
            representation_size=None,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.vit(x)


In [None]:
from torch.optim import AdamW

def build_vit_model(
    arch: str = "small_vit",
    image_size: int = 256,
    patch_size: int = 16,
    num_layers: int = 6,
    hidden_dim: int = 256,
    mlp_ratio: float = 4.0,
    dropout: float = 0.1,
    attention_dropout: float = 0.1,
    num_classes: int = 2,
    **kwargs,
) -> nn.Module:
    """
    Wrapper compatible with run_ablation.
    `arch` is kept for consistency with the CNN setup.
    """
    if arch != "small_vit":
        raise ValueError(f"Unknown architecture: {arch}")

    return SmallViTClassifier(
        image_size=image_size,
        patch_size=patch_size,
        num_layers=num_layers,
        hidden_dim=hidden_dim,
        mlp_ratio=mlp_ratio,
        dropout=dropout,
        attention_dropout=attention_dropout,
        num_classes=num_classes,
    )


# ----------------------------
# MODEL hyperparameters (ViT)
# ----------------------------

vit_model_static = {
    "arch": "small_vit",
    "image_size": 256,
    "patch_size": 16,
    "num_classes": 2,
    "mlp_ratio": 4.0,
    "dropout": 0.1,
    "attention_dropout": 0.1,
}

vit_model_grid = {
    "hidden_dim": [192, 256],   # model width
    "num_layers": [6, 8],       # depth of transformer encoder
    # Could also add "patch_size": [16, 32] to explore fewer tokens
}


# ----------------------------
# TRAINING hyperparameters
# ----------------------------

vit_train_static = {
    "optimizer_cls": AdamW,
    "use_amp": True,
    "grad_clip_norm": 1.0,
    "batch_size": 32,
}

vit_train_grid = {
    "base_lr": [1e-4, 3e-4],
    "optimizer_kwargs": [
        {"weight_decay": 0.0},
        {"weight_decay": 1e-4},
    ],
}


NUM_EPOCHS = 5

from torch.utils.data import Subset
n_total = len(train_dataset)
n_small = n_total // 10
indices = torch.randperm(n_total)[:n_small]
train_dataset_subset = Subset(train_dataset, indices)

vit_results = run_ablation(
    build_model=build_vit_model,
    train_model_fn=train_model,
    train_dataset=train_dataset_subset,
    eval_dataset=val_dataset,
    model_static=vit_model_static,
    model_grid=vit_model_grid,
    train_static=vit_train_static,
    train_grid=vit_train_grid,
    num_epochs=NUM_EPOCHS,
    verbose=True,
)
