In [None]:
%env CUDA_VISIBLE_DEVICES=0

In [None]:
import pytorch_lightning as pl
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics import Accuracy
from torchvision.models import mobilenet_v2
from pathlib import Path

import sys
sys.path.append('../')
from tutorial_utils.dataset import create_imagenette_dataloaders

# Train baseline model

In [None]:
class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = mobilenet_v2(pretrained=False, num_classes=10)
        self.metric = Accuracy()
        
        HOME_DIR = Path('./')
        DATASETS_DIR = HOME_DIR / 'datasets'
        PROJECT_DIR = HOME_DIR / 'lightning_mobilenet_pruning'
        
        HOME_DIR.mkdir(exist_ok=True)
        DATASETS_DIR.mkdir(exist_ok=True)
        PROJECT_DIR.mkdir(exist_ok=True)
        
        self.dataloaders = create_imagenette_dataloaders(
            dataset_root_dir=DATASETS_DIR, 
            project_dir=PROJECT_DIR,
            input_size=(224, 224),
            batch_size=64,
        )

    def forward(self, x):
        return self.model(x)
    
    def configure_optimizers(self):
        optimizer = torch.optim.RAdam(self.parameters(), lr=1e-2)
        self.optimizer = optimizer
        steps_per_epoch = len(self.train_dataloader())
        scheduler = LinearWarmupCosineAnnealingLR(
            optimizer,
            warmup_epochs=50 * steps_per_epoch,
            max_epochs=500 * steps_per_epoch,
        )
        self.scheduler = scheduler
        
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'step',
                'frequency': 1,
            }
        }
    
    def training_step(self, batch, batch_idx):
        images, labels = batch
        pred = self(images)
        loss = F.cross_entropy(pred, labels)
        return loss
    
    def validation_step(self, batch, batch_idx):
        images, labels = batch
        pred = self(images)
        pred_labels = torch.argmax(pred, dim=1)
        val_loss = F.cross_entropy(pred, labels)
        self.log("val_loss", val_loss)
        self.log("val batch accuracy", self.metric(pred_labels, labels), on_epoch=True)
        
    def test_step(self, batch, batch_idx):
        images, labels = batch
        pred = self(images)
        pred_labels = torch.argmax(pred, dim=1)
        test_loss = F.cross_entropy(pred, labels)
        self.log("test_loss", test_loss)
        self.log("val batch accuracy", self.metric(pred_labels, labels), on_epoch=True)
        
    def validation_epoch_end(self, validation_step_outputs):
        if hasattr(self, 'optimizer'):
            self.log("lr", self.optimizer.param_groups[0]['lr'])        
        self.log("val epoch accuracy", self.metric.compute())
        
    def test_epoch_end(self, validation_step_outputs):
        if hasattr(self, 'optimizer'):
            self.log("lr", self.optimizer.param_groups[0]['lr'])        
        self.log("val epoch accuracy", self.metric.compute())
        
    def train_dataloader(self):
        return self.dataloaders['tune_train_dataloader']

    def val_dataloader(self):
        return self.dataloaders['tune_validation_dataloader']

    def test_dataloader(self):
        return self.dataloaders['tune_validation_dataloader']

    def predict_dataloader(self):
        return self.dataloaders['tune_validation_dataloader']

In [None]:
model = LitModel()
trainer = pl.Trainer(max_epochs=500, accelerator='gpu', devices=1)
trainer.fit(model=model)

# For pruning we need to accumulate gradients on prunable model

In [None]:
from enot.pruning import EnotPruningCalibrator
from enot.pruning import prune_model
from enot.pruning import get_labels_for_equal_pruning
import numpy as np

In [None]:
pruning_ratio = 0.45  # This gives about x3 FLOPs reduction.

In [None]:
lit_model = LitModel()
# Load pretrained weights with 93.27% accuracy. If you train model from scratch - load your checkpoint
# model_state = torch.load('lightning_logs/version_0/checkpoints/epoch=499-step=59000.ckpt')['state_dict']
# lit_model.load_state_dict(model_state)
lit_model.cuda();

In [None]:
trainer = pl.Trainer(accelerator='gpu', devices=1)
trainer.test(lit_model, dataloaders=lit_model.test_dataloader())

In [None]:
lit_model.cuda();
pruning_calibrator = EnotPruningCalibrator(model=lit_model)
with pruning_calibrator:
    for batch_idx, batch in enumerate(lit_model.train_dataloader()):
        loss = lit_model.training_step(batch, batch_idx)
        loss.backward()

pruning_info = pruning_calibrator.pruning_info

In [None]:
all_channel_indices_to_prune = get_labels_for_equal_pruning(pruning_info, pruning_ratio)

In [None]:
pruned_model = prune_model(
    model=lit_model,
    pruning_info=pruning_info,
    prune_labels=all_channel_indices_to_prune,
    inplace=False,
)

In [None]:
torch.save(pruned_model.model, 'pruned_model.pth')

# Finetune pruned model

In [None]:
class TuneLitModel(LitModel):
    def configure_optimizers(self):
        optimizer = torch.optim.RAdam(self.parameters(), lr=1e-3)
        self.optimizer = optimizer
        steps_per_epoch = len(self.train_dataloader())
        scheduler = LinearWarmupCosineAnnealingLR(
            optimizer,
            warmup_epochs=15 * steps_per_epoch,
            max_epochs=150 * steps_per_epoch,
        )
        self.scheduler = scheduler
        
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'step',
                'frequency': 1,
            }
        }

In [None]:
pruned_model = TuneLitModel()
pruned_model.model = torch.load('pruned_model.pth')

In [None]:
trainer = pl.Trainer(max_epochs=150, accelerator='gpu', devices=1)
trainer.fit(model=pruned_model)

In [None]:
trainer.test(model=pruned_model, dataloaders=pruned_model.test_dataloader())

# Measure latency difference in FLOPs

In [None]:
from enot.latency import MacCalculatorPthflops

In [None]:
lit_model.cpu()
MacCalculatorPthflops().calculate(lit_model.model, torch.ones((1,3,224,224)))

In [None]:
pruned_model.cpu()
MacCalculatorPthflops().calculate(pruned_model.model, torch.ones((1,3,224,224)))