In [1]:
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from fisher_resnet import *

Specify number of run and dataset version

In [1]:
run_number = 0
data = 'cifar10'

Experiments were performed for run_number = 0, 1, 2 and data="cifar10", "cifar100"

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


Learning rate scheduler

In [4]:
class LinearLR(torch.optim.lr_scheduler._LRScheduler):

    def __init__(self, optimizer, num_epochs, last_epoch=-1):
        self.num_epochs = max(num_epochs, 1)
        super(LinearLR, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        res = []
        for lr in self.base_lrs:
            res.append(np.maximum(
                lr * np.minimum(-(self.last_epoch + 1) * 1. / self.num_epochs + 1., 1.), 0.))
        return res

Erase network

In [5]:
def erase_net_tmp(net):
    for child1 in net.children():
        if isinstance(child1, nn.Sequential):
            for child2 in child1.children():
                if isinstance(child2, PreActBottleneck):
                    for child3 in child2.children():
                        if isinstance(child3, MaskedConv2d) or isinstance(child3, MaskedLinear):
                            child3.erase_tmp()

Specify number of classes

In [6]:
if data == 'cifar10':
    n_classes = 10
    base_path = 'resnet_cifar10/'
    path = base_path + '0_' + str(run_number) + '.pth'
else:
    n_classes = 100
    base_path = 'resnet_cifar100/'
    path = base_path + '0_' + str(run_number) + '.pth'

Make network instance, move it to gpu's

In [6]:
net = resnet50(n_classes)
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    net = nn.DataParallel(net)

net.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): PreActBottleneck(
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): MaskedConv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): MaskedConv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): MaskedConv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (shortcut): Sequential(
        (0): MaskedConv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (1): PreActBottleneck(
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True

Load data, make loaders: trainloader with train data for training, testloader with test data for testing, trainloader_prun with batch of train data for pruning (identify the sensitivity of connections).

In [7]:
transform_train = transforms.Compose([
    transforms.Pad(4, padding_mode='symmetric'),
    transforms.RandomCrop(32),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [8]:
if data == 'cifar10':
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_test)
else:
    trainset = torchvision.datasets.CIFAR100(root='./data', train=True,
                                        download=True, transform=transform_train)
    testset = torchvision.datasets.CIFAR100(root='./data', train=False,
                                       download=True, transform=transform_test)    

Files already downloaded and verified
Files already downloaded and verified


In [9]:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True)

In [10]:
trainloader_prun = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=True)

In [11]:
testloader = torch.utils.data.DataLoader(testset, batch_size=1024, shuffle=False)

Define function for one epoch training and testing

In [13]:
def train(net, optimizer):
    train_loss = 0.
    net.train()
    for input, target in tqdm(trainloader):
        input = input.to(device)
        target = target.to(device)
        
        optimizer.zero_grad()
        output = net(input)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        
    print ('Train loss ', train_loss/len(trainloader))
    
def test(net):
    net.eval()
    with torch.no_grad():
        test_loss = 0.
        correct = 0.
        for input, target in tqdm(testloader):
            input = input.to(device)
            target = target.to(device)
            
            output = net(input)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            _, predicted = torch.max(output.data, 1)
            correct += (predicted == target).sum().item()
            
        test_loss /= len(testset)

        print ('Test loss ', test_loss)
        print ('Test accuracy ', correct / len(testset))

Define loss function, optimizer and learning rate scheduler

In [14]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam([p for p in net.parameters() if p.requires_grad], lr=1e-3)
n_epochs = 200
lr_scheduler = LinearLR(optimizer, n_epochs)

Training cycle: 50 epochs

In [15]:
for epoch in range(n_epochs):
    if epoch > 50:
        break
    print ('Epoch ', epoch)
    lr_scheduler.step()
    train(net, optimizer)
    test(net)

  0%|          | 0/196 [00:00<?, ?it/s]

Epoch  0


100%|██████████| 196/196 [01:46<00:00,  2.29it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  1.7083885791350384


100%|██████████| 10/10 [00:06<00:00,  1.52it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  1.2960661193847656
Test accuracy  0.525
Epoch  1


100%|██████████| 196/196 [01:46<00:00,  2.30it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  1.1828298486617146


100%|██████████| 10/10 [00:06<00:00,  1.53it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  1.0250132751464844
Test accuracy  0.6341
Epoch  2


100%|██████████| 196/196 [01:46<00:00,  2.30it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.9381964474308248


100%|██████████| 10/10 [00:06<00:00,  1.53it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  1.3506386596679687
Test accuracy  0.5743
Epoch  3


100%|██████████| 196/196 [01:46<00:00,  2.29it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.7906164785428923


100%|██████████| 10/10 [00:06<00:00,  1.52it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.8387103576660156
Test accuracy  0.7147
Epoch  4


100%|██████████| 196/196 [01:46<00:00,  2.30it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.6621971112124774


100%|██████████| 10/10 [00:06<00:00,  1.53it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.6198561859130859
Test accuracy  0.7879
Epoch  5


100%|██████████| 196/196 [01:46<00:00,  2.30it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.5740808392969929


100%|██████████| 10/10 [00:06<00:00,  1.50it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.7081285766601563
Test accuracy  0.7575
Epoch  6


100%|██████████| 196/196 [01:46<00:00,  2.30it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.503546419952597


100%|██████████| 10/10 [00:06<00:00,  1.53it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.5690905517578125
Test accuracy  0.8021
Epoch  7


100%|██████████| 196/196 [01:46<00:00,  2.30it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.4623216386048161


100%|██████████| 10/10 [00:06<00:00,  1.53it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.4931868133544922
Test accuracy  0.8373
Epoch  8


100%|██████████| 196/196 [01:46<00:00,  2.30it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.42281889459308314


100%|██████████| 10/10 [00:06<00:00,  1.53it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.4739133361816406
Test accuracy  0.837
Epoch  9


100%|██████████| 196/196 [01:46<00:00,  2.30it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.3917872696658786


100%|██████████| 10/10 [00:06<00:00,  1.53it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.5404795806884766
Test accuracy  0.8345
Epoch  10


100%|██████████| 196/196 [01:46<00:00,  2.30it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.35792917568160565


100%|██████████| 10/10 [00:06<00:00,  1.53it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.4434983367919922
Test accuracy  0.8553
Epoch  11


100%|██████████| 196/196 [01:46<00:00,  2.30it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.3362467859928705


100%|██████████| 10/10 [00:06<00:00,  1.53it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.4957904113769531
Test accuracy  0.8414
Epoch  12


100%|██████████| 196/196 [01:45<00:00,  2.31it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.3158384602592916


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.42221148986816404
Test accuracy  0.8627
Epoch  13


100%|██████████| 196/196 [01:45<00:00,  2.31it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.3006056521315964


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.4622724334716797
Test accuracy  0.8541
Epoch  14


100%|██████████| 196/196 [01:45<00:00,  2.31it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.27445624099702254


100%|██████████| 10/10 [00:06<00:00,  1.53it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.37269473876953124
Test accuracy  0.8764
Epoch  15


100%|██████████| 196/196 [01:45<00:00,  2.31it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.26071678368108614


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.4489701782226562
Test accuracy  0.8567
Epoch  16


100%|██████████| 196/196 [01:45<00:00,  2.32it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.24460370610563123


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.4020872650146484
Test accuracy  0.8722
Epoch  17


100%|██████████| 196/196 [01:45<00:00,  2.31it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.24011597364228598


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.3999902648925781
Test accuracy  0.8752
Epoch  18


100%|██████████| 196/196 [01:45<00:00,  2.32it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.21737052283572908


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.3459312255859375
Test accuracy  0.8897
Epoch  19


100%|██████████| 196/196 [01:45<00:00,  2.31it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.21150539708989008


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.3461247589111328
Test accuracy  0.8896
Epoch  20


100%|██████████| 196/196 [01:45<00:00,  2.30it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.2006787828036717


100%|██████████| 10/10 [00:06<00:00,  1.53it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.3528336303710938
Test accuracy  0.8848
Epoch  21


100%|██████████| 196/196 [01:45<00:00,  2.31it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.1824371055604852


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.43776494140625
Test accuracy  0.8714
Epoch  22


100%|██████████| 196/196 [01:45<00:00,  2.31it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.18364146693932767


100%|██████████| 10/10 [00:06<00:00,  1.53it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.35169449462890623
Test accuracy  0.8893
Epoch  23


100%|██████████| 196/196 [01:45<00:00,  2.31it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.16657669439303632


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.3481072937011719
Test accuracy  0.8944
Epoch  24


100%|██████████| 196/196 [01:45<00:00,  2.30it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.15984546879724582


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.31248771057128905
Test accuracy  0.9009
Epoch  25


100%|██████████| 196/196 [01:45<00:00,  2.31it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.14954571441120032


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.3487221649169922
Test accuracy  0.8974
Epoch  26


100%|██████████| 196/196 [01:45<00:00,  2.31it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.14511993998775677


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.313746630859375
Test accuracy  0.9048
Epoch  27


100%|██████████| 196/196 [01:45<00:00,  2.31it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.14146097916729597


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.32204950408935545
Test accuracy  0.9037
Epoch  28


100%|██████████| 196/196 [01:45<00:00,  2.31it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.134784232406896


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.33651532592773437
Test accuracy  0.9049
Epoch  29


100%|██████████| 196/196 [01:45<00:00,  2.32it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.12834411811995872


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.3437420669555664
Test accuracy  0.8981
Epoch  30


100%|██████████| 196/196 [01:45<00:00,  2.31it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.11745796643425616


100%|██████████| 10/10 [00:06<00:00,  1.53it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.355141162109375
Test accuracy  0.8979
Epoch  31


100%|██████████| 196/196 [01:45<00:00,  2.31it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.11171244421251575


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.3675396728515625
Test accuracy  0.8996
Epoch  32


100%|██████████| 196/196 [01:45<00:00,  2.30it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.10714654358369964


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.3150574035644531
Test accuracy  0.9088
Epoch  33


100%|██████████| 196/196 [01:45<00:00,  2.31it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.10421801601745645


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.3610534591674805
Test accuracy  0.9024
Epoch  34


100%|██████████| 196/196 [01:45<00:00,  2.32it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.09376982428437593


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.35960879211425784
Test accuracy  0.9058
Epoch  35


100%|██████████| 196/196 [01:45<00:00,  2.32it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.0943704898925308


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.32832709197998045
Test accuracy  0.9116
Epoch  36


100%|██████████| 196/196 [01:45<00:00,  2.31it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.08638757596514662


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.3445085708618164
Test accuracy  0.912
Epoch  37


100%|██████████| 196/196 [01:45<00:00,  2.31it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.08895258734724959


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.35338389434814454
Test accuracy  0.9109
Epoch  38


100%|██████████| 196/196 [01:45<00:00,  2.32it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.08487442528296794


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.3125431945800781
Test accuracy  0.9183
Epoch  39


100%|██████████| 196/196 [01:45<00:00,  2.31it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.0776900912787081


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.33332210693359376
Test accuracy  0.9147
Epoch  40


100%|██████████| 196/196 [01:45<00:00,  2.32it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.07310958339699677


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.31007933349609373
Test accuracy  0.9191
Epoch  41


100%|██████████| 196/196 [01:45<00:00,  2.31it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.07733184227491824


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.3422850891113281
Test accuracy  0.9112
Epoch  42


100%|██████████| 196/196 [01:45<00:00,  2.31it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.06596260744964286


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.39614991455078125
Test accuracy  0.9088
Epoch  43


100%|██████████| 196/196 [01:45<00:00,  2.31it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.06820601285720358


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.3303802017211914
Test accuracy  0.9184
Epoch  44


100%|██████████| 196/196 [01:45<00:00,  2.31it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.05977750629452722


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.34624663696289065
Test accuracy  0.9164
Epoch  45


100%|██████████| 196/196 [01:45<00:00,  2.31it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.06354068277631791


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.3539949432373047
Test accuracy  0.9151
Epoch  46


100%|██████████| 196/196 [01:45<00:00,  2.31it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.06494692692114991


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.3418929443359375
Test accuracy  0.9163
Epoch  47


100%|██████████| 196/196 [01:45<00:00,  2.31it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.06736668948160142


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.3540228088378906
Test accuracy  0.9132
Epoch  48


100%|██████████| 196/196 [01:45<00:00,  2.31it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.05447816416355116


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.3535787292480469
Test accuracy  0.9169
Epoch  49


100%|██████████| 196/196 [01:45<00:00,  2.31it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.05423128986921238


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  0.34511240234375
Test accuracy  0.9221
Epoch  50


100%|██████████| 196/196 [01:45<00:00,  2.31it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.048123802210451386


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]

Test loss  0.3576927993774414
Test accuracy  0.9186





Save the model to directory

In [17]:
torch.save(net.state_dict(), path)

Function for network pruning

In [15]:
def prune_net(net, percentile, N=128):
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), 
                       lr=5e-4)
    erase_net_tmp(net)

    inc = 0
    for input, target in trainloader_prun:
        if inc > N - 1: break
        input, target = input.cuda(), target.cuda()

        optimizer.zero_grad()
        output = net(input)
        loss = criterion(output, target)
        loss.backward()

        for child1 in net.children():
            if isinstance(child1, nn.Sequential):
                for child2 in child1.children():
                    if isinstance(child2, PreActBottleneck):
                        for child3 in child2.children():
                            if isinstance(child3, MaskedConv2d) or isinstance(child3, MaskedLinear):
                                child3.weight_tmp = (child3.weight_tmp + 
                                child3.weight.data ** 2 * child3.weight.grad.data ** 2 / len(trainloader_prun) / 2.)
                                
                                if child3.bias is not None:
                                    child3.bias_tmp = (child3.bias_tmp +
                                    child3.bias.data ** 2 * child3.bias.grad.data ** 2 / len(trainloader_prun) / 2.)

        inc += 1
        
    values = []
    for child1 in net.children():
        if isinstance(child1, nn.Sequential):
            for child2 in child1.children():
                if isinstance(child2, PreActBottleneck):
                    for child3 in child2.children():
                        if isinstance(child3, MaskedConv2d) or isinstance(child3, MaskedLinear):
                            
                            values += [child3.weight_tmp.view(-1)]
                            
                            if child3.bias is not None:
                                values += [child3.bias_tmp.view(-1)]
                
    values = torch.cat(values, dim=0).cpu().detach().numpy()
    value = np.percentile(values, percentile)
    
    for child1 in net.children():
        if isinstance(child1, nn.Sequential):
            for child2 in child1.children():
                if isinstance(child2, PreActBottleneck):
                    for child3 in child2.children():
                        if isinstance(child3, MaskedConv2d) or isinstance(child3, MaskedLinear):
                            child3.truncate(value)
    return net

Define percentiles

In [16]:
percentiles = np.arange(0, 95, 15)[1:]
percentiles

array([15, 30, 45, 60, 75, 90])

Pruning retraining cycle: prune for each percentile, retrain to recover initial quality

In [17]:
for i, percentile in enumerate(percentiles):
    
    net = resnet50(n_classes)
    net.load_state_dict(torch.load(path))
    net.to(device)
    
    net = prune_net(net, percentile)
        
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), 
                      lr=5e-4)

    n_epochs = 1
    for epoch in range(n_epochs):
        print ('Epoch ', epoch)
        test(net)
        train(net, optimizer)
    
    test(net)
    
    torch.save(net.state_dict(), base_path + str(i+1) + '_' + str(run_number) + '.pth')

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch  0


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  317.1786578125
Test accuracy  0.7911


100%|██████████| 196/196 [01:46<00:00,  2.30it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.016177984980904326


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]


Test loss  0.3928516174316406
Test accuracy  0.9317


  0%|          | 0/10 [00:00<?, ?it/s]

Epoch  0


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  165.5460078125
Test accuracy  0.8239


100%|██████████| 196/196 [01:46<00:00,  2.29it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.015746946909289086


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]


Test loss  0.4207188018798828
Test accuracy  0.9267


  0%|          | 0/10 [00:00<?, ?it/s]

Epoch  0


100%|██████████| 10/10 [00:06<00:00,  1.51it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  38.276552734375
Test accuracy  0.8518


100%|██████████| 196/196 [01:46<00:00,  2.29it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.01441514073414918


100%|██████████| 10/10 [00:06<00:00,  1.53it/s]


Test loss  0.40541546936035155
Test accuracy  0.9317


  0%|          | 0/10 [00:00<?, ?it/s]

Epoch  0


100%|██████████| 10/10 [00:06<00:00,  1.53it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  1.85873466796875
Test accuracy  0.8777


100%|██████████| 196/196 [01:46<00:00,  2.29it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.01440507626191865


100%|██████████| 10/10 [00:06<00:00,  1.53it/s]


Test loss  0.4354674011230469
Test accuracy  0.9291


  0%|          | 0/10 [00:00<?, ?it/s]

Epoch  0


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  6.084126171875
Test accuracy  0.7088


100%|██████████| 196/196 [01:46<00:00,  2.29it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.016161300351235027


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]


Test loss  0.4405260009765625
Test accuracy  0.9338


  0%|          | 0/10 [00:00<?, ?it/s]

Epoch  0


100%|██████████| 10/10 [00:06<00:00,  1.53it/s]
  0%|          | 0/196 [00:00<?, ?it/s]

Test loss  12.77069228515625
Test accuracy  0.3144


100%|██████████| 196/196 [01:46<00:00,  2.30it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Train loss  0.03013655732442834


100%|██████████| 10/10 [00:06<00:00,  1.54it/s]


Test loss  0.4074203857421875
Test accuracy  0.9299
