In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torchvision
from torchvision import datasets
from torchvision.transforms import transforms
from torchsummary import summary
import tltorch

import numpy as np



In [3]:
# baseline = np.prod(shape)

def count_parameters_tt(shape, rank):
    parameters = 0
    for i,s in enumerate(shape):
        r_0 = rank[i]
        r_1 = rank[i+1]
        parameters += r_0*s*r_1
    return parameters

In [4]:
path = "/bigdata/cifar10/logs/baselines/1646668631/rn18_18_dNone_128_adam_l0.001_g0.1_w0.0_sTrue/cnn_best.pth"
model = torch.load(path)

In [5]:
from tddl.factorizations import number_layers

number_layers(model)

{'conv1': (0,
  Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)),
 'bn1': (1,
  BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
 'relu': (2, ReLU(inplace=True)),
 'maxpool': (3,
  MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)),
 'layer1': (4,
  {'0': (5,
    {'conv1': (6,
      Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)),
     'bn1': (7,
      BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
     'relu': (8, ReLU(inplace=True)),
     'conv2': (9,
      Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)),
     'bn2': (10,
      BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True))}),
   '1': (11,
    {'conv1': (12,
      Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)),
     'bn1': (13,
      BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, t

In [6]:
model.layer3[0].downsample[0]

Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)

In [7]:
layers = {
    15: model.layer1[1].conv2,
    19: model.layer2[0].conv1,
    28: model.layer2[1].conv1,
    38: model.layer3[0].conv2,
    41: model.layer3[0].downsample[0],
    44: model.layer3[1].conv1,
    60: model.layer4[1].conv1,
    63: model.layer4[1].conv2,
}

In [8]:
import tensorly as tl
from tensorly.decomposition import parafac, tucker, tensor_train
from torch.linalg import norm

tensor = layers[15].weight

tt_tensors = tensor_train(tensor, rank=[1,32,9,3,1])

approximation = tt_tensors.to_tensor()
error = norm(tensor-approximation)



In [9]:
print('rank',tt_tensors.rank)
print('shape',tt_tensors.shape)

rank (1, 32, 9, 3, 1)
shape (64, 64, 3, 3)


In [10]:
n_param_tt = count_parameters_tt(rank=tt_tensors.rank, shape=tt_tensors.shape)

In [11]:
n_param = np.prod(tt_tensors.shape)

In [12]:
n_param_tt/n_param

0.5579969618055556

In [13]:
error/norm(tensor)

tensor(0.4570, device='cuda:0', grad_fn=<DivBackward0>)

In [14]:
conv = model.layer4[0].conv2
rank = 0.5
decompose_weights = True
factorization = 'tt'

conv_tt = tltorch.FactorizedConv.from_conv(
    conv, 
    rank=rank, 
    decompose_weights=decompose_weights, 
    factorization=factorization,
    # decomposition_kwargs={"init":"random"},
)

In [15]:
conv_tt

FactorizedConv(
  in_channels=512, out_channels=512, kernel_size=(3, 3), rank=(1, 1073, 13, 1073, 1), order=2, padding=[1, 1], bias=False
  (weight): TTTensor(shape=(512, 3, 3, 512), rank=(1, 512, 13, 39, 1))
)

In [16]:
rank = conv_tt.weight.rank
shape = conv_tt.weight.shape

In [17]:
rank

(1, 512, 13, 39, 1)

In [18]:
shape

(512, 3, 3, 512)

In [19]:
rank = 0.5 #(1,256,512,256,1)
decompose_weights = True
factorization = 'tt'

conv_tt_r = tltorch.FactorizedConv.from_conv(
    layers[15], 
    rank=rank, 
    decompose_weights=decompose_weights, 
    factorization=factorization,
    # decomposition_kwargs={"init":"random"},
)

In [20]:
conv_tt_r.input_rank

0.5

In [21]:
conv_tt_r.rank

(1, 101, 9, 101, 1)

In [22]:
conv_tt_r.shape

(64, 3, 3, 64)

In [23]:
conv_tt_r.weight

TTTensor(shape=(64, 3, 3, 64), rank=(1, 64, 9, 27, 1))

In [24]:
approximation_tt = conv_tt_r.weight.to_tensor()
tensor = layers[15].weight
error = norm(torch.moveaxis(approximation_tt,3,0)-tensor)/norm(tensor)
error

tensor(0.8161, device='cuda:0', grad_fn=<DivBackward0>)

In [25]:
approximation_tt = conv_tt_r.weight.to_tensor()
tensor = layers[15].weight
error = norm(approximation_tt.permute(3,0,1,2)-tensor)/norm(tensor)
error

tensor(0.8161, device='cuda:0', grad_fn=<DivBackward0>)

In [26]:
approximation_tt.shape

torch.Size([64, 3, 3, 64])

In [27]:
tensor.shape

torch.Size([64, 64, 3, 3])

In [28]:
layers[44].weight.shape

torch.Size([256, 256, 3, 3])

In [29]:
layers[41].weight.shape

torch.Size([256, 128, 1, 1])

In [227]:
layer = layers[63]

conv_tt = tltorch.FactorizedConv.from_conv(
    layer, 
    rank=17.90, 
    decompose_weights=True, 
    factorization='tt',
    # decomposition_kwargs={"init":"random"},
)
print(conv_tt.rank)
print(conv_tt.shape)

n_param_tt = count_parameters_tt(rank=conv_tt.weight.rank, shape=conv_tt.weight.shape)
n_param = np.prod(layer.weight.shape)
print("param", n_param_tt/n_param)

approximation_tt = conv_tt.weight.to_tensor()
tensor = layer.weight
error = norm(torch.moveaxis(approximation_tt,3,0)-tensor)/norm(tensor)
print("error", error)
print("-"*10)



(1, 18323, 213, 18323, 1)
(512, 3, 3, 512)
param 0.4995659722222222
error tensor(0.2446, device='cuda:0', grad_fn=<DivBackward0>)
----------


In [228]:
ranks = { # layer: {target_percentage: (tt_rank)} # relative_rank: approximation error
    '15':{
        10: (1, 40, 4, 40, 1), # 0.16: 0.9043
        25: (1, 118, 11, 118, 1), # 0.61: 0.7851
        50: (1, 296, 26, 296, 1), # 2.30: 0.5976
    },
    '19':{
        10: (1, 63, 6, 123, 1), # 0.31: 0.8788
        25: (1, 217, 19, 424, 1), # 1.43: 0.6974
        50: (1, 407, 36, 796, 1), # 3.52: 0.5383
    },
    '28':{
        10: (1, 94, 4, 94, 1), # 0.18: 0.9568
        25: (1, 468, 21, 468, 1), # 1.22: 0.8201
        50: (1, 1168, 53, 1168, 1), # 4.57: 0.6437
    },
    '38':{
        10: (1, 207, 5, 207, 1), # 0.19: 0.9612
        25: (1, 1837, 43, 1837, 1), # 2.39: 0.7571
        50: (1, 4600, 107, 4600, 1), # 8.98: 0.5704
    },
    '41':{
        10: (1, 23, 1, 46, 1), # 0.45: 0.9775
        25: (1, 62, 1, 123, 1), # 1.21: 0.9775
        50: (1, 122, 2, 243, 1), # 2.40: 0.9590
    },
    '44':{
        10: (1, 207, 5, 207, 1), # 0.19: 0.9518
        25: (1, 1837, 43, 1837, 1), # 2.39: 0.7340
        50: (1, 4600, 107, 4600, 1), # 8.98: 0.5535
    },
    '60':{
        10: (1, 425, 5, 425, 1), # 0.19: 0.5757
        25: (1, 7338, 85, 7338, 1), # 4.78: 0.3342
        50: (1, 18323, 213, 18323, 1), # 17.90: 0.2561
    },
    '63':{
        10: (1, 425, 5, 425, 1), # 0.19: 0.5394
        25: (1, 7338, 85, 7338, 1), # 4.78: 0.3132
        50: (1, 18323, 213, 18323, 1), # 17.90: 0.2446
    },
}

In [30]:
for layer in layers.values():
    conv_tt = tltorch.FactorizedConv.from_conv(
        layer, 
        rank=1.5, 
        decompose_weights=True, 
        factorization='tt',
        # decomposition_kwargs={"init":"random"},
    )
    print(conv_tt.rank)
    print(conv_tt.shape)

    n_param_tt = count_parameters_tt(rank=conv_tt.weight.rank, shape=conv_tt.weight.shape)
    n_param = np.prod(layer.weight.shape)
    print("param", n_param_tt/n_param)

    approximation_tt = conv_tt.weight.to_tensor()
    tensor = layer.weight
    error = norm(torch.moveaxis(approximation_tt,3,0)-tensor)/norm(tensor)
    print("error", error)
    print("-"*10)

    o,i,w,h = layer.weight.shape

    conv_tt_r = tltorch.FactorizedConv.from_conv(
        layer,
        # rank=(1,round(o/2),w*h,round(i/2),1),
        rank = (1,1000,1000,1000,1),
        decompose_weights=True, 
        factorization='tt',
        # decomposition_kwargs={"init":"random"},
    )
    print(conv_tt_r.rank)
    print(conv_tt_r.shape)

    n_param_tt_r = count_parameters_tt(rank=conv_tt_r.weight.rank, shape=conv_tt_r.weight.shape)
    print("param", n_param_tt_r/n_param)
    approximation_tt_r = conv_tt_r.weight.to_tensor()
    error_r = norm(torch.moveaxis(approximation_tt_r,3,0)-tensor)/norm(tensor)
    print("error", error_r)
    print("="*10)


(1, 223, 20, 223, 1)
(64, 3, 3, 64)
param 0.4171006944444444
error tensor(0.6635, device='cuda:0', grad_fn=<DivBackward0>)
----------
(1, 1000, 1000, 1000, 1)
(64, 3, 3, 64)
param 2.2222222222222223
error tensor(1.7427e-06, device='cuda:0', grad_fn=<DivBackward0>)
(1, 225, 20, 439, 1)
(64, 3, 3, 128)
param 0.2606336805555556
error tensor(0.6862, device='cuda:0', grad_fn=<DivBackward0>)
----------
(1, 1000, 1000, 1000, 1)
(64, 3, 3, 128)
param 1.7777777777777777
error tensor(1.8776e-06, device='cuda:0', grad_fn=<DivBackward0>)
(1, 545, 25, 545, 1)
(128, 3, 3, 128)
param 0.2794664171006944
error tensor(0.7938, device='cuda:0', grad_fn=<DivBackward0>)
----------
(1, 1000, 1000, 1000, 1)
(128, 3, 3, 128)
param 2.2222222222222223
error tensor(2.0746e-06, device='cuda:0', grad_fn=<DivBackward0>)
(1, 1282, 30, 1282, 1)
(256, 3, 3, 256)
param 0.2029690212673611
error tensor(0.8151, device='cuda:0', grad_fn=<DivBackward0>)
----------
(1, 1000, 1000, 1000, 1)
(256, 3, 3, 256)
param 2.22222222222

In [229]:
garipov_cifar = torch.load('/bigdata/cifar10/logs/garipov/baselines/1647358615/gar_18_dNone_128_sgd_l0.1_g0.1_w0.0_sTrue/cnn_best.pth')

In [231]:
garipov_layers = {
    2: garipov_cifar.conv2,
    4: garipov_cifar.conv3,
    6: garipov_cifar.conv4,
    8: garipov_cifar.conv5,
    10: garipov_cifar.conv6,
}

In [399]:
layer = garipov_layers[10]

conv_tt = tltorch.FactorizedConv.from_conv(
    layer, 
    rank=12.99, 
    decompose_weights=True, 
    factorization='tt',
    # decomposition_kwargs={"init":"random"},
)
print(conv_tt.rank)
print(conv_tt.shape)

n_param_tt = count_parameters_tt(rank=conv_tt.weight.rank, shape=conv_tt.weight.shape)
n_param = np.prod(layer.weight.shape)
print("param", n_param_tt/n_param)

approximation_tt = conv_tt.weight.to_tensor()
tensor = layer.weight
error = norm(torch.moveaxis(approximation_tt,3,0)-tensor)/norm(tensor)
print("error", error)
print("-"*10)

(1, 2215, 101, 2215, 1)
(128, 3, 3, 128)
param 0.7482638888888888
error tensor(0.1801, device='cuda:0', grad_fn=<DivBackward0>)
----------


In [48]:
garipov_ranks_cifar = { # layernr: {target_percentage_parameters: tt_rank} # approximation_error
    2:{
        10: 0.16, # 0.8505
        25: 0.61, # 0.6562
        50: 2.31, # 0.4151
        75: 6.60, # 0.2538
        90: 10.33, # 0.1969
    },
    4:{
        10: 0.31, # 0.8272
        25: 1.43, # 0.6319
        50: 3.52, # 0.4728
        75: 7.79, # 0.3309
        90: 12.28, # 0.2503
    },
    6:{
        10: 0.18, # 0.9334
        25: 1.22, # 0.7713
        50: 4.57, # 0.5768
        75: 12.99, # 0.3873
        90: 20.07, # 0.3033
    },
    8:{
        10: 0.18, # 0.9203
        25: 1.22, # 0.7452
        50: 4.57, # 0.5393
        75: 12.99, # 0.3474
        90: 20.07, # 0.2667
    },
    10:{
        10: 0.18, # 0.7626
        25: 1.22, # 0.2897
        50: 4.57, # 0.2346
        75: 12.99, # 0.1801
        90: 20.07, # 0.1511
    },
}

256 128 1 1


In [230]:
garipov_cifar

GaripovNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv1_bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2_bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3_bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4_bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv5): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv5_bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv6): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=128, out_features=1

In [400]:
garipov_fmnist = torch.load('/bigdata/f_mnist/logs/garipov/baselines/1647955843/gar_18_dNone_128_sgd_l0.1_g0.1_w0.0_sTrue/cnn_best.pth')

In [401]:
garipov_fmnist

GaripovNet(
  (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv1_bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2_bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3_bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4_bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv5): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv5_bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv6): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=128, out_features=1