I tried to change layers in trained network with lightweight layers.

### Converting output
Find eigenvalues of covariance matrix, count how many of them are not less than, for example, 0.001 of the biggest eigenvalue.
So instead of ConvWide(x) we will have UpCh(DownCh(ConvWide(x))), where 'DownCh' decreases channels count and B increases back.
After that I combine DownCh(ConvWide(x)), it's a linear operation, which could be recalculated to ConvNotSoWide(x)
So result will be UpCh(ConvNotSoWide(x))

### Converting input: not implemented yet
Same as previous, but for layer inputs.

So ConvWide(x) could be replaced as ConvWide(UpCh(DownCh(x))) and recalculated as ConvNotSoWide(Down(Ch))

In terms of pytorch idea 2 and 3 leads to the next replacement for `WideConv`:

```Python
nn.Sequential(
    nn.Conv(in_ch, in_less_ch, kernel_size = 1, bias = False),
    nn.Conv(in_less_ch, out_less_ch, kernel_size = 3, bias = False),
    nn.Conv(out_less_ch, out_ch, kernel_size = 1, bias = True),
)
```

In [1]:
import itertools
import time
from typing import Dict, Optional, Tuple, List, Optional, Iterable
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from myutil import CovarianceAccumulator

In [2]:
torch.cuda.is_available()

True

In [3]:
torch.__version__

'1.13.1+cu117'

In [4]:
from torchvision import datasets
from torchvision.transforms import ToTensor

train_data = datasets.MNIST(
    root='../models/mnist',
    train=True,
    transform=ToTensor(),
    download=True,
)

test_data = datasets.MNIST(
    root='../models/mnist',
    train=False,
    transform=ToTensor(),
    download=True,
)

In [5]:
class NpAccumulator:
    def __init__(self):
        self.arrays: List[np.ndarray] = []

    def add(self, tensor: torch.Tensor):
        self.arrays.append(tensor.cpu().detach().numpy())

    @property
    def np_arr(self) -> np.ndarray:
        return np.concatenate(self.arrays, axis=0)

In [6]:
class TrainHelper:
    @staticmethod
    def train(cnn: nn.Module,
              *,
              lr: float = 0.001,
              epochs: int,
              train_dataset: datasets.MNIST,
              test_dataset: Optional[datasets.MNIST] = None,
              print_results: bool = True,
              batch_size: int,
              device_name: str = 'cuda') -> List[float]:

        train_loader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=batch_size,
                                                   shuffle=True,
                                                   num_workers=1)

        device = torch.device(device_name)

        cnn.to(device)
        cnn.train()

        optimizer = torch.optim.Adam(cnn.parameters(), lr=lr)
        loss_func = nn.CrossEntropyLoss()

        eval_results: List[float] = []

        for epoch in range(epochs):
            for images, labels in train_loader:
                images = Variable(images.to(device))
                labels = Variable(labels.to(device))

                output = cnn(images)
                loss = loss_func(output, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            if test_dataset is not None:
                eval_result = TrainHelper.test(cnn, test_dataset, device)
                eval_results.append(eval_result)
                if print_results:
                    print(f"epoch {epoch}, accuracy = {eval_result}, loss = {loss.detach()}")
                cnn.train()

        return eval_results

    @staticmethod
    def test(cnn: nn.Module, test_dataset: datasets.MNIST, device=None) -> float:
        cnn.eval()
        loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=1)
        correct = 0
        incorrect = 0

        for images, labels in loader:
            if device is not None:
                images = images.to(device)

            results = cnn(images)
            predictions = results.detach().cpu().numpy().argmax(axis=1)
            oks = (predictions == labels.numpy()).sum()
            correct += oks
            incorrect += len(predictions) - oks

        return correct / (correct + incorrect)

    @staticmethod
    def train_models(models: List[nn.Module], device_name: str) -> Tuple[int, float]:
        """
        generator yields pair (trainable parameters count, best accuracy) for each network
        :param device_name: 'cuda' or 'cpu'
        """
        assert len(models) > 0

        for model in models:
            start = time.time()
            eval_results = TrainHelper.train(
                cnn=model,
                epochs=20,
                train_dataset=train_data,
                test_dataset=test_data,
                batch_size=2048,
                device_name=device_name,
                print_results=False
            )
            end = time.time()
            best_acc = max(eval_results)
            params_count = TrainHelper.total_parameters_count(model)
            print(f"best accuracy = {best_acc}, parameters = {params_count}, training time = {end - start}")
            yield params_count, best_acc

    @staticmethod
    def total_parameters_count(model: nn.Module) -> int:
        return sum(np.prod(p.size()) for p in model.parameters())

    @staticmethod
    def print_parameters(model: nn.Module):
        print(f"total parameters = {TrainHelper.total_parameters_count(model)}")
        for p in model.parameters():
            print(f"size {np.prod(p.size())}: {p.size()}")

    @staticmethod
    def eval_layer(cnn: nn.Module, x: np.ndarray, batch_size: int) -> np.ndarray:
        acc = NpAccumulator()
        for tensor in TrainHelper.cuda_tensors_from_numpy(x, batch_size):
            acc.add(cnn(tensor))
        return acc.np_arr
    
    @staticmethod
    def compare_layers(layer1: nn.Module, layer2: nn.Module, x: np.ndarray, batch_size: int) -> float:
        y1 = TrainHelper.eval_layer(layer1, x, batch_size)
        y2 = TrainHelper.eval_layer(layer2, x, batch_size)
        return ((y1 - y2) ** 2).mean()

    @staticmethod
    def cuda_tensors_from_numpy(arr: np.ndarray, batch_size: int):
        for i in range(arr.shape[0] // batch_size):
            yield torch.from_numpy(arr[i * batch_size: (i + 1) * batch_size]).to('cuda')

In [7]:
class MyParallelLayer(nn.Module):
    def __init__(self, real_layer: nn.Module):
        super().__init__()
        self.use_real_layer: bool = True
        self.real_layer: nn.Module = real_layer
        self.mirror_layer: Optional[nn.Module] = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.use_real_layer:
            return self.real_layer(x)
        else:
            return self.mirror_layer(x)

In [8]:
class MyConvModel(nn.Module):
    def __init__(self, channels: int):
        super(MyConvModel, self).__init__()

        c = channels
        self.layers = nn.Sequential(
            *self.conv(1, c, kernel_size=3),  # 28 - 26
            *self.conv(c, c, kernel_size=3),  # 26 - 24
            nn.MaxPool2d(2),  # 24 - 12

            *self.conv(c, c * 2, kernel_size=3),  # 12 - 10
            *self.conv(c * 2, c * 2, kernel_size=3),  # 10 - 8
            nn.MaxPool2d(2),  # 8 - 4

            *self.conv(c * 2, c * 4, kernel_size=3),  # 4 - 2
            *self.conv(c * 4, c * 4, kernel_size=2),  # 2 - 1

            nn.Conv2d(c * 4, 10, kernel_size=1, padding='valid', bias=True),
            nn.Flatten(),
        )

    def conv(self, in_ch: int, out_ch: int, *, kernel_size) -> List[nn.Module]:
        return [
            MyParallelLayer(nn.Sequential(
                nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding='valid', bias=True),
                # nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding='valid', bias=False),
                # nn.BatchNorm2d(out_ch),
            )),
            nn.LeakyReLU(0.1),
        ]

    @property
    def paraller_layers(self) -> List[MyParallelLayer]:
        return [layer for layer in self.layers if isinstance(layer, MyParallelLayer)]

    @property
    def alt_losses(self) -> List[float]:
        return [layer.last_loss for layer in self.layers if isinstance(layer, MyParallelLayer)]

    def forward(self, x: torch.Tensor):
        return self.layers(x)

In [11]:
model = MyConvModel(32).to('cuda')

In [12]:
TrainHelper.train(
    model,
    epochs=5,
    train_dataset=train_data,
    test_dataset=test_data,
    batch_size=2048,
    print_results=True,
)

epoch 0, accuracy = 0.863, loss = 0.5057899951934814
epoch 1, accuracy = 0.9319, loss = 0.2258572280406952
epoch 2, accuracy = 0.9606, loss = 0.14280276000499725
epoch 3, accuracy = 0.971, loss = 0.10730598866939545
epoch 4, accuracy = 0.9785, loss = 0.07931938022375107


[0.863, 0.9319, 0.9606, 0.971, 0.9785]

In [13]:
TrainHelper.train(
    model,
    lr=0.0001,
    epochs=5,
    train_dataset=train_data,
    test_dataset=test_data,
    batch_size=2048,
    print_results=True,
)

epoch 0, accuracy = 0.9815, loss = 0.04147940129041672
epoch 1, accuracy = 0.9827, loss = 0.060650743544101715
epoch 2, accuracy = 0.9828, loss = 0.05440560728311539
epoch 3, accuracy = 0.9838, loss = 0.05266561731696129
epoch 4, accuracy = 0.9833, loss = 0.05502895265817642


[0.9815, 0.9827, 0.9828, 0.9838, 0.9833]

In [14]:
TrainHelper.test(model, test_data, device='cuda')

0.9833

In [15]:
model

MyConvModel(
  (layers): Sequential(
    (0): MyParallelLayer(
      (real_layer): Sequential(
        (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=valid)
      )
    )
    (1): LeakyReLU(negative_slope=0.1)
    (2): MyParallelLayer(
      (real_layer): Sequential(
        (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=valid)
      )
    )
    (3): LeakyReLU(negative_slope=0.1)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): MyParallelLayer(
      (real_layer): Sequential(
        (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=valid)
      )
    )
    (6): LeakyReLU(negative_slope=0.1)
    (7): MyParallelLayer(
      (real_layer): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=valid)
      )
    )
    (8): LeakyReLU(negative_slope=0.1)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): MyParallelLayer(
      (real_l

In [16]:
def get_x():
    for x, loader in torch.utils.data.DataLoader(train_data, batch_size=60000):
        return x.numpy()

@dataclass
class InOut:
    layer: MyParallelLayer
    inp: np.ndarray
    out: np.ndarray
    inp_cov: CovarianceAccumulator
    out_cov: CovarianceAccumulator
    
    @staticmethod
    def calc() -> Dict[MyParallelLayer, 'InOut']:
        x = get_x()
        model.eval()
        
        result = {}
        with torch.no_grad():
            for layer in model.layers:
                x_next = TrainHelper.eval_layer(layer, x, batch_size=6000)
                if isinstance(layer, MyParallelLayer):
                    result[layer] = InOut(
                        layer, x, x_next, 
                        inp_cov=CovarianceAccumulator().add_samples(x, axis=1),
                        out_cov=CovarianceAccumulator().add_samples(x_next, axis=1),
                    )
                    print(f"{x.shape} -> {x_next.shape}")
                x = x_next
        
        model.train()
        return result


inouts = InOut.calc()    

(60000, 1, 28, 28) -> (60000, 32, 26, 26)
(60000, 32, 26, 26) -> (60000, 32, 24, 24)
(60000, 32, 12, 12) -> (60000, 64, 10, 10)
(60000, 64, 10, 10) -> (60000, 64, 8, 8)
(60000, 64, 4, 4) -> (60000, 128, 2, 2)
(60000, 128, 2, 2) -> (60000, 128, 1, 1)


In [17]:
def combine_convs(conv0: nn.Conv2d, conv1: nn.Conv2d) -> nn.Conv2d:
    assert conv0.bias is None, 'not supported yet'
    w0 = conv0.weight.cpu().detach()
    w1 = conv1.weight.cpu().detach()
    assert w0.size(2) == 1 or w1.size(2) == 1, 'no more than one non point-wise convolution'
    assert w0.size(3) == 1 or w1.size(3) == 1, 'no more than one non point-wise convolution'
    with torch.no_grad():
        w01 = torch.tensordot(w0, w1, dims=[[0], [1]])
        w01 = torch.moveaxis(w01, 3, 0)
        if w01.size(5) != 1:
            w01 = torch.swapaxes(w01, 3, 5)
        if w01.size(4) != 1:
            w01 = torch.swapaxes(w01, 2, 4)
        w01 = torch.reshape(w01, shape=[w01.size(i) for i in range(4)])

    bias = conv1.bias 
    conv01 = nn.Conv2d(
        in_channels=conv0.in_channels, 
        out_channels=conv1.out_channels,
        kernel_size=(w01.size(2), w01.size(3)),
        bias = bias is not None
    )
    
    conv01.weight = nn.Parameter(w01)
    if bias is not None:
        conv01.bias = nn.Parameter(bias.detach())
        
    return conv01

In [18]:
def make_parallel_layer(layer: MyParallelLayer, mid_ch: int):
    inout = inouts[layer]
    real_conv: nn.Conv2d = inout.layer.real_layer[0]
    conv_wide = nn.Conv2d(real_conv.in_channels, real_conv.out_channels, kernel_size=3, bias=False)
    conv_downch = nn.Conv2d(real_conv.out_channels, mid_ch, kernel_size=1, bias=False)
    conv_upch = nn.Conv2d(9, real_conv.out_channels, kernel_size=1, bias=True)

    shift, m_to, m_back = inout.out_cov.to_eigenvalues_and_back(mid_ch)
    
    conv_wide.weight = nn.Parameter(real_conv.weight.detach())
    
    conv_downch.weight = nn.Parameter(torch.from_numpy(m_to.astype(np.float32).T[:, :, np.newaxis, np.newaxis]))
    
    conv_upch.weight = nn.Parameter(torch.from_numpy(m_back.astype(np.float32).T[:, :, np.newaxis, np.newaxis]))
    shift_after = shift - shift @ m_to @ m_back + real_conv.bias.cpu().detach().numpy() @ m_to @ m_back
    conv_upch.bias = nn.Parameter(torch.from_numpy(shift_after.astype(np.float32)))
    
    layer.mirror_layer = nn.Sequential(
        combine_convs(conv_wide, conv_downch),
        conv_upch
    ).to('cuda')
    
    model.eval()
    diff = TrainHelper.compare_layers(layer.real_layer, layer.mirror_layer, inout.inp, batch_size=1000)
    print(f"diff = {diff}")

In [21]:
def shrink_layer(layer_no: int, mid_layers: int) -> np.ndarray:
    eig = inouts[model.paraller_layers[layer_no]].out_cov.covariance_eigenvalues_normalized
    make_parallel_layer(model.paraller_layers[layer_no], mid_layers)
    print(f"choosen eigenvalues weight = {np.sum(eig[0:mid_layers])}")
    return np.stack([eig, np.cumsum(eig)], axis=0)

In [22]:
shrink_layer(layer_no=0, mid_layers=9)

diff = 1.3486904337271463e-15
choosen eigenvalues weight = 0.9999999999997045


array([[ 6.93579290e-01,  1.57256679e-01,  9.19444662e-02,
         2.78612548e-02,  1.06627511e-02,  9.27248369e-03,
         5.88955466e-03,  1.95233102e-03,  1.58118955e-03,
         3.73497370e-13,  2.96235600e-13,  2.16443857e-13,
         1.62406429e-13,  1.12770068e-13,  9.71151631e-14,
         7.82251517e-14,  6.10534723e-14,  4.01548849e-14,
         3.17610912e-14,  1.95409871e-14,  1.06333967e-14,
        -1.86857989e-14, -3.61448670e-14, -3.96351742e-14,
        -4.87957283e-14, -6.96424282e-14, -7.86779172e-14,
        -8.81179782e-14, -1.20222649e-13, -1.60944379e-13,
        -2.55296592e-13, -2.88083364e-13],
       [ 6.93579290e-01,  8.50835969e-01,  9.42780435e-01,
         9.70641690e-01,  9.81304441e-01,  9.90576925e-01,
         9.96466479e-01,  9.98418810e-01,  1.00000000e+00,
         1.00000000e+00,  1.00000000e+00,  1.00000000e+00,
         1.00000000e+00,  1.00000000e+00,  1.00000000e+00,
         1.00000000e+00,  1.00000000e+00,  1.00000000e+00,
         1.00

In [23]:
shrink_layer(layer_no=1, mid_layers=16)

diff = 0.00016169504669960588
choosen eigenvalues weight = 0.9992694134739482


array([[6.90872364e-01, 1.51875500e-01, 1.11305604e-01, 2.28447524e-02,
        1.02340958e-02, 5.57610236e-03, 2.35518679e-03, 1.12360681e-03,
        8.27294286e-04, 6.22440049e-04, 4.64976564e-04, 3.21633101e-04,
        2.87125765e-04, 2.61755811e-04, 1.66492338e-04, 1.30483665e-04,
        1.12538187e-04, 8.38287296e-05, 7.79653793e-05, 6.87447916e-05,
        6.41351324e-05, 4.91414604e-05, 4.51465225e-05, 3.96403907e-05,
        3.84260112e-05, 2.91558330e-05, 2.82666889e-05, 2.52277161e-05,
        2.14428159e-05, 1.82928042e-05, 1.47147143e-05, 1.39193485e-05],
       [6.90872364e-01, 8.42747864e-01, 9.54053468e-01, 9.76898220e-01,
        9.87132316e-01, 9.92708418e-01, 9.95063605e-01, 9.96187212e-01,
        9.97014506e-01, 9.97636946e-01, 9.98101923e-01, 9.98423556e-01,
        9.98710682e-01, 9.98972437e-01, 9.99138930e-01, 9.99269413e-01,
        9.99381952e-01, 9.99465780e-01, 9.99543746e-01, 9.99612491e-01,
        9.99676626e-01, 9.99725767e-01, 9.99770914e-01, 9.99810

In [24]:
shrink_layer(layer_no=2, mid_layers=32)

diff = 0.0006230933358892798
choosen eigenvalues weight = 0.9994686847903907


array([[5.41107727e-01, 1.99994860e-01, 1.67957757e-01, 4.05030664e-02,
        2.43670958e-02, 1.13422359e-02, 4.34249801e-03, 2.34529642e-03,
        1.89836527e-03, 1.08557922e-03, 9.33857827e-04, 5.80825992e-04,
        4.39777189e-04, 3.68221598e-04, 3.49847263e-04, 2.53562413e-04,
        2.09806420e-04, 1.94004483e-04, 1.55790686e-04, 1.42688967e-04,
        1.31743372e-04, 1.05355449e-04, 9.31827381e-05, 8.74572888e-05,
        8.15766945e-05, 7.19818764e-05, 6.56479480e-05, 5.67727270e-05,
        5.34645921e-05, 5.13552674e-05, 5.08735441e-05, 4.64097143e-05,
        4.31595493e-05, 4.16782342e-05, 3.40198499e-05, 3.28355303e-05,
        3.20072718e-05, 2.72846195e-05, 2.49601327e-05, 2.37773450e-05,
        2.28731271e-05, 1.93054741e-05, 1.92533642e-05, 1.85251396e-05,
        1.67748439e-05, 1.60158515e-05, 1.51118533e-05, 1.45522741e-05,
        1.31160830e-05, 1.19272371e-05, 1.10114449e-05, 1.03634080e-05,
        9.64787599e-06, 9.38851544e-06, 8.99492550e-06, 8.966921

In [25]:
shrink_layer(layer_no=3, mid_layers=32)

diff = 0.002003157278522849
choosen eigenvalues weight = 0.9998424184481154


array([[3.07208001e-01, 2.72870291e-01, 2.28772508e-01, 9.78029362e-02,
        4.36307817e-02, 2.42746555e-02, 8.96494246e-03, 5.77804621e-03,
        4.11709096e-03, 2.67588362e-03, 1.10984865e-03, 5.24515200e-04,
        4.96975730e-04, 3.48865837e-04, 2.83119196e-04, 1.96321832e-04,
        1.73138622e-04, 1.33447728e-04, 7.85959138e-05, 7.26186205e-05,
        4.95241965e-05, 4.39073251e-05, 3.99645860e-05, 3.58316103e-05,
        2.70210053e-05, 2.57757963e-05, 2.28203668e-05, 2.07502908e-05,
        1.95681933e-05, 1.58548726e-05, 1.45112601e-05, 1.43043448e-05,
        1.32614879e-05, 1.22250092e-05, 1.16588637e-05, 9.25564671e-06,
        8.45779871e-06, 8.16983703e-06, 7.35021179e-06, 6.68566808e-06,
        6.00565246e-06, 5.85584591e-06, 5.36285327e-06, 5.21490360e-06,
        4.86160243e-06, 4.50231430e-06, 4.48313376e-06, 4.23901106e-06,
        3.86662156e-06, 3.66383083e-06, 3.44276582e-06, 3.32529609e-06,
        3.03785773e-06, 2.73651737e-06, 2.61771173e-06, 2.444884

In [26]:
shrink_layer(layer_no=4, mid_layers=64)

diff = 0.04089765623211861
choosen eigenvalues weight = 0.9980147880163952


array([[3.04662521e-01, 1.58349194e-01, 1.25816134e-01, 9.70014777e-02,
        6.28481699e-02, 5.88651191e-02, 3.85269118e-02, 3.02057014e-02,
        2.32203763e-02, 1.75204364e-02, 1.38472211e-02, 1.13532790e-02,
        9.55050552e-03, 6.97411320e-03, 5.98340613e-03, 4.93256677e-03,
        4.54769485e-03, 3.10258562e-03, 2.44730850e-03, 2.03940546e-03,
        1.69276736e-03, 1.41551110e-03, 1.34020977e-03, 1.07393348e-03,
        8.89830611e-04, 7.15702445e-04, 6.60485034e-04, 6.44456786e-04,
        5.34541169e-04, 5.11693313e-04, 5.02689231e-04, 4.40116380e-04,
        3.75370118e-04, 3.57245122e-04, 3.29962898e-04, 3.24354880e-04,
        3.03113508e-04, 2.84904555e-04, 2.59633762e-04, 2.47307453e-04,
        2.34306045e-04, 2.24004384e-04, 2.14141117e-04, 1.97803662e-04,
        1.86137153e-04, 1.74915546e-04, 1.67241141e-04, 1.57087846e-04,
        1.53433159e-04, 1.48848485e-04, 1.41542437e-04, 1.35022027e-04,
        1.20538861e-04, 1.14330715e-04, 1.10383463e-04, 1.063880

In [27]:
shrink_layer(layer_no=5, mid_layers=64)

diff = 0.03849208354949951
choosen eigenvalues weight = 0.9980337282775918


array([[3.06424392e-01, 2.26783453e-01, 1.34995127e-01, 1.03324974e-01,
        7.42818076e-02, 4.05715877e-02, 2.12504356e-02, 1.93831084e-02,
        1.13350797e-02, 9.61845711e-03, 9.21366802e-03, 6.96919250e-03,
        4.95609886e-03, 3.95573996e-03, 3.65463696e-03, 2.60785745e-03,
        2.19521496e-03, 1.92252897e-03, 1.40597379e-03, 1.24498059e-03,
        1.06533565e-03, 9.26767486e-04, 7.83990644e-04, 6.27024586e-04,
        5.99851624e-04, 5.61611956e-04, 5.11547709e-04, 4.33884431e-04,
        4.12262552e-04, 3.78815696e-04, 3.45213791e-04, 3.27943589e-04,
        3.18604491e-04, 2.86902903e-04, 2.73597741e-04, 2.63109562e-04,
        2.55781610e-04, 2.39160697e-04, 2.12447484e-04, 2.11253951e-04,
        1.93599493e-04, 1.81650870e-04, 1.72784248e-04, 1.64902847e-04,
        1.63063376e-04, 1.55786717e-04, 1.49621956e-04, 1.45030318e-04,
        1.32146912e-04, 1.23697265e-04, 1.19250965e-04, 1.16048102e-04,
        1.12558827e-04, 1.05605648e-04, 1.03498235e-04, 9.698522

In [28]:
def test_layers(mirror_count: int = 2) -> float:
    for i, layer in enumerate(model.paraller_layers):
        layer.use_real_layer = i >= mirror_count
    return TrainHelper.test(model, test_data, device='cuda')

In [29]:
[test_layers(i) for i in range(7)]

[0.9833, 0.9833, 0.9829, 0.9828, 0.9829, 0.9829, 0.983]