In [14]:
import os
import glob

import pandas as pd

import torch

from IPython.display import Audio, HTML, display

from itables import init_notebook_mode, show
init_notebook_mode(all_interactive=True)


from training.datasets import LibriTTSDatasetAcoustic

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device


<IPython.core.display.Javascript object>

device(type='cuda')

In [3]:
dataset = LibriTTSDatasetAcoustic()
index_dict = {item: index for index, item in enumerate(dataset.dataset._walker)}


def get_data(filename):
    idx = index_dict[os.path.splitext(filename)[0]]
    data = dataset[idx]
    return data
    

def load_text(file_path):
    with open(file_path, 'r') as file:
        text = file.read()
    return text


def create_audio(src: str, mimetype: str = "audio/wav") -> str:
    return f"""<audio controls="controls"><source src="{src}" type="{mimetype}" />Your browser does not support the audio element.</audio>"""


def audio_for_speaker_libri(dataset = "train-clean-100"):
    audio_dir = f"datasets_cache/LIBRITTS/LibriTTS/{dataset}"

    chapter_per_author = [
        (
            speaker_dir,
            os.path.join(
                audio_dir,
                speaker_dir,
                os.listdir(os.path.join(audio_dir, speaker_dir))[0]
            )
        )
        for speaker_dir in os.listdir(audio_dir)
    ]

    audios_for_author = {
        speaker: sorted(
            glob.glob(os.path.join(audio_dir, '*.wav')),
            key=os.path.getsize
        )[-1]
        for speaker, audio_dir in chapter_per_author
    }

    audios_and_text = [
        {
            "READER": speaker,
            "AUDIO": create_audio(audio_path),
            **get_data(os.path.basename(audio_path)),
        }
        for speaker, audio_path in audios_for_author.items()
    ]

    return audios_and_text


audio_for_speaker_libri("train-clean-360")

[{'READER': '1061',
  'AUDIO': '<audio controls="controls"><source src="datasets_cache/LIBRITTS/LibriTTS/train-clean-360\\1061\\146197\\1061_146197_000015_000000.wav" type="audio/wav" />Your browser does not support the audio element.</audio>',
  'id': '1061_146197_000015_000000',
  'wav': tensor([[ 5.4999e-11, -3.9558e-10,  1.1043e-09,  ...,  7.8759e-02,
            2.6045e-01, -1.5527e-02]]),
  'mel': tensor([[ -3.0387,  -2.2826,  -2.4592,  ...,  -3.3541,  -3.6037,  -3.0326],
          [ -3.2870,  -3.3729,  -4.0267,  ...,  -4.8843,  -4.8406,  -3.9526],
          [ -4.1049,  -4.2063,  -5.7389,  ...,  -5.6448,  -5.1445,  -4.1173],
          ...,
          [ -8.3479,  -8.5202,  -9.5800,  ...,  -9.9146, -10.0145,  -4.5424],
          [ -8.5265,  -8.7139,  -9.7501,  ...,  -9.8467, -10.0137,  -4.5241],
          [ -9.5756,  -9.7532, -10.9799,  ..., -10.9260, -11.0370,  -5.0525]]),
  'pitch': tensor([279.1139, 279.1139, 279.1139, 279.1139, 279.1139, 279.1139, 279.1139,
          279.1139, 2

In [4]:
from training.modules import AcousticModule
from training.loss import FastSpeech2LossGen
import ipywidgets as widgets

loss = FastSpeech2LossGen(fine_tuning=False)

ckpt_files = [f for f in os.listdir('checkpoints') if f.startswith('epoch=') and f.endswith('.ckpt')]
# Sort by epoch number
ckpt_files = sorted(ckpt_files, key=lambda s: int(s.split('-')[0].split('=')[1]))

dropdown = widgets.Dropdown(
    options=ckpt_files,
    description='Select Checkpoint:',
    value=ckpt_files[-1], # Select the last checkpoint by default
)

button = widgets.Button(description="Load Checkpoint")
output = widgets.Output()

display(dropdown)
display(button, output)

selected_checkpoint = f'./checkpoints/{dropdown.value}'

model = AcousticModule.load_from_checkpoint(selected_checkpoint).to(device)
model.eval()
print(f"Loaded checkpoint: {selected_checkpoint}")

def on_button_clicked(b):
    global model
    button.disabled = True
    
    with output:
        print("Loading model, please wait...")
        
        # Load model with the selected checkpoint
        model = AcousticModule.load_from_checkpoint(f'./checkpoints/{dropdown.value}').to(device)
        model.eval()

        print(f"Loaded checkpoint: {dropdown.value}")
    button.disabled = False

button.on_click(on_button_clicked)


Dropdown(description='Select Checkpoint:', index=5, options=('epoch=516-step=100828.ckpt', 'epoch=4796-step=43…

Button(description='Load Checkpoint', style=ButtonStyle())

Output()



Loaded checkpoint: ./checkpoints/epoch=5482-step=601951.ckpt


In [5]:
def prepared_dataset_subset(dataset = "train-clean-100"):
    # Filter speakers by the selected subset
    speakers_df = pd.read_csv(
        "./datasets_cache/LIBRITTS/LibriTTS/speakers.tsv",
        sep="\t",
        names=["READER", "GENDER", "SUBSET", "NAME"],
    )
    selected_speakers_subset = speakers_df[speakers_df["SUBSET"] == dataset]
    
    audio_example_for_author = audio_for_speaker_libri(dataset)

    # Convert the dictionary to a DataFrame
    audio_example_df = pd.DataFrame(audio_example_for_author)
    selected_speakers_subset = pd.merge(selected_speakers_subset, audio_example_df, on='READER')
    selected_speakers_subset['READER'] = selected_speakers_subset['READER'].astype(int)
    
    return selected_speakers_subset

example_demo = prepared_dataset_subset("train-clean-360")
example_demo[['READER', 'GENDER', 'SUBSET', 'NAME', 'AUDIO']]
# show(example_demo[['READER', 'GENDER', 'SUBSET', 'NAME', 'AUDIO']])

Unnamed: 0,READER,GENDER,SUBSET,NAME,AUDIO
0,1061,F,train-clean-360,Missie,"<audio controls=""controls""><source src=""datase..."


In [27]:
from model.helpers.tools import get_mask_from_lengths

example_demo2 = []

for index, row in example_demo.iterrows():
    batch = [
        r.to(device) if isinstance(r, torch.Tensor) else r
        for r in dataset.collate_fn([row])
    ]

    (
        _,
        _,
        speakers,
        texts,
        src_lens,
        mels,
        pitches,
        pitches_stat,
        mel_lens,
        langs,
        attn_priors,
        _,
    ) = batch

    src_mask = get_mask_from_lengths(src_lens)
    mel_mask = get_mask_from_lengths(mel_lens)

    with torch.no_grad():
        output = model.acoustic_model.forward_train(
            x=texts,
            speakers=speakers,
            src_lens=src_lens,
            mels=mels,
            mel_lens=mel_lens,
            pitches=pitches,
            pitches_range=pitches_stat, # type: ignore
            langs=langs,
            attn_priors=attn_priors,
        )

    y_pred = output["y_pred"]
    log_duration_prediction = output["log_duration_prediction"]
    p_prosody_ref = output["p_prosody_ref"]
    p_prosody_pred = output["p_prosody_pred"]
    pitch_prediction = output["pitch_prediction"]

    (
        total_loss,
        mel_loss,
        ssim_loss,
        duration_loss,
        u_prosody_loss,
        p_prosody_loss,
        pitch_loss,
        ctc_loss,
        bin_loss,
    ) = loss(
        src_masks=src_mask,
        mel_masks=mel_mask,
        mel_targets=mels,
        mel_predictions=y_pred,
        log_duration_predictions=log_duration_prediction,
        u_prosody_ref=output["u_prosody_ref"],
        u_prosody_pred=output["u_prosody_pred"],
        p_prosody_ref=p_prosody_ref,
        p_prosody_pred=p_prosody_pred,
        pitch_predictions=pitch_prediction,
        p_targets=output["pitch_target"],
        durations=output["attn_hard_dur"],
        attn_logprob=output["attn_logprob"],
        attn_soft=output["attn_soft"],
        attn_hard=output["attn_hard"],
        src_lens=src_lens,
        mel_lens=mel_lens,
        step=50000,
    )

    # example_demo2[index] = {
    #     "row": row.to_dict(),
    #     "result": output["y_pred"],
    #     "loss": {
    #         "total_loss": total_loss.item(),
    #         "mel_loss": mel_loss.item(),
    #         "ssim_loss": ssim_loss.item(),
    #         "duration_loss": duration_loss.item(),
    #         "u_prosody_loss": u_prosody_loss.item(),
    #         "p_prosody_loss": p_prosody_loss.item(),
    #         "pitch_loss": pitch_loss.item(),
    #         "ctc_loss": ctc_loss.item(),
    #         "bin_loss": bin_loss.item(),
    #     }
    # }

    y_pred = model.vocoder_module.forward(y_pred)


    # display(HTML(Audio(y_pred.detach().cpu().numpy(), rate=22050)._repr_html_()))

    example_demo2.append({
        "READER": row["READER"],
        "GENDER": row["GENDER"],
        "SUBSET": row["SUBSET"],
        "NAME": row["NAME"],
        "AUDIO": row["AUDIO"],
        "PRED_AUDIO": HTML(Audio(y_pred.detach().cpu().numpy(), rate=22050)._repr_html_()),
        "ID": row["id"],
        "TEXT": row["normalized_text"],
        # Losses
        "total_loss": total_loss.item(),
        "mel_loss": mel_loss.item(),
        "ssim_loss": ssim_loss.item(),
        "duration_loss": duration_loss.item(),
        "u_prosody_loss": u_prosody_loss.item(),
        "p_prosody_loss": p_prosody_loss.item(),
        "pitch_loss": pitch_loss.item(),
        "ctc_loss": ctc_loss.item(),
        "bin_loss": bin_loss.item(),
    })


In [26]:
show(pd.DataFrame(example_demo2))
# Order by loss, ascending, generate audios for the top 10

READER,GENDER,SUBSET,NAME,AUDIO,PRED_AUDIO,ID,TEXT,total_loss,mel_loss,ssim_loss,duration_loss,u_prosody_loss,p_prosody_loss,pitch_loss,ctc_loss,bin_loss
Loading... (need help?),,,,,,,,,,,,,,,,


In [None]:
import torchaudio
from training.preprocess import PreprocessLibriTTS

file_audio = 'datasets_cache/LIBRITTS/LibriTTS/train-clean-360/1061/146197/1061_146197_000015_000000.wav'

lang = "en"
preprocess_libtts = PreprocessLibriTTS(lang)

waveform, sample_rate = torchaudio.load(file_audio) # type: ignore

row = example_demo.iloc[0]

# Retrive the dataset row
data = (waveform, sample_rate, row.TEXT, row.TEXT, row.READER, 111, "111")

data = preprocess_libtts.acoustic(data)
data

PreprocessForAcousticResult(wav=tensor([ 5.4999e-11, -3.9558e-10,  1.1043e-09,  ...,  7.8759e-02,
         2.6045e-01, -1.5527e-02]), mel=tensor([[ -3.0387,  -2.2826,  -2.4592,  ...,  -3.3541,  -3.6037,  -3.0326],
        [ -3.2870,  -3.3729,  -4.0267,  ...,  -4.8843,  -4.8406,  -3.9526],
        [ -4.1049,  -4.2063,  -5.7389,  ...,  -5.6448,  -5.1445,  -4.1173],
        ...,
        [ -8.3479,  -8.5202,  -9.5800,  ...,  -9.9146, -10.0145,  -4.5424],
        [ -8.5265,  -8.7139,  -9.7501,  ...,  -9.8467, -10.0137,  -4.5241],
        [ -9.5756,  -9.7532, -10.9799,  ..., -10.9260, -11.0370,  -5.0525]]), pitch=tensor([279.1139, 279.1139, 279.1139, 279.1139, 279.1139, 279.1139, 279.1139,
        279.1139, 279.1139, 279.1139, 279.1139, 279.1139, 275.6250, 272.2222,
        262.5000, 259.4118, 272.3529, 285.2941, 298.2353, 311.1765, 324.1176,
        337.0588, 350.0000, 355.6452, 361.4754, 361.4754, 361.4754, 361.4754,
        361.4754, 361.4754, 361.4754, 355.6452, 344.5312, 339.2308, 324.2

In [None]:
from training.datasets import LibriTTSDatasetAcoustic

dataset = LibriTTSDatasetAcoustic()

{'1061_146197_000015_000000',
 '1061_146197_000016_000000',
 '1061_152224_000005_000004',
 '1061_152224_000009_000001',
 '1061_152224_000009_000002',
 '1061_152224_000011_000000',
 '1061_152224_000011_000006',
 '1061_152224_000011_000009',
 '1061_152224_000012_000000',
 '1061_152224_000013_000004',
 '1061_152224_000013_000005',
 '1061_152224_000015_000000',
 '1061_152224_000019_000000',
 '1061_152224_000020_000007',
 '1061_152224_000020_000009',
 '1061_152224_000020_000014',
 '1061_152224_000026_000001',
 '1061_152224_000028_000005'}

In [None]:
index_dict = {item: index for index, item in enumerate(dataset.dataset._walker)}
index_dict['1061_146197_000015_000000']

0

In [None]:
dataset[0]

{'id': '1061_146197_000015_000000',
 'wav': tensor([[ 5.4999e-11, -3.9558e-10,  1.1043e-09,  ...,  7.8759e-02,
           2.6045e-01, -1.5527e-02]]),
 'mel': tensor([[ -3.0387,  -2.2826,  -2.4592,  ...,  -3.3541,  -3.6037,  -3.0326],
         [ -3.2870,  -3.3729,  -4.0267,  ...,  -4.8843,  -4.8406,  -3.9526],
         [ -4.1049,  -4.2063,  -5.7389,  ...,  -5.6448,  -5.1445,  -4.1173],
         ...,
         [ -8.3479,  -8.5202,  -9.5800,  ...,  -9.9146, -10.0145,  -4.5424],
         [ -8.5265,  -8.7139,  -9.7501,  ...,  -9.8467, -10.0137,  -4.5241],
         [ -9.5756,  -9.7532, -10.9799,  ..., -10.9260, -11.0370,  -5.0525]]),
 'pitch': tensor([279.1139, 279.1139, 279.1139, 279.1139, 279.1139, 279.1139, 279.1139,
         279.1139, 279.1139, 279.1139, 279.1139, 279.1139, 275.6250, 272.2222,
         262.5000, 259.4118, 272.3529, 285.2941, 298.2353, 311.1765, 324.1176,
         337.0588, 350.0000, 355.6452, 361.4754, 361.4754, 361.4754, 361.4754,
         361.4754, 361.4754, 361.4754, 3

In [None]:
import json

# Get the list of speaker IDs you can setup list of speakers if you want
# speaker_ids = selected_speakers_subset['READER'].tolist()
speaker_ids = [19, 26, 27, 32, 39, 40, 60, 78]

speakers_libriid_speakerid = json.load(
    open("./speaker_id_mapping_libri.json", "r")
)

speakers_speakerid_libriid = {
    v: k for k, v in speakers_libriid_speakerid.items()
}

speakers_speakerid_libriid[0], speakers_libriid_speakerid["14"]

('14', 0)