In [81]:
import datetime
import json
import os
import time
from collections import defaultdict

import datasets
import equinox as eqx
import jax
import librosa
import matplotlib.pyplot as plt
import numpy as np
import optax
from datasets import Array2D, Features, Value
from librosa.util import normalize

from hifigan import (
    Generator,
    mel_spec_base_jit,
)

SAMPLE_RATE = 22050
SEGMENT_SIZE = 8192*10
RANDOM = jax.random.key(0)
CHECKPOINT_PATH = "checkpoints"

def transform(sample):
    """Based off the original code that can be found here: https://github.com/jik876/hifi-gan/blob/master/meldataset.py

    Args:
        sample (dict): dict entry in HF Dataset

    Returns:
        dict: updated entry
    """
    k = jax.random.key(0)
    wav = sample["audio"]["array"]
    if sample["audio"]["sampling_rate"] != SAMPLE_RATE:
        librosa.resample(wav, sample["audio"]["sampling_rate"], SAMPLE_RATE)
    wav = normalize(wav) * 0.95
    if wav.shape[0] >= SEGMENT_SIZE:
        max_audio_start = wav.shape[0] - SEGMENT_SIZE
        audio_start = jax.random.randint(k, (1,), 0, max_audio_start)[0]
        wav = wav[audio_start : audio_start + SEGMENT_SIZE]

    wav = np.expand_dims(wav, 0)

    mel = mel_spec_base_jit(wav=wav)
    return {"mel": np.array(mel), "audio": np.array(wav), "sample_rate": SAMPLE_RATE}


# # Define the exact shapes we expect from the transform function
features = Features(
    {
        "mel": Array2D(
            shape=(80, 32*10), dtype="float32"
        ),  # From mel_spec_base_jit output
        "audio": Array2D(shape=(1, 8192*10), dtype="float32"),  # From expand_dims(wav, 0)
        "sample_rate": Value(dtype="int64"),
    }
)

# lj_speech_data.save_to_disk("transformed_lj_speech")
lj_speech_data = datasets.load_dataset(
    "keithito/lj_speech", trust_remote_code=True
).with_format("jax")
lj_speech_data = lj_speech_data["train"].take(1)


In [82]:

lj_speech_data = lj_speech_data.map(
    transform,
    # num_proc=8,
    features=features,
    remove_columns=lj_speech_data.column_names,  # Remove original columns
)
lj_speech_data = lj_speech_data.with_format("jax")

k1, k2, k3 = jax.random.split(RANDOM, 3)
generator = Generator(channels_in=80, channels_out=1, key=k1)
generator = eqx.tree_deserialise_leaves("generator.eqx", generator)



Map: 100%|██████████| 1/1 [00:00<00:00,  5.90 examples/s]


In [83]:
test = lj_speech_data.select([0])
test

Dataset({
    features: ['mel', 'audio', 'sample_rate'],
    num_rows: 1
})

In [45]:
generator = eqx.filter_jit(generator)

In [33]:

test["mel"].device
start = time.perf_counter()
val = jax.vmap(generator)(jax.numpy.ones((1, 80, 8192)))
end = time.perf_counter()
print(f"Execution time: {end - start:.6f} seconds")

2025-02-13 11:08:37.977327: E external/xla/xla/service/slow_operation_alarm.cc:73] Trying algorithm eng0{} for conv %cudnn-conv-bw-filter.30 = (f32[128,1,524288]{2,1,0}, u8[0]{0}) custom-call(f32[128,1,524288]{2,1,0} %bitcast.7497, f32[128,128,7]{2,1,0} %bitcast.7290), window={size=524288 stride=3 pad=9_9}, dim_labels=bf0_oi0->bf0, custom_call_target="__cudnn$convBackwardFilter", metadata={op_name="jit(Generator)/jit(main)/eqx.nn.WeightNorm/eqx.nn.Conv/conv_general_dilated" source_file="/home/tugdual/hifigan-jax/.venv/lib/python3.11/site-packages/equinox/nn/_conv.py" source_line=239}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false} is taking a while...
2025-02-13 11:08:37.992802: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took 1.015601706s
Trying algorithm eng0{} for conv %cudnn-conv-bw-

Execution time: 36.237491 seconds


In [84]:
from IPython.display import Audio, display
from matplotlib import pyplot as plt

mel = test["mel"]
wav = test["audio"]
fig, (ax1, ax2) = plt.subplots(2, figsize=(10, 2))
im = ax1.imshow(mel[0], aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax1)


wav_pred = generator(mel[0])
mel_pred = mel_spec_base_jit(wav_pred)
im2 = ax2.imshow(mel[0], aspect="auto", origin="lower", interpolation="none")
plt.colorbar(im, ax=ax2)
fig.canvas.draw()

ValueError: cannot reshape array of size 8192 into shape (1,1,81920)

In [74]:
display(Audio(wav[0], rate=SAMPLE_RATE))
display(Audio(wav_pred[0], rate=SAMPLE_RATE))
