## Set up paths and imports

In [None]:
import os

import torch
from torchvision import transforms

if not os.path.exists("./notebooks"):
    %cd ..

import src.model
from src.training import train, validate
from src.dataset import prepare_dataset_loaders

wandb_enabled = False

## 1. Configure training

In [None]:
class Config:
    learning_rate = 0.005
    epochs = 40
    batch_size = 32
    image_size = (32, 32) # TODO: choose best image_size
config = Config()

### Optionally initialize W&B project

In [None]:
import wandb

wandb.init(project="iml", config=vars(config))
wandb_enabled = True

## 2. Set up data transformations and data loaders

In [None]:
transform = transforms.Compose([
    transforms.Resize(config.image_size),
    transforms.ToTensor(),
    # TODO: normalization
])

train_loader, val_loader, test_loader = prepare_dataset_loaders(transform, config.batch_size)

## 3. Initialize model, optimizer and set device

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = src.model.SimpleCNN().to(device)
model = src.model.TutorialCNN().to(device) # - for 32x32 images
model.device = device
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)


## 4. Training and validation loop

In [None]:

for epoch in range(config.epochs):
    print(f"Epoch {epoch+1}/{config.epochs}")

    if wandb_enabled:
        logger = wandb.log
    else:
        logger = lambda data,step: print(f"  Step {step}: {data}")

    train(model, train_loader, criterion, optimizer, epoch, logger, len(train_loader) // 5 - 1)
    # 0 for recall and precision in first few epochs is expected (case when one class wasn't predicted yet)
    metrics = validate(model, val_loader)
    print(metrics)

    if wandb_enabled:
        wandb.log({"validation/recall": metrics.recall, "validation/precision": metrics.precision, "validation/f1": metrics.f1, "epoch": epoch+1})


## 5. Check model performance on test data

In [None]:
# TODO

## 6. Save the model

In [None]:

model_path = "./models/simple_cnn.pth"
os.makedirs(os.path.dirname(model_path), exist_ok=True)
torch.save(model.state_dict(), model_path)

if wandb_enabled: 
    wandb.save(model_path)
    wandb.finish()

print("Training complete and model saved!")
