In [1]:
!pip install -q torch torchvision numpy

import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

import numpy as np
from collections import OrderedDict
import itertools
import sys
import time

sys.path.append("..")

from prunenn.data import *
from prunenn.models import *
from prunenn.pruner import *


In [2]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

    return 1. * correct / len(test_loader.dataset)

In [3]:
dataset = 'mnist'

# Select model and data loaders based on data set
if dataset == 'mnist':
    model = MNIST_Net()
    (train_loader, test_loader) = get_mnist_loaders()
elif dataset == 'cifar':
    model = CIFAR_Net()
    (train_loader, test_loader) = get_cifar_loaders()  

In [4]:
'''
Pretraining full model
'''
epoch_range = 1
device = None
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(epoch_range):
    train(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)


Test set: Average loss: 0.2130, Accuracy: 9789/10000 (98%)



In [5]:
pruning_model = Pruner(model, thres = 0.5, function = "var")

pruning_model.to_train = False
t0 = time.time()
acc = test(pruning_model, device, test_loader)
t1 = time.time()
print("testing time", (t1-t0))
sum(p.numel() for p in model.parameters() if p.requires_grad)


Test set: Average loss: 0.2130, Accuracy: 9789/10000 (98%)

testing time 3.100321054458618


25670

In [None]:
def test_acc(model, test_loader):
    model.to_train = False
    t0 = time.time()
    acc = test(model, device, test_loader)
    t1 = time.time()
    model.to_train = True
    print("Inference time ", t1 - t0)
    print("Accuracy", acc)
    return acc

def retrain(model, test_loader):
    for epoch in range(epoch_range):
        model.to_train = True 
        train(model, device, train_loader, optimizer, epoch)
        model.to_train = False 
        test(model, device, test_loader)
        model.to_train = True
    return test_acc(model, test_loader)

def prune_loop(model, thresholds, sacrifice, test_loader):
    init_model_acc = test_acc(model, test_loader)

    for thres in thresholds:
        print("***** THRES = ", thres, " *****)")
        pruning_model = Pruner(model, thres = thres, function = "corrs")
        pruning_model.to_train = False
        acc = test(pruning_model, device, test_loader)
        pruning_model.prune()
        pruning_model.to_train = True
        pruning_model_acc = retrain(pruning_model, test_loader)

        i = 0
        while(pruning_model_acc <= init_model_acc - sacrifice and i < 3):
            print("--- accuracy drop ", i, " ---")
            pruning_model_acc = retrain(pruning_model, test_loader)
            i += 1

        if(pruning_model_acc <= init_model_acc - sacrifice):
            return model
        model = pruning_model.model
        print("Number of parameters", sum(p.numel() for p in model.parameters()))
    return model
  
thresholds = np.flip(np.logspace(np.log10(0.7), np.log10(0.98), num=10), axis=0)
sacrifice = 0.01
model = prune_loop(model, thresholds, sacrifice, test_loader)
    


Test set: Average loss: 0.2130, Accuracy: 9789/10000 (98%)

Inference time  3.1375491619110107
Accuracy 0.9789
***** THRES =  0.98  *****)


  c /= stddev[:, None]
  c /= stddev[None, :]
  nodes_to_prune = np.where((self.thres <= corrs[:,:]))



Test set: Average loss: 0.2130, Accuracy: 9789/10000 (98%)


Test set: Average loss: 0.1298, Accuracy: 9765/10000 (98%)


Test set: Average loss: 0.1298, Accuracy: 9765/10000 (98%)

Inference time  8.031752347946167
Accuracy 0.9765
Number of parameters 24164
***** THRES =  0.94403832854698  *****)

Test set: Average loss: 0.1298, Accuracy: 9765/10000 (98%)


Test set: Average loss: 0.0981, Accuracy: 9784/10000 (98%)


Test set: Average loss: 0.0981, Accuracy: 9784/10000 (98%)

Inference time  13.587354183197021
Accuracy 0.9784
Number of parameters 24164
***** THRES =  0.9093962915977304  *****)

Test set: Average loss: 0.0981, Accuracy: 9784/10000 (98%)


Test set: Average loss: 0.0981, Accuracy: 9750/10000 (98%)



In [None]:
t0 = time.time()
acc = test(model, device, test_loader)
t1 = time.time()
print("testing time", (t1-t0))
sum(p.numel() for p in model.parameters() if p.requires_grad)