## Set up paths and imports

In [None]:
import os

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvggish import vggish
from torchvggish import vggish_input
import librosa
import numpy as np

if not os.path.exists("./notebooks"):
    %cd ..

from src.training import do_train, do_test
from src.audio_dataset_processor import DAPSDatasetProcessor
from src.dataset import BalancedBatchSampler
from src.config import VALID_ACCESS_LABELS, DATA_DIR

wandb_enabled = False

## 1. Define Config

In [None]:
class Config:
    def __init__(self, lr=0.001, epochs=40, batch_size=32):
        self.learning_rate = lr
        self.epochs = epochs
        self.batch_size = batch_size

### Optionally initialize W&B project

In [None]:
wandb_enabled = True

## 2. Choose device

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## 3. Define VGGish model and feature_extraction code
`VGGish` takes `.wav` files and have its own method to set proper settings on sampled audio. Inference gives 128 floats (these should be semantic features of clip) for every second of recording, so we early return if clip duration is not divisible by `SPLIT_SECONDS`. 

In [None]:
model = vggish()
model.eval()

SPLIT_SECONDS = 3

def preprocess_audio(file_path, target_sample_rate=16000):
    """
    Load a .wav file, convert to mono, and preprocess into log-Mel spectrogram.
    """
    audio, sr = librosa.load(file_path, sr=target_sample_rate, mono=True)
    
    if len(audio) < target_sample_rate:
        padding = target_sample_rate - len(audio)
        audio = np.pad(audio, (0, padding), mode='constant')

    mel_spec = vggish_input.waveform_to_examples(audio, sr)
    return torch.tensor(mel_spec)

def extract_features(file_paths):
    features = []
    for file in file_paths:
        mel_spec = preprocess_audio(file)
        speaker_id = os.path.basename(file).split("_")[0]
        label = int(speaker_id in VALID_ACCESS_LABELS)

        with torch.no_grad():
            file_features = model(mel_spec)
        
        for idx, feature in enumerate(file_features):
            if idx >= len(file_features) - (len(file_features) % SPLIT_SECONDS):
                break
            features.append((torch.tensor(feature), label))
    return features

## 4. Use common code to split `.wav` files into datasets

In [None]:
allowed_directories=['ipadflat_confroom1', 'ipadflat_office1', 'ipad_balcony1', 'ipad_bedroom1', 'ipad_confroom1', 'ipad_confroom2', 'ipad_livingroom1', 'ipad_office1', 'ipad_office2', 'iphone_balcony1', 'iphone_bedroom1', 'iphone_livingroom1']
dataset_processor = DAPSDatasetProcessor(DATA_DIR, VALID_ACCESS_LABELS, allowed_directories)
dataset_processor.compute_statistics()
train_set, validate_set, test_set = dataset_processor.get_datasets()

## 5. Define `VGGish` specific Dataset class
It takes batches of `SPLIT_SECONDS` 1-second features embedded using `VGGish`. This class would be used as dataset for simple classifier.

In [None]:
from torch.utils.data import Dataset

class VGGishDataset(Dataset):
    def __init__(self, files):
        self.data = extract_features(files)

    def __len__(self):
        return int(len(self.data) / SPLIT_SECONDS)

    def __getitem__(self, idx):
        spectrogram, label = self.data[idx * SPLIT_SECONDS]
        spectrogram2, label = self.data[idx * SPLIT_SECONDS + 1]
        spectrogram3, label = self.data[idx * SPLIT_SECONDS + 2]
        return torch.cat((spectrogram, spectrogram2, spectrogram3), dim=0), label
    
train_dataset = VGGishDataset(train_set)
val_dataset = VGGishDataset(validate_set)
test_dataset = VGGishDataset(test_set)



## 6. Define simple classifier

In [None]:
N_CLASSES = 2
INPUT_DIM = 128 * SPLIT_SECONDS
HIDDEN_DIM = 256
N_HIDDEN_LAYERS = 1

class ClassifierForVGGish(nn.Module):
    def __init__(self):
        super(ClassifierForVGGish, self).__init__()
        layers = []
        layers.append(nn.Linear(INPUT_DIM, HIDDEN_DIM))
        layers.append(nn.ReLU())
        
        for _ in range(N_HIDDEN_LAYERS - 1):
            layers.append(nn.Linear(HIDDEN_DIM, HIDDEN_DIM))
            layers.append(nn.ReLU())
        
        layers.append(nn.Linear(HIDDEN_DIM, N_CLASSES))
        
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)



## 7. Define config and run train, validation loops and test evaluation

In [None]:
model = ClassifierForVGGish()
config = Config(batch_size=32, epochs=40, lr=0.001)
name = "VGGish_transfer_learning"
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

sampler = BalancedBatchSampler(train_dataset, config.batch_size)
train_loader = DataLoader(train_dataset, batch_sampler=sampler)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False)

run = do_train(name, train_loader, val_loader, config, model, criterion, optimizer, device, wandb_enabled)
do_test(name, test_loader, model.__class__, run, device, wandb_enabled)