<a href="https://colab.research.google.com/github/Nishorgo26/Project_2/blob/main/Project_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title ✅ Setup: GPU, deps, optional Drive
import sys, os, torch
print("Python:", sys.version)
print("CUDA available:", torch.cuda.is_available())
print("GPU:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")

# Quiet installs
!pip -q install torchvision torchaudio scikit-learn matplotlib

# Optional: save to Drive
USE_DRIVE = False  #@param {type:"boolean"}
ARTIFACT_DIR = "/content/drive/MyDrive/vgg19_cifar10_prune_artifacts" if USE_DRIVE else "/content/artifacts"
if USE_DRIVE:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)

os.makedirs(ARTIFACT_DIR, exist_ok=True)
print("Artifacts →", ARTIFACT_DIR)


Python: 3.12.11 (main, Jun  4 2025, 08:56:18) [GCC 11.4.0]
CUDA available: True
GPU: NVIDIA A100-SXM4-80GB
Artifacts → /content/artifacts


In [None]:
#@title 🔧 Imports, seeds, metrics, dataloaders, plotting (fixed)
import time, copy, random
import numpy as np
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms, models

from sklearn.metrics import (accuracy_score, precision_recall_fscore_support,
                             confusion_matrix, classification_report)
from torch.nn.utils import prune

# Repro (keep cudnn fast)
def seed_everything(seed=42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

CIFAR10_CLASSES = ["airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck"]

def count_params(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

def build_dataloaders(data_root, batch_size, num_workers=2, val_ratio=0.1):
    normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                     std=[0.2470, 0.2435, 0.2616])
    train_tf = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
        transforms.ToTensor(),
        normalize
    ])
    test_tf = transforms.Compose([transforms.ToTensor(), normalize])

    train_full = datasets.CIFAR10(root=data_root, train=True, download=True, transform=train_tf)
    test_set   = datasets.CIFAR10(root=data_root, train=False, download=True, transform=test_tf)

    n_val = int(len(train_full) * val_ratio)
    n_train = len(train_full) - n_val
    train_set, val_set = random_split(train_full, [n_train, n_val],
                                      generator=torch.Generator().manual_seed(123))

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True,  num_workers=num_workers, pin_memory=True)
    val_loader   = DataLoader(val_set,   batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    test_loader  = DataLoader(test_set,  batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    return train_loader, val_loader, test_loader

# === Plotting (matplotlib; no seaborn, no custom colors) ===
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

def plot_curves(history, outdir):
    # Loss
    plt.figure()
    plt.plot(history["epoch"], history["train_loss"], label="train_loss")
    plt.plot(history["epoch"], history["val_loss"],   label="val_loss")
    plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.legend(); plt.grid(True, linestyle=":")
    plt.title("Training vs Validation Loss")
    plt.tight_layout(); plt.savefig(Path(outdir)/"loss_curves.png"); plt.close()
    # Accuracy
    plt.figure()
    plt.plot(history["epoch"], history["train_acc"], label="train_acc")
    plt.plot(history["epoch"], history["val_acc"],   label="val_acc")
    plt.xlabel("Epoch"); plt.ylabel("Accuracy"); plt.legend(); plt.grid(True, linestyle=":")
    plt.title("Training vs Validation Accuracy")
    plt.tight_layout(); plt.savefig(Path(outdir)/"acc_curves.png"); plt.close()

def plot_confusion(cm, classes, outpath, normalize=False, title="Confusion Matrix"):
    # FIX: dynamic formatting; ints for raw CM, .2f for normalized
    cm = np.array(cm)
    if normalize:
        cm = cm.astype(np.float64)
        with np.errstate(all='ignore'):
            cm = cm / cm.sum(axis=1, keepdims=True)
            cm = np.nan_to_num(cm)
        fmt = ".2f"
    else:
        if cm.dtype.kind == 'f':
            cm = np.rint(cm).astype(np.int64)
        fmt = "d"

    plt.figure(figsize=(6.5, 5.5))
    plt.imshow(cm, interpolation='nearest')
    plt.title(title); plt.colorbar()
    ticks = np.arange(len(classes))
    plt.xticks(ticks, classes, rotation=45, ha="right")
    plt.yticks(ticks, classes)
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, format(cm[i, j], fmt), ha="center", va="center")
    plt.ylabel("True label"); plt.xlabel("Predicted label")
    Path(outpath).parent.mkdir(parents=True, exist_ok=True)
    plt.tight_layout(); plt.savefig(outpath); plt.close()

def plot_per_class_bars(report_dict, outstem):
    y_true, y_pred = report_dict.get("y_true"), report_dict.get("y_pred")
    if y_true is None or y_pred is None: return
    p, r, f1, _ = precision_recall_fscore_support(y_true, y_pred, labels=list(range(10)), zero_division=0)
    x = np.arange(len(CIFAR10_CLASSES))
    # Precision
    plt.figure(); plt.bar(x, p); plt.xticks(x, CIFAR10_CLASSES, rotation=45, ha="right")
    plt.ylabel("Precision"); plt.title("Per-class Precision"); plt.tight_layout()
    plt.savefig(f"{outstem}_precision.png"); plt.close()
    # Recall
    plt.figure(); plt.bar(x, r); plt.xticks(x, CIFAR10_CLASSES, rotation=45, ha="right")
    plt.ylabel("Recall"); plt.title("Per-class Recall"); plt.tight_layout()
    plt.savefig(f"{outstem}_recall.png"); plt.close()
    # F1
    plt.figure(); plt.bar(x, f1); plt.xticks(x, CIFAR10_CLASSES, rotation=45, ha="right")
    plt.ylabel("F1-score"); plt.title("Per-class F1-score"); plt.tight_layout()
    plt.savefig(f"{outstem}_f1.png"); plt.close()


In [None]:
#@title 🧠 Model (VGG19), train/eval, structured pruning, sparsity
def build_vgg19(num_classes=10, pretrained=False):
    model = models.vgg19(weights=models.VGG19_Weights.DEFAULT if pretrained else None)
    in_features = model.classifier[-1].in_features
    head = list(model.classifier.children())
    head[-1] = nn.Linear(in_features, num_classes)
    model.classifier = nn.Sequential(*head)
    return model

def train_one_epoch(model, loader, criterion, optimizer, device, scaler=None):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        if scaler:
            with torch.autocast(device_type=device.type, dtype=torch.float16):
                logits = model(x); loss = criterion(logits, y)
            scaler.scale(loss).backward(); scaler.step(optimizer); scaler.update()
        else:
            logits = model(x); loss = criterion(logits, y)
            loss.backward(); optimizer.step()
        running_loss += loss.item() * x.size(0)
        pred = logits.argmax(1)
        correct += (pred == y).sum().item()
        total += y.size(0)
    return running_loss / total, correct / total

@torch.no_grad()
def evaluate(model, loader, device, return_preds=False):
    model.eval()
    y_true, y_pred = [], []
    total, running_loss = 0, 0.0
    criterion = nn.CrossEntropyLoss()
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        logits = model(x); loss = criterion(logits, y)
        running_loss += loss.item() * x.size(0)
        pred = logits.argmax(1)
        y_true.append(y.cpu().numpy()); y_pred.append(pred.cpu().numpy())
        total += y.size(0)
    y_true = np.concatenate(y_true); y_pred = np.concatenate(y_pred)
    acc = accuracy_score(y_true, y_pred)
    p_mac, r_mac, f1_mac, _ = precision_recall_fscore_support(y_true, y_pred, average='macro', zero_division=0)
    p_mic, r_mic, f1_mic, _ = precision_recall_fscore_support(y_true, y_pred, average='micro', zero_division=0)
    cm = confusion_matrix(y_true, y_pred, labels=list(range(10)))
    return {
        "loss": running_loss / total,
        "accuracy": acc,
        "precision_macro": p_mac, "recall_macro": r_mac, "f1_macro": f1_mac,
        "precision_micro": p_mic, "recall_micro": r_mic, "f1_micro": f1_mic,
        "confusion_matrix": cm,
        "y_true": y_true if return_preds else None,
        "y_pred": y_pred if return_preds else None,
        "class_report": classification_report(y_true, y_pred, target_names=CIFAR10_CLASSES, zero_division=0)
    }

def prune_for_inference(model, amount=0.3, structured=True, progressive=True):
    """
    Structured pruning by default.
      - Conv2d: prune output channels (dim=0) by L2 norm
      - Linear: prune output units (dim=0)
    progressive=True → slightly more pruning in deeper layers.
    Note: shapes are not physically shrunk; channels are zeroed (structured sparsity).
    """
    m = copy.deepcopy(model)
    # collect target layers
    targets = [mod for mod in m.modules() if isinstance(mod, (nn.Conv2d, nn.Linear))]
    if not targets:
        return m

    if structured:
        if progressive and len(targets) > 1:
            import torch as _torch
            scales = _torch.linspace(0.8, 1.2, steps=len(targets)).tolist()
            per_layer = [max(0.0, min(0.95, s * amount)) for s in scales]
        else:
            per_layer = [max(0.0, min(0.95, amount))] * len(targets)

        for mod, amt in zip(targets, per_layer):
            if isinstance(mod, nn.Conv2d):
                prune.ln_structured(mod, name="weight", amount=amt, n=2, dim=0)  # drop output channels
                prune.remove(mod, "weight")
            elif isinstance(mod, nn.Linear):
                prune.ln_structured(mod, name="weight", amount=amt, n=2, dim=0)  # drop output units
                prune.remove(mod, "weight")
    else:
        for mod in targets:
            prune.l1_unstructured(mod, name="weight", amount=amount)
            prune.remove(mod, "weight")
    return m

@torch.no_grad()
def report_sparsity(model):
    zeros_total, params_total = 0, 0
    details = []
    for name, p in model.named_parameters():
        if p is None or p.numel() == 0:
            continue
        num = p.numel()
        z = (p == 0).sum().item()
        frac = z / num
        zeros_total += z; params_total += num
        details.append((name, z, num, frac))
    global_frac = zeros_total / params_total if params_total > 0 else 0.0
    return global_frac, details


In [None]:
#@title 🚀 Train, structured prune-at-inference, evaluate, plot, save (self-contained)

# --- knobs ---
SEED = 42                 #@param {type:"integer"}
EPOCHS = 40               #@param {type:"integer"}
BATCH  = 256              #@param {type:"integer"}
LR     = 0.01             #@param {type:"number"}
MOMENTUM = 0.9            #@param {type:"number"}
WEIGHT_DECAY = 5e-4       #@param {type:"number"}
STEP_SIZE = 15            #@param {type:"integer"}
GAMMA     = 0.1           #@param {type:"number"}

# --- pruning ---
STRUCTURED   = True       # default: structured pruning
PRUNE_AMOUNT = 0.30       # try 0.2–0.5; >0.3 usually needs FT

# --- plotting & FT ---
PLOT = True               # save curves + CMs + per-class bars
DO_FT = True              # short fine-tune after pruning (recommended for structured)
FINE_TUNE_EPOCHS = 5      # small bump to recover accuracy
FINE_TUNE_LR     = 5e-4

# ---------- imports you already had ----------
import time, copy, random
import numpy as np
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms, models

from sklearn.metrics import (accuracy_score, precision_recall_fscore_support,
                             confusion_matrix, classification_report)
from torch.nn.utils import prune

# CIFAR-10 class names (used by plots)
CIFAR10_CLASSES = ["airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck"]

# --- tiny utils we rely on here ---
def seed_everything(seed=42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

def count_params(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

def build_dataloaders(data_root, batch_size, num_workers=2, val_ratio=0.1):
    normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                     std=[0.2470, 0.2435, 0.2616])
    train_tf = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.AutoAugment(transforms.AutoAugmentPolicy.CIFAR10),
        transforms.ToTensor(),
        normalize
    ])
    test_tf = transforms.Compose([transforms.ToTensor(), normalize])

    train_full = datasets.CIFAR10(root=data_root, train=True, download=True, transform=train_tf)
    test_set   = datasets.CIFAR10(root=data_root, train=False, download=True, transform=test_tf)

    n_val = int(len(train_full) * val_ratio)
    n_train = len(train_full) - n_val
    train_set, val_set = random_split(train_full, [n_train, n_val],
                                      generator=torch.Generator().manual_seed(123))

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True,  num_workers=num_workers, pin_memory=True)
    val_loader   = DataLoader(val_set,   batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    test_loader  = DataLoader(test_set,  batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    return train_loader, val_loader, test_loader

def build_vgg19(num_classes=10, pretrained=False):
    model = models.vgg19(weights=models.VGG19_Weights.DEFAULT if pretrained else None)
    in_features = model.classifier[-1].in_features
    head = list(model.classifier.children())
    head[-1] = nn.Linear(in_features, num_classes)
    model.classifier = nn.Sequential(*head)
    return model

def train_one_epoch(model, loader, criterion, optimizer, device, scaler=None):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)
        if scaler:
            with torch.autocast(device_type=device.type, dtype=torch.float16):
                logits = model(x)
                loss = criterion(logits, y)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()
        running_loss += loss.item() * x.size(0)
        pred = logits.argmax(1)
        correct += (pred == y).sum().item()
        total += y.size(0)
    return running_loss / total, correct / total

@torch.no_grad()
def evaluate(model, loader, device, return_preds=False):
    model.eval()
    y_true, y_pred = [], []
    total, running_loss = 0, 0.0
    criterion = nn.CrossEntropyLoss()
    for x, y in loader:
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        logits = model(x)
        loss = criterion(logits, y)
        running_loss += loss.item() * x.size(0)
        pred = logits.argmax(1)
        y_true.append(y.cpu().numpy())
        y_pred.append(pred.cpu().numpy())
        total += y.size(0)
    y_true = np.concatenate(y_true)
    y_pred = np.concatenate(y_pred)
    acc = accuracy_score(y_true, y_pred)
    p_mac, r_mac, f1_mac, _ = precision_recall_fscore_support(y_true, y_pred, average='macro', zero_division=0)
    p_mic, r_mic, f1_mic, _ = precision_recall_fscore_support(y_true, y_pred, average='micro', zero_division=0)
    cm = confusion_matrix(y_true, y_pred, labels=list(range(10)))
    return {
        "loss": running_loss / total,
        "accuracy": acc,
        "precision_macro": p_mac, "recall_macro": r_mac, "f1_macro": f1_mac,
        "precision_micro": p_mic, "recall_micro": r_mic, "f1_micro": f1_mic,
        "confusion_matrix": cm,
        "y_true": y_true if return_preds else None,
        "y_pred": y_pred if return_preds else None,
        "class_report": classification_report(y_true, y_pred, target_names=CIFAR10_CLASSES, zero_division=0)
    }

def prune_for_inference(model, amount=0.3, structured=True):
    """
    Structured pruning (default):
      - Conv2d: prune output channels (dim=0) by L2 norm
      - Linear: prune output units (dim=0)
    Leaves shapes the same (structured sparsity).
    """
    m = copy.deepcopy(model)
    if structured:
        for mod in m.modules():
            if isinstance(mod, nn.Conv2d):
                prune.ln_structured(mod, name="weight", amount=amount, n=2, dim=0)
                prune.remove(mod, "weight")
            elif isinstance(mod, nn.Linear):
                prune.ln_structured(mod, name="weight", amount=amount, n=2, dim=0)
                prune.remove(mod, "weight")
    else:
        for mod in m.modules():
            if isinstance(mod, (nn.Conv2d, nn.Linear)):
                prune.l1_unstructured(mod, name="weight", amount=amount)
                prune.remove(mod, "weight")
    return m

@torch.no_grad()
def report_sparsity(model):
    zeros_total, params_total = 0, 0
    details = []
    for name, p in model.named_parameters():
        if p is None or p.numel() == 0:
            continue
        num = p.numel()
        z = (p == 0).sum().item()
        frac = z / num
        zeros_total += z; params_total += num
        details.append((name, z, num, frac))
    global_frac = zeros_total / params_total if params_total > 0 else 0.0
    return global_frac, details

# --- plotting helpers embedded here (fixed confusion-matrix formatting) ---
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

def plot_curves(history, outdir):
    # Loss
    plt.figure()
    plt.plot(history["epoch"], history["train_loss"], label="train_loss")
    plt.plot(history["epoch"], history["val_loss"],   label="val_loss")
    plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.legend(); plt.grid(True, linestyle=":")
    plt.title("Training vs Validation Loss")
    plt.tight_layout(); plt.savefig(Path(outdir)/"loss_curves.png"); plt.close()
    # Accuracy
    plt.figure()
    plt.plot(history["epoch"], history["train_acc"], label="train_acc")
    plt.plot(history["epoch"], history["val_acc"],   label="val_acc")
    plt.xlabel("Epoch"); plt.ylabel("Accuracy"); plt.legend(); plt.grid(True, linestyle=":")
    plt.title("Training vs Validation Accuracy")
    plt.tight_layout(); plt.savefig(Path(outdir)/"acc_curves.png"); plt.close()

def plot_confusion(cm, classes, outpath, normalize=False, title="Confusion Matrix"):
    cm = np.array(cm)
    if normalize:
        cm = cm.astype(np.float64)
        with np.errstate(all='ignore'):
            cm = cm / cm.sum(axis=1, keepdims=True)
            cm = np.nan_to_num(cm)
        fmt = ".2f"
    else:
        if cm.dtype.kind == 'f':
            cm = np.rint(cm).astype(np.int64)
        fmt = "d"

    plt.figure(figsize=(6.5, 5.5))
    plt.imshow(cm, interpolation='nearest')
    plt.title(title); plt.colorbar()
    ticks = np.arange(len(classes))
    plt.xticks(ticks, classes, rotation=45, ha="right")
    plt.yticks(ticks, classes)
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, format(cm[i, j], fmt), ha="center", va="center")
    plt.ylabel("True label"); plt.xlabel("Predicted label")
    Path(outpath).parent.mkdir(parents=True, exist_ok=True)
    plt.tight_layout(); plt.savefig(outpath); plt.close()

def plot_per_class_bars(report_dict, outstem):
    y_true, y_pred = report_dict.get("y_true"), report_dict.get("y_pred")
    if y_true is None or y_pred is None: return
    p, r, f1, _ = precision_recall_fscore_support(y_true, y_pred, labels=list(range(10)), zero_division=0)
    x = np.arange(len(CIFAR10_CLASSES))
    # Precision
    plt.figure(); plt.bar(x, p); plt.xticks(x, CIFAR10_CLASSES, rotation=45, ha="right")
    plt.ylabel("Precision"); plt.title("Per-class Precision"); plt.tight_layout()
    plt.savefig(f"{outstem}_precision.png"); plt.close()
    # Recall
    plt.figure(); plt.bar(x, r); plt.xticks(x, CIFAR10_CLASSES, rotation=45, ha="right")
    plt.ylabel("Recall"); plt.title("Per-class Recall"); plt.tight_layout()
    plt.savefig(f"{outstem}_recall.png"); plt.close()
    # F1
    plt.figure(); plt.bar(x, f1); plt.xticks(x, CIFAR10_CLASSES, rotation=45, ha="right")
    plt.ylabel("F1-score"); plt.title("Per-class F1-score"); plt.tight_layout()
    plt.savefig(f"{outstem}_f1.png"); plt.close()

# ---------------- run ----------------
seed_everything(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

train_loader, val_loader, test_loader = build_dataloaders("/content/data", BATCH)

model = build_vgg19(num_classes=10, pretrained=False).to(device)
tot, trn = count_params(model)
print(f"VGG19 params: total={tot/1e6:.2f}M, trainable={trn/1e6:.2f}M")

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
scheduler = StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)
scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))

history = {"epoch": [], "train_loss": [], "train_acc": [], "val_loss": [], "val_acc": [], "lr": []}
best_val_f1 = -1.0
best_state = copy.deepcopy(model.state_dict())

print("\n=== TRAINING ===")
for epoch in range(1, EPOCHS + 1):
    t0 = time.time()
    tr_loss, tr_acc = train_one_epoch(model, train_loader, criterion, optimizer, device, scaler)
    val_rep = evaluate(model, val_loader, device)
    scheduler.step()

    history["epoch"].append(epoch)
    history["train_loss"].append(tr_loss)
    history["train_acc"].append(tr_acc)
    history["val_loss"].append(val_rep["loss"])
    history["val_acc"].append(val_rep["accuracy"])
    history["lr"].append(optimizer.param_groups[0]["lr"])

    if val_rep["f1_macro"] > best_val_f1:
        best_val_f1 = val_rep["f1_macro"]
        best_state = copy.deepcopy(model.state_dict())

    dt = time.time() - t0
    print(f"Epoch {epoch:02d}/{EPOCHS} | tr_loss {tr_loss:.4f} tr_acc {tr_acc:.4f} "
          f"| val_loss {val_rep['loss']:.4f} val_acc {val_rep['accuracy']:.4f} "
          f"| val_f1(macro) {val_rep['f1_macro']:.4f} | {dt:.1f}s")

# Load best weights by validation macro-F1
model.load_state_dict(best_state)

# Baseline test
print("\n=== TEST: Baseline (no pruning) ===")
base_rep = evaluate(model, test_loader, device, return_preds=True)
print(base_rep["class_report"])
print(f"Acc: {base_rep['accuracy']:.4f} | F1(macro): {base_rep['f1_macro']:.4f}")

# Structured prune for inference
print("\n=== APPLY STRUCTURED PRUNING (inference) ===")
pruned_model = prune_for_inference(model, amount=PRUNE_AMOUNT, structured=STRUCTURED).to(device)
tot_p, trn_p = count_params(pruned_model)
print(f"Pruned params: total={tot_p/1e6:.2f}M, trainable={trn_p/1e6:.2f}M "
      f"(amount={PRUNE_AMOUNT}, structured={STRUCTURED})")

# Sparsity report
global_sparsity, details = report_sparsity(pruned_model)
print(f"Global weight sparsity after pruning: {100*global_sparsity:.2f}%")
topk = sorted(details, key=lambda x: x[3], reverse=True)[:5]
print("Top sparse params:")
for name, z, n, frac in topk:
    print(f"  {name:45s} {100*frac:5.1f}% zeros  ({z}/{n})")

# Pruned test
print("\n=== TEST: Pruned (structured) ===")
pruned_rep = evaluate(pruned_model, test_loader, device, return_preds=True)
print(pruned_rep["class_report"])
print(f"Acc: {pruned_rep['accuracy']:.4f} | F1(macro): {pruned_rep['f1_macro']:.4f}")

# Optional short fine-tune to recover accuracy
if DO_FT:
    print("\n=== FINE-TUNE (structured pruned model) ===")
    ft_model = copy.deepcopy(pruned_model).to(device)
    optimizer_ft = optim.SGD(ft_model.parameters(), lr=FINE_TUNE_LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
    scheduler_ft = StepLR(optimizer_ft, step_size=max(1, FINE_TUNE_EPOCHS//2), gamma=0.5)
    for e in range(1, FINE_TUNE_EPOCHS+1):
        tr_loss, tr_acc = train_one_epoch(ft_model, train_loader, criterion, optimizer_ft, device)
        val_rep = evaluate(ft_model, val_loader, device)
        scheduler_ft.step()
        print(f"FT {e:02d}/{FINE_TUNE_EPOCHS} | tr_loss {tr_loss:.4f} tr_acc {tr_acc:.4f} "
              f"| val_loss {val_rep['loss']:.4f} val_acc {val_rep['accuracy']:.4f} "
              f"| val_f1(macro) {val_rep['f1_macro']:.4f}")

    print("\n=== TEST: Pruned + Fine-Tuned ===")
    ft_rep = evaluate(ft_model, test_loader, device, return_preds=True)
    print(ft_rep["class_report"])
    print(f"Acc: {ft_rep['accuracy']:.4f} | F1(macro): {ft_rep['f1_macro']:.4f}")

# Comparison table
def row(name, rep):
    return [name, f"{rep['accuracy']:.4f}", f"{rep['precision_macro']:.4f}", f"{rep['recall_macro']:.4f}", f"{rep['f1_macro']:.4f}"]
header = ["Model", "Acc", "Prec(mac)", "Rec(mac)", "F1(mac)"]
rows = [header, row("Baseline", base_rep), row("Pruned", pruned_rep)]
if DO_FT:
    rows.append(row("Pruned+FT", ft_rep))
colw = [max(len(r[i]) for r in rows) for i in range(len(header))]
print("\n=== COMPARISON (test) ===")
print(" | ".join(h.ljust(colw[i]) for i, h in enumerate(header)))
print("-+-".join("-"*colw[i] for i in range(len(header))))
for r in rows[1:]:
    print(" | ".join(r[i].ljust(colw[i]) for i in range(len(header))))

# Save artifacts
ARTIFACT_DIR = "/content/artifacts"
outdir = Path(ARTIFACT_DIR); outdir.mkdir(exist_ok=True, parents=True)
torch.save(model.state_dict(), outdir/"vgg19_cifar10_baseline.pt")
torch.save(pruned_model.state_dict(), outdir/"vgg19_cifar10_pruned.pt")
if DO_FT:
    torch.save(ft_model.state_dict(), outdir/"vgg19_cifar10_pruned_finetuned.pt")
np.save(outdir/"history.npy", history, allow_pickle=True)

# Plots
if PLOT:
    plot_curves(history, outdir)
    # Baseline
    plot_confusion(base_rep["confusion_matrix"], CIFAR10_CLASSES, outdir/"cm_baseline_raw.png",
                   normalize=False, title="Confusion Matrix (Baseline)")
    plot_confusion(base_rep["confusion_matrix"], CIFAR10_CLASSES, outdir/"cm_baseline_norm.png",
                   normalize=True, title="Normalized Confusion Matrix (Baseline)")
    plot_per_class_bars(base_rep, str(outdir/"baseline_perclass"))
    # Pruned
    plot_confusion(pruned_rep["confusion_matrix"], CIFAR10_CLASSES, outdir/"cm_pruned_raw.png",
                   normalize=False, title="Confusion Matrix (Pruned)")
    plot_confusion(pruned_rep["confusion_matrix"], CIFAR10_CLASSES, outdir/"cm_pruned_norm.png",
                   normalize=True, title="Normalized Confusion Matrix (Pruned)")
    plot_per_class_bars(pruned_rep, str(outdir/"pruned_perclass"))
    # Pruned + FT
    if DO_FT:
        plot_confusion(ft_rep["confusion_matrix"], CIFAR10_CLASSES, outdir/"cm_pruned_ft_norm.png",
                       normalize=True, title="Normalized Confusion Matrix (Pruned + FT)")
        plot_per_class_bars(ft_rep, str(outdir/"pruned_ft_perclass"))

print("\nArtifacts saved in:", outdir)


Device: cuda


100%|██████████| 170M/170M [00:05<00:00, 29.9MB/s]


VGG19 params: total=139.61M, trainable=139.61M

=== TRAINING ===


  scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda"))


Epoch 01/40 | tr_loss 2.2785 tr_acc 0.1222 | val_loss 2.1949 val_acc 0.1632 | val_f1(macro) 0.0911 | 53.5s
Epoch 02/40 | tr_loss 2.1103 tr_acc 0.1932 | val_loss 2.1268 val_acc 0.1872 | val_f1(macro) 0.1162 | 19.9s
Epoch 03/40 | tr_loss 1.9508 tr_acc 0.2492 | val_loss 1.8555 val_acc 0.2914 | val_f1(macro) 0.2541 | 19.9s
Epoch 04/40 | tr_loss 1.8114 tr_acc 0.3148 | val_loss 1.7737 val_acc 0.3420 | val_f1(macro) 0.3257 | 20.0s
Epoch 05/40 | tr_loss 1.6847 tr_acc 0.3730 | val_loss 1.5493 val_acc 0.4256 | val_f1(macro) 0.4064 | 19.8s
Epoch 06/40 | tr_loss 1.5340 tr_acc 0.4381 | val_loss 1.4586 val_acc 0.4774 | val_f1(macro) 0.4716 | 20.5s
Epoch 07/40 | tr_loss 1.3982 tr_acc 0.4917 | val_loss 1.3322 val_acc 0.5180 | val_f1(macro) 0.5049 | 19.9s
Epoch 08/40 | tr_loss 1.2649 tr_acc 0.5438 | val_loss 1.2299 val_acc 0.5602 | val_f1(macro) 0.5561 | 20.2s
Epoch 09/40 | tr_loss 1.1888 tr_acc 0.5725 | val_loss 1.2380 val_acc 0.5622 | val_f1(macro) 0.5522 | 20.3s
Epoch 10/40 | tr_loss 1.1331 tr_acc 0