# PyTorch Global, Unstructured & Iterative Pruning:

Using _ResNet-18_ CNN trained from scratch on CIFAR-10 dataset.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.nn.utils.prune as prune

import sklearn
from sklearn.metrics import classification_report

import matplotlib.pyplot as plt
import numpy as np

import copy, pickle, os

In [2]:
print(f"PyTorch Version: {torch.__version__}")
print(f"Torchvision Version: {torchvision.__version__}")

PyTorch Version: 1.8.1+cu101
Torchvision Version: 0.9.1+cu101


In [3]:
# Device configuration-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"currently available device: {device}")

currently available device: cuda


In [4]:
# Hyper-parameters-
num_epochs = 100
batch_size = 128
learning_rate = 0.1

In [5]:
# Define transformations for training and test sets-
transform_train = transforms.Compose(
    [
      transforms.RandomCrop(32, padding = 4),
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
     ]
     )

transform_test = transforms.Compose(
    [
      transforms.ToTensor(),
      transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
     ]
     )

In [None]:
'''
# Dataset has PILImage images of range [0, 1]. We transform them to Tensors
# of normalized range [-1, 1]
transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ]
        )
'''

In [None]:
os.chdir("/home/arjun/Documents/Programs/Python_Codes/PyTorch_Resources/Good_Codes/")

In [6]:
# Load dataset-
train_dataset = torchvision.datasets.CIFAR10(
        root = './data', train = True,
        download = True, transform = transform_train
        )

test_dataset = torchvision.datasets.CIFAR10(
        root = './data', train = False,
        download = True, transform = transform_test
        )

Files already downloaded and verified
Files already downloaded and verified


In [7]:
print(f"len(train_dataset) = {len(train_dataset)} & len(test_dataset) = {len(test_dataset)}")

len(train_dataset) = 50000 & len(test_dataset) = 10000


In [8]:
# Create training and testing loaders-
train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size = batch_size,
        shuffle = True
        )

test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size = batch_size,
        shuffle = False
        )

In [9]:
print(f"len(train_loader) = {len(train_loader)} & len(test_loader) = {len(test_loader)}")

len(train_loader) = 391 & len(test_loader) = 79


In [10]:
# Sanity check-
len(train_dataset) / batch_size, len(test_dataset) / batch_size

(390.625, 78.125)

In [None]:
# Get some random training images-
# some_img = iter(train_loader)
# images, labels = some_img.next()
images, labels = next(iter(train_loader))

# You get 32 images due to our specified batch size-
print(f"images.shape: {images.shape} & labels.shape: {labels.shape}")

images.shape: torch.Size([128, 3, 32, 32]) & labels.shape: torch.Size([128])


### ResNet model definition:

In [11]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


In [12]:
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


In [13]:
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


In [14]:
def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])

In [None]:
def ResNet34():
    return ResNet(BasicBlock, [3, 4, 6, 3])

In [None]:
def ResNet50():
    return ResNet(Bottleneck, [3, 4, 6, 3])

In [None]:
def ResNet101():
    return ResNet(Bottleneck, [3, 4, 23, 3])

In [None]:
def ResNet152():
    return ResNet(Bottleneck, [3, 8, 36, 3])

In [15]:
# Initialize ResNet-18 CNN model-
trained_model = ResNet18()

In [21]:
import os
os.getcwd()

'/content'

In [16]:
# Move file from Google Colab to drive-
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [17]:
# !cp -rv ResNet18_best_trained_loss.pth drive/MyDrive
!cp -rv /content/drive/MyDrive/ResNet18_best_trained_loss.pth /content

'/content/drive/MyDrive/ResNet18_best_trained_loss.pth' -> '/content/ResNet18_best_trained_loss.pth'


In [17]:
# Path to file-
# PATH = "/home/arjun/Deep_Learning_Resources/Computer_Vision_Resources/ResNet_resources/ResNet_Codes/Good_Codes/ResNet18-CosineAnnealing_LR_scheduler/ResNet18_best_trained_loss.pth"
PATH = "ResNet18_best_trained_loss.pth"

In [18]:
# Load pre-trained weights-
trained_model.load_state_dict(torch.load(PATH, map_location = device))

<All keys matched successfully>

In [None]:
# Place model on GPU (if available)-
trained_model.to(device)

## PyTorch Pruning:

- [Reference blog](https://leimao.github.io/blog/PyTorch-Pruning/)

- [Reference GitHub](https://github.com/leimao/PyTorch-Pruning-Example)

### Sparsity for Iterative Pruning

The ```prune.l1_unstructured``` function uses an ```amount``` argument which could be either the percentage of connections to prune (if it is a float between 0 and 1), or the absolute number of connections to prune (if it is a non-negative integer).

__When it is the percentage, it is the the relative percentage to the number of unmasked/remaining parameters in the module/layer__.
For example, in iterative pruning, if we prune the weights of a certain layer by ```amount = 0.2``` in the first iteration and then prune the same module/layer by the same ```amount = 0.2``` in the second iteration. Then:
- _the amount of the valid/surviving parameters after the second round of pruning will be 1 x (1 - 0.2) x (1 - 0.2), (and)_
- _the sparsity of the parameters, i.e., the pruning rate/rate of pruning, in this module/layer will be: 1 - (1 x (1 - 0.2) x (1 - 0.2))_.


Formally, the final prune rate could be calculated using the following equation. Suppose that the relative pruning rate for each iteration is $\gamma$, the final pruning rate, after _n_ iterations, will be:
$1-\left(1-\gamma\right)^n$


Similarly, it is also easy to derive the final pruning rate for the scenario that is different in each iteration.

In [20]:
def evaluate_model(model, test_loader, device, criterion = None):

    model.eval()
    model.to(device)

    running_loss = 0
    running_corrects = 0

    for inputs, labels in test_loader:

        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)

        if criterion is not None:
            loss = criterion(outputs, labels).item()
        else:
            loss = 0

        # statistics
        running_loss += loss * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    eval_loss = running_loss / len(test_loader.dataset)
    eval_accuracy = running_corrects / len(test_loader.dataset)

    return eval_loss, eval_accuracy


In [21]:
def create_classification_report(model, device, test_loader):

    model.eval()
    model.to(device)

    y_pred = []
    y_true = []

    with torch.no_grad():
        for data in test_loader:
            y_true += data[1].numpy().tolist()
            images, _ = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            y_pred += predicted.cpu().numpy().tolist()

    classification_report = sklearn.metrics.classification_report(
        y_true = y_true, y_pred = y_pred)

    return classification_report


In [35]:
def remove_parameters(model):

    for module_name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            try:
                prune.remove(module, "weight")
            except:
                pass
            try:
                prune.remove(module, "bias")
            except:
                pass
        elif isinstance(module, torch.nn.Linear):
            try:
                prune.remove(module, "weight")
            except:
                pass
            try:
                prune.remove(module, "bias")
            except:
                pass

    return model


In [36]:
def compute_final_pruning_rate(pruning_rate, num_iterations):
    '''
    A function to compute the final pruning rate for iterative pruning.
        Note that this cannot be applied for global pruning rate if the pruning rate is heterogeneous among different layers.
    Args:
        pruning_rate (float): Pruning rate.
        num_iterations (int): Number of iterations.
    Returns:
        float: Final pruning rate.
    '''

    final_pruning_rate = 1 - (1 - pruning_rate) ** num_iterations

    return final_pruning_rate


In [37]:
def measure_module_sparsity(module, weight=True, bias=False, use_mask=False):

    num_zeros = 0
    num_elements = 0

    if use_mask == True:
        for buffer_name, buffer in module.named_buffers():
            if "weight_mask" in buffer_name and weight == True:
                num_zeros += torch.sum(buffer == 0).item()
                num_elements += buffer.nelement()
            if "bias_mask" in buffer_name and bias == True:
                num_zeros += torch.sum(buffer == 0).item()
                num_elements += buffer.nelement()
    else:
        for param_name, param in module.named_parameters():
            if "weight" in param_name and weight == True:
                num_zeros += torch.sum(param == 0).item()
                num_elements += param.nelement()
            if "bias" in param_name and bias == True:
                num_zeros += torch.sum(param == 0).item()
                num_elements += param.nelement()

    sparsity = num_zeros / num_elements

    return num_zeros, num_elements, sparsity


In [38]:
def measure_global_sparsity(
    model, weight = True,
    bias = False, conv2d_use_mask = False,
    linear_use_mask = False):

    num_zeros = 0
    num_elements = 0

    for module_name, module in model.named_modules():

        if isinstance(module, torch.nn.Conv2d):

            module_num_zeros, module_num_elements, _ = measure_module_sparsity(
                module, weight=weight, bias=bias, use_mask=conv2d_use_mask)
            num_zeros += module_num_zeros
            num_elements += module_num_elements

        elif isinstance(module, torch.nn.Linear):

            module_num_zeros, module_num_elements, _ = measure_module_sparsity(
                module, weight=weight, bias=bias, use_mask=linear_use_mask)
            num_zeros += module_num_zeros
            num_elements += module_num_elements

    sparsity = num_zeros / num_elements

    return num_zeros, num_elements, sparsity


### torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1, last_epoch=-1, verbose=False)
[refer](https://pytorch.org/docs/stable/optim.html)

Decays the learning rate of each parameter group by 'gamma' once the number of epoch reaches one of the milestones. Notice that such decay can happen simultaneously with other changes to the learning rate from outside this scheduler. When ```last_epoch = -1```, sets initial lr as lr.

Parameters:

- optimizer (Optimizer) – Wrapped optimizer.

- milestones (list) – List of epoch indices. Must be increasing.

- gamma (float) – Multiplicative factor of learning rate decay. Default: 0.1.

- last_epoch (int) – The index of last epoch. Default: -1.

- verbose (bool) – If True, prints a message to stdout for each update. Default: False.

Example:
```
# Assuming optimizer uses lr = 0.05 for all groups
>>> # lr = 0.05     if epoch < 30
>>> # lr = 0.005    if 30 <= epoch < 80
>>> # lr = 0.0005   if epoch >= 80
>>> scheduler = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
>>> for epoch in range(100):
>>>     train(...)
>>>     validate(...)
>>>     scheduler.step()
```

In [22]:
# def train_model(model, train_loader, test_loader, device, l1_regularization_strength = 0,
# num_epochs = 200
def fine_tune_train_model(model, train_loader, test_loader, device, l1_regularization_strength = 0,
                l2_regularization_strength = 1e-4, learning_rate = 1e-1, num_epochs = 20):

    # The training configurations were not carefully selected.

    criterion = nn.CrossEntropyLoss()

    model.to(device)

    # It seems that SGD optimizer is better than Adam optimizer for ResNet18 training on CIFAR10-
    optimizer = torch.optim.SGD(
        model.parameters(), lr = learning_rate,
        momentum = 0.9, weight_decay = l2_regularization_strength
    )
    # optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
    
    # Define learning rate scheduler-
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        # optimizer, milestones = [100, 150],
        optimizer, milestones = [5, 10],
        gamma = 0.1, last_epoch = -1)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=500)
    

    # Evaluation-
    model.eval()
    eval_loss, eval_accuracy = evaluate_model(
        model = model, test_loader = test_loader,
        device = device, criterion = criterion)
    
    print(f"Pre fine-tuning: val_loss = {eval_loss:.3f} & val_accuracy = {eval_accuracy * 100:.3f}%")
    # print("Epoch: {:03d} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(0, eval_loss, eval_accuracy))

    
    for epoch in range(num_epochs):

        # Training
        model.train()

        running_loss = 0
        running_corrects = 0

        for inputs, labels in train_loader:

            inputs = inputs.to(device)
            labels = labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)

            l1_reg = torch.tensor(0.).to(device)
            for module in model.modules():
                mask = None
                weight = None
                for name, buffer in module.named_buffers():
                    if name == "weight_mask":
                        mask = buffer
                for name, param in module.named_parameters():
                    if name == "weight_orig":
                        weight = param
                # We usually only want to introduce sparsity to weights and prune weights.
                # Do the same for bias if necessary.
                if mask is not None and weight is not None:
                    l1_reg += torch.norm(mask * weight, 1)

            loss += l1_regularization_strength * l1_reg

            loss.backward()
            optimizer.step()

            # statistics
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        train_loss = running_loss / len(train_loader.dataset)
        train_accuracy = running_corrects / len(train_loader.dataset)

        # Evaluation
        model.eval()
        eval_loss, eval_accuracy = evaluate_model(
            model = model, test_loader = test_loader,
            device = device, criterion = criterion)

        # Set learning rate scheduler
        scheduler.step()

        '''
        print(
            "Epoch: {:03d} Train Loss: {:.3f} Train Acc: {:.3f} Eval Loss: {:.3f} Eval Acc: {:.3f}"
            .format(epoch + 1, train_loss, train_accuracy, eval_loss,
                    eval_accuracy))
        '''
        print(f"epoch = {epoch + 1} loss = {train_loss:.3f}, accuracy = {train_accuracy * 100:.3f}%, val_loss = {eval_loss:.3f}, val_accuracy = {eval_accuracy * 100:.3f}% & LR: {optimizer.param_groups[0]['lr']:.4f}")

    return model


In [32]:
# Sanity check-
# fine_tuned_model = fine_tune_train_model(trained_model, train_loader, test_loader, device)

In [26]:
def iterative_pruning_finetuning(
    model, train_loader, test_loader, device,
    learning_rate, l1_regularization_strength,
    l2_regularization_strength, learning_rate_decay = 0.1,
    conv2d_prune_amount = 0.2, linear_prune_amount = 0.1,
    num_iterations = 10, num_epochs_per_iteration = 10,
    model_filename_prefix = "pruned_model", model_dir = "saved_models",
    grouped_pruning = False):
    
    '''
    num_iterations - number of pruning iterations/rounds
    num_epochs_per_iteration - number of fine-tuning rounds
    '''

    for i in range(num_iterations):

        print("\nPruning and Finetuning {}/{}".format(i + 1, num_iterations))

        print("Pruning...")


        # NOTE: For global pruning, linear/dense layer can also be pruned!
        if grouped_pruning == True:
            # grouped_pruning -> Global pruning
            parameters_to_prune = []
            for module_name, module in model.named_modules():
                if isinstance(module, torch.nn.Conv2d):
                    parameters_to_prune.append((module, "weight"))
            prune.global_unstructured(
                parameters_to_prune,
                pruning_method = prune.L1Unstructured,
                amount = conv2d_prune_amount,
            )
        else:
            for module_name, module in model.named_modules():
                if isinstance(module, torch.nn.Conv2d):
                    prune.l1_unstructured(
                        module, name = "weight",
                        amount = conv2d_prune_amount)
                elif isinstance(module, torch.nn.Linear):
                    prune.l1_unstructured(
                        module, name = "weight",
                        amount = linear_prune_amount)

        # Compute validation accuracy just after pruning-
        _, eval_accuracy = evaluate_model(
            model = model, test_loader = test_loader,
            device = device, criterion = None)

        '''
        classification_report = create_classification_report(
            model=model, test_loader=test_loader, device=device)
        '''

        # Compute global sparsity-
        num_zeros, num_elements, sparsity = measure_global_sparsity(
            model, weight = True,
            bias = False, conv2d_use_mask = True,
            linear_use_mask = False)
        
        print(f"Global sparsity = {sparsity * 100:.3f}% & val_accuracy = {eval_accuracy * 100:.3f}%")
        # print(model.conv1._forward_pre_hooks)

        print("\nFine-tuning...")

        # train_model(
        fine_tuned_model = fine_tune_train_model(
            model = model, train_loader = train_loader,
            test_loader = test_loader, device = device,
            l1_regularization_strength = l1_regularization_strength,
            l2_regularization_strength = l2_regularization_strength,
            # i -> current pruning round-
            # learning_rate = learning_rate * (learning_rate_decay ** i),
            learning_rate = 1e-1,
            num_epochs = num_epochs_per_iteration)

        _, eval_accuracy = evaluate_model(
            model=model, test_loader = test_loader,
            device = device, criterion = None)

        '''
        classification_report = create_classification_report(
            model=model, test_loader=test_loader, device=device)
        '''

        num_zeros, num_elements, sparsity = measure_global_sparsity(
            # model,
            fine_tuned_model, weight = True,
            bias = False, conv2d_use_mask = True,
            linear_use_mask = False)

        print(f"Post fine-tuning: Global sparsity = {sparsity * 100:.3f}% & val_accuracy = {eval_accuracy * 100:.3f}%")

        '''
        model_filename = "{}_{}.pt".format(model_filename_prefix, i + 1)
        model_filepath = os.path.join(model_dir, model_filename)
        save_model(model=model,
                   model_dir=model_dir,
                   model_filename=model_filename)
        model = load_model(model=model,
                           model_filepath=model_filepath,
                           device=device)
        '''
        
    return model


In [28]:
# num_classes = 10
# random_seed = 1
l1_regularization_strength = 0
l2_regularization_strength = 1e-4
learning_rate = 1e-3
learning_rate_decay = 1

In [29]:
_, eval_accuracy = evaluate_model(
    model = trained_model, test_loader=test_loader,
    device = device, criterion = None)

In [30]:
'''
classification_report = create_classification_report(
    model = trained_model, test_loader = test_loader,
    device = device)
'''

In [30]:
num_zeros, num_elements, sparsity = measure_global_sparsity(trained_model)
print(f"Global sparsity = {sparsity:.3f}% & val_accuracy = {eval_accuracy * 100:.3f}%")

# print("Test Accuracy: {:.3f}".format(eval_accuracy))
# print("Classification Report:")
# print(classification_report)
# print("Global Sparsity:")
# print("{:.2f}".format(sparsity))

Global sparsity = 0.000% & val_accuracy = 88.990%


In [31]:
model_dir = "saved_models"
model_filename = "resnet18_cifar10.pth"
model_filename_prefix = "pruned_model"
pruned_model_filename = "resnet18_pruned_cifar10.pth"
model_filepath = os.path.join(model_dir, model_filename)
pruned_model_filepath = os.path.join(model_dir, pruned_model_filename)


In [32]:
import copy

In [33]:
print("Iterative Pruning + Fine-Tuning...")

Iterative Pruning + Fine-Tuning...


In [34]:
# Create a deep copy of the pre-trained model-
pruned_model = copy.deepcopy(trained_model)


# Prune and fine-tune trained model-
'''
num_iterations - number of pruning iterations/rounds
num_epochs_per_iteration - number of fine-tuning rounds
'''
pruned_model = iterative_pruning_finetuning(
        model = pruned_model, train_loader = train_loader,
        test_loader = test_loader, device = device,
        learning_rate = learning_rate, learning_rate_decay = learning_rate_decay,
        l1_regularization_strength = l1_regularization_strength, l2_regularization_strength = l2_regularization_strength,
        conv2d_prune_amount = 0.2, linear_prune_amount = 0.1,
        num_iterations = 15, num_epochs_per_iteration = 10,
        model_filename_prefix = model_filename_prefix, model_dir = model_dir,
        grouped_pruning = True)


# Apply pruned mask to the parameters/weights and remove the masks-
remove_parameters(model = pruned_model)

_, eval_accuracy = evaluate_model(
    model = pruned_model, test_loader = test_loader,
    device = device, criterion = None
)

'''
classification_report = create_classification_report(
    model = pruned_model, test_loader = test_loader,
    device = device)
'''

num_zeros, num_elements, sparsity = measure_global_sparsity(pruned_model)


print(f"Global sparsity = {sparsity:.3f} & val_accuracy = {eval_accuracy:.3f}")
# print("Classification Report:")
# print(classification_report)
# NOTE: classification report is avoided as it's too verbose!



Pruning and Finetuning 1/15
Pruning...
Global sparsity = 19.991% & val_accuracy = 88.990%

Fine-tuning...
Pre fine-tuning: val_loss = 0.355 & val_accuracy = 88.990%
epoch = 1 loss = 0.586, accuracy = 0.796, val_loss = 0.580, val_accuracy = 80.290 & LR: 0.1000
epoch = 2 loss = 0.481, accuracy = 0.836, val_loss = 0.537, val_accuracy = 82.120 & LR: 0.1000
epoch = 3 loss = 0.445, accuracy = 0.845, val_loss = 0.563, val_accuracy = 80.870 & LR: 0.1000
epoch = 4 loss = 0.428, accuracy = 0.852, val_loss = 0.488, val_accuracy = 83.710 & LR: 0.1000
epoch = 5 loss = 0.404, accuracy = 0.860, val_loss = 0.531, val_accuracy = 82.090 & LR: 0.0100
epoch = 6 loss = 0.303, accuracy = 0.897, val_loss = 0.353, val_accuracy = 87.970 & LR: 0.0100
epoch = 7 loss = 0.268, accuracy = 0.907, val_loss = 0.348, val_accuracy = 88.030 & LR: 0.0100
epoch = 8 loss = 0.258, accuracy = 0.912, val_loss = 0.343, val_accuracy = 88.510 & LR: 0.0100
epoch = 9 loss = 0.246, accuracy = 0.914, val_loss = 0.339, val_accuracy =

In [39]:
# Remove pruning parameters-
final_model = remove_parameters(pruned_model)

In [41]:
# Compute final model's val_accuracy and global sparsity-
_, eval_accuracy = evaluate_model(
    model = final_model, test_loader = test_loader,
    device = device, criterion = None)

num_zeros, num_elements, sparsity = measure_global_sparsity(pruned_model)
print(f"Global sparsity = {sparsity * 100:.3f}%"
f" & val_accuracy = {eval_accuracy * 100:.3f}%")


Global sparsity = 96.437% & val_accuracy = 91.320%


In [46]:
# Save final trained and pruned model for later use-
torch.save(final_model.state_dict(), f"ResNet18_trained_sparsity-{sparsity * 100:.3f}.pth")

### Sanity check:

In [47]:
# Initialize and load trained and pruned model-
trained_pruned_model = ResNet18()
trained_pruned_model.load_state_dict(torch.load("/content/ResNet18_trained_sparsity-96.437.pth", map_location = device))
# trained_pruned_model.load_state_dict(final_model)

# Move model to GPU (if available)-
trained_pruned_model.to(device)

# Define cost function and optimizer-
criterion = nn.CrossEntropyLoss()

# It seems that SGD optimizer is better than Adam optimizer for ResNet18 training on CIFAR10-
optimizer = torch.optim.SGD(
        trained_pruned_model.parameters(), lr = learning_rate,
        momentum = 0.9, weight_decay = l2_regularization_strength
)


In [50]:
# Compute final model's val_accuracy and global sparsity-
eval_loss, eval_accuracy = evaluate_model(
    model = trained_pruned_model, test_loader=test_loader,
    device = device, criterion = None)

num_zeros, num_elements, sparsity = measure_global_sparsity(trained_pruned_model)
print(f"Global sparsity = {sparsity * 100:.3f}%, val_loss = {eval_loss:.3f}"
f" & val_accuracy = {eval_accuracy * 100:.3f}%")


Global sparsity = 96.437%, val_loss = 0.000 & val_accuracy = 91.320%
