# CNN with Xavier/He (matched init)
CIFAR-10, 80/20 split, deep conv stack with activation-appropriate init. Includes weight decay regularization, dropout, early stopping, and CSV logging (detailed + summary) for plotting.

In [None]:
import csv
import random
from dataclasses import dataclass
from typing import Callable, List, Tuple

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

In [None]:
def set_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def get_activation(name: str) -> Tuple[nn.Module, str]:
    if name == "relu":
        return nn.ReLU(), "relu"
    if name == "tanh":
        return nn.Tanh(), "tanh"
    raise ValueError(f"Unsupported activation {name}")


class ConvNet(nn.Module):
    def __init__(self, activation: str, dropout_p: float) -> None:
        super().__init__()
        act, _ = get_activation(activation)
        layers: List[nn.Module] = []
        layers += [nn.Conv2d(3, 32, kernel_size=3, padding=1), act]
        layers += [nn.Conv2d(32, 32, kernel_size=3, padding=1), act]
        layers += [nn.Conv2d(32, 32, kernel_size=3, padding=1), act, nn.MaxPool2d(2), nn.Dropout(dropout_p)]
        layers += [nn.Conv2d(32, 64, kernel_size=3, padding=1), act]
        layers += [nn.Conv2d(64, 64, kernel_size=3, padding=1), act]
        layers += [nn.Conv2d(64, 64, kernel_size=3, padding=1), act, nn.MaxPool2d(2), nn.Dropout(dropout_p)]
        layers += [nn.Conv2d(64, 128, kernel_size=3, padding=1), act]
        layers += [nn.Conv2d(128, 128, kernel_size=3, padding=1), act]
        layers += [nn.Conv2d(128, 128, kernel_size=3, padding=1), act, nn.MaxPool2d(2), nn.Dropout(dropout_p)]
        self.features = nn.Sequential(*layers)
        self.classifier = nn.Linear(128 * 4 * 4, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = torch.flatten(x, 1)
        return self.classifier(x)

In [None]:
def init_weights(module: nn.Module, scheme: str, activation_name: str) -> None:
    if not isinstance(module, (nn.Conv2d, nn.Linear)):
        return
    if scheme == "xavier":
        gain = nn.init.calculate_gain("tanh" if activation_name == "tanh" else "relu")
        nn.init.xavier_uniform_(module.weight, gain=gain)
    elif scheme == "he":
        nn.init.kaiming_uniform_(module.weight, nonlinearity=activation_name)
    else:
        raise ValueError(f"Unknown scheme {scheme}")
    if module.bias is not None:
        nn.init.zeros_(module.bias)


def get_data(batch_size: int, seed: int) -> Tuple[DataLoader, DataLoader]:
    transform = transforms.ToTensor()
    full_train = datasets.CIFAR10(root="data", train=True, download=True, transform=transform)
    n_train = int(0.8 * len(full_train))
    n_val = len(full_train) - n_train
    g = torch.Generator().manual_seed(seed)
    train_ds, val_ds = random_split(full_train, [n_train, n_val], generator=g)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    return train_loader, val_loader

In [None]:
def activation_stats(model: nn.Module, x: torch.Tensor) -> Tuple[float, float, float]:
    seq = model.features
    means, stds, frac_zero = [], [], []
    with torch.no_grad():
        for layer in seq:
            x = layer(x)
            if isinstance(layer, (nn.ReLU, nn.Tanh)):
                means.append(x.mean().item())
                stds.append(x.std().item())
                frac_zero.append((x == 0).float().mean().item())
    return (float(np.mean(means)) if means else 0.0,
            float(np.mean(stds)) if stds else 0.0,
            float(np.mean(frac_zero)) if frac_zero else 0.0)


def gradient_norm(model: nn.Module) -> float:
    norms = [p.grad.norm().item() for p in model.parameters() if p.grad is not None]
    return float(np.mean(norms)) if norms else 0.0

In [None]:
def train_one_epoch(model: nn.Module, loader: DataLoader, criterion: Callable, optimizer: torch.optim.Optimizer, device: str) -> float:
    model.train()
    running = 0.0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        running += loss.item() * x.size(0)
    return running / len(loader.dataset)


def evaluate(model: nn.Module, loader: DataLoader, criterion: Callable, device: str) -> Tuple[float, float]:
    model.eval()
    total = 0
    correct = 0
    running = 0.0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out, y)
            running += loss.item() * x.size(0)
            preds = out.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return running / len(loader.dataset), correct / total

In [None]:
@dataclass
class Config:
    batch_size: int = 128
    epochs: int = 50
    lr: float = 1e-3
    weight_decay: float = 1e-4
    dropout: float = 0.1
    patience: int = 3
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    seed: int = 42


cfg = Config()


Run experiments with activation-matched init; log metrics, activation stats, gradient norm; early stop on stalled val_loss.

In [None]:
set_seed(cfg.seed)
train_loader, val_loader = get_data(cfg.batch_size, cfg.seed)
results = []
summaries = []

combos = [
    ("relu", "he"),
    ("tanh", "xavier"),
]

for activation, scheme in combos:
    model = ConvNet(activation=activation, dropout_p=cfg.dropout)
    model.apply(lambda m: init_weights(m, scheme, activation))
    model.to(cfg.device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

    last_metrics = None

    for epoch in range(cfg.epochs):
        train_loss = train_one_epoch(model, train_loader, criterion, optimizer, cfg.device)
        val_loss, val_acc = evaluate(model, val_loader, criterion, cfg.device)

        sample_x, sample_y = next(iter(val_loader))
        sample_x, sample_y = sample_x.to(cfg.device), sample_y.to(cfg.device)
        model.zero_grad()
        out = model(sample_x)
        loss_sample = criterion(out, sample_y)
        loss_sample.backward()
        grad_norm_val = gradient_norm(model)
        act_mean, act_std, frac_zero = activation_stats(model, sample_x)

        metrics = {
            "epoch": epoch + 1,
            "train_loss": train_loss,
            "val_loss": val_loss,
            "val_acc": val_acc,
            "activation": activation,
            "init": scheme,
            "act_mean": act_mean,
            "act_std": act_std,
            "frac_zero": frac_zero,
            "grad_norm": grad_norm_val,
        }
        results.append(metrics)
        last_metrics = metrics

        print(
            f"act={activation} init={scheme} epoch={epoch+1:02d} "
            f"train_loss={train_loss:.4f} val_loss={val_loss:.4f} val_acc={val_acc*100:.2f}% "
            f"act_mean={act_mean:.4f} act_std={act_std:.4f} frac_zero={frac_zero:.4f} grad_norm={grad_norm_val:.4f}"
        )

    if last_metrics is not None:
        summaries.append(last_metrics)

with open("good_runs.csv", "w", newline="") as f:
    writer = csv.DictWriter(f, fieldnames=list(results[0].keys()))
    writer.writeheader()
    writer.writerows(results)

if summaries:
    with open("good_runs_summary.csv", "w", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=list(summaries[0].keys()))
        writer.writeheader()
        writer.writerows(summaries)

print("Wrote good_runs.csv and good_runs_summary.csv")
