### 0. Validate pretrained
Берём предобученную ResNet32 (``pretrained_models/resnet32-d509ac18.th``). Посмотрим на accuracy у исходной модели.

In [3]:
%run stage_0/validate_pretrained.py

Files already downloaded and verified
Test: [0/79]	Time 9.200 (9.200)	Loss 0.2115 (0.2115)	Prec@1 93.750 (93.750)
Test: [50/79]	Time 0.332 (0.538)	Loss 0.4470 (0.3828)	Prec@1 90.625 (92.371)
 * Prec@1 92.630
 * Timec@1 0.493


### 1. Get initial weights
Берём предобученную ResNet32 (``pretrained_models/resnet32-d509ac18.th``). Вытаскиваем веса, кладём тензоры в ``weights/weigths_base.mat``.

In [20]:
%run stage_1/get_initial_weights.py

Затем расскладываем веса в CP-decomposition, использую NLS (non-linear least squares), с помощью ``cpd_nls`` из [Tensorlab](https://www.tensorlab.net) в MATLAB (примерный скрипт в ``MATLAB/script_matlab_decompose.m``). В файле ``weights/weights_nls.mat``

### 2. Fine-tune initial decomposition
Заменяем набор слоёв в моделе на декомпозированные (все фильтры 3x3) (``weigths/weigths_nls.mat``) с помощью ``cpd_nls``. Затем делаем файнтюнинг весов (``epochs=50``). Лучшая модель сохранена в ``decomposed/best_initial_decompose.th``. Заменённые слои можно посмотреть в ``functional.py``, в примере всего 6.

In [None]:
%run stage_2/fine_tune_initial_decomposition.py

![initial_decomposition](stage_2/initial_decomposition.png)

Можно запустить ``stage_2/check.py``, чтобы посмотреть на структуру модели и accuracy на тестовой выборке y ``decompose/best_initial_decompose.th`` (это быстро). Хуже на 1.1 от исходной модели.

In [11]:
%run stage_2/check.py

DataParallel(
  (module): ResNet(
    (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
      (1): BasicBlock(
        (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Sequential(
          (0_decomposed): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1)

Files already downloaded and verified
Test: [0/79]	Time 7.342 (7.342)	Loss 0.3062 (0.3062)	Prec@1 92.969 (92.969)
Test: [50/79]	Time 0.303 (0.508)	Loss 0.5858 (0.4408)	Prec@1 90.625 (91.330)
 * Prec@1 91.520
 * Timec@1 0.430


### 3. Extend Decomposed Kernels
Берём отфайнтьюниную модель с предыдущего шага (``decomposed/best_initital_decompose.th``). Увеличиваем факторы (``1_decomposed``, ``2_decomposed``) до размеров 1x21 и 21x1. Добавляем сигмы ``resnet_with_sigmas.py`` и обучаем вместе с сигмами. Модель с лучшим приближением сохранена в ``decompose/best_extended_decompose.th``.

In [None]:
%run stage_3/extend_decomposed_kernels.py

![best_extended_decomposition](stage_3/extended_kernels_decomposition.png)

![sigmas](stage_3/sigmas_values.png)

Можно запустить ``stage_3/check.py``, чтобы посмотреть на структуру модели, значения сигм, соответствующие им размеры ядер и и accuracy на тестовой выборке y ``decompose/best_extended_decompose.th`` (это быстро). 

In [17]:
%run stage_3/check.py

DataParallel(
  (module): ResNet(
    (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
      (1): BasicBlock(
        (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Sequential(
          (0_decomposed): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1)

best sigmas:  [-1.0235847234725952, -1.6244574785232544, -0.6324220895767212, -1.0306727886199951, -0.5999158024787903, -0.45050451159477234]
best kernels_sz:  [9, 7, 11, 9, 11, 11]
Files already downloaded and verified
Test: [0/79]	Time 7.086 (7.086)	Loss 0.2054 (0.2054)	Prec@1 91.406 (91.406)
Test: [50/79]	Time 0.432 (0.575)	Loss 0.4664 (0.3881)	Prec@1 85.938 (88.817)
 * Prec@1 89.050
 * Timec@1 0.520


### 4. Get decomposed weights
Берём предобученную полученную на предыдущем шаге модель, применяем к весам маску и кропаем их до полученных размеров. Затем кладём декомпозиции весов в ``weights/decomposed_weights.mat``.

In [21]:
import warnings
warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) 

%run stage_4/get_decomposed_weights.py

kernels_sz:  [9, 7, 11, 9, 11, 11]
sigmas:  [-1.0235847234725952, -1.6244574785232544, -0.6324220895767212, -1.0306727886199951, -0.5999158024787903, -0.45050451159477234]


Затем обратно композируем тензоры с помощью ``cpdgen`` из [Tensorlab](https://www.tensorlab.net) в MATLAB (примерный скрипт в ``MATLAB/script_matlab_compose.m``). В файле ``weights/weights_composed.mat``.

### 5. Final fine-tune
Заменяем набор слоёв в моделе на нормальные ``conv2`` с новыми размерами фильтров и весами, скомпозированными на предыдущем шаге. Делаем финальный файнтьюнинг (``epochs=50``). Лучшая модель сохранена в ``decomposed/best_final.th``.

In [28]:
%run stage_5/final_fine_tune.py

DataParallel(
  (module): ResNet(
    (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential()
      )
      (1): BasicBlock(
        (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(16, 16, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4), bias=False)
       

Test: [50/79]	Time 0.404 (0.577)	Loss 0.3990 (0.3361)	Prec@1 90.625 (90.119)
 * Prec@1 90.080
 * Timec@1 0.520
best precision:  90.08
Epoch: [1][0/391]	Time 12.782 (12.782)	Data 10.005 (10.005)	Loss 0.0851 (0.0851)	Prec@1 98.438 (98.438)
Epoch: [1][50/391]	Time 2.878 (2.925)	Data 0.001 (0.197)	Loss 0.2397 (0.1652)	Prec@1 90.625 (94.393)
Epoch: [1][100/391]	Time 2.787 (2.848)	Data 0.001 (0.100)	Loss 0.1388 (0.1692)	Prec@1 96.094 (94.144)
Epoch: [1][150/391]	Time 3.821 (2.807)	Data 0.000 (0.067)	Loss 0.1967 (0.1674)	Prec@1 92.188 (94.221)
Epoch: [1][200/391]	Time 2.650 (2.842)	Data 0.002 (0.050)	Loss 0.1502 (0.1692)	Prec@1 93.750 (94.213)
Epoch: [1][250/391]	Time 2.799 (2.807)	Data 0.000 (0.041)	Loss 0.1251 (0.1661)	Prec@1 94.531 (94.273)
Epoch: [1][300/391]	Time 2.893 (2.851)	Data 0.001 (0.034)	Loss 0.1617 (0.1642)	Prec@1 92.969 (94.321)
Epoch: [1][350/391]	Time 2.852 (2.861)	Data 0.001 (0.029)	Loss 0.1448 (0.1630)	Prec@1 96.094 (94.351)
Test: [0/79]	Time 8.000 (8.000)	Loss 0.1636 (0.16

Test: [50/79]	Time 0.333 (0.444)	Loss 0.3590 (0.3024)	Prec@1 89.844 (90.564)
 * Prec@1 90.690
 * Timec@1 0.401
best precision:  90.8
Epoch: [9][0/391]	Time 9.761 (9.761)	Data 7.667 (7.667)	Loss 0.1564 (0.1564)	Prec@1 94.531 (94.531)
Epoch: [9][50/391]	Time 2.025 (2.189)	Data 0.000 (0.151)	Loss 0.1008 (0.1420)	Prec@1 96.875 (95.221)
Epoch: [9][100/391]	Time 2.031 (2.110)	Data 0.000 (0.076)	Loss 0.1681 (0.1348)	Prec@1 93.750 (95.429)
Epoch: [9][150/391]	Time 2.036 (2.084)	Data 0.000 (0.051)	Loss 0.1482 (0.1333)	Prec@1 95.312 (95.494)
Epoch: [9][200/391]	Time 2.026 (2.071)	Data 0.000 (0.038)	Loss 0.1173 (0.1327)	Prec@1 96.875 (95.503)
Epoch: [9][250/391]	Time 2.035 (2.063)	Data 0.000 (0.031)	Loss 0.1211 (0.1343)	Prec@1 95.312 (95.440)
Epoch: [9][300/391]	Time 2.021 (2.058)	Data 0.000 (0.026)	Loss 0.1343 (0.1359)	Prec@1 95.312 (95.416)
Epoch: [9][350/391]	Time 2.034 (2.054)	Data 0.000 (0.022)	Loss 0.1286 (0.1359)	Prec@1 95.312 (95.430)
Test: [0/79]	Time 6.071 (6.071)	Loss 0.1746 (0.1746)	P

Test: [50/79]	Time 0.330 (0.443)	Loss 0.3460 (0.2941)	Prec@1 89.844 (90.809)
 * Prec@1 90.910
 * Timec@1 0.400
best precision:  91.09
Epoch: [17][0/391]	Time 9.643 (9.643)	Data 7.518 (7.518)	Loss 0.1466 (0.1466)	Prec@1 94.531 (94.531)
Epoch: [17][50/391]	Time 2.040 (2.176)	Data 0.000 (0.148)	Loss 0.0951 (0.1256)	Prec@1 98.438 (95.818)
Epoch: [17][100/391]	Time 2.034 (2.102)	Data 0.000 (0.075)	Loss 0.1345 (0.1253)	Prec@1 94.531 (95.831)
Epoch: [17][150/391]	Time 2.028 (2.077)	Data 0.000 (0.050)	Loss 0.1019 (0.1257)	Prec@1 96.875 (95.923)
Epoch: [17][200/391]	Time 2.021 (2.064)	Data 0.000 (0.038)	Loss 0.0998 (0.1229)	Prec@1 95.312 (95.997)
Epoch: [17][250/391]	Time 2.034 (2.057)	Data 0.000 (0.030)	Loss 0.1288 (0.1243)	Prec@1 96.094 (95.966)
Epoch: [17][300/391]	Time 2.032 (2.052)	Data 0.000 (0.025)	Loss 0.1587 (0.1255)	Prec@1 93.750 (95.935)
Epoch: [17][350/391]	Time 2.033 (2.049)	Data 0.000 (0.022)	Loss 0.1395 (0.1253)	Prec@1 95.312 (95.909)
Test: [0/79]	Time 6.030 (6.030)	Loss 0.1560 (

Test: [50/79]	Time 0.335 (0.441)	Loss 0.3334 (0.2836)	Prec@1 89.062 (91.222)
 * Prec@1 91.250
 * Timec@1 0.399
best precision:  91.25
Epoch: [25][0/391]	Time 9.861 (9.861)	Data 7.757 (7.757)	Loss 0.1164 (0.1164)	Prec@1 95.312 (95.312)
Epoch: [25][50/391]	Time 2.028 (2.180)	Data 0.000 (0.152)	Loss 0.0667 (0.1271)	Prec@1 100.000 (95.787)
Epoch: [25][100/391]	Time 2.023 (2.105)	Data 0.000 (0.077)	Loss 0.1164 (0.1218)	Prec@1 94.531 (95.908)
Epoch: [25][150/391]	Time 2.028 (2.080)	Data 0.000 (0.052)	Loss 0.1430 (0.1197)	Prec@1 93.750 (95.975)
Epoch: [25][200/391]	Time 2.024 (2.067)	Data 0.000 (0.039)	Loss 0.0948 (0.1212)	Prec@1 96.094 (95.903)
Epoch: [25][250/391]	Time 2.029 (2.060)	Data 0.000 (0.031)	Loss 0.1196 (0.1204)	Prec@1 95.312 (95.885)
Epoch: [25][300/391]	Time 2.030 (2.055)	Data 0.000 (0.026)	Loss 0.1998 (0.1198)	Prec@1 93.750 (95.909)
Epoch: [25][350/391]	Time 2.027 (2.052)	Data 0.000 (0.022)	Loss 0.1443 (0.1195)	Prec@1 94.531 (95.951)
Test: [0/79]	Time 6.035 (6.035)	Loss 0.1467 

Test: [50/79]	Time 0.332 (0.441)	Loss 0.3383 (0.2819)	Prec@1 89.844 (91.176)
 * Prec@1 91.190
 * Timec@1 0.398
best precision:  91.42
Epoch: [33][0/391]	Time 9.622 (9.622)	Data 7.546 (7.546)	Loss 0.1651 (0.1651)	Prec@1 95.312 (95.312)
Epoch: [33][50/391]	Time 2.010 (2.174)	Data 0.000 (0.148)	Loss 0.0862 (0.1171)	Prec@1 97.656 (95.987)
Epoch: [33][100/391]	Time 2.021 (2.101)	Data 0.000 (0.075)	Loss 0.0937 (0.1132)	Prec@1 97.656 (96.194)
Epoch: [33][150/391]	Time 2.030 (2.076)	Data 0.000 (0.050)	Loss 0.1655 (0.1162)	Prec@1 92.969 (96.145)
Epoch: [33][200/391]	Time 2.012 (2.064)	Data 0.000 (0.038)	Loss 0.0865 (0.1172)	Prec@1 96.094 (96.125)
Epoch: [33][250/391]	Time 2.023 (2.057)	Data 0.000 (0.030)	Loss 0.0752 (0.1169)	Prec@1 97.656 (96.122)
Epoch: [33][300/391]	Time 2.021 (2.051)	Data 0.000 (0.025)	Loss 0.1625 (0.1171)	Prec@1 94.531 (96.143)
Epoch: [33][350/391]	Time 2.029 (2.048)	Data 0.000 (0.022)	Loss 0.0591 (0.1161)	Prec@1 99.219 (96.169)
Test: [0/79]	Time 6.062 (6.062)	Loss 0.1377 (

Test: [50/79]	Time 0.479 (0.585)	Loss 0.3490 (0.2786)	Prec@1 89.062 (91.131)
 * Prec@1 91.160
 * Timec@1 0.535
best precision:  91.49
Epoch: [41][0/391]	Time 11.868 (11.868)	Data 9.155 (9.155)	Loss 0.0838 (0.0838)	Prec@1 97.656 (97.656)
Epoch: [41][50/391]	Time 2.708 (2.826)	Data 0.000 (0.180)	Loss 0.1721 (0.1128)	Prec@1 95.312 (96.431)
Epoch: [41][100/391]	Time 2.606 (2.731)	Data 0.000 (0.091)	Loss 0.0992 (0.1133)	Prec@1 96.094 (96.357)
Epoch: [41][150/391]	Time 2.776 (2.705)	Data 0.001 (0.061)	Loss 0.1313 (0.1112)	Prec@1 94.531 (96.352)
Epoch: [41][200/391]	Time 2.597 (2.694)	Data 0.000 (0.046)	Loss 0.0989 (0.1106)	Prec@1 96.875 (96.432)
Epoch: [41][250/391]	Time 1.969 (2.668)	Data 0.001 (0.037)	Loss 0.1208 (0.1098)	Prec@1 96.875 (96.464)
Epoch: [41][300/391]	Time 2.897 (2.655)	Data 0.001 (0.031)	Loss 0.1745 (0.1101)	Prec@1 93.750 (96.442)
Epoch: [41][350/391]	Time 2.641 (2.657)	Data 0.001 (0.027)	Loss 0.1039 (0.1106)	Prec@1 96.875 (96.423)
Test: [0/79]	Time 7.561 (7.561)	Loss 0.1380

Test: [0/79]	Time 8.375 (8.375)	Loss 0.1370 (0.1370)	Prec@1 94.531 (94.531)
Test: [50/79]	Time 0.462 (0.613)	Loss 0.3376 (0.2745)	Prec@1 90.625 (91.452)
 * Prec@1 91.550
 * Timec@1 0.555
best precision:  91.55
Epoch: [49][0/391]	Time 16.061 (16.061)	Data 11.727 (11.727)	Loss 0.1170 (0.1170)	Prec@1 95.312 (95.312)
Epoch: [49][50/391]	Time 2.757 (3.229)	Data 0.000 (0.231)	Loss 0.1123 (0.1123)	Prec@1 96.094 (96.400)
Epoch: [49][100/391]	Time 2.835 (3.030)	Data 0.000 (0.117)	Loss 0.0833 (0.1076)	Prec@1 96.875 (96.473)
Epoch: [49][150/391]	Time 2.852 (2.966)	Data 0.000 (0.078)	Loss 0.0699 (0.1085)	Prec@1 96.875 (96.394)
Epoch: [49][200/391]	Time 3.926 (2.887)	Data 0.002 (0.059)	Loss 0.1385 (0.1072)	Prec@1 96.875 (96.506)
Epoch: [49][250/391]	Time 2.991 (2.897)	Data 0.002 (0.047)	Loss 0.0841 (0.1063)	Prec@1 96.875 (96.526)
Epoch: [49][300/391]	Time 2.754 (2.903)	Data 0.000 (0.040)	Loss 0.0684 (0.1060)	Prec@1 98.438 (96.538)
Epoch: [49][350/391]	Time 2.887 (2.900)	Data 0.002 (0.034)	Loss 0.10

![final](stage_5/final.png)

Можно запустить ``stage_5/check.py``, чтобы посмотреть на структуру модели и accuracy на тестовой выборке y ``decompose/best_final.th`` (это быстро). 

In [1]:
%run stage_5/check.py

+------------------------------+------------+
|           Modules            | Parameters |
+------------------------------+------------+
|     module.conv1.weight      |    432     |
|      module.bn1.weight       |     16     |
|       module.bn1.bias        |     16     |
| module.layer1.0.conv1.weight |    2304    |
|  module.layer1.0.bn1.weight  |     16     |
|   module.layer1.0.bn1.bias   |     16     |
| module.layer1.0.conv2.weight |    2304    |
|  module.layer1.0.bn2.weight  |     16     |
|   module.layer1.0.bn2.bias   |     16     |
| module.layer1.1.conv1.weight |    2304    |
|  module.layer1.1.bn1.weight  |     16     |
|   module.layer1.1.bn1.bias   |     16     |
| module.layer1.1.conv2.weight |    2304    |
|  module.layer1.1.bn2.weight  |     16     |
|   module.layer1.1.bn2.bias   |     16     |
| module.layer1.2.conv1.weight |    2304    |
|  module.layer1.2.bn1.weight  |     16     |
|   module.layer1.2.bn1.bias   |     16     |
| module.layer1.2.conv2.weight |  

+------------------------------+------------+
|           Modules            | Parameters |
+------------------------------+------------+
|     module.conv1.weight      |    432     |
|      module.bn1.weight       |     16     |
|       module.bn1.bias        |     16     |
| module.layer1.0.conv1.weight |    2304    |
|  module.layer1.0.bn1.weight  |     16     |
|   module.layer1.0.bn1.bias   |     16     |
| module.layer1.0.conv2.weight |    2304    |
|  module.layer1.0.bn2.weight  |     16     |
|   module.layer1.0.bn2.bias   |     16     |
| module.layer1.1.conv1.weight |    2304    |
|  module.layer1.1.bn1.weight  |     16     |
|   module.layer1.1.bn1.bias   |     16     |
| module.layer1.1.conv2.weight |   20736    |
|  module.layer1.1.bn2.weight  |     16     |
|   module.layer1.1.bn2.bias   |     16     |
| module.layer1.2.conv1.weight |    2304    |
|  module.layer1.2.bn1.weight  |     16     |
|   module.layer1.2.bn1.bias   |     16     |
| module.layer1.2.conv2.weight |  