In [1]:
import time
import copy
import torch
import torchvision
import matplotlib.pyplot as plt

from torchvision import models
import torch.nn.functional as F
from torch import nn
import torch.optim as optim
from torch.optim import lr_scheduler

from load_models import load_mobilenet, load_resnet
from load_data import load_imagenette
from utils import batch_accuracy

if torch.cuda.is_available() == True:
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")
print(device)
dtype = torch.float32

cuda:0


# Load Model

In [79]:
MOBILENET_PATH = './models/MobileNetV3Small.pt'
model = load_mobilenet(MOBILENET_PATH).to(device)

RESNET_PATH = './models/ResNet18.pt'
model = load_resnet(RESNET_PATH).to(device)


# Load Data

In [2]:
PATH = '/home/florian/data/imagenette2'
train_dl, val_dl = load_imagenette(PATH, 128, normalize=False)



image = next(iter(train_dl))[0][1].unsqueeze(0).to(device)
img_batch, label_batch = next(iter(train_dl))
img_batch, label_batch = img_batch.to(device), label_batch.to(device)

In [81]:
batch_accuracy(model, img_batch, label_batch)

0.9453

# Prune Model

In [8]:
from torch.nn.utils import prune

### Identify Modules to Prune

In [3]:
def get_prunable_modules(model):
    modules_to_prune = []
    for i, m in enumerate(list(model.named_modules())):
        if isinstance(m[1], torch.nn.Conv2d):
            modules_to_prune.append(model.get_submodule(m[0]))
    return modules_to_prune

In [84]:
modules_to_prune = get_prunable_modules(model)

In [71]:
modules_to_prune

[Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),
 Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
 Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
 Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
 Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
 Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
 Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
 Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False),
 Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
 Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
 Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
 Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
 Conv2d(128, 256, kernel_size=(1, 1), stride=

In [4]:
# Helper Function to determine the min, max and average sparsity of the prunable modules
def get_sparsity(modules):
    sparsities = []
    for m in modules:
        sparsities.append(((m.weight == 0).sum()/m.weight.numel()).item())
    sparsities = torch.Tensor(sparsities)
    return {
        'min':sparsities.min().item(), 
        'max':sparsities.max().item(), 
        'mean':sparsities.mean().item()}

### Pre Pruning Stats

In [54]:
batch_accuracy(model, img_batch, label_batch)

0.9766

In [55]:
get_sparsity(modules_to_prune)

{'min': 0.0, 'max': 0.0, 'mean': 0.0}

### Apply Pruning

In [5]:
def l1_prune(modules, amount):
    for m in modules:
        prune.L1Unstructured(.0).apply(m, 'weight', amount)
    return get_sparsity(modules)

In [73]:
l1_prune(modules_to_prune, .5)

{'min': 0.875, 'max': 0.875, 'mean': 0.875}

### Post Pruning Stats

In [57]:
batch_accuracy(model, img_batch, label_batch)

0.3281

In [58]:
get_sparsity(modules_to_prune)

{'min': 0.5, 'max': 0.5, 'mean': 0.5}

# Finetune model

In [9]:
EPOCHS = 25
RESNET_PATH = './models/ResNet18.pt'
model = load_resnet(RESNET_PATH).to(device)
modules_to_prune = get_prunable_modules(model)

dataloaders = {
    'train':train_dl, 
    'val':val_dl
}
dataset_sizes = {
    'train':len(train_dl.dataset), 
    'val':len(val_dl.dataset)
}

criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

print('CPR-1')
print(batch_accuracy(model, img_batch, label_batch))
print(get_sparsity(modules_to_prune))

print('CPR-2')
l1_prune(modules_to_prune, .5)
cpr2_state_dict = train_model(model, dataloaders, criterion, optimizer,exp_lr_scheduler, num_epochs=EPOCHS)
print(batch_accuracy(model, img_batch, label_batch))
print(get_sparsity(modules_to_prune))

SAVE_PATH = './models/CPR2-ResNet18.pt'
#SAVE_PATH = './models/CPR2-MobileNet.pt'
torch.save(model.state_dict(), SAVE_PATH)

print('CPR-4')
l1_prune(modules_to_prune, .5)
cpr2_state_dict = train_model(model, dataloaders, criterion, optimizer,exp_lr_scheduler, num_epochs=EPOCHS)
print(batch_accuracy(model, img_batch, label_batch))
print(get_sparsity(modules_to_prune))

SAVE_PATH = './models/CPR4-ResNet18.pt'
#SAVE_PATH = './models/CPR4-MobileNet.pt'
torch.save(model.state_dict(), SAVE_PATH)

print('CPR-8')
l1_prune(modules_to_prune, .5)
cpr2_state_dict = train_model(model, dataloaders, criterion, optimizer,exp_lr_scheduler, num_epochs=EPOCHS)
print(batch_accuracy(model, img_batch, label_batch))
print(get_sparsity(modules_to_prune))

SAVE_PATH = './models/CPR8-ResNet18.pt'
#SAVE_PATH = './models/CPR8-MobileNet.pt'
torch.save(model.state_dict(), SAVE_PATH)

print('CPR-16')
l1_prune(modules_to_prune, .5)
cpr2_state_dict = train_model(model, dataloaders, criterion, optimizer,exp_lr_scheduler, num_epochs=EPOCHS)
print(batch_accuracy(model, img_batch, label_batch))
print(get_sparsity(modules_to_prune))

SAVE_PATH = './models/CPR16-ResNet18.pt'
#SAVE_PATH = './models/CPR16-MobileNet.pt'
torch.save(model.state_dict(), SAVE_PATH)

print('CPR-32')
l1_prune(modules_to_prune, .5)
cpr2_state_dict = train_model(model, dataloaders, criterion, optimizer,exp_lr_scheduler, num_epochs=EPOCHS)
print(batch_accuracy(model, img_batch, label_batch))
print(get_sparsity(modules_to_prune))

SAVE_PATH = './models/CPR32-ResNet18.pt'
#SAVE_PATH = './models/CPR32-MobileNet.pt'
torch.save(model.state_dict(), SAVE_PATH)

CPR-1
0.9766
{'min': 0.0, 'max': 0.0, 'mean': 0.0}
CPR-2
Training complete in 9m 45s
Best val Acc: 0.929682
0.9766
{'min': 0.5, 'max': 0.5, 'mean': 0.5}
CPR-4
Training complete in 9m 45s
Best val Acc: 0.897325
0.9531
{'min': 0.75, 'max': 0.75, 'mean': 0.75}
CPR-8
Training complete in 9m 48s
Best val Acc: 0.802803
0.8438
{'min': 0.875, 'max': 0.875, 'mean': 0.875}
CPR-16
Training complete in 9m 47s
Best val Acc: 0.577580
0.6328
{'min': 0.9375, 'max': 0.9375, 'mean': 0.9375}
CPR-32
Training complete in 9m 46s
Best val Acc: 0.245350
0.2109
{'min': 0.96875, 'max': 0.96875, 'mean': 0.96875}


In [10]:
EPOCHS = 100
RESNET_PATH = './models/ResNet18.pt'
model = load_resnet(RESNET_PATH).to(device)
modules_to_prune = get_prunable_modules(model)

dataloaders = {
    'train':train_dl, 
    'val':val_dl
}
dataset_sizes = {
    'train':len(train_dl.dataset), 
    'val':len(val_dl.dataset)
}

criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

print('CPR-1')
print(batch_accuracy(model, img_batch, label_batch))
print(get_sparsity(modules_to_prune))

print('CPR-2')
l1_prune(modules_to_prune, .5)
cpr2_state_dict = train_model(model, dataloaders, criterion, optimizer,exp_lr_scheduler, num_epochs=EPOCHS)
print(batch_accuracy(model, img_batch, label_batch))
print(get_sparsity(modules_to_prune))

SAVE_PATH = './models/CPR2-ResNet18.pt'
#SAVE_PATH = './models/CPR2-MobileNet.pt'
torch.save(model.state_dict(), SAVE_PATH)

print('CPR-4')
l1_prune(modules_to_prune, .5)
cpr2_state_dict = train_model(model, dataloaders, criterion, optimizer,exp_lr_scheduler, num_epochs=EPOCHS)
print(batch_accuracy(model, img_batch, label_batch))
print(get_sparsity(modules_to_prune))

SAVE_PATH = './models/CPR4-ResNet18.pt'
#SAVE_PATH = './models/CPR4-MobileNet.pt'
torch.save(model.state_dict(), SAVE_PATH)

print('CPR-8')
l1_prune(modules_to_prune, .5)
cpr2_state_dict = train_model(model, dataloaders, criterion, optimizer,exp_lr_scheduler, num_epochs=EPOCHS)
print(batch_accuracy(model, img_batch, label_batch))
print(get_sparsity(modules_to_prune))

SAVE_PATH = './models/CPR8-ResNet18.pt'
#SAVE_PATH = './models/CPR8-MobileNet.pt'
torch.save(model.state_dict(), SAVE_PATH)

print('CPR-16')
l1_prune(modules_to_prune, .5)
cpr2_state_dict = train_model(model, dataloaders, criterion, optimizer,exp_lr_scheduler, num_epochs=EPOCHS)
print(batch_accuracy(model, img_batch, label_batch))
print(get_sparsity(modules_to_prune))

SAVE_PATH = './models/CPR16-ResNet18.pt'
#SAVE_PATH = './models/CPR16-MobileNet.pt'
torch.save(model.state_dict(), SAVE_PATH)

print('CPR-32')
l1_prune(modules_to_prune, .5)
cpr2_state_dict = train_model(model, dataloaders, criterion, optimizer,exp_lr_scheduler, num_epochs=EPOCHS)
print(batch_accuracy(model, img_batch, label_batch))
print(get_sparsity(modules_to_prune))

SAVE_PATH = './models/CPR32-ResNet18.pt'
#SAVE_PATH = './models/CPR32-MobileNet.pt'
torch.save(model.state_dict(), SAVE_PATH)

CPR-1
0.9766
{'min': 0.0, 'max': 0.0, 'mean': 0.0}
CPR-2
Training complete in 38m 56s
Best val Acc: 0.935541
0.9844
{'min': 0.5, 'max': 0.5, 'mean': 0.5}
CPR-4
Training complete in 39m 7s
Best val Acc: 0.903185
0.9609
{'min': 0.75, 'max': 0.75, 'mean': 0.75}
CPR-8
Training complete in 39m 5s
Best val Acc: 0.779873
0.8516
{'min': 0.875, 'max': 0.875, 'mean': 0.875}
CPR-16
Training complete in 39m 9s
Best val Acc: 0.543949
0.5703
{'min': 0.9375, 'max': 0.9375, 'mean': 0.9375}
CPR-32
Training complete in 39m 11s
Best val Acc: 0.241274
0.2891
{'min': 0.96875, 'max': 0.96875, 'mean': 0.96875}


In [6]:
def train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        #print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        #print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            #print('{} Loss: {:.4f} Acc: {:.4f}'.format(
            #    phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())


    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return copy.deepcopy(model.state_dict())

In [270]:
prune.l1_unstructured(model[1].layer1[0].conv1, 'weight', .1)

Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)

In [286]:
t = torch.ones((3,3));t

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])

In [287]:
prune.L1Unstructured(.1).prune(t)

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 0., 1.]])

In [288]:
prune.L1Unstructured(.1).apply(model[1].layer1[0].conv2, 'weight', .1)

<torch.nn.utils.prune.L1Unstructured at 0x7fa43421a2e0>

In [309]:
(model[1].layer1[0].conv2.weight == 0).sum() / model[1].layer1[0].conv2.weight.numel()

tensor(0.5500, device='cuda:0')

In [310]:
(model[1].layer1[0].conv1.weight == 0).sum() / model[1].layer1[0].conv1.weight.numel()

tensor(0.5500, device='cuda:0')

In [267]:
prune.RandomUnstructured(model[1].conv1, 'weight', .1)

AttributeError: 'ResNet' object has no attribute 'conv2'

In [55]:
t = torch.nn.Conv2d(3,3,1)
t.weight

Parameter containing:
tensor([[[[ 0.2369]],

         [[ 0.3606]],

         [[-0.0933]]],


        [[[-0.0698]],

         [[-0.3554]],

         [[ 0.3857]]],


        [[[-0.5537]],

         [[-0.1553]],

         [[-0.4332]]]], requires_grad=True)

In [61]:
p = prune.l1_unstructured(t, 'weight',.1)

In [62]:
p.weight

tensor([[[[ 0.2369]],

         [[ 0.3606]],

         [[-0.0000]]],


        [[[-0.0000]],

         [[-0.3554]],

         [[ 0.3857]]],


        [[[-0.5537]],

         [[-0.0000]],

         [[-0.4332]]]], grad_fn=<MulBackward0>)

In [None]:
p = torch.ran