# Experiments results

This notebook contains visualization of the results of the research. This includes the speech generated by models trained with various setups, comparison of their performance etc.

This notebook assumes the following models have been trained and compiled:
- Vanilla acoustic model without GST support
- Acoustic model trained with GST support but compiled as vanilla
- Acoustic model accepting GST weights as input
- Acoustic model accepting reference speech as input


Additionally the following models are expected to be trained;
- GST predictor


First all necessary packages are imported, moreover paths to the used components should be set. 

In [None]:
import sys
import os
import numpy as np
import json

from tensorboard.backend.event_processing import event_accumulator
from torchvision import transforms
import torch
import matplotlib.pyplot as plt
import torchaudio

sys.path.append('../src')

from data import data_loading
from data import visualization
from data.preprocessing import text as text_prep
from models.gst_predictor import utils as gst_pred_utils
from utilities import diffusion as diff_utils

In [None]:
COMPILED_VANILLA_PATH = '/home/devcontainer/workspace/tmp/compiled_vanilla.pt'
COMPILED_GST_VANILLA_PATH = '/home/devcontainer/workspace/tmp/compiled_gst_vanilla.pt'
COMPILED_GST_REF_PATH = '/home/devcontainer/workspace/tmp/compiled_gst_reference.pt'
COMPILED_GST_WEIGHTS_PATH = '/home/devcontainer/workspace/tmp/compiled_gst_weights.pt'

GST_PREDICTOR_CHECKPOINT_PATH = '/home/devcontainer/workspace/tmp/gst_predictor/checkpoints/gst_predictor_ckpt_9'

LJSPEECH_DS_PATH = '/home/devcontainer/workspace/.vscode/13100-dataset'
GST_MODEL_TB_OUTPUT_PATH = '/home/devcontainer/workspace/tmp/acoustic_gst/events.out.tfevents.1733311117.867ead6dfff8.69863.0'
VANILLA_MODEL_TB_OUTPUT_PATH = '/home/devcontainer/workspace/tmp/acoustic_vanilla/events.out.tfevents.1733305123.867ead6dfff8.37740.0'
MEL_TO_LIN_MODEL_TB_OUTPUT_PATH = '/home/devcontainer/workspace/tmp/mel_to_lin/events.out.tfevents.1733482633.867ead6dfff8.3499.0'
GST_PRED_MODEL_TB_OUTPUT_PATH = '/home/devcontainer/workspace/tmp/gst_predictor/events.out.tfevents.1733681817.fd1e12a87d0d.179259.0'

In [None]:

gst_vanilla_model = torch.jit.load(COMPILED_GST_VANILLA_PATH)
gst_ref_model = torch.jit.load(COMPILED_GST_REF_PATH)
gst_weights_model = torch.jit.load(COMPILED_GST_WEIGHTS_PATH)

gst_tb_events = event_accumulator.EventAccumulator(GST_MODEL_TB_OUTPUT_PATH)
gst_tb_events.Reload()

vanilla_tb_events = event_accumulator.EventAccumulator(VANILLA_MODEL_TB_OUTPUT_PATH)
vanilla_tb_events.Reload()

mel_to_lin_tb_events = event_accumulator.EventAccumulator(MEL_TO_LIN_MODEL_TB_OUTPUT_PATH)
mel_to_lin_tb_events.Reload()

gst_pred_tb_events = event_accumulator.EventAccumulator(GST_PRED_MODEL_TB_OUTPUT_PATH)
gst_pred_tb_events.Reload()

## Visualization of models' results

In this chapter compiled models are going to be loaded and, depending on the particular model's setup, speech samples in different configurations shall be generated and visualized.   

In [None]:
def annotate_spectrogram_with_phoneme_durations(spectrogram, durations, title):

    frame_boundaries = np.cumsum(durations)

    fig, ax = plt.subplots(1, 1, figsize=(10, 5))

    ax.imshow(spectrogram, aspect='auto', origin='lower', cmap='viridis')

    for i, _ in enumerate(durations):

        if i > 0:
            ax.axvline(frame_boundaries[i - 1], color='red', linestyle='--')

    ax.set_title(title)
    ax.set_xlabel('Time bin index')
    ax.set_ylabel('Frequency bin index')

_, _, test_ds = data_loading.get_datasets(os.path.join(LJSPEECH_DS_PATH, 'split'),
                                          train_split_ratio=0.98,
                                          n_test_files=100)

sample_spectrogram, sample_transcript, sample_log_durations = test_ds[0]

print('Example input phonemes:', visualization.decode_transcript(
    sample_transcript, text_prep.ENHANCED_MFA_ARP_VOCAB
))

durations_mask = (sample_log_durations.numpy() > 0).astype(np.uint16)
pow_durations = (np.power(2, sample_log_durations.numpy()) +
                    1e-4).astype(np.uint16)[:np.sum(durations_mask).item()]

visualization.annotate_spectrogram_with_phoneme_durations(
    sample_spectrogram, pow_durations
)

In [None]:
def visualize_vanilla_output(inference_cfg,
                             cfg_name,
                             gt_wav_path):
    
    with open(f'.experiments_results/{cfg_name}.json', 'w') as f:
        json.dump(inference_cfg, f)

    !PYTHONPATH=../src python ../scripts/inference/run_inference.py --config_path .experiments_results/{cfg_name}.json

    spec_trans = transforms.Compose([
        torchaudio.transforms.Spectrogram(n_fft=1024, hop_length=256, win_length=256),
        torchaudio.transforms.AmplitudeToDB()
    ])

    gt_waveform, _ = torchaudio.load(gt_wav_path)
    output_waveform, _ = torchaudio.load(inference_cfg['output_path'])

    gt_spec = spec_trans(gt_waveform)
    output_spec = spec_trans(output_waveform)
    
    fig, ax = plt.subplots(2, 1, figsize=(20, 20))
    ax[0].imshow(gt_spec[0].numpy(), aspect='auto', origin='lower', cmap='viridis')
    ax[0].set_title('Ground truth spectrogram')

    ax[1].imshow(output_spec[0].numpy(), aspect='auto', origin='lower', cmap='viridis')
    ax[1].set_title('Output spectrogram')

def visualize_gst_weights_output(inference_cfg,
                                 cfg_name,
                                 gt_wav_path):
    
    visualize_vanilla_output(inference_cfg, cfg_name, gt_wav_path)

    gst_weights = torch.load(inference_cfg['gst_weights_cfg']['weights_path'], weights_only=True)

    fig, ax = plt.subplots(1, 1, figsize=(20, 20))

    ax.plot(gst_weights)
    ax.set_title('GST weights')


def visualize_gst_reference_output(inference_cfg,
                                   cfg_name,
                                   gt_wav_path):
    
    visualize_vanilla_output(inference_cfg, cfg_name, gt_wav_path)

    gst_ref_waveform, _ = torchaudio.load(inference_cfg['gst_reference_cfg']['reference_audio_path'])

    gst_ref_spec = transforms.Compose([
        torchaudio.transforms.Spectrogram(n_fft=1024, hop_length=256, win_length=256),
        torchaudio.transforms.AmplitudeToDB()
    ])(gst_ref_waveform)

    fig, ax = plt.subplots(1, 1, figsize=(20, 10))
    ax.imshow(gst_ref_spec[0].numpy(), aspect='auto', origin='lower', cmap='viridis')
    ax.set_title('GST reference spectrogram')

### Vanilla model

In [None]:
visualize_vanilla_output(
    {
        "compiled_model_path": COMPILED_VANILLA_PATH,
        "input_phonemes_length": 20,
        "input_text": "The crime, long carried on without detection, was first discovered in eighteen twenty.",
        "gst_mode": "none",
        "scale_max": 45.8506,
        "scale_min": -100.0,
        "output_path": ".experiments_results/vanilla_output.wav",
    },
    'vanilla_inference_cfg',
    os.path.join(LJSPEECH_DS_PATH, 'raw/LJSpeech-1.1/wavs/LJ011-0018.wav')
)

### Vanilla GST model

In [None]:
visualize_vanilla_output(
    {
        "compiled_model_path": COMPILED_GST_VANILLA_PATH,
        "input_phonemes_length": 20,
        "input_text": "The crime, long carried on without detection, was first discovered in eighteen twenty.",
        "gst_mode": "none",
        "scale_max": 45.8506,
        "scale_min": -100.0,
        "output_path": ".experiments_results/vanilla_gst_output.wav",
    },
    'vanilla_gst_inference_cfg',
    os.path.join(LJSPEECH_DS_PATH, 'raw/LJSpeech-1.1/wavs/LJ011-0018.wav')
)

### GST model with input weights

#### Example 1

In [None]:

weights = torch.tensor([0.0003, 0.0100, 0.0038, 0.1589, 0.0324, 0.0657, 0.0221, 0.0317, 0.0155,
        0.0051, 0.0652, 0.0430, 0.0019, 0.0403, 0.0196, 0.0174, 0.0654, 0.0011,
        0.0280, 0.0002, 0.0115, 0.0107, 0.0223, 0.0610, 0.0195, 0.0172, 0.0160,
        0.0040, 0.0269, 0.0028, 0.0345, 0.0111])

torch.save(weights, '.experiments_results/gst_weights_1.pt')

visualize_gst_weights_output(
    {
        "compiled_model_path": COMPILED_GST_WEIGHTS_PATH,
        "input_phonemes_length": 20,
        "input_text": "The crime, long carried on without detection, was first discovered in eighteen twenty.",
        "gst_mode": "weights",
        "scale_max": 45.8506,
        "scale_min": -100.0,
        "output_path": ".experiments_results/gst_weights_output_1.wav",
        "gst_weights_cfg": {
            "weights_path": ".experiments_results/gst_weights_1.pt"
        }
    },
    'gst_weights_inference_cfg_1',
    os.path.join(LJSPEECH_DS_PATH, 'raw/LJSpeech-1.1/wavs/LJ011-0018.wav')
)

#### Example 2

In [None]:
weights = torch.tensor([0.0095, 0.0242, 0.0075, 0.1424, 0.0184, 0.0415, 0.0071, 0.0085, 0.0049,
        0.0405, 0.0976, 0.0585, 0.1090, 0.0401, 0.0091, 0.0186, 0.0436, 0.0144,
        0.0827, 0.0013, 0.0338, 0.0018, 0.0119, 0.0194, 0.0384, 0.0081, 0.0124,
        0.0512, 0.0120, 0.0130, 0.0157, 0.0028])

torch.save(weights, '.experiments_results/gst_weights_2.pt')

visualize_gst_weights_output(
    {
        "compiled_model_path": COMPILED_GST_WEIGHTS_PATH,
        "input_phonemes_length": 20,
        "input_text": "The crime, long carried on without detection, was first discovered in eighteen twenty.",
        "gst_mode": "weights",
        "scale_max": 45.8506,
        "scale_min": -100.0,
        "output_path": ".experiments_results/gst_weights_output_2.wav",
        "gst_weights_cfg": {
            "weights_path": ".experiments_results/gst_weights_2.pt"
        }
    },
    'gst_weights_inference_cfg_2',
    os.path.join(LJSPEECH_DS_PATH, 'raw/LJSpeech-1.1/wavs/LJ011-0018.wav')
)

### GST model with input reference speech 

In [None]:
visualize_gst_reference_output(
    {
        "compiled_model_path": COMPILED_GST_REF_PATH,
        "input_phonemes_length": 20,
        "input_text": "The crime, long carried on without detection, was first discovered in eighteen twenty.",
        "gst_mode": "reference",
        "scale_max": 45.8506,
        "scale_min": -100.0,
        "output_path": ".experiments_results/gst_reference_output.wav",
        "gst_reference_cfg": {
            "reference_audio_path": os.path.join(LJSPEECH_DS_PATH, 'raw/LJSpeech-1.1/wavs/LJ011-0018.wav'),
            "spectrogram_window_length": 1024,
            "spectrogram_hop_length": 256,
            "n_mels": 80,
            "spec_length": 200,
            "sample_rate": 22050
        },
    },
    'gst_reference_inference_cfg',
    os.path.join(LJSPEECH_DS_PATH, 'raw/LJSpeech-1.1/wavs/LJ011-0018.wav')
)

### GST Model with predicted GST weights 

In [None]:
class GSTPredictorInferenceModel(torch.nn.Module):

    def __init__(self,
                 encoder: torch.nn.Module,
                 decoder: torch.nn.Module,
                 global_mean: torch.Tensor,
                 global_std: torch.Tensor,
                 diff_handler: diff_utils.DiffusionHandler,
                 n_weights: int):
        
        super().__init__()

        self._encoder = encoder
        self._decoder = decoder
        self._global_mean = global_mean
        self._global_std = global_std
        self._diff_handler = diff_handler
        self._n_weights = n_weights

    def forward(self, phonemes):

        noised_data = torch.randn(1, self._n_weights)
        phoneme_embedding = self._encoder(phonemes)

        for timestep in reversed(range(self._diff_handler.num_steps)):

            predicted_noise = self._decoder(noised_data,
                                            torch.Tensor([timestep]),
                                            phoneme_embedding)

            noised_data = self._diff_handler.remove_noise(noised_data,
                                                          predicted_noise,
                                                          timestep)

        return (noised_data * self._global_std) + self._global_mean



In [None]:

gst_pred_cfg = {
    "decoder": {
        "timestep_embedding_size": 256,
        "internal_channels": 64,
        "n_conv_blocks": 6
    },
    "encoder": {
        "n_conv_blocks": 6,
        "embedding_size": 128
    },
    "dropout_rate": 0.2
}

diff_cfg = {
    "n_steps": 1000,
    "beta_min": 0.0001,
    "beta_max": 0.02
}

global_mean = torch.tensor(
            [0.0362, 0.0354, 0.0347, 0.0343, 0.0259, 0.0346, 0.0244, 0.0342, 0.0281,
             0.0316, 0.0365, 0.0270, 0.0412, 0.0305, 0.0279, 0.0336, 0.0336, 0.0341,
             0.0292, 0.0277, 0.0288, 0.0273, 0.0328, 0.0329, 0.0305, 0.0232, 0.0295,
             0.0271, 0.0316, 0.0271, 0.0413, 0.0273])

global_stddev = torch.tensor(
        [0.0289, 0.0285, 0.0275, 0.0289, 0.0193, 0.0269, 0.0212, 0.0265, 0.0265,
        0.0249, 0.0389, 0.0205, 0.0345, 0.0271, 0.0283, 0.0286, 0.0286, 0.0281,
        0.0233, 0.0227, 0.0224, 0.0250, 0.0324, 0.0276, 0.0263, 0.0215, 0.0259,
        0.0234, 0.0255, 0.0227, 0.0366, 0.0225])

gst_pred_components = gst_pred_utils.create_model_components((20, 73), gst_pred_cfg, 'cpu')

gst_predictor_inf = GSTPredictorInferenceModel(
    gst_pred_components.encoder,
    gst_pred_components.decoder,
    global_mean,
    global_stddev,
    diff_utils.DiffusionHandler(
        diff_utils.LinearScheduler(diff_cfg['beta_min'], diff_cfg['beta_max'], diff_cfg['n_steps']),
        'cpu'
    ),
    32
)

gst_predictor_inf = gst_predictor_inf.eval()

In [None]:
input_text = 'The crime, long carried on without detection, was first discovered in eighteen twenty.'
phonemes = text_prep.G2PTransform()(input_text)
encoded_phonemes = transforms.Compose([
    text_prep.PadSequenceTransform(20),
    text_prep.OneHotEncodeTransform(text_prep.ENHANCED_MFA_ARP_VOCAB)
])(phonemes)

with torch.no_grad():
    predicted_gst = gst_predictor_inf(encoded_phonemes.unsqueeze(0))
    predicted_gst = torch.functional.F.softmax(predicted_gst, dim=-1)
    torch.save(predicted_gst.squeeze(0), '.experiments_results/gst_predicted_weights.pt')

visualize_gst_weights_output(
    {
        "compiled_model_path": COMPILED_GST_WEIGHTS_PATH,
        "input_phonemes_length": 20,
        "input_text": "The crime, long carried on without detection, was first discovered in eighteen twenty.",
        "gst_mode": "weights",
        "scale_max": 45.8506,
        "scale_min": -100.0,
        "output_path": ".experiments_results/gst_predicted_weights_output.wav",
        "gst_weights_cfg": {
            "weights_path": ".experiments_results/gst_predicted_weights.pt"
        }
    },
    'gst_predicted_weights_inference_cfg',
    os.path.join(LJSPEECH_DS_PATH, 'raw/LJSpeech-1.1/wavs/LJ011-0018.wav')
)