<a href="https://colab.research.google.com/github/EherSenaw/Prune_tutorial/blob/main/Prune_tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Colab notebook (.ipynb) for [Prune_tutorial](https://github.com/EherSenaw/Prune_tutorial)

- \<Imports\>
> Imports needed for this tutorial.
- \<Configs\>
> Configurations used for this tutorial. Change this to customize.
- \<Utils\>
> Utils used for this tutorial. Might be helpful for reader whom searching for starting-point of pruning-related research/implementation.
- \<Model\>
> Sample model used for this tutorial. Any other models could be used...


## Imports

In [1]:
import torch
import math

import numpy as np 
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.prune as prune

from copy import deepcopy
from tqdm.notebook import tqdm
from torchvision import transforms
from torch.nn.modules.batchnorm import BatchNorm2d
from torchvision.datasets import CIFAR10, SVHN, MNIST

## Configs

In [2]:
batch_size = 512
step_size = 10
total_epochs = 30
learning_rate = 3e-4 
lr_decay = 1e-1

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
save_file_path = './best.pth'

## mask_dict: True -> don't change, False -> change to 0.
true_3x3 = [[True for _ in range(3)] for _ in range(3)]
false_3x3 = [[False for _ in range(3)] for _ in range(3)]
cross_3x3 = [[True, False, True], [False, True, False], [True, False, True]]
custom_3x3 = [[False, True, True], [True, True, True], [True, True, True]] 
## Change these dict to change masking
kernel_mask_dict = {
    'conv_1_24_3': true_3x3, 
    'conv_24_32_3': custom_3x3, 
    'conv_32_32_3': true_3x3, 
    'conv_32_64_3': true_3x3, 
    'conv_64_64_3': true_3x3, 
}
IC_mask_dict = {
    'conv_1_24_3': [true_3x3 for i in range(1)],
    'conv_24_32_3': [true_3x3 for i in range(24)],
    'conv_32_32_3': [true_3x3 for i in range(32)],
    'conv_32_64_3': [true_3x3 for i in range(32)],
    'conv_64_64_3': [true_3x3 for i in range(64)],
}
OC_mask_dict = {
    'conv_1_24_3': [[true_3x3 for _ in range(1)] for i in range(24)],
    'conv_24_32_3': [[true_3x3 for _ in range(24)] for i in range(32)],
    'conv_32_32_3': [[true_3x3 for _ in range(32)] for i in range(32)],
    'conv_32_64_3': [[true_3x3 for _ in range(32)] for i in range(64)],
    'conv_64_64_3': [[true_3x3 for _ in range(64)] for i in range(64)],
}

## Utils

In [3]:
# Hierarchical Masking of weight initialization
##Use with 'torch.nn.module.apply()'
@torch.no_grad()
def weight_init_custom(submodule):
    if isinstance(submodule, torch.nn.Conv2d):
        # Set all element of conv2d layers to zero (both weight & bias)
        if isinstance(submodule.kernel_size, tuple) and submodule.kernel_size == (3,3):
            if submodule.weight is not None:
                dict_key = "conv_" + str(submodule.in_channels) +"_"+ str(submodule.out_channels) +"_"+ str(submodule.kernel_size[0])
                # Kernel-mask
                mask_tensor = torch.tensor([[kernel_mask_dict[dict_key] for _ in range(submodule.in_channels)] for _ in range(submodule.out_channels)], requires_grad=False).float()
                # IC-mask
                mask_tensor = mask_tensor * torch.tensor([IC_mask_dict[dict_key] for _ in range(submodule.out_channels)], requires_grad=False).float()
                # OC-mask
                mask_tensor = mask_tensor * torch.tensor(OC_mask_dict[dict_key], requires_grad=False).float()
                submodule.weight.copy_(submodule.weight * mask_tensor.float())
    #print(f"non-zero submodule.weight.data count: {len(submodule.weight.data[submodule.weight.data.nonzero(as_tuple=True)])}")

# Show parameters of given model
@torch.no_grad()
def profiler_custom(model):
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name, param.data)
# Show modules of given model
@torch.no_grad()
def profiler_modules(model):
    for name, module in model.named_modules():
        print(f"{name}: {module}")

# Hook, for masking gradients (and debugging)
class Hook():
    ## TODO: Define gradient mask for mask-target layers
    #mask_dict = {
    #    'conv_1_24_3': [[[[False, True, False] for _ in range(3)] for i in range(1)] for j in range(24)],
    #    'conv_24_32_3': [[[[False, True, False] for _ in range(3)] for i in range(24)] for j in range(32)],
    #    'conv_32_32_3': [[[[False, True, False] for _ in range(3)] for i in range(32)] for j in range(32)],
    #    'conv_32_64_3': [[[[False, True, False] for _ in range(3)] for i in range(32)] for j in range(64)],
    #    'conv_64_64_3': [[[[False, True, False] for _ in range(3)] for i in range(64)] for j in range(64)],
    #}
    def __init__(self, module, backward=False):
        if not backward:
            if isinstance(module[1], nn.Sequential):
                self.hook = []
                for submodule in list(module[1]._modules.items()):
                    curHook = submodule[1].register_forward_hook(self.hook_fn_fwd)
                    self.hook.append(curHook)
            else:
                self.hook = module[1].register_forward_hook(self.hook_fn_fwd)
        else:
            self.module_name = module[0]
            if isinstance(module[1], nn.Sequential):
                self.hook = []
                for submodule in list(module[1]._modules.items()):
                    curHook = submodule[1].register_full_backward_hook(self.hook_fn_full_bwd)
                    self.hook.append(curHook)
            else:
                self.hook = module[1].register_full_backward_hook(self.hook_fn_full_bwd)
    def hook_fn_fwd(self, module, input, output):
        # Forward hook
        # If you want to check forward pass, use this.
        pass
    def hook_fn_full_bwd(self, module, grad_in=None, grad_out=None):
        # Backward Hook
        # TODO: Repair commented parts, which was intended for gradient masking.
        with torch.no_grad():
            if not (isinstance(module, nn.Linear) or isinstance(module, nn.GELU) \
                    or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.Dropout)\
                    or isinstance(module, nn.MaxPool2d) or isinstance(module, nn.Conv2d)):
                if grad_out is not None:
                    if grad_out[0] is not None:
                        assert not torch.isnan(grad_out[0]).any()
                        #print(f"\t\t<GRAD_OUT shape> {grad_out[0].size()}")
                        #masked_grad = grad_out[0] * torch.tensor(self.mask_dict[module[0]], requires_grad=False).float()
                        #return (masked_grad, )
            else:
                # Unsubscriptable
                if grad_out is not None:
                    if grad_out[0] is not None:
                        assert not torch.isnan(grad_out[0]).any()
                        #print(f"\t\t<GRAD_OUT shape> {grad_out[0].size()}")
                        #masked_grad = grad_out[0] * torch.tensor(self.mask_dict[module], requires_grad=False).float()
                        #return (masked_grad, )
        #return grad_in
    def close(self):
        if self.hook is not None:
            if isinstance(self.hook, list):
                for subHook in self.hook:
                    subHook.remove()
            else:
                self.hook.remove()


## Prune-related stuffs

In [4]:
MPM_dict_key = None

class MaskPruningMethod(prune.BasePruningMethod):
    """Prune tensor"""
    PRUNING_TYPE = 'unstructured'
    def compute_mask(self, t, default_mask):
        global MPM_dict_key
        OC_NUM, IC_NUM, _, _ = torch.tensor(OC_mask_dict[MPM_dict_key], requires_grad=False).size()
        mask = default_mask.clone()
        # Kernel-mask
        mask = mask * torch.tensor([[kernel_mask_dict[MPM_dict_key] for _ in range(IC_NUM)] for _ in range(OC_NUM)], requires_grad=False).float()
        # IC-mask
        mask = mask * torch.tensor([IC_mask_dict[MPM_dict_key] for _ in range(OC_NUM)], requires_grad=False).float()
        # OC-mask
        mask = mask * torch.tensor(OC_mask_dict[MPM_dict_key], requires_grad=False).float()
        return mask

def mask_prune(module):
    global MPM_dict_key
    for submodule in module.modules():
        if not isinstance(submodule, nn.Conv2d):
            continue
        dict_key = "conv_" + str(submodule.in_channels) +"_"+ str(submodule.out_channels) +"_"+ str(submodule.kernel_size[0])
        if dict_key not in kernel_mask_dict:
            continue
        MPM_dict_key = dict_key
        if hasattr(submodule, 'weight'):
            # Apply prune (weights saved as 'weight_orig')
            MaskPruningMethod.apply(submodule, 'weight')
            ## !!!SHOULD CALL 'permenant_prune()' BEFORE saving model!!!
# Use this BEFORE saving model.
def permanent_prune(module):
    for submodule in module.modules():
        if not isinstance(submodule, nn.Conv2d):
            continue
        dict_key = "conv_" + str(submodule.in_channels) +"_"+ str(submodule.out_channels) +"_"+ str(submodule.kernel_size[0])
        if dict_key not in kernel_mask_dict:
            continue
        # Make prune permanent (weights saved as 'weight')
        prune.remove(submodule, 'weight')

## Model

In [5]:
class Sample_Model(nn.Module):
    def __init__(self, num_classes=10): 
        super(Sample_Model, self).__init__()

        self.layers = nn.Sequential(
            nn.Conv2d(1, 24, 3, padding=1, bias=True), 
            nn.MaxPool2d(2),
            nn.GELU(),

            nn.Conv2d(24, 32, 3, padding=1, bias=False),
            nn.BatchNorm2d(32, eps=1e-4),
            nn.GELU(),

            nn.Conv2d(32, 32, 3, padding=1, bias=False),
            nn.BatchNorm2d(32, eps=1e-4),
            nn.MaxPool2d(2),
            nn.GELU(),

            nn.Conv2d(32, 32, 3, padding=0, bias=False),
            nn.BatchNorm2d(32, eps=1e-4),
            nn.GELU(),

            nn.Conv2d(32, 64, 3, padding=1, bias=False),
            nn.BatchNorm2d(64, eps=1e-4),
            nn.GELU(),
            
            nn.Conv2d(64, 64, 3, padding=0, bias=False),
            nn.BatchNorm2d(64, eps=1e-4),
            nn.GELU(),

            nn.Dropout(0.5),
            nn.Conv2d(64, 64, 3, padding=0, bias=False),
            nn.GELU(),
        )
        # Apply Prune
        mask_prune(self.layers)
        self.fc = nn.Linear(64, 10) 

    def forward(self, x):
        x = self.layers(x)
        x = x.view([x.shape[0], -1])
        x = self.fc(x)
        return x

In [10]:
def get_dataloader():
    # MNIST
    train_loader = torch.utils.data.DataLoader(
        MNIST('./Data',
            transform=transforms.Compose([
                transforms.ToTensor(),
            ]),
            download=True,
            train=True
        ),
        shuffle=True,
        batch_size=batch_size,
        num_workers=0
    )
    test_loader = torch.utils.data.DataLoader(
        MNIST('./Data',
            transform=transforms.Compose([
                transforms.ToTensor(),
            ]),
            download=True,
            train=False
        ), 
        shuffle=False, 
        batch_size=batch_size, 
        num_workers=0
    )
    return train_loader, test_loader

@torch.no_grad()
def eval(model, test_loader):
    correct = 0
    total = 0

    model.to(device)
    model.eval()

    for i, (img, target) in enumerate(tqdm(test_loader)):
        img = img.to(device)
        out = model(img)
        pred = out.max(1)[1].detach().cpu().numpy()
        target = target.cpu().numpy()
        correct += (pred==target).sum()
        total += len(target)

    return correct / total


def train_model(model, train_loader, test_loader):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size, lr_decay) 

    best_acc = -1
    # For debugging, uncomment following line 
    #torch.autograd.set_detect_anomaly(True)
    for epoch in tqdm(range(total_epochs)):
        # Train
        model.train()
        cnt = 0
        loss_sum = 0
        loss_ = 0
        for i, (img, target) in enumerate(train_loader):
            cnt += 1
            img, target = img.to(device), target.to(device)
            optimizer.zero_grad()
            out = model(img)
            loss = F.cross_entropy(out, target)
            loss.backward()
            optimizer.step()
            loss_sum += loss.item()
        
        loss_sum = loss_sum / cnt
        loss_= round(loss_sum, 4)

        # Evaluate
        model.eval()
        acc = eval(model, test_loader)
        if acc > best_acc :
          best_acc = acc
          print('Best accuracy is updated: %.4f '%(best_acc))
          # Save model
          copied_model = deepcopy(model)
          permanent_prune(copied_model.layers)
          torch.save(copied_model.state_dict(), save_file_path)
          del copied_model

        scheduler.step() 

    print("Best Accuracy: %.4f"%(best_acc))

## Train, Evaluate, Save

In [11]:
def main():
    train_loader, test_loader = get_dataloader()
    model = Sample_Model(num_classes=10).cuda()

    # Register hook, for masking (and so on)
    layers = list(model._modules.items())
    #HookF = [Hook(layer, backward=False) for layer in layers]
    HookB = [Hook(layer, backward=True) for layer in layers]

    # Train, evaluate, save model
    train_model(model, train_loader, test_loader)

    # Unregister hook
    for hook in HookB:
        hook.close()
    #for hook in HookF:
    #    hook.close()
    del HookB
    #del HookF

if __name__=='__main__':
  main()

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Best accuracy is updated: 0.9791 


  0%|          | 0/20 [00:00<?, ?it/s]

Best accuracy is updated: 0.9881 


  0%|          | 0/20 [00:00<?, ?it/s]

Best accuracy is updated: 0.9889 


  0%|          | 0/20 [00:00<?, ?it/s]

Best accuracy is updated: 0.9902 


  0%|          | 0/20 [00:00<?, ?it/s]

Best accuracy is updated: 0.9908 


KeyboardInterrupt: ignored

## Load best model and check pruning

In [12]:
model = Sample_Model(num_classes=10)
# To match shape of model's parameters, use 'permanent_prune'.
permanent_prune(model.layers)
model.load_state_dict(torch.load(save_file_path))
model = model.cuda()
for name, p in model.named_parameters():
	print(name)
 

layers.0.bias
layers.0.weight
layers.3.weight
layers.4.weight
layers.4.bias
layers.6.weight
layers.7.weight
layers.7.bias
layers.10.weight
layers.11.weight
layers.11.bias
layers.13.weight
layers.14.weight
layers.14.bias
layers.16.weight
layers.17.weight
layers.17.bias
layers.20.weight
fc.weight
fc.bias


In [13]:
model.get_parameter('layers.3.weight').shape

torch.Size([32, 24, 3, 3])

In [14]:
print(model.get_parameter('layers.3.weight'))

Parameter containing:
tensor([[[[ 0.0000,  0.0347, -0.0546],
          [-0.0356, -0.0218, -0.0172],
          [-0.0383,  0.0507, -0.0138]],

         [[-0.0000,  0.0390,  0.0070],
          [ 0.0354, -0.0549,  0.0540],
          [-0.0450,  0.0202, -0.0072]],

         [[-0.0000, -0.0391,  0.0122],
          [-0.0549,  0.0374,  0.0332],
          [ 0.0283, -0.0039, -0.0281]],

         ...,

         [[-0.0000, -0.0011,  0.0185],
          [ 0.0334, -0.0079,  0.0712],
          [ 0.0406,  0.0461,  0.0166]],

         [[ 0.0000,  0.0709,  0.0272],
          [ 0.0035,  0.0010, -0.0408],
          [ 0.0343, -0.0102, -0.0494]],

         [[-0.0000,  0.0581,  0.0133],
          [ 0.0163, -0.0444,  0.0078],
          [-0.0094,  0.0470,  0.0024]]],


        [[[-0.0000, -0.0185, -0.0652],
          [-0.0339, -0.0225,  0.0094],
          [-0.0115, -0.0453, -0.0186]],

         [[-0.0000,  0.0058, -0.0376],
          [ 0.0393,  0.0737, -0.0200],
          [-0.0516, -0.0035,  0.0219]],

         