# DDPLKO Moduł 4 - praca domowa - Quickdraw 10 class - regularyzacja - pytorch

Twoim zadaniem w tym module będzie przygotowanie własnego modelu sieci neuronowej korzystając z regularyzacji.

Lista rzeczy które musi spełnić Twój model:
- [x] działać na wybranych przez Ciebie 10 klasach (bazuj na kodzie z modułu 3)
- [x] liczba parametrów pomiędzy 100'000 a 200'000
- [x] wykorzystane przynajmniej 2 sposoby walki z regularyzacją
- [x] mieć wykonane co najmniej 4 zmiany w celu poprawy wyniku; zachowaj wszystkie iteracje (modyfikując model możesz dodać opcje w funkji, bądź skopiować klasę/funkcję, tak by było widać kolejne architektury)
- [x] opisz co chcesz sprawdzić w kolejnych eksperymentach (np. sprawdzę czy Dropout pomaga i z jaką wartością drop ratio najbardziej)
- [x] uzyskiwać lepsze `validation accuracy` niż w przypadku pierwszego modelu z poprzedniego modułu (im więcej punktów procentowych różnicy tym lepiej)

Zwizualizuj proszę:
- [ ] historie treningów (wystarczy Val acc, ale train acc czy lossy też mogą być)
- [ ] zależność: liczba parametrów - val acc

Możesz (czyli opcjonalne rzeczy):
- pracować na zmniejszonym zbiorze, by dobrać wartość parametrów
- np. zastosować dropout, pooling i early stopping
- zastosować TF2 - Keras / PyTorcha czy PL (Pytorch Lightning)
- dodać LR scheduler do swojego treningu (i sprawdzić czy to poprawiło wynik)
- zwizualizować dodatkowo:
  - confussion matrix
  - błędne przypadki

Warto:
- zmieniać 1 parametr między eksperymentami (szczególnie trudne gdy się już nabierze wyczucia)

In [1]:
! pip install lightning wandb && wandb login



wandb: Currently logged in as: gratkadlafana. Use `wandb login --relogin` to force relogin


## Importy

In [2]:
import gc
import os
import pathlib
import pprint
import urllib
from typing import Any, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset, TensorDataset

from torchvision import datasets, transforms

import lightning.pytorch as pl
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger

from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from torch.utils.tensorboard import SummaryWriter

from torchmetrics.functional.classification.accuracy import accuracy



### Klasa QuickDrawDataset

In [3]:
# wczytanie danych

class_names = [
    "airplane",
    "banana",
    "cookie",
    "diamond",
    "dog",
    "hot air balloon",
    "knife",
    "parachute",
    "scissors",
    "wine glass",
]
data_folder = "../data/quickdraw/"

# make sure data_folder exists - pathlib
pathlib.Path(data_folder).mkdir(parents=True, exist_ok=True)


class QuickDrawDataset(Dataset):
    """A Quick, Draw! dataset"""

    def __init__(
        self, classes, root_dir, download_data=False, load_data=True, transform=None
    ):
        """
        Arguments:
            classes (list[string]): List of classes to be used.
            root_dir (string): Directory with all the images.
            download (bool, optional) – If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
        """
        self.classes = classes
        self.root_dir = root_dir

        if download_data:
            self.download_data()

        if load_data:
            self.data, self.targets = self._load_data()

        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img, target = self.data[idx], int(self.targets[idx])

        if self.transform:
            img = self.transform(img)

        return img, target

    def download_data(self):
        for name in self.classes:
            url = (
                "https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/%s.npy"
                % name
            )
            file_name = self.root_dir + url.split("/")[-1].split("?")[0]

            url = url.replace(" ", "%20")

            if not os.path.isfile(file_name):
                print(url, "==>", file_name)
                urllib.request.urlretrieve(url, file_name)

    def _load_data(self):
        raw_data = []
        for name in self.classes:
            file_name = self.root_dir + name + ".npy"
            raw_data.append(np.load(file_name, fix_imports=True, allow_pickle=True))
            print("%-15s" % name, type(raw_data[-1]))

        reshaped_data = np.concatenate(raw_data).reshape(-1, 28, 28, 1)
        reshaped_targets = np.concatenate(
            [np.full(d.shape[0], i) for i, d in enumerate(raw_data)]
        )

        return reshaped_data, reshaped_targets

    def _set_data(self, data, targets):
        self.data = data
        self.targets = targets

    def split_train_test(self, test_size=0.2):
        """Split data into train and test sets using sklearn.model_selectiontrain_test_split function."""

        X_train, X_test, y_train, y_test = train_test_split(
            self.data,
            self.targets,
            test_size=test_size,
            random_state=12,
            stratify=self.targets,
        )

        train_dataset = QuickDrawDataset(
            self.classes,
            self.root_dir,
            download_data=False,
            load_data=False,
            transform=self.transform,
        )
        test_dataset = QuickDrawDataset(
            self.classes,
            self.root_dir,
            download_data=False,
            load_data=False,
            transform=self.transform,
        )

        train_dataset._set_data(X_train, y_train)
        test_dataset._set_data(X_test, y_test)

        return train_dataset, test_dataset

def get_torch_optimizer(optimizer_name, model_params, lr):
    if optimizer_name == "Adam":
        return torch.optim.Adam(
            model_params,
            lr=lr,
        )
    elif optimizer_name == "SGD":
        return torch.optim.SGD(
            model_params,
            lr=lr,
        )
    else:
        raise ValueError(f"Unknown optimizer {optimizer_name}")

def get_torch_loss(loss_name):
    if loss_name == "cross_entropy":
        return torch.nn.CrossEntropyLoss()
    elif loss_name == "mse":
        return torch.nn.MSELoss()
    else:
        raise ValueError(f"Unknown loss {loss_name}")

# from https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html

def train(dataloader, model, loss_fn, optimizer, device):
    size = len(dataloader.dataset)
    model.train()

    loss_samples = []
    acc_samples = []
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss_samples.append(loss.item())
            acc_samples.append(accuracy(pred, y))
     
    return np.mean(loss_samples), np.mean(acc_samples)

def test(dataloader, model, loss_fn, device):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [4]:
all_dataset = QuickDrawDataset(
    class_names,
    data_folder,
    download_data=True,
    load_data=True,
    transform=transforms.ToTensor(),
)

train_dataset, val_dataset = all_dataset.split_train_test(test_size=0.2)

# to save RAM
del all_dataset
gc.collect()

print(f"train_dataset: {len(train_dataset)} samples")
print(f"val_dataset: {len(val_dataset)} samples")

airplane        <class 'numpy.ndarray'>
banana          <class 'numpy.ndarray'>
cookie          <class 'numpy.ndarray'>
diamond         <class 'numpy.ndarray'>
dog             <class 'numpy.ndarray'>
hot air balloon <class 'numpy.ndarray'>
knife           <class 'numpy.ndarray'>
parachute       <class 'numpy.ndarray'>
scissors        <class 'numpy.ndarray'>
wine glass      <class 'numpy.ndarray'>
train_dataset: 1245340 samples
val_dataset: 311335 samples


In [5]:
base_config = {
    "batch_size": 128,
    "epochs": 10,
    "learning_rate": 1e-3,
    "device": (
        "cuda"
        if torch.cuda.is_available()
        else "cpu"
    ),
}
pprint.pprint(base_config)

{'batch_size': 128, 'device': 'cuda', 'epochs': 10, 'learning_rate': 0.001}


In [6]:
# przygotowane dataloadery

train_dataloader = DataLoader(train_dataset, batch_size=base_config["batch_size"])
val_dataloader = DataLoader(val_dataset, batch_size=base_config["batch_size"])

## Model bazowy

In [7]:
class QuickDrawNetwork(pl.LightningModule):
    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3, weight_decay=1e-5)

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y, task="multiclass", num_classes=self.num_classes)

        self.log('val_loss', loss)
        self.log('val_acc', acc)
        return loss

In [8]:
class QuickDrawNetwork_V1(QuickDrawNetwork):
    def __init__(self, dimensions, num_classes):
        super().__init__()

        self.channels, self.width, self.height = dimensions
        self.num_classes = num_classes
        self.name = "Baseline"

        self.model = nn.Sequential(
            nn.Conv2d(self.channels, 32, 3),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 256, 3),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(256, 32),
            nn.Linear(32, self.num_classes), 
        )

img_dimensions = (1, 28, 28)
model = QuickDrawNetwork_V1(img_dimensions, len(class_names))
num_of_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('Number of parameters:', num_of_params)

assert num_of_params > 100_000, "Za mało parametrów"
assert num_of_params < 200_000, "Za dużo parametrów"

Number of parameters: 175082


In [10]:
# trening
torch.set_float32_matmul_precision('medium')

model = QuickDrawNetwork_V1(img_dimensions, len(class_names))

tb_logger = TensorBoardLogger("lightning_logs", name=model.name)
wandb_logger = WandbLogger(project="deepdrive-modul-4", name=model.name)

trainer = pl.Trainer(
    max_epochs=base_config["epochs"], 
    precision=32,
    accelerator="gpu",
    logger=[
        tb_logger,
        wandb_logger
    ],
    limit_train_batches=0.1, # trenujemy tylko na 10% batchy z podzbioru treningowego
)

trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: lightning_logs\Baseline
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 175 K 
-------------------------------------
175 K     Trainable params
0         Non-trainable params
175 K     Total params
0.700     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


## 1 zmiana

Dodaję LR scheduler (OneCycleLR), żeby skrócić czas treningu i przyspieszyć kolejne iteracje zmian. W tym i w następnych modelach zastosowano również early stopping.

In [None]:
class QuickDrawNetwork_V2(QuickDrawNetwork):
    def __init__(self, dimensions, num_classes):
        super().__init__()

        self.channels, self.width, self.height = dimensions
        self.num_classes = num_classes
        self.name = "OneCycleLR"

        self.model = nn.Sequential(
            nn.Conv2d(self.channels, 32, 3),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 256, 3),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(256, 32),
            nn.Linear(32, self.num_classes), 
        )
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), weight_decay=1e-5)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer, max_lr=1e-3, total_steps=self.trainer.estimated_stepping_batches
        )
        return [optimizer], [scheduler]

model = QuickDrawNetwork_V2(img_dimensions, len(class_names))

tb_logger = TensorBoardLogger("lightning_logs", name=model.name)
wandb_logger = WandbLogger(project="deepdrive-modul-4", name=model.name)

trainer = pl.Trainer(
    max_epochs=base_config["epochs"], 
    precision=32,
    accelerator="gpu",
    logger=[
        tb_logger,
        wandb_logger
    ],
    callbacks=[
        EarlyStopping(monitor="val_acc", min_delta=0.01, patience=4, verbose=False, mode="max")
    ],
    limit_train_batches=0.1, # trenujemy tylko na 10% batchy z podzbioru treningowego
)

trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

### 2 zmiana

Dodaję regularyzację poprzez BatchNormalization w celu poprawienia wyniku po ok 18k kroku, gdy zaczyna się overfitting.

In [None]:
class QuickDrawNetwork_V3(QuickDrawNetwork):
    def __init__(self, dimensions, num_classes):
        super().__init__()

        self.channels, self.width, self.height = dimensions
        self.num_classes = num_classes
        self.name = "OneCycleLR + BatchNorm"

        self.model = nn.Sequential(
            nn.Conv2d(self.channels, 32, 3),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, 3),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 256, 3),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.MaxPool2d(2),

            nn.Flatten(),
            nn.Linear(256, 32),
            nn.Linear(32, self.num_classes), 
        )
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), weight_decay=1e-5)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer, max_lr=1e-3, total_steps=self.trainer.estimated_stepping_batches
        )
        return [optimizer], [scheduler]

model = QuickDrawNetwork_V3(img_dimensions, len(class_names))

tb_logger = TensorBoardLogger("lightning_logs", name=model.name)
wandb_logger = WandbLogger(project="deepdrive-modul-4", name=model.name)

trainer = pl.Trainer(
    max_epochs=base_config["epochs"], 
    precision=32,
    accelerator="gpu",
    logger=[
        tb_logger,
        wandb_logger
    ],
    callbacks=[
        EarlyStopping(monitor="val_acc", min_delta=0.01, patience=4, verbose=False, mode="max")
    ],
    limit_train_batches=0.1, # trenujemy tylko na 10% batchy z podzbioru treningowego
)

trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

### Trzecia zmiana

Porównujemy z regularyzacją poprzez Dropout

In [None]:
class QuickDrawNetwork_V4(QuickDrawNetwork):
    def __init__(self, dimensions, num_classes):
        super().__init__()

        self.channels, self.width, self.height = dimensions
        self.num_classes = num_classes
        self.name = "OneCycleLR + Dropout"

        self.model = nn.Sequential(
            nn.Conv2d(channels, 32, 3),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(p=dropout_ratio[0]),

            nn.Conv2d(32, 64, 3),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(p=dropout_ratio[1]),

            nn.Conv2d(64, 256, 3),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(p=dropout_ratio[1]),
            
            nn.Flatten(),
            nn.Linear(256, 32),
            nn.Linear(32, self.num_classes),  
        )
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), weight_decay=1e-5)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer, max_lr=1e-3, total_steps=self.trainer.estimated_stepping_batches
        )
        return [optimizer], [scheduler]

model = QuickDrawNetwork_V4(img_dimensions, len(class_names))

tb_logger = TensorBoardLogger("lightning_logs", name=model.name)
wandb_logger = WandbLogger(project="deepdrive-modul-4", name=model.name)

trainer = pl.Trainer(
    max_epochs=base_config["epochs"], 
    precision=32,
    accelerator="gpu",
    logger=[
        tb_logger,
        wandb_logger
    ],
    callbacks=[
        EarlyStopping(monitor="val_acc", min_delta=0.01, patience=4, verbose=False, mode="max")
    ],
    limit_train_batches=0.1, # trenujemy tylko na 10% batchy z podzbioru treningowego
)

trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

### Trzecia zmiana

Porównujemy z regularyzacją poprzez Dropout

In [None]:
class QuickDrawNetwork_V5(QuickDrawNetwork):
    def __init__(self, dimensions, num_classes, dropout_ratio=[0.2, 0.2]):
        super().__init__()

        self.channels, self.width, self.height = dimensions
        self.num_classes = num_classes
        
        ratio_string = "_".join([str(x) for x in ratio])
        self.name = f"OneCycleLR + Dropout ({ratio_string})"



        self.model = nn.Sequential(
            nn.Conv2d(channels, 32, 3),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(p=dropout_ratio[0]),

            nn.Conv2d(32, 64, 3),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(p=dropout_ratio[1]),

            nn.Conv2d(64, 256, 3),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(p=dropout_ratio[1]),
            
            nn.Flatten(),
            nn.Linear(256, 32),
            nn.Linear(32, self.num_classes),  
        )
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), weight_decay=1e-5)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer, max_lr=1e-3, total_steps=self.trainer.estimated_stepping_batches
        )
        return [optimizer], [scheduler]

dropout_ratios_same = [[0.2, 0.2], [0.3, 0.3], [0.5, 0.5]]
dropout_ratios_different = [ [0.2, 0.5], [0.3, 0.7], [0.5, 0.2] ]


for ratio in dropout_ratios_same + dropout_ratios_different:
    model = QuickDrawNetwork_V4(img_dimensions, len(class_names))

    tb_logger = TensorBoardLogger("lightning_logs", name=model.name)
    wandb_logger = WandbLogger(project="deepdrive-modul-4", name=model.name)

    trainer = pl.Trainer(
        max_epochs=base_config["epochs"], 
        precision=32,
        accelerator="gpu",
        logger=[
            tb_logger,
            wandb_logger
        ],
        callbacks=[
            EarlyStopping(monitor="val_acc", min_delta=0.01, patience=4, verbose=False, mode="max")
        ],
        limit_train_batches=0.1, # trenujemy tylko na 10% batchy z podzbioru treningowego
    )

    trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

Wniosek: Z wybranych opcji, dropout najlepiej zadziałał usuwając 50% map cech z pierwszej warstwy i 20% z warstw późniejszych.

### 4 zmiana - GlobalAvgPool vs GlobalMaxPool

Porównujemy nn.AdaptiveMaxPool2d(1) z nn.AdaptiveAvgPool2d(1)

In [None]:
class QuickDrawNetwork_V6(QuickDrawNetwork):
    def __init__(self, dimensions, num_classes, global_pooling, batch_norm=False):
        super().__init__()

        self.channels, self.width, self.height = dimensions
        self.num_classes = num_classes

        batch_norm_part = " + BatchNorm2d" if batch_norm else ""
        self.name = f"OneCycleLR + Global{pooling_name}Pool2d" + batch_norm_part

        pooling_layer = None
        if global_pooling == "Avg":
            pooling_layer = nn.AdaptiveAvgPool2d
        elif global_pooling == "Max":
            pooling_layer = nn.AdaptiveMaxPoo
        else:
            pooling_layer = lambda x: x

        batch_norm_layer = nn.BatchNorm2d if batch_norm else lambda x: x

        self.model = nn.Sequential(
            nn.Conv2d(channels, 32, 3),
            nn.ReLU(),
            batch_norm_layer(32),

            nn.Conv2d(32, 64, 3),
            nn.ReLU(),
            batch_norm_layer(64),

            nn.Conv2d(64, 256, 3),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            batch_norm_layer(256),
            
            pooling_layer(1),
            
            nn.Flatten(),
            nn.Linear(256, 32),
            nn.Linear(32, self.num_classes),  
        )
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), weight_decay=1e-5)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer, max_lr=1e-3, total_steps=self.trainer.estimated_stepping_batches
        )
        return [optimizer], [scheduler]

params = {
    "batch_norm": [True, False],
    "global_pooling": ["Avg", "Max"],
}

for global_poooling, batch_norm in product(params["global_pooling"], params["batch_norm"]):
    model = QuickDrawNetwork_V6(img_dimensions, len(class_names), global_pooling, batch_norm)

    tb_logger = TensorBoardLogger("lightning_logs", name=model.name)
    wandb_logger = WandbLogger(project="deepdrive-modul-4", name=model.name)

    trainer = pl.Trainer(
        max_epochs=base_config["epochs"], 
        precision=32,
        accelerator="gpu",
        logger=[
            tb_logger,
            wandb_logger
        ],
        callbacks=[
            EarlyStopping(monitor="val_acc", min_delta=0.01, patience=4, verbose=False, mode="max")
        ]
        limit_train_batches=0.1, # trenujemy tylko na 10% batchy z podzbioru treningowego
    )

    trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

Wniosek: GlobalAvgPool osiąga nieznacznie wyższy wynik

In [None]:
class QuickDrawNetwork_V6(QuickDrawNetwork):
    def __init__(self, dimensions, num_classes, global_pooling, batch_norm=False):
        super().__init__()

        self.channels, self.width, self.height = dimensions
        self.num_classes = num_classes

        self.name = "final_model"

        self.model = nn.Sequential(
            nn.Conv2d(channels, 32, 3),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.5),

            nn.Conv2d(32, 64, 3),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.2),

            nn.Conv2d(64, 256, 3),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.AdaptiveAvgPool2d(1),

            nn.Flatten(),
            nn.Linear(256, 32),
            nn.Linear(32, self.num_classes), 
        )
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), weight_decay=1e-5)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            optimizer, max_lr=1e-3, total_steps=self.trainer.estimated_stepping_batches
        )
        return [optimizer], [scheduler]


model = QuickDrawNetwork_V7(img_dimensions, len(class_names))

tb_logger = TensorBoardLogger("lightning_logs", name=model.name)
wandb_logger = WandbLogger(project="deepdrive-modul-4", name=model.name)

trainer = pl.Trainer(
    max_epochs=base_config["epochs"], 
    precision=32,
    accelerator="gpu",
    logger=[
        tb_logger,
        wandb_logger
    ],
    callbacks=[
        EarlyStopping(monitor="val_acc", min_delta=0.01, patience=4, verbose=False, mode="max")
    ]
)

trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)