I tried to change layers in trained network with lightweight layers.

### Converting output
Find eigenvalues of covariance matrix, count how many of them are not less than, for example, 0.001 of the biggest eigenvalue.
So instead of ConvWide(x) we will have UpCh(DownCh(ConvWide(x))), where 'DownCh' decreases channels count and B increases back.
After that I combine DownCh(ConvWide(x)), it's a linear operation, which could be recalculated to ConvNotSoWide(x)
So result will be UpCh(ConvNotSoWide(x))

## Results:

Results are quite good. You could shrink wide model to more narrow architecture. Recomputing for weights is just working.
Accuracy is the same as for narrow model trained from scratch.

### Converting input: not implemented yet
Same as previous, but for layer inputs.

So ConvWide(x) could be replaced as ConvWide(UpCh(DownCh(x))) and recalculated as ConvNotSoWide(Down(Ch))

In terms of pytorch idea 2 and 3 leads to the next replacement for `WideConv`:

```Python
nn.Sequential(
    nn.Conv(in_ch, in_less_ch, kernel_size = 1, bias = False),
    nn.Conv(in_less_ch, out_less_ch, kernel_size = 3, bias = False),
    nn.Conv(out_less_ch, out_ch, kernel_size = 1, bias = True),
)
```

In [1]:
import time
from typing import Dict, Tuple, List, Optional
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from myutil import CovarianceAccumulator

In [2]:
torch.cuda.is_available()

True

In [3]:
torch.__version__

'1.13.1+cu117'

In [4]:
from torchvision import datasets
from torchvision.transforms import ToTensor

train_data = datasets.MNIST(
    root='../models/mnist',
    train=True,
    transform=ToTensor(),
    download=True,
)

test_data = datasets.MNIST(
    root='../models/mnist',
    train=False,
    transform=ToTensor(),
    download=True,
)

In [5]:
class NpAccumulator:
    def __init__(self):
        self.arrays: List[np.ndarray] = []

    def add(self, tensor: torch.Tensor):
        self.arrays.append(tensor.cpu().detach().numpy())

    @property
    def np_arr(self) -> np.ndarray:
        return np.concatenate(self.arrays, axis=0)

In [6]:
class TrainHelper:
    @staticmethod
    def train(cnn: nn.Module,
              *,
              lr: float = 0.001,
              epochs: int,
              train_dataset: datasets.MNIST,
              test_dataset: Optional[datasets.MNIST] = None,
              print_results: bool = True,
              batch_size: int,
              device_name: str = 'cuda') -> List[float]:

        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=1)

        device = torch.device(device_name)

        cnn.to(device)
        cnn.train()

        optimizer = torch.optim.Adam(cnn.parameters(), lr=lr)
        loss_func = nn.CrossEntropyLoss()

        eval_results: List[float] = []

        for epoch in range(epochs):
            for images, labels in train_loader:
                images = Variable(images.to(device))
                labels = Variable(labels.to(device))

                output = cnn(images)
                loss = loss_func(output, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            if test_dataset is not None:
                eval_result = TrainHelper.test(cnn, test_dataset, device)
                eval_results.append(eval_result)
                if print_results:
                    print(f"epoch {epoch}, accuracy = {eval_result}, loss = {loss.detach()}")
                cnn.train()

        return eval_results

    @staticmethod
    def test(cnn: nn.Module, test_dataset: datasets.MNIST, device=None) -> float:
        cnn.eval()
        loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=1)
        correct = 0
        incorrect = 0

        for images, labels in loader:
            if device is not None:
                images = images.to(device)

            results = cnn(images)
            predictions = results.detach().cpu().numpy().argmax(axis=1)
            oks = (predictions == labels.numpy()).sum()
            correct += oks
            incorrect += len(predictions) - oks

        return correct / (correct + incorrect)

    @staticmethod
    def train_models(models: List[nn.Module], device_name: str) -> Tuple[int, float]:
        """
        generator yields pair (trainable parameters count, best accuracy) for each network
        :param device_name: 'cuda' or 'cpu'
        """
        assert len(models) > 0

        for model in models:
            start = time.time()
            eval_results = TrainHelper.train(
                cnn=model,
                epochs=20,
                train_dataset=train_data,
                test_dataset=test_data,
                batch_size=2048,
                device_name=device_name,
                print_results=False
            )
            end = time.time()
            best_acc = max(eval_results)
            params_count = TrainHelper.total_parameters_count(model)
            print(f"best accuracy = {best_acc}, parameters = {params_count}, training time = {end - start}")
            yield params_count, best_acc

    @staticmethod
    def total_parameters_count(model: nn.Module) -> int:
        return sum(np.prod(p.size()) for p in model.parameters())

    @staticmethod
    def print_parameters(model: nn.Module):
        print(f"total parameters = {TrainHelper.total_parameters_count(model)}")
        for p in model.parameters():
            print(f"size {np.prod(p.size())}: {p.size()}")

    @staticmethod
    def eval_layer(cnn: nn.Module, x: np.ndarray, batch_size: int) -> np.ndarray:
        acc = NpAccumulator()
        for tensor in TrainHelper.cuda_tensors_from_numpy(x, batch_size):
            acc.add(cnn(tensor))
        return acc.np_arr
    
    @staticmethod
    def compare_layers(layer1: nn.Module, layer2: nn.Module, x: np.ndarray, batch_size: int) -> float:
        y1 = TrainHelper.eval_layer(layer1, x, batch_size)
        y2 = TrainHelper.eval_layer(layer2, x, batch_size)
        return ((y1 - y2) ** 2).mean()

    @staticmethod
    def cuda_tensors_from_numpy(arr: np.ndarray, batch_size: int):
        for i in range(arr.shape[0] // batch_size):
            yield torch.from_numpy(arr[i * batch_size: (i + 1) * batch_size]).to('cuda')

In [7]:
class MyParallelLayer(nn.Module):
    def __init__(self, real_layer: nn.Module):
        super().__init__()
        self.use_real_layer: bool = True
        self.real_layer: nn.Module = real_layer
        self.mirror_layer: Optional[nn.Module] = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.use_real_layer:
            return self.real_layer(x)
        else:
            return self.mirror_layer(x)

In [8]:
class MyConvModel(nn.Module):
    def __init__(self, channels: int):
        super(MyConvModel, self).__init__()

        c = channels
        self.layers = nn.Sequential(
            *self.conv(1, c, kernel_size=3),  # 28 - 26
            *self.conv(c, c, kernel_size=3),  # 26 - 24
            nn.MaxPool2d(2),  # 24 - 12

            *self.conv(c, c * 2, kernel_size=3),  # 12 - 10
            *self.conv(c * 2, c * 2, kernel_size=3),  # 10 - 8
            nn.MaxPool2d(2),  # 8 - 4

            *self.conv(c * 2, c * 4, kernel_size=3),  # 4 - 2
            *self.conv(c * 4, c * 4, kernel_size=2),  # 2 - 1

            nn.Conv2d(c * 4, 10, kernel_size=1, padding='valid', bias=True),
            nn.Flatten(),
        )

    def conv(self, in_ch: int, out_ch: int, *, kernel_size) -> List[nn.Module]:
        return [
            MyParallelLayer(nn.Sequential(
                nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding='valid', bias=False),
                nn.BatchNorm2d(out_ch),
            )),
            nn.LeakyReLU(0.1),
        ]

    @property
    def paraller_layers(self) -> List[MyParallelLayer]:
        return [layer for layer in self.layers if isinstance(layer, MyParallelLayer)]

    @property
    def alt_losses(self) -> List[float]:
        return [layer.last_loss for layer in self.layers if isinstance(layer, MyParallelLayer)]

    def forward(self, x: torch.Tensor):
        return self.layers(x)

In [9]:
model = MyConvModel(32).to('cuda')

In [10]:
TrainHelper.train(
    model,
    epochs=5,
    train_dataset=train_data,
    test_dataset=test_data,
    batch_size=2048,
    print_results=True,
)

epoch 0, accuracy = 0.6094, loss = 0.17913737893104553
epoch 1, accuracy = 0.9802, loss = 0.06752708554267883
epoch 2, accuracy = 0.9902, loss = 0.03562994673848152
epoch 3, accuracy = 0.9898, loss = 0.041804224252700806
epoch 4, accuracy = 0.9923, loss = 0.014046435244381428


[0.6094, 0.9802, 0.9902, 0.9898, 0.9923]

In [11]:
TrainHelper.train(
    model,
    lr=0.0001,
    epochs=5,
    train_dataset=train_data,
    test_dataset=test_data,
    batch_size=2048,
    print_results=True,
)

epoch 0, accuracy = 0.9937, loss = 0.013888739049434662
epoch 1, accuracy = 0.9942, loss = 0.009143810719251633
epoch 2, accuracy = 0.9943, loss = 0.011036340147256851
epoch 3, accuracy = 0.9941, loss = 0.005964362993836403
epoch 4, accuracy = 0.9948, loss = 0.005565876606851816


[0.9937, 0.9942, 0.9943, 0.9941, 0.9948]

In [12]:
TrainHelper.test(model, test_data, device='cuda')

0.9948

In [13]:
model

MyConvModel(
  (layers): Sequential(
    (0): MyParallelLayer(
      (real_layer): Sequential(
        (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=valid, bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): LeakyReLU(negative_slope=0.1)
    (2): MyParallelLayer(
      (real_layer): Sequential(
        (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=valid, bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (3): LeakyReLU(negative_slope=0.1)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): MyParallelLayer(
      (real_layer): Sequential(
        (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=valid, bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (6): LeakyReLU(negative_slope=0.1

In [14]:
def get_x():
    for x, loader in torch.utils.data.DataLoader(train_data, batch_size=60000):
        return x.numpy()

@dataclass
class InOut:
    layer: MyParallelLayer
    inp: np.ndarray
    out: np.ndarray
    inp_cov: CovarianceAccumulator
    out_cov: CovarianceAccumulator
    
    @staticmethod
    def calc() -> Dict[MyParallelLayer, 'InOut']:
        x = get_x()
        model.eval()
        
        result = {}
        with torch.no_grad():
            for layer in model.layers:
                x_next = TrainHelper.eval_layer(layer, x, batch_size=6000)
                if isinstance(layer, MyParallelLayer):
                    result[layer] = InOut(
                        layer, x, x_next, 
                        inp_cov=CovarianceAccumulator().add_samples(x, axis=1),
                        out_cov=CovarianceAccumulator().add_samples(x_next, axis=1),
                    )
                    print(f"{x.shape} -> {x_next.shape}")
                x = x_next
        
        model.train()
        return result


inouts = InOut.calc()    

(60000, 1, 28, 28) -> (60000, 32, 26, 26)
(60000, 32, 26, 26) -> (60000, 32, 24, 24)
(60000, 32, 12, 12) -> (60000, 64, 10, 10)
(60000, 64, 10, 10) -> (60000, 64, 8, 8)
(60000, 64, 4, 4) -> (60000, 128, 2, 2)
(60000, 128, 2, 2) -> (60000, 128, 1, 1)


In [20]:
def combine_conv_with_bn(conv: nn.Conv2d, bn: nn.BatchNorm2d) -> nn.Conv2d:
    with torch.no_grad():
        result = nn.Conv2d(
            conv.in_channels,
            conv.out_channels,
            kernel_size=(conv.weight.size(2), conv.weight.size(3)),
            bias=True,
        )
        
        # var_sqrt = (1e-10 + bn.running_var.detach().sqrt())
        var_sqrt = bn.running_var.detach().sqrt()
        result.weight = nn.Parameter(
            conv.weight.detach() * (bn.weight.detach() / var_sqrt)[:, None, None, None]
        )
        result.bias = nn.Parameter(
            (bn.bias.detach() - bn.running_mean.detach() * bn.weight.detach() / var_sqrt)
        )
        
        return result

In [21]:
def combine_convs(conv0: nn.Conv2d, conv1: nn.Conv2d) -> nn.Conv2d:
    assert conv0.bias is None, 'not supported yet'
    w0 = conv0.weight.cpu().detach()
    w1 = conv1.weight.cpu().detach()
    assert w0.size(2) == 1 or w1.size(2) == 1, 'no more than one non point-wise convolution'
    assert w0.size(3) == 1 or w1.size(3) == 1, 'no more than one non point-wise convolution'
    with torch.no_grad():
        w01 = torch.tensordot(w0, w1, dims=[[0], [1]])
        w01 = torch.moveaxis(w01, 3, 0)
        if w01.size(5) != 1:
            w01 = torch.swapaxes(w01, 3, 5)
        if w01.size(4) != 1:
            w01 = torch.swapaxes(w01, 2, 4)
        w01 = torch.reshape(w01, shape=[w01.size(i) for i in range(4)])

    bias = conv1.bias 
    conv01 = nn.Conv2d(
        in_channels=conv0.in_channels, 
        out_channels=conv1.out_channels,
        kernel_size=(w01.size(2), w01.size(3)),
        bias = bias is not None
    )
    
    conv01.weight = nn.Parameter(w01.detach())
    if bias is not None:
        conv01.bias = nn.Parameter(bias.detach())
        
    return conv01

In [48]:
def make_parallel_layer(layer: MyParallelLayer, mid_ch: int):
    inout = inouts[layer]
    real_as_conv: nn.Conv2d = combine_conv_with_bn(conv=layer.real_layer[0], bn=layer.real_layer[1])
    conv_wide = nn.Conv2d(real_as_conv.in_channels, real_as_conv.out_channels, kernel_size=3, bias=False)
    conv_downch = nn.Conv2d(real_as_conv.out_channels, mid_ch, kernel_size=1, bias=False)
    conv_upch = nn.Conv2d(9, real_as_conv.out_channels, kernel_size=1, bias=True)

    shift, m_to, m_back = inout.out_cov.to_eigenvalues_and_back(mid_ch)

    conv_wide.weight = nn.Parameter(real_as_conv.weight.detach())

    conv_downch.weight = nn.Parameter(torch.from_numpy(m_to.astype(np.float32).T[:, :, np.newaxis, np.newaxis]))

    conv_upch.weight = nn.Parameter(torch.from_numpy(m_back.astype(np.float32).T[:, :, np.newaxis, np.newaxis]))
    shift_after = shift - shift @ m_to @ m_back + real_as_conv.bias.cpu().detach().numpy() @ m_to @ m_back
    conv_upch.bias = nn.Parameter(torch.from_numpy(shift_after.astype(np.float32)))
    
    layer.mirror_layer = nn.Sequential(
        combine_convs(conv_wide, conv_downch),
        conv_upch
    ).to('cuda')
    
    model.eval()
    diff = TrainHelper.compare_layers(layer.real_layer, layer.mirror_layer, inout.inp, batch_size=1000)
    print(f"diff = {diff}")

In [49]:
def shrink_layer(layer_no: int, mid_layers: int) -> np.ndarray:
    eig = inouts[model.paraller_layers[layer_no]].out_cov.eigenvalues_normalized
    make_parallel_layer(model.paraller_layers[layer_no], mid_layers)
    print(f"choosen eigenvalues weight = {np.sum(eig[0:mid_layers])}")
    return np.stack([eig, np.cumsum(eig)], axis=0)

In [50]:
shrink_layer(layer_no=0, mid_layers=9)

diff = 9.49143128536889e-08
choosen eigenvalues weight = 1.0000000000001272


array([[ 4.29704253e-01,  2.68115971e-01,  1.24157649e-01,
         8.04575583e-02,  4.39040646e-02,  2.95081168e-02,
         1.19096976e-02,  1.01697522e-02,  2.07293804e-03,
         1.22863553e-13,  9.97315915e-14,  8.51974704e-14,
         7.50987748e-14,  5.55099895e-14,  4.94501527e-14,
         3.97690044e-14,  3.46198724e-14,  2.49301814e-14,
         1.17608943e-14, -1.50150390e-15, -1.45516711e-14,
        -1.84393761e-14, -2.25469884e-14, -3.53238128e-14,
        -4.45581167e-14, -5.31659117e-14, -6.50267717e-14,
        -7.99860865e-14, -8.36323688e-14, -8.82918775e-14,
        -1.05560026e-13, -1.13744656e-13],
       [ 4.29704253e-01,  6.97820224e-01,  8.21977872e-01,
         9.02435431e-01,  9.46339495e-01,  9.75847612e-01,
         9.87757310e-01,  9.97927062e-01,  1.00000000e+00,
         1.00000000e+00,  1.00000000e+00,  1.00000000e+00,
         1.00000000e+00,  1.00000000e+00,  1.00000000e+00,
         1.00000000e+00,  1.00000000e+00,  1.00000000e+00,
         1.00

In [54]:
shrink_layer(layer_no=1, mid_layers=24)

diff = 0.005448869429528713
choosen eigenvalues weight = 0.994574078583633


array([[2.81859026e-01, 1.98255604e-01, 1.51850699e-01, 7.84284897e-02,
        6.57679001e-02, 5.68234519e-02, 3.15657182e-02, 2.04241435e-02,
        1.83555612e-02, 1.42297051e-02, 1.35056943e-02, 9.21466554e-03,
        9.20421064e-03, 8.24164731e-03, 6.68852106e-03, 5.77877249e-03,
        5.47644385e-03, 4.15019355e-03, 3.73452905e-03, 3.07917536e-03,
        2.65783320e-03, 2.08764606e-03, 1.75591597e-03, 1.43853129e-03,
        1.25296900e-03, 9.36114113e-04, 7.75216381e-04, 6.38272178e-04,
        5.82369053e-04, 4.95988521e-04, 4.51396236e-04, 2.93595937e-04],
       [2.81859026e-01, 4.80114630e-01, 6.31965329e-01, 7.10393819e-01,
        7.76161719e-01, 8.32985171e-01, 8.64550889e-01, 8.84975033e-01,
        9.03330594e-01, 9.17560299e-01, 9.31065993e-01, 9.40280659e-01,
        9.49484869e-01, 9.57726517e-01, 9.64415038e-01, 9.70193810e-01,
        9.75670254e-01, 9.79820448e-01, 9.83554977e-01, 9.86634152e-01,
        9.89291985e-01, 9.91379631e-01, 9.93135547e-01, 9.94574

In [53]:
shrink_layer(layer_no=2, mid_layers=48)

diff = 0.009351296350359917
choosen eigenvalues weight = 0.9905895888614311


array([[1.84476935e-01, 1.39537819e-01, 1.22777491e-01, 8.16121902e-02,
        6.72731105e-02, 6.35277026e-02, 5.20721369e-02, 3.43373185e-02,
        3.14075872e-02, 2.28305958e-02, 2.08568791e-02, 1.69634740e-02,
        1.54627651e-02, 1.28149683e-02, 1.11716986e-02, 1.02963746e-02,
        9.22164570e-03, 8.50079391e-03, 7.48350372e-03, 6.28762622e-03,
        6.10840598e-03, 5.62741618e-03, 4.89596679e-03, 4.22434669e-03,
        4.13916661e-03, 4.13396309e-03, 3.58474965e-03, 3.51951168e-03,
        3.23601711e-03, 2.95594248e-03, 2.68639006e-03, 2.43302794e-03,
        2.26231306e-03, 2.15242611e-03, 2.02204027e-03, 1.94405962e-03,
        1.82328124e-03, 1.64219316e-03, 1.60803416e-03, 1.47816214e-03,
        1.39105440e-03, 1.34731490e-03, 1.27644017e-03, 1.23373134e-03,
        1.07354374e-03, 1.01312354e-03, 9.50283918e-04, 9.14066685e-04,
        9.00739069e-04, 8.49435170e-04, 7.79678314e-04, 7.67677204e-04,
        7.06470189e-04, 6.73503554e-04, 6.15284598e-04, 5.847901

In [55]:
shrink_layer(layer_no=3, mid_layers=48)

diff = 0.014899151399731636
choosen eigenvalues weight = 0.9847368993848389


array([[1.47121074e-01, 1.16047351e-01, 1.08268925e-01, 1.05549985e-01,
        8.42595544e-02, 5.70158770e-02, 5.16015245e-02, 4.64819724e-02,
        3.33840008e-02, 2.50069074e-02, 2.13963120e-02, 1.83027418e-02,
        1.55315358e-02, 1.31018583e-02, 1.14665500e-02, 1.06392946e-02,
        8.61081684e-03, 8.18654076e-03, 7.99503403e-03, 7.61446938e-03,
        6.61265831e-03, 6.37980679e-03, 5.26401955e-03, 4.86386753e-03,
        4.83957078e-03, 4.44991545e-03, 4.17544895e-03, 4.04689046e-03,
        3.76129715e-03, 3.38518269e-03, 3.35744393e-03, 3.08733753e-03,
        2.94985462e-03, 2.85883876e-03, 2.57259405e-03, 2.45338984e-03,
        2.39973318e-03, 2.29493874e-03, 2.19892180e-03, 2.02775154e-03,
        1.90702084e-03, 1.84602403e-03, 1.76521028e-03, 1.66784037e-03,
        1.57697698e-03, 1.51855034e-03, 1.46728612e-03, 1.42620402e-03,
        1.36307912e-03, 1.26768920e-03, 1.22798461e-03, 1.21097098e-03,
        1.12662137e-03, 1.09652334e-03, 9.89836398e-04, 9.221821

In [57]:
shrink_layer(layer_no=4, mid_layers=96)

diff = 0.008490923792123795
choosen eigenvalues weight = 0.9911743401075941


array([[1.84385464e-01, 1.55088939e-01, 8.71673906e-02, 8.05175919e-02,
        6.88854570e-02, 4.93061176e-02, 3.96039074e-02, 3.24921957e-02,
        3.08897446e-02, 2.52436377e-02, 1.99384378e-02, 1.66164269e-02,
        1.56344948e-02, 1.24548504e-02, 1.17167563e-02, 1.11647709e-02,
        1.01345509e-02, 9.44358734e-03, 7.21642338e-03, 6.43273699e-03,
        6.05714910e-03, 5.75158883e-03, 5.41072922e-03, 4.74186499e-03,
        4.49168525e-03, 4.27619980e-03, 3.83842219e-03, 3.65897138e-03,
        3.42196823e-03, 3.26155958e-03, 3.16222040e-03, 2.90211016e-03,
        2.66146851e-03, 2.45721524e-03, 2.35589572e-03, 2.28768853e-03,
        2.06125825e-03, 1.91944533e-03, 1.90335175e-03, 1.82857493e-03,
        1.79258275e-03, 1.71763161e-03, 1.64722940e-03, 1.61676516e-03,
        1.58946988e-03, 1.48425106e-03, 1.40671482e-03, 1.35445676e-03,
        1.33558048e-03, 1.28045860e-03, 1.27585659e-03, 1.18238436e-03,
        1.15087772e-03, 1.08376979e-03, 1.07774121e-03, 1.055075

In [58]:
shrink_layer(layer_no=5, mid_layers=96)

diff = 0.0050513772293925285
choosen eigenvalues weight = 0.9955017697229525


array([[1.36571249e-01, 1.29763431e-01, 1.13380778e-01, 1.06225043e-01,
        1.04733834e-01, 9.88330742e-02, 8.68179328e-02, 8.23750389e-02,
        7.25473235e-02, 8.20348927e-03, 1.93513623e-03, 1.84405279e-03,
        1.68486348e-03, 1.56738244e-03, 1.55953029e-03, 1.43327481e-03,
        1.35337515e-03, 1.33341209e-03, 1.26607967e-03, 1.25320561e-03,
        1.22448376e-03, 1.16216758e-03, 1.14164923e-03, 1.09310273e-03,
        1.06885901e-03, 1.02958300e-03, 9.76709062e-04, 9.40946720e-04,
        9.28741242e-04, 9.02838013e-04, 8.78413417e-04, 8.64199838e-04,
        8.31411977e-04, 7.97050214e-04, 7.86972696e-04, 7.82043599e-04,
        7.40482828e-04, 7.27695212e-04, 7.21172425e-04, 6.95717507e-04,
        6.87917520e-04, 6.81386033e-04, 6.60601103e-04, 6.41899503e-04,
        6.31135369e-04, 6.16737145e-04, 6.06400919e-04, 5.94189938e-04,
        5.71543358e-04, 5.57672054e-04, 5.55188477e-04, 5.47778843e-04,
        5.28998454e-04, 5.21338721e-04, 5.08907344e-04, 4.976837

In [59]:
def test_layers(mirror_count: int) -> float:
    for i, layer in enumerate(model.paraller_layers):
        layer.use_real_layer = i >= mirror_count
    return TrainHelper.test(model, test_data, device='cuda')

In [60]:
[test_layers(i) for i in range(7)]

[0.9948, 0.9948, 0.995, 0.995, 0.9944, 0.9944, 0.9943]

## Compare accuracy
Create architecture of shrinked model and train it from scratch.

In [61]:
class MyShrinkedConvModel(nn.Module):
    def __init__(self):
        super().__init__()

        c = 32
        self.layers = nn.Sequential(
            *self.conv(1, 9, c, kernel_size=3),  # 28 - 26
            *self.conv(c, 24, c, kernel_size=3),  # 26 - 24
            nn.MaxPool2d(2),  # 24 - 12

            *self.conv(c, 48, c * 2, kernel_size=3),  # 12 - 10
            *self.conv(c * 2, 48, c * 2, kernel_size=3),  # 10 - 8
            nn.MaxPool2d(2),  # 8 - 4

            *self.conv(c * 2, 96, c * 4, kernel_size=3),  # 4 - 2
            *self.conv(c * 4, 96, c * 4, kernel_size=2),  # 2 - 1

            nn.Conv2d(c * 4, 10, kernel_size=1, padding='valid', bias=True),
            nn.Flatten(),
        )

    def conv(self, in_ch: int, mid_ch: int, out_ch: int, *, kernel_size) -> List[nn.Module]:
        return [
            MyParallelLayer(nn.Sequential(
                nn.Conv2d(in_ch, mid_ch, kernel_size=kernel_size, padding='valid', bias=False),
                nn.Conv2d(mid_ch, out_ch, kernel_size=1, padding='valid', bias=False),
                nn.BatchNorm2d(out_ch),
            )),
            nn.LeakyReLU(0.1),
        ]

    @property
    def paraller_layers(self) -> List[MyParallelLayer]:
        return [layer for layer in self.layers if isinstance(layer, MyParallelLayer)]

    @property
    def alt_losses(self) -> List[float]:
        return [layer.last_loss for layer in self.layers if isinstance(layer, MyParallelLayer)]

    def forward(self, x: torch.Tensor):
        return self.layers(x)


shrinked_model = MyShrinkedConvModel().to('cuda')
TrainHelper.train(
    shrinked_model,
    epochs=5,
    train_dataset=train_data,
    test_dataset=test_data,
    batch_size=2048,
    print_results=True,
)

epoch 0, accuracy = 0.1135, loss = 0.18532802164554596
epoch 1, accuracy = 0.9589, loss = 0.07008330523967743
epoch 2, accuracy = 0.99, loss = 0.04463108256459236
epoch 3, accuracy = 0.9885, loss = 0.03158889710903168
epoch 4, accuracy = 0.9907, loss = 0.0327920988202095


[0.1135, 0.9589, 0.99, 0.9885, 0.9907]

In [62]:
TrainHelper.train(
    shrinked_model,
    lr=0.0001,
    epochs=5,
    train_dataset=train_data,
    test_dataset=test_data,
    batch_size=2048,
    print_results=True,
)    

epoch 0, accuracy = 0.9927, loss = 0.01813686452805996
epoch 1, accuracy = 0.9932, loss = 0.020897528156638145
epoch 2, accuracy = 0.9933, loss = 0.014707138761878014
epoch 3, accuracy = 0.9943, loss = 0.011871189810335636
epoch 4, accuracy = 0.9944, loss = 0.01109230611473322


[0.9927, 0.9932, 0.9933, 0.9943, 0.9944]