# Main Script for a Vocoder

### Imports

In [1]:
try: 
    import librosa
except:
    !pip install librosa


#Set Dir 
import sys, os
sys.path.append(os.path.abspath('..'))

# Torch
import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader, Subset
import torch.optim as optim


# Utils
import numpy as np
from numpy import ndarray
import logging, librosa
from typing import Sequence, Optional, Callable


# Base Scripts
from Libraries.Utils import *
from MainScripts.Conf import conf

Collecting librosa
  Downloading librosa-0.11.0-py3-none-any.whl.metadata (8.7 kB)
Collecting audioread>=2.1.9 (from librosa)
  Downloading audioread-3.0.1-py3-none-any.whl.metadata (8.4 kB)
Collecting numba>=0.51.0 (from librosa)
  Downloading numba-0.61.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (2.8 kB)
Collecting soundfile>=0.12.1 (from librosa)
  Downloading soundfile-0.13.1-py2.py3-none-manylinux_2_28_x86_64.whl.metadata (16 kB)
Collecting pooch>=1.1 (from librosa)
  Downloading pooch-1.8.2-py3-none-any.whl.metadata (10 kB)
Collecting soxr>=0.3.2 (from librosa)
  Downloading soxr-0.5.0.post1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.6 kB)
Collecting msgpack>=1.0 (from librosa)
  Downloading msgpack-1.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.4 kB)
Collecting llvmlite<0.45,>=0.44.0dev0 (from numba>=0.51.0->librosa)
  Downloading llvmlite-0.44.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x

### Config

### General

In [2]:
remote_kernel: bool = True #Set to true if using a remote Kernel changes the file structure
model_name: str = "MelGan_v1"
training_data_name: str = "training_full_mel"
training_label_name: str = "training_full_wave"
full_model_path: str = path_to_remote_path("{}/{}".format(conf["paths"].model_path, model_name + ".pth"), remote_kernel)

Logging

In [3]:
logging_level: int = logging.INFO
logging.basicConfig(level=logging_level, format='%(asctime)s - %(levelname)s - %(message)s')
logger: logging.Logger = logging.getLogger(__name__)

Training Params

In [4]:
device: str = "cuda" if torch.cuda.is_available() else "cpu"
n_training_samples: int = 2496 // 2
batch_size: int = 64
tensor_waver_dim: list = [batch_size, 1, 2**17] #B, C, H = Batch, channels, Time domain
tensor_mel_dim: list = [batch_size, 96, 512]
learning_rate: float = 1e-4
epochs: int = 300
restart_training: bool = True
checkpoint_freq: int = 5


### Data Loading

In [5]:
mel_data: ndarray = load_training_data(path_to_remote_path("{}/{}".format(conf["paths"].data_path, training_data_name + ".npy"), remote_kernel))
audio_data: ndarray = load_training_data(path_to_remote_path("{}/{}".format(conf["paths"].data_path, training_label_name + ".npy"), remote_kernel))

np.random.seed(50)
indicies: ndarray = np.arange(mel_data.shape[0])
np.random.shuffle(indicies)
mel_data = mel_data[indicies]
audio_data = audio_data[indicies]


data_loader = create_dataloader(Audio_Data(mel_data[:n_training_samples], audio_data[:n_training_samples]), batch_size)


### Model

In [6]:
class Mel2Wave(nn.Module):
    def __init__(self, in_channels: int, intermediate_channels: int = 512) -> None:
        super(Mel2Wave, self).__init__()
        self.block = nn.Sequential(
            nn.Conv1d(in_channels, intermediate_channels, kernel_size=7, stride=1, padding=3),
            
            Upsample(in_channels=intermediate_channels, out_channels=intermediate_channels // 2, factor=8),
            ResStack(channels=intermediate_channels // 2),

            Upsample(in_channels=intermediate_channels // 2, out_channels=intermediate_channels // 4, factor=8),
            ResStack(channels=intermediate_channels // 4),

            Upsample(in_channels=intermediate_channels // 4, out_channels=intermediate_channels // 8, factor=2),
            ResStack(channels=intermediate_channels // 8),

            Upsample(in_channels=intermediate_channels // 8, out_channels=intermediate_channels // 16, factor=2),
            ResStack(channels=intermediate_channels // 16),

            nn.Conv1d(intermediate_channels // 16, 1, kernel_size=7, stride=1, padding=3),
            nn.Tanh()
        )
    def forward(self, x: Tensor) -> Tensor:
        return self.block(x)

class Upsample(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, factor: int) -> None:
        super(Upsample, self).__init__()
        kernel_size = factor * 2
        stride = factor
        padding = factor // 2 + factor % 2
        output_padding = factor % 2
        self.block = nn.Sequential(
            nn.ConvTranspose1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding),
            nn.LeakyReLU(0.2)
        )
    def forward(self, x: Tensor) -> Tensor:
        return self.block(x)

class DilConv(nn.Module):
    def __init__(self, channels: int, dilation: int) -> None:
        super(DilConv, self).__init__()
        self.block = nn.Sequential(
            nn.LeakyReLU(0.2),
            nn.Conv1d(channels, channels, kernel_size=3, stride=1, padding=dilation, dilation=dilation),
            nn.LeakyReLU(0.2),
            nn.Conv1d(channels, channels, kernel_size=1)
        )
    def forward(self, x: Tensor) -> Tensor:
        return x + self.block(x)

class ResStack(nn.Module):
    def __init__(self, channels: int) -> None:
        super(ResStack, self).__init__()
        self.block = nn.Sequential(
            DilConv(channels, 1),
            DilConv(channels, 3),
            DilConv(channels, 9)
        )
    def forward(self, x: Tensor) -> Tensor:
        return self.block(x)

class DiscriminatorBlock(nn.Module):
    def __init__(self, in_channels: int = 1, channels: int = 16) -> None:
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv1d(in_channels, channels, kernel_size=15, stride=1, padding=7),
            nn.LeakyReLU(0.2),

            nn.Conv1d(channels, channels*4, kernel_size=41, stride=4, padding=20, groups=4),
            nn.LeakyReLU(0.2),

            nn.Conv1d(channels*4, channels*8, kernel_size=41, stride=4, padding=20, groups=16),
            nn.LeakyReLU(0.2),

            nn.Conv1d(channels*8, channels*16, kernel_size=41, stride=4, padding=20, groups=16),
            nn.LeakyReLU(0.2),

            nn.Conv1d(channels*16, channels*16, kernel_size=5, stride=1, padding=2),
            nn.LeakyReLU(0.2),

            nn.Conv1d(channels*16, 1, kernel_size=3, stride=1, padding=1)
        )

    def forward(self, x: Tensor) -> tuple[Tensor,...]:
        feature_maps = []
        for layer in self.block:
            x = layer(x)
            if isinstance(layer, nn.Conv1d):
                feature_maps.append(x)
        return x, feature_maps

class MultiScaleDiscriminator(nn.Module):
    def __init__(self, in_channels: int = 1, channels: int = 16) -> None:
        super().__init__()
        self.pooling = nn.AvgPool1d(kernel_size=4, stride=2, padding=1)
        self.discriminators = nn.ModuleList([
            DiscriminatorBlock(in_channels, channels),
            DiscriminatorBlock(in_channels, channels),
            DiscriminatorBlock(in_channels, channels),
        ])

    def forward(self, x: Tensor) -> tuple[Tensor, ...]:
        outputs = []
        feature_maps = []
        for disc in self.discriminators:
            out, fmap = disc(x)
            outputs.append(out)
            feature_maps.append(fmap)
            x = self.pooling(x)
        return outputs, feature_maps


In [7]:
generator = Mel2Wave(in_channels=96, intermediate_channels=512).to(device)
discriminator=MultiScaleDiscriminator(in_channels=1, channels=16).to(device)

In [8]:
gen_optim = optim.AdamW(generator.parameters(), lr=learning_rate, betas=(0.5, 0.9))
disc_optim = optim.AdamW(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.9))

start_epoch: int = 0
if os.path.exists(full_model_path):
    model = torch.load(full_model_path, map_location=device)
    generator.load_state_dict(model["generator"])
    discriminator.load_state_dict(model["discriminator"])
    if not restart_training:
        gen_optim.load_state_dict(model["gen_optim"])
        disc_optim.load_state_dict(model["disc_optim"])
        start_epoch = model.get("epoch", 0)
    logger.info(f"Model {model_name} loaded with {count_parameters(generator)} and {count_parameters(discriminator)} Parameters")
else: 
    logger.info(f"Model {model_name} created with {count_parameters(generator)} and {count_parameters(discriminator)} Parameters")

2025-06-07 07:41:00,127 - INFO - Model MelGan_v1 loaded with ~4.055M and ~1.335M Parameters


In [17]:
adversarial_loss = nn.MSELoss()
l1_loss = nn.L1Loss()

recon_loss_weight: float = 150.0
fm_loss_weight: float = 10.0
label_smooth_val: float = 0.1

n_gen_updates: int = 2
n_disc_updates: int = 1

In [19]:
logger.info(f"Training started on {device}")
scaler = torch.cuda.amp.GradScaler()
loss_d_list: list = []
loss_g_list: list = []
total_time: float = 0.0

for e in range(0, epochs):
    total_d_loss: float = 0
    total_g_loss: float = 0
    start_time: float = time.time()

    for b_idx, (mel, audio) in enumerate(data_loader):
            mel, audio = mel.to(device), audio.to(device).unsqueeze(1)
            for _ in range(n_disc_updates):
                with torch.autocast(device_type=device):
                    fake_waveform = generator(mel).detach()
                    real_preds, _ = discriminator(audio)
                    fake_preds, _ = discriminator(fake_waveform)
                
                d_loss = 0
                for real_pred, fake_pred in zip(real_preds, fake_preds):
                    d_loss += adversarial_loss(real_pred, torch.full_like(real_pred, 1 - label_smooth_val))
                    d_loss += adversarial_loss(fake_pred, torch.full_like(fake_pred, label_smooth_val))
                
                disc_optim.zero_grad()
                scaler.scale(d_loss).backward()
                scaler.step(disc_optim)
                scaler.update()

                total_d_loss += d_loss.item()
                if np.isnan(d_loss.item()):
                    logger.info("Breaking due to NaN Discriminator loss.")
                    break

            for _ in range(n_gen_updates):
                with torch.autocast(device_type=device):
                    fake_waveform = generator(mel)
                    fake_preds, fake_feats = discriminator(fake_waveform)
                    _, real_feats = discriminator(audio)
                
                g_adv_loss = 0
                for fake_pred in fake_preds:
                    g_adv_loss += adversarial_loss(fake_pred, torch.ones_like(fake_pred))

                fm_loss = 0
                num_fmaps = 0
                for real_fmaps, fake_fmaps in zip(real_feats, fake_feats):
                    for real_fmap, fake_fmap in zip(real_fmaps, fake_fmaps):
                        fm_loss += l1_loss(fake_fmap, real_fmap.detach())
                        num_fmaps += 1
                fm_loss = fm_loss / num_fmaps

                recon_loss = l1_loss(fake_waveform, audio)
                g_loss = g_adv_loss + fm_loss_weight * fm_loss + recon_loss_weight * recon_loss

                gen_optim.zero_grad()
                scaler.scale(g_loss).backward()
                scaler.step(gen_optim)
                scaler.update()

                total_g_loss += g_loss.item()
                if np.isnan(g_loss.item()):
                    logger.info("Breaking due to NaN Generator loss.")
                    break

            if logger.getEffectiveLevel() == LIGHT_DEBUG:
                current_batch = b_idx + 1
                print(f"\r{time.strftime('%Y-%m-%d %H:%M:%S')},000 - LIGHT_DEBUG - Batch {current_batch:03d}/{len(data_loader):03d} D/G Loss: {d_loss.item():.3f} {g_loss.item():.3f}", end='', flush=True)
    else:
        if logger.getEffectiveLevel() == LIGHT_DEBUG:
            print(flush=True)

        avg_d_loss = total_d_loss / len(data_loader)
        avg_g_loss = total_g_loss / len(data_loader)
        loss_d_list.append(avg_d_loss)
        loss_g_list.append(avg_g_loss)


        epoch_time = time.time() - start_time
        total_time += epoch_time
        remaining_time = int((total_time / (e + 1)) * (epochs - e - 1))

        logger.info(f"Epoch {e + 1:03d}: Avg. D/G Loss: {avg_d_loss:.4e}, {avg_g_loss:.4e} Remaining Time: {remaining_time // 3600:02d}h {(remaining_time % 3600) // 60:02d}min {round(remaining_time % 60):02d}s LR: {gen_optim.param_groups[0]['lr']:.5e} ")
        
        if checkpoint_freq > 0 and (e + 1) % checkpoint_freq == 0:
            checkpoint_path: str = f"{full_model_path[:-4]}_epoch_{e + 1:03d}.pth"
            torch.save({"generator": generator.state_dict(), "discriminator": discriminator.state_dict(), "gen_optim": gen_optim.state_dict(), "disc_optim": disc_optim.state_dict() , "epoch": e + 1}, checkpoint_path)
            if e + 1 != checkpoint_freq:
                last_path: str = f"{full_model_path[:-4]}_epoch_{(e + 1) - checkpoint_freq:03d}.pth"
                del_if_exists(last_path)
            logger.light_debug(f"Checkpoint saved model to {checkpoint_path}")
        continue
    break


torch.save({"generator": generator.state_dict(), "discriminator": discriminator.state_dict(), "gen_optim": gen_optim.state_dict(), "disc_optim": disc_optim.state_dict() , "epoch": e + 1}, full_model_path)

logger.light_debug(f"Saved model to {full_model_path}")

if checkpoint_freq > 0:
    checkpoint_path: str = f"{full_model_path[:-4]}_epoch_{e + 1 - ((e + 1) % checkpoint_freq):03d}.pth"
    del_if_exists(checkpoint_path)

2025-06-07 07:51:08,749 - INFO - Training started on cuda
2025-06-07 07:51:48,150 - INFO - Epoch 001: Avg. D/G Loss: 3.1395e-01, 5.3380e+01 Remaining Time: 03h 16min 20s LR: 1.00000e-04 
2025-06-07 07:52:28,310 - INFO - Epoch 002: Avg. D/G Loss: 3.0397e-01, 5.2804e+01 Remaining Time: 03h 17min 34s LR: 1.00000e-04 
2025-06-07 07:53:07,992 - INFO - Epoch 003: Avg. D/G Loss: 3.0382e-01, 5.2784e+01 Remaining Time: 03h 16min 44s LR: 1.00000e-04 
2025-06-07 07:53:46,492 - INFO - Epoch 004: Avg. D/G Loss: 2.9763e-01, 5.2760e+01 Remaining Time: 03h 14min 32s LR: 1.00000e-04 
2025-06-07 07:54:27,129 - INFO - Epoch 005: Avg. D/G Loss: 2.9238e-01, 5.3003e+01 Remaining Time: 03h 15min 04s LR: 1.00000e-04 
2025-06-07 07:55:06,638 - INFO - Epoch 006: Avg. D/G Loss: 2.9623e-01, 5.2610e+01 Remaining Time: 03h 14min 11s LR: 1.00000e-04 
2025-06-07 07:55:48,520 - INFO - Epoch 007: Avg. D/G Loss: 3.0015e-01, 5.2692e+01 Remaining Time: 03h 15min 06s LR: 1.00000e-04 
2025-06-07 07:56:28,745 - INFO - Epoch 

### Convert to wave

In [22]:
file_idx: int = 100
with torch.no_grad():
    generated_wave = generator(torch.tensor(mel_data[file_idx]).unsqueeze(0).to(device))
save_audio_file(generated_wave.cpu().numpy()[0,0], "test.wav", 32000)
save_audio_file(audio_data[file_idx], "test_real.wav", 32000)