In [1]:
import torch
import torchvision
import torch.nn as nn
torch.cuda.is_available()

True

In [2]:
import torchvision.models as models 
net = models.resnet18(pretrained=True)


def get_model_size(model):
    total_bytes = 0
    for name, params in net.named_parameters():
        # Bytes per element * number of elements for Layer "i" 
        total_bytes += params.element_size() * params.nelement()
    
    return (total_bytes / (1024 * 1024))

print(f"Total number of MB : {get_model_size(net)}")

Total number of MB : 44.591949462890625


In [3]:
'''To be placed in Preprocessing Module'''

def expand_model(model) -> nn.Module:
    '''

    Expand model to show all the fundamental layers like :
    nn.Conv1d,
    nn.Conv2d,
    nn.BatchNorm2d,
    etc..

    Model must be expanded before passing as argument to the technique modules
    '''
    net = []
    for mod in model.modules():
        if(list(mod.children()) == []):
            net.append(mod)
    
    return nn.Sequential(*net)


net = expand_model(net)
net

Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (6): ReLU(inplace=True)
  (7): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (8): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (9): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (11): ReLU(inplace=True)
  (12): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (13): BatchNorm2d(64, eps=1e-05, momentum=0.1, affi

In [4]:
import torch.nn as nn
from typing import List


# Pick out the layers you want to binarize 
# User input
BIN_LAYERS = {
    nn.Conv1d : True,
    nn.Conv2d : True,
    nn.Linear : True,
    nn.BatchNorm2d : True,
}

def _bin_param_data(child) -> torch.Tensor:
    '''
    Binarize functon

    Binarize each layer's weights and bias, if avaialable
    '''
    
    try:
        child.weight.data.fill_(int(torch.randint(low = 0, high = 15,  size=(1,))))
    except:
        print(f"Cannot .weight {type(child)}")
        # pass
        

    try:
        child.bias.data.fill_(int(torch.randint(low = 0, high = 15,  size=(1,))))
    except:
        print(f"Cannot .bias {type(child)}")
        # pass
    
    return child.type(torch.int8)

def swap_out_layers(model, BIN_LAYERS) -> nn.Module:
    '''
    Swap out a model layer with its binaraized version 

    The binarized version comes from the technique implementation of the binarization method:
    Eg. Hard Sign, Stochastic Sign, Noisy Sign, etc.
    '''
    list_model = list(model.children())
    for ind, child in enumerate(list_model):
        if type(child) in BIN_LAYERS:
            if BIN_LAYERS[type(child)]:
                list_model[ind] = _bin_param_data(child)
    
    return nn.Sequential(*list_model)

new_net = swap_out_layers(net, BIN_LAYERS)
for child in new_net:
    print(type(child))
    for params in child.parameters():
        print(params)

Cannot .bias <class 'torch.nn.modules.conv.Conv2d'>
Cannot .bias <class 'torch.nn.modules.conv.Conv2d'>
Cannot .bias <class 'torch.nn.modules.conv.Conv2d'>
Cannot .bias <class 'torch.nn.modules.conv.Conv2d'>
Cannot .bias <class 'torch.nn.modules.conv.Conv2d'>
Cannot .bias <class 'torch.nn.modules.conv.Conv2d'>
Cannot .bias <class 'torch.nn.modules.conv.Conv2d'>
Cannot .bias <class 'torch.nn.modules.conv.Conv2d'>
Cannot .bias <class 'torch.nn.modules.conv.Conv2d'>
Cannot .bias <class 'torch.nn.modules.conv.Conv2d'>
Cannot .bias <class 'torch.nn.modules.conv.Conv2d'>
Cannot .bias <class 'torch.nn.modules.conv.Conv2d'>
Cannot .bias <class 'torch.nn.modules.conv.Conv2d'>
Cannot .bias <class 'torch.nn.modules.conv.Conv2d'>
Cannot .bias <class 'torch.nn.modules.conv.Conv2d'>
Cannot .bias <class 'torch.nn.modules.conv.Conv2d'>
Cannot .bias <class 'torch.nn.modules.conv.Conv2d'>
Cannot .bias <class 'torch.nn.modules.conv.Conv2d'>
Cannot .bias <class 'torch.nn.modules.conv.Conv2d'>
Cannot .bias

In [5]:
print(f"Total number of MB : {get_model_size(new_net)}")

# SInce model is converted into int8 storage format, notice the significant drop in model size.
# Of course, this comes with tradeoff of reduced accuracy 

Total number of MB : 11.147987365722656
