In [None]:
!pip install torchbearer

In [None]:
import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import torchbearer

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from torchvision.models import vgg16_bn, vgg19_bn
num_classes = 10
model = vgg16_bn()
model.features = model.features[:-1]
model.avgpool = nn.AvgPool2d(2)
model.classifier = nn.Linear(512, num_classes)

In [None]:
def init_weights(m):
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight)
    elif isinstance(m, torch.nn.BatchNorm2d):
        m.weight.data = torch.rand(m.weight.data.shape)
        m.bias.data = torch.zeros_like(m.bias.data)

model = model.apply(init_weights)

In [None]:
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10

train_batch_size = 128
test_batch_size = 128

# convert each image to tensor format
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

# load data
trainset = CIFAR10(root='.', train=True, download=True, transform=transform_train)
testset = CIFAR10(root='.', train=False, download=True, transform=transform_test)

In [None]:
from torch.utils.data import DataLoader, random_split

trainloader = DataLoader(trainset, batch_size=train_batch_size, shuffle=True)
# validloader = DataLoader(valid_data, batch_size=train_batch_size, shuffle=True)
testloader = DataLoader(testset, batch_size=test_batch_size, shuffle=False)

In [None]:
import random
import numpy as np

sparse_ratio = 0.6

s_inputs = []
s_targets = []

samples = next(iter(DataLoader(trainset, batch_size=300, shuffle=True)))
for t in set(trainset.targets):
    indices = random.sample([i for i, x in enumerate(samples[1]) if x == t], 10)
    s_inputs += [samples[0][i].tolist() for i in indices]
    s_targets += [samples[1][i].tolist() for i in indices]

s_inputs = torch.Tensor(np.array(s_inputs))
s_targets = torch.Tensor(np.array(s_targets)).to(torch.long)

print(s_inputs.shape)
print(s_targets.shape)

In [None]:
import time

model.train()
model.zero_grad()
weights = [weight for name, weight in model.named_parameters() if name.endswith('.weight')]

# feed sample data to the model
outputs = model.forward(s_inputs)
loss = F.cross_entropy(outputs, s_targets).backward()
# grad_w = list(torch.autograd.grad(loss, weights))


In [None]:
scores = {}

for name, weight in model.named_parameters():
    if name.endswith('.weight'):
        scores[name.replace('.weight', '')] = torch.abs(weight.detach() * weight.grad) 

In [None]:
score_vec = torch.cat([torch.flatten(x) for x in scores.values()])
norm_factor = torch.sum(score_vec)
score_vec.div_(norm_factor)

In [None]:
num_prune = np.ceil(len(score_vec) * sparse_ratio).astype(int)
print("Number of params to prune:", num_prune)
print("Remaining params:", score_vec.shape - num_prune)

In [None]:
num_keep = (score_vec.shape - num_prune)[0]
threshold = torch.topk(score_vec, num_prune, sorted=True)[0][-1]
print(threshold)

In [None]:
masks = {}
named_modules = dict(model.named_modules())

for m, g in scores.items():
    masks[named_modules[m]] = ((g / norm_factor) <= threshold).float()

In [None]:
print('Masks')
print(torch.sum(torch.cat([torch.flatten(x == 1) for x in masks.values()])))

In [None]:
from torch.nn.utils import prune

for m in masks.keys():
    m = prune.custom_from_mask(m, name='weight', mask=masks[m])

In [None]:
import matplotlib.pyplot as plt

def plot_loss(train_loss, test_loss):
    plt.plot(train_loss, label="Training data")
    plt.plot(test_loss, label="Validation data")
    plt.xlabel("Epochs", fontsize="18")
    plt.ylabel("Loss", fontsize="18")
    plt.tick_params(axis='both', which='major', labelsize=15)
    plt.legend(fontsize="15")
    plt.grid()
    plt.show();

def plot_acc(train_acc, test_acc):
    plt.plot(train_acc, label="Training data")
    plt.plot(test_acc, label="Validation data")
    plt.xlabel("Epochs", fontsize="18")
    plt.ylabel("Accuracy", fontsize="18")
    plt.tick_params(axis='both', which='major', labelsize=15)
    plt.legend(fontsize="15")
    plt.grid();

In [None]:
from torchbearer import Trial

In [None]:
from torchbearer import Callback
from torchbearer import callbacks
from torchbearer.callbacks import MultiStepLR

@callbacks.on_end_epoch
def callback(state):
    try:
        train_loss[state[torchbearer.state.EPOCH]] = state[torchbearer.state.METRICS]['loss']
        train_acc[state[torchbearer.state.EPOCH]] = state[torchbearer.state.METRICS]['acc']
        test_loss[state[torchbearer.state.EPOCH]] = state[torchbearer.state.METRICS]['val_loss']
        test_acc[state[torchbearer.state.EPOCH]] = state[torchbearer.state.METRICS]['val_acc']
    except:
        pass

In [None]:
def train_model(model, epochs=80):
    model = model.to(device)
    loss_function = nn.CrossEntropyLoss()
    scheduler = callbacks.MultiStepLR(milestones=[40, 60], gamma=0.1)
    optimiser = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)

    trial = torchbearer.Trial(model, optimiser, loss_function, metrics=['loss', 'accuracy'], callbacks=[callback, scheduler]).to(device)
    trial.with_generators(trainloader, test_generator=testloader, val_generator=testloader)
    trial.run(epochs)
    results = trial.evaluate(data_key=torchbearer.TEST_DATA)
    print(results)

In [None]:
from torch import optim


num_epochs = 80

train_loss = np.zeros(num_epochs)
train_acc = np.zeros(num_epochs)
test_loss = np.zeros(num_epochs)
test_acc = np.zeros(num_epochs)

train_model(model, epochs=num_epochs)

In [None]:
plot_loss(train_loss, test_loss)

In [None]:
plot_acc(train_acc, test_acc)

In [None]:
#save the trained model weights
torch.save(model.state_dict(), "./vgg16-60-inversion.weights")
# from google.colab import files
# files.download('vgg19-90-2.weights')