In [None]:
import os, time, json
import copy
import math
import argparse
from collections import OrderedDict, defaultdict
from typing import Dict, Tuple, List, Any

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm

import torch
from torch.utils.data import random_split
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.serialization import add_safe_globals


**Task 1**

In [15]:
device = "cuda" if torch.cuda.is_available() else "cpu"

normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize,
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    normalize,
])

VAL_SIZE = 5000
TEST_SIZE = len(CIFAR10_testset) - VAL_SIZE

CIFAR10_trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
CIFAR10_testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

CIFAR10_valset, CIFAR10_testset_final = random_split(CIFAR10_testset, [VAL_SIZE, TEST_SIZE], generator=torch.Generator().manual_seed(42))

CIFAR10_trainloader = DataLoader(CIFAR10_trainset, batch_size=64, shuffle=True, num_workers=2)

CIFAR10_valloader = DataLoader(CIFAR10_valset, batch_size=64, shuffle=False, num_workers=2)
CIFAR10_testloader = DataLoader(CIFAR10_testset_final, batch_size=64, shuffle=False, num_workers=2)

CIFAR10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
num_classes = len(CIFAR10_classes)

vgg_model = torchvision.models.vgg11_bn(pretrained=False)
vgg_model.classifier[6] = nn.Linear(in_features=vgg_model.classifier[6].in_features, out_features=num_classes)
vgg_model.to(device)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU(inplace=True)
    (11): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (13): ReLU(inplace=True)
    (14): MaxPool2d(ke

*Unstructured Pruning*

In [None]:
def get_prunable_param_names(model):
    prunable = []
    for name, param in model.named_parameters():
        if name.endswith(".weight"):
            comps = name.split(".")
            module = model
            for comp in comps[:-1]:
                if hasattr(module, comp):
                    module = getattr(module, comp)
                else:
                    try:
                        idx = int(comp)
                        module = module[idx]
                    except Exception:
                        module = None
                        break
            if module is not None:
                if isinstance(module, (nn.Conv2d, nn.Linear)):
                    prunable.append(name)
    return prunable

def init_masks(model, prunable_names):
  masks = {}
  for name, param in model.named_parameters():
    if name in prunable_names:
      masks[name] = torch.ones_like(param.data, dtype=torch.uint8).cpu()
  return masks

def apply_masks(model, masks):
  with torch.no_grad():
    for name, param in model.named_parameters():
      if name in masks:
        mask = masks[name].to(param.device).to(param.dtype)
        param.data.mul_(mask)

def count_global_sparsity(masks):
  total = sum(m.numel() for m in masks.values())
  nonzero = sum(int(m.sum().item()) for m in masks.values())
  return 1.0 - (nonzero/total), nonzero, total

In [None]:
def prune_layer_by_magnitude(model, masks, layer_name, target_sparsity):
  assert 0.0 <= target_sparsity <= 1.0
  param = dict(model.named_parameters())[layer_name]
  w = param.data.detach().cpu().clone()
  mask = masks[layer_name].cpu().clone()
  total = mask.numel()
  target_pruned = int(math.floor(target_sparsity * total))
  currently_pruned = int((mask==0).sum().item())
  need_prune = max(0, target_pruned - currently_pruned)
  if need_prune == 0:
    return
  active_flat_idx = torch.nonzero(mask.view(-1)).squeeze()
  if active_flat_idx.numel() == 0:
    return
  active_vals = w.view(-1).abs()[active_flat_idx]
  if need_prune >= active_vals.numel():
    new_mask = torch.zeros_like(mask)
    masks[layer_name] = new_mask.to(torch.uint8)
    return
  _, order = torch.sort(active_vals)
  selected = active_flat_idx[order[:need_prune]]
  new_mask_flat = mask.view(-1).clone()
  new_mask_flat[selected] = 0
  masks[layer_name] = new_mask_flat.view(mask.shape).to(torch.uint8)


In [18]:
def train_one_epoch(model, dataloader, criterion, optimizer, masks=None):
  model.train()
  running_loss = 0.0
  correct = 0
  total = 0
  for images, labels in dataloader:
    images, labels = images.to(device), labels.to(device)
    optimizer.zero_grad()
    outputs = model(images)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    if masks is not None:
      apply_masks(model, masks)
    batch_size = images.size(0)
    running_loss += loss.item() * batch_size
    _, preds = outputs.max(1)
    correct += preds.eq(labels).sum().item()
    total += batch_size
  return running_loss / total, 100.0 * correct / total

@torch.no_grad()
def evaluate(model, dataloader):
  model.eval()
  total = 0
  correct = 0
  for images, labels in dataloader:
    images, labels = images.to(device), labels.to(device)
    outputs = model(images)
    _, preds = outputs.max(1)
    correct += preds.eq(labels).sum().item()
    total += images.size(0)
  return 100.0 * correct / total

In [None]:
def save_layer_histogram(model, masks, layer_name, out_dir, tag):
  os.makedirs(out_dir, exist_ok=True)
  w = dict(model.named_parameters())[layer_name].data.detach().cpu().numpy().flatten()
  m = masks[layer_name].cpu().numpy().flatten()
  active = w[m==1]
  plt.figure(figsize=(6,4))
  plt.hist(w, bins=120, alpha=0.4, label='all', density=True)
  if active.size>0:
    plt.hist(active, bins=120, alpha=0.6, label='active (mask=1)', density=True)
  plt.title(f"{layer_name} ({tag})")
  plt.legend()
  plt.tight_layout()
  fname = os.path.join(out_dir, f"{layer_name.replace('.','_')}_{tag}.png")
  plt.savefig(fname)
  plt.close()

def save_all_histograms(model, masks, prunable_names, out_dir, tag):
  for name in prunable_names:
    save_layer_histogram(model, masks, name, out_dir, tag)\

def display_figure(fig):
  try:
    from IPython.display import display
    ipy_available = True
  except ImportError:
      ipy_available = False

  try:
      if ipy_available and fig is not None:
          display(fig)
      else:
          plt.show()
  except Exception:
      plt.show()

In [None]:
def run_sensitivity(baseline_checkpoint, prunable_names, dataloaders, device, sparsity_list, short_finetune_epochs=3, out_dir='sensitivity'):
    os.makedirs(out_dir, exist_ok=True)
    trainloader, testloader = dataloaders

    model = torchvision.models.vgg11_bn(pretrained=False)
    model.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)
    ck = torch.load(baseline_checkpoint, map_location='cpu')
    model.load_state_dict(ck['model_state_dict'])
    model.to(device)

    baseline_acc = evaluate(model, testloader)  
    print("Baseline accuracy (loaded):", baseline_acc)

    results = {}
    criterion = nn.CrossEntropyLoss()

    for layer in prunable_names:
        results[layer] = {'sparsities': [], 'acc_no_ft': []}
        for s in sparsity_list:
            sp = float(s) / 100.0
            model.load_state_dict(ck['model_state_dict'])
            model.to(device)
            masks = init_masks(model, prunable_names)

            prune_layer_by_magnitude(model, masks, layer, sp)
            apply_masks(model, masks)

            acc_no = evaluate(model, testloader)

            results[layer]['sparsities'].append(s)
            results[layer]['acc_no_ft'].append(acc_no)
            print(f"[sens] {layer} @ {s}% -> no_ft {acc_no:.2f}")

        plt.figure(figsize=(6,4))
        plt.plot(results[layer]['sparsities'], results[layer]['acc_no_ft'], marker='o', label='no finetune')
        plt.xlabel('Sparsity (%)'); plt.ylabel('Val Accuracy (%)'); plt.title(layer)
        plt.legend(); plt.grid(True)
        plt.tight_layout()
        plt.savefig(os.path.join(out_dir, f"{layer.replace('.','_')}_sensitivity.png"))
        plt.close()

    np.save(os.path.join(out_dir, 'sensitivity_results.npy'), results)
    print("Saved sensitivity results to", out_dir)

    avg_acc_drop = {}
    for layer, rec in results.items():
      base = baseline_acc
      mean_no_ft = np.mean(rec['acc_no_ft'])
      avg_acc_drop[layer] = base - mean_no_ft
    sorted_layers = sorted(avg_acc_drop.items(), key=lambda x: x[1])  
    ranked_path = os.path.join(out_dir, 'ranked_layers_by_sensitivity.json')
    json.dump(sorted_layers, open(ranked_path, 'w'), indent=2)
    print(f"Saved layer ranking (least sensitive first) → {ranked_path}")

    return results, baseline_acc

In [None]:
def train_baseline(baseline_path="baseline_vgg11.pth",
                   epochs=20,
                   lr=0.001,
                   batch_size=64,
                   momentum=0.9,
                   weight_decay=5e-4):
    vgg_model.to(device)
    p = next(vgg_model.parameters())
    print("Model param device:", p.device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(vgg_model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

    best_acc = 0.0
    for epoch in range(epochs):
        vgg_model.train()
        total, correct, running_loss = 0, 0, 0.0
        for images, labels in CIFAR10_trainloader:
            images, labels = images.to(device), labels.to(device)   
            optimizer.zero_grad()
            outputs = vgg_model(images)                             
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            _, preds = outputs.max(1)
            total += labels.size(0)
            correct += preds.eq(labels).sum().item()

        train_acc = 100.0 * correct / total
        train_loss = running_loss / total

        vgg_model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for images, labels in CIFAR10_valloader:
                images, labels = images.to(device), labels.to(device)
                outputs = vgg_model(images)
                _, preds = outputs.max(1)
                total += labels.size(0)
                correct += preds.eq(labels).sum().item()
        val_acc = 100.0 * correct / total

        print(f"[Epoch {epoch+1}/{epochs}] Train Acc: {train_acc:.2f}%  Val Acc: {val_acc:.2f}%")

        scheduler.step()

        if val_acc > best_acc:
            best_acc = val_acc
            state_dict_cpu = {k: v.cpu() for k, v in vgg_model.state_dict().items()}
            torch.save({
                'model_state_dict': state_dict_cpu,
                'num_classes': num_classes,
                'arch': 'vgg11_bn'
            }, baseline_path)
            print(f"Saved baseline clean state_dict -> {baseline_path}")
            print(f"Saved new best model ({best_acc:.2f}%)")

    print(f"Baseline training complete. Best validation accuracy: {best_acc:.2f}%")


In [None]:
import math, json, numpy as np

def suggest_unstructured_allocation(prunable_names, sensitivity_results, target_global_sparsity=0.70, per_layer_cap=0.95, anchor_sparsity=50, baseline_acc=None):
    name_to_param = {n: dict(vgg_model.named_parameters())[n] for n in prunable_names}
    layer_param_counts = {n: name_to_param[n].numel() for n in prunable_names}
    total_params = sum(layer_param_counts.values())
    target_prune_params = int(round(target_global_sparsity * total_params))

    scores = {}
    for n in prunable_names:
        res = sensitivity_results.get(n, None)
        if res:
            spars_list = res.get('sparsities', [])
            accs_no = res.get('acc_no_ft', []) or res.get('acc_after_ft', [])
            if len(spars_list) and len(accs_no):
                idx = min(range(len(spars_list)), key=lambda i: abs(spars_list[i] - anchor_sparsity))
                scores[n] = float(accs_no[idx])
            else:
                if accs_no:
                    scores[n] = float(max(accs_no))
                else:
                    scores[n] = -1.0
        else:
            scores[n] = -1.0

    if baseline_acc is None:
        all_accs = []
        for res in sensitivity_results.values():
            all_accs += res.get('acc_no_ft', []) or res.get('acc_after_ft', []) or []
        baseline_acc = float(max(all_accs)) if all_accs else None

    items = sorted(prunable_names, key=lambda x: scores.get(x, -1.0), reverse=True)

    to_prune = target_prune_params
    allocated = {n: 0 for n in prunable_names}
    for name in items:
        if to_prune <= 0:
            break
        layer_total = layer_param_counts[name]
        cap = int(math.floor(per_layer_cap * layer_total))
        take = min(cap, to_prune)
        allocated[name] = take
        to_prune -= take

    if to_prune > 0:
        for name in items:
            if to_prune <= 0:
                break
            layer_total = layer_param_counts[name]
            cap = int(math.floor(per_layer_cap * layer_total))
            avail = cap - allocated[name]
            if avail <= 0:
                continue
            add = min(avail, to_prune)
            allocated[name] += add
            to_prune -= add

    per_layer_sparsities = {}
    for name in prunable_names:
        frac = float(allocated.get(name, 0)) / float(layer_param_counts[name]) if layer_param_counts[name] > 0 else 0.0
        per_layer_sparsities[name] = min(max(frac, 0.0), per_layer_cap)

    est_pruned = sum(int(round(per_layer_sparsities[n] * layer_param_counts[n])) for n in prunable_names)
    info = {
        'total_params': int(total_params),
        'target_prune_params': int(target_prune_params),
        'estimated_pruned_params': int(est_pruned),
        'estimated_global_sparsity': float(est_pruned / total_params) if total_params > 0 else 0.0,
        'per_layer_sparsities': per_layer_sparsities,
        'scores_anchor': {n: float(scores.get(n, -1.0)) for n in prunable_names},
        'baseline_acc_used': float(baseline_acc) if baseline_acc is not None else None
    }

    print(f"[allocator] total_params={total_params}, target_prune_params={target_prune_params}, estimated_pruned={est_pruned}, est_global_sparsity={info['estimated_global_sparsity']:.4f}")
    return per_layer_sparsities, info


In [None]:
def apply_per_layer_and_finetune(baseline_checkpoint, per_layer_sparsities, dataloaders, device, finetune_epochs, out_dir='final_pruned'):

    os.makedirs(out_dir, exist_ok=True)
    trainloader, testloader = dataloaders

    model = torchvision.models.vgg11_bn(pretrained=False)
    model.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)
    ckpt = torch.load(baseline_checkpoint, map_location='cpu')
    model.load_state_dict(ckpt['model_state_dict'])
    model = model.to(device)

    prunable = get_prunable_param_names(model)
    masks = {name: torch.ones_like(p, dtype=torch.uint8).cpu()
             for name, p in model.named_parameters() if name in prunable}

    for name, s in per_layer_sparsities.items():
        if name not in masks:
            print(f"[WARN] Skipping {name}: not found or not prunable.")
            continue
        prune_layer_by_magnitude(model, masks, name, s)

    apply_masks(model, masks)
    initial_acc = evaluate(model, testloader)
    print(f"Initial accuracy after applying sparsities: {initial_acc:.2f}%")

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
    for epoch in range(finetune_epochs):
        train_loss, train_acc = train_one_epoch(model, trainloader, criterion, optimizer, masks=masks)
        val_acc = evaluate(model, testloader)
        print(f"[Finetune {epoch+1}/{finetune_epochs}] Train Acc={train_acc:.2f}%, Val Acc={val_acc:.2f}%")

    total_params, nonzero_params = 0, 0
    for m in masks.values():
        total_params += m.numel()
        nonzero_params += (m == 1).sum().item()
    global_sparsity = 1.0 - (nonzero_params / total_params)
    print(f"Final global sparsity: {global_sparsity*100:.2f}%")

    final_sd = {k: v.cpu() for k, v in model.state_dict().items()}
    torch.save({
        'model_state_dict': final_sd,
        'num_classes': num_classes,
        'arch': 'vgg11_bn',
        'masks': {k: v.cpu().astype('uint8') if isinstance(v, np.ndarray) else v.cpu().numpy() for k,v in masks.items()},
        'global_sparsity': global_sparsity
    }, os.path.join(out_dir, 'final_pruned_state_dict.pth'))
    print("Saved final unstructured state_dict:", os.path.join(out_dir, 'final_pruned_state_dict.pth'))

    return model, masks, global_sparsity


In [None]:
def run_unstructured_pipeline(baseline_checkpoint, sparsity_list_percent=[10, 30, 50, 70], short_finetune_epochs=3, final_finetune_epochs=15, out_dir='unstructured_outputs', auto_apply_suggestion=True):
  
    os.makedirs(out_dir, exist_ok=True)

    if not os.path.exists(baseline_checkpoint):
        raise FileNotFoundError(f"Baseline checkpoint not found: {baseline_checkpoint}")

    ck = torch.load(baseline_checkpoint, map_location='cpu')
    try:
        vgg_model.load_state_dict(ck['model_state_dict'])
    except Exception:
        vgg_model.load_state_dict(ck, strict=False)
    vgg_model.to(device)

    prunable_names = get_prunable_param_names(vgg_model)

    masks_before = init_masks(vgg_model, prunable_names)
    save_all_histograms(vgg_model, masks_before, prunable_names, os.path.join(out_dir,'histograms_before'), tag='before')

    trainloader, testloader = CIFAR10_trainloader, CIFAR10_testloader
    print("Running sensitivity analysis (unstructured). This may take some time...")
    results, baseline_acc = run_sensitivity(baseline_checkpoint, prunable_names, (trainloader, testloader),
                                           device, sparsity_list_percent, short_finetune_epochs,
                                           out_dir=os.path.join(out_dir,'sensitivity'))
    print("Sensitivity done. Baseline accuracy:", baseline_acc)

    per_layer_suggestion, info = suggest_unstructured_allocation(prunable_names, results, target_global_sparsity=0.70, anchor_sparsity=50)
    suggestion_path = os.path.join(out_dir, 'suggested_unstruct_sparsities.json')
    with open(suggestion_path, 'w') as f:
        json.dump(per_layer_suggestion, f, indent=2)
    print(f"Saved suggested per-layer sparsities to {suggestion_path}")

    if not auto_apply_suggestion:
        print("auto_apply_suggestion=False: returning suggestion path for manual editing.")
        return suggestion_path

    pruned_model, masks_final, spars = apply_per_layer_and_finetune(baseline_checkpoint, per_layer_suggestion,
                                                                   (trainloader, testloader), device,
                                                                   finetune_epochs=final_finetune_epochs,
                                                                   out_dir=os.path.join(out_dir,'final'))
    
    save_all_histograms(pruned_model, masks_final, prunable_names, os.path.join(out_dir,'histograms_after'), tag='after')
    print("Final unstructured model saved. Global sparsity:", spars)
    return pruned_model, masks_final, spars

In [None]:
baseline_ckpt_path = "baseline_vgg11.pth"   

train_baseline( baseline_path=baseline_ckpt_path, epochs=15, batch_size=64)


Model param device: cuda:0
[Epoch 1/15] Train Acc: 38.02%  Val Acc: 51.94%
Saved baseline clean state_dict -> baseline_vgg11.pth
✓ Saved new best model (51.94%)
[Epoch 2/15] Train Acc: 53.41%  Val Acc: 60.48%
Saved baseline clean state_dict -> baseline_vgg11.pth
✓ Saved new best model (60.48%)
[Epoch 3/15] Train Acc: 59.80%  Val Acc: 65.38%
Saved baseline clean state_dict -> baseline_vgg11.pth
✓ Saved new best model (65.38%)
[Epoch 4/15] Train Acc: 64.26%  Val Acc: 69.24%
Saved baseline clean state_dict -> baseline_vgg11.pth
✓ Saved new best model (69.24%)
[Epoch 5/15] Train Acc: 67.56%  Val Acc: 70.26%
Saved baseline clean state_dict -> baseline_vgg11.pth
✓ Saved new best model (70.26%)
[Epoch 6/15] Train Acc: 70.17%  Val Acc: 72.88%
Saved baseline clean state_dict -> baseline_vgg11.pth
✓ Saved new best model (72.88%)
[Epoch 7/15] Train Acc: 72.04%  Val Acc: 71.44%
[Epoch 8/15] Train Acc: 73.58%  Val Acc: 74.86%
Saved baseline clean state_dict -> baseline_vgg11.pth
✓ Saved new best mo

In [None]:
baseline_ckpt_path = "baseline_vgg11.pth"   

results = run_unstructured_pipeline( baseline_checkpoint=baseline_ckpt_path, sparsity_list_percent=[50, 70], short_finetune_epochs=0, final_finetune_epochs=10, out_dir='unstructured_outputs')

Running sensitivity analysis (unstructured). This may take some time...
Baseline accuracy (loaded): 80.19
[sens] features.0.weight @ 50% -> no_ft 74.16
[sens] features.0.weight @ 70% -> no_ft 49.70
[sens] features.4.weight @ 50% -> no_ft 79.27
[sens] features.4.weight @ 70% -> no_ft 71.81
[sens] features.8.weight @ 50% -> no_ft 78.64
[sens] features.8.weight @ 70% -> no_ft 69.29
[sens] features.11.weight @ 50% -> no_ft 78.66
[sens] features.11.weight @ 70% -> no_ft 73.23
[sens] features.15.weight @ 50% -> no_ft 78.98
[sens] features.15.weight @ 70% -> no_ft 75.73
[sens] features.18.weight @ 50% -> no_ft 79.57
[sens] features.18.weight @ 70% -> no_ft 77.63
[sens] features.22.weight @ 50% -> no_ft 79.63
[sens] features.22.weight @ 70% -> no_ft 77.99
[sens] features.25.weight @ 50% -> no_ft 80.23
[sens] features.25.weight @ 70% -> no_ft 79.97
[sens] classifier.0.weight @ 50% -> no_ft 80.10
[sens] classifier.0.weight @ 70% -> no_ft 80.20
[sens] classifier.3.weight @ 50% -> no_ft 80.18
[sen

In [None]:
pruned_model = results[0]
print(f"Test Accuracy of pruned model: {evaluate(pruned_model, CIFAR10_testloader)}%")

Test Accuracy of pruned model: 83.87%


In [None]:
def load_unstructured_checkpoint(path, device=None):
    if device is None:
        device = globals().get('device', 'cpu')
    device = torch.device(device)
    ck = torch.load(path, map_location='cpu')  
    state = ck.get('model_state_dict', ck)
    num_classes = ck.get('num_classes', None)

    model = torchvision.models.vgg11_bn(pretrained=False)
    if num_classes is not None and model.classifier[6].out_features != int(num_classes):
        model.classifier[6] = torch.nn.Linear(model.classifier[6].in_features, int(num_classes))

    new_state = {}
    for k, v in state.items():
        new_k = k[len('module.'):] if k.startswith('module.') else k
        new_state[new_k] = v

    try:
        model.load_state_dict(new_state, strict=True)
    except RuntimeError as e:
        print("Strict load failed for unstructured checkpoint:", e)
        model.load_state_dict(new_state, strict=False)
        print("Loaded with strict=False (check missing/unexpected keys).")

    model.to(device)
    model.eval()
    return model, ck

def load_structured_checkpoint(path, device=None):
    if device is None:
        device = globals().get('device', 'cpu')
    device = torch.device(device)
    ck = torch.load(path, map_location='cpu')
    state = ck.get('model_state_dict', ck)
    kept_indices = ck.get('kept_indices', None)
    if kept_indices is None:
        raise RuntimeError("structured checkpoint missing 'kept_indices' — cannot rebuild pruned architecture.")

    normalized = {}
    for k, v in kept_indices.items():
        key = k
        if key.endswith('.weight') or key.endswith('.bias'):
            key = '.'.join(key.split('.')[:2])
        if isinstance(v, torch.Tensor):
            lst = v.cpu().tolist()
        elif isinstance(v, (list, tuple)):
            lst = [int(x) for x in v]
        elif isinstance(v, np.ndarray):
            lst = v.tolist()

def _resolve_kaggle_path(path: str) -> str:
    if os.path.exists(path):
        return path
    base = '/kaggle/input'
    if not os.path.exists(base):
        return path
    candidate = os.path.join(base, path)
    if os.path.exists(candidate):
        return candidate
    filename = os.path.basename(path)
    for root, _, files in os.walk(base):
        if filename in files:
            return os.path.join(root, filename)
    return path

def _torch_load_robust(path: str):
    try:
        return torch.load(path, map_location='cpu')
    except Exception as e:
        msg = str(e)
        if 'WeightsUnpickler' in msg or 'multiarray._reconstruct' in msg or 'UnpicklingError' in msg:
            print("WARN: torch.load raised an unpickling/weights-only error. Retrying with a safe-globals allowlist (only do this for trusted checkpoints).")
            try:
                add_safe_globals([np.core.multiarray._reconstruct])
            except Exception:
                try:
                    add_safe_globals([np._core.multiarray._reconstruct])
                except Exception:
                    pass
            return torch.load(path, map_location='cpu', weights_only=False)
        raise

def load_unstructured_checkpoint(path: str, device=None) -> Tuple[torch.nn.Module, Dict[str,Any]]:
    resolved = _resolve_kaggle_path(path)
    if not os.path.exists(resolved):
        raise FileNotFoundError(f"Checkpoint not found: {resolved}")
    ck = _torch_load_robust(resolved)
    state = ck.get('model_state_dict', ck) if isinstance(ck, dict) else ck
    num_classes = ck.get('num_classes', None) if isinstance(ck, dict) else None

    model = torchvision.models.vgg11_bn(pretrained=False)
    if num_classes is not None:
        numc = int(num_classes)
        if model.classifier[6].out_features != numc:
            model.classifier[6] = nn.Linear(model.classifier[6].in_features, numc)

    new_state = {}
    for k, v in state.items():
        new_k = k[len('module.'): ] if k.startswith('module.') else k
        new_state[new_k] = v

    try:
        model.load_state_dict(new_state, strict=True)
    except Exception as e:
        print("Strict load failed for unstructured checkpoint:", e)
        res = model.load_state_dict(new_state, strict=False)
        print("Loaded with strict=False. Missing keys:", res.missing_keys, "Unexpected keys:", res.unexpected_keys)

    if device is None:
        device = globals().get('device', 'cuda' if torch.cuda.is_available() else 'cpu')
    model.to(torch.device(device))
    model.eval()
    return model, ck


def load_structured_checkpoint(path: str, device=None) -> Tuple[torch.nn.Module, Dict[str, list], Dict[str,Any]]:
    resolved = _resolve_kaggle_path(path)
    if not os.path.exists(resolved):
        raise FileNotFoundError(f"Checkpoint not found: {resolved}")
    ck = _torch_load_robust(resolved)
    state = ck.get('model_state_dict', ck) if isinstance(ck, dict) else ck
    kept_indices = ck.get('kept_indices', None) if isinstance(ck, dict) else None
    if kept_indices is None:
        raise RuntimeError("Structured checkpoint missing 'kept_indices' — cannot rebuild pruned architecture.")

    normalized = {}
    for k, v in kept_indices.items():
        key = k
        if key.endswith('.weight') or key.endswith('.bias') or key.endswith('.running_mean') or key.endswith('.running_var'):
            key = '.'.join(key.split('.')[:2])
        if isinstance(v, torch.Tensor):
            lst = v.cpu().tolist()
        elif isinstance(v, np.ndarray):
            lst = v.tolist()
        elif isinstance(v, (list, tuple)):
            lst = [int(x) for x in v]
        else:
            lst = [int(v)]
        normalized[key] = lst

    pruned_model, used_kept = build_pruned_vgg(normalized)
    pruned_model.to('cpu') 
    
    new_state = {}
    for k, v in state.items():
        new_k = k[len('module.'):] if k.startswith('module.') else k
        new_state[new_k] = v

    try:
        pruned_model.load_state_dict(new_state, strict=True)
    except Exception as e:
        print("Strict load failed for structured checkpoint (expected if shapes/buffers differ):", e)
        res = pruned_model.load_state_dict(new_state, strict=False)
        print("Loaded with strict=False. Missing keys:", res.missing_keys, "Unexpected keys:", res.unexpected_keys)

    if device is None:
        device = globals().get('device', 'cuda' if torch.cuda.is_available() else 'cpu')
    pruned_model.to(torch.device(device))
    pruned_model.eval()
    return pruned_model, normalized, ck

model_u, ck_u = load_unstructured_checkpoint('/kaggle/input/unstructured/final/final_pruned_state_dict.pth', device=device)

WARN: torch.load raised an unpickling/weights-only error. Retrying with a safe-globals allowlist (only do this for trusted checkpoints).
done


In [None]:

def _torch_load_robust(path):
    try:
        return torch.load(path, map_location='cpu')
    except Exception as e:
        msg = str(e)
        if 'WeightsUnpickler' in msg or 'multiarray._reconstruct' in msg or 'UnpicklingError' in msg:
            print("WARN: torch.load raised an unpickling/weights-only error. Retrying with a safe-globals allowlist.")
            try:
                add_safe_globals([np.core.multiarray._reconstruct])
            except Exception:
                try:
                    add_safe_globals([np._core.multiarray._reconstruct])
                except Exception:
                    pass
            return torch.load(path, map_location='cpu', weights_only=False)
        raise

def _extract_weight_values_from_state_dict(sd, max_total_elems=5_000_000):
    vals_list = []
    total_elems = 0
    for k, v in sd.items():
        if 'weight' not in k:
            continue
        try:
            if isinstance(v, torch.Tensor):
                arr = v.detach().cpu().numpy()
            elif isinstance(v, np.ndarray):
                arr = v
            elif isinstance(v, (list, tuple)):
                arr = np.asarray(v)
            else:
                continue
        except Exception:
            continue
        if arr.size == 0:
            continue
        vals_list.append(arr.ravel())
        total_elems += arr.size

    if len(vals_list) == 0:
        return np.array([], dtype=np.float32)

    if total_elems <= max_total_elems:
        return np.concatenate(vals_list).astype(np.float32)

    print(f"Note: total weight elements = {total_elems:,} > {max_total_elems:,}; sampling for plotting.")
    concat_len = total_elems
    rng = np.random.default_rng(42)
    idx = rng.choice(concat_len, size=max_total_elems, replace=False)
    idx.sort()
    out = np.empty(len(idx), dtype=np.float32)
    pos = 0
    sample_ptr = 0
    for arr in vals_list:
        n = arr.size
        lo = np.searchsorted(idx, pos, side='left')
        hi = np.searchsorted(idx, pos+n, side='left')
        if lo < hi:
            relative = idx[lo:hi] - pos
            out[sample_ptr:sample_ptr + (hi-lo)] = arr.ravel()[relative]
            sample_ptr += (hi-lo)
        pos += n
    return out[:sample_ptr].astype(np.float32)

def plot_weight_distributions(baseline_ckpt, pruned_ckpt, out_path='weight_dists.png', bins=200, log_scale=False, per_layer=False, sample_limit=5_000_000):
   
    for p in (baseline_ckpt, pruned_ckpt):
        if not os.path.exists(p):
            raise FileNotFoundError(f"Checkpoint file not found: {p}")

    def load_and_extract(p):
        ck = _torch_load_robust(p)
        state = ck.get('model_state_dict', ck) if isinstance(ck, dict) else ck
        # ensure it's a mapping
        if not isinstance(state, dict):
            raise RuntimeError("Loaded checkpoint did not contain a state-dict mapping.")
        vals = _extract_weight_values_from_state_dict(state, max_total_elems=sample_limit)
        return vals, state

    w_base, sd_base = load_and_extract(baseline_ckpt)
    w_pruned, sd_pruned = load_and_extract(pruned_ckpt)

    if w_base.size == 0:
        raise RuntimeError("Baseline checkpoint contains no weight tensors matching 'weight' keys.")
    if w_pruned.size == 0:
        print("Warning: pruned checkpoint contains no weight tensors matching 'weight' keys — plot will be empty for pruned.")

    abs_base = np.abs(w_base)
    abs_pruned = np.abs(w_pruned)

    combined = np.concatenate([abs_base, abs_pruned]) if abs_pruned.size>0 else abs_base
    if combined.size == 0:
        raise RuntimeError("No numeric weight data found in either checkpoint.")
    eps = 1e-12
    hist_min, hist_max = combined.min(), combined.max() + eps
    if hist_min == hist_max:
        hist_max = hist_min + 1e-6
    bin_edges = np.linspace(hist_min, hist_max, bins+1)

    plt.figure(figsize=(7,4))
    plt.hist(abs_base, bins=bin_edges, alpha=0.7, label='Before pruning', density=False)
    if abs_pruned.size > 0:
        plt.hist(abs_pruned, bins=bin_edges, alpha=0.7, label='After pruning', density=False)
    plt.xlabel('|Weight|')
    plt.ylabel('Count')
    plt.title('Distribution of Weight Magnitudes (Before vs After Pruning)')
    plt.legend()
    if log_scale:
        plt.yscale('log')
    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()
    print("Saved weight distribution comparison to", out_path)

    if per_layer:
        base_dir = os.path.splitext(out_path)[0] + "_per_layer"
        os.makedirs(base_dir, exist_ok=True)
        def safe_to_array(v):
            try:
                if isinstance(v, torch.Tensor): return v.detach().cpu().numpy().ravel()
                if isinstance(v, np.ndarray): return v.ravel()
                if isinstance(v, (list,tuple)): return np.asarray(v).ravel()
            except Exception:
                return np.array([], dtype=np.float32)
        keys = sorted([k for k in sd_base.keys() if 'weight' in k])
        for k in keys:
            a = safe_to_array(sd_base.get(k, np.array([])))
            b = safe_to_array(sd_pruned.get(k, np.array([])))
            if a.size == 0 and b.size == 0:
                continue
            if a.size > sample_limit:
                rng = np.random.default_rng(0); a = rng.choice(a, sample_limit, replace=False)
            if b.size > sample_limit:
                rng = np.random.default_rng(0); b = rng.choice(b, sample_limit, replace=False)
            plt.figure(figsize=(6,3))
            plt.hist(np.abs(a), bins=50, alpha=0.6, label='before')
            if b.size>0:
                plt.hist(np.abs(b), bins=50, alpha=0.6, label='after')
            plt.title(k); plt.xlabel('|w|'); plt.ylabel('count'); plt.legend()
            plt.tight_layout()
            outk = os.path.join(base_dir, k.replace('.','_') + '.png')
            plt.savefig(outk); plt.close()
        print(f"Saved per-layer plots to {base_dir}")

plot_weight_distributions("baseline_vgg11.pth",
                         "/kaggle/input/unstructured/final/final_pruned_state_dict.pth",  out_path='weight_dists.png', bins=200, log_scale=True, per_layer=False)


Note: total weight elements = 128,799,104 > 5,000,000; sampling for plotting.
WARN: torch.load raised an unpickling/weights-only error. Retrying with a safe-globals allowlist.
Note: total weight elements = 128,799,104 > 5,000,000; sampling for plotting.
Saved weight distribution comparison to weight_dists.png


In [None]:

ranked_path = "/kaggle/input/unstructured/sensitivity/ranked_layers_by_sensitivity.json"  # adjust path
out_dir = "sensitivity_vis"
os.makedirs(out_dir, exist_ok=True)

with open(ranked_path, "r") as f:
    ranked_layers = json.load(f)

if isinstance(ranked_layers, dict):
    items = list(ranked_layers.items())
elif isinstance(ranked_layers, list):
    items = [(k, float(v)) for k, v in ranked_layers]
else:
    raise RuntimeError(f"Unexpected type: {type(ranked_layers)}")

items = sorted(items, key=lambda x: x[1], reverse=True) 
layers = [k for k, _ in items]
drops  = [v for _, v in items]

topN = min(40, len(layers))
plt.figure(figsize=(8, max(3, 0.25 * topN)))
plt.barh(np.arange(topN)[::-1], drops[:topN])
plt.yticks(np.arange(topN), layers[:topN][::-1])
plt.xlabel("Average accuracy drop (%)")
plt.title("Layer sensitivity ranking (higher = more sensitive)")
plt.tight_layout()
plt.savefig(os.path.join(out_dir, "sensitivity_ranked_from_json.png"))
plt.close()
print(f"Saved ranked plot to {os.path.join(out_dir, 'sensitivity_ranked_from_json.png')}")

print("Top 5 most sensitive layers:")
for k, v in items[:5]:
    print(f"  {k:30s}  drop={v:.2f}")
print("Top 5 least sensitive layers:")
for k, v in items[-5:]:
    print(f"  {k:30s}  drop={v:.2f}")


Saved ranked plot to sensitivity_vis/sensitivity_ranked_from_json.png
Top 5 most sensitive layers:
  features.0.weight               drop=18.26
  features.8.weight               drop=6.22
  features.4.weight               drop=4.65
  features.11.weight              drop=4.25
  features.15.weight              drop=2.83
Top 5 least sensitive layers:
  features.22.weight              drop=1.38
  classifier.6.weight             drop=0.12
  features.25.weight              drop=0.09
  classifier.0.weight             drop=0.04
  classifier.3.weight             drop=-0.06


*Part 2*

In [None]:
def collect_convs_from_global_vgg():
  convs = []
  for idx, mod in enumerate(vgg_model.features):
      if isinstance(mod, nn.Conv2d):
          convs.append((f"features.{idx}", mod))
  return convs

def channel_l2_norms(conv_mod):
  w = conv_mod.weight.data.detach()
  norms = torch.norm(w.view(w.shape[0], -1), p=2, dim=1)  
  return norms

def choose_kept_channels(conv_mod, keep_fraction):
  norms = channel_l2_norms(conv_mod).cpu()
  k = max(1, int(round(keep_fraction * norms.numel())))
  vals, idx = torch.topk(norms, k=k, largest=True)
  kept = sorted(idx.tolist())
  return kept

In [None]:

def build_pruned_vgg(kept_map):
    orig = vgg_model  
    orig.eval()
    conv_entries = collect_convs_from_global_vgg()
    conv_names = [n for n,_ in conv_entries]

    kept = {}
    for name, conv in conv_entries:
        if name in kept_map:
            kept[name] = list(kept_map[name])
        else:
            kept[name] = list(range(conv.out_channels))

    new_features = []
    prev_kept = None
    conv_iter = iter(conv_entries)
    for idx, module in enumerate(orig.features):
        if isinstance(module, nn.Conv2d):
            name, old_conv = next(conv_iter)
            kept_out = kept[name]
            if prev_kept is None:
                kept_in = list(range(old_conv.in_channels))
            else:
                kept_in = prev_kept
            new_conv = nn.Conv2d(in_channels=len(kept_in),
                                 out_channels=len(kept_out),
                                 kernel_size=old_conv.kernel_size,
                                 stride=old_conv.stride,
                                 padding=old_conv.padding,
                                 dilation=old_conv.dilation,
                                 groups=old_conv.groups,
                                 bias=(old_conv.bias is not None))
            with torch.no_grad():
                old_w = old_conv.weight.data.cpu()
                new_w = old_w[kept_out, :, :, :].clone()
                new_w = new_w[:, kept_in, :, :].clone()
                new_conv.weight.data.copy_(new_w.to(new_conv.weight.device))
                if old_conv.bias is not None:
                    new_conv.bias.data.copy_(old_conv.bias.data.cpu()[kept_out].to(new_conv.bias.device))
            new_features.append(new_conv)
            prev_kept = kept_out
        elif isinstance(module, nn.BatchNorm2d):
            if prev_kept is None:
                new_features.append(copy.deepcopy(module))
            else:
                kept_out = prev_kept
                new_bn = nn.BatchNorm2d(num_features=len(kept_out))
                with torch.no_grad():
                    new_bn.weight.data.copy_(module.weight.data.cpu()[kept_out].to(new_bn.weight.device))
                    new_bn.bias.data.copy_(module.bias.data.cpu()[kept_out].to(new_bn.bias.device))
                    new_bn.running_mean.data.copy_(module.running_mean.data.cpu()[kept_out].to(new_bn.running_mean.device))
                    new_bn.running_var.data.copy_(module.running_var.data.cpu()[kept_out].to(new_bn.running_var.device))
                new_features.append(new_bn)
        else:
            new_features.append(copy.deepcopy(module))

    new_features = nn.Sequential(*new_features)

    new_vgg = torchvision.models.vgg11_bn(pretrained=False)
    new_vgg.classifier[6] = nn.Linear(in_features=new_vgg.classifier[6].in_features,
                                      out_features=vgg_model.classifier[6].out_features)
    new_vgg.features = new_features

    new_vgg.eval()
    with torch.no_grad():
        dummy = torch.randn(1,3,32,32).to(next(new_vgg.parameters()).device)
        feat_out = new_vgg.features(dummy)
        new_flat = int(feat_out.view(1,-1).shape[1])

    old_clf = vgg_model.classifier
    new_clf = []
    for i, m in enumerate(old_clf):
        if i == 0 and isinstance(m, nn.Linear):
            new_m = nn.Linear(in_features=new_flat, out_features=m.out_features, bias=(m.bias is not None))
            new_clf.append(new_m)
        else:
            new_clf.append(copy.deepcopy(m))
    new_vgg.classifier = nn.Sequential(*new_clf)

    with torch.no_grad():
        orig_for_fc = None
        try:
            baseline_ckpt = globals().get('baseline_ckpt_path', None)
            if baseline_ckpt and os.path.exists(baseline_ckpt):
                ck = torch.load(baseline_ckpt, map_location='cpu')
                orig_for_fc = torchvision.models.vgg11_bn(pretrained=False)
                orig_for_fc.classifier[6] = nn.Linear(orig_for_fc.classifier[6].in_features,
                                                      vgg_model.classifier[6].out_features)
                try:
                    orig_for_fc.load_state_dict(ck['model_state_dict'])
                except Exception:
                    orig_for_fc.load_state_dict(ck, strict=False)
                orig_for_fc.to(next(orig.parameters()).device)
            else:
                orig_for_fc = vgg_model
        except Exception:
            orig_for_fc = vgg_model

        dummy2 = torch.randn(1,3,32,32).to(next(orig_for_fc.parameters()).device)
        orig_feat = orig_for_fc.features(dummy2)
        orig_C = orig_feat.shape[1]
        H = orig_feat.shape[2]
        W = orig_feat.shape[3]

        last_conv_name = conv_names[-1]
        last_kept = kept[last_conv_name]

        old_fc0 = vgg_model.classifier[0]
        old_w0 = old_fc0.weight.data.cpu()  
        old_flat = old_fc0.in_features

        if orig_C * H * W != old_flat:
            spatial = old_flat // orig_C if orig_C > 0 else 1
            h_w = int(math.sqrt(spatial))
            if h_w * h_w == spatial:
                H = W = h_w
            else:
                H = 1
                W = spatial

        block_size = H * W 

        keep_indices = []
        for c in last_kept:
            start = int(c) * block_size
            keep_indices.extend(list(range(start, start + block_size)))
        if len(keep_indices) == 0:
            keep_indices = list(range(min(block_size, old_flat)))

        keep_idx = torch.tensor(keep_indices, dtype=torch.long)

        new_w0 = old_w0.index_select(dim=1, index=keep_idx)  

        new_in = new_w0.shape[1]
        new_out = old_w0.shape[0]
        new_fc0 = nn.Linear(in_features=new_in, out_features=new_out, bias=(old_fc0.bias is not None))
        new_fc0.weight.data.copy_(new_w0.to(new_fc0.weight.device))
        new_fc0.bias.data.copy_(old_fc0.bias.data.clone().to(new_fc0.bias.device))
        new_vgg.classifier[0] = new_fc0.to(next(new_vgg.parameters()).device)

        for i in range(1, len(old_clf)):
            if isinstance(old_clf[i], nn.Linear) and isinstance(new_vgg.classifier[i], nn.Linear):
                ow = old_clf[i].weight.data.clone()
                ob = old_clf[i].bias.data.clone()
                if new_vgg.classifier[i].weight.data.shape == ow.shape:
                    new_vgg.classifier[i].weight.data.copy_(ow.to(new_vgg.classifier[i].weight.device))
                    new_vgg.classifier[i].bias.data.copy_(ob.to(new_vgg.classifier[i].bias.device))
                else:
                    min_r = min(ow.shape[0], new_vgg.classifier[i].weight.data.shape[0])
                    min_c = min(ow.shape[1], new_vgg.classifier[i].weight.data.shape[1])
                    if min_r > 0 and min_c > 0:
                        new_vgg.classifier[i].weight.data[:min_r,:min_c].copy_(ow[:min_r,:min_c].to(new_vgg.classifier[i].weight.device))
                        new_vgg.classifier[i].bias.data[:min_r].copy_(ob[:min_r].to(new_vgg.classifier[i].bias.device))

    new_vgg.to(device)
    return new_vgg, kept


In [None]:
def structured_sensitivity(sparsity_list_percent=[50,70], out_dir='structured_sensitivity', eval_loader=None):
 
    os.makedirs(out_dir, exist_ok=True)
    convs = collect_convs_from_global_vgg()
    results = {}

    if eval_loader is None:
        eval_loader = CIFAR10_testloader

    baseline = vgg_model
    baseline.to(device)
    baseline.eval()
    baseline_acc = evaluate(baseline, eval_loader)
    print(f"[structured_sensitivity] Baseline accuracy on eval_loader: {baseline_acc:.2f}%")

    for (layer_name, conv_mod) in convs:
        results[layer_name] = {'sparsities': [], 'acc_no_ft': []}
        norms_before = channel_l2_norms(conv_mod).cpu().numpy()
        for s in sparsity_list_percent:
            sp = float(s) / 100.0
            keep_frac = max(0.0, 1.0 - sp)

            kept_map = {}
            for nm, cm in convs:
                if nm == layer_name:
                    kept_map[nm] = choose_kept_channels(cm, keep_frac)
                else:
                    kept_map[nm] = list(range(cm.out_channels))

            pruned_net, kept_used = build_pruned_vgg(kept_map)
            pruned_net.to(device)
            pruned_net.eval()

            acc_no = evaluate(pruned_net, eval_loader)
            results[layer_name]['sparsities'].append(s)
            results[layer_name]['acc_no_ft'].append(acc_no)
            print(f"[structured_sens] {layer_name} @ {s}% -> no_ft {acc_no:.2f}")

        plt.figure(figsize=(6,4))
        plt.plot(results[layer_name]['sparsities'], results[layer_name]['acc_no_ft'], marker='o', label='no finetune')
        plt.xlabel('Sparsity (%)'); plt.ylabel('Validation Accuracy (%)'); plt.title(layer_name)
        plt.legend(); plt.grid(True)
        plt.tight_layout()
        plt.savefig(os.path.join(out_dir, f"{layer_name.replace('.','_')}_structured_sensitivity.png"))
        plt.close()

    np.save(os.path.join(out_dir, 'structured_sensitivity_results.npy'), results)
    avg_noft = {n: float(np.mean(results[n]['acc_no_ft'])) for n in results}
    ranked = sorted(avg_noft.items(), key=lambda x: x[1], reverse=True)
    ranked_path = os.path.join(out_dir, 'ranked_layers_by_noft.json')
    with open(ranked_path, 'w') as f:
        json.dump(ranked, f, indent=2)
    print(f"Saved structured sensitivity results and ranking to {out_dir} (ranking -> {ranked_path})")
    return results, baseline_acc


In [None]:
vgg_model.to(device)
vgg_model.eval()  

struct_results_quick, baseline_acc = structured_sensitivity( sparsity_list_percent=[50,70], out_dir='structured_quick', eval_loader=CIFAR10_valloader)

[structured_sensitivity] Baseline accuracy on eval_loader: 79.62%
[structured_sens] features.0 @ 50% -> no_ft 50.24
[structured_sens] features.0 @ 70% -> no_ft 31.24
[structured_sens] features.4 @ 50% -> no_ft 40.48
[structured_sens] features.4 @ 70% -> no_ft 25.42
[structured_sens] features.8 @ 50% -> no_ft 58.04
[structured_sens] features.8 @ 70% -> no_ft 42.78
[structured_sens] features.11 @ 50% -> no_ft 44.82
[structured_sens] features.11 @ 70% -> no_ft 22.32
[structured_sens] features.15 @ 50% -> no_ft 66.38
[structured_sens] features.15 @ 70% -> no_ft 28.48
[structured_sens] features.18 @ 50% -> no_ft 49.94
[structured_sens] features.18 @ 70% -> no_ft 14.30
[structured_sens] features.22 @ 50% -> no_ft 76.46
[structured_sens] features.22 @ 70% -> no_ft 64.60
[structured_sens] features.25 @ 50% -> no_ft 78.72
[structured_sens] features.25 @ 70% -> no_ft 69.60
Saved structured sensitivity results and ranking to structured_quick (ranking -> structured_quick/ranked_layers_by_noft.json

In [None]:
def structured_finetune_topk(ranked_json='structured_quick/ranked_layers_by_noft.json', top_k=8, sparsity_percent=70, short_ft_epochs=2, eval_loader=None, out_dir='structured_topk'):
    os.makedirs(out_dir, exist_ok=True)
    ranked = json.load(open(ranked_json))
    top_layers = [r[0] for r in ranked[:top_k]]
    print("Top-k least-sensitive layers:", top_layers)

    convs = collect_convs_from_global_vgg()
    results = {}
    if eval_loader is None:
        eval_loader = CIFAR10_testloader

    for name, conv in convs:
        results[name] = {'sparsity': sparsity_percent, 'acc_no_ft': None, 'acc_after_ft': None}
        keep_map = {}
        for nm, cm in convs:
            if nm == name:
                keep_frac = max(0.0, 1.0 - float(sparsity_percent)/100.0)
                keep_map[nm] = choose_kept_channels(cm, keep_frac)
            else:
                keep_map[nm] = list(range(cm.out_channels))
        pruned_net, kept_used = build_pruned_vgg(keep_map)
        pruned_net.to(device)
        pruned_net.eval()

        acc_no = evaluate(pruned_net, eval_loader)
        results[name]['acc_no_ft'] = acc_no

        if name in top_layers:
            criterion = nn.CrossEntropyLoss()
            optimizer = optim.SGD(pruned_net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
            best_val = acc_no
            no_imp = 0
            for ep in range(short_ft_epochs):
                train_one_epoch(pruned_net, CIFAR10_trainloader, criterion, optimizer)
                val = evaluate(pruned_net, eval_loader)
                if val > best_val + 1e-6:
                    best_val = val; no_imp = 0
                else:
                    no_imp += 1
                if no_imp >= 1:
                    break
            results[name]['acc_after_ft'] = best_val
            print(f"[topk_ft] {name} @ {sparsity_percent}% -> no_ft {acc_no:.2f}, after_ft {best_val:.2f}")
        else:
            results[name]['acc_after_ft'] = None
            print(f"[topk_skip] {name} @ {sparsity_percent}% -> no_ft {acc_no:.2f} (skipped ft)")

    np.save(os.path.join(out_dir, 'structured_topk_results.npy'), results)
    return results


In [None]:
topk_results = structured_finetune_topk(ranked_json='structured_quick/ranked_layers_by_noft.json', top_k=8, sparsity_percent=70, 
                                        short_ft_epochs=2, eval_loader=CIFAR10_valloader, out_dir='structured_topk')

Top-k least-sensitive layers: ['features.25', 'features.22', 'features.18', 'features.0', 'features.15', 'features.8', 'features.11', 'features.4']
[topk_ft] features.0 @ 70% -> no_ft 39.40, after_ft 73.75
[topk_ft] features.4 @ 70% -> no_ft 21.50, after_ft 73.30
[topk_ft] features.8 @ 70% -> no_ft 18.85, after_ft 70.10
[topk_ft] features.11 @ 70% -> no_ft 17.00, after_ft 73.55
[topk_ft] features.15 @ 70% -> no_ft 25.95, after_ft 76.15
[topk_ft] features.18 @ 70% -> no_ft 35.80, after_ft 72.60
[topk_ft] features.22 @ 70% -> no_ft 49.40, after_ft 73.65
[topk_ft] features.25 @ 70% -> no_ft 76.10, after_ft 76.10


In [None]:

def suggest_structured_allocation(conv_list=None, sensitivity_results=None, target_param_reduction=0.70, max_prune_per_layer=0.90, anchor_sparsity=50):

    if conv_list is None:
        conv_list = collect_convs_from_global_vgg()

    def param_count(m): return sum(p.numel() for p in m.parameters())
    orig_total_params = param_count(vgg_model)
    target_prune_params = int(round(target_param_reduction * orig_total_params))

    layers = []
    for name, conv in conv_list:
        oc = conv.out_channels
        ic = conv.in_channels
        k_h, k_w = conv.kernel_size if isinstance(conv.kernel_size, tuple) else (conv.kernel_size, conv.kernel_size)
        cost_per_out = int(ic * k_h * k_w) 
        layers.append({
            'name': name,
            'module': conv,
            'out_channels': oc,
            'in_channels': ic,
            'cost_per_out': cost_per_out
        })

    scores = {}
    for L in layers:
        nm = L['name']
        score = None
        if sensitivity_results and nm in sensitivity_results:
            res = sensitivity_results[nm]
            spars_list = res.get('sparsities', [])
            accs = res.get('acc_after_ft', [])
            if len(spars_list) and len(accs):
                idx = min(range(len(spars_list)), key=lambda i: abs(spars_list[i] - anchor_sparsity))
                score = float(accs[idx])
        if score is None:
            score = 50.0
        scores[nm] = score

    layers_sorted = sorted(layers, key=lambda x: scores.get(x['name'], 0.0), reverse=True)

    removed_params = 0
    channels_to_remove = {L['name']: 0 for L in layers}
    for L in layers_sorted:
        name = L['name']
        oc = L['out_channels']
        cost = L['cost_per_out']
        max_remove = int(math.floor(max_prune_per_layer * oc))
        remaining = max(0, target_prune_params - removed_params)
        if remaining <= 0:
            break
        need = cost and int(math.floor(remaining / cost)) or max_remove
        remove_here = min(max_remove, need)
        remove_here = min(remove_here, max(0, oc-1))
        channels_to_remove[name] = remove_here
        removed_params += remove_here * cost

    if removed_params < target_prune_params:
        candidates = []
        for L in layers_sorted:
            name = L['name']
            oc = L['out_channels']
            cost = L['cost_per_out']
            max_remove = int(math.floor(max_prune_per_layer * oc))
            remain_slots = max_remove - channels_to_remove[name]
            for _ in range(remain_slots):
                candidates.append((scores.get(name, 0.0), cost, name))
        candidates.sort(key=lambda x: (x[0], x[1]), reverse=True)
        idx = 0
        while removed_params < target_prune_params and idx < len(candidates):
            _, cost, name = candidates[idx]
            Linfo = next(x for x in layers if x['name'] == name)
            if channels_to_remove[name] < int(math.floor(max_prune_per_layer * Linfo['out_channels'])) and channels_to_remove[name] < Linfo['out_channels'] - 1:
                channels_to_remove[name] += 1
                removed_params += cost
            idx += 1

    per_layer_sparsities = {}
    for L in layers:
        name = L['name']
        oc = L['out_channels']
        rem = channels_to_remove.get(name, 0)
        frac = float(rem) / float(oc)
        frac = min(max(frac, 0.0), 0.95)
        per_layer_sparsities[name] = frac

    estimated_removed = sum(channels_to_remove[n] * next(l['cost_per_out'] for l in layers if l['name']==n) for n in channels_to_remove)
    achieved_fraction = float(estimated_removed) / float(orig_total_params) if orig_total_params > 0 else 0.0

    info = {
        'orig_total_params': orig_total_params,
        'target_prune_params': target_prune_params,
        'estimated_removed_params': estimated_removed,
        'estimated_fraction_reduction': achieved_fraction,
        'channels_to_remove': channels_to_remove,
        'sensitivity_scores': scores
    }
    return per_layer_sparsities, info


In [None]:
conv_list = collect_convs_from_global_vgg()  
per_layer_sparsities, info = suggest_structured_allocation(conv_list=conv_list,
                                                           sensitivity_results=struct_results_quick,
                                                           target_param_reduction=0.70,
                                                           max_prune_per_layer=0.9,
                                                           anchor_sparsity=50)
print("Estimated global reduction:", info['estimated_fraction_reduction'])
import json
with open('suggested_struct_sparsities.json','w') as f:
    json.dump(per_layer_sparsities, f, indent=2)


Estimated global reduction: 0.0642912688574995


In [None]:
def apply_structured_and_finetune(per_layer_sparsities, finetune_epochs=20, out_dir='structured_final', 
                                  use_unstructured_ranked_path='unstructured_outputs/sensitivity/ranked_layers_by_sensitivity.json', 
                                  use_structured_ranked_path='structured_quick/ranked_layers_by_noft.json', ranked_k=8):

    os.makedirs(out_dir, exist_ok=True)

    convs = collect_convs_from_global_vgg() 
    conv_names = [n for n,_ in convs]

    candidate_layers = None

    if use_structured_ranked_path and os.path.exists(use_structured_ranked_path):
        try:
            sr = json.load(open(use_structured_ranked_path, 'r'))
            sr_names = [x[0] for x in sr]
            candidate_layers = [n for n in sr_names if n in conv_names]
            print(f"Loaded structured ranking, candidate conv layers (top {ranked_k}): {candidate_layers[:ranked_k]}")
        except Exception as e:
            print("Failed to load structured ranked file:", e)

    if candidate_layers is None and use_unstructured_ranked_path and os.path.exists(use_unstructured_ranked_path):
        try:
            ur = json.load(open(use_unstructured_ranked_path, 'r'))
            ur_names = [x[0] for x in ur]
            candidate_layers = [n for n in ur_names if n in conv_names]
            print(f"Loaded unstructured ranking, candidate conv layers (top {ranked_k}): {candidate_layers[:ranked_k]}")
        except Exception as e:
            print("Failed to load unstructured ranked file:", e)

    if candidate_layers is None:
        candidate_layers = conv_names.copy()
        print("No ranking files found; defaulting to all conv layers as candidates.")

    prune_candidates = candidate_layers[:ranked_k] if ranked_k is not None else candidate_layers
    print("Prune candidate convs used:", prune_candidates)

    kept_map = {}
    for name, conv_mod in convs:
        requested_frac = float(per_layer_sparsities.get(name, 0.0))
        if name in prune_candidates:
            spars = requested_frac
        else:
            spars = 0.0
        keep_frac = max(0.0, 1.0 - spars)
        kept_map[name] = choose_kept_channels(conv_mod, keep_frac)

    print("Per-layer kept channels (orig -> kept):")
    for name, conv_mod in convs:
        orig_ch = conv_mod.out_channels
        kept_ch = len(kept_map[name])
        print(f"  {name}: {orig_ch} -> {kept_ch}  (prune frac requested {per_layer_sparsities.get(name,0.0):.3f})")

    pruned_model, kept_used = build_pruned_vgg(kept_map)
    pruned_model.to(device)
    pruned_model.eval()

    acc_before = evaluate(pruned_model, CIFAR10_testloader)
    print(f"[structured_final] val acc immediately after pruning: {acc_before:.2f}")

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(pruned_model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
    best_val = acc_before
    no_imp = 0
    for ep in range(finetune_epochs):
        train_one_epoch(pruned_model, CIFAR10_trainloader, criterion, optimizer)
        val_acc = evaluate(pruned_model, CIFAR10_testloader)
        print(f"[structured_final] finetune epoch {ep+1}/{finetune_epochs} val_acc={val_acc:.2f}")
        if val_acc > best_val + 1e-6:
            best_val = val_acc
            no_imp = 0
        else:
            no_imp += 1
        if no_imp >= 2:
            print("Early stopping: no improvement in 2 epochs.")
            break

    save_path = os.path.join(out_dir, 'structured_pruned_final.pth')

    def param_count(m): return sum(p.numel() for p in m.parameters())
    orig_params = param_count(vgg_model)
    new_params = param_count(pruned_model)
    global_sparsity = 1.0 - (new_params / orig_params) if orig_params>0 else 0.0
    summary = {'orig_params': orig_params, 'pruned_params': new_params, 'global_sparsity': global_sparsity}
    with open(os.path.join(out_dir, 'structured_prune_summary.json'), 'w') as f:
        json.dump(summary, f, indent=2)

    pruned_sd_cpu = {k: v.cpu() for k,v in pruned_model.state_dict().items()}
    torch.save({
        'model_state_dict': pruned_sd_cpu,
        'num_classes': num_classes,
        'arch': 'vgg11_bn',
        'kept_indices': kept_used,
        'structured_per_layer_sparsities': per_layer_sparsities,  
        'global_sparsity': summary['global_sparsity']
    }, save_path) 
    print("Saved final structured state_dict:", save_path)
    print(f"Saved final structured pruned model to {save_path}. Global sparsity (by params): {global_sparsity:.4f}")
    return pruned_model, kept_used, summary


In [26]:
pruned_model, kept_used, summary = apply_structured_and_finetune(per_layer_sparsities, finetune_epochs=10, out_dir='structured_final')
print("Structured summary:", summary)

Loaded structured ranking, candidate conv layers (top 8): ['features.25', 'features.22', 'features.18', 'features.0', 'features.15', 'features.8', 'features.11', 'features.4']
Prune candidate convs used: ['features.25', 'features.22', 'features.18', 'features.0', 'features.15', 'features.8', 'features.11', 'features.4']
Per-layer kept channels (orig -> kept):
  features.0: 64 -> 7  (prune frac requested 0.891)
  features.4: 128 -> 13  (prune frac requested 0.898)
  features.8: 256 -> 26  (prune frac requested 0.898)
  features.11: 256 -> 26  (prune frac requested 0.898)
  features.15: 512 -> 52  (prune frac requested 0.898)
  features.18: 512 -> 52  (prune frac requested 0.898)
  features.22: 512 -> 52  (prune frac requested 0.898)
  features.25: 512 -> 52  (prune frac requested 0.898)
[structured_final] val acc immediately after pruning: 10.00
[structured_final] finetune epoch 1/10 val_acc=44.71
[structured_final] finetune epoch 2/10 val_acc=57.72
[structured_final] finetune epoch 3/1

In [None]:
pruned_model_s, kept_map_s, ck_s = load_structured_checkpoint('structured-f/structured_pruned_final.pth', device=device)
plot_weight_distributions("baseline_vgg11.pth", "/kaggle/input/structured-f/structured_pruned_final.pth", 
                          out_path='weight_dists_struct.png', bins=200, log_scale=True, per_layer=False)

Note: total weight elements = 128,799,104 > 5,000,000; sampling for plotting.
Note: total weight elements = 27,350,374 > 5,000,000; sampling for plotting.
Saved weight distribution comparison to weight_dists_struct.png


In [None]:

ranked_path = "/kaggle/working/structured_quick/ranked_layers_by_noft.json" 
out_dir = "sensitivity_vis_struct"
os.makedirs(out_dir, exist_ok=True)

with open(ranked_path, "r") as f:
    ranked_layers = json.load(f)

if isinstance(ranked_layers, dict):
    items = list(ranked_layers.items())
elif isinstance(ranked_layers, list):
    items = [(k, float(v)) for k, v in ranked_layers]
else:
    raise RuntimeError(f"Unexpected type: {type(ranked_layers)}")

items = sorted(items, key=lambda x: x[1], reverse=True)  

layers = [k for k, _ in items]
drops  = [v for _, v in items]

topN = min(40, len(layers))
plt.figure(figsize=(8, max(3, 0.25 * topN)))
plt.barh(np.arange(topN)[::-1], drops[:topN])
plt.yticks(np.arange(topN), layers[:topN][::-1])
plt.xlabel("Average accuracy drop (%)")
plt.title("Layer sensitivity ranking (higher = more sensitive)")
plt.tight_layout()
plt.savefig(os.path.join(out_dir, "sensitivity_ranked_from_json.png"))
plt.close()
print(f"Saved ranked plot to {os.path.join(out_dir, 'sensitivity_ranked_struct.png')}")

print("Top 5 most sensitive layers:")
for k, v in items[:5]:
    print(f"  {k:30s}  drop={v:.2f}")
print("Top 5 least sensitive layers:")
for k, v in items[-5:]:
    print(f"  {k:30s}  drop={v:.2f}")


Saved ranked plot to sensitivity_vis_struct/sensitivity_ranked_struct.png
Top 5 most sensitive layers:
  features.25                     drop=74.16
  features.22                     drop=70.53
  features.8                      drop=50.41
  features.15                     drop=47.43
  features.0                      drop=40.74
Top 5 least sensitive layers:
  features.15                     drop=47.43
  features.0                      drop=40.74
  features.11                     drop=33.57
  features.4                      drop=32.95
  features.18                     drop=32.12


*Part 3*

In [None]:

def load_vgg_from_checkpoint(ckpt_path, device=None, prefer_state_key='model_state_dict'):

    if device is None:
        try:
            device = globals().get('device', 'cpu')
        except Exception:
            device = 'cpu'
    device = torch.device(device)

    if not os.path.exists(ckpt_path):
        raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")

    ck = torch.load(ckpt_path, map_location='cpu')

    if isinstance(ck, dict) and prefer_state_key in ck:
        state = ck[prefer_state_key]
    elif isinstance(ck, dict) and 'state_dict' in ck:
        state = ck['state_dict']
    else:
        state = ck

    if not isinstance(state, dict):
        raise RuntimeError("Loaded checkpoint does not contain a state-dict mapping.")

    num_classes = None
    if 'classifier.6.weight' in state:
        w = state['classifier.6.weight']
        try:
            num_classes = int(w.shape[0])
        except Exception:
            pass
    if num_classes is None and 'classifier.6.bias' in state:
        b = state['classifier.6.bias']
        try:
            num_classes = int(b.shape[0])
        except Exception:
            pass

    if num_classes is None:
        for k, v in state.items():
            if k.endswith('classifier.6.weight') or k.endswith('classifier.6.bias'):
                try:
                    num_classes = int(state[k].shape[0])
                    break
                except Exception:
                    pass

    if num_classes is None:
        raise RuntimeError(
            "Could not infer classifier output dimension from checkpoint. "
            "Please pass a checkpoint that contains 'classifier.6.weight' or provide a cleaned state_dict."
        )

    model = torchvision.models.vgg11_bn(pretrained=False)
    in_features = model.classifier[6].in_features  
    if model.classifier[6].out_features != num_classes:
        model.classifier[6] = nn.Linear(in_features=in_features, out_features=num_classes, bias=True)
    new_state = {}
    for k, v in state.items():
        new_k = k
        if k.startswith('module.'):
            new_k = k[len('module.'):]
        new_state[new_k] = v

    try:
        model.load_state_dict(new_state, strict=True)
    except RuntimeError as e:
        err_msg = str(e)
        print("Strict load failed:", err_msg)
        load_res = model.load_state_dict(new_state, strict=False)  
        print("Loaded state_dict with strict=False. Missing keys:", load_res.missing_keys)
        print("Unexpected keys:", load_res.unexpected_keys)
        mismatches = []
        for k, v in new_state.items():
            if k in dict(model.named_parameters()) or k in dict(model.named_buffers()):
                try:
                    model_tensor = dict(model.named_parameters()).get(k, None)
                    if model_tensor is None:
                        model_tensor = dict(model.named_buffers()).get(k, None)
                    if model_tensor is not None:
                        if tuple(model_tensor.shape) != tuple(v.shape):
                            mismatches.append((k, tuple(v.shape), tuple(model_tensor.shape)))
                except Exception:
                    pass
        if mismatches:
            print("Parameter shape mismatches (loaded_shape -> model_shape):")
            for k, ls, ms in mismatches:
                print(f"  {k}: {ls} -> {ms}")
        else:
            print("No parameter shape mismatches detected beyond missing/unexpected keys.")
    model.to(device)
    model.eval()
    print(f"Loaded VGG from checkpoint '{ckpt_path}' with classifier out_features = {num_classes}")
    return model

def get_module_by_qname(model, qname):
    parts = qname.split('.')
    mod = model
    for p in parts:
        if p.isdigit():
            mod = mod[int(p)]
        else:
            if not hasattr(mod, p):
                raise AttributeError(f"Module {mod} has no attribute {p} (qname={qname})")
            mod = getattr(mod, p)
    return mod

def get_default_target_layer_name(model):
    last_idx = None
    for idx, module in enumerate(model.features):
        if isinstance(module, nn.Conv2d):
            last_idx = idx
    if last_idx is None:
        raise RuntimeError("No Conv2d found in model.features")
    return f"features.{last_idx}"

In [None]:
def compute_gradcam(model, input_tensor, target_layer_name=None, class_idx=None, device=device):
    model.eval()
    if target_layer_name is None:
        target_layer_name = get_default_target_layer_name(model)

    activations = None
    gradients = None

    def forward_hook(module, inp, out):
        nonlocal activations
        activations = out.detach()

    def backward_hook(module, grad_input, grad_output):
        nonlocal gradients
        gradients = grad_output[0].detach()

    target_module = get_module_by_qname(model, target_layer_name)
    fh = target_module.register_forward_hook(forward_hook)
    try:
        bh = target_module.register_full_backward_hook(backward_hook)
    except AttributeError:
        bh = target_module.register_backward_hook(backward_hook)

    input_tensor = input_tensor.to(device)
    model.to(device)
    input_tensor.requires_grad_(True)

    out = model(input_tensor)             
    probs = F.softmax(out, dim=1)
    pred_class = int(probs.argmax(dim=1).item())
    if class_idx is None:
        class_idx = pred_class

    model.zero_grad()
    score = out[0, class_idx]
    score.backward(retain_graph=False)

    if activations is None or gradients is None:
        fh.remove(); bh.remove()
        raise RuntimeError("Hooks did not capture activations/gradients. Check target_layer_name.")

    weights = gradients.mean(dim=(2,3), keepdim=False).squeeze(0)  
    cam = (weights.view(-1,1,1) * activations[0]).sum(dim=0)    
    cam = F.relu(cam)
    cam_np = cam.cpu().numpy()
    cam_np -= cam_np.min()
    if cam_np.max() > 0:
        cam_np = cam_np / cam_np.max()

    _, _, H_in, W_in = input_tensor.shape
    cam_tensor = torch.from_numpy(cam_np).unsqueeze(0).unsqueeze(0).float()  # (1,1,Hf,Wf)
    cam_resized = F.interpolate(cam_tensor, size=(H_in, W_in), mode='bilinear', align_corners=False)
    cam_resized = cam_resized.squeeze().cpu().numpy()
    fh.remove(); bh.remove()
    return cam_resized, pred_class

def overlay_and_save(original_img_tensor, heatmap, out_path, alpha=0.5):
    mean = np.array([0.4914, 0.4822, 0.4465])
    std  = np.array([0.2470, 0.2435, 0.2616])
    img = original_img_tensor.detach().cpu().numpy().transpose(1,2,0)  # H,W,C normalized
    img = (img * std + mean)
    img = np.clip(img, 0, 1)

    H, W = img.shape[0], img.shape[1]
    if heatmap.shape != (H, W):
        heatmap_img = Image.fromarray(np.uint8(heatmap * 255))
        heatmap_img = heatmap_img.resize((W, H), resample=Image.BILINEAR)
        heatmap_resized = np.array(heatmap_img) / 255.0
    else:
        heatmap_resized = heatmap

    cmap = plt.get_cmap('jet')
    heat_col = cmap(heatmap_resized)[:, :, :3]  
    overlay = (1-alpha) * img + alpha * heat_col
    overlay = np.clip(overlay, 0, 1)

    fig, axs = plt.subplots(1,3, figsize=(9,3))
    axs[0].imshow(img); axs[0].set_title("Original"); axs[0].axis('off')
    axs[1].imshow(heatmap_resized, cmap='jet'); axs[1].set_title("Grad-CAM"); axs[1].axis('off')
    axs[2].imshow(overlay); axs[2].set_title("Overlay"); axs[2].axis('off')
    plt.tight_layout()
    plt.savefig(out_path)
    plt.close(fig)

In [None]:
def compare_gradcam_models(models_info, image_indices=[0], target_layer_name=None, out_dir='gradcam_compare', device=None):
 
    if device is None:
        device = globals().get('device', torch.device('cpu'))
    os.makedirs(out_dir, exist_ok=True)

    label_model_map = {}
    for label, info in models_info.items():
        if isinstance(info, tuple) and info[0] in ('original','unstructured','structured'):
            typ, path = info
            if typ == 'original':
                m = vgg_model
                m.to(device)
                m.eval()
                label_model_map[label] = m
            else:
                if path is None or (isinstance(path, str) and not os.path.exists(path)):
                    raise ValueError(f"Checkpoint path required & must exist for model label {label}: got {path}")
                ck = torch.load(path, map_location='cpu')
                if isinstance(ck, dict) and 'kept_indices' in ck:
                    m, kept, ckmeta = load_structured_checkpoint(path, device=device)   
                    label_model_map[label] = m
                else:
                    m, ckmeta = load_unstructured_checkpoint(path, device=device)     
                    label_model_map[label] = m
        elif isinstance(info, nn.Module):
            label_model_map[label] = info.to(device)
        elif isinstance(info, str) and os.path.exists(info):
            ck = torch.load(info, map_location='cpu')
            if isinstance(ck, dict) and 'kept_indices' in ck:
                m, kept, ckmeta = load_structured_checkpoint(info, device=device)
                label_model_map[label] = m
            else:
                m, ckmeta = load_unstructured_checkpoint(info, device=device)
                label_model_map[label] = m
        else:
            raise ValueError(f"Unsupported models_info entry for '{label}': {info}")

    dataset = CIFAR10_testset
    for idx in image_indices:
        img, gt = dataset[idx]
        inp = img.unsqueeze(0).to(device)   
        for label, model in label_model_map.items():
            try:
                cam, pred = compute_gradcam(model, inp.clone(), target_layer_name=target_layer_name, class_idx=None, device=device)
            except Exception as e:
                print(f"[GradCAM] failed for {label} idx {idx}: {e}")
                continue
            out_path = os.path.join(out_dir, f"{label}_img{idx}.png")
            overlay_and_save(img, cam, out_path)  
            print(f"[GradCAM] saved {out_path} (pred {pred})")
    print("Saved Grad-CAM comparisons to", out_dir)

def measure_inference_time(model, input_tensor, runs=200, warmup=20, device=None):
    if device is None:
        device = globals().get('device', torch.device('cpu'))
    device = torch.device(device)
    model.to(device)
    model.eval()
    inp = input_tensor.to(device)
    with torch.no_grad():
        for _ in range(warmup):
            _ = model(inp)
    times = []
    with torch.no_grad():
        for _ in range(runs):
            t0 = time.time()
            _ = model(inp)
            if device.type == 'cuda':
                torch.cuda.synchronize()
            t1 = time.time()
            times.append((t1 - t0) * 1000.0)
    times = np.array(times)
    return float(times.mean()), float(times.std()), times.tolist()


In [None]:
def load_checkpoint_auto(ckpt_path, device=None):
    if device is None:
        device = globals().get('device', torch.device('cpu'))
    device = torch.device(device)

    if not os.path.exists(ckpt_path):
        raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")

    ck = torch.load(ckpt_path, map_location='cpu')

    if not isinstance(ck, dict):
        try:
            m = torchvision.models.vgg11_bn(pretrained=False)
            return load_vgg_from_checkpoint(ckpt_path, device=device), {}
        except Exception:
            raise RuntimeError("Unable to load checkpoint (not dict and not a simple state_dict).")

    if 'kept_indices' in ck or 'kept_indices' in ck.get('model_state_dict', {}):
        try:
            model, kept, meta = load_structured_checkpoint(ckpt_path, device=device)
            return model, meta
        except NameError:
            kept_indices = ck.get('kept_indices') or ck.get('model_state_dict', {}).get('kept_indices')
            if kept_indices is None:
                raise RuntimeError("kept_indices present but loader function missing.")
            pruned_model, _ = build_pruned_vgg(kept_indices)
            pruned_model.to(device)
            sd = ck.get('model_state_dict', None) or ck
            try:
                pruned_model.load_state_dict(sd, strict=False)
            except Exception as e:
                print("Warning: loading structured checkpoint with strict=False:", e)
            pruned_model.eval()
            return pruned_model, ck
    try:
        model = load_vgg_from_checkpoint(ckpt_path, device=device)
        return model, ck
    except Exception as e:
        m = torchvision.models.vgg11_bn(pretrained=False)
        if isinstance(ck, dict):
            state = ck.get('model_state_dict', ck)
            if 'classifier.6.weight' in state:
                try:
                    outdim = int(state['classifier.6.weight'].shape[0])
                    m.classifier[6] = nn.Linear(m.classifier[6].in_features, outdim)
                except Exception:
                    pass
        try:
            m.load_state_dict(state, strict=False)
            m.to(device); m.eval()
            return m, ck
        except Exception as e2:
            raise RuntimeError(f"Failed to auto-load checkpoint {ckpt_path}: {e2}") from e2


In [None]:
def model_checkpoint_size_bytes(ckpt_path):
    if not os.path.exists(ckpt_path):
        return None
    return os.path.getsize(ckpt_path)

def compute_sparsity_from_state_dict(state_dict):
    total = 0
    nz = 0
    for k, v in state_dict.items():
        if 'weight' not in k:
            continue
        try:
            if isinstance(v, torch.Tensor):
                arr = v.detach().cpu().numpy()
            else:
                arr = np.asarray(v)
        except Exception:
            continue
        total += arr.size
        nz += np.count_nonzero(arr)
    if total == 0:
        return None
    return 1.0 - (nz / total)

def evaluate_checkpoint_accuracy(ckpt_path, batch_size=128):
    if not os.path.exists(ckpt_path):
        raise FileNotFoundError(ckpt_path)
    model, ckmeta = load_checkpoint_auto(ckpt_path, device=device)
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x,y in CIFAR10_testloader:
            x,y = x.to(device), y.to(device)
            out = model(x)
            _, pred = out.max(1)
            correct += pred.eq(y).sum().item()
            total += y.size(0)
    return 100.0 * correct / total


In [None]:
def accuracy_vs_sparsity_plot(readings, out_path='acc_vs_spars.png'):
    xs = []
    ys = []
    notes = []
    for s, ck in readings:
      s_pct = float(s)*100.0 if (0.0 <= float(s) <= 1.0) else float(s)
      if not os.path.exists(ck):
          print(f"Skipping missing checkpoint {ck}")
          continue
      try:
          acc = evaluate_checkpoint_accuracy(ck)
      except Exception as e:
          print(f"Failed to eval {ck}: {e}")
          continue
      xs.append(s_pct); ys.append(acc); notes.append((s_pct, acc))
      print(f"  -> accuracy: {acc:.2f}%")
    order = np.argsort(xs)
    xs_sorted = np.array(xs)[order]
    ys_sorted = np.array(ys)[order]
    plt.figure(figsize=(6,4))
    plt.plot(xs_sorted, ys_sorted, marker='o')
    plt.xlabel('Target sparsity (%)')
    plt.ylabel('Validation accuracy (%)')
    plt.title('Accuracy vs target sparsity')
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(out_path)
    plt.close()
    print("Saved accuracy vs sparsity plot to", out_path)
    return notes

def summarize_models(original_label, original_model_or_ckpt, unstructured_label, unstructured_ckpt, structured_label, 
                     structured_ckpt, sample_image_index=0, runs=200, out_dir='model_summary'):
  
    os.makedirs(out_dir, exist_ok=True)
    def load_entry(entry):
      if isinstance(entry, nn.Module):
          return entry
      if entry == 'original':
          return vgg_model
      if isinstance(entry, str) and os.path.exists(entry):
          m, ckmeta = load_checkpoint_auto(entry, device=device)
          return m
      raise ValueError("Entry must be 'original' or path or nn.Module")


    models = {}
    ckpt_paths = {'original': None, 'unstructured': None, 'structured': None}
    models[original_label] = load_entry(original_model_or_ckpt)
    models[unstructured_label] = load_entry(unstructured_ckpt)
    models[structured_label] = load_entry(structured_ckpt)
    if isinstance(original_model_or_ckpt, str) and os.path.exists(original_model_or_ckpt):
        ckpt_paths['original'] = original_model_or_ckpt
    if os.path.exists(unstructured_ckpt):
        ckpt_paths['unstructured'] = unstructured_ckpt
    if os.path.exists(structured_ckpt):
        ckpt_paths['structured'] = structured_ckpt

    img, _ = CIFAR10_testset[sample_image_index]
    input_tensor = img.unsqueeze(0)

    summary = {}
    for label, model in models.items():
        mean_ms, std_ms, _ = measure_inference_time(model, input_tensor, runs=runs)
        param_cnt = sum(p.numel() for p in model.parameters())
        ckpt_size = None
        ckpt_path = None
        if label == original_label:
            ckpt_path = original_model_or_ckpt if isinstance(original_model_or_ckpt, str) and os.path.exists(original_model_or_ckpt) else None
        elif label == unstructured_label:
            ckpt_path = unstructured_ckpt
        elif label == structured_label:
            ckpt_path = structured_ckpt
        if ckpt_path and os.path.exists(ckpt_path):
            ckpt_size = os.path.getsize(ckpt_path)
            ck = torch.load(ckpt_path, map_location='cpu')
            if 'model_state_dict' in ck:
                sd = ck['model_state_dict']
            elif 'state_dict' in ck:
                sd = ck['state_dict']
            else:
                sd = ck
            spars = compute_sparsity_from_state_dict(sd)
        else:
            spars = None
        summary[label] = {
            'mean_inference_ms': mean_ms,
            'std_inference_ms': std_ms,
            'param_count': param_cnt,
            'ckpt_path': ckpt_path,
            'ckpt_size_bytes': ckpt_size,
            'computed_sparsity': spars
        }

    summary_serializable = {}
    for label, info in summary.items():
        summary_serializable[label] = {
            'mean_inference_ms': float(info['mean_inference_ms']),
            'std_inference_ms': float(info['std_inference_ms']),
            'param_count': int(info['param_count']),
            'ckpt_path': info['ckpt_path'],
            'ckpt_size_bytes': int(info['ckpt_size_bytes']) if info['ckpt_size_bytes'] is not None else None,
            'computed_sparsity': float(info['computed_sparsity']) if info['computed_sparsity'] is not None else None
        }
    with open(os.path.join(out_dir, 'models_summary.json'), 'w') as f:
        json.dump(summary_serializable, f, indent=2)

In [None]:
baseline_ckpt = "baseline_vgg11.pth"
unstructured_ckpt = "unstructured_outputs/final/final_pruned_state_dict.pth"   
structured_ckpt   = "structured_final/structured_pruned_final_state_dict.pth"  
models_info = {
    "baseline": ("original", None),
    "unstructured": ("unstructured", unstructured_ckpt),
    "structured": ("structured", structured_ckpt)
}
image_indices = [0, 7, 42]
compare_gradcam_models(models_info, image_indices=image_indices, target_layer_name=None, out_dir='gradcam_compare', device=device)


In [None]:
summary = summarize_models(
    original_label="baseline", original_model_or_ckpt='original',
    unstructured_label="unstructured", unstructured_ckpt=unstructured_ckpt,
    structured_label="structured",   structured_ckpt=structured_ckpt,
    sample_image_index=0,   
    runs=200,               
    out_dir='model_summary'
)
print(json.dumps(summary, indent=2))

print("Baseline ckpt size:", model_checkpoint_size_bytes(baseline_ckpt))
print("Unstructured ckpt size:", model_checkpoint_size_bytes(unstructured_ckpt))
print("Structured ckpt size:", model_checkpoint_size_bytes(structured_ckpt))

ck = torch.load(unstructured_ckpt, map_location='cpu')
sd = ck.get('model_state_dict', ck)
print("Unstructured computed sparsity:", compute_sparsity_from_state_dict(sd))

In [None]:
readings = [
    (0.0, "baseline_vgg11.pth"),
    (0.3, "unstructured_outputs/ckpt_s30.pth"),
    (0.5, "unstructured_outputs/ckpt_s50.pth"),
    (0.7, "unstructured_outputs/ckpt_s70.pth"),
    (0.9, "unstructured_outputs/ckpt_s90.pth"),
]

notes = accuracy_vs_sparsity_plot(readings, out_path='acc_vs_spars.png')
print("Readings (sparsity_pct, acc):", notes)


img, _ = CIFAR10_testset[0]
inp = img.unsqueeze(0)
m = load_vgg_from_checkpoint(unstructured_ckpt)  # or pruned model object
mean_ms, std_ms, all_times = measure_inference_time(m, inp, runs=200, warmup=20)
print("Mean ms:", mean_ms, "std ms:", std_ms)

cam, pred = compute_gradcam(m, inp, target_layer_name=None, class_idx=None, device=device)
overlay_and_save(img, cam, out_path='single_gradcam_unstructured.png')


In [None]:
import math
from collections import OrderedDict
from copy import deepcopy

def generate_unstructured_pruned_checkpoints(baseline_ckpt, sparsity_list, out_dir='unstructured_outputs/generated', device='cpu'):

    os.makedirs(out_dir, exist_ok=True)
    ck = torch.load(baseline_ckpt, map_location='cpu')
    state = ck.get('model_state_dict', ck) if isinstance(ck, dict) else ck
    prunable_items = [(k, v.clone()) for k, v in state.items() if k.endswith('.weight')]
    if len(prunable_items) == 0:
        raise RuntimeError("No weight tensors found in checkpoint. Check state dict keys.")

    flat_vals = []
    mapping = []  
    offset = 0
    for k, v in prunable_items:
        arr = v.detach().cpu().abs().view(-1)
        L = arr.numel()
        flat_vals.append(arr)
        mapping.append((k, offset, L))
        offset += L
    flat_all = torch.cat(flat_vals)  # 1D
    N_total = flat_all.numel()
    flat_all_np = flat_all.numpy()

    def make_sparsity_checkpoint(target_sparsity, fname_suffix):
        s = float(target_sparsity)
        if s > 1.0: s = s/100.0 if s>1.0 and s<=100.0 else s
        target_pruned = int(math.floor(s * N_total))
        if target_pruned <= 0:
            new_state = {k: v.clone().cpu() for k,v in state.items()}
            save_path = os.path.join(out_dir, f"ckpt_us{int(round(s*100)):02d}.pth")
            torch.save({'model_state_dict': new_state, 'global_sparsity': 0.0}, save_path)
            return save_path, 0.0

        thresh = np.partition(flat_all_np, target_pruned-1)[target_pruned-1] if target_pruned>0 else 0.0
        new_state = {}
        for k, v in state.items():
            if k.endswith('.weight'):
                w = v.clone().cpu()
                mask = w.abs() > float(thresh)  
                w = w * mask
                new_state[k] = w
            else:
                new_state[k] = v.clone().cpu()
        tot = 0
        nz = 0
        for k,v in new_state.items():
            if k.endswith('.weight'):
                arr = v.numpy()
                tot += arr.size
                nz += np.count_nonzero(arr)
        achieved = 1.0 - (nz / tot) if tot>0 else 0.0
        save_path = os.path.join(out_dir, f"ckpt_s{int(round(s*100)):02d}.pth")
        torch.save({'model_state_dict': new_state, 'global_sparsity': achieved}, save_path)
        return save_path, achieved

    results = []
    for s in sparsity_list:
        path, achieved = make_sparsity_checkpoint(s, None)
        print(f"Saved unstructured ckpt: {path} (achieved sparsity {achieved:.4f})")
        results.append((s, path, achieved))
    return results

generate_unstructured_pruned_checkpoints("baseline_vgg11.pth", sparsity_list=[0.3, 0.5, 0.7])


Saved unstructured ckpt: unstructured_outputs/generated/ckpt_s30.pth (achieved sparsity 0.3000)
Saved unstructured ckpt: unstructured_outputs/generated/ckpt_s50.pth (achieved sparsity 0.5000)
Saved unstructured ckpt: unstructured_outputs/generated/ckpt_s70.pth (achieved sparsity 0.7000)


[(0.3, 'unstructured_outputs/generated/ckpt_s30.pth', 0.29999999844719416),
 (0.5, 'unstructured_outputs/generated/ckpt_s50.pth', 0.5000000077640292),
 (0.7, 'unstructured_outputs/generated/ckpt_s70.pth', 0.6999999937887766)]

In [None]:
def generate_structured_pruned_checkpoints(baseline_ckpt, sparsity_list, out_dir='structured_final/generated', device='cpu'):
  
    os.makedirs(out_dir, exist_ok=True)
    ck = torch.load(baseline_ckpt, map_location='cpu')
    state = ck.get('model_state_dict', ck) if isinstance(ck, dict) else ck
    try:
        vgg_model.load_state_dict(state)
    except Exception:
        vgg_model.load_state_dict(state, strict=False)
    vgg_model.to(device)
    vgg_model.eval()

    convs = collect_convs_from_global_vgg() 

    saved = []
    for s in sparsity_list:
        sp = float(s)
        if sp>1.0: sp = sp/100.0
        keep_frac = max(0.0, 1.0 - sp)
        kept_map = {}
        for name, conv in convs:
            kept_map[name] = choose_kept_channels(conv, keep_frac)
        pruned_model, kept_used = build_pruned_vgg(kept_map)
        pruned_model.to('cpu')
        orig_params = sum(p.numel() for p in vgg_model.parameters())
        new_params  = sum(p.numel() for p in pruned_model.parameters())
        global_spars = 1.0 - (new_params / orig_params)
        pruned_state = {k: v.cpu() for k,v in pruned_model.state_dict().items()}
        save_path = os.path.join(out_dir, f"ckpt_s{int(round(sp*100)):02d}.pth")
        torch.save({
            'model_state_dict': pruned_state,
            'kept_indices': kept_used,
            'structured_per_layer_sparsities': {k: float(1.0 - len(v)/get_conv_out_channels(k)) for k,v in kept_used.items()} if isinstance(kept_used, dict) else None,
            'global_sparsity': float(global_spars)
        }, save_path)
        print(f"Saved structured ckpt: {save_path} (global_sparsity={global_spars:.4f})")
        saved.append((s, save_path, global_spars))
    return saved

def get_conv_out_channels(qname):
    mod = get_module_by_qname(vgg_model, qname)
    if isinstance(mod, nn.Conv2d):
        return mod.out_channels
    return None

generate_structured_pruned_checkpoints("baseline_vgg11.pth", sparsity_list=[0.3,0.5,0.7])


Saved structured ckpt: structured_final/generated/ckpt_s30.pth (global_sparsity=0.2765)
Saved structured ckpt: structured_final/generated/ckpt_s50.pth (global_sparsity=0.4526)
Saved structured ckpt: structured_final/generated/ckpt_s70.pth (global_sparsity=0.6229)


[(0.3, 'structured_final/generated/ckpt_s30.pth', 0.2765286697805909),
 (0.5, 'structured_final/generated/ckpt_s50.pth', 0.45257312529708804),
 (0.7, 'structured_final/generated/ckpt_s70.pth', 0.6229331927469015)]

In [96]:
readings_us = [
    (0.0, "baseline_vgg11.pth"),
    (0.3, "unstructured_outputs/generated/ckpt_s30.pth"),
    (0.5, "unstructured_outputs/generated/ckpt_s50.pth"),
    (0.7, "unstructured_outputs/generated/ckpt_s70.pth"),
]

readings_s = [
    (0.0, "baseline_vgg11.pth"),
    (0.3, "structured_final/generated/ckpt_s30.pth"),
    (0.5, "structured_final/generated/ckpt_s50.pth"),
    (0.7, "structured_final/generated/ckpt_s70.pth"),
]

notes = accuracy_vs_sparsity_plot(readings_us, out_path='acc_vs_spars_us.png')
notes = accuracy_vs_sparsity_plot(readings_s, out_path='acc_vs_spars_struct.png')

Loaded VGG from checkpoint 'baseline_vgg11.pth' with classifier out_features = 10
  -> accuracy: 80.00%
Loaded VGG from checkpoint 'unstructured_outputs/generated/ckpt_s30.pth' with classifier out_features = 10
  -> accuracy: 79.96%
Loaded VGG from checkpoint 'unstructured_outputs/generated/ckpt_s50.pth' with classifier out_features = 10
  -> accuracy: 79.62%
Loaded VGG from checkpoint 'unstructured_outputs/generated/ckpt_s70.pth' with classifier out_features = 10
  -> accuracy: 78.36%
Saved accuracy vs sparsity plot to acc_vs_spars_us.png
Loaded VGG from checkpoint 'baseline_vgg11.pth' with classifier out_features = 10
  -> accuracy: 80.00%
  -> accuracy: 11.54%
  -> accuracy: 10.72%
  -> accuracy: 10.72%
Saved accuracy vs sparsity plot to acc_vs_spars_struct.png
