In [185]:
#imports
from torch.utils.data.sampler import SubsetRandomSampler
import torch
import torch.nn as nn
from torchvision import datasets, transforms
import math
import os
from tqdm.auto import tqdm
import numpy as np
from matplotlib import pyplot as plt
from torchprofile import profile_macs
import time

In [186]:
#utils

# helper functions to measure latency of a regular PyTorch models.
#   Unlike fine-grained pruning, channel pruning
#   can directly leads to model size reduction and speed up.
@torch.no_grad()
def measure_latency(model, dummy_input, n_warmup=20, n_test=100):
    model.eval()
    # warmup
    for _ in range(n_warmup):
        _ = model(dummy_input)
    # real test
    t1 = time.time()
    for _ in range(n_test):
        _ = model(dummy_input)
    t2 = time.time()
    return (t2 - t1) / n_test  # average latency

def get_model_macs(model, inputs) -> int:
    return profile_macs(model, inputs)

def get_num_parameters(model: nn.Module, count_nonzero_only=False) -> int:
    """
    calculate the total number of parameters of model
    :param count_nonzero_only: only count nonzero weights
    """
    num_counted_elements = 0
    for param in model.parameters():
        if count_nonzero_only:
            num_counted_elements += param.count_nonzero()
        else:
            num_counted_elements += param.numel()
    return num_counted_elements


def get_model_size(model: nn.Module, data_width=32, count_nonzero_only=False) -> int:
    """
    calculate the model size in bits
    :param data_width: #bits per element
    :param count_nonzero_only: only count nonzero weights
    """
    return get_num_parameters(model, count_nonzero_only) * data_width

Byte = 8
KiB = 1024 * Byte
MiB = 1024 * KiB
GiB = 1024 * MiB

In [187]:
#LeNet5 model definition
class LeNet5(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0),
            nn.BatchNorm2d(6),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2, stride = 2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size = 2, stride = 2))
        self.fc = nn.Linear(400, 120)
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(120, 84)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(84, num_classes)
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        out = self.relu(out)
        out = self.fc1(out)
        out = self.relu1(out)
        out = self.fc2(out)
        return out

In [188]:
def prune(weights, indices, dim=0):
    things = []
    for index,i in enumerate(weights): 
        if index in indices:
            if dim>=1:
                things.append(torch.tensor(i).tolist())
            else:
                things.append(torch.tensor(i).item())
    return torch.tensor(things)

In [189]:
def evaluate(model, loader, device="cpu"):
    correct = 0
    total = 0
    for images, labels in loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    return 100 * correct / total

In [190]:
def measure(model, loader):
    dummy_input,_ = next(iter(loader))
    model = model.to('cpu')

    size = get_model_size(model=model, count_nonzero_only=True)/(8*2**10)

    latency = measure_latency(model, dummy_input)

    macs = get_model_macs(model, dummy_input)

    param = get_num_parameters(model)

    accuracy = evaluate(model,loader)
    
    return size.item(), latency, param, macs, accuracy

In [191]:
#Variables and model
batch_size=64
def dataset():
    train_data = datasets.MNIST('./data/mnist', train=True, download=True,
                                transform=transforms.Compose([
                                    transforms.Resize((32,32)),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.1307,), (0.3081,))
                                ]))
    test_data = datasets.MNIST('./data/mnist', train=False, download=True,
                               transform=transforms.Compose([
                                   transforms.Resize((32,32)),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.1325,), (0.3105,))
                               ]))

    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True)
    return train_loader, test_loader
train_loader, test_loader = dataset()

device = "cpu"

model = torch.load("lenet5.pt")
model.eval()
orig_weight_dict = model.state_dict().copy()

In [192]:
orig_size, orig_latency, orig_param, orig_macs, orig_accuracy = measure(model,test_loader)
print("size: ", orig_size, "KB")
print("original_latency: ",orig_latency)
print("macs: ",orig_macs)
print("param: ",orig_param)
print("accuracy: ",orig_accuracy,"%")



size:  241.2109375 KB
original_latency:  0.0036598777770996092
macs:  27060736
param:  61750
accuracy:  95.59 %


In [193]:
#Channel Pruning
model = torch.load("prunedLenet5.pt")
channel_mask=[]

conv=[]
bns=[]
fcs=[]

for i in model.modules():
    if isinstance(i, nn.Conv2d):
        conv.append(i)
    if isinstance(i, nn.BatchNorm2d):
        bns.append(i)
    if isinstance(i, nn.Linear):
        fcs.append(i)

with torch.no_grad():
    sparsity=0.7
    mask_temp=[]
    i=0
    current_conv = conv[i]
    current_bn = bns[i]
    next_conv = conv[i+1] 
    original_channels = current_conv.weight
    n_keep = round((1-sparsity)*len(original_channels))

    mean_conv = torch.linalg.matrix_norm(current_conv.weight)
    threshold_conv = torch.flatten(mean_conv).kthvalue(n_keep).values
    
    indices=[]
    for index, i in enumerate(mean_conv):
        if i<=threshold_conv:
            indices.append(index)
            mask_temp.append(torch.zeros(5,5))
        else:
            mask_temp.append(torch.ones(5,5))
    channel_mask.append(mask_temp)
    
    current_conv.weight.set_(prune(current_conv.weight.detach(),indices,1))
    current_conv.bias.set_(prune(current_conv.bias.detach(),indices))
    current_bn.weight.set_(prune(current_bn.weight.detach(), indices))
    current_bn.bias.set_(prune(current_bn.bias.detach(),indices))
    current_bn.running_mean.set_(prune(current_bn.running_mean.detach(), indices))
    current_bn.running_var.set_(prune(current_bn.running_var.detach(), indices))

    things = []
    for index,i in enumerate(next_conv.weight): 
        things.append(prune(i, indices,1).tolist())
    things = torch.tensor(things)
    next_conv.weight.set_(things)

  things.append(torch.tensor(i).tolist())
  things.append(torch.tensor(i).item())


In [194]:
prune_size, prune_latency, prune_param, prune_macs, prune_accuracy = measure(model,test_loader)
print("size: ", prune_size, "KB")
print("original_latency: ",prune_latency)
print("macs: ",prune_macs)
print("param: ",prune_param)
print("accuracy: ",prune_accuracy,"%")
print("Mask:",channel_mask)



size:  118.54296875 KB
original_latency:  0.0027092885971069336
macs:  15467008
param:  60466
accuracy:  27.25 %
Mask: [[tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]]), tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]]), tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]]), tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]]), tensor([[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]]), tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],