### ME3:

In [38]:
# Standard library imports
import time
import json
from collections import defaultdict
from functools import wraps
from pathlib import Path
from typing import Tuple

# Third-party imports
import lightning as L
import torch
from einops import einsum
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10
from torchvision.models import ResNet18_Weights, resnet18, ResNet
from torch import nn
from torchmetrics.classification import Accuracy
from torch.nn.utils import prune
from tqdm.auto import tqdm

# set precision for matmuls
torch.set_float32_matmul_precision("medium")   # or "medium"

In [2]:
weights_and_biases = ResNet18_Weights.DEFAULT.get_state_dict()
print(weights_and_biases.keys())

odict_keys(['conv1.weight', 'bn1.running_mean', 'bn1.running_var', 'bn1.weight', 'bn1.bias', 'layer1.0.conv1.weight', 'layer1.0.bn1.running_mean', 'layer1.0.bn1.running_var', 'layer1.0.bn1.weight', 'layer1.0.bn1.bias', 'layer1.0.conv2.weight', 'layer1.0.bn2.running_mean', 'layer1.0.bn2.running_var', 'layer1.0.bn2.weight', 'layer1.0.bn2.bias', 'layer1.1.conv1.weight', 'layer1.1.bn1.running_mean', 'layer1.1.bn1.running_var', 'layer1.1.bn1.weight', 'layer1.1.bn1.bias', 'layer1.1.conv2.weight', 'layer1.1.bn2.running_mean', 'layer1.1.bn2.running_var', 'layer1.1.bn2.weight', 'layer1.1.bn2.bias', 'layer2.0.conv1.weight', 'layer2.0.bn1.running_mean', 'layer2.0.bn1.running_var', 'layer2.0.bn1.weight', 'layer2.0.bn1.bias', 'layer2.0.conv2.weight', 'layer2.0.bn2.running_mean', 'layer2.0.bn2.running_var', 'layer2.0.bn2.weight', 'layer2.0.bn2.bias', 'layer2.0.downsample.0.weight', 'layer2.0.downsample.1.running_mean', 'layer2.0.downsample.1.running_var', 'layer2.0.downsample.1.weight', 'layer2.0.do

In [3]:
weights_and_biases['fc.weight']

Parameter containing:
tensor([[-0.0185, -0.0705, -0.0518,  ..., -0.0390,  0.1735, -0.0410],
        [-0.0818, -0.0944,  0.0174,  ...,  0.2028, -0.0248,  0.0372],
        [-0.0332, -0.0566, -0.0242,  ..., -0.0344, -0.0227,  0.0197],
        ...,
        [-0.0103,  0.0033, -0.0359,  ..., -0.0279, -0.0115,  0.0128],
        [-0.0359, -0.0353, -0.0296,  ..., -0.0330, -0.0110, -0.0513],
        [ 0.0021, -0.0248, -0.0829,  ...,  0.0417, -0.0500,  0.0663]],
       requires_grad=True)

#### Load the Datasets

In [4]:
transform = ResNet18_Weights.IMAGENET1K_V1.transforms()

# Build the full training set once, then carve out a validation split
full_train = CIFAR10(
    root="data/CIFAR10",
    train=True,
    transform=transform,
    download=True,
)

generator = torch.Generator().manual_seed(42)
train_dataset, val_dataset = random_split(full_train, [45000, 5000], generator=generator)

cifar10_test = CIFAR10(
    root="data/CIFAR10",
    train=False,
    transform=transform,
    download=True,
)

train_loader = DataLoader(
    train_dataset,
    batch_size=256,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True,
    drop_last=True,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=256,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True,
)

test_loader = DataLoader(
    cifar10_test,
    batch_size=256,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True,
)



#### Initialize the model and replace head

In [3]:
from typing import Tuple
import torch
import torch.nn as nn
import pytorch_lightning as L
from torch.optim.lr_scheduler import MultiStepLR
from torchvision.models import resnet18, ResNet18_Weights
from torchmetrics.classification import Accuracy

class LitResnet18(L.LightningModule):
    def __init__(self, num_classes: int = 10, lr: float = 1e-3):
        super().__init__()
        self.save_hyperparameters()
        self.model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
        self.model.fc = nn.Linear(512, num_classes)

        # freeze backbone
        for name, p in self.model.named_parameters():
            if not name.startswith("fc"):
                p.requires_grad = False

        self.loss = nn.CrossEntropyLoss()
        self.val_acc = Accuracy(task="multiclass", num_classes=num_classes)
        self.test_acc = Accuracy(task="multiclass", num_classes=num_classes)
        self.train_acc = Accuracy(task="multiclass", num_classes=num_classes)

    def forward(self, x):
        return self.model(x)

    def _shared_step(self, batch, batch_idx) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        x, y = batch
        logits = self(x)
        loss = self.loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        return loss, preds, y

    def training_step(self, batch, batch_idx):
        loss, preds, y = self._shared_step(batch, batch_idx)
        self.train_acc.update(preds, y)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
        self.log("train_acc", self.train_acc, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, preds, y = self._shared_step(batch, batch_idx)
        self.val_acc.update(preds, y)
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        self.log("val_acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)

    def test_step(self, batch, batch_idx):
        loss, preds, y = self._shared_step(batch, batch_idx)
        self.test_acc.update(preds, y)
        self.log("test_loss", loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        self.log("test_acc", self.test_acc, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)

    def configure_optimizers(self):
        trainable = (p for p in self.parameters() if p.requires_grad)
        optimizer = torch.optim.SGD(
            trainable, lr=self.hparams.lr, momentum=0.9, weight_decay=1e-4
        )
        scheduler = MultiStepLR(optimizer, milestones=[30, 60], gamma=0.1)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",
                "frequency": 1,
            }
        }


In [4]:
class CIFAR10DataModule(L.LightningDataModule):
    def __init__(
            self, 
            data_dir: str = "data/CIFAR10", 
            batch_size: int = 256, 
            num_workers: int = 0
        ):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.transform = ResNet18_Weights.IMAGENET1K_V1.transforms()

    def setup(self, stage: str | None = None):
        self.generator = torch.Generator().manual_seed(42)

        if stage in (None, 'fit', 'validate'):
            cifar10_full = CIFAR10(
                root=self.data_dir,
                train=True,
                transform=self.transform,
                download=True,
            )
            self.cifar10_train, self.cifar10_val = random_split(
                cifar10_full, [45000, 5000], generator=self.generator
            )
        
        if stage in (None, "test", "predict"):
            self.cifar10_test = CIFAR10(
                root=self.data_dir,
                train=False,
                transform=self.transform,
                download=True,
            )

    def train_dataloader(self):
        return DataLoader(
            self.cifar10_train,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            persistent_workers=False,
            pin_memory=True,
            drop_last=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.cifar10_val,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=0,
            pin_memory=True,
            persistent_workers=False,
        )

    def test_dataloader(self):
        return DataLoader(
            self.cifar10_test,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            persistent_workers=self.num_workers > 0,
        )


In [None]:
ckpt_path = Path("resnet18_finetuned.ckpt")
if ckpt_path.exists():
    model = LitResnet18.load_from_checkpoint(ckpt_path)
else:
    model = LitResnet18()
    cifar10_dm = CIFAR10DataModule(num_workers=0)

    trainer = L.Trainer(max_epochs=90)
    trainer.fit(model, datamodule=cifar10_dm)
    trainer.save_checkpoint(ckpt_path)

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | model     | ResNet             | 11.2 M | train
1 | loss      | CrossEntropyLoss   | 0      | train
2 | val_acc   | MulticlassAccuracy | 0      | train
3 | test_acc  | MulticlassAccuracy | 0      | train
4 | train_acc | MulticlassAccuracy | 0      | train
---------------------------------------------------------
5.1 K     Trainable params
11.2 M    Non-trainable params
11.2 M    Total params
44.727    Total estimated model params size (MB)
72        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/aleisley/Documents/ai231/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


                                                                           

/home/aleisley/Documents/ai231/.venv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Epoch 89: 100%|██████████| 175/175 [01:55<00:00,  1.52it/s, v_num=10, train_loss_step=0.558, val_loss=0.639, val_acc=0.778, train_loss_epoch=0.606, train_acc=0.793]

`Trainer.fit` stopped: `max_epochs=90` reached.


Epoch 89: 100%|██████████| 175/175 [01:55<00:00,  1.52it/s, v_num=10, train_loss_step=0.558, val_loss=0.639, val_acc=0.778, train_loss_epoch=0.606, train_acc=0.793]


Yes—around 60 cycles is a sensible upper bound if you keep pruning 5 % of the remaining head weights per cycle.

Here’s why:

With s = 0.05 per cycle, total sparsity after n cycles is

1−0.95n
1−0.95
n

So:

20 cycles → ~64 % total sparsity

40 cycles → ~87 % total sparsity

60 cycles → ~95 % total sparsity

95 % total sparsity is right at the range where most ResNet-18 heads can no longer recover validation accuracy.
Beyond that, recovery is usually negligible and training time is wasted.

Going much past 60 cycles gives very small additional pruning (e.g., 80 cycles only brings you to ~98 % sparsity) and very low accuracy.

#### Prune cycle

In [39]:

num_cycles = 30
prune_amount = 0.10 # 95% sparsity will be reached in 30 cycles

layer = model.model.fc
results_path = Path('artifacts/pruning_results.json')
results = json.loads(results_path.read_text()) if results_path.exists() else []
completed = {entry['cycle'] for entry in results}

for cycle in range(1, 31):
    if cycle in completed:
        continue
    prune.l1_unstructured(layer, name="weight", amount=prune_amount)
    prune_trainer = L.Trainer(max_epochs=5)
    prune_trainer.fit(model, datamodule=cifar10_dm)
    acc = prune_trainer.callback_metrics["val_acc"].item()
    results.append({"cycle": cycle, "val_acc": acc})
    prune_trainer.save_checkpoint(f"artifacts/cycle_{cycle:.2f}.ckpt")

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/aleisley/Documents/ai231/.venv/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


TypeError: `model` must be a `LightningModule` or `torch._dynamo.OptimizedModule`, got `LitResnet18`

In [33]:
trainer.save_checkpoint(f'data/resnet18_finetuned2.ckpt')

In [37]:
1- 0.9**30

0.9576088417247838

In [8]:
trainer.callback_metrics

{'train_loss': tensor(0.8721),
 'train_loss_step': tensor(0.7433),
 'val_loss': tensor(0.8424),
 'val_acc': tensor(0.7294),
 'train_loss_epoch': tensor(0.8721),
 'train_acc': tensor(0.7277)}

In [9]:
trainer.callback_metrics["val_acc"].item()

0.7293999791145325

In [None]:
for cycle in range(num_cycles):
    trainer.fit(model, datamodule)
    acc = trainer.callback_metrics["val_acc"].item()
    results.append({"cycle": cycle, "val_acc": acc})