In [None]:
%pip install seaborn

In [None]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style="darkgrid")

import numpy as np
from collections import defaultdict
import functools
import pandas as pd
import pickle

Архитектура ResNet18:

![resnet18](https://russianblogs.com/images/785/72d76970469cbaae3253584ea2e81441.png)

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self,  in_chnls, use_bn, use_do, do_prob, use_skip, part_num, stride_first=2):
        super().__init__()
        out_chnls = 2 * in_chnls if stride_first == 2 else in_chnls
        self.part_num = part_num
        self.use_skip = use_skip

        for p in range(part_num):
            attr = "part" + str(p)
            self.__setattr__(attr, nn.Sequential())


            self.__getattr__(attr).add_module("conv0", 
                                              nn.Conv2d(
                                                  in_channels=in_chnls if p==0 else out_chnls,
                                                  out_channels=out_chnls,
                                                  kernel_size=3,
                                                  stride=stride_first if p == 0 else 1,
                                                  padding=1))
            if use_bn:
                self.__getattr__(attr).add_module("bn0", nn.BatchNorm2d(out_chnls))
            if use_do:
                self.__getattr__(attr).add_module("do0", nn.Dropout(p=do_prob))
            self.__getattr__(attr).add_module("relu0", nn.ReLU())

            self.__getattr__(attr).add_module("conv1", nn.Conv2d(out_chnls, out_chnls, 3, 1, 1))
            if use_bn:
                self.__getattr__(attr).add_module("bn1", nn.BatchNorm2d(out_chnls))
            if use_do:
                self.__getattr__(attr).add_module("do1", nn.Dropout(p=do_prob))
            
            self.__setattr__("part"+str(p)+"_lastrelu", nn.ReLU())
        
        if use_skip:
            for p in range(part_num): 
                attr = "skip" + str(p)
                if p == 0 and stride_first == 2:
                    self.__setattr__(attr, nn.Sequential(
                                                nn.Conv2d(in_chnls, out_chnls, 1, stride_first, bias=False),
                                                nn.BatchNorm2d(out_chnls)))
                else:
                    self.__setattr__(attr, nn.Sequential())

    def forward(self, x):
        for p in range(self.part_num): 
            out = self.__getattr__("part"+str(p))(x)
            if self.use_skip:
                out += self.__getattr__("skip"+str(p))(x)
            x = self.__getattr__("part"+str(p)+"_lastrelu")(out)

        return x


class CustomResNet18(pl.LightningModule):
    def __init__(self, conf, in_chnls=3, num_classes=10, lr=1e-2):
        super().__init__()
        self.conf = conf
        self.lr = lr
        self.bins = np.linspace(-10, 10, 200001)
        self.binsdict = defaultdict(lambda: np.zeros(200000, dtype=np.float64))
        self.curves = defaultdict(list)

        self.in_conv = nn.Conv2d(in_chnls, conf["width"], kernel_size=1, stride=1)
        self.block0 = ResidualBlock(conf["width"], conf["bn"], conf["do"], conf["do_p"], conf["use_skip"], conf["depth"], 1)
        self.block1 = ResidualBlock(conf["width"], conf["bn"], conf["do"], conf["do_p"], conf["use_skip"], conf["depth"], 2)
        self.block2 = ResidualBlock(2*conf["width"], conf["bn"], conf["do"], conf["do_p"], conf["use_skip"], conf["depth"], 2)
        self.block3 = ResidualBlock(4*conf["width"], conf["bn"], conf["do"], conf["do_p"], conf["use_skip"], conf["depth"], 2)
        self.avgpool = nn.AvgPool2d(4)
        self.fc = nn.Linear(8*conf["width"], num_classes)
    
    def forward(self, x):
        out = self.in_conv(x)
        for i in range(4):
            out = self.__getattr__("block"+str(i))(out)
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)

        return out
    
    def _shared_eval_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        acc = torch.mean((torch.argmax(y_hat, dim=-1) == y).float())

        return loss, acc

    def training_step(self, batch, batch_idx):
        loss, acc = self._shared_eval_step(batch, batch_idx)

        return loss

    def validation_step(self, batch, batch_idx):
        loss, acc = self._shared_eval_step(batch, batch_idx)
        metrics = {"val_acc": acc, "val_loss": loss}
        self.log_dict(metrics, prog_bar=True)

        return metrics

    def configure_optimizers(self):
        optimizer = None
        if self.conf["optim"] == "adam":
            optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        elif self.conf["optim"] == "sgd":
            optimizer = torch.optim.SGD(self.parameters(), lr=self.lr)
        elif self.conf["optim"] == "rmsprop":
            optimizer = torch.optim.RMSprop(self.parameters(), lr=self.lr)
        elif self.conf["optim"] == "adamw":
            optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)

        
        return optimizer
    
    def on_after_backward(self):
        for i in range(4):
            attr_list = ['block'+str(i), 'part0', 'conv1']
            layer = functools.reduce(lambda x, y: getattr(x, y), attr_list, self)
            curr_grad = layer.weight.grad.cpu().flatten().numpy()
            self.binsdict["_".join(attr_list)] += np.histogram(curr_grad, self.bins)[0]
        

In [None]:
#!g1.1
dataset = CIFAR10(os.getcwd(), download=True, transform=transforms.ToTensor())
train, val = random_split(dataset, [40_000, 10_000])

dataset.__len__()

In [None]:
#!g1.1
fig, ax = plt.subplots(1,3)
ax[0].imshow(train.dataset.data[0])
ax[1].imshow(train.dataset.data[100])
ax[2].imshow(train.dataset.data[1000])

In [None]:
#!g1.1
train_dataloader = DataLoader(train, batch_size=512, num_workers=8)
test_dataloader = DataLoader(val, batch_size=512, num_workers=8)

In [109]:
#!g1.1
dict_width = {}
main_par = "width"
for curr_width in [8, 16, 32, 64, 128]:
    conf = {"width": curr_width, "bn": True, "do": False, "do_p": 0.0, "use_skip": False, "depth": 2, "optim": "adam"}

    model = CustomResNet18(conf)
    trainer = pl.Trainer(max_epochs=10, accelerator="gpu", gpus=1)
    trainer.fit(model, train_dataloader)

    val = trainer.validate(model, dataloaders=test_dataloader)[0]["val_acc"]
    trn = trainer.validate(model, dataloaders=train_dataloader)[0]["val_acc"]


    sett = [f"$\\bf{{{k}={v}}}$" if k == main_par else f"{k}={v}" for k, v in conf.items()]
    sett += [f"trn_acc={trn:.3f}", f"val_acc={val:.3f}"]
    name = "; ".join(sett)

    dict_width[name] = model.binsdict

with open('dict_width.pickle', 'wb') as handle:
    pickle.dump(dict_width, handle)

HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type          | Params
------------------------------------------
0 | in_conv | Conv2d        | 32    
1 | block0  | ResidualBlock | 2.4 K 
2 | block1  | ResidualBlock | 8.3 K 
3 | block2  | ResidualBlock | 32.6 K
4 | block3  | ResidualBlock | 129 K 
5 | avgpool | AvgPool2d     | 0     
6 | fc      | Linear        | 650   
------------------------------------------
173 K     Trainable params
0         Non-trainable params
173 K     Total params
0.695     Total estimated model params size (MB)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type          | Params
------------------------------------------
0 | in_conv | Conv2d

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…


--------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{'val_acc': 0.578000009059906, 'val_loss': 1.2772639989852905}
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{'val_acc': 0.578000009059906, 'val_loss': 1.2772639989852905}
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{'val_acc': 0.39809998869895935, 'val_loss': 1.6334506273269653}
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{'val_acc': 0.39809998869895935, 'val_loss': 1.6334506273269653}
--------------------------------------------------------

AttributeError: Can't pickle local object 'CustomResNet18.__init__.<locals>.<lambda>'

In [115]:
dict_width_ = {}
for k, v in dict_width.items():
    dict_width_[k] = dict(v)

In [117]:
with open('dict_width.pickle', 'wb') as handle:
    pickle.dump(dict_width_, handle)

# New Section

In [125]:
#!g1.1

dict_bn_do = {}
main_par = ["bn", "do", "do_p"]
grid = [[False, False, 0.0], [True, False, 0.0],
        [False, True, 0.5], [True, True, 0.5], [True, True, 0.25]]

for curr_bn, curr_do, curr_do_p in grid:
    conf = {"width": 32, "bn": curr_bn, "do": curr_do, "do_p": curr_do_p, "use_skip": False, "depth": 2, "optim": "adam"}

    model = CustomResNet18(conf)
    trainer = pl.Trainer(max_epochs=10, accelerator="gpu", gpus=1)
    trainer.fit(model, train_dataloader)

    val = trainer.validate(model, dataloaders=test_dataloader)[0]["val_acc"]
    trn = trainer.validate(model, dataloaders=train_dataloader)[0]["val_acc"]

    sett = [f"$\\bf{{{k}={v}}}$" if k in main_par else f"{k}={v}" for k, v in conf.items()]
    sett += [f"trn_acc={trn:.3f}", f"val_acc={val:.3f}"]
    name = "; ".join(sett)

    dict_bn_do[name] = dict(model.binsdict)
    
with open('dict_bn_do.pickle', 'wb') as handle:
    pickle.dump(dict_bn_do, handle)

HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_warn(f"you defined a {step_name} but have no {loader_name}. Skipping {stage} loop")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type          | Params
------------------------------------------
0 | in_conv | Conv2d        | 128   
1 | block0  | ResidualBlock | 37.0 K
2 | block1  | ResidualBlock | 129 K 
3 | block2  | ResidualBlock | 516 K 
4 | block3  | ResidualBlock | 2.1 M 
5 | avgpool | AvgPool2d     | 0     
6 | fc      | Linear        | 2.6 K 
------------------------------------------
2.8 M     Trainable params
0         Non-trainable params
2.8 M     Total params
11.004    Total estimated model params size (MB)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_warn(f"you defined a {step_name} bu

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…


--------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{'val_acc': 0.09570000320672989, 'val_loss': 2.3027024269104004}
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{'val_acc': 0.10107500106096268, 'val_loss': 2.3025729656219482}
--------------------------------------------------------------------------------

--------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{'val_acc': 0.6804999709129333, 'val_loss': 1.2473310232162476}
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{'val_acc': 0.7572000026702881, 'val_loss': 0.8124817609786987}
-----------------------------------------------------

In [126]:
with open('dict_bn_do.pickle', 'wb') as handle:
    pickle.dump(dict_bn_do, handle)

In [127]:
dict_bn_do

{'width=32; $\\bf{bn=False}$; $\\bf{do=False}$; $\\bf{do_p=0.0}$; use_skip=False; depth=2; optim=adam; trn_acc=0.101; val_acc=0.096': {'block0_part0_conv1': array([0., 0., 0., ..., 0., 0., 0.]),
  'block1_part0_conv1': array([0., 0., 0., ..., 0., 0., 0.]),
  'block2_part0_conv1': array([0., 0., 0., ..., 0., 0., 0.]),
  'block3_part0_conv1': array([0., 0., 0., ..., 0., 0., 0.])},
 'width=32; $\\bf{bn=True}$; $\\bf{do=False}$; $\\bf{do_p=0.0}$; use_skip=False; depth=2; optim=adam; trn_acc=0.757; val_acc=0.680': {'block0_part0_conv1': array([0., 0., 0., ..., 0., 0., 0.]),
  'block1_part0_conv1': array([0., 0., 0., ..., 0., 0., 0.]),
  'block2_part0_conv1': array([0., 0., 0., ..., 0., 0., 0.]),
  'block3_part0_conv1': array([0., 0., 0., ..., 0., 0., 0.])},
 'width=32; $\\bf{bn=False}$; $\\bf{do=True}$; $\\bf{do_p=0.5}$; use_skip=False; depth=2; optim=adam; trn_acc=0.156; val_acc=0.161': {'block0_part0_conv1': array([0., 0., 0., ..., 0., 0., 0.]),
  'block1_part0_conv1': array([0., 0., 0., 

# New Section

In [135]:
#!g1.1
dict_skip = {}
main_par = ["use_skip", "bn"]
grid = [[False, False], [False, True], [True, False], [True, True]]

for curr_bn, curr_skip in grid:
    conf = {"width": 32, "bn": curr_bn, "do": False, "do_p": 0.0, "use_skip": curr_skip, "depth": 2, "optim": "adam"}

    model = CustomResNet18(conf)
    trainer = pl.Trainer(max_epochs=10, accelerator="gpu", gpus=1)
    trainer.fit(model, train_dataloader)

    val = trainer.validate(model, dataloaders=test_dataloader)[0]["val_acc"]
    trn = trainer.validate(model, dataloaders=train_dataloader)[0]["val_acc"]

    sett = [f"$\\bf{{{k}={v}}}$" if k in main_par else f"{k}={v}" for k, v in conf.items()]
    sett += [f"trn_acc={trn:.3f}", f"val_acc={val:.3f}"]
    name = "; ".join(sett)

    dict_skip[name] = dict(model.binsdict)

with open('dict_skip.pickle', 'wb') as handle:
    pickle.dump(dict_skip, handle)

HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_warn(f"you defined a {step_name} but have no {loader_name}. Skipping {stage} loop")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type          | Params
------------------------------------------
0 | in_conv | Conv2d        | 128   
1 | block0  | ResidualBlock | 37.0 K
2 | block1  | ResidualBlock | 129 K 
3 | block2  | ResidualBlock | 516 K 
4 | block3  | ResidualBlock | 2.1 M 
5 | avgpool | AvgPool2d     | 0     
6 | fc      | Linear        | 2.6 K 
------------------------------------------
2.8 M     Trainable params
0         Non-trainable params
2.8 M     Total params
11.004    Total estimated model params size (MB)
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_warn(f"you defined a {step_name} bu

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…


--------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{'val_acc': 0.09880000352859497, 'val_loss': 2.3026626110076904}
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{'val_acc': 0.10029999911785126, 'val_loss': 2.302584648132324}
--------------------------------------------------------------------------------

--------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{'val_acc': 0.6309999823570251, 'val_loss': 1.4472473859786987}
--------------------------------------------------------------------------------
--------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{'val_acc': 0.7777249813079834, 'val_loss': 0.6707043647766113}
------------------------------------------------------

# AAAAAA

In [None]:
#!g1.1

dict_optim = {}
main_par = ["optim"]
grid = ["sgd", "adam", "rmsprop", "adamw"]

for curr_optim in grid:
    conf = {"width": 32, "bn": True, "do": False, "do_p": 0.0, "use_skip": False, "depth": 2, "optim": curr_optim}

    model = CustomResNet18(conf)
    trainer = pl.Trainer(max_epochs=15, accelerator="gpu", gpus=1)
    trainer.fit(model, train_dataloader)

    val = trainer.validate(model, dataloaders=test_dataloader)[0]["val_acc"]
    trn = trainer.validate(model, dataloaders=train_dataloader)[0]["val_acc"]

    sett = [f"$\\bf{{{k}={v}}}$" if k in main_par else f"{k}={v}" for k, v in conf.items()]
    sett += [f"trn_acc={trn:.3f}", f"val_acc={val:.3f}"]
    name = "; ".join(sett)

    dict_optim[name] = dict(model.binsdict)

with open('dict_optim.pickle', 'wb') as handle:
    pickle.dump(dict_optim, handle)

In [None]:
#!g1.1
