<a href="https://colab.research.google.com/github/anthony2850/HifiGAN_RPGAN_Formant_loss/blob/main/GAN_stabilization_and_formant_loss.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive, output
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install wavenet_vocoder
!pip install librosa==0.9.1
!pip install wandb
!pip install "pip<24.1"
!pip install git+https://github.com/One-sixth/fairseq.git
output.clear()

In [None]:
%cd /content/drive/MyDrive/paper_submit_PAPER_SUBMIT
# 본인 폴더 경로로 변경

/content/drive/.shortcut-targets-by-id/1S6Y0Fb8MYqMax1o_dZBhgLlN3NSMxMXi/PAPER_SUBMIT


In [None]:
!pip install --upgrade scoreq
!pip install torch-pesq
!pip install praat-parselmouth
output.clear()

In [None]:
import os
import pickle
import numpy as np
import soundfile as sf
from scipy import signal
from scipy.signal import get_window
from librosa.filters import mel
from numpy.random import RandomState
import torch
import torch.nn as nn
from collections import OrderedDict
import torch.nn.functional as F
from torch.utils import data
from multiprocessing import Process, Manager
import easydict
from torch.backends import cudnn
import time
import datetime
from math import ceil
from tqdm import tqdm
from wavenet_vocoder import builder
from torch_pesq import PesqLoss
import scoreq
output.clear()

In [None]:
#utils
import glob
import os
import matplotlib
import torch
from torch.nn.utils import weight_norm
matplotlib.use("Agg")
import matplotlib.pylab as plt


def plot_spectrogram(spectrogram):
    fig, ax = plt.subplots(figsize=(10, 2))
    im = ax.imshow(spectrogram, aspect="auto", origin="lower",
                   interpolation='none')
    plt.colorbar(im, ax=ax)

    fig.canvas.draw()
    plt.close()

    return fig


def init_weights(m, mean=0.0, std=0.01):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        m.weight.data.normal_(mean, std)


def apply_weight_norm(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        weight_norm(m)


def get_padding(kernel_size, dilation=1):
    return int((kernel_size*dilation - dilation)/2)


def load_checkpoint(filepath, device):
    assert os.path.isfile(filepath)
    print("Loading '{}'".format(filepath))
    checkpoint_dict = torch.load(filepath, map_location=device)
    print("Complete.")
    return checkpoint_dict


def save_checkpoint(filepath, obj):
    print("Saving checkpoint to {}".format(filepath))
    torch.save(obj, filepath)
    print("Complete.")


def scan_checkpoint(cp_dir, prefix):
    pattern = os.path.join(cp_dir, prefix + '????????')
    cp_list = glob.glob(pattern)
    if len(cp_list) == 0:
        return None
    return sorted(cp_list)[-1]

In [None]:
#env
import os
import shutil

class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self


def build_env(config, config_name, path): # env_path : cp_hifigan/config_v1.json에 저장
    t_path = os.path.join(path, config_name)
    if config != t_path:
        os.makedirs(path, exist_ok=True)
        shutil.copyfile(config, os.path.join(path, config_name))

In [None]:
import math
import os
import random
import torch
import torch.utils.data
import numpy as np
from librosa.util import normalize
from scipy.io.wavfile import read
from librosa.filters import mel

MAX_WAV_VALUE = 32768.0

def load_wav(full_path):
    sampling_rate, data = read(full_path)
    return data, sampling_rate

def dynamic_range_compression(x, C=1, clip_val=1e-5):
    return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)

def dynamic_range_decompression(x, C=1):
    return np.exp(x) / C

def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
    return torch.log(torch.clamp(x, min=clip_val) * C)

def dynamic_range_decompression_torch(x, C=1):
    return torch.exp(x) / C

def spectral_normalize_torch(magnitudes):
    output = dynamic_range_compression_torch(magnitudes)
    return output

def spectral_de_normalize_torch(magnitudes):
    output = dynamic_range_decompression_torch(magnitudes)
    return output

mel_basis = {}
hann_window = {}

def librosa_mel_fn_(sampling_rate, n_fft, num_mels, fmin, fmax):
    return librosa.filters.mel(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)

def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
    if torch.min(y) < -1.:
        print('min value is ', torch.min(y))
    if torch.max(y) > 1.:
        print('max value is ', torch.max(y))

    global mel_basis, hann_window
    if fmax not in mel_basis:
        mel = librosa_mel_fn_(sampling_rate, n_fft, num_mels, fmin, fmax)
        mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
        hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)

    # Apply padding and calculate STFT
    y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
    y = y.squeeze(1)

    spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
                      center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)

    # STFT의 결과에서 절대값을 취해 크기만 남김
    spec = spec.abs()

    # spec의 차원이 (n_frames, frequency_bins) 형식으로 맞는지 확인
    if len(spec.shape) == 3:
        spec = spec.squeeze(1)

    # mel_basis와 spec의 차원이 맞는지 확인한 후 matmul 실행
    spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
    spec = spectral_normalize_torch(spec)

    return spec

def get_dataset_filelist(a):
    with open(a.input_training_file, 'r', encoding='utf-8') as fi:
        training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
                          for x in fi.read().split('\n') if len(x) > 0]

    with open(a.input_validation_file, 'r', encoding='utf-8') as fi:
        validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
                            for x in fi.read().split('\n') if len(x) > 0]
    return training_files, validation_files

class MelDataset(torch.utils.data.Dataset):
    def __init__(self, training_files, segment_size, n_fft, num_mels,
                 hop_size, win_size, sampling_rate,  fmin, fmax, split=True, shuffle=True, n_cache_reuse=1,
                 device=None, fmax_loss=None, fine_tuning=False, base_mels_path=None):
        self.audio_files = training_files
        random.seed(1234)
        if shuffle:
            random.shuffle(self.audio_files)
        self.segment_size = segment_size
        self.sampling_rate = sampling_rate
        self.split = split
        self.n_fft = n_fft
        self.num_mels = num_mels
        self.hop_size = hop_size
        self.win_size = win_size
        self.fmin = fmin
        self.fmax = fmax
        self.fmax_loss = fmax_loss
        self.cached_wav = None
        self.n_cache_reuse = n_cache_reuse
        self._cache_ref_count = 0
        self.device = device
        self.fine_tuning = fine_tuning
        self.base_mels_path = base_mels_path

    def __getitem__(self, index):
        filename = self.audio_files[index]
        if self._cache_ref_count == 0:
            audio, sampling_rate = load_wav(filename)
            audio = audio / MAX_WAV_VALUE
            if not self.fine_tuning:
                audio = normalize(audio) * 0.95
            self.cached_wav = audio
            if sampling_rate != self.sampling_rate:
                raise ValueError(f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR")
            self._cache_ref_count = self.n_cache_reuse
        else:
            audio = self.cached_wav
            self._cache_ref_count -= 1

        audio = torch.FloatTensor(audio).unsqueeze(0)

        if not self.fine_tuning:  # self.fine_tuning : False
            if self.split:
                if audio.size(1) >= self.segment_size:
                    max_audio_start = audio.size(1) - self.segment_size
                    audio_start = random.randint(0, max_audio_start)
                    audio = audio[:, audio_start:audio_start+self.segment_size]
                else:
                    audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')

            mel = mel_spectrogram(audio, self.n_fft, self.num_mels,
                                  self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax,
                                  center=False)
        else:
            mel = np.load(
                os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + '.npy'))
            mel = torch.from_numpy(mel)

            if len(mel.shape) < 3:
                mel = mel.unsqueeze(0)

            if self.split:
                frames_per_seg = math.ceil(self.segment_size / self.hop_size)

                if audio.size(1) >= self.segment_size:
                    mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
                    mel = mel[:, :, mel_start:mel_start + frames_per_seg]
                    audio = audio[:, mel_start * self.hop_size:(mel_start + frames_per_seg) * self.hop_size]
                else:
                    mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), 'constant')
                    audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')

        mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels,
                                   self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss,
                                   center=False)

        return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())

    def __len__(self):
        return len(self.audio_files)

## hifi-model

In [None]:
#model
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm

LRELU_SLOPE = 0.1


class ResBlock1(torch.nn.Module):
    def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
        super(ResBlock1, self).__init__()
        self.h = h
        self.convs1 = nn.ModuleList([
            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
                               padding=get_padding(kernel_size, dilation[0]))),
            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
                               padding=get_padding(kernel_size, dilation[1]))),
            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
                               padding=get_padding(kernel_size, dilation[2])))
        ])
        self.convs1.apply(init_weights)

        self.convs2 = nn.ModuleList([
            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
                               padding=get_padding(kernel_size, 1))),
            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
                               padding=get_padding(kernel_size, 1))),
            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
                               padding=get_padding(kernel_size, 1)))
        ])
        self.convs2.apply(init_weights)

    def forward(self, x):
        for c1, c2 in zip(self.convs1, self.convs2):
            xt = F.leaky_relu(x, LRELU_SLOPE)
            xt = c1(xt)
            xt = F.leaky_relu(xt, LRELU_SLOPE)
            xt = c2(xt)
            x = xt + x
        return x

    def remove_weight_norm(self):
        for l in self.convs1:
            remove_weight_norm(l)
        for l in self.convs2:
            remove_weight_norm(l)


class ResBlock2(torch.nn.Module):
    def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
        super(ResBlock2, self).__init__()
        self.h = h
        self.convs = nn.ModuleList([
            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
                               padding=get_padding(kernel_size, dilation[0]))),
            weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
                               padding=get_padding(kernel_size, dilation[1])))
        ])
        self.convs.apply(init_weights)

    def forward(self, x):
        for c in self.convs:
            xt = F.leaky_relu(x, LRELU_SLOPE)
            xt = c(xt)
            x = xt + x
        return x

    def remove_weight_norm(self):
        for l in self.convs:
            remove_weight_norm(l)

class Generator(torch.nn.Module):
    def __init__(self, h):
        super(Generator, self).__init__()
        self.h = h
        self.num_kernels = len(h.resblock_kernel_sizes)
        self.num_upsamples = len(h.upsample_rates)
        self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
        resblock = ResBlock1 if h.resblock == '1' else ResBlock2

        self.ups = nn.ModuleList()
        for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
            self.ups.append(weight_norm(
                ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
                                k, u, padding=(k-u)//2)))

        self.resblocks = nn.ModuleList()
        for i in range(len(self.ups)):
            ch = h.upsample_initial_channel//(2**(i+1))
            for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
                self.resblocks.append(resblock(h, ch, k, d))

        self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
        self.ups.apply(init_weights)
        self.conv_post.apply(init_weights)

    def forward(self, x):
        x = self.conv_pre(x)
        for i in range(self.num_upsamples):
            x = F.leaky_relu(x, LRELU_SLOPE)
            x = self.ups[i](x)
            xs = None
            for j in range(self.num_kernels):
                if xs is None:
                    xs = self.resblocks[i*self.num_kernels+j](x)
                else:
                    xs += self.resblocks[i*self.num_kernels+j](x)
            x = xs / self.num_kernels
        x = F.leaky_relu(x)
        x = self.conv_post(x)
        x = torch.tanh(x)

        return x

    def remove_weight_norm(self):
        print('Removing weight norm...')
        for l in self.ups:
            remove_weight_norm(l)
        for l in self.resblocks:
            l.remove_weight_norm()
        remove_weight_norm(self.conv_pre)
        remove_weight_norm(self.conv_post)


class DiscriminatorP(torch.nn.Module):
    def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
        super(DiscriminatorP, self).__init__()
        self.period = period
        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
        self.convs = nn.ModuleList([
            norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
            norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
            norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
            norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
            norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
        ])
        self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))

    def forward(self, x):
        fmap = []

        # 1d to 2d
        b, c, t = x.shape
        if t % self.period != 0: # pad first
            n_pad = self.period - (t % self.period)
            x = F.pad(x, (0, n_pad), "reflect")
            t = t + n_pad
        x = x.view(b, c, t // self.period, self.period)

        for l in self.convs:
            x = l(x)
            x = F.leaky_relu(x, LRELU_SLOPE)
            fmap.append(x)
        x = self.conv_post(x)
        fmap.append(x)
        x = torch.flatten(x, 1, -1)

        return x, fmap

class MultiPeriodDiscriminator(torch.nn.Module):
    def __init__(self):
        super(MultiPeriodDiscriminator, self).__init__()
        self.discriminators = nn.ModuleList([
            DiscriminatorP(2),
            DiscriminatorP(3),
            DiscriminatorP(5),
            DiscriminatorP(7),
            DiscriminatorP(11),
        ])

    def forward(self, y, y_hat):
        y_d_rs = []
        y_d_gs = []
        fmap_rs = []
        fmap_gs = []
        for i, d in enumerate(self.discriminators):
            y_d_r, fmap_r = d(y)
            y_d_g, fmap_g = d(y_hat)
            y_d_rs.append(y_d_r)
            fmap_rs.append(fmap_r)
            y_d_gs.append(y_d_g)
            fmap_gs.append(fmap_g)

        return y_d_rs, y_d_gs, fmap_rs, fmap_gs


class DiscriminatorS(torch.nn.Module):
    def __init__(self, use_spectral_norm=False):
        super(DiscriminatorS, self).__init__()
        norm_f = weight_norm if use_spectral_norm == False else spectral_norm
        self.convs = nn.ModuleList([
            norm_f(Conv1d(1, 128, 15, 1, padding=7)),
            norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
            norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
            norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
            norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
            # 히든 레이어 추가 실험

            norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
            norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
        ])
        self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))

    def forward(self, x):
        fmap = []
        for l in self.convs:
            x = l(x)
            x = F.leaky_relu(x, LRELU_SLOPE)
            fmap.append(x)
        x = self.conv_post(x)
        fmap.append(x)
        x = torch.flatten(x, 1, -1)

        return x, fmap


class MultiScaleDiscriminator(torch.nn.Module):
    def __init__(self):
        super(MultiScaleDiscriminator, self).__init__()
        self.discriminators = nn.ModuleList([
            DiscriminatorS(use_spectral_norm=True),
            DiscriminatorS(),
            DiscriminatorS(),
        ])
        self.meanpools = nn.ModuleList([
            AvgPool1d(4, 2, padding=2),
            AvgPool1d(4, 2, padding=2)
        ])

    def forward(self, y, y_hat):
        y_d_rs = []
        y_d_gs = []
        fmap_rs = []
        fmap_gs = []
        for i, d in enumerate(self.discriminators):
            if i != 0:
                y = self.meanpools[i-1](y)
                y_hat = self.meanpools[i-1](y_hat)
            y_d_r, fmap_r = d(y)
            y_d_g, fmap_g = d(y_hat)
            y_d_rs.append(y_d_r)
            fmap_rs.append(fmap_r)
            y_d_gs.append(y_d_g)
            fmap_gs.append(fmap_g)

        return y_d_rs, y_d_gs, fmap_rs, fmap_gs

# hifi - train

In [None]:
import wandb
os.environ["WANDB_API_KEY"] = "4c7d91ca2cd073dc0f1c148b6e4bacff713df5c6"
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mjeo053472[0m ([33mjeo053472-Chung-Ang University[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
# WandB 프로젝트 초기화
name = "HiFi_GAN_RPGAN_R1R2_lambda(0.0005)_perception(0.0001)"   # ex) WGAN (노션 참고)
id = 'rpgan_r1r2_lambda0.0005_perception(0.0001)'     # 프로젝트마다 고유 id 부여 (실험마다 다르게 설정해야함, 만약 전 실험을 이어서 진행하고 싶다면 같은 id 기재)
wandb.init(project="PAPER_new", name=name ,id = id, resume = 'allow')

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


## Loss

RPGAN Loss

In [None]:
def feature_loss(fmap_r, fmap_g):
    loss = 0
    for dr, dg in zip(fmap_r, fmap_g):
        for rl, gl in zip(dr, dg):
            loss += torch.mean(torch.abs(rl - gl))

    return loss*2

def discriminator_loss(disc_real_outputs, disc_generated_outputs): #모양 바꿈
     loss = 0
     for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
         relativistic_logits = dr - dg
         loss += torch.mean(F.softplus(-relativistic_logits))
     return loss

def generator_loss(disc_real_outputs, disc_generated_outputs): #discriminator가 뱉은 값을 받아서 업뎃을 해야지
     loss = 0
     for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
         relativistic_logits = dr - dg
         loss += torch.mean(F.softplus(relativistic_logits))
     return loss

In [None]:
from torchaudio.transforms import Resample
import tempfile
import librosa

def lsd(est ,target):
    assert est.shape == target.shape, "Spectrograms must have the same shape."
    est = est.squeeze(0).squeeze(0) ** 2
    target = target.squeeze(0).squeeze(0) ** 2
    # Compute the log of the magnitude spectrograms (adding a small epsilon to avoid log(0))
    epsilon = 1e-10
    log_spectrogram1 = torch.log10(target + epsilon)
    log_spectrogram2 = torch.log10(est + epsilon)
    squared_diff = (log_spectrogram1 - log_spectrogram2) ** 2
    squared_diff = torch.mean(squared_diff, dim = 1) ** 0.5
    lsd = torch.mean(squared_diff, dim = 0)
    return lsd

def lsd_hf(est, target, hf_ratio=0.25):
    assert est.shape == target.shape, "Spectrograms must have the same shape."
    est = est.squeeze(0).squeeze(0) ** 2
    target = target.squeeze(0).squeeze(0) ** 2

    # Define high-frequency range
    num_freq_bins = est.shape[0]
    hf_start = int(num_freq_bins * (1 - hf_ratio))  # Starting index for high frequencies

    # Focus on high-frequency bands
    est_hf = est[hf_start:, :]
    target_hf = target[hf_start:, :]

    # Compute the log of the magnitude spectrograms (adding a small epsilon to avoid log(0))
    epsilon = 1e-10
    log_spectrogram1 = torch.log10(target_hf + epsilon)
    log_spectrogram2 = torch.log10(est_hf + epsilon)
    squared_diff = (log_spectrogram1 - log_spectrogram2) ** 2
    squared_diff = torch.mean(squared_diff, dim=1) ** 0.5
    lsd_hf = torch.mean(squared_diff, dim=0)

    return lsd_hf

def extract_f0_from_audio(audio, sr, fmin=50, fmax=500):
    audio_np = audio.cpu().numpy()  # Convert to numpy for librosa
    f0, voiced_flag, _ = librosa.pyin(audio_np, fmin=fmin, fmax=fmax, sr=sr)
    f0 = torch.tensor(f0, dtype=torch.float32)  # Convert back to tensor
    f0[~torch.tensor(voiced_flag, dtype=torch.bool)] = 0  # Set unvoiced regions to 0
    return f0

def f0_rmse(f0_pred, f0_target):
    assert f0_pred.shape == f0_target.shape, "F0 shapes must match."
    squared_error = (f0_pred - f0_target) ** 2
    mse = torch.mean(squared_error)
    rmse = torch.sqrt(mse)
    return rmse

In [None]:
!pip install praat-parselmouth



In [None]:
import parselmouth
import torch
import torch.nn.functional as F


def formant_loss(original_signal, reconstructed_signal, sr, num_formants=5, time_step=0.01):
    """
    Formant-based differentiable loss function using parselmouth for formant extraction.

    Args:
        original_signal (torch.Tensor): Original audio signal (B, T).
        reconstructed_signal (torch.Tensor): Reconstructed/generated audio signal (B, T).
        sr (int): Sampling rate of the signals.
        num_formants (int): Number of formants to consider.
        time_step (float): Time step for formant analysis.

    Returns:
        torch.Tensor: Formant-based loss value.
    """
    if original_signal.dim() == 1:
        original_signal = original_signal.unsqueeze(0)
    if reconstructed_signal.dim() == 1:
        reconstructed_signal = reconstructed_signal.unsqueeze(0)

    batch_loss = []
    for b in range(original_signal.shape[0]):
        # Convert signals to numpy for parselmouth
        orig_np = original_signal[b].detach().cpu().numpy()
        recon_np = reconstructed_signal[b].detach().cpu().numpy()

        # Create parselmouth Sound objects
        orig_sound = parselmouth.Sound(orig_np, sampling_frequency=sr)
        recon_sound = parselmouth.Sound(recon_np, sampling_frequency=sr)

        # Extract formants using Praat's Burg algorithm
        orig_formant = orig_sound.to_formant_burg(time_step=time_step)
        recon_formant = recon_sound.to_formant_burg(time_step=time_step)

        # Get time steps from the original formant analysis
        times = orig_formant.ts()

        # Compute formant loss
        frame_losses = []
        for t in times:
            orig_formants = []
            recon_formants = []

            for formant_number in range(1, num_formants + 1):
                # Extract formant values at time t
                orig_value = orig_formant.get_value_at_time(formant_number=formant_number, time=t)
                recon_value = recon_formant.get_value_at_time(formant_number=formant_number, time=t)

                # Handle NaN values
                if torch.isnan(torch.tensor(orig_value)) or torch.isnan(torch.tensor(recon_value)):
                    orig_value = 0.0
                    recon_value = 0.0

                orig_formants.append(orig_value)
                recon_formants.append(recon_value)

            # Calculate mean squared error for the formants at time t
            orig_formants = torch.tensor(orig_formants, dtype=torch.float32, device=original_signal.device)
            recon_formants = torch.tensor(recon_formants, dtype=torch.float32, device=reconstructed_signal.device)
            frame_losses.append(torch.mean((orig_formants - recon_formants) ** 2))

        # Average the loss over all frames for the current batch
        batch_loss.append(torch.mean(torch.stack(frame_losses)))

    # Average the loss over the batch
    total_loss = torch.mean(torch.stack(batch_loss))
    return total_loss

## train

In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import itertools
import os
import time
import argparse
import json
import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DistributedSampler, DataLoader
import torch.multiprocessing as mp
from torch.distributed import init_process_group
from torch.nn.parallel import DistributedDataParallel
import easydict
import torchaudio

torch.backends.cudnn.benchmark = True

#penalty 계산 함수
def compute_r1_penalty(discriminators, real_samples, device, gamma=0.0005):
    """
    Computes the R1 regularization penalty for each discriminator in the given discriminators list.
    Args:
        discriminators: A list of individual discriminators (DiscriminatorP or DiscriminatorS).
        real_samples: The real samples (y).
        device: The device (cuda or cpu).
        gamma: The regularization weight (default: 10.0).
    Returns:
        r1_penalty: The computed R1 penalty.
    """
    real_samples = real_samples.requires_grad_(True)  # Enable gradient tracking
    r1_penalties = []

    for d in discriminators:
        real_outputs, _ = d(real_samples)  # Get discriminator outputs for real samples
        grad_real = torch.autograd.grad(
            outputs=real_outputs.sum(),
            inputs=real_samples,
            create_graph=True, retain_graph=True, only_inputs=True
        )[0]  # Compute gradients

        r1_penalty = (grad_real.norm(2, dim=(1, 2)) ** 2).mean()  # Compute gradient norm squared and take mean
        r1_penalties.append(r1_penalty)

    return (gamma / 2) * sum(r1_penalties)  # Apply weighting factor

def compute_r2_penalty(discriminators, fake_samples, device, gamma=0.0005):
    """
    Computes the R2 regularization penalty for each discriminator in the given discriminators list.
    Args:
        discriminators: A list of individual discriminators (DiscriminatorP or DiscriminatorS).
        fake_samples: The fake samples (generated by the generator).
        device: The device (cuda or cpu).
        gamma: The regularization weight (default: 10.0).
    Returns:
        r2_penalty: The computed R2 penalty.
    """
    fake_samples = fake_samples.requires_grad_(True)  # Enable gradient tracking
    r2_penalties = []

    for d in discriminators:
        fake_outputs, _ = d(fake_samples)  # Get discriminator outputs for fake samples
        grad_fake = torch.autograd.grad(
            outputs=fake_outputs.sum(),
            inputs=fake_samples,
            create_graph=True, retain_graph=True, only_inputs=True
        )[0]  # Compute gradients

        r2_penalty = (grad_fake.norm(2, dim=(1, 2)) ** 2).mean()  # Compute gradient norm squared and take mean
        r2_penalties.append(r2_penalty)

    return (gamma / 2) * sum(r2_penalties)  # Apply weighting factor


def train(rank, a, h): # rank : 0
    # SCOREQ 초기화
    scoreq_nr = scoreq.Scoreq(data_domain='natural', mode='nr')  # No-reference

    # PESQ 초기화
    pesq = PesqLoss(0.5,h.sampling_rate).eval()
    for param in pesq.parameters():
      param.requires_grad = False

    # # lambda 값 설정
    # ex) loss_scale_formant = 0.001
    # ex) loss_scale_r1r2 = #

    if h.num_gpus > 1:
        init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'],
                           world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank)

    torch.cuda.manual_seed(h.seed)
    device = torch.device('cuda:{:d}'.format(rank))

    generator = Generator(h).to(device)
    mpd = MultiPeriodDiscriminator().to(device)
    msd = MultiScaleDiscriminator().to(device)

    if rank == 0:
        os.makedirs(a.checkpoint_path, exist_ok=True)
        print("checkpoints directory : ", a.checkpoint_path)

    if os.path.isdir(a.checkpoint_path):                 # 만약 체크포인트 폴더에 g_ 또는 do_로 시작하는 파일이 있다면 체크포인트 로드됨.
        cp_g = scan_checkpoint(a.checkpoint_path, 'g_')
        cp_do = scan_checkpoint(a.checkpoint_path, 'do_')

    steps = 0
    if cp_g is None or cp_do is None:         # 처음부터 학습
        state_dict_do = None
        last_epoch = -1
    else:                                     # 체크포인트부터 학습
        state_dict_g = load_checkpoint(cp_g, device)
        state_dict_do = load_checkpoint(cp_do, device)
        generator.load_state_dict(state_dict_g['generator'])
        mpd.load_state_dict(state_dict_do['mpd'])
        msd.load_state_dict(state_dict_do['msd'])
        steps = state_dict_do['steps'] + 1
        last_epoch = state_dict_do['epoch']

    if h.num_gpus > 1:
        generator = DistributedDataParallel(generator, device_ids=[rank]).to(device)
        mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
        msd = DistributedDataParallel(msd, device_ids=[rank]).to(device)

    optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
    optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()),
                                h.learning_rate, betas=[h.adam_b1, h.adam_b2])

    if state_dict_do is not None:
        optim_g.load_state_dict(state_dict_do['optim_g'])
        optim_d.load_state_dict(state_dict_do['optim_d'])

    scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
    scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)

    training_filelist, validation_filelist = get_dataset_filelist(a)

    trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels,
                          h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0,
                          shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device,
                          fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir)

    train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None

    train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False,
                              sampler=train_sampler,
                              batch_size=h.batch_size,
                              pin_memory=True,
                              drop_last=True)

    if rank == 0:
        validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels,
                              h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0,
                              fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning,
                              base_mels_path=a.input_mels_dir)
        validation_loader = DataLoader(validset, num_workers=1, shuffle=False,
                                       sampler=None,
                                       batch_size=1,
                                       pin_memory=True,
                                       drop_last=True)

        sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'))


    generator.train()
    mpd.train()
    msd.train()
    for epoch in range(max(0, last_epoch), a.training_epochs):  # 200 epoch
        if rank == 0:
            start = time.time()
            print("Epoch: {}".format(epoch+1))

        if h.num_gpus > 1:
            train_sampler.set_epoch(epoch)

        for i, batch in enumerate(train_loader):
            if rank == 0:
                start_b = time.time()
            x, y, _, y_mel = batch
            x = torch.autograd.Variable(x.to(device, non_blocking=True))
            y = torch.autograd.Variable(y.to(device, non_blocking=True))
            y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
            y = y.unsqueeze(1)

            y_g_hat = generator(x)
            y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size,
                                          h.fmin, h.fmax_for_loss)

            optim_d.zero_grad()

            # MPD
            y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
            loss_disc_f = discriminator_loss(y_df_hat_r, y_df_hat_g)

            # MSD
            y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
            loss_disc_s = discriminator_loss(y_ds_hat_r, y_ds_hat_g)

            #R1 Penalty 추가
            # ✅ MPD & MSD 각각의 Discriminators 리스트를 전달
            r1_penalty_mpd = compute_r1_penalty(mpd.discriminators, y, device, gamma=0.0005)
            r1_penalty_msd = compute_r1_penalty(msd.discriminators, y, device, gamma=0.0005)

            r2_penalty_mpd = compute_r2_penalty(mpd.discriminators, y_g_hat.detach(), device, gamma=0.0005)
            r2_penalty_msd = compute_r2_penalty(msd.discriminators, y_g_hat.detach(), device, gamma=0.0005)

            # ✅ 최종 Discriminator Loss 계산 (R1 + R2 Penalty 포함)
            loss_disc_all = loss_disc_s + loss_disc_f + r1_penalty_mpd + r1_penalty_msd + r2_penalty_mpd + r2_penalty_msd
            wandb.log({"train/loss_disc_all" : loss_disc_all , "train/loss_disc_f" : loss_disc_f , "train/loss_disc_s " : loss_disc_s ,
                       "train/r1_penalty_mpd": r1_penalty_mpd, "train/r1_penalty_msd": r1_penalty_msd, "train/r2_penalty_mpd": r2_penalty_mpd, "train/r2_penalty_msd": r2_penalty_msd})

            loss_disc_all.backward()
            optim_d.step()

            # Generator
            optim_g.zero_grad()

            # L1 Mel-Spectrogram Loss
            loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45

            y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
            y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
            loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
            loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
            loss_gen_f = generator_loss(y_df_hat_r,y_df_hat_g)
            loss_gen_s = generator_loss(y_ds_hat_r, y_ds_hat_g)

            # formant loss
            loss_formant = formant_loss(y.squeeze(1),y_g_hat.squeeze(1), h.sampling_rate, num_formants = 5, time_step = 0.01)
            loss_formant = loss_formant * 0.0001

            loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel + loss_formant   # generator loss 계산
            wandb.log({"train/loss_gen_all" : loss_gen_all , "train/loss_gen_s" : loss_gen_s , "train/loss_gen_f" : loss_gen_f,
                       "train/loss_fm_s" : loss_fm_s , "train/loss_fm_f" : loss_fm_f , "train/loss_mel" : loss_mel, "train/loss_formant" : loss_formant})

            loss_gen_all.backward()
            optim_g.step()

            if rank == 0:
                # STDOUT logging
                if steps % a.stdout_interval == 0:
                    with torch.no_grad():
                        mel_error = F.l1_loss(y_mel, y_g_hat_mel).item()

                    print('Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'.
                          format(steps, loss_gen_all, mel_error, time.time() - start_b))

                # checkpointing
                if steps % a.checkpoint_interval == 0 and steps != 0:
                    checkpoint_path = "{}/g_{:08d}".format(a.checkpoint_path, steps)
                    save_checkpoint(checkpoint_path,
                                    {'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
                    checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps)
                    save_checkpoint(checkpoint_path,
                                    {'mpd': (mpd.module if h.num_gpus > 1
                                                         else mpd).state_dict(),
                                     'msd': (msd.module if h.num_gpus > 1
                                                         else msd).state_dict(),
                                     'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
                                     'epoch': epoch})
                # Validation
                if steps % a.validation_interval == 0 and steps != 0:
                    generator.eval()
                    torch.cuda.empty_cache()

                    val_err_tot = 0
                    val_err_tot_rmse = 0
                    val_err_tot_f0rmse = 0
                    val_err_tot_lsd = 0
                    val_err_tot_lsd_hf = 0
                    val_err_tot_mos = 0

                    with torch.no_grad():
                        for j, batch in enumerate(validation_loader):
                            x, y, _, y_mel = batch
                            y_g_hat = generator(x.to(device))
                            y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
                            y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
                                                          h.hop_size, h.win_size,
                                                          h.fmin, h.fmax_for_loss)
                            # L1 loss 계산
                            val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item()

                            # RMSE 계산
                            val_err_tot_rmse += torch.sqrt(F.mse_loss(y_g_hat_mel, y_mel, reduction='mean')).item()

                            # F0_RMSE 계산
                            f0_y = extract_f0_from_audio(y.squeeze(1), h.sampling_rate)
                            f0_y_g_hat = extract_f0_from_audio(y_g_hat.squeeze(1), h.sampling_rate)
                            val_err_tot_f0rmse += f0_rmse(f0_y_g_hat, f0_y).item()

                            # LSD
                            val_err_tot_lsd += lsd(y_g_hat_mel,  y_mel).item()

                            # LSD_HF
                            val_err_tot_lsd_hf += lsd_hf(y_g_hat_mel,  y_mel).item()

                            # SCOREQ 계산
                            # 참조 및 생성된 오디오를 파일로 저장 (SCOREQ는 파일 경로로 입력 받음)
                            test_audio_path = f"temp_test.wav"
                            torchaudio.save(test_audio_path, y_g_hat.squeeze(1).cpu(), h.sampling_rate)

                            # No-reference 모드에서 품질 평가
                            val_err_tot_mos += scoreq_nr.predict(test_path=test_audio_path).item()

                            if j <= 4:
                                if steps !=0:
                                    wandb.log({
                                        "Predicted Audio": wandb.Audio(
                                        y_g_hat.squeeze().cpu().numpy(),
                                        sample_rate=h.sampling_rate,
                                        caption="Predicted Audio"
                                    ),
                                    "Ground Truth Audio": wandb.Audio(
                                        y[0].squeeze().cpu().numpy(),
                                        sample_rate=h.sampling_rate,
                                        caption="Ground Truth Audio"
                                    )})

                                y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels,
                                                             h.sampling_rate, h.hop_size, h.win_size,
                                                             h.fmin, h.fmax)
                                wandb.log({'generated/y_hat_spec_{}'.format(j):
                                              plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy())})

                        val_err_l1 = val_err_tot / (j+1)
                        val_err_rmse = val_err_tot_rmse / (j+1)
                        val_err_f0rmse = val_err_tot_f0rmse / (j+1)
                        val_err_lsd = val_err_tot_lsd / (j+1)
                        val_err_lsd_hf = val_err_tot_lsd_hf / (j+1)
                        val_err_mos = val_err_tot_mos / (j+1)

                        wandb.log({"validation/val_err_l1": val_err_l1, "validation/val_err_rmse": val_err_rmse, "validation/val_err_f0rmse": val_err_f0rmse,
                                   "validation/val_err_lsd": val_err_lsd, "validation/val_err_lsd_hf": val_err_lsd_hf, "validation/val_err_mos": val_err_mos})

                    generator.train()

            steps += 1

        scheduler_g.step()
        scheduler_d.step()

        if rank == 0:
            print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start)))
    # checkpointing
    checkpoint_path = "{}/g_{:08d}".format(a.checkpoint_path, steps)
    save_checkpoint(checkpoint_path,
                    {'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
    checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps)
    save_checkpoint(checkpoint_path,
                    {'mpd': (mpd.module if h.num_gpus > 1
                                          else mpd).state_dict(),
                      'msd': (msd.module if h.num_gpus > 1
                                          else msd).state_dict(),
                      'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
                      'epoch': epoch})


def main():
    print('Initializing Training Process..')

    a = easydict.EasyDict({
      "group_name" : None,
      "input_wavs_dir" : 'LJSpeech-1.1/wavs',
      "input_mels_dir" : None,                              # fine-tuning 안할거면 필요 없음
      "input_training_file" : 'LJSpeech-1.1/training.txt',
      "input_validation_file" : 'LJSpeech-1.1/validation.txt',
      "checkpoint_path" : 'cp_HiFi_GAN_RPGAN_R1R2_lambda(0.0005)_perception(0.0001)',                     # 체크포인트 폴더명 지정 (cp_Wandb name이랑 똑같이 설정 -> ex) cp_WGAN )
      "config" : 'config_v1.json',
      "training_epochs" : 200,
      "stdout_interval" : 50,
      "checkpoint_interval" : 500,
      "summary_interval" : 100,
      "validation_interval" : 1000,
      "fine_tuning" : False
      })


    with open(a.config) as f:
        data = f.read()

    json_config = json.loads(data)
    h = AttrDict(json_config)
    build_env(a.config, 'config.json', a.checkpoint_path)

    torch.manual_seed(h.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(h.seed)
        h.num_gpus = torch.cuda.device_count()
        h.batch_size = int(h.batch_size / h.num_gpus)
        print('Batch size per GPU :', h.batch_size)
    else:
        pass

    if h.num_gpus > 1:
        mp.spawn(train, nprocs=h.num_gpus, args=(a, h,))
    else:
        train(0, a, h)
        wandb.finish()

In [None]:
import logging
logging.getLogger().setLevel(logging.ERROR)

In [None]:
if __name__ == '__main__':
  # a = Args()
  main()

Initializing Training Process..
Batch size per GPU : 32
SCOREQ running on: cuda


The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  self.delegate = real_initialize(
See https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/changes_to_package_header for more information
'config' is validated against ConfigStore schema with the same name.
This behavior is deprecated in Hydra 1.1 and will be removed in Hydra 1.2.
See https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/automatic_schema_matching for migration instructions.
  state = load_checkpoint_to_cpu(filename, arg_overrides)
The strict flag in the compose API is deprecated.
See https://hydra.cc/docs/1.2/upgrades/0.11_to_1.0/strict_mode_flag_deprecated for more info.

'config' is validated against ConfigStore schema with the same name.
This behavior is deprecated in Hydra 1.1 and will be removed in Hydra 1.2.
See https://hydra.cc/docs/1.2/upgrades/1.0_to_1.1/automatic_schema_matching for migration instructions.
  scoreq_nr = scoreq.Scoreq(data

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Steps : 49850, Gen Loss Total : 25.430, Mel-Spec. Error : 0.271, s/b : 1.394
Steps : 49900, Gen Loss Total : 26.313, Mel-Spec. Error : 0.273, s/b : 1.395
Steps : 49950, Gen Loss Total : 24.962, Mel-Spec. Error : 0.267, s/b : 1.390
Steps : 50000, Gen Loss Total : 26.163, Mel-Spec. Error : 0.264, s/b : 1.413
Saving checkpoint to cp_HiFi_GAN_RPGAN_R1R2_lambda(0.0005)_perception(0.0001)/g_00050000
Complete.
Saving checkpoint to cp_HiFi_GAN_RPGAN_R1R2_lambda(0.0005)_perception(0.0001)/do_00050000
Complete.
SCOREQ | No-Reference Mode | Domain natural | temp_test.wav: 4.1556
SCOREQ | No-Reference Mode | Domain natural | temp_test.wav: 4.2045
SCOREQ | No-Reference Mode | Domain natural | temp_test.wav: 4.3496
SCOREQ | No-Reference Mode | Domain natural | temp_test.wav: 4.2432
SCOREQ | No-Reference Mode | Domain natural | temp_test.wav: 4.2549
SCOREQ | No-Reference Mode | Domain natural | temp_test.wav: 4.2526
SCOREQ | No-Referenc