# GraSP pruning

In [1]:
!pip install torchbearer

Collecting torchbearer
  Downloading torchbearer-0.5.3-py3-none-any.whl (138 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m138.1/138.1 kB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torchbearer
Successfully installed torchbearer-0.5.3
[0m

In [2]:
import torch
import torch.nn.functional as F
from torch.nn.utils import prune
import torchvision.transforms as transforms
import torchbearer
from torch import nn
from torch import optim
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10
from torchvision.models import vgg16_bn, vgg19_bn
from torchbearer import Trial
import numpy as np
import random

In [3]:
# fix random seed for reproducibility
seed = 7
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True

# Dataset Preparation
The CIFAR-10 dataset is downloaded and transformed with a batch size of 128, using the same parameters as the source research paper [1]. The training dataset is transformed by random crop followed by horizontal flips.

In [4]:
train_batch_size = 128
test_batch_size = 128

# convert each image to tensor format
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

# load data
trainset = CIFAR10(root='.', train=True, download=True, transform=transform_train)
testset = CIFAR10(root='.', train=False, download=True, transform=transform_test)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:05<00:00, 28599440.37it/s]


Extracting ./cifar-10-python.tar.gz to .
Files already downloaded and verified


In [5]:
# create data loaders
trainloader = DataLoader(trainset, batch_size=train_batch_size, shuffle=True)
testloader = DataLoader(testset, batch_size=test_batch_size, shuffle=False)

# Defining Model
The default VGG16 model from Pytorch is modified to align with the CIFAR-10 dataset which has 10 output classes.

In [29]:
num_classes = 10

model = vgg16_bn()
model.features = model.features[:-1]
model.avgpool = nn.AvgPool2d(2)
model.classifier = nn.Linear(512, num_classes)
model

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 256

Reinitialised weights using [He initialisation](https://pytorch.org/docs/stable/nn.init.html#torch.nn.init.kaiming_normal_) method for all linear and convolution layers. And using random weight from a uniform distribution for Batch normalization weight with zero bias.

In [30]:
def init_weights(m):
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        # He initialization
        nn.init.kaiming_normal_(m.weight)
        # Glorot initialization
        # nn.init.xavier_normal(m.weight)
    elif isinstance(m, torch.nn.BatchNorm2d):
        m.weight.data = torch.rand(m.weight.data.shape)
        m.bias.data = torch.zeros_like(m.bias.data)

model = model.apply(init_weights)

## GraSP Pruning

In [32]:
def data_sampling(trainset):
    # Sample 10 data from each class (100 data in total)
    s_inputs = []
    s_targets = []
    samples = next(iter(DataLoader(trainset, batch_size=300, shuffle=True)))
    for t in set(trainset.targets):
        indices = random.sample([i for i, x in enumerate(samples[1]) if x == t], 10)
        s_inputs += [samples[0][i].tolist() for i in indices]
        s_targets += [samples[1][i].tolist() for i in indices]

    s_inputs = torch.Tensor(np.array(s_inputs))
    s_targets = torch.Tensor(np.array(s_targets)).to(torch.long)

    print(s_inputs.shape)
    print(s_targets.shape)
    return s_inputs, s_targets

def magnitude_score(model, trainset):
    '''calculate pruning score for all prune-able layers
        return: {layer: tensor of score for each weight in the layer}'''
    # modified from: https://github.com/alecwangcq/GraSP/blob/master/pruner/GraSP.py
    T = 200
    model.zero_grad()
    weights = [weight for name, weight in model.named_parameters() if name.endswith('.weight')]
    
    # sampling from trainset
    s_inputs, s_targets = data_sampling(trainset)
    
    # compute the Hessian-gradient
    outputs = model.forward(s_inputs)/T
    loss = F.cross_entropy(outputs, s_targets)
    grad_w = list(torch.autograd.grad(loss, weights))

    outputs = model.forward(s_inputs)/T
    loss = F.cross_entropy(outputs, s_targets)
    grad_f = list(torch.autograd.grad(loss, weights, create_graph=True))

    z = sum([(gw.data * gf).sum() for gw, gf in zip(grad_w, grad_f)])
    z.backward()

    scores = {}
    for name, weight in model.named_parameters():
        if name.endswith('.weight'):
            # score is calculated by -weight * Hessian-gradient
            scores[name.replace('.weight', '')] = -weight.detach() * weight.grad
    return scores

def create_mask(model, scores, sparse_ratio, prune_type='min', random_shuffling=False):
    # flatted all score to a vector 
    # modified from: https://github.com/alecwangcq/GraSP/blob/master/pruner/GraSP.py
    score_vec = torch.cat([torch.flatten(x) for x in scores.values()])
    
    # nomalisation 
    eps = 1e-10
    norm_factor = torch.abs(torch.sum(score_vec)) + eps
    score_vec.div_(norm_factor)
    
    # calculate number of parameters to prune
    num_prune = np.ceil(len(score_vec) * sparse_ratio).astype(int)
    num_keep = (score_vec.shape - num_prune)[0]
    print("Number of params to prune:", num_prune)
    print("Remaining params:", num_keep)
    
    if prune_type == 'top':
        # prune top k score
        threshold = torch.topk(score_vec, num_prune, sorted=True)[0][-1]
        print('threshold', threshold.data)
    elif prune_type == 'min':
        # prune min k score
        threshold = torch.topk(score_vec, num_keep, sorted=True)[0][-1]
        print('threshold', threshold.data)
        
    # create mask
    masks = {}
    named_modules = dict(model.named_modules())
    
    for m, g in scores.items():
        layer = named_modules[m]
        if prune_type == 'top':
            # prune top k score
            masks[layer] = ((g / norm_factor) <= threshold).float()
        elif prune_type == 'min':
            # prune min k score
            masks[layer] = ((g / norm_factor) >= threshold).float()
            
        if random_shuffling:
            ## randomly shuffle weight within each layer
            idx = torch.randperm(masks[layer].nelement())
            masks[layer] = masks[layer].view(-1)[idx].view(masks[layer].size())
            
    print('masks', torch.sum(torch.cat([torch.flatten(x == 1) for x in masks.values()])))
    return masks

def prune_model(model, masks, reinit=False):
    if reinit:
        # re-initialise weight
        model = model.apply(init_weights)

    for m in masks.keys():
        m = prune.custom_from_mask(m, name='weight', mask=masks[m].data)
    
    return model

In [33]:
sparse_ratio = 0.9
# for inversion, use prune_type = 'min'
prune_type = 'top' # ['top', 'min']
random_shuffling = False
reinit = False

scores = magnitude_score(model, trainset)
masks = create_mask(model, scores, sparse_ratio, prune_type, random_shuffling)
model = prune_model(model, masks, reinit)

torch.Size([100, 3, 32, 32])
torch.Size([100])
Number of params to prune: 13247828
Remaining params: 1471980
threshold tensor(-1.0171e-05)
masks tensor(1471981)


## Training

In [37]:
import matplotlib.pyplot as plt

def plot_loss(train_loss, test_loss):
    plt.plot(train_loss, label="Training data")
    plt.plot(test_loss, label="Validation data")
    plt.xlabel("Epochs", fontsize="18")
    plt.ylabel("Loss", fontsize="18")
    plt.tick_params(axis='both', which='major', labelsize=15)
    plt.legend(fontsize="15")
    plt.grid()
    plt.show();

def plot_acc(train_acc, test_acc):
    plt.plot(train_acc, label="Training data")
    plt.plot(test_acc, label="Validation data")
    plt.xlabel("Epochs", fontsize="18")
    plt.ylabel("Accuracy", fontsize="18")
    plt.tick_params(axis='both', which='major', labelsize=15)
    plt.legend(fontsize="15")
    plt.grid();

In [38]:
from torchbearer import Callback
from torchbearer import callbacks
from torchbearer.callbacks import MultiStepLR

@callbacks.on_end_epoch
def callback(state):
    try:
        train_loss[state[torchbearer.state.EPOCH]] = state[torchbearer.state.METRICS]['loss']
        train_acc[state[torchbearer.state.EPOCH]] = state[torchbearer.state.METRICS]['acc']
        test_loss[state[torchbearer.state.EPOCH]] = state[torchbearer.state.METRICS]['val_loss']
        test_acc[state[torchbearer.state.EPOCH]] = state[torchbearer.state.METRICS]['val_acc']
    except:
        pass

In [39]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# train the model using cross entropy loss with loss ratio scheduler and SDG optimiser
def train_model(model, epochs=80):
    model = model.to(device)
    loss_function = nn.CrossEntropyLoss()
    scheduler = callbacks.MultiStepLR(milestones=[40, 60], gamma=0.1)
    optimiser = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)

    trial = torchbearer.Trial(model, optimiser, loss_function, metrics=['loss', 'accuracy'], callbacks=[callback, scheduler]).to(device)
    trial.with_generators(trainloader, test_generator=testloader, val_generator=testloader)
    trial.run(epochs)
    results = trial.evaluate(data_key=torchbearer.TEST_DATA)
    print(results)

In [40]:
num_epochs = 80

# save loss and accuracy during training
train_loss = np.zeros(num_epochs)
train_acc = np.zeros(num_epochs)
test_loss = np.zeros(num_epochs)
test_acc = np.zeros(num_epochs)

# train model
train_model(model, epochs=num_epochs)

0/2(t):   0%|          | 0/391 [00:00<?, ?it/s]



0/2(v):   0%|          | 0/79 [00:00<?, ?it/s]

1/2(t):   0%|          | 0/391 [00:00<?, ?it/s]

1/2(v):   0%|          | 0/79 [00:00<?, ?it/s]

0/1(e):   0%|          | 0/79 [00:00<?, ?it/s]

{'test_loss': 1.0348103046417236, 'test_acc': 0.6349999904632568}


In [None]:
plot_loss(train_loss, test_loss)

In [None]:
plot_acc(train_acc, test_acc)

In [None]:
torch.save(model.state_dict(), "./weight.weights")

# References
[1] Jonathan Frankle, Gintare Karolina Dziugaite, Daniel Roy, and Michael Carbin. Pruning neural net- works at initialization: Why are we missing the mark? In International Conference on Learning Representations, 2021. URL https://openreview.net/forum?id=Ig-VyQc-MLK