## Set up paths and imports

In [None]:
import os

from PIL import Image

import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

if not os.path.exists("./notebooks"):
    %cd ..

import src.model
from src.training import train, validate
from src.config import VALID_ACCESS_LABELS, TRAIN_DIR, TEST_DIR, VAL_DIR

wandb_enabled = False

## 1. Configure training

In [None]:
class Config:
    learning_rate = 0.005
    epochs = 40
    batch_size = 32
    image_size = (32, 32) # TODO: choose best image_size
config = Config()

### Optionally initialize W&B project

In [None]:
import wandb

wandb.init(project="iml", config=vars(config))
wandb_enabled = True

## 2. Define a custom dataset class to load spectrograms from files

In [None]:
class SpectrogramDataset(Dataset):
    def __init__(self, directory, transform=None):
        self.files = [os.path.join(directory, f) for f in os.listdir(directory) if f.endswith('.png')]
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_path = self.files[idx]
        speaker_id = img_path.split('/')[-1].split('_')[0]
        label = int(speaker_id in VALID_ACCESS_LABELS)
        image = Image.open(img_path).convert("L")
        
        if self.transform:
            image = self.transform(image)
        
        return image, label


## 3. Set up data transformations and data loaders

In [None]:
transform = transforms.Compose([
    transforms.Resize(config.image_size),
    transforms.ToTensor(),
    # TODO: normalization
])

train_dataset = SpectrogramDataset(TRAIN_DIR, transform=transform)
val_dataset = SpectrogramDataset(VAL_DIR, transform=transform)
test_dataset = SpectrogramDataset(TEST_DIR, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)


## 4. Initialize model, optimizer and set device

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = src.model.SimpleCNN().to(device)
model = src.model.TutorialCNN().to(device) # - for 32x32 images
model.device = device
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)


## 5. Training and validation loop

In [None]:

for epoch in range(config.epochs):
    print(f"Epoch {epoch+1}/{config.epochs}")

    if wandb_enabled:
        logger = wandb.log
    else:
        logger = lambda data,step: print(f"  Step {step}: {data}")

    train(model, train_loader, criterion, optimizer, epoch, logger, len(train_loader) // 5 - 1)
    # 0 for recall and precision in first few epochs is expected (case when one class wasn't predicted yet)
    metrics = validate(model, val_loader)
    print(metrics)

    if wandb_enabled:
        wandb.log({"validation/recall": metrics.recall, "validation/precision": metrics.precision, "validation/f1": metrics.f1, "epoch": epoch+1})


## 6. Save the model

In [None]:

model_path = "./models/simple_cnn.pth"
os.makedirs(os.path.dirname(model_path), exist_ok=True)
torch.save(model.state_dict(), model_path)

if wandb_enabled: 
    wandb.save(model_path)
    wandb.finish()

print("Training complete and model saved!")
