In [None]:
import os
from pathlib import Path
import math
import random
import json
import time
from typing import Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import soundfile as sf
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

In [None]:
# reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

In [None]:
BASE_DIR = Path(r"C:/Users/ADMIN/Downloads/SSP/VoiceProject/processed")
TRAIN_META = BASE_DIR / "train_metadata.csv"
TEST_META  = BASE_DIR / "test_metadata.csv"
NORM_MEL_DIR = BASE_DIR / "normalized"   # contains mel .npy
CHECKPOINT_DIR = BASE_DIR / "checkpoints/hifigan"
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
RUNS_DIR = "runs/hifigan"

SR = 16000
N_MELS = 80
HOP_LENGTH = 256
WIN_LENGTH = 1024
MEL_MEAN = np.load(BASE_DIR / "mel_mean.npy")
MEL_STD  = np.load(BASE_DIR / "mel_std.npy")

In [None]:
#Training hyperparams
BATCH_SIZE = 4
NUM_WORKERS = 0
PIN_MEMORY = True
EPOCHS = 100
STEPS_PER_EPOCH = None   # set None to run through DataLoader fully
LR_G = 2e-4
LR_D = 2e-4
BETAS = (0.9, 0.999)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
USE_AMP = True

print("Device:", DEVICE)
print("Base dir:", BASE_DIR)

In [5]:
def find_mel_file(utt_id: str):
    """Try common filename patterns in normalized dir."""
    candidates = [
        NORM_MEL_DIR / f"{utt_id}_mel.npy",
        NORM_MEL_DIR / f"{utt_id}.npy",
        NORM_MEL_DIR / f"{utt_id}_mel.npy".replace("//","/"),
    ]
    for p in candidates:
        if p.exists():
            return p
    return None

In [6]:
def load_mel_for_vocoder(utt_id: str):
    """Load mel (denormalized) as float32 numpy array shape [T, n_mels]."""
    p = find_mel_file(utt_id)
    if p is None:
        raise FileNotFoundError(f"Mel not found for utt {utt_id}")
    mel = np.load(p).astype(np.float32)
    # denormalize (HiFi-GAN expects mel in original log-power scale)
    mel = mel * MEL_STD + MEL_MEAN
    return mel  # [T, n_mels]

In [7]:
def load_wav(path: str, sr=SR):
    """Load waveform, convert to mono, resample to sr, return numpy float32 [-1,1]."""
    wav, orig_sr = torchaudio.load(path)  # [channels, time]
    if wav.ndim > 1:
        wav = wav.mean(dim=0, keepdim=True)
    wav = wav.squeeze(0).numpy().astype(np.float32)
    if orig_sr != sr:
        wav = torchaudio.functional.resample(torch.from_numpy(wav), orig_sr, sr).numpy()
    # ensure float32 and range [-1,1] (torchaudio does that by default)
    return wav

In [8]:
import pandas as pd

class HiFiGANDataset(Dataset):
    def __init__(self, metadata_csv, mel_dir, wav_mean=None, wav_std=None, sr=SR, max_wav_seconds=None):
        self.df = pd.read_csv(metadata_csv)
        self.mel_dir = Path(mel_dir)
        self.sr = sr
        # filter out missing mels/wavs
        rows = []
        for _, r in self.df.iterrows():
            utt = str(r['utt_id'])
            melp = find_mel_file(utt)
            wavp = Path(r['path'])
            if melp is None or not wavp.exists():
                continue
            rows.append(r)
        self.df = pd.DataFrame(rows).reset_index(drop=True)
        print(f"HiFiGANDataset: {len(self.df)} examples (from {metadata_csv})")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        r = self.df.iloc[idx]
        utt = str(r['utt_id'])
        mel = load_mel_for_vocoder(utt)        # [T, n_mels]
        wav = load_wav(r['path'], sr=self.sr)  # [time]
        # convert types
        mel = torch.from_numpy(mel.T).float() # [n_mels, T]  <-- HiFi-GAN uses [B, n_mels, T]
        wav = torch.from_numpy(wav).float()   # [Twav]
        return mel, wav, utt

In [9]:
def hifigan_collate(batch):
    """Pad mel and wav to longest in batch. Return tensors and lengths."""
    mels, wavs, utts = zip(*batch)
    # mel lengths
    mel_lens = [m.shape[1] for m in mels]
    max_mel = max(mel_lens)
    mel_pad = torch.zeros(len(mels), N_MELS, max_mel, dtype=torch.float32)
    for i, m in enumerate(mels):
        mel_pad[i, :, :m.shape[1]] = m

    # waveform lengths (samples)
    wav_lens = [w.shape[0] for w in wavs]
    max_wav = max(wav_lens)
    wav_pad = torch.zeros(len(wavs), max_wav, dtype=torch.float32)
    for i, w in enumerate(wavs):
        wav_pad[i, :w.shape[0]] = w

    return mel_pad, torch.tensor(mel_lens), wav_pad, torch.tensor(wav_lens), list(utts)

In [10]:
class ResBlock(nn.Module):
    def __init__(self, channels, kernel_size, dilation):
        super().__init__()
        self.conv1 = nn.Conv1d(channels, channels, kernel_size, padding=((kernel_size-1)//2)*dilation, dilation=dilation)
        self.conv2 = nn.Conv1d(channels, channels, kernel_size, padding=(kernel_size-1)//2)
        self.act = nn.LeakyReLU(0.1)

    def forward(self, x):
        out = self.act(self.conv1(x))
        out = self.conv2(out)
        return self.act(out + x)

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Generator(nn.Module):
    def __init__(
        self,
        mel_channels=80,
        upsample_rates=(8, 8, 2, 2),
        upsample_kernel_sizes=(16, 16, 4, 4),
        resblock_kernel_sizes=(3, 3, 3),
        resblock_dilations=((1, 3, 5), (1, 3, 5), (1, 3, 5))
    ):
        super().__init__()

        # initial projection
        self.conv_in = nn.Conv1d(mel_channels, 512, kernel_size=7, padding=3)

        # upsampling layers
        self.ups = nn.ModuleList()
        in_ch = 512
        for r, k in zip(upsample_rates, upsample_kernel_sizes):
            out_ch = in_ch // 2
            self.ups.append(
                nn.ConvTranspose1d(
                    in_ch,
                    out_ch,
                    kernel_size=k,
                    stride=r,
                    padding=(k - r) // 2
                )
            )
            in_ch = out_ch  # update for next stage

        # residual blocks (1 per upsample stage for simplicity)
        self.resblocks = nn.ModuleList()
        in_ch = 512
        for i in range(len(self.ups)):
            out_ch = in_ch // 2
            for k, d_tuple in zip(resblock_kernel_sizes, resblock_dilations):
                for d in d_tuple:
                    self.resblocks.append(ResBlock(out_ch, k, d))
            in_ch = out_ch

        # final conv
        self.conv_out = nn.Conv1d(in_ch, 1, kernel_size=7, padding=3)
        self.tanh = nn.Tanh()

    def forward(self, m):
        # m: [B, n_mels, Tm]
        x = self.conv_in(m)
        cur = x
        for i, up in enumerate(self.ups):
            cur = up(cur)
            cur = F.relu(cur)
            # (you can apply resblocks here if you want more realism)
        out = self.conv_out(cur)
        return self.tanh(out).squeeze(1)  # [B, Ts]

In [12]:
# Simple multi-scale discriminator
class DiscriminatorP(nn.Module):
    def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
        super().__init__()
        self.period = period
        def SNConv(in_ch, out_ch, k, s, p):
            c = nn.Conv1d(in_ch, out_ch, k, s, p)
            return c
        self.convs = nn.ModuleList([
            SNConv(1, 32, 5, 3, 2),
            SNConv(32, 128, 5, 3, 2),
            SNConv(128, 512, 5, 3, 2),
            SNConv(512, 1024, 5, 3, 2),
            SNConv(1024, 1024, 5, 1, 2),
        ])
        self.final = nn.Conv1d(1024, 1, 3, 1, 1)

    def forward(self, x):
        # x: [B, T]
        b, t = x.shape
        if t % self.period != 0:
            pad_len = self.period - (t % self.period)
            x = F.pad(x, (0, pad_len))
            t = t + pad_len
        x = x.view(b, 1, t)
        features = []
        for c in self.convs:
            x = F.leaky_relu(c(x), 0.1)
            features.append(x)
        out = self.final(x)
        features.append(out)
        return out, features

In [13]:
class MultiScaleDiscriminator(nn.Module):
    def __init__(self, periods=(2,3,5)):
        super().__init__()
        self.discriminators = nn.ModuleList([DiscriminatorP(p) for p in periods])

    def forward(self, x):
        outs = []
        feats = []
        for d in self.discriminators:
            out, f = d(x)
            outs.append(out)
            feats.append(f)
        return outs, feats

In [14]:
# losses
adv_criterion = nn.MSELoss()  # LSGAN style
mel_criterion = nn.L1Loss()

def discriminator_loss(real_scores, fake_scores):
    loss = 0.0
    for r, f in zip(real_scores, fake_scores):
        loss += adv_criterion(r, torch.ones_like(r)) + adv_criterion(f, torch.zeros_like(f))
    return loss

def generator_adversarial_loss(fake_scores):
    loss = 0.0
    for f in fake_scores:
        loss += adv_criterion(f, torch.ones_like(f))
    return loss

def feature_matching_loss(real_feats, fake_feats):
    loss = 0.0
    for rf_list, ff_list in zip(real_feats, fake_feats):
        for r_feat, f_feat in zip(rf_list, ff_list):
            loss += F.l1_loss(f_feat, r_feat.detach())
    return loss


In [15]:
def generator_loss(d_fake_mpd, d_fake_msd, fmap_r_mpd, fmap_g_mpd, fmap_r_msd, fmap_g_msd, fake_wav, real_wav):
    """
    Complete generator loss for HiFi-GAN.
    Handles variable length feature maps by truncating to the minimum length.
    """
    # Adversarial loss
    adv_loss = 0.0
    for f_mpd, f_msd in zip(d_fake_mpd, d_fake_msd):
        adv_loss += adv_criterion(f_mpd, torch.ones_like(f_mpd))
        adv_loss += adv_criterion(f_msd, torch.ones_like(f_msd))
    
    # Feature matching loss - handle variable lengths
    fm_loss = 0.0
    
    # For Multi-Period Discriminator features
    for rf_list, gf_list in zip(fmap_r_mpd, fmap_g_mpd):
        for r_feat, g_feat in zip(rf_list, gf_list):
            # Truncate to minimum length
            min_length = min(r_feat.size(-1), g_feat.size(-1))
            r_feat_trunc = r_feat[..., :min_length]
            g_feat_trunc = g_feat[..., :min_length]
            fm_loss += F.l1_loss(g_feat_trunc, r_feat_trunc.detach())
    
    # For Multi-Scale Discriminator features
    for rf_list, gf_list in zip(fmap_r_msd, fmap_g_msd):
        for r_feat, g_feat in zip(rf_list, gf_list):
            # Truncate to minimum length
            min_length = min(r_feat.size(-1), g_feat.size(-1))
            r_feat_trunc = r_feat[..., :min_length]
            g_feat_trunc = g_feat[..., :min_length]
            fm_loss += F.l1_loss(g_feat_trunc, r_feat_trunc.detach())
    
    # Mel-spectrogram reconstruction loss
    mel_fake = mel_spectrogram(fake_wav)
    mel_real = mel_spectrogram(real_wav)
    
    # Match lengths
    min_len = min(mel_fake.size(-1), mel_real.size(-1))
    mel_fake = mel_fake[..., :min_len]
    mel_real = mel_real[..., :min_len]
    
    mel_loss = mel_criterion(mel_fake, mel_real)
    
    # Weighting factors (standard HiFi-GAN weights)
    lambda_adv = 1.0
    lambda_fm = 2.0
    lambda_mel = 45.0
    
    total_loss = (lambda_adv * adv_loss + 
                 lambda_fm * fm_loss + 
                 lambda_mel * mel_loss)
    
    return total_loss

In [16]:
# init models + optimizers
gen = Generator().to(DEVICE)
disc = MultiScaleDiscriminator().to(DEVICE)

opt_g = torch.optim.AdamW(gen.parameters(), lr=LR_G, betas=BETAS)
opt_d = torch.optim.AdamW(disc.parameters(), lr=LR_D, betas=BETAS)
# learning rate schedulers 
scheduler_g = torch.optim.lr_scheduler.ReduceLROnPlateau(
    opt_g, mode='min', factor=0.5, patience=5
)
scheduler_d = torch.optim.lr_scheduler.ReduceLROnPlateau(
    opt_d, mode='min', factor=0.5, patience=5
)
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)

print("Gen params (M):", sum(p.numel() for p in gen.parameters())/1e6)
print("Disc params (M):", sum(p.numel() for p in disc.parameters())/1e6)


Gen params (M): 7.659137
Disc params (M): 24.655299


  scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)


In [17]:
# dataloaders
train_ds = HiFiGANDataset(TRAIN_META, NORM_MEL_DIR)
val_ds   = HiFiGANDataset(TEST_META, NORM_MEL_DIR)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, collate_fn=hifigan_collate)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, collate_fn=hifigan_collate)


HiFiGANDataset: 29102 examples (from C:\Users\ADMIN\Downloads\SSP\VoiceProject\processed\train_metadata.csv)
HiFiGANDataset: 4134 examples (from C:\Users\ADMIN\Downloads\SSP\VoiceProject\processed\test_metadata.csv)


In [18]:
class PeriodDiscriminator(nn.Module):
    def __init__(self, period=2):
        super().__init__()
        self.period = period
        self.convs = nn.ModuleList([
            nn.Conv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)),
            nn.Conv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)),
            nn.Conv2d(128, 512, (5, 1), (3, 1), padding=(2, 0))
        ])
        self.conv_post = nn.Conv2d(512, 1, (3, 1), 1, padding=(1, 0))

    def forward(self, x):
        fmap = []
        # reshape waveform into [B, 1, T//p, p]
        b, t = x.size(0), x.size(1)
        if t % self.period != 0:
            pad_len = self.period - (t % self.period)
            x = F.pad(x, (0, pad_len), "reflect")
        x = x.view(b, 1, -1, self.period)

        for l in self.convs:
            x = l(x)
            x = F.leaky_relu(x, 0.1)
            fmap.append(x)

        x = self.conv_post(x)
        fmap.append(x)
        return torch.flatten(x, 1, -1), fmap

In [19]:
class MultiPeriodDiscriminator(nn.Module):
    def __init__(self, periods=[2,3,5,7,11]):
        super().__init__()
        self.discriminators = nn.ModuleList([PeriodDiscriminator(p) for p in periods])

    def forward(self, real, fake):
        real_outputs, fake_outputs, real_feats, fake_feats = [], [], [], []
        for d in self.discriminators:
            r_out, r_feat = d(real)
            f_out, f_feat = d(fake)
            real_outputs.append(r_out)
            fake_outputs.append(f_out)
            real_feats.append(r_feat)
            fake_feats.append(f_feat)
        return real_outputs, fake_outputs, real_feats, fake_feats

In [20]:
class ScaleDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.convs = nn.ModuleList([
            nn.Conv1d(1, 128, 15, 1, padding=7),
            nn.Conv1d(128, 128, 41, 4, groups=4, padding=20),
            nn.Conv1d(128, 256, 41, 4, groups=16, padding=20),
            nn.Conv1d(256, 512, 41, 4, groups=16, padding=20),
            nn.Conv1d(512, 1024, 41, 4, groups=16, padding=20),
            nn.Conv1d(1024, 1024, 5, 1, padding=2)
        ])
        self.conv_post = nn.Conv1d(1024, 1, 3, 1, padding=1)

    def forward(self, x):
        fmap = []
        for l in self.convs:
            x = l(x)
            x = F.leaky_relu(x, 0.1)
            fmap.append(x)
        x = self.conv_post(x)
        fmap.append(x)
        return torch.flatten(x, 1, -1), fmap

In [21]:
class MultiScaleDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.discriminators = nn.ModuleList([ScaleDiscriminator() for _ in range(3)])
        self.pools = nn.ModuleList([nn.Identity(), nn.AvgPool1d(4, 2, padding=2), nn.AvgPool1d(4, 2, padding=2)])

    def forward(self, real, fake):
        real_outputs, fake_outputs, real_feats, fake_feats = [], [], [], []
        for pool, d in zip(self.pools, self.discriminators):
            r_out, r_feat = d(pool(real.unsqueeze(1)))
            f_out, f_feat = d(pool(fake.unsqueeze(1)))
            real_outputs.append(r_out)
            fake_outputs.append(f_out)
            real_feats.append(r_feat)
            fake_feats.append(f_feat)
        return real_outputs, fake_outputs, real_feats, fake_feats

In [22]:
# Initialize models
gen = Generator().to(DEVICE)
mpd = MultiPeriodDiscriminator().to(DEVICE)
msd = MultiScaleDiscriminator().to(DEVICE)
# Optimizers
opt_g = torch.optim.AdamW(gen.parameters(), lr=2e-4, betas=(0.8, 0.99))
opt_d = torch.optim.AdamW(list(mpd.parameters()) + list(msd.parameters()), lr=2e-4, betas=(0.8, 0.99))

# Mixed precision
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)

  scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)


In [23]:
def train_one_epoch(epoch, gen, mpd, msd, train_loader,
                    optimizer_g, optimizer_d, scaler, DEVICE, USE_AMP=True):
    """
    Train HiFi-GAN for one epoch (memory-optimized).
    """
    gen.train()
    mpd.train()
    msd.train()

    running_g, running_d = 0.0, 0.0
    print(f"\n[Epoch {epoch}] Training...")
    print_cuda_memory("Start of Epoch")

    for batch_idx, batch in enumerate(train_loader):
        # Unpack all 5 values from the collate function
        mel_pad, mel_len, wav_pad, wav_len, utts = batch

        mel_pad = mel_pad.to(DEVICE, non_blocking=True)
        wav_pad = wav_pad.to(DEVICE, non_blocking=True)

        if isinstance(mel_len, torch.Tensor):
            mel_len = mel_len.to(DEVICE, non_blocking=True)
        if isinstance(wav_len, torch.Tensor):
            wav_len = wav_len.to(DEVICE, non_blocking=True)

        optimizer_g.zero_grad(set_to_none=True)
        optimizer_d.zero_grad(set_to_none=True)

        # Train Discriminator
        with torch.amp.autocast("cuda", enabled=USE_AMP):
            fake_wav = gen(mel_pad)
            real_wav = wav_pad

            # Ensure lengths match by truncating
            min_length = min(fake_wav.size(-1), real_wav.size(-1))
            fake_wav = fake_wav[..., :min_length]
            real_wav = real_wav[..., :min_length]

            d_real_mpd, d_fake_mpd, _, _ = mpd(real_wav, fake_wav.detach())
            d_real_msd, d_fake_msd, _, _ = msd(real_wav, fake_wav.detach())
            d_loss = (discriminator_loss(d_real_mpd, d_fake_mpd) +
                      discriminator_loss(d_real_msd, d_fake_msd))

        # Scale and backward for discriminator
        scaler.scale(d_loss).backward()
        scaler.step(optimizer_d)  # Update discriminator weights

        # Train Generator
        with torch.amp.autocast("cuda", enabled=USE_AMP):
            d_real_mpd, d_fake_mpd, fmap_r_mpd, fmap_g_mpd = mpd(real_wav, fake_wav)
            d_real_msd, d_fake_msd, fmap_r_msd, fmap_g_msd = msd(real_wav, fake_wav)

            g_loss = generator_loss(
                d_fake_mpd, d_fake_msd,
                fmap_r_mpd, fmap_g_mpd,
                fmap_r_msd, fmap_g_msd,
                fake_wav, real_wav
            )

        # Scale and backward for generator
        scaler.scale(g_loss).backward()
        scaler.step(optimizer_g)  # Update generator weights
        
        # Update the scaler only once per batch
        scaler.update()

        #   Logging + Cleanup
        running_d += d_loss.item()
        running_g += g_loss.item()

        if batch_idx % 50 == 0:
            avg_g = running_g / (batch_idx + 1)
            avg_d = running_d / (batch_idx + 1)
            print(f"  Batch {batch_idx:04d} | g_loss={avg_g:.4f}, d_loss={avg_d:.4f}")
            print_cuda_memory(f"After batch {batch_idx}")

        # Free unused references (light cleanup)
        del fake_wav, real_wav

    print_cuda_memory("End of Epoch")

    return {
        "g_loss": running_g / len(train_loader),
        "d_loss": running_d / len(train_loader),
    }

@torch.no_grad()
def validate_epoch(epoch, gen, val_loader, DEVICE, USE_AMP=True):
    """
    Validate HiFi-GAN for one epoch (memory-optimized).
    Only computes mel-spectrogram reconstruction loss (L1).
    """
    gen.eval()
    val_loss = 0.0

    print(f"\n[Epoch {epoch}] Validation...")
    print_cuda_memory("Start of Validation")

    for batch_idx, batch in enumerate(val_loader):
        # Unpack all 5 values from the collate function
        mel_pad, mel_len, wav_pad, wav_len, utts = batch  # ← Added utts here
        mel_pad = mel_pad.to(DEVICE, non_blocking=True)
        wav_pad = wav_pad.to(DEVICE, non_blocking=True)

        with torch.amp.autocast("cuda", enabled=USE_AMP):
            fake_wav = gen(mel_pad)

            # Match lengths
            min_len = min(fake_wav.size(-1), wav_pad.size(-1))
            fake_wav = fake_wav[..., :min_len]
            real_wav = wav_pad[..., :min_len]

            # Compute mel loss
            mel_fake = mel_spectrogram(fake_wav)
            mel_real = mel_spectrogram(real_wav)
            loss = F.l1_loss(mel_fake, mel_real)

        val_loss += loss.item()

        if batch_idx % 20 == 0:
            avg_loss = val_loss / (batch_idx + 1)
            print(f"  Val batch {batch_idx:04d} | mel_loss={avg_loss:.4f}")
            print_cuda_memory(f"After val batch {batch_idx}")

        # cleanup
        del fake_wav, real_wav, mel_fake, mel_real

    val_loss /= len(val_loader)
    print(f"[Epoch {epoch}] Validation Loss: {val_loss:.4f}")
    print_cuda_memory("End of Validation")

    return {"val_loss": val_loss}

In [24]:
from tqdm import tqdm

def train_one_epoch(epoch, gen, mpd, msd, train_loader,
                    optimizer_g, optimizer_d, scaler, DEVICE, USE_AMP=True):
    """
    Train HiFi-GAN for one epoch with progress bar.
    """
    gen.train()
    mpd.train()
    msd.train()

    running_g, running_d = 0.0, 0.0
    print(f"\n[Epoch {epoch}] Training...")
    print_cuda_memory("Start of Epoch")

    # Add progress bar
    pbar = tqdm(train_loader, desc=f"Epoch {epoch} Training", leave=False)
    
    for batch_idx, batch in enumerate(pbar):
        # Unpack all 5 values from the collate function
        mel_pad, mel_len, wav_pad, wav_len, utts = batch

        mel_pad = mel_pad.to(DEVICE, non_blocking=True)
        wav_pad = wav_pad.to(DEVICE, non_blocking=True)

        if isinstance(mel_len, torch.Tensor):
            mel_len = mel_len.to(DEVICE, non_blocking=True)
        if isinstance(wav_len, torch.Tensor):
            wav_len = wav_len.to(DEVICE, non_blocking=True)

        optimizer_g.zero_grad(set_to_none=True)
        optimizer_d.zero_grad(set_to_none=True)

        # Train Discriminator
        with torch.amp.autocast("cuda", enabled=USE_AMP):
            fake_wav = gen(mel_pad)
            real_wav = wav_pad

            # Ensure lengths match by truncating
            min_length = min(fake_wav.size(-1), real_wav.size(-1))
            fake_wav = fake_wav[..., :min_length]
            real_wav = real_wav[..., :min_length]

            d_real_mpd, d_fake_mpd, _, _ = mpd(real_wav, fake_wav.detach())
            d_real_msd, d_fake_msd, _, _ = msd(real_wav, fake_wav.detach())
            d_loss = (discriminator_loss(d_real_mpd, d_fake_mpd) +
                      discriminator_loss(d_real_msd, d_fake_msd))

        # Scale and backward for discriminator
        scaler.scale(d_loss).backward()
        scaler.step(optimizer_d)

        # Train Generator
        with torch.amp.autocast("cuda", enabled=USE_AMP):
            d_real_mpd, d_fake_mpd, fmap_r_mpd, fmap_g_mpd = mpd(real_wav, fake_wav)
            d_real_msd, d_fake_msd, fmap_r_msd, fmap_g_msd = msd(real_wav, fake_wav)

            g_loss = generator_loss(
                d_fake_mpd, d_fake_msd,
                fmap_r_mpd, fmap_g_mpd,
                fmap_r_msd, fmap_g_msd,
                fake_wav, real_wav
            )

        # Scale and backward for generator
        scaler.scale(g_loss).backward()
        scaler.step(optimizer_g)
        scaler.update()

        # Update running losses
        running_d += d_loss.item()
        running_g += g_loss.item()

        # Update progress bar
        avg_g = running_g / (batch_idx + 1)
        avg_d = running_d / (batch_idx + 1)
        pbar.set_postfix({
            'g_loss': f'{avg_g:.4f}',
            'd_loss': f'{avg_d:.4f}',
            'gpu_mem': f'{torch.cuda.memory_allocated()/1024**3:.1f}GB'
        })

        # Free unused references
        del fake_wav, real_wav

    pbar.close()
    print_cuda_memory("End of Epoch")

    return {
        "g_loss": running_g / len(train_loader),
        "d_loss": running_d / len(train_loader),
    }

@torch.no_grad()
def validate_epoch(epoch, gen, val_loader, DEVICE, USE_AMP=True):
    """
    Validate HiFi-GAN for one epoch with progress bar.
    """
    gen.eval()
    val_loss = 0.0

    print(f"\n[Epoch {epoch}] Validation...")
    print_cuda_memory("Start of Validation")

    # Add progress bar for validation
    pbar = tqdm(val_loader, desc=f"Epoch {epoch} Validation", leave=False)
    
    for batch_idx, batch in enumerate(pbar):
        # Unpack all 5 values from the collate function
        mel_pad, mel_len, wav_pad, wav_len, utts = batch
        mel_pad = mel_pad.to(DEVICE, non_blocking=True)
        wav_pad = wav_pad.to(DEVICE, non_blocking=True)

        with torch.amp.autocast("cuda", enabled=USE_AMP):
            fake_wav = gen(mel_pad)

            # Match lengths
            min_len = min(fake_wav.size(-1), wav_pad.size(-1))
            fake_wav = fake_wav[..., :min_len]
            real_wav = wav_pad[..., :min_len]

            # Compute mel loss
            mel_fake = mel_spectrogram(fake_wav)
            mel_real = mel_spectrogram(real_wav)
            loss = F.l1_loss(mel_fake, mel_real)

        val_loss += loss.item()

        # Update progress bar
        avg_loss = val_loss / (batch_idx + 1)
        pbar.set_postfix({
            'val_loss': f'{avg_loss:.4f}',
            'gpu_mem': f'{torch.cuda.memory_allocated()/1024**3:.1f}GB'
        })

        # cleanup
        del fake_wav, real_wav, mel_fake, mel_real

    pbar.close()
    val_loss /= len(val_loader)
    print(f"[Epoch {epoch}] Validation Loss: {val_loss:.4f}")
    print_cuda_memory("End of Validation")

    return {"val_loss": val_loss}

In [25]:
# Optimizers and Scaler 
LEARNING_RATE = 2e-4  # default for HiFi-GAN
betas = (0.8, 0.99)

optimizer_g = torch.optim.AdamW(gen.parameters(), lr=LEARNING_RATE, betas=betas)
optimizer_d = torch.optim.AdamW(
    list(mpd.parameters()) + list(msd.parameters()), 
    lr=LEARNING_RATE, 
    betas=betas
)

# Mixed precision scaler
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)


  scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)


In [26]:
def mel_spectrogram(wav):
    """Convert waveform to mel-spectrogram using the same parameters as preprocessing."""
    # You'll need to implement this using torchaudio or your preferred method
    # This should match the parameters used during data preprocessing
    # (SR=16000, N_MELS=80, HOP_LENGTH=256, WIN_LENGTH=1024)
    
    # Example implementation using torchaudio:
    mel_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=SR,
        n_fft=WIN_LENGTH,
        win_length=WIN_LENGTH,
        hop_length=HOP_LENGTH,
        n_mels=N_MELS,
        center=True
    ).to(wav.device)
    
    mel = mel_transform(wav)
    mel = torch.log(torch.clamp(mel, min=1e-5))  # Convert to log scale
    return mel

In [None]:
import torch
import time
from pathlib import Path

# Directory to save checkpoints
CHECKPOINT_DIR = Path("./checkpoints")
CHECKPOINT_DIR.mkdir(exist_ok=True)

# Checkpoint Helpers
def save_checkpoint(epoch, gen, mpd, msd, optimizer_g, optimizer_d, scaler, path):
    state = {
        "epoch": epoch,
        "gen": gen.state_dict(),
        "mpd": mpd.state_dict(),
        "msd": msd.state_dict(),
        "optimizer_g": optimizer_g.state_dict(),
        "optimizer_d": optimizer_d.state_dict(),
        "scaler": scaler.state_dict()
    }
    torch.save(state, path)
    print(f" Checkpoint saved at {path}")

def load_checkpoint(path, gen, mpd, msd, optimizer_g, optimizer_d, scaler, device):
    checkpoint = torch.load(path, map_location=device)
    gen.load_state_dict(checkpoint["gen"])
    mpd.load_state_dict(checkpoint["mpd"])
    msd.load_state_dict(checkpoint["msd"])
    optimizer_g.load_state_dict(checkpoint["optimizer_g"])
    optimizer_d.load_state_dict(checkpoint["optimizer_d"])
    scaler.load_state_dict(checkpoint["scaler"])
    start_epoch = checkpoint["epoch"] + 1
    print(f" Loaded checkpoint from {path}, resuming at epoch {start_epoch}")
    return start_epoch

def print_cuda_memory(tag=""):
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / (1024 ** 2)
        reserved = torch.cuda.memory_reserved() / (1024 ** 2)
        print(f"[{tag}] CUDA Memory - Allocated: {allocated:.2f} MB | Reserved: {reserved:.2f} MB")
    else:
        print(f"[{tag}] CUDA not available")

#  Training Loop
start_epoch = 1
resume_ckpt = CHECKPOINT_DIR / "checkpoint_last.pt"  # change if needed

# Resume if checkpoint exists
if resume_ckpt.exists():
    start_epoch = load_checkpoint(resume_ckpt, gen, mpd, msd, optimizer_g, optimizer_d, scaler, DEVICE)
# Early stopping parameters
PATIENCE = 10
MIN_DELTA = 0.001
best_val_loss = float('inf')
patience_counter = 0
early_stop = False

# Training Loop with progress bars and schedulers
for epoch in range(start_epoch, EPOCHS+1):
    if early_stop:
        print(f"Early stopping triggered at epoch {epoch}")
        break
        
    t0 = time.time()

    # Training with progress bar
    train_stats = train_one_epoch(
        epoch,
        gen,
        mpd,
        msd,
        train_loader,
        optimizer_g,
        optimizer_d,
        scaler,
        DEVICE,
        USE_AMP
    )

    # Validation with progress bar
    val_stats = validate_epoch(
        epoch,
        gen,
        val_loader,
        DEVICE,
        USE_AMP
    )
    
    current_val_loss = val_stats["val_loss"]
    
    # Update learning rate schedulers
    scheduler_g.step(current_val_loss)
    scheduler_d.step(current_val_loss)
    
    dt = time.time() - t0
    
    print(f" Epoch {epoch} finished in {dt/60:.2f} min | "
          f"G Loss: {train_stats['g_loss']:.4f} | D Loss: {train_stats['d_loss']:.4f} | "
          f"Val Loss: {current_val_loss:.4f}")

    # Early stopping logic
    if current_val_loss < best_val_loss - MIN_DELTA:
        best_val_loss = current_val_loss
        patience_counter = 0
        best_ckpt_path = CHECKPOINT_DIR / "checkpoint_best.pt"
        save_checkpoint(epoch, gen, mpd, msd, optimizer_g, optimizer_d, scaler, best_ckpt_path)
        print(f"✓ New best validation loss: {best_val_loss:.4f}")
    else:
        patience_counter += 1
        print(f"✗ Validation loss didn't improve. Patience: {patience_counter}/{PATIENCE}")
        if patience_counter >= PATIENCE:
            early_stop = True
            print("Early stopping triggered!")

    # Save checkpoints
    ckpt_path = CHECKPOINT_DIR / f"checkpoint_epoch{epoch}.pt"
    save_checkpoint(epoch, gen, mpd, msd, optimizer_g, optimizer_d, scaler, ckpt_path)
    save_checkpoint(epoch, gen, mpd, msd, optimizer_g, optimizer_d, scaler, resume_ckpt)

 Loaded checkpoint from checkpoints\checkpoint_last.pt, resuming at epoch 2

[Epoch 2] Training...
[Start of Epoch] CUDA Memory - Allocated: 444.89 MB | Reserved: 580.00 MB


Epoch 2 Training:   0%|              | 4/7276 [00:03<1:16:17,  1.59it/s, g_loss=133.3203, d_loss=2.7370, gpu_mem=4.8GB]