# Run Pipeline - Simple ResNet

In [1]:
"""
A bite-size ResNet demo with verbose prints to illustrate how residual
connections work.
"""
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import random
import numpy as np

# ---------- 1. Reproducibility ------------------------------------------------
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
# ---------- 2. Synthetic Dataset ---------------------------------------------
class SyntheticMeanThreshold(Dataset):
    """
    Each sample is a 1×32×32 image of uniform noise.
    Label is 1 if the mean pixel intensity > 0.5, else 0.
    """
    def __init__(self, n_samples: int):
        self.n_samples = n_samples

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        img = torch.rand(1, 32, 32)          # values in [0,1)
        label = torch.tensor([1 if img.mean() > 0.5 else 0], dtype=torch.long)
        return img, label.squeeze(0)

In [3]:
# ---------- 3. Residual Block -------------------------------------------------
class BasicBlock(nn.Module):
    """
    A very small residual block:
    x -> Conv3(→out) -> BN -> ReLU -> Conv3 -> BN
         |                                   |
         +------------ (identity / 1x1) -----+
         -> ReLU
    """
    def __init__(self, in_ch: int, out_ch: int, stride: int = 1):
        super().__init__()

        # Main path
        self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(out_ch)
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_ch)

        # Shortcut path (projection if shape changes)
        if stride != 1 or in_ch != out_ch:
            self.shortcut = nn.Conv2d(in_ch, out_ch, kernel_size=1,
                                      stride=stride, bias=False)
        else:
            self.shortcut = nn.Identity()

    def forward(self, x):
        print(f"Input       : {x.shape}")

        # ----- main path -----
        out = self.conv1(x)
        print(f"Conv1       : {out.shape}")
        out = self.bn1(out)
        out = F.relu(out)

        out = self.conv2(out)
        print(f"Conv2       : {out.shape}")
        out = self.bn2(out)

        # ----- shortcut ------
        sc = self.shortcut(x)
        print(f"Shortcut    : {sc.shape}")

        out += sc
        print(f"After add   : {out.shape}")

        out = F.relu(out)
        print(f"Block output: {out.shape}\n")
        return out


# ---------- 4. Tiny ResNet ----------------------------------------------------
class SimpleResNet(nn.Module):
    """
    Three BasicBlocks, doubling channels each time and halving
    spatial dim the first time they appear.  Suitable for 32×32 inputs.
    """
    def __init__(self, n_classes: int = 2):
        super().__init__()
        # Stem
        self.stem = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1,
                              bias=False)

        # Residual stages
        self.layer1 = BasicBlock(16, 16, stride=1)   # 32×32
        self.layer2 = BasicBlock(16, 32, stride=2)   # 16×16
        self.layer3 = BasicBlock(32, 64, stride=2)   # 8×8

        # Head
        self.pool = nn.AdaptiveAvgPool2d(1)          # 64×1×1
        self.fc   = nn.Linear(64, n_classes)

    def forward(self, x):
        x = self.stem(x)
        x = F.relu(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

In [4]:
# ---------- 5. Utilities ------------------------------------------------------
def accuracy(pred_logits, labels):
    preds = pred_logits.argmax(dim=1)
    return (preds == labels).float().mean().item()

# ---------- 6. Training Loop --------------------------------------------------
def train(model, loader, criterion, optim, epoch):
    model.train()
    running_loss, running_acc = 0.0, 0.0
    for imgs, labels in loader:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)

        optim.zero_grad()
        logits = model(imgs)
        loss = criterion(logits, labels)
        loss.backward()
        optim.step()

        running_loss += loss.item() * imgs.size(0)
        running_acc  += accuracy(logits, labels) * imgs.size(0)

    n = len(loader.dataset)
    print(f"[Epoch {epoch:02}] "
          f"loss={running_loss/n:.4f}  acc={running_acc/n:.3f}")

In [5]:
# ---------- 7. Main -----------------------------------------------------------
def main():
    # Data
    train_set = SyntheticMeanThreshold(1000)
    val_set   = SyntheticMeanThreshold(200)

    train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
    val_loader   = DataLoader(val_set, batch_size=32)

    # Model, loss, optimiser
    model = SimpleResNet().to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimiser = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Train
    for epoch in range(1, 6):          # five short epochs
        train(model, train_loader, criterion, optimiser, epoch)

        # quick validation
        model.eval()
        with torch.no_grad():
            val_loss, val_acc = 0.0, 0.0
            for imgs, labels in val_loader:
                imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
                logits = model(imgs)
                val_loss += criterion(logits, labels).item() * imgs.size(0)
                val_acc  += accuracy(logits, labels) * imgs.size(0)
            n = len(val_loader.dataset)
            print(f"          ↳ val_loss={val_loss/n:.4f}  "
                  f"val_acc={val_acc/n:.3f}\n")

In [6]:
main()

Input       : torch.Size([32, 16, 32, 32])
Conv1       : torch.Size([32, 16, 32, 32])
Conv2       : torch.Size([32, 16, 32, 32])
Shortcut    : torch.Size([32, 16, 32, 32])
After add   : torch.Size([32, 16, 32, 32])
Block output: torch.Size([32, 16, 32, 32])

Input       : torch.Size([32, 16, 32, 32])
Conv1       : torch.Size([32, 32, 16, 16])
Conv2       : torch.Size([32, 32, 16, 16])
Shortcut    : torch.Size([32, 32, 16, 16])
After add   : torch.Size([32, 32, 16, 16])
Block output: torch.Size([32, 32, 16, 16])

Input       : torch.Size([32, 32, 16, 16])
Conv1       : torch.Size([32, 64, 8, 8])
Conv2       : torch.Size([32, 64, 8, 8])
Shortcut    : torch.Size([32, 64, 8, 8])
After add   : torch.Size([32, 64, 8, 8])
Block output: torch.Size([32, 64, 8, 8])

Input       : torch.Size([32, 16, 32, 32])
Conv1       : torch.Size([32, 16, 32, 32])
Conv2       : torch.Size([32, 16, 32, 32])
Shortcut    : torch.Size([32, 16, 32, 32])
After add   : torch.Size([32, 16, 32, 32])
Block output: torc

In [7]:
train_set = SyntheticMeanThreshold(1000)
val_set   = SyntheticMeanThreshold(200)

In [18]:
print(train_set.__getitem__(1)[0].shape)

torch.Size([1, 32, 32])
