# Main Script HiFiGAN

### Imports

In [1]:
try: 
    import librosa
except:
    !pip install librosa
try: 
    import optuna, plotly
except:
    !pip install optuna
    !pip install plotly

#Set Dir 
import sys, os
sys.path.append(os.path.abspath('..'))

# Torch
import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader, Subset
from torch.nn.utils import weight_norm 
from torch.nn import utils
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import torchaudio.transforms as T
import optuna, plotly
from optuna.importance import get_param_importances
from optuna.visualization import plot_param_importances

# Utils
import numpy as np
from numpy import ndarray
import logging, librosa, itertools, tensorboard
from typing import Sequence, Optional, Callable


# Base Scripts
from Libraries.Utils import *
from MainScripts.Conf import conf

[0mCollecting plotly
  Downloading plotly-6.1.2-py3-none-any.whl.metadata (6.9 kB)
Collecting narwhals>=1.15.1 (from plotly)
  Downloading narwhals-1.43.0-py3-none-any.whl.metadata (11 kB)
Downloading plotly-6.1.2-py3-none-any.whl (16.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.3/16.3 MB[0m [31m113.8 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hDownloading narwhals-1.43.0-py3-none-any.whl (362 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m362.7/362.7 kB[0m [31m75.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: narwhals, plotly
Successfully installed narwhals-1.43.0 plotly-6.1.2
[0m

2025-06-16 17:56:57.327530: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-06-16 17:56:57.327588: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-06-16 17:56:57.328758: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-16 17:56:57.334814: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### Config

### General

In [2]:
remote_kernel: bool = True #Set to true if using a remote Kernel changes the file structure
model_name: str = "HiFiGAN_v1"
training_data_name: str = "training_full_mel"
training_label_name: str = "training_full_wave"
full_model_path: str = path_to_remote_path("{}/{}".format(conf["paths"].model_path, model_name + ".pth"), remote_kernel)
sw = SummaryWriter(path_to_remote_path("{}/{}".format(conf["paths"].model_path, 'logs'), remote_kernel))

Logging

In [3]:
logging_level: int = LIGHT_DEBUG
logging.basicConfig(level=logging_level, format='%(asctime)s - %(levelname)s - %(message)s')
logger: logging.Logger = logging.getLogger(__name__)

Training Params

In [4]:
device: str = "cuda" if torch.cuda.is_available() else "cpu"
n_training_samples: int = 2496 // 2
batch_size: int = 16
tensor_wave_dim: list = [batch_size, 1, 2**17] #B, C, H = Batch, channels, Time domain
tensor_mel_dim: list = [batch_size, 96, 512]
learning_rate: float = 1e-4
b1, b2 = [0.7, 0.99]
epochs: int = 300
restart_training: bool = True
checkpoint_freq: int = 5
num_workers: int = 4


### Data Loading

In [5]:
mel_data: ndarray = load_training_data(path_to_remote_path("{}/{}".format(conf["paths"].data_path, training_data_name + ".npy"), remote_kernel))
audio_data: ndarray = load_training_data(path_to_remote_path("{}/{}".format(conf["paths"].data_path, training_label_name + ".npy"), remote_kernel))

np.random.seed(50)
indicies: ndarray = np.arange(mel_data.shape[0])
np.random.shuffle(indicies)
mel_data = mel_data[indicies]
audio_data = audio_data[indicies]

data_loader = create_dataloader(Audio_Data(mel_data[:n_training_samples], audio_data[:n_training_samples]), batch_size, num_workers)


2025-06-16 17:57:00,863 - LIGHT_DEBUG - Ndarray loaded from Data/training_full_mel.npy of shape: (6867, 96, 512)
2025-06-16 17:57:02,183 - LIGHT_DEBUG - Ndarray loaded from Data/training_full_wave.npy of shape: (6867, 131072)


### Original Impl

In [6]:
def init_weights(m, mean=0.0, std=0.01):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        m.weight.data.normal_(mean, std)


def apply_weight_norm(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        weight_norm(m)


def get_padding(kernel_size, dilation=1):
    return int((kernel_size*dilation - dilation)/2)


class ResBlock1(nn.Module):
    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
        super(ResBlock1, self).__init__()
        self.convs1 = nn.ModuleList([
            weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
                                padding=get_padding(kernel_size, dilation[0]))),
            weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
                                padding=get_padding(kernel_size, dilation[1]))),
            weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
                                padding=get_padding(kernel_size, dilation[2])))
        ])
        self.convs1.apply(init_weights)

        self.convs2 = nn.ModuleList([
            weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1,
                                padding=get_padding(kernel_size, 1))),
            weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1,
                                padding=get_padding(kernel_size, 1))),
            weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1,
                                padding=get_padding(kernel_size, 1)))
        ])
        self.convs2.apply(init_weights)

    def forward(self, x):
        for c1, c2 in zip(self.convs1, self.convs2):
            xt = F.leaky_relu(x, 0.1)
            xt = c1(xt)
            xt = F.leaky_relu(xt, 0.1)
            xt = c2(xt)
            x = xt + x
        return x

    def remove_weight_norm(self):
        for l in self.convs1:
            utils.remove_weight_norm(l)
        for l in self.convs2:
            utils.remove_weight_norm(l)


class ResBlock2(nn.Module):
    def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
        super(ResBlock2, self).__init__()
        self.convs = nn.ModuleList([
            weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
                                padding=get_padding(kernel_size, dilation[0]))),
            weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
                                padding=get_padding(kernel_size, dilation[1])))
        ])
        self.convs.apply(init_weights)

    def forward(self, x):
        for c in self.convs:
            xt = F.leaky_relu(x, 0.1)
            xt = c(xt)
            x = xt + x
        return x

    def remove_weight_norm(self):
        for l in self.convs:
            utils.remove_weight_norm(l)

class Generator(nn.Module):
    def __init__(self, n_mel_channels: int, resblock_kernel_sizes: list[int], upsample_rates: list[int], upsample_initial_channel: int, upsample_kernel_sizes: list[int], resblock_dilation_sizes: list[int], resblock: int = 1):
        super(Generator, self).__init__()
        self.num_kernels = len(resblock_kernel_sizes)
        self.num_upsamples = len(upsample_rates)
        self.conv_pre = weight_norm(nn.Conv1d(n_mel_channels, upsample_initial_channel, 7, 1, padding=3))
        resblock = ResBlock1 if resblock == '1' else ResBlock2

        self.ups = nn.ModuleList()
        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
            self.ups.append(weight_norm(
                nn.ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
                                k, u, padding=(k-u)//2)))

        self.resblocks = nn.ModuleList()
        for i in range(len(self.ups)):
            ch = upsample_initial_channel//(2**(i+1))
            for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
                self.resblocks.append(resblock(ch, k, d))

        self.conv_post = weight_norm(nn.Conv1d(ch, 1, 7, 1, padding=3))
        self.ups.apply(init_weights)
        self.conv_post.apply(init_weights)

    def forward(self, x):
        x = self.conv_pre(x)
        for i in range(self.num_upsamples):
            x = F.leaky_relu(x, 0.1)
            x = self.ups[i](x)
            xs = None
            for j in range(self.num_kernels):
                if xs is None:
                    xs = self.resblocks[i*self.num_kernels+j](x)
                else:
                    xs += self.resblocks[i*self.num_kernels+j](x)
            x = xs / self.num_kernels
        x = F.leaky_relu(x)
        x = self.conv_post(x)
        x = torch.tanh(x)

        return x

    def remove_weight_norm(self):
        logger.light_debug('Removing weight norm...')
        for l in self.ups:
            utils.remove_weight_norm(l)
        for l in self.resblocks:
            l.remove_weight_norm()
        utils.remove_weight_norm(self.conv_pre)
        utils.remove_weight_norm(self.conv_post)


class DiscriminatorP(nn.Module):
    def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
        super(DiscriminatorP, self).__init__()
        self.period = period
        norm_f = weight_norm if use_spectral_norm == False else utils.spectral_norm
        self.convs = nn.ModuleList([
            norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
            norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
            norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
            norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
            norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
        ])
        self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))

    def forward(self, x):
        fmap = []

        # 1d to 2d
        b, c, t = x.shape
        if t % self.period != 0: # pad first
            n_pad = self.period - (t % self.period)
            x = F.pad(x, (0, n_pad), "reflect")
            t = t + n_pad
        x = x.view(b, c, t // self.period, self.period)

        for l in self.convs:
            x = l(x)
            x = F.leaky_relu(x, 0.1)
            fmap.append(x)
        x = self.conv_post(x)
        fmap.append(x)
        x = torch.flatten(x, 1, -1)

        return x, fmap


class MultiPeriodDiscriminator(nn.Module):
    def __init__(self):
        super(MultiPeriodDiscriminator, self).__init__()
        self.discriminators = nn.ModuleList([
            DiscriminatorP(2),
            DiscriminatorP(3),
            DiscriminatorP(5),
            DiscriminatorP(7),
            DiscriminatorP(11),
        ])

    def forward(self, y, y_hat):
        y_d_rs = []
        y_d_gs = []
        fmap_rs = []
        fmap_gs = []
        for i, d in enumerate(self.discriminators):
            y_d_r, fmap_r = d(y)
            y_d_g, fmap_g = d(y_hat)
            y_d_rs.append(y_d_r)
            fmap_rs.append(fmap_r)
            y_d_gs.append(y_d_g)
            fmap_gs.append(fmap_g)

        return y_d_rs, y_d_gs, fmap_rs, fmap_gs


class DiscriminatorS(torch.nn.Module):
    def __init__(self, use_spectral_norm=False):
        super(DiscriminatorS, self).__init__()
        norm_f = weight_norm if use_spectral_norm == False else utils.spectral_norm
        self.convs = nn.ModuleList([
            norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)),
            norm_f(nn.Conv1d(128, 128, 41, 2, groups=4, padding=20)),
            norm_f(nn.Conv1d(128, 256, 41, 2, groups=16, padding=20)),
            norm_f(nn.Conv1d(256, 512, 41, 4, groups=16, padding=20)),
            norm_f(nn.Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
            norm_f(nn.Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
            norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
        ])
        self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))

    def forward(self, x):
        fmap = []
        for l in self.convs:
            x = l(x)
            x = F.leaky_relu(x, 0.1)
            fmap.append(x)
        x = self.conv_post(x)
        fmap.append(x)
        x = torch.flatten(x, 1, -1)

        return x, fmap


class MultiScaleDiscriminator(torch.nn.Module):
    def __init__(self):
        super(MultiScaleDiscriminator, self).__init__()
        self.discriminators = nn.ModuleList([
            DiscriminatorS(use_spectral_norm=True),
            DiscriminatorS(),
            DiscriminatorS(),
        ])
        self.meanpools = nn.ModuleList([
            nn.AvgPool1d(4, 2, padding=2),
            nn.AvgPool1d(4, 2, padding=2)
        ])

    def forward(self, y, y_hat):
        y_d_rs = []
        y_d_gs = []
        fmap_rs = []
        fmap_gs = []
        for i, d in enumerate(self.discriminators):
            if i != 0:
                y = self.meanpools[i-1](y)
                y_hat = self.meanpools[i-1](y_hat)
            y_d_r, fmap_r = d(y)
            y_d_g, fmap_g = d(y_hat)
            y_d_rs.append(y_d_r)
            fmap_rs.append(fmap_r)
            y_d_gs.append(y_d_g)
            fmap_gs.append(fmap_g)

        return y_d_rs, y_d_gs, fmap_rs, fmap_gs


def feature_loss(fmap_r, fmap_g):
    loss = 0
    for dr, dg in zip(fmap_r, fmap_g):
        for rl, gl in zip(dr, dg):
            loss += torch.mean(torch.abs(rl - gl))

    return loss*2


def discriminator_loss(disc_real_outputs, disc_generated_outputs):
    loss = 0
    r_losses = []
    g_losses = []
    for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
        r_loss = torch.mean((1-dr)**2)
        g_loss = torch.mean(dg**2)
        loss += (r_loss + g_loss)
        r_losses.append(r_loss.item())
        g_losses.append(g_loss.item())

    return loss, r_losses, g_losses


def generator_loss(disc_outputs):
    loss = 0
    gen_losses = []
    for dg in disc_outputs:
        l = torch.mean((1-dg)**2)
        gen_losses.append(l)
        loss += l

    return loss, gen_losses

In [7]:
generator = Generator(
                    n_mel_channels=96,
                    upsample_rates=[8,8,2,2],
                    upsample_kernel_sizes=[16,16,4,4],
                    upsample_initial_channel=512,
                    resblock_kernel_sizes=[3,7,11],
                    resblock_dilation_sizes=[[1,3,5], [1,3,5], [1,3,5]],
                    resblock=1
                ).to(device)
mpd = MultiPeriodDiscriminator().to(device)
msd = MultiScaleDiscriminator().to(device)



In [8]:
optim_g = torch.optim.AdamW(generator.parameters(), learning_rate, betas=[b1, b2])
optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()),
                                learning_rate, betas=[b1, b2])
gen_lr_scheduler = optim.lr_scheduler.ExponentialLR(optim_g, gamma=0.999)
disc_lr_scheduler = optim.lr_scheduler.ExponentialLR(optim_d, gamma=0.999)
start_epoch: int = 0
if os.path.exists(full_model_path):
    model = torch.load(full_model_path, map_location=device)
    generator.load_state_dict(model["generator"])
    msd.load_state_dict(model["msd"])
    mpd.load_state_dict(model["mpd"])
    if not restart_training:
        optim_g.load_state_dict(model["optim_g"])
        optim_d.load_state_dict(model["optim_d"])
        start_epoch = model.get("epoch", 0)
    logger.info(f"Model {model_name} loaded with {count_parameters(generator)} G and {count_parameters(mpd)}, {count_parameters(msd)} D Parameters")
else: 
    logger.info(f"Model {model_name} loaded with {count_parameters(generator)} G and {count_parameters(mpd)}, {count_parameters(msd)} D Parameters")

2025-06-16 17:57:05,968 - INFO - Model HiFiGAN_v1 loaded with ~6.671M G and ~41.10M, ~29.61M D Parameters


#### Optuna

In [9]:
def static_model() -> nn.Module:
    generator = Generator(
                    n_mel_channels=96,
                    upsample_rates=[8,8,2,2],
                    upsample_kernel_sizes=[16,16,4,4],
                    upsample_initial_channel=512,
                    resblock_kernel_sizes=[3,7,11],
                    resblock_dilation_sizes=[[1,3,5], [1,3,5], [1,3,5]],
                    resblock=1
                ).to(device)
    mpd = MultiPeriodDiscriminator().to(device)
    msd = MultiScaleDiscriminator().to(device)
    return generator, mpd, msd

def objective(trial: optuna.Trial) -> float:
    lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True)
    b1 = trial.suggest_float("b1", 0.4, 0.99)
    b2 = trial.suggest_float("b2", 0.4, 0.999)
    lr_decay = trial.suggest_float("lr_decay", 0.7, 0.99999)
    generator, mpd, msd = static_model()
    optim_g = torch.optim.AdamW(generator.parameters(), lr, betas=[b1, b2])
    optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()),
                                    lr, betas=[b1, b2])
    gen_lr_scheduler = optim.lr_scheduler.ExponentialLR(optim_g, gamma=lr_decay)
    disc_lr_scheduler = optim.lr_scheduler.ExponentialLR(optim_d, gamma=lr_decay)

    
    n_epochs = 10
    best_reconst_loss: float = float('inf')
    generator.train()
    mpd.train()
    msd.train()
    for e in range(0, n_epochs):
        total_reconst_loss: float = 0

        for b_idx, (mel, audio) in enumerate(data_loader):
            mel, audio = mel.to(device), audio.to(device).unsqueeze(1)
            with torch.autocast(device_type=device):
                generated_audio = generator(mel)

            generated_mel = T.MelSpectrogram(sample_rate = 32000, n_fft=1023, hop_length=256, n_mels=96, f_min=30).to(device)(generated_audio.squeeze(1))
            optim_d.zero_grad()

            with torch.autocast(device_type=device):
                real_mpd_scores, fake_mpd_scores, _, _ = mpd(audio, generated_audio.detach())
            loss_d_s, mpd_loss_real, mpd_loss_fake = discriminator_loss(real_mpd_scores, fake_mpd_scores)
            
            with torch.autocast(device_type=device):
                real_msd_scores, fake_msd_scores, _, _ = msd(audio, generated_audio.detach())
            loss_d_f, msd_loss_real, msd_loss_fake = discriminator_loss(real_msd_scores, fake_msd_scores)

            total_disc_loss: Tensor = loss_d_s + loss_d_f

            total_disc_loss.backward()
            optim_d.step()

            optim_g.zero_grad()

            mel_loss = F.l1_loss(mel, generated_mel) * 45

            with torch.autocast(device_type=device):
                real_mpd_scores, fake_mpd_scores, real_mpd_features, fake_mpd_features = mpd(audio, generated_audio)
                real_msd_scores, fake_msd_scores, real_msd_features, fake_msd_features = msd(audio, generated_audio)
            
            mpd_feature_loss = feature_loss(real_mpd_features, fake_mpd_features)
            msd_feature_loss = feature_loss(real_msd_features, fake_msd_features)
            mpd_gen_loss, _ = generator_loss(fake_mpd_scores)
            msd_gen_loss, _ = generator_loss(fake_msd_scores)
            
            total_reconst_loss += (F.l1_loss(audio, generated_audio) + mel_loss / 45) / 2

            total_gen_loss = mpd_gen_loss + msd_gen_loss + mpd_feature_loss + msd_feature_loss + mel_loss

            total_gen_loss.backward()
            optim_g.step()

            avg_reconst_loss = total_reconst_loss / len(data_loader)


            gen_lr_scheduler.step()
            disc_lr_scheduler.step()

    
        if avg_reconst_loss < best_reconst_loss:
            best_reconst_loss = avg_reconst_loss
        trial.report(avg_reconst_loss, e)
        if trial.should_prune():
            raise optuna.TrialPruned()
    return best_reconst_loss

def run_optim(n_trials: int, name: str ="main_study_wave") -> None:
    study = optuna.create_study(direction="minimize", pruner=optuna.pruners.MedianPruner(),study_name=name)
    study.optimize(objective, n_trials=n_trials)
    logger.info("Finished Study")
    logger.info(f"Best trial: {study.best_trial} with value: {study.best_trial.value} using params:")
    for key, val in study.best_trial.params.items():
        logger.info(f"{key}:{val}")
    
    logger.info("Param importance:")
    for param, importance in get_param_importances(study).items():
            logger.info(f"{param}: {importance:.4f}")
    fig = plot_param_importances(study)
    fig.show()

In [None]:
run_optim(25, "main_study")

[I 2025-06-16 17:57:06,176] A new study created in memory with name: main_study
[I 2025-06-16 18:19:40,773] Trial 0 finished with value: 0.23749111592769623 and parameters: {'lr': 0.0006364028564452524, 'b1': 0.7872070788505228, 'b2': 0.5387522714858821, 'lr_decay': 0.8797010728452869}. Best is trial 0 with value: 0.23749111592769623.
[I 2025-06-16 18:42:15,557] Trial 1 finished with value: 0.2389201670885086 and parameters: {'lr': 0.0006317571936581328, 'b1': 0.9317838884335691, 'b2': 0.6985585219527081, 'lr_decay': 0.9189079868516059}. Best is trial 0 with value: 0.23749111592769623.
[I 2025-06-16 19:04:53,468] Trial 2 finished with value: 0.23809592425823212 and parameters: {'lr': 0.00015222819211399175, 'b1': 0.8835259616332495, 'b2': 0.6759253199716412, 'lr_decay': 0.7871656983646544}. Best is trial 0 with value: 0.23749111592769623.
[I 2025-06-16 19:27:31,760] Trial 3 finished with value: 0.23827217519283295 and parameters: {'lr': 0.00035883244083511793, 'b1': 0.6676474017690415,

In [None]:
torch.backends.cudnn.benchmark = True

In [None]:
logger.info(f"Training started on {device}")
loss_d_list: list = []
loss_g_list: list = []
total_time: float = 0.0

generator.train()
mpd.train()
msd.train()
for e in range(0, epochs):
    total_d_loss: float = 0
    total_g_loss: float = 0
    total_reconst_loss: float = 0
    start_time: float = time.time()

    for b_idx, (mel, audio) in enumerate(data_loader):
            mel, audio = mel.to(device), audio.to(device).unsqueeze(1)
            with torch.autocast(device_type=device):
                generated_audio = generator(mel)

            generated_mel = T.MelSpectrogram(sample_rate = 32000, n_fft=1023, hop_length=256, n_mels=96, f_min=30).to(device)(generated_audio.squeeze(1))
            optim_d.zero_grad()

            with torch.autocast(device_type=device):
                real_mpd_scores, fake_mpd_scores, _, _ = mpd(audio, generated_audio.detach())
            loss_d_s, mpd_loss_real, mpd_loss_fake = discriminator_loss(real_mpd_scores, fake_mpd_scores)
            
            with torch.autocast(device_type=device):
                real_msd_scores, fake_msd_scores, _, _ = msd(audio, generated_audio.detach())
            loss_d_f, msd_loss_real, msd_loss_fake = discriminator_loss(real_msd_scores, fake_msd_scores)

            total_disc_loss: Tensor = loss_d_s + loss_d_f

            total_disc_loss.backward()
            optim_d.step()

            optim_g.zero_grad()

            mel_loss = F.l1_loss(mel, generated_mel) * 45

            with torch.autocast(device_type=device):
                real_mpd_scores, fake_mpd_scores, real_mpd_features, fake_mpd_features = mpd(audio, generated_audio)
                real_msd_scores, fake_msd_scores, real_msd_features, fake_msd_features = msd(audio, generated_audio)
            
            mpd_feature_loss = feature_loss(real_mpd_features, fake_mpd_features)
            msd_feature_loss = feature_loss(real_msd_features, fake_msd_features)
            mpd_gen_loss, _ = generator_loss(fake_mpd_scores)
            msd_gen_loss, _ = generator_loss(fake_msd_scores)
            
            total_reconst_loss += F.l1_loss(audio, generated_audio)

            total_gen_loss = mpd_gen_loss + msd_gen_loss + mpd_feature_loss + msd_feature_loss + mel_loss

            total_gen_loss.backward()
            optim_g.step()

            if logger.getEffectiveLevel() == LIGHT_DEBUG:
                current_batch = b_idx + 1
                print(f"\r{time.strftime('%Y-%m-%d %H:%M:%S')},000 - LIGHT_DEBUG - Batch {current_batch:03d}/{len(data_loader):03d} D/G Loss: {total_disc_loss.item():.3f} {total_gen_loss.item():.3f}", end='', flush=True)
    else:
        if logger.getEffectiveLevel() == LIGHT_DEBUG:
            print(flush=True)

        avg_d_loss = total_disc_loss / len(data_loader)
        avg_g_loss = total_gen_loss / len(data_loader)
        avg_reconst_loss = total_reconst_loss / len(data_loader)
        loss_d_list.append(avg_d_loss)
        loss_g_list.append(avg_g_loss)
        if gen_lr_scheduler is not None:
            gen_lr_scheduler.step()
        if disc_lr_scheduler is not None:
            disc_lr_scheduler.step()

        sw.add_scalar("training/gen_loss", avg_g_loss, e)
        sw.add_scalar("training/disc_loss", avg_d_loss, e)
        sw.add_scalar("training/reconstr_loss", avg_reconst_loss, e)
        sw.add_scalar("training/lr", optim_g.param_groups[0]["lr"], e)
        sw.flush()
        
        epoch_time = time.time() - start_time
        total_time += epoch_time
        remaining_time = int((total_time / (e + 1)) * (epochs - e - 1))

        logger.info(f"Epoch {e + 1:03d}: Avg. D/G Loss: {avg_d_loss:.4e}, {avg_g_loss:.4e} Avg. reconst. Loss: {avg_reconst_loss:.4e} Remaining Time: {remaining_time // 3600:02d}h {(remaining_time % 3600) // 60:02d}min {round(remaining_time % 60):02d}s LR: {optim_d.param_groups[0]['lr']:.5e} ")
        
        if checkpoint_freq > 0 and (e + 1) % checkpoint_freq == 0:
            checkpoint_path: str = f"{full_model_path[:-4]}_epoch_{e + 1:03d}.pth"
            torch.save({"generator": generator.state_dict(), "msd": msd.state_dict(), "mpd": mpd.state_dict(), "optim_g": optim_g.state_dict(), "optim_d": optim_d.state_dict() , "epoch": e + 1}, checkpoint_path)
            if e + 1 != checkpoint_freq:
                last_path: str = f"{full_model_path[:-4]}_epoch_{(e + 1) - checkpoint_freq:03d}.pth"
                del_if_exists(last_path)
            logger.light_debug(f"Checkpoint saved model to {checkpoint_path}")
        continue


torch.save({"generator": generator.state_dict(), "msd": msd.state_dict(), "mpd": mpd.state_dict(), "optim_g": optim_g.state_dict(), "optim_d": optim_d.state_dict() , "epoch": e + 1}, full_model_path)

logger.light_debug(f"Saved model to {full_model_path}")

if checkpoint_freq > 0:
    checkpoint_path: str = f"{full_model_path[:-4]}_epoch_{e + 1 - ((e + 1) % checkpoint_freq):03d}.pth"
    del_if_exists(checkpoint_path)

scatter_plot(loss_d_list)
scatter_plot(loss_g_list)

2025-06-16 17:40:23,964 - INFO - Training started on cuda


2025-06-16 17:40:25,000 - LIGHT_DEBUG - Batch 001/078 D/G Loss: 1.570 123.769

KeyboardInterrupt: 

### Convert to wave

In [None]:
file_idx: int = 4000
with torch.no_grad():
    generated_wave = generator(torch.tensor(mel_data[file_idx]).unsqueeze(0).to(device))
save_audio_file(generated_wave.cpu().numpy()[0,0], "test3.wav", 32000)
#save_audio_file(librosa.feature.inverse.mel_to_audio(mel_data[file_idx], n_fft=1023, hop_length=256, sr=32000), "test_gl.wav", 32000)
#save_audio_file(audio_data[file_idx], "test_real.wav", 32000)

2025-06-16 17:39:08,196 - LIGHT_DEBUG - Normalized to range: [-0.99999,0.99999]
2025-06-16 17:39:08,210 - LIGHT_DEBUG - Saved file to:test3.wav


In [None]:
spect = load_spectrogram("spect.npz")
with torch.no_grad():
    generated_wave = generator(torch.tensor(spect).unsqueeze(0).to(device))
save_audio_file(generated_wave.cpu().numpy()[0,0], "muGen_out2.wav", 32000)
save_audio_file(librosa.feature.inverse.mel_to_audio(spect, n_fft=1023, hop_length=256, sr=32000), "muGen_out1_gl.wav", 32000)