In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [2]:
class LowRank(nn.Module):
  def __init__(self,
               in_channels: int,
               out_channels: int,
               low_rank: int,
               kernel_size: int):
    super().__init__()
    self.T = nn.Parameter(
        torch.empty(size=(low_rank, low_rank, kernel_size, kernel_size)),
        requires_grad=True
    )
    self.O = nn.Parameter(
        torch.empty(size=(low_rank, out_channels)),
        requires_grad=True
    )
    self.I = nn.Parameter(
        torch.empty(size=(low_rank, in_channels)),
        requires_grad=True
    )
    self._init_parameters()
  
  def _init_parameters(self):
    # Initialization affects the convergence stability for our parameterization
    fan = nn.init._calculate_correct_fan(self.T, mode='fan_in')
    gain = nn.init.calculate_gain('relu', 0)
    std_t = gain / np.sqrt(fan)

    fan = nn.init._calculate_correct_fan(self.O, mode='fan_in')
    std_o = gain / np.sqrt(fan)

    fan = nn.init._calculate_correct_fan(self.I, mode='fan_in')
    std_i = gain / np.sqrt(fan)

    nn.init.normal_(self.T, 0, std_t)
    nn.init.normal_(self.O, 0, std_o)
    nn.init.normal_(self.I, 0, std_i)

  def forward(self):
    # torch.einsum simplify the tensor produce (matrix multiplication)
    return torch.einsum("xyzw,xo,yi->oizw", self.T, self.O, self.I)

class Conv2d(nn.Module):
  def __init__(self,
               in_channels: int,
               out_channels: int,
               kernel_size: int=3,
               stride: int=1,
               padding: int=0,
               bias: bool=False,
               ratio: float=0.0):
    super().__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.kernel_size = kernel_size
    self.stride = stride
    self.padding = padding
    self.bias = bias
    self.ratio = ratio
    self.low_rank = self._calc_from_ratio()

    self.W1 = LowRank(in_channels, out_channels, self.low_rank, kernel_size)
    self.W2 = LowRank(in_channels, out_channels, self.low_rank, kernel_size)
    self.bias = nn.Parameter(torch.zeros(out_channels)) if bias else None

  def _calc_from_ratio(self):
    # Return the low-rank of sub-matrices given the compression ratio 
    r1 = int(np.ceil(np.sqrt(self.out_channels)))
    r2 = int(np.ceil(np.sqrt(self.in_channels)))
    r = np.max((r1, r2))

    num_target_params = self.out_channels * self.in_channels * \
      (self.kernel_size ** 2) * self.ratio
    r3 = np.sqrt(
        ((self.out_channels + self.in_channels) ** 2) / (4 *(self.kernel_size ** 4)) + \
        num_target_params / (2 * (self.kernel_size ** 2))
    ) - (self.out_channels + self.in_channels) / (2 * (self.kernel_size ** 2))
    r3 = int(np.ceil(r3))
    r = np.max((r, r3))

    return r

  def forward(self, x):
    # Hadamard product of two submatrices
    W = self.W1() * self.W2()
    out = F.conv2d(input=x, weight=W, bias=self.bias,
                 stride=self.stride, padding=self.padding)
    return out

Adjusting the number of parameters given compression ratio.

In [39]:
orig_num_params = 256 * 256 * 3 * 3

layer1 = Conv2d(256, 256, 3, 1, 1, False, 0.1)
layer3 = Conv2d(256, 256, 3, 1, 1, False, 0.3)
layer5 = Conv2d(256, 256, 3, 1, 1, False, 0.5)
layer7 = Conv2d(256, 256, 3, 1, 1, False, 0.7)
layer9 = Conv2d(256, 256, 3, 1, 1, False, 0.9)

num1 = sum(p.numel() for p in layer1.parameters() if p.requires_grad)
num3 = sum(p.numel() for p in layer3.parameters() if p.requires_grad)
num5 = sum(p.numel() for p in layer5.parameters() if p.requires_grad)
num7 = sum(p.numel() for p in layer7.parameters() if p.requires_grad)
num9 = sum(p.numel() for p in layer9.parameters() if p.requires_grad)

print(orig_num_params, num1, num1 / orig_num_params)
print(orig_num_params, num3, num3 / orig_num_params)
print(orig_num_params, num5, num5 / orig_num_params)
print(orig_num_params, num7, num7 / orig_num_params)
print(orig_num_params, num9, num9 / orig_num_params)

589824 60192 0.10205078125
589824 178050 0.3018697102864583
589824 296434 0.5025804307725694
589824 414792 0.7032470703125
589824 533192 0.9039849175347222


Feedforward test

In [40]:
x = torch.randn(size=(1, 128, 16, 16))
layer = Conv2d(128, 256, 3, 1, 1, False, 0.1)
out = layer(x)
print(out.shape)

x = torch.randn(size=(1, 128, 16, 16))
layer = Conv2d(128, 256, 3, 1, 1, True, 0.1)
out = layer(x)
print(out.shape)

x = torch.randn(size=(1, 128, 16, 16))
layer = Conv2d(128, 128, 3, 1, 1, False, 0.1)
out = layer(x)
print(out.shape)

x = torch.randn(size=(1, 128, 16, 16))
layer = Conv2d(128, 128, 3, 1, 1, True, 0.1)
out = layer(x)
print(out.shape)

torch.Size([1, 256, 16, 16])
torch.Size([1, 256, 16, 16])
torch.Size([1, 128, 16, 16])
torch.Size([1, 128, 16, 16])


In [3]:
x = torch.randn(size=( 128, 16, 16))
layer = Conv2d(128, 256, 3, 1, 1, False, 0.1)
out = layer(x)
print(out.shape)

x = torch.randn(size=(128, 16, 16))
layer = Conv2d(128, 256, 3, 1, 1, True, 0.1)
out = layer(x)
print(out.shape)

x = torch.randn(size=( 128, 16, 16))
layer = Conv2d(128, 128, 3, 1, 1, False, 0.1)
out = layer(x)
print(out.shape)

x = torch.randn(size=( 128, 16, 16))
layer = Conv2d(128, 128, 3, 1, 1, True, 0.1)
out = layer(x)
print(out.shape)

torch.Size([256, 16, 16])
torch.Size([256, 16, 16])
torch.Size([128, 16, 16])
torch.Size([128, 16, 16])


In [None]:
orig_num_params = 256 * 256 * 3 * 3

layer1 = Conv2d(256, 256, 3, 1, 1, False, 0.1)
layer3 = Conv2d(256, 256, 3, 1, 1, False, 0.3)
layer5 = Conv2d(256, 256, 3, 1, 1, False, 0.5)
layer7 = Conv2d(256, 256, 3, 1, 1, False, 0.7)
layer9 = Conv2d(256, 256, 3, 1, 1, False, 0.9)
num1 = sum(p.numel() for p in layer1.parameters() )
num3 = sum(p.numel() for p in layer3.parameters() )
num5 = sum(p.numel() for p in layer5.parameters() )
num7 = sum(p.numel() for p in layer7.parameters() )
num9 = sum(p.numel() for p in layer9.parameters() )

print(orig_num_params, num1, num1 / orig_num_params)
print(orig_num_params, num3, num3 / orig_num_params)
print(orig_num_params, num5, num5 / orig_num_params)
print(orig_num_params, num7, num7 / orig_num_params)
print(orig_num_params, num9, num9 / orig_num_params)

TypeError: object of type 'generator' has no len()