In [1]:
from src.dataset import CIFAR10DataModule
from src.models import ModelFactory
from pytorch_lightning.callbacks import ModelPruning
from src.training_module import TrainingModule
from pytorch_lightning.callbacks import Callback
import pytorch_lightning as pl
from src.prune_scheduler import AgpPruningRate
import torchmetrics
from torch import nn
import torch.nn.utils.prune as prune
import torch
from typing import Callable



In [2]:
class PruningTrainingModule(TrainingModule):
    def __init__(
        self, 
        model_name, 
        image_size, 
        num_classes, 
        lr, 
        momentum, 
        epochs,
        weight_decay,
        mixup,
        pre_trained=False,
    ):
        super(TrainingModule, self).__init__()
        self.lr = lr
        self.image_size = image_size
        self.num_classes = num_classes
        self.weight_decay = weight_decay
        self.momentum = momentum
        self.epochs = epochs
        self.mixup = mixup
        self._model = self.create_model(model_name=model_name, pre_trained=pre_trained)
        self._loss = nn.CrossEntropyLoss()
        acc = torchmetrics.Accuracy()
        self.val_acc = acc.clone()
        self.train_acc = acc.clone()

        self.freq = 1
        self.prune_end = int(20 * 0.75)
        self.prune_sch = AgpPruningRate(.05, .50, 1, self.prune_end, self.freq)
        self.prune_layers = [module for module in self._model.modules()][:-1]

    def OG_Pruning_Scheduler(self) -> Callable:
        if self.current_epoch % self.freq == 1 and self.current_epoch <= self.prune_end:
            target = self.prune_sch.step(self.current_epoch)
            return target
        else:
            return 0
        

In [3]:
class OG_Pruning_Callback(Callback):    
    def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        if trainer.current_epoch % pl_module.freq == 1 and trainer.current_epoch <= pl_module.prune_end:
            target = pl_module.prune_sch.step(trainer.current_epoch)
            print(target)
            print(f'pruning {target * 100}% sparsity')
            if trainer.current_epoch > 1 and trainer.current_epoch < pl_module.prune_end:
                for i, layer in enumerate(pl_module.prune_layers):
                    if type(layer) == nn.Conv2d or type(layer) == nn.Linear:
                        prune.remove(layer, "weight")
            for i, layer in enumerate(pl_module.prune_layers):
                if type(layer) == nn.Conv2d or type(layer) == nn.Linear:
                    prune.l1_unstructured(layer, name='weight', amount=float(target))
                    layer_spar = float(torch.sum(layer.weight == 0))
                    layer_spar /= float(layer.weight.nelement())
                    print(f"Sparsity in layer {i} {type(layer)} {layer_spar: 3f}")
        elif trainer.current_epoch > pl_module.prune_end:
            print("All done pruning")
    
    def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        torch.save(trainer.model.state_dict(), 'og_pruning_weights.pth')

In [4]:
training_module = PruningTrainingModule(
    model_name='resnet34',
    image_size=1,
    num_classes=10,
    pre_trained=False,
    lr=0.01,
    epochs=20,
    mixup=False,
    momentum=0.005,
    weight_decay=1e-5
        )

In [5]:
dm = CIFAR10DataModule(data_dir='data/', num_workers=4, pin_memory=True)

In [5]:
og_trainer = pl.Trainer(gpus=1, max_epochs=20, callbacks=[OG_Pruning_Callback()])

og_trainer.fit(training_module, dm)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


Files already downloaded and verified
Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | _model    | ResNet           | 21.3 M
1 | _loss     | CrossEntropyLoss | 0     
2 | val_acc   | Accuracy         | 0     
3 | train_acc | Accuracy         | 0     
-----------------------------------------------
21.3 M    Trainable params
0         Non-trainable params
21.3 M    Total params
85.128    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

All done pruning


Validating: 0it [00:00, ?it/s]

All done pruning


Validating: 0it [00:00, ?it/s]

All done pruning


Validating: 0it [00:00, ?it/s]

All done pruning


Validating: 0it [00:00, ?it/s]

In [6]:
class PL_Pruning_Callback(Callback):
   def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        torch.save(trainer.model.state_dict(), 'pl_pruning_weights.pth')

pl_training_module = PruningTrainingModule(
    model_name='resnet34',
    image_size=1,
    num_classes=10,
    pre_trained=False,
    lr=0.01,
    epochs=20,
    mixup=False,
    momentum=0.005,
    weight_decay=1e-5
        )

pl_pruning_trainer = pl.Trainer(gpus=1, max_epochs=20, callbacks=[PL_Pruning_Callback(), ModelPruning(pruning_fn="l1_unstructured", parameter_names=["weight"], prune_on_train_epoch_end=False, make_pruning_permanent=True, amount=pl_training_module.OG_Pruning_Scheduler(), apply_pruning=True, verbose=2)])

pl_pruning_trainer.fit(pl_training_module ,dm)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


Files already downloaded and verified
Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


KeyboardInterrupt: 

In [7]:
#compare the weights
pl_weights = torch.load('pl_pruning_weights.pth')
og_weights = torch.load('og_pruning_weights.pth')

In [12]:
for key in pl_weights.keys():
    if 'weight' in key:
        print(pl_weights[key])


tensor([[[[-1.0641e-01, -5.0489e-03, -2.1420e-01],
          [-3.5454e-01, -3.2358e-01, -3.1218e-01],
          [ 6.3811e-01,  1.0212e+00,  9.6754e-01]],

         [[ 5.2443e-02,  1.3738e-01, -5.4519e-02],
          [-9.4720e-01, -1.5034e+00, -1.3832e+00],
          [-5.1260e-01, -8.9049e-01, -7.9765e-01]],

         [[ 1.3256e+00,  1.5431e+00,  1.6317e+00],
          [ 4.9313e-01,  1.1425e-01, -5.3128e-02],
          [ 1.9431e-01, -4.6809e-02,  4.5639e-02]]],


        [[[-1.1137e-01, -1.2503e-01,  7.7536e-02],
          [-1.4019e-01, -3.6544e-01, -2.5705e-01],
          [ 3.5468e-02, -8.5552e-02, -1.8104e-01]],

         [[-5.4145e-02,  3.1229e-02,  1.9665e-01],
          [-3.9104e-02, -1.5974e-01, -3.7913e-02],
          [ 1.6727e-01,  1.3050e-01,  1.8205e-01]],

         [[-3.9901e-01, -2.4419e-01,  7.2559e-02],
          [-4.1874e-01, -5.4164e-01, -2.7920e-01],
          [-2.2091e-02, -2.0859e-01, -2.1214e-01]]],


        [[[ 5.2291e-01,  3.7866e-01,  1.2178e-01],
          [-3.3

In [9]:

for key in og_weights.keys():
    if 'weight' in key:
        print(og_weights[key])

tensor([[[[-2.5636e-05,  7.6953e-06, -1.4899e-05],
          [-3.6582e-05,  3.4272e-06, -1.4877e-05],
          [-3.0094e-05, -2.8462e-06, -3.0818e-05]],

         [[ 1.9103e-06,  1.9386e-06, -1.6528e-07],
          [ 3.1970e-06,  3.7207e-06, -1.2150e-07],
          [ 4.7266e-06,  9.7194e-06,  1.1327e-06]],

         [[ 7.4194e-08,  8.9581e-07, -1.4759e-06],
          [-1.1106e-07,  1.3331e-06, -1.1651e-06],
          [ 1.8358e-06,  3.3447e-06, -1.9748e-06]]],


        [[[ 6.0838e-01,  7.1134e-01,  2.5318e-01],
          [-4.0637e-01, -3.8191e-03, -3.1392e-01],
          [-6.6857e-01, -6.2625e-02, -4.3601e-01]],

         [[-1.2321e+00, -1.7114e+00, -1.8158e+00],
          [-1.5384e+00, -9.9285e-01, -1.2118e+00],
          [-8.8245e-01, -6.6735e-02, -3.3815e-01]],

         [[ 1.0500e+00,  9.7117e-01,  2.2515e-01],
          [ 4.0887e-01,  6.0113e-01,  3.5747e-01],
          [ 1.6105e-01,  7.2194e-01,  4.8165e-01]]],


        [[[ 2.6409e-01, -8.3532e-01,  7.2853e-01],
          [ 8.6

In [13]:
for key in og_weights.keys():
    if 'weight' in key:
        print(torch.sum(torch.eq(og_weights[key], pl_weights[key])).item()/pl_weights[key].nelement())

0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
3.3908420138888887e-06
0.0
0.0
0.0
1.6954210069444444e-06
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
3.3908420138888887e-06
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
8.477105034722222e-07
0.0
0.0
0.0
1.2715657552083333e-06
0.0
8.477105034722222e-07
0.0
1.6954210069444444e-06
0.0
1.6954210069444444e-06
0.0
0.0
