In [1]:
import torch
import torch.nn as nn
import numpy as np

In [2]:
from utils.load_data import load_data
from utils.load_model import load_model
from utils.test_model import test_model

In [3]:
dataloaders, dataset_size, class_name = load_data('../data')

In [4]:
conv_idx = [2, 8, 11, 14, 17, 21, 24, 27, 30, 33, 37, 40, 43, 46, 49, 53, 56, 59, 62, 65]
skip_layers = [2, 11, 14, 17, 24, 33, 40, 49, 56, 59, 65]
prune_layers = [x for x in conv_idx if x not in skip_layers]
prune_prob = [0.1, 0.2, 0.3, 0.4, 0.6]

In [5]:
class Node():
    
    def __init__(self, mod):
        self.next = None
        self.prev = None
        self.name = type(mod)
        
        self.prev_pruned = None
        self.pruned = None
        self.no_filter = None
        self.mask_filter = None
        
        self.layer_id = 1
            
    def make_mask(self, mod, skip_layers, prune_layers):
        
        no_filter = mod.weight.data.shape[0]

        if mod.kernel_size == (1,1):
            self.pruned = False
            self.no_filter = None
            self.mask_filter = None
            return
        
        elif self.layer_id in skip_layers:
            mask_filter = torch.ones(no_filter)
            
            self.pruned = False
            self.no_filter = no_filter
            self.mask_filter = mask_filter
            return
        
        elif self.layer_id in prune_layers:
            if self.layer_id <= 8:
                stage = 0
            elif self.layer_id <= 21:
                stage = 1
            elif self.layer_id <= 36:
                stage = 2
            elif self.layer_id <= 53:
                stage = 3
            else:
                stage = 4
            prune_prob_stage = prune_prob[stage]
            weight_copy = mod.weight.data.abs().clone().cpu().numpy()
            L1_norm = np.sum(weight_copy, axis = (1,2,3))
            num_keep = int(no_filter * (1 - prune_prob_stage))
            arg_max = np.argsort(L1_norm)
            arg_max_rev = arg_max[::-1][:num_keep]
            mask_filter = torch.zeros(no_filter)
            mask_filter[arg_max_rev.tolist()] = 1
            
            self.pruned = True
            self.no_filter = num_keep
            self.mask_filter = mask_filter
            return

In [6]:
class PruneDDL():
        
    def __init__(self):
        self.head = None
        
    def append(self, mod, process = False):
        
        if self.head == None:
            self.head = Node(mod)
            self.head.next = None
            self.head.prev = None
        else:
            ptr = self.head
            while ptr.next!=None:
                ptr = ptr.next
            new_node = Node(mod)
            new_node.next = None
            ptr.next = new_node
            new_node.prev = ptr
        
            if new_node.prev == None:
                new_node.layer_id = 2
            else:
                new_node.layer_id = new_node.prev.layer_id + 1
            
            ptr = self.head
            
            if isinstance(mod, nn.Conv2d) and process:
                
                if new_node.layer_id!=2:
                    prev_layer_id = conv_idx[conv_idx.index(new_node.layer_id) - 1]
                    while ptr.layer_id!=prev_layer_id:
                        ptr = ptr.next

                    if ptr.pruned:
                        new_node.prev_pruned = True
                
                new_node.make_mask(mod, skip_layers, prune_layers)
                
    def prev_conv_dim(self, layer_id):
        
        if layer_id == 2:
            return
        prev_layer_id = conv_idx[conv_idx.index(layer_id) - 1]
        ptr = self.head
        
        while ptr.layer_id!=prev_layer_id:
            ptr = ptr.next
        
        return ptr.no_filter

In [7]:
model_orig = load_model('../models/resnet18_01')
model_orig_ddl = PruneDDL()

for m in model_orig.modules():
    model_orig_ddl.append(m, True)

m = [x for x in model_orig.modules() if isinstance(x, nn.Linear)]
model_orig_ddl.append(m[0])

node_orig = model_orig_ddl.head

<All keys matched successfully>


In [8]:
model_prune = load_model()

In [9]:
for m0, m1 in zip(model_orig.modules(), model_prune.modules()):
    
    if node_orig == None:
        break
        
    if isinstance(m0, nn.Conv2d):
        
        if m0.kernel_size == (1,1):
            m1.weight.data = m0.weight.data.clone()
            
        if node_orig.layer_id == 2:
            m1.weight.data = m0.weight.data.clone()
            
        if node_orig.pruned:
            mask = node_orig.mask_filter
            idx = np.squeeze(np.argwhere(np.asarray(mask.cpu().numpy())))
            if idx.size == 1:
                idx = np.resize(idx, (1,))
            w = m0.weight.data[idx.tolist(), :, :, :].clone()
            m1.weight.data = w.clone()
            
        if node_orig.prev_pruned:
            ptr = model_orig_ddl.head
            prev_layer_id = conv_idx[conv_idx.index(node_orig.layer_id)-1]
            while ptr.layer_id!=prev_layer_id:
                ptr = ptr.next
            mask = ptr.mask_filter
            idx = np.squeeze(np.argwhere(np.asarray(mask.cpu().numpy())))
            if idx.size == 1:
                idx = np.resize(idx, (1,))
            w = m0.weight.data[:, idx.tolist(), :, :].clone()
            m1.weight.data = w.clone()
            
    elif isinstance(m0, nn.BatchNorm2d):
        assert isinstance(m1, nn.BatchNorm2d)
        
        if node_orig.prev.pruned:
            mask = node_orig.prev.mask_filter
            idx = np.squeeze(np.argwhere(np.asarray(mask.cpu().numpy())))
            if idx.size == 1:
                idx = np.resize(idx, (1,))
            m1.weight.data = m0.weight.data[idx.tolist()].clone()
            m1.bias.data = m0.bias.data[idx.tolist()].clone()
            m1.running_mean = m0.running_mean[idx.tolist()].clone()
            m1.running_var = m0.running_var[idx.tolist()].clone()
        else:
            m1.weight.data = m0.weight.data.clone()
            m1.bias.data = m0.bias.data.clone()
            m1.running_mean = m0.running_mean.clone()
            m1.running_var = m0.running_var.clone()
            
    elif isinstance(m0, nn.Linear):
        m1.weight.data = m0.weight.data.clone()
        m1.bias.data = m0.bias.data.clone()
            
    node_orig = node_orig.next  
    

In [10]:
test_model(model_orig, dataloaders['test'])

Accuracy of the model on 540 test images is 93.33


In [11]:
test_model(model_prune, dataloaders['test'])

Accuracy of the model on 540 test images is 85.56


In [12]:
# torch.save(model_prune.state_dict(), '../models/pruned_model1')

In [13]:
import torch.optim as optim
from torch.optim import lr_scheduler
import matplotlib.pyplot as plt

In [14]:
from utils.visualize import visualize_model
from utils.train_model import train_model

In [15]:
device = ('cuda:0' if torch.cuda.is_available() else 'cpu')

In [16]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model_prune.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
trainloaders = {x: dataloaders[x] for x in ['train', 'val']}

In [17]:
model_prune_retrain = train_model(model_prune, trainloaders, dataset_size, criterion, optimizer, exp_lr_scheduler, 10)

Epoch 0/9
----------
train Loss: 0.3407 Acc: 0.8685
val Loss: 0.1921 Acc: 0.9247
Epoch time: 2m 41s

Epoch 1/9
----------
train Loss: 0.2614 Acc: 0.9061
val Loss: 0.1673 Acc: 0.9486
Epoch time: 2m 43s

Epoch 2/9
----------
train Loss: 0.2304 Acc: 0.9100
val Loss: 0.1776 Acc: 0.9452
Epoch time: 2m 42s

Epoch 3/9
----------
train Loss: 0.2228 Acc: 0.9161
val Loss: 0.3381 Acc: 0.9144
Epoch time: 2m 41s

Epoch 4/9
----------
train Loss: 0.1890 Acc: 0.9284
val Loss: 0.1838 Acc: 0.9384
Epoch time: 2m 59s

Epoch 5/9
----------
train Loss: 0.1764 Acc: 0.9320
val Loss: 0.2031 Acc: 0.9281
Epoch time: 2m 57s

Epoch 6/9
----------
train Loss: 0.1767 Acc: 0.9323
val Loss: 0.1432 Acc: 0.9384
Epoch time: 2m 46s

Epoch 7/9
----------
train Loss: 0.1102 Acc: 0.9615
val Loss: 0.1516 Acc: 0.9555
Epoch time: 3m 10s

Epoch 8/9
----------
train Loss: 0.1047 Acc: 0.9612
val Loss: 0.1282 Acc: 0.9555
Epoch time: 2m 57s

Epoch 9/9
----------
train Loss: 0.1016 Acc: 0.9630
val Loss: 0.1218 Acc: 0.9555
Epoch time

In [18]:
test_model(model_prune_retrain, dataloaders['test'])

Accuracy of the model on 540 test images is 95.56


In [19]:
model_prune_retrain.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [26]:
example = torch.rand(1, 3, 224, 224)
example = example.to(device)
traced_script_module = torch.jit.trace(model_prune_retrain, example)
traced_script_module.save('../models/model_pruned.pt')

In [28]:
traced_script_module = torch.jit.trace(model_orig, example)
traced_script_module.save('../models/model_orig.pt')

In [58]:
checkpoint = {
    'model': model_orig.__class__,
    'state_dict' : model_orig.state_dict()
}

torch.save(checkpoint, '../models/model_orig_checkpoint.pth')

In [59]:
def load_checkpoint(filepath):
    checkpoint = torch.load(filepath)
    model_o = checkpoint['model']
    model_o.load_state_dict(torch.load('../models/resnet18_01'))
    
    return model_o

In [60]:
model = load_checkpoint('../models/model_orig_checkpoint.pth')

TypeError: load_state_dict() missing 1 required positional argument: 'state_dict'

In [48]:
torch.save(model_prune_retrain.state_dict(), '../models/pruned_model_retrained')