In [None]:
import sys
import functools
import numpy as np

import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import torchvision

import pytorch_lightning as pl
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization

from typing import Union, Tuple

import matplotlib.pyplot as plt

In [None]:
pl.seed_everything(24)

In [None]:
batch_size = 32

train_transforms = torchvision.transforms.Compose([
    torchvision.transforms.RandomCrop(32, padding=4),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    cifar10_normalization(),
])

test_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    cifar10_normalization(),
])

cifar10_dm = CIFAR10DataModule(
    batch_size=batch_size,
    train_transforms=train_transforms,
    test_transforms=test_transforms,
    val_transforms=test_transforms,
)

In [None]:
cifar10_dm

In [None]:
hamming1d = np.hamming(7)
hamming2d = np.sqrt(np.outer(hamming1d, hamming1d))

In [None]:
plt.imshow(hamming2d)

In [None]:
def create_model(first_downsampling=True):
    model = torchvision.models.resnet18(pretrained=False, num_classes=10)
    if not first_downsampling:
        model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        model.maxpool = nn.Identity()
    return model

In [None]:
def partialclass(name, cls, *args, **kwargs):
    # could I name it "Abstract Fabric" lol?
    
    NewCls = type(name, (cls,), {
        '__init__': functools.partialmethod(cls.__init__, *args, **kwargs)
    })

    """
    # For pickling to work, the __module__ variable needs to be set to the frame
    # where the named tuple is created.  Bypass this step in enviroments where
    # sys._getframe is not defined (Jython for example) or sys._getframe is not
    # defined for arguments greater than 0 (IronPython).
    """
    try:
        NewCls.__module__ = sys._getframe(1).f_globals.get('__name__', '__main__')
    except (AttributeError, ValueError):
        pass
    
    return NewCls


def make_tuple(x):
    if isinstance(x, int):
        x = (x, x)
    return x


class Conv2dHamming(nn.Module):
    def __init__(self, in_channels: int,
                 out_channels: int,
                 kernel_size: Union[int, Tuple[int, int]],
                 stride: Union[int, Tuple[int, int]] = 1,
                 padding: Union[int, Tuple[int, int]] = 0,
                 dilation: Union[int, Tuple[int, int]] = 1,
                 groups: int = 1,
                 bias: bool = True):
        super(Conv2dHamming, self).__init__()
        
        kernel_size = make_tuple(kernel_size)
        stride = make_tuple(stride)
        padding = make_tuple(padding)
        dilation = make_tuple(dilation)

        self.weight = nn.Parameter(
            torch.zeros(out_channels, in_channels, kernel_size[0], kernel_size[1])
        )
        nn.init.kaiming_normal_(self.weight, mode='fan_out', nonlinearity='relu')
        
        if bias is not None:
            self.bias = nn.Parameter(torch.zeros(out_channels))
        else:
            self.bias = None
        
        hamming2d = np.sqrt(np.outer(np.hamming(kernel_size[0]), np.hamming(kernel_size[1])))
        self.register_buffer('hamming2d', torch.from_numpy(hamming2d).to(torch.float32))
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
    
    def transform_weight(self):
        return self.weight * self.hamming2d


    def forward(self, input: Tensor) -> Tensor:
        out = F.conv2d(input, self.transform_weight(), self.bias, self.stride,
                       self.padding, self.dilation)
        return out

    def __repr__(self):
        return (
            f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
            f" kernel_size={self.kernel_size}, stride={self.stride}, padding={self.padding},"
            f" bias={self.bias is not None})"
        )


class Conv2dFactorized(nn.Module):
    def __init__(self, in_channels: int,
                 out_channels: int,
                 kernel_size: Union[int, Tuple[int, int]],
                 stride: Union[int, Tuple[int, int]] = 1,
                 padding: Union[int, Tuple[int, int]] = 0,
                 dilation: Union[int, Tuple[int, int]] = 1,
                 groups: int = 1,
                 bias: bool = True,
                 ConvClass=nn.Conv2d):
        super(Conv2dFactorized, self).__init__()
        
        kernel_size = make_tuple(kernel_size)
        stride = make_tuple(stride)
        padding = make_tuple(padding)
        dilation = make_tuple(dilation)
        
        self.conv1 = ConvClass(in_channels, out_channels, (kernel_size[0], 1),
                               (stride[0], 1), (padding[0], 0), (dilation[0], 1),
                               groups, bias)
        self.conv2 = ConvClass(out_channels, out_channels, (1, kernel_size[1]),
                               (1, stride[1]), (0, padding[1]), (1, dilation[1]),
                               groups, bias)
        
        nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu')
        nn.init.kaiming_normal_(self.conv2.weight, mode='fan_out', nonlinearity='relu')
    
    def forward(self, input: Tensor) -> Tensor:
        x = self.conv1(input)
        return self.conv2(x)

    
Conv2dHamingFactorized = partialclass('Conv2dHamingFactorized',
                                      Conv2dFactorized, ConvClass=Conv2dHamming)


def modify_model(model, kernel=3, padding=1, factorized=False, hamming=False, big_first=False):
    def _replace_layer(modules, parent_name, parent_module):
        for cur_name, cur_module in modules.items():
            if isinstance(cur_module, nn.Conv2d) and cur_module.kernel_size == (3, 3):
                
                if factorized and hamming:
                    ConvClass = Conv2dHamingFactorized
                elif factorized:
                    ConvClass = Conv2dFactorized
                elif hamming:
                    ConvClass = Conv2dHamming
                else:
                    ConvClass = nn.Conv2d
                
                if big_first and cur_module.in_channels == 3:
                    new_kernel = 9
                    new_padding = 5
                else:
                    new_kernel = kernel
                    new_padding = padding
                    
                new_conv = ConvClass(cur_module.in_channels,
                                     cur_module.out_channels,
                                     kernel,
                                     cur_module.stride,
                                     padding,
                                     cur_module.dilation,
                                     cur_module.groups,
                                     cur_module.bias)
                setattr(parent_module, cur_name, new_conv)

            if len(cur_module._modules) > 0:
                _replace_layer(cur_module._modules, cur_name, cur_module)
    _replace_layer(model._modules, 'base', model)

    
def decrease_channel_width(model, coeff):
    def _decrease_func(modules, parent_name, parent_module):
        for cur_name, cur_module in modules.items():
            if isinstance(cur_module, nn.Conv2d) or isinstance(cur_module, Conv2dHamming):
                ConvClass = type(cur_module)
                in_channels = int(cur_module.in_channels * coeff)
                if cur_module.in_channels == 3:
                    in_channels = 3
                new_conv = ConvClass(in_channels,
                                     int(cur_module.out_channels * coeff),
                                     cur_module.kernel_size,
                                     cur_module.stride,
                                     cur_module.padding,
                                     cur_module.dilation,
                                     cur_module.groups,
                                     cur_module.bias)
                setattr(parent_module, cur_name, new_conv)

            elif isinstance(cur_module, nn.BatchNorm2d):
                new_bn = nn.BatchNorm2d(int(cur_module.num_features * coeff))
                setattr(parent_module, cur_name, new_bn)

            elif isinstance(cur_module, nn.Linear):
                new_ln = nn.Linear(int(cur_module.in_features * coeff), cur_module.out_features)
                setattr(parent_module, cur_name, new_ln)

            if len(cur_module._modules) > 0:
                _decrease_func(cur_module._modules, cur_name, cur_module)

    _decrease_func(model._modules, 'base', model)


def create_custom_model(kernel=3, padding=1, factorized=False, hamming=False,
                        big_first=False, wm=1, first_downsampling=False):
    model = create_model(first_downsampling=first_downsampling)
    modify_model(model, kernel, padding, factorized, hamming, big_first)
    decrease_channel_width(model, wm)
    return model


def calc_params(model):
    total_params = 0
    per_layer_params = {}
    for name, module in model.named_modules():
        layer_params = 0
        if hasattr(module, "weight") and hasattr(module.weight, "size"):
            layer_params += np.product(list(module.weight.size()))
        if hasattr(module, "bias") and hasattr(module.bias, "size"):
            layer_params += np.product(list(module.bias.size()))
        
        if layer_params != 0:
            per_layer_params[name] = layer_params
            total_params += layer_params

    return total_params, per_layer_params

In [None]:
resnet = create_model(False)

resnet_narrow = create_custom_model(wm=28/64)
resnet_factorized = create_custom_model(factorized=True)

resnet_7x7_narrow = create_custom_model(7, 3, big_first=True, wm=28/64)
resnet_7x7_factorized = create_custom_model(7, 3, big_first=True, factorized=True, wm=50/64)

resnet_7x7_hamming_narrow = create_custom_model(7, 3, hamming=True,
                                                big_first=True, wm=28/64)
resnet_7x7_hamming_factorized = create_custom_model(7, 3, hamming=True,
                                                    big_first=True, factorized=True, wm=50/64)

In [None]:
calc_params(resnet)[0] // 1e6

In [None]:
calc_params(resnet_narrow)[0] // 1e6

In [None]:
calc_params(resnet_factorized)[0] // 1e6

In [None]:
calc_params(resnet_7x7_narrow)[0] // 1e6

In [None]:
calc_params(resnet_7x7_factorized)[0] // 1e6

In [None]:
calc_params(resnet_7x7_hamming_narrow)[0] // 1e6

In [None]:
calc_params(resnet_7x7_hamming_factorized)[0] // 1e6

---

In [None]:
import torchsummary

In [None]:
torchsummary.summary(resnet_7x7_factorized, (3, 32, 32), device='cpu')

In [None]:
torchsummary.summary(resnet_7x7_hamming_factorized, (3, 32, 32), device='cpu')