In [1]:
from utils.trainer import ClassifierTrainer

In [2]:
from utils.nn import create_mlp_layers
import torch


clf = torch.nn.Sequential(*create_mlp_layers(784, [300, 100], 10))

In [3]:
from sklearn.datasets import fetch_openml


mnist = fetch_openml('mnist_784', version=1, parser='auto')
X, y = mnist['data'], mnist['target']

In [4]:
class NamedDataset(torch.utils.data.Dataset):
    def __init__(self, names: list[str], data: list[torch.Tensor]):
        self.names = names
        self.data = data
        
    def __len__(self) -> int:
        return len(self.data[0])
    
    def __getitem__(self, idx) -> dict[str, torch.Tensor]:
        return {name: data[idx] for name, data in zip(self.names, self.data)}


def create_generator(dataset: torch.utils.data.Dataset, batch_size: int = 128, shuffle: bool = True, drop_last: bool = True, **kwargs):
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, **kwargs)
    while True:
        yield from loader

In [5]:
train_generator = create_generator(NamedDataset(['x', 'y'], [torch.tensor(X.values) / 255.0, y.values.astype(int)]))

In [6]:
trainer = ClassifierTrainer(clf)

trainer.train(train_generator)

[34m[1mwandb[0m: Currently logged in as: [33mantonii-belyshev[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 6/5000 [00:00<01:41, 49.39it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b


  0%|          | 20/5000 [00:00<01:23, 59.89it/s]

a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


  1%|          | 34/5000 [00:00<01:18, 62.90it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


  1%|          | 41/5000 [00:00<01:43, 47.89it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a


  1%|          | 54/5000 [00:01<01:35, 51.64it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b


  1%|▏         | 71/5000 [00:01<01:13, 66.73it/s]

a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


  2%|▏         | 89/5000 [00:01<01:04, 76.22it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


  2%|▏         | 107/5000 [00:01<01:01, 80.16it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


  2%|▎         | 125/5000 [00:01<00:58, 83.05it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


  3%|▎         | 134/5000 [00:01<01:00, 80.56it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


  3%|▎         | 153/5000 [00:02<00:58, 82.80it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


  3%|▎         | 172/5000 [00:02<00:56, 86.04it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b


  4%|▍         | 190/5000 [00:02<01:01, 78.28it/s]

a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


  4%|▍         | 208/5000 [00:02<00:59, 81.15it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


  5%|▍         | 226/5000 [00:03<00:56, 83.84it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


  5%|▍         | 236/5000 [00:03<00:55, 85.71it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


  5%|▌         | 254/5000 [00:03<00:54, 86.39it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


  5%|▌         | 274/5000 [00:03<00:52, 89.64it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b


  6%|▌         | 294/5000 [00:03<00:51, 90.65it/s]

a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


  6%|▋         | 313/5000 [00:04<00:54, 85.68it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b


  7%|▋         | 331/5000 [00:04<00:54, 85.38it/s]

a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


  7%|▋         | 350/5000 [00:04<00:52, 88.12it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


  7%|▋         | 368/5000 [00:04<00:53, 87.36it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


  8%|▊         | 388/5000 [00:04<00:51, 90.13it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


  8%|▊         | 408/5000 [00:05<00:50, 91.56it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


  9%|▊         | 428/5000 [00:05<00:50, 91.17it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b


  9%|▉         | 438/5000 [00:05<00:49, 91.49it/s]

a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b


  9%|▉         | 458/5000 [00:05<00:49, 92.23it/s]

a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


 10%|▉         | 478/5000 [00:05<00:48, 93.94it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b


 10%|▉         | 498/5000 [00:06<00:47, 93.84it/s]

a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b


 10%|█         | 518/5000 [00:06<00:48, 93.18it/s]

a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b


 11%|█         | 538/5000 [00:06<00:47, 93.63it/s]

a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


 11%|█         | 557/5000 [00:06<00:50, 87.57it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b


 12%|█▏        | 576/5000 [00:06<00:49, 89.07it/s]

a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b


 12%|█▏        | 595/5000 [00:07<00:49, 89.15it/s]

a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


 12%|█▏        | 614/5000 [00:07<00:48, 90.68it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


 13%|█▎        | 634/5000 [00:07<00:46, 93.27it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a

 13%|█▎        | 644/5000 [00:07<00:48, 89.51it/s]


b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


 13%|█▎        | 663/5000 [00:07<00:47, 90.78it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


 14%|█▎        | 683/5000 [00:08<00:48, 89.65it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


 14%|█▍        | 701/5000 [00:08<00:50, 85.06it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


 14%|█▍        | 719/5000 [00:08<00:50, 84.21it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


 15%|█▍        | 737/5000 [00:08<00:50, 83.91it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


 15%|█▍        | 746/5000 [00:08<00:51, 82.00it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


 15%|█▌        | 764/5000 [00:09<00:51, 82.47it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


 16%|█▌        | 783/5000 [00:09<00:48, 86.22it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


 16%|█▌        | 802/5000 [00:09<00:47, 88.37it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


 16%|█▋        | 821/5000 [00:09<00:47, 87.78it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


 17%|█▋        | 840/5000 [00:10<00:46, 88.63it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


 17%|█▋        | 860/5000 [00:10<00:46, 88.49it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


 18%|█▊        | 878/5000 [00:10<00:49, 84.11it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


 18%|█▊        | 887/5000 [00:10<00:50, 81.48it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


 18%|█▊        | 904/5000 [00:10<00:54, 75.75it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


 18%|█▊        | 912/5000 [00:11<01:11, 56.98it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b


 19%|█▊        | 928/5000 [00:11<01:04, 63.07it/s]

a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


 19%|█▊        | 935/5000 [00:11<01:15, 53.84it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


 19%|█▉        | 947/5000 [00:11<01:16, 53.24it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


 19%|█▉        | 961/5000 [00:11<01:08, 58.91it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a

 20%|█▉        | 975/5000 [00:12<01:07, 60.01it/s]


b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


 20%|█▉        | 991/5000 [00:12<01:00, 65.77it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


 20%|██        | 1007/5000 [00:12<00:55, 72.25it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


 20%|██        | 1025/5000 [00:12<00:52, 76.15it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


 21%|██        | 1035/5000 [00:12<00:49, 80.61it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


 21%|██        | 1053/5000 [00:13<00:47, 83.34it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


 21%|██▏       | 1071/5000 [00:13<00:45, 85.66it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


 22%|██▏       | 1089/5000 [00:13<00:47, 82.29it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


 22%|██▏       | 1107/5000 [00:13<00:48, 80.38it/s]

b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a
b
a


 22%|██▏       | 1119/5000 [00:13<00:48, 80.31it/s]


b
a
b
a
b
a
b
a
b
a
b
a
b
a
b


KeyboardInterrupt: 