In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

In [2]:
from utils.load_data import load_data
from utils.load_model import load_model
from utils.test_model import test_model

SyntaxError: invalid syntax (test_model.py, line 33)

In [None]:
dataloaders, dataset_size, class_name = load_data('../data')

In [None]:
conv_idx = [2, 8, 11, 14, 17, 21, 24, 27, 30, 33, 37, 40, 43, 46, 49, 53, 56, 59, 62, 65]
skip_layers = [2, 11, 14, 17, 24, 33, 40, 49, 56, 59, 65]
prune_layers = [x for x in conv_idx if x not in skip_layers]
prune_prob = [0.1, 0.2, 0.3, 0.4, 0.6]

In [None]:
class Node():
    
    def __init__(self, mod):
        self.next = None
        self.prev = None
        self.name = type(mod)
        
        self.prev_pruned = None
        self.pruned = None
        self.no_filter = None
        self.mask_filter = None
        
        self.layer_id = 1
            
    def make_mask(self, mod, skip_layers, prune_layers):
        
        no_filter = mod.weight.data.shape[0]

        if mod.kernel_size == (1,1):
            self.pruned = False
            self.no_filter = None
            self.mask_filter = None
            return
        
        elif self.layer_id in skip_layers:
            mask_filter = torch.ones(no_filter)
            
            self.pruned = False
            self.no_filter = no_filter
            self.mask_filter = mask_filter
            return
        
        elif self.layer_id in prune_layers:
            if self.layer_id <= 8:
                stage = 0
            elif self.layer_id <= 21:
                stage = 1
            elif self.layer_id <= 36:
                stage = 2
            elif self.layer_id <= 53:
                stage = 3
            else:
                stage = 4
            prune_prob_stage = prune_prob[stage]
            weight_copy = mod.weight.data.abs().clone().cpu().numpy()
            L1_norm = np.sum(weight_copy, axis = (1,2,3))
            num_keep = int(no_filter * (1 - prune_prob_stage))
            arg_max = np.argsort(L1_norm)
            arg_max_rev = arg_max[::-1][:num_keep]
            mask_filter = torch.zeros(no_filter)
            mask_filter[arg_max_rev.tolist()] = 1
            
            self.pruned = True
            self.no_filter = num_keep
            self.mask_filter = mask_filter
            return

In [None]:
class PruneDDL():
        
    def __init__(self):
        self.head = None
        
    def append(self, mod, process = False):
        
        if self.head == None:
            self.head = Node(mod)
            self.head.next = None
            self.head.prev = None
        else:
            ptr = self.head
            while ptr.next!=None:
                ptr = ptr.next
            new_node = Node(mod)
            new_node.next = None
            ptr.next = new_node
            new_node.prev = ptr
        
            if new_node.prev == None:
                new_node.layer_id = 2
            else:
                new_node.layer_id = new_node.prev.layer_id + 1
            
            ptr = self.head
            
            if isinstance(mod, nn.Conv2d) and process:
                
                if new_node.layer_id!=2:
                    prev_layer_id = conv_idx[conv_idx.index(new_node.layer_id) - 1]
                    while ptr.layer_id!=prev_layer_id:
                        ptr = ptr.next

                    if ptr.pruned:
                        new_node.prev_pruned = True
                
                new_node.make_mask(mod, skip_layers, prune_layers)
                
    def prev_conv_dim(self, layer_id):
        
        if layer_id == 2:
            return
        prev_layer_id = conv_idx[conv_idx.index(layer_id) - 1]
        ptr = self.head
        
        while ptr.layer_id!=prev_layer_id:
            ptr = ptr.next
        
        return ptr.no_filter

In [None]:
model_orig = load_model('../models/resnet18_01')
model_orig_ddl = PruneDDL()

for m in model_orig.modules():
    model_orig_ddl.append(m, True)

m = [x for x in model_orig.modules() if isinstance(x, nn.Linear)]
model_orig_ddl.append(m[0])

node_orig = model_orig_ddl.head

In [None]:
model_prune = load_model()

In [None]:
for m0, m1 in zip(model_orig.modules(), model_prune.modules()):
    
    if node_orig == None:
        break
        
    if isinstance(m0, nn.Conv2d):
        
        if m0.kernel_size == (1,1):
            m1.weight.data = m0.weight.data.clone()
            
        if node_orig.layer_id == 2:
            m1.weight.data = m0.weight.data.clone()
            
        if node_orig.pruned:
            mask = node_orig.mask_filter
            idx = np.squeeze(np.argwhere(np.asarray(mask.cpu().numpy())))
            if idx.size == 1:
                idx = np.resize(idx, (1,))
            w = m0.weight.data[idx.tolist(), :, :, :].clone()
            m1.weight.data = w.clone()
            
        if node_orig.prev_pruned:
            ptr = model_orig_ddl.head
            prev_layer_id = conv_idx[conv_idx.index(node_orig.layer_id)-1]
            while ptr.layer_id!=prev_layer_id:
                ptr = ptr.next
            mask = ptr.mask_filter
            idx = np.squeeze(np.argwhere(np.asarray(mask.cpu().numpy())))
            if idx.size == 1:
                idx = np.resize(idx, (1,))
            w = m0.weight.data[:, idx.tolist(), :, :].clone()
            m1.weight.data = w.clone()
            
    elif isinstance(m0, nn.BatchNorm2d):
        assert isinstance(m1, nn.BatchNorm2d)
        
        if node_orig.prev.pruned:
            mask = node_orig.prev.mask_filter
            idx = np.squeeze(np.argwhere(np.asarray(mask.cpu().numpy())))
            if idx.size == 1:
                idx = np.resize(idx, (1,))
            m1.weight.data = m0.weight.data[idx.tolist()].clone()
            m1.bias.data = m0.bias.data[idx.tolist()].clone()
            m1.running_mean = m0.running_mean[idx.tolist()].clone()
            m1.running_var = m0.running_var[idx.tolist()].clone()
        else:
            m1.weight.data = m0.weight.data.clone()
            m1.bias.data = m0.bias.data.clone()
            m1.running_mean = m0.running_mean.clone()
            m1.running_var = m0.running_var.clone()
            
    elif isinstance(m0, nn.Linear):
        m1.weight.data = m0.weight.data.clone()
        m1.bias.data = m0.bias.data.clone()
            
    node_orig = node_orig.next  
    

In [None]:
test_model(model_orig, dataloaders['test'])

In [None]:
test_model(model_prune, dataloaders['test'])

In [None]:
model_prune.load_state_dict(torch.load('../models/pruned_model_retrained'))

In [None]:
test_model(model_prune, dataloaders['test'])

In [None]:
from utils.visualize import visualize_model
from utils.train_model import train_model

In [None]:
device = ('cuda:0' if torch.cuda.is_available() else 'cpu')
fig = visualize_model(model_orig, dataloaders, device, class_name, name='orig')
fig.savefig('orig')

In [None]:
plt = visualize_model(model_prune, dataloaders, device, class_name,name= 'prune')

In [None]:
from torchsummary import summary

In [None]:
summary(model_orig, (3,224,224))

In [None]:
summary(model_prune, (3,224,224))