## Set up paths and imports

In [None]:
import os

import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvggish import vggish
from torchvggish import vggish_input
import librosa
import numpy as np

if not os.path.exists("./notebooks"):
    %cd ..

from src.training import train, validate
from src.audio_dataset_processor import AudioDatasetProcessor
from src.config import PATIENCE_THRESHOLD, VALID_ACCESS_LABELS, DATA_DIR

wandb_enabled = False

## 1. Load standarization data and define Config

In [None]:
class Config:
    def __init__(self, lr=0.001, epochs=40, batch_size=32):
        self.learning_rate = lr
        self.epochs = epochs
        self.batch_size = batch_size

### Optionally initialize W&B project

In [None]:
import wandb

wandb_enabled = True

In [None]:
allowed_directories=['ipadflat_confroom1', 'ipadflat_office1', 'ipad_balcony1', 'ipad_bedroom1', 'ipad_confroom1', 'ipad_confroom2', 'ipad_livingroom1', 'ipad_office1', 'ipad_office2', 'iphone_balcony1', 'iphone_bedroom1', 'iphone_livingroom1']
dataset_processor = AudioDatasetProcessor(DATA_DIR, VALID_ACCESS_LABELS, allowed_directories)
dataset_processor.compute_statistics()
train_set, validate_set, test_set = dataset_processor.get_datasets(balanced=True) # if you want unbalanced set parameter to False

## 2. Define training and validation loop

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def do_train(name, train_loader, val_loader, config, model, criterion, optimizer):
    if wandb_enabled:
            wandb.init(name=name, project="iml", config=vars(config))
 
    model.device = device
    model.to(device)

    saved = False
    patience = 0
    best_f1 = -1

    for epoch in range(config.epochs):
        print(f"Epoch {epoch+1}/{config.epochs}")

        if wandb_enabled:
            logger = wandb.log
        else:
            logger = lambda data,step: print(f"  Step {step}: {data}")

        train(model, train_loader, criterion, optimizer, epoch, logger, len(train_loader) // 5 - 1)
        metrics = validate(model, val_loader)
        print(metrics)

        if wandb_enabled:
            wandb.log({"validation/recall": metrics.recall, "validation/accuracy": metrics.accuracy, "validation/precision": metrics.precision, "validation/f1": metrics.f1, "epoch": epoch+1})

        if metrics.f1 < best_f1:
            patience = patience + 1
        else:
            patience = 0
            best_f1 = metrics.f1
        if patience >= PATIENCE_THRESHOLD:
            model_path = f"./models/{name}.pth"
            os.makedirs(os.path.dirname(model_path), exist_ok=True)
            torch.save(model.state_dict(), model_path)
            saved = True

    if(saved == False):
            model_path = f"./models/{name}.pth"
            os.makedirs(os.path.dirname(model_path), exist_ok=True)
            torch.save(model.state_dict(), model_path)

    if wandb_enabled:
        wandb.save(model_path)
        wandb.finish()
    


In [None]:
model = vggish()
model.eval()

SPLIT_SECONDS = 3

def preprocess_audio(file_path, target_sample_rate=16000):
    """
    Load a .wav file, convert to mono, and preprocess into log-Mel spectrogram.
    """
    audio, sr = librosa.load(file_path, sr=target_sample_rate, mono=True)
    
    if len(audio) < target_sample_rate:
        padding = target_sample_rate - len(audio)
        audio = np.pad(audio, (0, padding), mode='constant')

    mel_spec = vggish_input.waveform_to_examples(audio, sr)
    return torch.tensor(mel_spec)

def extract_features(file_paths):
    features = []
    for file in file_paths:
        mel_spec = preprocess_audio(file)
        speaker_id = os.path.basename(file).split("_")[0]
        label = int(speaker_id in VALID_ACCESS_LABELS)

        with torch.no_grad():
            file_features = model(mel_spec)
        
        for idx, feature in enumerate(file_features):
            if idx >= len(file_features) - (len(file_features) % SPLIT_SECONDS):
                break
            features.append((torch.tensor(feature), label))
    return features

In [None]:
from torch.utils.data import Dataset

class VGGishDataset(Dataset):
    def __init__(self, files):
        self.data = extract_features(files)

    def __len__(self):
        return int(len(self.data) / SPLIT_SECONDS)

    def __getitem__(self, idx):
        spectrogram, label = self.data[idx * SPLIT_SECONDS]
        spectrogram2, label = self.data[idx * SPLIT_SECONDS + 1]
        spectrogram3, label = self.data[idx * SPLIT_SECONDS + 2]
        return torch.cat((spectrogram, spectrogram2, spectrogram3), dim=0), torch.tensor(label, dtype=torch.long)
    
train_dataset = VGGishDataset(train_set)
val_dataset = VGGishDataset(validate_set)


In [None]:
N_CLASSES = 2

class ClassifierForVGGish(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes):
        super(ClassifierForVGGish, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

input_dim = 128 * SPLIT_SECONDS
hidden_dim = 256
num_classes = 2

my_model = ClassifierForVGGish(input_dim, hidden_dim, num_classes)

In [None]:
model = my_model
config = Config(batch_size=32, epochs=40, lr=0.0001)
name = "VGGish_transfer_learning"
transform = transforms.Compose([])
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)

do_train(name, train_loader, val_loader, config, model, criterion, optimizer)