## Set up paths and imports

In [1]:
import os

import torch
from torchvision import transforms

if not os.path.exists("./notebooks"):
    %cd ..

import src.model
from src.training import train, validate
from src.dataset import prepare_dataset_loaders
from src.data_processing import load_mean_std
from src.config import DATASET_DIR, PATIENCE_THRESHOLD

wandb_enabled = False

  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


/home/mytkom/Documents/iml


## 1. Load standarization data and define Config

In [2]:
mean, std = load_mean_std(f"{DATASET_DIR}/scaling_params.json")

class Config:
    def __init__(self, lr=0.001, epochs=40, batch_size=32):
        self.learning_rate = lr
        self.epochs = epochs
        self.batch_size = batch_size

### Optionally initialize W&B project

In [3]:
import wandb

wandb_enabled = True

## 2. Define training and validation loop

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def do_train(name, train_loader, val_loader, config, model, criterion, optimizer):
    if wandb_enabled:
            wandb.init(name=name, project="iml", config=vars(config))
 
    model.device = device
    model.to(device)

    saved = False
    patience = 0
    best_f1 = -1

    for epoch in range(config.epochs):
        print(f"Epoch {epoch+1}/{config.epochs}")

        if wandb_enabled:
            logger = wandb.log
        else:
            logger = lambda data,step: print(f"  Step {step}: {data}")

        train(model, train_loader, criterion, optimizer, epoch, logger, len(train_loader) // 5 - 1)
        metrics = validate(model, val_loader)
        print(metrics)

        if wandb_enabled:
            wandb.log({"validation/recall": metrics.recall, "validation/accuracy": metrics.accuracy, "validation/precision": metrics.precision, "validation/f1": metrics.f1, "epoch": epoch+1})

        if metrics.f1 < best_f1:
            patience = patience + 1
        else:
            patience = 0
            best_f1 = metrics.f1
        if patience >= PATIENCE_THRESHOLD:
            model_path = f"./models/{name}.pth"
            os.makedirs(os.path.dirname(model_path), exist_ok=True)
            torch.save(model.state_dict(), model_path)
            saved = True;

    if(saved == False):
            model_path = f"./models/{name}.pth"
            os.makedirs(os.path.dirname(model_path), exist_ok=True)
            torch.save(model.state_dict(), model_path)

    if wandb_enabled:
        wandb.save(model_path)
        wandb.finish()


    


In [None]:
name = "TutorialCNN without standardization"
model = src.model.TutorialCNN()
config = Config()
transform = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.ToTensor()
])
train_loader, val_loader, test_loader = prepare_dataset_loaders(transform, config.batch_size)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
   
do_train(name, train_loader, val_loader, config, model, criterion, optimizer)

In [None]:
name = "TutorialCNN"
model = src.model.TutorialCNN()
config = Config()
transform = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])
train_loader, val_loader, test_loader = prepare_dataset_loaders(transform, config.batch_size)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

do_train(name, train_loader, val_loader, config, model, criterion, optimizer)

In [5]:
name = "OriginalSizeCNN"
model = src.model.OriginalSizeCNN()
config = Config()
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])
train_loader, val_loader, test_loader = prepare_dataset_loaders(transform, config.batch_size)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

do_train(name, train_loader, val_loader, config, model, criterion, optimizer)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mmytkom[0m ([33mmytkom-warsaw-university-of-technology[0m). Use [1m`wandb login --relogin`[0m to force relogin


Epoch 1/40
Metrics:
    F1: 0.91,
    Accuracy: 0.88,
    Recall: 0.94,
    Precision: 0.89,
    False acceptance: 0.25,
    False rejection: 0.06
Epoch 2/40
Metrics:
    F1: 0.91,
    Accuracy: 0.88,
    Recall: 0.88,
    Precision: 0.94,
    False acceptance: 0.12,
    False rejection: 0.12
Epoch 3/40
Metrics:
    F1: 0.93,
    Accuracy: 0.91,
    Recall: 0.95,
    Precision: 0.92,
    False acceptance: 0.19,
    False rejection: 0.05
Epoch 4/40
Metrics:
    F1: 0.93,
    Accuracy: 0.90,
    Recall: 0.95,
    Precision: 0.91,
    False acceptance: 0.20,
    False rejection: 0.05
Epoch 5/40
Metrics:
    F1: 0.94,
    Accuracy: 0.92,
    Recall: 0.94,
    Precision: 0.94,
    False acceptance: 0.14,
    False rejection: 0.06
Epoch 6/40
Metrics:
    F1: 0.94,
    Accuracy: 0.92,
    Recall: 0.95,
    Precision: 0.93,
    False acceptance: 0.16,
    False rejection: 0.05
Epoch 7/40
Metrics:
    F1: 0.94,
    Accuracy: 0.92,
    Recall: 0.96,
    Precision: 0.93,
    False acceptance: 0.1

0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/accuracy,▁▄▇▇▇▇█▇████████████████████████████████
train/loss,█▇▆▃▂▂▂▂▂▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▂▁▁▁
validation/accuracy,▁▁▄▄▆▆▆▃▅▆▆▆▇▇▅▇▄▅▇▅▆▆▆▇▇▇▇████▆▅▇▇▇▇▅█▆
validation/f1,▂▁▅▄▆▆▆▂▆▇▆▆▇▇▅▇▅▅▇▅▆▆▇▇▇▇▇████▆▅▇▇█▇▅█▆
validation/precision,▁▆▃▃▅▅▄█▃▅▅▇▄▆▇▅▁▆▄█▆▇▃▆▄▆▅▅▆▆▆▅▇▅▅▅▅▆▆▆
validation/recall,▅▁▆▆▅▆▆▁▇▆▆▅▇▆▄▆█▄▇▃▆▅▇▅▇▆▆▇▆▆▆▅▄▇▆▇▆▄▆▅

0,1
epoch,40.0
train/accuracy,0.99769
train/loss,0.00799
validation/accuracy,0.91924
validation/f1,0.94114
validation/precision,0.94294
validation/recall,0.93935
