# SSNE Miniproject 3
### 318703 Tomasz Owienko
### 318718 Anna Schäfer
### Grupa piątek

In [1]:
import os
from typing import Any, Callable

import PIL.Image
import pandas as pd
import torch
import torchmetrics
import torch.nn as nn
import torchvision.transforms as transforms
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
import pytorch_lightning as pl

In [2]:
RANDOM_SEED = 123
torch.manual_seed(RANDOM_SEED)

<torch._C.Generator at 0x7ffb33c9ec50>

In [3]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.set_float32_matmul_precision('medium')

In [4]:
class ImagesDataModule(pl.LightningDataModule):
    class FastDataset(Dataset):
        def __init__(self, data, labels, num_classes):
            self.dataset = data
            self.labels = labels
            self.number_classes = num_classes

        def __len__(self):
            return len(self.dataset)

        def __getitem__(self, index):
            return self.dataset[index], self.labels[index]

    def __init__(self, path: str, transform: Callable[[Any], torch.Tensor], *, val_fraction: float,
                 test_fraction: float, in_memory=False):
        super().__init__()
        assert 0 <= val_fraction + test_fraction <= 1
        assert val_fraction * test_fraction >= 0

        self.image_folder = ImageFolder(path, transform=transform)
        self.dataset: ImagesDataModule.FastDataset | None = None
        self._val_fraction = val_fraction
        self._test_fraction = test_fraction
        self._in_memory = in_memory

        self._train = self._val = self._test = None

    def prepare_data(self) -> None:
        if self._in_memory:
            loader = DataLoader(self.image_folder, batch_size=len(self.image_folder))
            data = next(iter(loader))
            dataset = ImagesDataModule.FastDataset(data[0], data[1], num_classes=len(self.image_folder.classes))
        else:
            dataset = self.image_folder

        val_size = int(len(dataset) * self._val_fraction)
        test_size = int(len(dataset) * self._test_fraction)
        train_size = len(dataset) - val_size - test_size

        self._train, self._val, self._test = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

    def train_dataloader(self) -> TRAIN_DATALOADERS:
        return DataLoader(self._train, batch_size=1536, shuffle=True, num_workers=8 if not self._in_memory else 0,
                          pin_memory=True)

    def val_dataloader(self) -> EVAL_DATALOADERS:
        return DataLoader(self._val, batch_size=64, shuffle=False, num_workers=8 if not self._in_memory else 0,
                          pin_memory=True)

    def test_dataloader(self) -> EVAL_DATALOADERS:
        return DataLoader(self._test, batch_size=64, shuffle=False, num_workers=8 if not self._in_memory else 0,
                          pin_memory=True)

In [5]:
class ImageClassifier(pl.LightningModule):
    def __init__(self, num_classes, lr, weight_decay, loss):
        super().__init__()
        self.conv1_1 = nn.Conv2d(3, 32, 5, padding=2)
        self.conv1_2 = nn.Conv2d(32, 32, 5, padding=2)
        self.bn1 = nn.BatchNorm2d(32)
        self.pool1 = nn.MaxPool2d(3, 3)
        self.act1 = nn.ReLU()
        self.conv2_1 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv2_2 = nn.Conv2d(64, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.act2 = nn.ReLU()
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool3 = nn.MaxPool2d(2, 2)
        self.act3 = nn.ReLU()
        self.conv4 = nn.Conv2d(128, 256, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(256)
        self.pool4 = nn.MaxPool2d(2, 2)
        self.act4 = nn.ReLU()
        self.flatten = nn.Flatten()
        self.dropout1 = nn.Dropout(0.2)
        self.fc1 = nn.Linear(1024, 1024)
        self.act5 = nn.ReLU()
        self.bn5 = nn.BatchNorm1d(1024)
        self.dropout2 = nn.Dropout(0.2)
        self.fc2 = nn.Linear(1024, 512)
        self.act6 = nn.ReLU()
        self.fc3 = nn.Linear(512, num_classes)

        self._accuracy = torchmetrics.classification.MulticlassAccuracy(num_classes=num_classes)

        self.example_input_array = torch.rand((16, 3, 64, 64)).to(device)

        self._lr = lr
        self._weight_decay = weight_decay
        self._loss = loss
        
        self._val_preds: list | None = None
        self._test_preds: list | None = None

    def forward(self, x: torch.Tensor):
        x = self.act1(self.pool1(self.bn1(self.conv1_2(self.conv1_1(x)))))
        x = self.act2(self.pool2(self.bn2(self.conv2_2(self.conv2_1(x)))))
        x = self.act3(self.pool3(self.bn3(self.conv3(x))))
        x = self.act4(self.pool4(self.bn4(self.conv4(x))))
        x = self.flatten(x)
        x = self.dropout1(x)
        x = self.act5(self.fc1(x))
        x = self.bn5(x)
        x = self.dropout2(x)
        x = self.act6(self.fc2(x))
        x = self.fc3(x)
        return x

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001, weight_decay=0.001)
        return optimizer

    def on_train_start(self) -> None:
        self.logger.experiment.add_graph(self, self.example_input_array)
        
    def training_step(self, batch, batch_idx):
        assert self.training
        images, labels = batch
        out = self.forward(images)
        loss = self._loss(out, labels)
        self.log('train_loss', loss, on_step=False, on_epoch=True)
        return loss
    
    def on_validation_epoch_start(self) -> None:
        self._val_preds = []

    def validation_step(self, batch, batch_idx):
        assert not self.training
        images, labels = batch
        out = self.forward(images)
        loss = self._loss(out, labels)
        self.log('val_loss', loss, on_step=False, on_epoch=True)
        out = self.forward(images)
        preds = torch.argmax(out, dim=1)
        
        self._val_preds.append((preds, labels))

        return loss
    
    def on_validation_epoch_end(self) -> None:
        preds, labels = zip(*self._val_preds)
        acc = self._accuracy(torch.cat(preds), torch.cat(labels))
        self.log('val_accuracy', acc)
    
    def on_test_epoch_start(self) -> None:
        self._test_preds = []

    def test_step(self, batch, batch_idx):
        assert not self.training
        images, labels = batch
        out = self.forward(images)
        preds = torch.argmax(out, dim=1)
        loss = self._loss(out, labels)
        self.log('test_loss', loss, on_step=False, on_epoch=True)

        self.logger.log_hyperparams(
            {
                'lr': self._lr,
                'weight_decay': self._weight_decay,
                'loss': str(self._loss)
            }
        )
        
        self._test_preds.append((preds, labels))
        
        return loss
        
    def on_test_epoch_end(self) -> None:
        preds, labels = zip(*self._test_preds)
        acc = self._accuracy(torch.cat(preds), torch.cat(labels))
        self.log('test_accuracy', acc)


In [6]:
transform = transforms.Compose([
    transforms.RandAugment(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5,), (0.5, 0.5, 0.5)),
])

dm = ImagesDataModule('data/train', transform, val_fraction=0.1, test_fraction=0.1)

In [7]:
model = ImageClassifier(num_classes=len(dm.image_folder.classes), lr=1e-3, weight_decay=1e-4,
                        loss=torch.nn.CrossEntropyLoss().to(device))

checkpoint_callback = ModelCheckpoint(
    save_top_k=10,
    monitor="val_loss",
    mode="min",
    filename="checkpoint-{epoch:02d}-{val_loss:.2f}",
)

trainer = pl.Trainer(max_epochs=250, callbacks=[checkpoint_callback])
trainer.fit(model, datamodule=dm)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name      | Type               | Params | In sizes          | Out sizes        
------------------------------------------------------------------------------------------
0  | conv1_1   | Conv2d             | 2.4 K  | [16, 3, 64, 64]   | [16, 32, 64, 64] 
1  | conv1_2   | Conv2d             | 25.6 K | [16, 32, 64, 64]  | [16, 32, 64, 64] 
2  | bn1       | BatchNorm2d        | 64     | [16, 32, 64, 64]  | [16, 32, 64, 64] 
3  | pool1     | MaxPool2d          | 0      | [16, 32, 64, 64]  | [16, 32, 21, 21] 
4  | act1      | ReLU               | 0      | [16, 32, 21, 21]  | [16, 32, 21, 21] 
5  | conv2_1   | Conv2d             | 18.5 K | [16, 32, 21, 21]  | [16, 64, 21, 21] 
6  | conv2_2   | Conv2d             | 36.9 K | [16, 64, 21, 21]  | [16, 64, 21, 21] 
7  | bn2       | BatchNorm2d      

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=250` reached.


In [8]:
trainer.test(model, datamodule=dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

[{'test_loss': 0.7726643681526184, 'test_accuracy': 0.7801985144615173}]

In [9]:
%load_ext tensorboard
%tensorboard --logdir=lightning_logs

Launching TensorBoard...

In [10]:
acc = torchmetrics.classification.MulticlassAccuracy(num_classes=50)
model.eval()

ImageClassifier(
  (conv1_1): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv1_2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool1): MaxPool2d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
  (act1): ReLU()
  (conv2_1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (act2): ReLU()
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (act3): ReLU()
  (conv4): Co

In [11]:
# sanity check

with torch.no_grad():
    test_dl = dm.test_dataloader()
    test_res = [(torch.argmax(model.forward(x), dim=1), y) for x, y in test_dl]
    test_preds, test_labels = zip(*test_res)
    print(acc(torch.cat(test_preds), torch.cat(test_labels)))

tensor(0.7811)


In [12]:
# sanity check

with torch.no_grad():
    val_dl = dm.val_dataloader()
    val_res = [(torch.argmax(model.forward(x), dim=1), y) for x, y in val_dl]
    val_preds, val_labels = zip(*val_res)
    print(acc(torch.cat(val_preds), torch.cat(val_labels)))

tensor(0.7814)


### Predict on test data

In [13]:
full_dm = ImagesDataModule('data/train', transform=transform, val_fraction=0., test_fraction=0.)
final_model = ImageClassifier(num_classes=len(full_dm.image_folder.classes), lr=1e-3, weight_decay=1e-4,
                        loss=torch.nn.CrossEntropyLoss().to(device))
final_trainer = pl.Trainer(max_epochs=250, enable_checkpointing=False)
final_trainer.fit(final_model, datamodule=full_dm)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name      | Type               | Params | In sizes          | Out sizes        
------------------------------------------------------------------------------------------
0  | conv1_1   | Conv2d             | 2.4 K  | [16, 3, 64, 64]   | [16, 32, 64, 64] 
1  | conv1_2   | Conv2d             | 25.6 K | [16, 32, 64, 64]  | [16, 32, 64, 64] 
2  | bn1       | BatchNorm2d        | 64     | [16, 32, 64, 64]  | [16, 32, 64, 64] 
3  | pool1     | MaxPool2d          | 0      | [16, 32, 64, 64]  | [16, 32, 21, 21] 
4  | act1      | ReLU               | 0      | [16, 32, 21, 21]  | [16, 32, 21, 21] 
5  | conv2_1   | Conv2d             | 18.5 K | [16, 32, 21, 21]  | [16, 64, 21, 21] 
6  | conv2_2   | Conv2d             | 36.9 K | [16, 64, 21, 21]  | [16, 64, 21, 21] 
7  | bn2       | BatchNorm2d      

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=250` reached.


In [14]:
class TestDataset(Dataset):
    def __init__(self, path, num_classes, transform):
        self.img_paths = [f"{path}/{file}" for file in os.listdir(path)]
        self.number_classes = num_classes
        self.transform = transform

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, index):
        return (self.img_paths[index], self.transform(PIL.Image.open(self.img_paths[index]).convert('RGB'))), torch.tensor(0)
    
    
test_dataset = TestDataset('data/test_all', num_classes=50, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5,), (0.5, 0.5, 0.5)),
]))

test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True)

final_model.eval()
with torch.no_grad():
    preds = pd.DataFrame([(path[0].split('/')[-1], int(final_model(x).argmax(axis=1))) for (path, x), _ in test_dataloader])
preds.to_csv('owienko_schafer.csv', header=None, index=None)
print(preds)

                           0   1
0       835086824463163.JPEG  38
1     39130056525035284.JPEG  36
2     22458983111906805.JPEG  24
3      2359038148696866.JPEG  18
4      8048582063490501.JPEG  23
...                      ...  ..
9995   3292234890542963.JPEG   8
9996   7254685650337267.JPEG  43
9997   7494989598289197.JPEG   1
9998   7850868808571556.JPEG  28
9999   7409775133938961.JPEG  37

[10000 rows x 2 columns]
