# Tuning Notebook

This is the entire hyperparameter tuning script condensed into a two cells so I can copy/paste it to a Jupyter notebook on vast.ai.

In [None]:
import os

# for multi-gpu instances this picks which GPU the notebook uses
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

!pip install pandas
!pip install librosa
!pip install matplotlib
!pip install torch_fidelity
import random
import glob
import pandas as pd
import librosa
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import json
import hashlib
import tempfile
import torch.optim as optim
import time
from torch.nn import BCEWithLogitsLoss
from torch_fidelity import calculate_metrics
from torchvision.utils import save_image
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from scipy.io.wavfile import write as wavwrite

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if not os.path.exists("/workspace/Data/genres_original/"):
    !unzip -o "gtzan.zip" -d "/workspace/"
else:
    print("Data already exists, skipping unzip")
base_path = '/workspace/Data/genres_original/'
target_genre = 'disco'
sample_rate = 16000
genre_dir = os.path.join(base_path, target_genre)
wav_paths = sorted(glob.glob(f"{genre_dir}/*.wav"))
print(f"Loading {target_genre:10s}... {len(wav_paths)} files")
all_audio = {}
genre_audio = []
bad_files = []
for wav_path in tqdm(wav_paths):
    try:
        waveform, sample_rate = librosa.load(wav_path, sr=sample_rate, mono=True)
        genre_audio.append(waveform)
    except Exception as e:
        print(f"Failed to load {wav_path}: {e}")
        bad_files.append(wav_path)
all_audio[target_genre] = genre_audio
print(f"{len(bad_files)} files discarded.")

def chop_audio_segments(segment_length=5, overlap_seconds=0):
    np.random.seed(42)
    segment_samples = int(sample_rate * segment_length)
    overlap_samples = int(sample_rate * overlap_seconds)
    genre_segments = {}

    for genre, tracks in all_audio.items():
        segments = []
        for track_idx, y in enumerate(tracks):
            total_samples = len(y)
            for start in range(0, total_samples - segment_samples + 1, segment_samples - overlap_samples):
                chunk = y[start:start + segment_samples]
                if len(chunk) < segment_samples:
                    continue
                if np.abs(chunk).max() < 1e-4:
                    continue
                segments.append(chunk.copy())
        np.random.shuffle(segments)
        genre_segments[genre] = segments
        print(f"Genre {genre:10}: {len(segments)} segments ({segment_length}s, {overlap_seconds}s overlap)")

    return genre_segments

class AudioSegmentDataset(Dataset):
    def __init__(self, genre_segments_dict, genre, normalize=True, as_channels=True):
        self.samples = []
        self.labels = []
        segments = genre_segments_dict[genre]
        self.samples.extend(segments)
        self.labels.extend([genre] * len(segments))
        self.normalize = normalize
        self.as_channels = as_channels

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        x = self.samples[idx].astype(np.float32)
        if self.normalize:
            mx = np.max(np.abs(x))
            x = x / (mx if mx>0 else 1.0)
        if self.as_channels:
            x = np.expand_dims(x, axis=0)
        x = torch.from_numpy(x)
        y = self.labels[idx]
        return x, y

def save_waveform_and_spectrogram(wave, sample_rate, fname, spec_fname=None):
    int16 = np.int16(wave * 32767)
    wavwrite(fname, sample_rate, int16)
    if spec_fname:
        S = librosa.feature.melspectrogram(y=wave, sr=sample_rate, n_mels=64, fmax=8000)
        S_dB = librosa.power_to_db(S, ref=np.max)
        plt.figure(figsize=(5, 2))
        librosa.display.specshow(S_dB, sr=sample_rate, x_axis='time', y_axis='mel', cmap='magma')
        plt.title('')
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(spec_fname)
        plt.close()

def audio_to_melspec_img(audio, sample_rate=16000, n_mels=64, fmax=8000):
    S = librosa.feature.melspectrogram(y=audio, sr=sample_rate, n_mels=n_mels, fmax=fmax)
    S_dB = librosa.power_to_db(S, ref=np.max)
    S_norm = (S_dB - S_dB.min()) / (S_dB.max() - S_dB.min() + 1e-8)
    S_img = torch.tensor(S_norm).unsqueeze(0)
    return S_img

def compute_fid_on_spectrograms_torchfidelity(gen_samples, real_segments,sample_rate=16000, n_mels=64, n_per_set=100, device='cuda'):
    num = min(n_per_set, len(gen_samples), len(real_segments))
    fake_subset = gen_samples[:num]
    real_subset = random.sample(real_segments, num)

    with tempfile.TemporaryDirectory() as real_dir, tempfile.TemporaryDirectory() as fake_dir:
        for i, wav in enumerate(real_subset):
            img_tensor = audio_to_melspec_img(wav, sample_rate, n_mels)
            save_image(img_tensor, os.path.join(real_dir, f"real_{i}.png"))
        for i, wav in enumerate(fake_subset):
            img_tensor = audio_to_melspec_img(wav, sample_rate, n_mels)
            save_image(img_tensor, os.path.join(fake_dir, f"fake_{i}.png"))

        metrics = calculate_metrics(
            input1=fake_dir, input2=real_dir,
            cuda=device == 'cuda', isc=False, fid=True, kid=False, verbose=False
        )
    return metrics['frechet_inception_distance']

def save_audio_clips_to_folder(clips, sample_rate, folder):
    os.makedirs(folder, exist_ok=True)
    for i, clip in enumerate(clips):
        arr = clip.cpu().numpy() if hasattr(clip, "cpu") else clip
        if arr.ndim > 1:
            arr = arr.squeeze()
        arr = np.clip(arr, -1, 1)
        wavwrite(os.path.join(folder, f"{i}.wav"), sample_rate, (arr * 32767).astype(np.int16))

def hash_hyperparams(hyperparams):
    hstr = json.dumps(hyperparams, sort_keys=True)
    return hashlib.md5(hstr.encode('utf-8')).hexdigest()[:8]

def train_wavegan(
        run_id,
        audio_segments,
        netG,
        netD,
        train_loader,
        epochs,
        hyperparams,
        device,
        genre,
        checkpoint_interval=10,
):
    output_dir=f'/workspace/training_outputs/{run_id}/'

    lr_g = hyperparams.get('lr_g', 0.0001)
    lr_d = hyperparams.get('lr_d', 0.0001)
    beta1 = hyperparams.get('beta1', 0.5)
    noise_amplitude = hyperparams.get('noise_amplitude', 0.05)
    latent_dim = hyperparams.get('latent_dim', 100)

    criterion = BCEWithLogitsLoss()
    optG = optim.Adam(netG.parameters(), lr=lr_g, betas=(beta1, 0.999))
    optD = optim.Adam(netD.parameters(), lr=lr_d, betas=(beta1, 0.999))
    netG.to(device)
    netD.to(device)
    loss_curves = {"gen": [], "disc": []}
    fid_history = []

    os.makedirs(output_dir, exist_ok=True)
    sample_dir = os.path.join(output_dir, f'{genre}_samples')
    os.makedirs(sample_dir, exist_ok=True)

    epoch_times = []
    for epoch in range(1, epochs+1):
        epoch_start_time = time.time()
        g_losses, d_losses = [], []
        netG.train()
        netD.train()
        for batch in train_loader:
            real_audio, _ = batch
            batch_size = real_audio.size(0)
            real_audio = real_audio.to(device)
            noise = torch.randn(batch_size, latent_dim, device=device)
            fake_audio = netG(noise)

            label_real = torch.FloatTensor(batch_size).uniform_(0.8, 1.0).to(device)
            label_fake = torch.FloatTensor(batch_size).uniform_(0.0, 0.2).to(device)

            real_audio_noisy = real_audio + noise_amplitude * torch.randn_like(real_audio)
            fake_audio_noisy = fake_audio + noise_amplitude * torch.randn_like(fake_audio)

            netD.zero_grad()
            out_real = netD(real_audio_noisy)
            loss_real = criterion(out_real, label_real)
            out_fake = netD(fake_audio_noisy.detach())
            loss_fake = criterion(out_fake, label_fake)
            loss_D = (loss_real + loss_fake) / 2
            loss_D.backward()
            optD.step()

            netG.zero_grad()
            output_gen = netD(fake_audio)
            loss_G = criterion(output_gen, label_real)
            loss_G.backward()
            optG.step()

            d_losses.append(loss_D.item())
            g_losses.append(loss_G.item())

        mean_g = np.mean(g_losses)
        mean_d = np.mean(d_losses)
        loss_curves["gen"].append(mean_g)
        loss_curves["disc"].append(mean_d)

        if (epoch % checkpoint_interval == 0) or (epoch==epochs):
            # just save the audio and waveform files for now, FID and FAD will be analyzed later on my computer
            netG.eval()
            with torch.no_grad():
                fixed_noise = torch.randn(5, latent_dim, device=device)
                samples = netG(fixed_noise).cpu().numpy()
                for i, sample in enumerate(samples):
                    save_waveform_and_spectrogram(sample[0], sample_rate=16000, fname=os.path.join(sample_dir, f"ep{epoch}_samp{i}.wav"), spec_fname=os.path.join(sample_dir, f"ep{epoch}_samp{i}_spec.png"))
            print(f'Epoch {epoch} | Gen loss: {mean_g:.4f} | Disc loss: {mean_d:.4f}')

        epoch_time = time.time() - epoch_start_time
        epoch_times.append(epoch_time)
        avg_time_per_epoch = np.mean(epoch_times)
        epochs_remaining = epochs - epoch
        eta_minutes = (epochs_remaining * avg_time_per_epoch) / 60
        if (epoch % checkpoint_interval == 0) or (epoch==epochs):
            print(f"Estimated time remaining: {eta_minutes:.2f} min")

    # save models/optimizers when done
    ckpt_dir = os.path.join(output_dir, "checkpoints")
    os.makedirs(ckpt_dir, exist_ok=True)
    torch.save({
        'epoch': epoch,
        'netG_state_dict': netG.state_dict(),
        'netD_state_dict': netD.state_dict(),
        'optG_state_dict': optG.state_dict(),
        'optD_state_dict': optD.state_dict(),
        'hyperparams': hyperparams,
    }, os.path.join(ckpt_dir, f'checkpoint_epoch{epoch}.pt'))

    return loss_curves

class WaveGANGenerator(nn.Module):
    def __init__(self, latent_dim=100, gf_dim=64, norm='batch'):
        super().__init__()
        self.project = nn.Linear(latent_dim, 16 * gf_dim * 16)
        self.init_channels = 16 * gf_dim
        self.init_t = 16

        NormLayer = nn.InstanceNorm1d if norm=='instance' else nn.BatchNorm1d
        self.net = nn.Sequential(
            nn.ConvTranspose1d(16*gf_dim, 8*gf_dim, 25, stride=4, padding=11, output_padding=1),
            NormLayer(8*gf_dim), nn.ReLU(True),

            nn.ConvTranspose1d(8*gf_dim, 4*gf_dim, 25, 4, 11, 1),
            NormLayer(4*gf_dim), nn.ReLU(True),

            nn.ConvTranspose1d(4*gf_dim, 2*gf_dim, 25, 4, 11, 1),
            NormLayer(2*gf_dim), nn.ReLU(True),

            nn.ConvTranspose1d(2*gf_dim, gf_dim, 25, 4, 11, 1),
            NormLayer(gf_dim), nn.ReLU(True),

            nn.ConvTranspose1d(gf_dim, 1, 25, 4, 11, 1),
            nn.Tanh()
        )

    def forward(self, z):
        x = self.project(z)
        x = x.view(z.size(0), self.init_channels, self.init_t)
        x = self.net(x)
        return x

class WaveGANDiscriminator(nn.Module):
    def __init__(self, df_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(1, df_dim, 25, stride=4, padding=11), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv1d(df_dim, 2*df_dim, 25, 4, 11), nn.BatchNorm1d(2*df_dim), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv1d(2*df_dim, 4*df_dim, 25, 4, 11), nn.BatchNorm1d(4*df_dim), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv1d(4*df_dim, 8*df_dim, 25, 4, 11), nn.BatchNorm1d(8*df_dim), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv1d(8*df_dim, 16*df_dim, 25, 4, 11), nn.BatchNorm1d(16*df_dim), nn.LeakyReLU(0.2, inplace=True),
        )
        self.final = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256*df_dim, 1),
        )

    def forward(self, x):
        x = self.net(x)
        x = self.final(x)
        return x.view(-1)

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

def random_hyperparam_sample():
    return {
        "lr_g": 10 ** random.uniform(-4.5, -3),
        "lr_d": 10 ** random.uniform(-4.5, -3),
        "beta1": random.choice([0.3, 0.5, 0.7]),
        "noise_amplitude": random.choice([0.005, 0.01, 0.025, 0.05]),
        "latent_dim": 100,
        "norm": random.choice(['batch', 'instance']),
        "batch_size": random.choice([8, 16, 32]),
    }

def run_random_search(num_trials, genre, epochs, checkpoint_interval, data_segments, custom_hyperparams=None):
    finished = []
    for i in range(num_trials):
        hyperparams = custom_hyperparams if custom_hyperparams else random_hyperparam_sample()
        run_id = hash_hyperparams(hyperparams)
        print('-' * 60)
        print(f"[{i+1}/{num_trials}] Starting trial {run_id}")
        print(json.dumps(hyperparams, indent=2))
        dataset = AudioSegmentDataset(data_segments, genre)
        train_loader = DataLoader(dataset, batch_size=hyperparams['batch_size'], shuffle=True, drop_last=True, num_workers=4, pin_memory=True)
        netG = WaveGANGenerator(latent_dim=hyperparams['latent_dim'], gf_dim=64, norm=hyperparams['norm'])
        netD = WaveGANDiscriminator(df_dim=64)
        netG.apply(weights_init)
        netD.apply(weights_init)

        hp = {k:v for k,v in hyperparams.items()}
        hp['run_id'] = run_id
        # save hyperparams json
        os.makedirs('/workspace/training_outputs/', exist_ok=True)
        with open(f'/workspace/training_outputs/{run_id}.json', 'w') as f:
            json.dump(hyperparams, f, indent=4)
        loss_curves = train_wavegan(run_id, data_segments, netG, netD, train_loader, epochs=epochs, hyperparams=hp, device=device, genre=genre, checkpoint_interval=checkpoint_interval)
        summary = {
            "run_id": run_id,
            "hyperparams": json.dumps(hyperparams),
            "final_gen_loss": loss_curves['gen'][-1] if loss_curves['gen'] else None,
            "final_disc_loss": loss_curves['disc'][-1] if loss_curves['disc'] else None,
        }
        finished.append(summary)
        log_file = f'/workspace/training_outputs/{run_id}.csv'
        pd.DataFrame([summary]).to_csv(log_file, mode='a', header=not os.path.exists(log_file), index=False)
        print(f"Completed trial {run_id}")

    return finished

audio_segments_1s = chop_audio_segments(16384 / sample_rate, overlap_seconds=0.512)

In [None]:
# run random search (for finding hyperparams)
start_time = time.time()
results = run_random_search(num_trials=20, genre=target_genre, epochs=500, checkpoint_interval=20, data_segments=audio_segments_1s)
run_time_seconds = time.time() - start_time
print(f'Finished in {int(run_time_seconds)} seconds')