In [None]:
# Conv-TasNet Hyperparameter Tuning Notebook
# Author: Graham Pellegrini | UOM Final Year Project

# ============================
# 📦 1. Setup & Imports
# ============================
import os
import torch
import torchaudio
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import time

from torch.utils.data import DataLoader
from torch import nn
from Utils.train import train_eval
from Utils.models import ConvTasNet
from Utils.dataset import DynamicBuckets, BucketSampler
import config

# ============================
# 🧩 2. Configuration
# ============================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Dataset and preprocessing parameters
sr = config.SAMPLE_RATE
n_fft = config.N_FFT
hop_length = config.HOP_LENGTH
batch_size = config.BATCH_SIZE
accum_steps = config.ACCUMULATION_STEPS
num_workers = config.NUM_WORKERS
num_buckets = config.NUM_BUCKET

# Datasets
train_dataset = DynamicBuckets(config.DATASET_DIR, "trainset_56spk", sr, n_fft, hop_length, num_buckets)
val_dataset = DynamicBuckets(config.DATASET_DIR, "trainset_28spk", sr, n_fft, hop_length, num_buckets)
train_sampler = BucketSampler(train_dataset.bucket_indices, batch_size=batch_size)
val_sampler = BucketSampler(val_dataset.bucket_indices, batch_size=batch_size)
train_loader = DataLoader(train_dataset, batch_sampler=train_sampler, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_sampler=val_sampler, num_workers=num_workers)

# ============================
# 🔍 3. Define Search Space
# ============================
hparam_trials = [
    {"enc_dim": 64, "feature_dim": 32, "kernel_size": (3, 3), "num_layers": 3, "num_stacks": 2},
    {"enc_dim": 128, "feature_dim": 48, "kernel_size": (3, 3), "num_layers": 4, "num_stacks": 2},  # baseline
    {"enc_dim": 128, "feature_dim": 64, "kernel_size": (5, 5), "num_layers": 4, "num_stacks": 3},
    {"enc_dim": 192, "feature_dim": 64, "kernel_size": (3, 3), "num_layers": 5, "num_stacks": 2},
]

results = []

# ============================
# 🚀 4. Run Trials
# ============================
for i, hparams in enumerate(hparam_trials):
    print(f"\n🎯 Trial {i+1}: {hparams}")

    model = ConvTasNet(
        enc_dim=hparams["enc_dim"],
        feature_dim=hparams["feature_dim"],
        kernel_size=hparams["kernel_size"],
        num_layers=hparams["num_layers"],
        num_stacks=hparams["num_stacks"]
    )

    optimizer = torch.optim.Adam(model.parameters(), lr=config.LEARNING_RATE)
    criterion = nn.MSELoss()
    
    save_path = f"Models/Trial_{i+1}_ConvTasNet.pth"
    
    start = time.time()
    train_eval(
        device,
        model,
        train_loader,
        val_loader,
        optimizer,
        criterion,
        epochs=5,  # fewer epochs for tuning
        accumulation_steps=accum_steps,
        save_pth=save_path,
        pto=False,
        scheduler=config.SCHEDULER
    )
    end = time.time()

    # Record result
    results.append({
        "trial": i+1,
        **hparams,
        "val_loss": model.best_val_loss if hasattr(model, "best_val_loss") else "NA",
        "time": end - start
    })

# ============================
# 📊 5. Save Results & Plot
# ============================
results_df = pd.DataFrame(results)
results_df.to_csv("Output/hparam_tuning_results.csv", index=False)

# Plot
plt.figure(figsize=(10,5))
plt.bar(results_df["trial"], results_df["val_loss"], tick_label=[f"T{i+1}" for i in range(len(results_df))])
plt.title("Conv-TasNet Validation Loss per Trial")
plt.xlabel("Trial")
plt.ylabel("Validation Loss")
plt.grid(True)
plt.savefig("Output/png/hparam_val_loss_plot.png")
plt.show()
