## Set up paths and imports

In [None]:
import os

import torch
import torch.nn as nn
from torchvision import transforms

if not os.path.exists("./notebooks"):
    %cd ..

from src.training import train, validate
from src.dataset import prepare_dataset_loaders
from src.audio_dataset_processor import AudioDatasetProcessor
from src.config import VALID_ACCESS_LABELS, DATA_DIR, PATIENCE_THRESHOLD

wandb_enabled = False

In [None]:
class Config:
    def __init__(self, lr=0.001, epochs=40, batch_size=32):
        self.learning_rate = lr
        self.epochs = epochs
        self.batch_size = batch_size

### Optionally initialize W&B project

In [None]:
import wandb

wandb_enabled = True

## 1. Split all allowed .wav files
We are using [DAPS](https://zenodo.org/records/4660670) dataset. It has several directories available in which there are .wav files of 5 scripts read by 20 speakers. Directories differ from each other with augmentation, which is labeled by `room` and `recording device`. In this cell we are specifying allowed directories, their contents are being discovered and splitted into 3 datasets (training, validation and test). The same speaker with the same script with different augmentation cannot be in the same dataset - `AudioDatasetProcessor` class take care of that. `balance` parameter set to true balances authorized and unauthorized speakers file count in training dataset.

In [None]:
allowed_directories=['ipadflat_confroom1', 'ipadflat_office1', 'ipad_balcony1', 'ipad_bedroom1', 'ipad_confroom1', 'ipad_confroom2', 'ipad_livingroom1', 'ipad_office1', 'ipad_office2', 'iphone_balcony1', 'iphone_bedroom1', 'iphone_livingroom1']
dataset_processor = AudioDatasetProcessor(DATA_DIR, VALID_ACCESS_LABELS, allowed_directories)
dataset_processor.compute_statistics()
train_set, validate_set, test_set = dataset_processor.get_datasets(balanced=True) # if you want unbalanced set parameter to False

## 2. Define training and validation loop

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def do_train(name, train_loader, val_loader, config, model, criterion, optimizer):
    if wandb_enabled:
        run = wandb.init(name=name, project="iml", config=vars(config))
 
    model.device = device
    model.to(device)

    saved = False
    patience = 0
    best_f1 = -1

    for epoch in range(config.epochs):
        print(f"Epoch {epoch+1}/{config.epochs}")

        if wandb_enabled:
            logger = wandb.log
        else:
            logger = lambda data,step: print(f"  Step {step}: {data}")

        train(model, train_loader, criterion, optimizer, epoch, logger, len(train_loader) // 5 - 1)
        metrics = validate(model, val_loader)
        print(metrics)

        if wandb_enabled:
            wandb.log({"validation/recall": metrics.recall, "validation/accuracy": metrics.accuracy, "validation/precision": metrics.precision, "validation/f1": metrics.f1, "epoch": epoch+1})

        if metrics.f1 < best_f1:
            patience = patience + 1
        else:
            patience = 0
            best_f1 = metrics.f1
        if patience >= PATIENCE_THRESHOLD:
            model_path = f"./models/{name}.pth"
            os.makedirs(os.path.dirname(model_path), exist_ok=True)
            torch.save(model.state_dict(), model_path)
            saved = True;

    if(saved == False):
            model_path = f"./models/{name}.pth"
            os.makedirs(os.path.dirname(model_path), exist_ok=True)
            torch.save(model.state_dict(), model_path)

    if wandb_enabled: 
        wandb.save(model_path)
        wandb.finish()
        return run
    
def do_test(name, test_loader, model_class, run):
    if wandb_enabled:
        wandb.init(name=name, project="iml", resume="must", id=run.id)
 
    model = model_class()
    model.device = device
    model.to(device)

    model_path = f"./models/{name}.pth"
    model.load_state_dict(torch.load(model_path, map_location=device))

    metrics = validate(model, test_loader)
    print(metrics)

    if wandb_enabled:
        wandb.log({"test/recall": metrics.recall, "test/accuracy": metrics.accuracy, "test/precision": metrics.precision, "test/f1": metrics.f1})

    if wandb_enabled: 
        wandb.finish()

## 3. Choose pretrained model architecture.
`EfficientNetB0` is much lighter than `VGG16` and according to our tests there is almost no difference between their accuracy.

In [None]:
# EfficientNetB0
from torchvision.models import efficientnet_b0
from torchvision.models import EfficientNet_B0_Weights

weights = EfficientNet_B0_Weights.DEFAULT
pretrained_model = efficientnet_b0(weights=weights)
pre_trans = weights.transforms()
name_base="EfficientNet_B0"
num_features = pretrained_model.classifier[1].in_features

In [None]:
# VGG16
from torchvision.models import vgg16
from torchvision.models import VGG16_Weights

weights = VGG16_Weights.DEFAULT
pretrained_model = vgg16(weights=weights)
pre_trans = weights.transforms()
name_base="VGG16"
num_features = pretrained_model.classifier[0].in_features

## 4. Choose training approach
Options are:
1. freeze pretrained model and train only small classifier based on pretrained model input (transfer learning).
2. train both pretrained model and added classifier (fine-tuning).

In [None]:
# Freeze base model (transfer learning)
pretrained_model.requires_grad_(False)
next(iter(pretrained_model.parameters())).requires_grad
name = name_base + "_transfer_learning"

In [None]:
# Do not freeze model
name = name_base + "_fine_tuning"

##  5. Add our small classifier after pretrained model's feature extraction

In [None]:
# Our own classifier
N_CLASSES = 2

pretrained_model.classifier = nn.Sequential(
    nn.Linear(num_features, 256),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(256, N_CLASSES)
)
my_model = pretrained_model

In [None]:
from torch.utils.data import Dataset
from PIL import Image

class SpectrogramRGBDataset(Dataset):
    def __init__(self, directory, transform=None):
        self.files = [
            os.path.join(directory, f)
            for f in os.listdir(directory)
            if f.endswith(".png")
        ]
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_path = self.files[idx]
        speaker_id = img_path.split("/")[-1].split("_")[0]
        label = int(speaker_id in VALID_ACCESS_LABELS)

        image = Image.open(img_path).convert("RGB")
        image = pre_trans(image)

        if self.transform:
            image = self.transform(image)

        return image, label
    

In [None]:
model = my_model
config = Config(batch_size=32, epochs=40, lr=0.0001)
transform = transforms.Compose([])
train_loader, val_loader, test_loader = prepare_dataset_loaders(transform, config.batch_size, SpectrogramRGBDataset)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

run = do_train(name, train_loader, val_loader, config, model, criterion, optimizer)

In [None]:
do_test(name, test_loader, model.__class__, run)