In [1]:
import torch
import torch.nn as nn
from torchvision.models import resnet18

![Original-ResNet-18-Architecture.png](Original-ResNet-18-Architecture.png)

In [2]:
import sys
if '../tensorly-private' not in sys.path:
    sys.path.append('../tensorly-private')
import tensorly as tl
from tensorly.decomposition import parafac
tl.set_backend('pytorch')
A = torch.rand(90, 80)
_, (B,C) = parafac(A, rank=80, init='random', 
                   tol=1e-8, stop_criterion='rec_error_deviation', n_iter_max=5000)
assert (B@C.T).shape == A.shape

torch.linalg.solve has its arguments reversed and does not return the LU factorization.
To get the LU factorization see torch.lu, which can be used with torch.lu_solve or torch.lu_unpack.
X = torch.solve(B, A).solution
should be replaced with
X = torch.linalg.solve(A, B) (Triggered internally at  ../aten/src/ATen/native/BatchLinearAlgebra.cpp:859.)
  solution, _ = torch.solve(matrix2, matrix1)


In [14]:
torch.sqrt(torch.sum((B@C.T - A)**2) / torch.sum((A)**2)).item()

0.0009052125387825072

In [49]:
A

tensor([[-0.1000,  0.5000,  0.9000,  ..., -1.0000,  0.5000,  0.8000],
        [ 0.9000, -0.5000, -0.4000,  ...,  0.1000,  0.5000, -0.9000],
        [ 0.3000, -0.7000,  0.7000,  ..., -0.6000,  0.5000,  0.4000],
        ...,
        [-0.7000,  0.0000, -0.5000,  ..., -1.0000, -0.7000, -0.8000],
        [ 0.8000,  0.3000, -0.7000,  ..., -0.1000,  0.5000, -0.4000],
        [ 0.1000,  0.6000, -0.7000,  ...,  0.8000,  0.7000,  0.0000]])

In [38]:
A.shape, len(factors), factors[0].shape, factors[1].shape

(torch.Size([90, 80]), 2, torch.Size([90, 50]), torch.Size([80, 50]))

In [3]:
orig_model = resnet18(pretrained=True)
orig_model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [4]:
REDUCTION_RATES = [2, 4, 8]
rank_map = {}

X = orig_model.conv1.weight.detach()
print('Layer: conv1')
print('Original Shape:', X.shape)
X = X.reshape((X.shape[0], X.shape[1], -1))
print('Flattened Shape:', X.shape)
print('Parameters:', X.numel())
# print('Feasible ranks <', X.numel()/sum(list(X.shape))/REDUCTION_RATE)
ranks = list(int(X.numel()/sum(list(X.shape))/rate) for rate in REDUCTION_RATES)
print('Feasible ranks = ', ranks)
print()
rank_map['conv1'] = ranks

for layer_name in ['layer1.0.conv1', 'layer1.0.conv2', 'layer1.1.conv1', 'layer1.1.conv2', 
                   'layer2.0.conv1', 'layer2.0.conv2', 'layer2.1.conv1', 'layer2.1.conv2',
                   'layer3.0.conv1', 'layer3.0.conv2', 'layer3.1.conv1', 'layer3.1.conv2',
                   'layer4.0.conv1', 'layer4.0.conv2', 'layer4.1.conv1', 'layer4.1.conv2']:
    lname, lidx, ltype = layer_name.split('.')
    lidx = int(lidx)
    layer = orig_model.__getattr__(lname)[lidx].__getattr__(ltype)
    X = layer.weight.detach()
    print('Layer:', layer_name)
    print('Original Shape:', X.shape)
    X = X.reshape((X.shape[0], X.shape[1], -1))
    print('Flattened Shape:', X.shape)
    print('Parameters:', X.numel())
    ranks = list(int(X.numel()/sum(list(X.shape))/rate) for rate in REDUCTION_RATES)
    print('Feasible ranks = ', ranks)
    print()
    rank_map[layer_name] = ranks

X = orig_model.fc.weight.detach()
print('Layer: fc')
print('Original Shape:', X.shape)
print('Parameters:', X.numel())
# print('Feasible rank < ', X.numel()/sum(list(X.shape))/rate)
ranks = list(int(X.numel()/sum(list(X.shape))/rate) for rate in REDUCTION_RATES)
print('Feasible ranks = ', ranks)
rank_map['fc'] = ranks

Layer: conv1
Original Shape: torch.Size([64, 3, 7, 7])
Flattened Shape: torch.Size([64, 3, 49])
Parameters: 9408
Feasible ranks =  [40, 20, 10]

Layer: layer1.0.conv1
Original Shape: torch.Size([64, 64, 3, 3])
Flattened Shape: torch.Size([64, 64, 9])
Parameters: 36864
Feasible ranks =  [134, 67, 33]

Layer: layer1.0.conv2
Original Shape: torch.Size([64, 64, 3, 3])
Flattened Shape: torch.Size([64, 64, 9])
Parameters: 36864
Feasible ranks =  [134, 67, 33]

Layer: layer1.1.conv1
Original Shape: torch.Size([64, 64, 3, 3])
Flattened Shape: torch.Size([64, 64, 9])
Parameters: 36864
Feasible ranks =  [134, 67, 33]

Layer: layer1.1.conv2
Original Shape: torch.Size([64, 64, 3, 3])
Flattened Shape: torch.Size([64, 64, 9])
Parameters: 36864
Feasible ranks =  [134, 67, 33]

Layer: layer2.0.conv1
Original Shape: torch.Size([128, 64, 3, 3])
Flattened Shape: torch.Size([128, 64, 9])
Parameters: 73728
Feasible ranks =  [183, 91, 45]

Layer: layer2.0.conv2
Original Shape: torch.Size([128, 128, 3, 3])
F

In [5]:
rank_map

{'conv1': [40, 20, 10],
 'layer1.0.conv1': [134, 67, 33],
 'layer1.0.conv2': [134, 67, 33],
 'layer1.1.conv1': [134, 67, 33],
 'layer1.1.conv2': [134, 67, 33],
 'layer2.0.conv1': [183, 91, 45],
 'layer2.0.conv2': [278, 139, 69],
 'layer2.1.conv1': [278, 139, 69],
 'layer2.1.conv2': [278, 139, 69],
 'layer3.0.conv1': [375, 187, 93],
 'layer3.0.conv2': [566, 283, 141],
 'layer3.1.conv1': [566, 283, 141],
 'layer3.1.conv2': [566, 283, 141],
 'layer4.0.conv1': [759, 379, 189],
 'layer4.0.conv2': [1141, 570, 285],
 'layer4.1.conv1': [1141, 570, 285],
 'layer4.1.conv2': [1141, 570, 285],
 'fc': [169, 84, 42]}