# Эксперименты с классификатором

## Подготовка данных

In [5]:
import torch
import torchvision
import matplotlib.pyplot as plt
import torchmetrics

from torch import nn
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger

from models.autoencoder import MyAutoencoder
from models.classifier import MyClassifier
from src.utils import plot_reconstructed, grid_plot, vis_confusion

In [6]:
transform = transforms.Compose([
    transforms.ToTensor(),
])


# Prepare test data
cifar_test = CIFAR10('data/', train=False, download=True, transform=transform)
test_dataloader = torch.utils.data.DataLoader(dataset=cifar_test, batch_size=1000)


# Prepare train/val data
cifar_train = CIFAR10('data/', train=True, download=True, transform=transform)

val_size = 2000
train_size= len(cifar_train) - val_size
torch.manual_seed(42)
cifar_train, cifar_val = torch.utils.data.random_split(cifar_train, [train_size, val_size])


train_dataloader = torch.utils.data.DataLoader(dataset=cifar_train, batch_size=1000, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(dataset=cifar_val, batch_size=1000)

Files already downloaded and verified
Files already downloaded and verified


In [7]:
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

classes = {label: i for i, label in enumerate(classes)}

## Модели

### Автоэнкодер

In [8]:
autoencoder_weights = 'outputs/autoencoder_model.pth'

autoencoder = MyAutoencoder()
autoencoder.load_state_dict(torch.load(autoencoder_weights))

<All keys matched successfully>

### Классификатор

In [9]:
def train(clf, name, max_epochs=20):
    net = MyClassifier(autoencoder, clf, classes, lr=1e-3)

    logger = TensorBoardLogger('', name='outputs', version=name)

    trainer = Trainer(max_epochs=max_epochs, logger=logger)
    trainer.fit(net, train_dataloader, val_dataloader)
    
    # trainer = Trainer(max_epochs=20, check_val_every_n_epoch=2,
    #      limit_train_batches=5, limit_val_batches=5, logger=logger)
    # trainer = Trainer(check_val_every_n_epoch=5)

    # trainer = Trainer(max_epochs=20, limit_train_batches=5, limit_val_batches=5, logger=logger)
    # trainer = Trainer(max_epochs=20, log_every_n_steps=10, logger=logger)
    # trainer = Trainer(max_epochs=20, logger=logger,
    #                     check_val_every_n_epoch=1)

In [10]:
clf = nn.Sequential(
    nn.Linear(256, 128),
    nn.ReLU(),
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Linear(64, 10),
    )

train(clf, 'test_run3')

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name         | Type                      | Params
-----------------------------------------------------------
0 | _autoencoder | MyAutoencoder             | 1.5 K 
1 | encoder      | Sequential                | 772   
2 | clf          | Sequential                | 41.8 K
3 | accuracy     | MulticlassAccuracy        | 0     
4 | conf_matrix  | MulticlassConfusionMatrix | 0     
-----------------------------------------------------------
41.8 K    Trainable params
1.5 K     Non-trainable params
43.3 K    Total params
0.173     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:00<00:00,  3.21it/s]

  ax.set_xticklabels([''] + all_categories, rotation=90)
  ax.set_yticklabels([''] + all_categories)


                                                                           

  rank_zero_warn(
  rank_zero_warn(


Epoch 19: 100%|██████████| 50/50 [00:09<00:00,  5.24it/s, loss=1.44, v_num=run3]

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 50/50 [00:09<00:00,  5.24it/s, loss=1.44, v_num=run3]


In [None]:
#         self.conv1 = nn.Conv2d(4, 6, 5)
#         self.pool = nn.MaxPool2d(2, 2)
#         self.conv2 = nn.Conv2d(6, 16, 5)
#         self.fc1 = nn.Linear(16 * 5 * 5, 120)
#         self.fc2 = nn.Linear(120, 84)
#         self.fc3 = nn.Linear(84, 10)

#     def forward(self, x):
#         x = self.pool(F.relu(self.conv1(x)))
#         x = self.pool(F.relu(self.conv2(x)))
#         x = torch.flatten(x, 1) # flatten all dimensions except batch
#         x = F.relu(self.fc1(x))
#         x = F.relu(self.fc2(x))
#         x = self.fc3(x)
#         return x

In [None]:
# def conv_block(in_f, out_f, *args, **kwargs):
# 	return nn.Sequential(
# 		nn.Conv2d(in_f, out_f, *args, **kwargs),
# 		nn.BatchNorm2d(out_f),
# 		nn.ReLU()
# 		)


# class MyCNNClassifier(nn.Module):
# 2	    def __init__(self, in_c, n_classes):
# 3	        super().__init__()
# 4	        self.encoder = nn.Sequential(
# 5	            conv_block(in_c, 32, kernel_size=3, padding=1),
# 6	            conv_block(32, 64, kernel_size=3, padding=1)
# 7	        )
# 8	
# 9	        
# 10	        self.decoder = nn.Sequential(
# 11	            nn.Linear(32 * 28 * 28, 1024),
# 12	            nn.Sigmoid(),
# 13	            nn.Linear(1024, n_classes)
# 14	        )
# 15	
# 16	        
# 17	    def forward(self, x):
# 18	        x = self.encoder(x)
# 19	        
# 20	        x = x.view(x.size(0), -1) # flat
# 21	        
# 22	        x = self.decoder(x)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name        | Type                      | Params
----------------------------------------------------------
0 | clf         | Sequential                | 41.8 K
1 | accuracy    | MulticlassAccuracy        | 0     
2 | conf_matrix | MulticlassConfusionMatrix | 0     
----------------------------------------------------------
41.8 K    Trainable params
0         Non-trainable params
41.8 K    Total params
0.167     Total estimated model params size (MB)


Epoch 2:  80%|████████  | 40/50 [00:30<00:07,  1.30it/s, loss=2.3, v_num=un_2] 

In [None]:
torch.save(net.state_dict(), 'outputs/clf_model.pth')

## Оценка качества

trainer.test(dataloaders=test_dataloader)

In [None]:
net = Classifier()
net.load_state_dict(torch.load('outputs/clf_model.pth'))

<All keys matched successfully>