In [None]:
from src.dataset import CIFAR10DataModule
from src.models import ModelFactory
from pytorch_lightning.callbacks import ModelPruning
from src.training_module import TrainingModule
from pytorch_lightning.callbacks import Callback
import pytorch_lightning as pl
from src.prune_scheduler import AgpPruningRate
import torchmetrics
from torch import nn
import torch.nn.utils.prune as prune
import torch
from typing import Callable
from pytorch_lightning.loggers.base import LightningLoggerBase
import csv

In [None]:
class SparsityLogger(LightningLoggerBase):
    def __init__(self, file_path, header=['epoch', 'layer', 'sparsity']):
        super().__init__
        self.file_path = file_path

        with open(self.file_path, 'w') as f:
          writer = csv.writer(f)
          writer.writerow(header)
    
    def log(self, metrics):
        fields = [metrics["epoch"], metrics["layer"], metrics["sparsity"]]
        filename = self.file_path
        with open(filename, 'a') as f:
            writer = csv.writer(f)
            writer.writerow(fields)


In [None]:
class PruningTrainingModule(TrainingModule):
    def __init__(
        self, 
        model_name, 
        image_size, 
        num_classes, 
        lr, 
        momentum, 
        epochs,
        weight_decay,
        mixup,
        pre_trained=False,
    ):
        super(TrainingModule, self).__init__()
        self.lr = lr
        self.image_size = image_size
        self.num_classes = num_classes
        self.weight_decay = weight_decay
        self.momentum = momentum
        self.epochs = epochs
        self.mixup = mixup
        self._model = self.create_model(model_name=model_name, pre_trained=pre_trained)
        self._loss = nn.CrossEntropyLoss()
        acc = torchmetrics.Accuracy()
        self.val_acc = acc.clone()
        self.train_acc = acc.clone()

        self.freq = 1
        self.prune_end = int(20 * 0.75)
        self.prune_sch = AgpPruningRate(.05, .50, 1, self.prune_end, self.freq)
        self.prune_layers = [module for module in self._model.modules()][:-1]

    def on_train_epoch_start(self) -> None:
        if self.current_epoch % self.freq == 1 and self.current_epoch <= self.prune_end:
            target = self.prune_sch.step(self.current_epoch)
            print(target)
            print(f'pruning {target * 100}% sparsity')
            if self.current_epoch > 1 and self.current_epoch < self.prune_end:
                for i, layer in enumerate(self.prune_layers):
                    if type(layer) == nn.Conv2d or type(layer) == nn.Linear:
                        prune.remove(layer, "weight")
            for i, layer in enumerate(self.prune_layers):
                if type(layer) == nn.Conv2d or type(layer) == nn.Linear:
                    prune.l1_unstructured(layer, name='weight', amount=float(target))
                    layer_spar = float(torch.sum(layer.weight == 0))
                    layer_spar /= float(layer.weight.nelement())
                    print(f"Sparsity in layer {i} {type(layer)} {layer_spar: 3f}")
        elif self.current_epoch > self.prune_end:
            print("All done pruning")

    def OG_Pruning_Scheduler(self) -> Callable:
        if self.current_epoch % self.freq == 1 and self.current_epoch <= self.prune_end:
            target = self.prune_sch.step(self.current_epoch)
            return target
        else:
            return 0



In [None]:
class OG_Pruning_Callback(Callback):    
    def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        if trainer.current_epoch % pl_module.freq == 1 and trainer.current_epoch <= pl_module.prune_end:
            target = pl_module.prune_sch.step(trainer.current_epoch)
            print(target)
            print(f'pruning {target * 100}% sparsity')
            if trainer.current_epoch > 1 and trainer.current_epoch < pl_module.prune_end:
                for i, layer in enumerate(pl_module.prune_layers):
                    if type(layer) == nn.Conv2d or type(layer) == nn.Linear:
                        prune.remove(layer, "weight")
            for i, layer in enumerate(pl_module.prune_layers):
                if type(layer) == nn.Conv2d or type(layer) == nn.Linear:
                    prune.l1_unstructured(layer, name='weight', amount=float(target))
                    layer_spar = float(torch.sum(layer.weight == 0))
                    layer_spar /= float(layer.weight.nelement())
                    print(f"Sparsity in layer {i} {type(layer)} {layer_spar: 3f}")
        elif trainer.current_epoch > pl_module.prune_end:
            print("All done pruning")
    
    def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        torch.save(trainer.model.state_dict(), 'og_pruning_weights.pth')

In [None]:
training_module = PruningTrainingModule(
    model_name='resnet34',
    image_size=1,
    num_classes=10,
    pre_trained=False,
    lr=0.01,
    epochs=20,
    mixup=False,
    momentum=0.005,
    weight_decay=1e-5
        )

In [None]:
dm = CIFAR10DataModule(data_dir='data/', num_workers=4, pin_memory=True)

In [None]:
og_trainer = pl.Trainer(gpus=1, max_epochs=20, callbacks=[OG_Pruning_Callback()])

og_trainer.fit(training_module, dm)

In [None]:
class PL_Pruning_Callback(Callback):
   def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        torch.save(trainer.model.state_dict(), 'pl_pruning_weights.pth')

pl_training_module = PruningTrainingModule(
    model_name='resnet34',
    image_size=1,
    num_classes=10,
    pre_trained=False,
    lr=0.01,
    epochs=20,
    mixup=False,
    momentum=0.005,
    weight_decay=1e-5
        )

pl_pruning_trainer = pl.Trainer(gpus=1, max_epochs=20, callbacks=[PL_Pruning_Callback(), ModelPruning(pruning_fn="l1_unstructured", parameter_names=["weight"], prune_on_train_epoch_end=False, make_pruning_permanent=True, amount=pl_training_module.OG_Pruning_Scheduler(), apply_pruning=True, verbose=2)])

pl_pruning_trainer.fit(pl_training_module ,dm)