We're going to use the Generator and discriminators to thus train on the LJSpeech dataset. The dataset consists of around 13,000 samples reading from 7 non-fiction books. They vary from 1 - 10 seconds and total around 24h.  



In [26]:
import argparse
import jax
import os
import librosa
import optax
import wandb
import equinox as eqx

from Generator import Generator
from Discriminators import MultiPeriodDiscriminator, MultiScaleDiscriminator, feature_loss, generator_loss, discriminator_loss

def create_parser():
    parser = argparse.ArgumentParser(description="Arguments for training HiFiGaN")

    parser.add_argument("--dataset_path", "Path to the dataset to use (LJSpeech)")

    parser.add_argument("--learning_rate", "Learning rate during training")

    parser.add_argument("--output_path", "Path to store model weights")

    return parser

def save_model(model, path):
    eqx.tree_serialise_leaves(path, model)

def get_dataset(dataset_path):
    mel_dir = os.path.join(dataset_path, 'mel_spectrograms')
    wav_dir = os.path.join(dataset_path, 'processed_wavs')

    mels = []
    wavs = []
    for filename in os.listdir(mel_dir):
        np_data = jax.numpy.load(os.path.join(mel_dir, filename), allow_pickle=True)
        # print(np_data.shape)
        mels.append(np_data[:,:-1])

    for filename in os.listdir(wav_dir):
        wav_data, _ = librosa.load(os.path.join(wav_dir, filename))
        # print(wav_data.shape)

        wavs.append(jax.numpy.array(wav_data))

    return jax.numpy.array(mels), jax.numpy.array(wavs)


def train_hifigan(dataset_path, output_path, learning_rate=1e-4, batch_size=1, epochs=1, seed=69):
    run = wandb.init(
    # Set the project where this run will be logged
    project="HiFiGaN JAX",
    # Track hyperparameters and run metadata
    config={
        "learning_rate": learning_rate,
        "epochs": epochs,
        "batch_size": batch_size,
        "PRNG_SEED": seed,
    },
)

    key = jax.random.PRNGKey(seed)

    key1, key2, key3 = jax.random.split(key, 3)

    generator = Generator(channels_in=80, channels_out=1, key=key1)
    scale_disc = MultiScaleDiscriminator(key=key2)
    period_disc = MultiPeriodDiscriminator(key=key3)

    dataset_mels, dataset_wavs = get_dataset("dataset")


    @eqx.filter_value_and_grad
    def calculate_loss(model, x, y):
        result = jax.vmap(model)(x)
        return jax.numpy.mean(jax.numpy.abs(result - y)) # L1 loss

    @eqx.filter_jit
    def make_step(model, x, y, optim_state):
        loss, grads = calculate_loss(model, x, y)
        updates, optim_state = optim.update(grads, optim_state)
        model = eqx.apply_updates(model, updates)
        return loss, model, optim_state

    optim = optax.adam(learning_rate)
    optim_state = optim.init(generator)
    # print(dataset)
    for epoch in range(epochs):
        key, subkey = jax.random.split(key)
        perm = jax.random.permutation(subkey, len(dataset_mels))
        
        for batch_start in range(0, len(dataset_mels), batch_size):
            batch_indices = perm[batch_start: batch_start + batch_size]
            x = dataset_mels.take(batch_indices, axis=0)
            y = dataset_wavs.take(batch_indices, axis=0)
            
            # Display batch indices and data
            loss, generator, optim_state = make_step(generator, x, y, optim_state)
            # print(grads)        loss = loss.item()
            wandb.log({"loss": loss})
            # print(batch_data.shape)

            # print(res.shape)

        save_model(generator, os.path.join(output_path, f"generator_epoch_{epoch}.eqx"))

# if __name__ == "__main__":
#     parser = create_parser()
#     args = parser.parse_args()

#     train_hifigan(dataset_path=args.dataset_path, output_path=args.output_path, learning_rate=args.learning_rate)


In [27]:
train_hifigan(dataset_path="dataset", output_path="checkpoint", learning_rate=1e-5)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mtugdual-kerjan[0m ([33mtugdualk[0m). Use [1m`wandb login --relogin`[0m to force relogin


KeyboardInterrupt: 