In [1]:
import os
import copy

import torch
import tensorly as tl

from tddl.models.utils import count_parameters
from tddl.factorizations import factorize_network, number_layers
from tddl.utils.hardware import select_hardware



In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
tl.set_backend('pytorch')

select_hardware(
    cuda="6",
    cpu="2",
)

In [4]:
!pwd

/home/jetzeschuurman/gitProjects/phd/tddl/notebooks


In [18]:
save_path = "/home/jetzeschuurman/gitProjects/phd/tddl/notebooks/tmp"

In [5]:
pretrained = "/scratch/jetzeschuurman/f_mnist/logs/parn_18_d0.5_256_sgd_l0.1_g0.1_sTrue/1633280228/cnn_best"

# load pretrained model
pretrained_model = torch.load(pretrained)

In [6]:
pretrained_model

ResNet(
  (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): PreActBlock(
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (shortcut): Sequential()
    )
    (1): PreActBlock(
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1

In [7]:
pre_param = count_parameters(pretrained_model)
pre_param

11170122

In [21]:
number_layers(pretrained_model)

{'conv1': (0,
  Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)),
 'bn1': (1,
  BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
 'layer1': (2,
  {'0': (3,
    {'bn1': (4,
      BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
     'conv1': (5,
      Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)),
     'bn2': (6,
      BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
     'conv2': (7,
      Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)),
     'shortcut': (8, Sequential())}),
   '1': (9,
    {'bn1': (10,
      BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
     'conv1': (11,
      Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)),
     'bn2': (12,
      BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
 

In [22]:
fact_model = copy.deepcopy(pretrained_model)

# TODO: do I also consider the skip conneciton layers?
# For now not
layers = [5, 7, 11, 13, 18, 20, 25, 27, 32, 34, 39, 41, 46, 48, 53, 55]
factorization='tucker'
rank=0.5
decompose_weights=True

decomposition_kwargs = {'init': 'random'} if factorization == 'cp' else {}
fixed_rank_modes = 'spatial' if factorization == 'tucker' else None

output = factorize_network(
    fact_model,
    layers=layers,
    factorization=factorization,
    rank=rank,
    decompose_weights=decompose_weights,
    return_error=True,
    verbose=True,
)

0 conv1 <class 'torch.nn.modules.conv.Conv2d'>
1 bn1 <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
2 layer1 <class 'torch.nn.modules.container.Sequential'>
3 0 <class 'tddl.models.resnet.PreActBlock'>
4 bn1 <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
5 conv1 <class 'torch.nn.modules.conv.Conv2d'>




6 bn2 <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
7 conv2 <class 'torch.nn.modules.conv.Conv2d'>
8 shortcut <class 'torch.nn.modules.container.Sequential'>
9 1 <class 'tddl.models.resnet.PreActBlock'>
10 bn1 <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
11 conv1 <class 'torch.nn.modules.conv.Conv2d'>
12 bn2 <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
13 conv2 <class 'torch.nn.modules.conv.Conv2d'>
14 shortcut <class 'torch.nn.modules.container.Sequential'>
15 layer2 <class 'torch.nn.modules.container.Sequential'>
16 0 <class 'tddl.models.resnet.PreActBlock'>
17 bn1 <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
18 conv1 <class 'torch.nn.modules.conv.Conv2d'>
19 bn2 <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
20 conv2 <class 'torch.nn.modules.conv.Conv2d'>
21 shortcut <class 'torch.nn.modules.container.Sequential'>
22 0 <class 'torch.nn.modules.conv.Conv2d'>
23 1 <class 'tddl.models.resnet.PreActBlock'>
24 bn1 <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
25 conv1

In [35]:
output

{'conv1': (0,
  None,
  Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)),
 'bn1': (1,
  None,
  BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
 'layer1': (2,
  None,
  {'0': (3,
    None,
    {'bn1': (4,
      None,
      BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
     'conv1': (5,
      tensor(1.9908, device='cuda:0', grad_fn=<CopyBackwards>),
      Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)),
     'bn2': (6,
      tensor(1.9908, device='cuda:0', grad_fn=<CopyBackwards>),
      BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
     'conv2': (7,
      tensor(3.0027, device='cuda:0', grad_fn=<CopyBackwards>),
      Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)),
     'shortcut': (8,
      tensor(3.0027, device='cuda:0', grad_fn=<CopyBackwards>),
      Sequential())}),
   '1': (9,
    None

In [2]:
import json

errors_path = "/local/jetzeschuurman/f_mnist/logs/erros.json"

with open(errors_path) as f:
    errors = json.load(f)

In [24]:
pretrained_model.cpu()

res = fact_model.cpu()


In [31]:
fact_param = [param for param in fact_model.named_parameters()]

In [32]:
len(fact_param)

120

In [33]:
fact_param[0]

('conv1.weight',
 Parameter containing:
 tensor([[[[-1.4902e-01,  2.8401e-01, -9.5481e-02],
           [ 3.4867e-02,  4.3530e-01,  3.4018e-02],
           [ 2.6803e-01, -2.4621e-01,  8.4183e-04]]],
 
 
         [[[-6.1156e-01, -3.9539e-01, -5.7530e-01],
           [ 1.2588e-01,  2.9693e-01,  2.7905e-02],
           [ 5.0103e-01,  3.8716e-01,  7.1203e-01]]],
 
 
         [[[ 8.4599e-03, -1.7036e-01,  2.4733e-02],
           [-3.3256e-02, -3.5303e-01, -5.7239e-02],
           [ 8.7161e-02, -2.2029e-01,  2.8481e-02]]],
 
 
         [[[ 1.2420e-01, -1.0697e-01, -6.0135e-02],
           [-8.7881e-02, -2.8268e-01, -1.9654e-01],
           [-2.5943e-02, -2.2010e-01,  7.1406e-02]]],
 
 
         [[[-2.7283e-02,  4.7559e-01,  1.3859e-01],
           [-4.1556e-01, -4.4288e-01,  1.9513e-02],
           [ 3.5110e-01,  2.6675e-01, -3.7744e-02]]],
 
 
         [[[ 1.1480e-01,  3.4332e-01,  2.7545e-02],
           [ 1.6841e-01, -2.6461e-01, -2.8250e-01],
           [ 5.1656e-02, -3.6538e-01, -3.9081e