## Set up paths and imports

In [None]:
import os

import torch
from torchvision import transforms

if not os.path.exists("./notebooks"):
    %cd ..

import src.model
from src.training import do_train, do_test
from src.dataset import prepare_dataset_loaders
from src.data_processing import load_mean_std
from src.config import DATASET_DIR

wandb_enabled = False

## 1. Load standarization data and define Config

In [None]:
mean, std = load_mean_std(f"{DATASET_DIR}/scaling_params.json")

class Config:
    def __init__(self, lr=0.001, epochs=40, batch_size=32):
        self.learning_rate = lr
        self.epochs = epochs
        self.batch_size = batch_size

### Optionally initialize W&B project

In [None]:
wandb_enabled = True

## 2. Choose device

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Choose your architecture

In [None]:
name = "TutorialCNN without standardization"
model = src.model.TutorialCNN()
config = Config()
transform = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.ToTensor()
])
train_loader, val_loader, test_loader = prepare_dataset_loaders(transform, config.batch_size)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

run = do_train(name, train_loader, val_loader, config, model, criterion, optimizer, device, wandb_enabled)
do_test(name, test_loader, model.__class__, run, device, wandb_enabled)

In [None]:
name = "TutorialCNN"
model = src.model.TutorialCNN()
config = Config()
transform = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])
train_loader, val_loader, test_loader = prepare_dataset_loaders(transform, config.batch_size)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

run = do_train(name, train_loader, val_loader, config, model, criterion, optimizer, device, wandb_enabled)
do_test(name, test_loader, model.__class__, run, device, wandb_enabled)

In [None]:
name = "OriginalSizeCNN"
model = src.model.OriginalSizeCNN()
config = Config()
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])
train_loader, val_loader, test_loader = prepare_dataset_loaders(transform, config.batch_size)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

run = do_train(name, train_loader, val_loader, config, model, criterion, optimizer, device, wandb_enabled)
do_test(name, test_loader, model.__class__, run, device, wandb_enabled)

In [None]:
name = "DropoutCNN"
model = src.model.DropoutCNN()
config = Config()
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])
train_loader, val_loader, test_loader = prepare_dataset_loaders(transform, config.batch_size)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

do_train(name, train_loader, val_loader, config, model, criterion, optimizer)

In [None]:
ensemble_model_names = ["OriginalSizeCNN-HE-RELU", "OriginalSizeCNN-UNIFORM-RELU", "OriginalSizeCNN-XAVIER-RELU"]

ensemble_models = []
for model_name in ensemble_model_names:
    model = src.model.OriginalSizeCNN()
    model.load_state_dict(torch.load(f"./models/{model_name}.pth", weights_only=True))
    model.device = device
    model.to(device)
    ensemble_models.append(model)

name = "EnsembleCNN"
model = src.model.EnsembleCNN(ensemble_models, 2)
config = Config()
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])
train_loader, val_loader, test_loader = prepare_dataset_loaders(transform, config.batch_size)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

do_train(name, train_loader, val_loader, config, model, criterion, optimizer)

### Comparison of models
Comparison of architectures trainable using this notebook can be seen [here](https://wandb.ai/mytkom-warsaw-university-of-technology/iml/reports/Comparison-of-from-scratch-architectures--VmlldzoxMDU0MDk4NQ?accessToken=mle3zdqu8bxvrc4z8pdhl89talltdlml5gw5zmictx9e0qhvue0k5awsdggr37vp)