In [None]:
!nvidia-smi

# If this doesn't work, there's no GPU available or detected

In [None]:
!pip install -e.

# Demo of Speech Inpainting

Includes :
- How to load data and models.
- How to inpaint with our methods.

## Import

In [None]:
import os
import sys
import json
import joblib
import torch
import torchaudio
import whisper
import soundfile as sf
import numpy as np
import matplotlib.pyplot as plt

from transformers import AutoProcessor
from librosa.util import normalize
from scipy.io.wavfile import read, write
from pesq import pesq
from pystoi import stoi
from jiwer import cer

# From I_ea folder
from I_ea.hifi_gan.inference_modified import load_checkpoint as iea_load_checkpoint
from I_ea.hifi_gan.inference_modified import extend_mel
from I_ea.hifi_gan.models import Generator
from I_ea.hifi_gan.env import AttrDict
from I_ea.model import CustomModel
from I_ea.dataset.mel_dump import MAX_WAV_VALUE, get_mel
from I_ea.utils import choose_device
from I_ea.loss_fn import LossFunction
from I_ea.metrics import Metrics

# From I_da folder
from I_da.src.model import CodeGenerator
from I_da.src.multiseries import match_length
from I_da.src.preprocess import normalize_nonzero
from I_da.src.dataset import extract_fo, generate
from I_da.src.utils import (
    AttrDict,
    get_audio_files,
    get_feature_reader,
    parse_speaker,
    load_checkpoint,
    scan_checkpoint,
    get_logger,
)

## Functions definitions

In [None]:
def load_wav(full_path):
    sampling_rate, data = read(full_path)
    return data, sampling_rate

def save_fig(tensor, path, fig_name= 'orig'):
    fig, ax_image = plt.subplots(1, 1, figsize=(8, 4))

    # Display the image array
    image = ax_image.imshow(np.array(tensor))
    fig.colorbar(image, ax=ax_image)

    output_path = os.path.join(path, fig_name+'.png')  # Output file path and extension (e.g., PNG, JPEG, etc.)

    # Save the image
    plt.savefig(output_path, bbox_inches='tight', pad_inches=0)
def plot_wave(waveform, sr):
    n = len(waveform)
    duration = len(waveform) / sr

    # Generate a time array
    time = np.linspace(0, duration, num=n)
    random_values = np.random.uniform(low=-0.2, high=0.2, size=n)
    waveform[n//4:n*3//4] = 0
    waveform[n//4:n*3//4] = random_values[n//4:n*3//4]
    # Plot the waveform
    plt.plot(time, waveform)
    plt.xlabel('Time (s)')
    plt.ylabel('Amplitude')
    plt.title('Waveform')
    plt.ylim([-1,1])
    # plt.grid(True)
    plt.plot(time[n//4: n*3//4+1], waveform[n//4:n*3//4+1], color='red')
    # plt.axvspan(n//4/sr, n//4*3/sr, facecolor='red', alpha=0.3)
    # Save the plot as a PNG image
    plt.savefig('waveform_plot.png')

## Parameters

In [None]:
start_pos_in_sec = 1.8
end_pos_in_sec = 2.2
n_clusters = 500
device = 0
cache_dir = 'I_ea/pretrained_models'
asr_name = 'openai/whisper-small'
extension = ".flac"

### Parameters for method I_ea (Encoder Adaptation)

In [None]:
i_ea_params = {}
i_ea_params['dataset_path'] = 'I_ea/dataset/VCTK/wavs'
i_ea_params['validation_path'] = 'I_ea/dataset/VCTK_splits/validation.txt'
i_ea_params['wave_path'] = 'I_ea/dataset/VCTK/wavs/p231_368.wav' 
i_ea_params['wave_text'] = 'That deal is a joke.' 
i_ea_params['save_pred'] = 'I_ea/prediction/VCTK'
i_ea_params['path2dict'] = 'I_ea/results'
i_ea_params['checkpoint_file'] = 'I_ea/hifi_gan/VCTK_V1/g_00022000'
i_ea_params['model_checkpoint'] = 'I_ea/trained_models/HuBERT_large_VCTK.pt'
i_ea_params['km_model_path'] = 'I_ea/dataset/kmeans/VCTK/'

### Parameters for method I_da (Decoder Adaptation)

In [None]:
i_da_params = {}
i_da_params["acoustic_model_path"]="I_da/checkpoints/hubert_large.pt"
i_da_params["kmeans_model_path"]="I_da/checkpoints/hubert_large_km500.bin"
i_da_params["checkpoint_file"]="I_da/checkpoints/vctk_hubert_large_500"
i_da_params["manifest_path"]="I_da/datasets/VCTK/manifest.txt"
i_da_params["output_dir"]="I_da/data/VCTK/prediction"

## Inpainting process

In [None]:
seed = 52
np.random.seed(seed)
torch.manual_seed(seed)

### I_ea Inpainting

In [None]:
# Load model and wav file
km_model_path = os.path.join(i_ea_params['km_model_path'], f'km_model_{n_clusters}/model.km')
device = choose_device(device)
loss_instance = LossFunction(km_model_path, device=device)
print("Current device:", device)
wave_name = os.path.basename(i_ea_params['wave_path']).split('.')[0]
save_pred = os.path.join(i_ea_params['save_pred'],wave_name)
if not os.path.exists(save_pred):
    os.makedirs(save_pred)
wave_22, sr_22 = librosa.load(i_ea_params['wave_path'], sr = 22050)
wave_16, sr_16 = librosa.load(i_ea_params['wave_path'], sr = 16000)
assert sr_22==22050
assert sr_16==16000

# Create mask
sf.write(os.path.join(save_pred,'orig'+'.wav'), wave_16, sr_16)
mask_ms = int((end_pos_in_sec-start_pos_in_sec)*1000)
mask_50ms = mask_ms//20 # FIXME: get the hop-size of the STFT analysis
start_mask = int(start_pos_in_sec*16000)
end_mask = int(end_pos_in_sec*16000)
mask_pos = start_mask//320

# Check the generated waveform using hifi-gan, and save it
wave_22_orig = wave_22.copy()
wave_22_orig = normalize(wave_22_orig) * 0.95 # makes very little difference
norm_wave_22 = torch.FloatTensor(wave_22_orig)
mel_feats_orig = get_mel(norm_wave_22.unsqueeze(0))
save_fig(mel_feats_orig.squeeze(0), save_pred, fig_name = 'orig')

# Create mel features
start_mask_22 = start_mask*sr_22//sr_16
end_mask_22 = end_mask*sr_22//sr_16
wave_22_masked = wave_22.copy()
wave_22_masked[start_mask_22:end_mask_22] = 0
wave_22_masked = normalize(wave_22_masked) * 0.95 
norm_wave_22 = torch.FloatTensor(wave_22_masked)
mel_feats = get_mel(norm_wave_22.unsqueeze(0))
save_fig(mel_feats.squeeze(0), save_pred, fig_name = 'masked')
feats = extend_mel(mel_feats)

In [None]:
# Get the generator and generate speech signal 
config_file = os.path.join(os.path.split(i_ea_params['checkpoint_file'])[0], 'config.json')
with open(config_file) as f:
    hifi_gan_data = f.read()
json_config = json.loads(hifi_gan_data)
h = AttrDict(json_config)
torch.manual_seed(h.seed)
generator = Generator(h).to(device)
state_dict_g = iea_load_checkpoint(i_ea_params['checkpoint_file'], device)
generator.load_state_dict(state_dict_g['generator'])

generator.eval()
generator.remove_weight_norm()
with torch.no_grad():
    y_g_hat = generator(feats.to(device))
    audio = y_g_hat.squeeze()
    audio = audio * MAX_WAV_VALUE
    audio = audio.cpu().numpy().astype('int16')
    sf.write(os.path.join(save_pred, 'hifi_masked'+'.wav'), audio, sr_22)

In [None]:
# Preprocess the signal
masked_wave_16 = wave_16.copy()
masked_wave_16[mask_pos*320+80:(mask_pos+mask_50ms)*320+79-80] = 0
sf.write(os.path.join(save_pred, 'masked'+'.wav'), masked_wave_16, sr_16)

tokenizer = AutoProcessor.from_pretrained("facebook/hubert-large-ls960-ft")
tokenized_values= tokenizer(masked_wave_16, sampling_rate = sr_16, return_attention_mask = True, return_tensors = 'pt')
input_values, attention_mask = tokenized_values.input_values, tokenized_values.attention_mask 
input_values = input_values.squeeze(0)
attention_mask = attention_mask.squeeze(0)

# Load a trained hubert
model = CustomModel(codebook_dim = 80, type= 'large', load_pretrained=False) # 'base'

model.to(device) 
model.load_state_dict(torch.load(i_ea_params['model_checkpoint'], map_location = 'cuda'))
model.eval()
with torch.no_grad():
    inputs = input_values.unsqueeze(0).to(device)
    attention_masks = attention_mask.unsqueeze(0).to(device)
    mask_pos = torch.tensor([mask_pos], dtype = torch.int).unsqueeze(0).to(device)
    mask_len = torch.tensor([mask_50ms], dtype = torch.int).unsqueeze(0).to(device)
    path2centroids = os.path.join(km_model_path, f'km_model_{n_clusters}/label_dir/validation')
    labels = torch.load(os.path.join(path2centroids, wave_name+'_labels.pt')).t()
    labels = labels[mask_pos:mask_pos+mask_50ms].unsqueeze(0).to(device)
    outputs = model(inputs, attention_masks)
    values = torch.zeros((mask_pos.shape[0], mask_len[0], outputs.shape[-1])).to(device)
    for i in range(mask_pos.shape[0]): # considre flatten the tensor
        values[i,:,:] = outputs[i,mask_pos[i]:mask_pos[i]+mask_len[i],:]

    # Compute loss
    loss, pred_labels = loss_instance.cos_sim(values, labels)
    cos_sim_pred_target = loss_instance.cos_sim_target_labels(
            pred_labels, labels)
    print("Loss:",loss.item())
    mel_feats = mel_feats.to(device)
    expected_mel = mel_feats.clone()
    exp_mask_mel = loss_instance.all_embeds_t_c[0,labels[0,:],:] +loss_instance.center_ 
    expected_mel[0,:,mask_pos[0]:mask_pos[0]+mask_len[0]] = exp_mask_mel.T 
    save_fig(expected_mel.cpu().squeeze(0), save_pred, fig_name = 'expected')
    exp_feats = extend_mel(expected_mel)

    pred_mels = loss_instance.all_embeds_t_c[0,pred_labels[0,:],:] +loss_instance.center_ 
    mel_feats[0,:,mask_pos[0]:mask_pos[0]+mask_len[0]] = pred_mels.T 
    save_fig(mel_feats.cpu().squeeze(0), save_pred, fig_name = 'inpainted')
    feats = extend_mel(mel_feats)
    print("Target codewords: ", labels)
    print("Predicted codewords: ", pred_labels)
    print(cos_sim_pred_target.shape, cos_sim_pred_target)
    print("Average Cosine Similarity: ", cos_sim_pred_target.mean())

In [None]:
# Generate the expected inpaiting and the actual inpaiting for comparison:
with torch.no_grad():
    y_g_hat = generator(exp_feats.to(device))
    audio = y_g_hat.squeeze()
    audio = audio * MAX_WAV_VALUE
    audio = audio.cpu().numpy().astype('int16')
    sf.write(os.path.join(save_pred, 'expected_inpaint'+'.wav'), audio, sr_22)

    y_g_hat = generator(feats.to(device))
    audio = y_g_hat.squeeze()
    audio = audio * MAX_WAV_VALUE
    audio = audio.cpu().numpy().astype('int16')
    sf.write(os.path.join(save_pred, 'inpainted'+'.wav'), audio, sr_22)

### I_da Inpainting

In [None]:
# Load models
device = "cuda"

if os.path.isdir(i_da_params["checkpoint_file"]):
    config_file = os.path.join(i_da_params["checkpoint_file"], "config.json")
else:
    config_file = os.path.join(
        os.path.split(i_da_params["checkpoint_file"])[0], "config.json"
    )
with open(config_file) as f:
    data = f.read()
json_config = json.loads(data)
h = AttrDict(json_config)

if os.path.isdir(i_da_params["checkpoint_file"]):
    cp_g = scan_checkpoint(i_da_params["checkpoint_file"], "g_")
else:
    cp_g = i_da_params["checkpoint_file"]
if not os.path.isfile(cp_g) or not os.path.exists(cp_g):
    print(f"Didn't find checkpoints for {cp_g}")

# Feature extraction
use_cuda = True
feature_reader_cls = get_feature_reader("hubert")
reader = feature_reader_cls(
    checkpoint_path=i_da_params["acoustic_model_path"], layer=-1, use_cuda=use_cuda
)

# K-means model
kmeans_model = joblib.load(open(i_da_params["kmeans_model_path"], "rb"))
kmeans_model.verbose = False
kmeans_model._n_threads = 40

generator = CodeGenerator(h).to(device)
if os.path.isdir(i_da_params["checkpoint_file"]):
    cp_g = scan_checkpoint(i_da_params["checkpoint_file"], "g_")
else:
    cp_g = i_da_params["checkpoint_file"]
state_dict_g = load_checkpoint(cp_g)
generator.load_state_dict(state_dict_g["generator"])
generator.eval()
generator.remove_weight_norm()

In [None]:
# Get the file to process
root_dir, fnames, _ = get_audio_files(i_da_params["manifest_path"])
id_to_spkr = sorted(set([parse_speaker(f, "_") for f in fnames]))
spk_name_to_idx = {spk_name: spk_idx for spk_idx, spk_name in enumerate(id_to_spkr)}
wav2mel = torch.jit.load(h["wav2mel_path"])
embedder = torch.jit.load(h["embedder_path"]).eval()
file = [x for x in fnames if wave_name in x][0]

if not os.path.exists(i_da_params["output_dir"]):
    os.makedirs(i_da_params["output_dir"])

spk_idx = spk_name_to_idx[parse_speaker(file, "_")]
gt_file = os.path.join(root_dir, file + extension)
base_fname = os.path.basename(file).rstrip("." + extension.lstrip("."))
fname_out_name = base_fname.rsplit(".")[0]
wav_tensor, sample_rate = torchaudio.load(os.path.join(root_dir, file))
mel_tensor = wav2mel(wav_tensor, sample_rate)
emb = embedder.embed_utterance(mel_tensor)
emb = emb.detach().cpu().numpy()

audio_gt = reader.read_audio(gt_file, channel_id="1")
y = audio_gt.copy()

In [None]:
# Create a mask to specify the missing region
mask = np.ones_like(y)
mask[start_mask : end_mask] = 0
# Apply the mask to the audio to perform inpainting (fill the missing region with zeros)
y_inpainting = (y + 1e-6) * mask
audio_mask = y_inpainting.copy()

# Get the features from HuBERT
feats = reader.get_feats(None, signal=y, channel_id="1")
feats_inpainting = reader.get_feats(
    None, signal=y_inpainting, channel_id="1"
)

feats = feats.cpu().numpy()
feats_inpainting = feats_inpainting.cpu().numpy()

# Get the units from quantified HuBERT
pred = kmeans_model.predict(feats)
pred_inpainting = kmeans_model.predict(feats_inpainting)

code = pred
code_inpainting = pred_inpainting
code_inpainting[: (start_mask) // h.code_hop_size] = code[
    : (start_mask) // h.code_hop_size
]
code_inpainting[end_mask// h.code_hop_size :] = code[
    end_mask // h.code_hop_size :
]

f0 = extract_fo(audio_gt, h["sampling_rate"])
fo = normalize_nonzero(f0, np.mean(f0), np.std(f0))
fo = np.expand_dims(fo, axis=0)
spk_idx = np.array([spk_idx], dtype=np.int64)
audio_gt, audio_mask, code, fo = match_length(
    [
        (audio_gt, 1),
        (audio_mask, 1),
        (code, h.code_hop_size),
        (fo, int(h["sampling_rate"] * 0.005)),
    ],
    -1,
)
audio_gt = torch.FloatTensor(audio_gt)
code = {
    "code": torch.LongTensor(code).to(device).unsqueeze(0),
    "f0": torch.FloatTensor(fo).to(device).unsqueeze(0),
    "emb": torch.LongTensor(emb).to(device).unsqueeze(0),
    "spkr": torch.LongTensor(spk_idx).to(device).unsqueeze(0),
}
code_inpainting = {
    "code": torch.LongTensor(code_inpainting).to(device).unsqueeze(0),
    "f0": torch.FloatTensor(fo).to(device).unsqueeze(0),
    "emb": torch.LongTensor(emb).to(device).unsqueeze(0),
    "spkr": torch.LongTensor(spk_idx).to(device).unsqueeze(0),
}

if h.get("f0_vq_params", None) or h.get("f0_quantizer", None):
    to_remove = audio_gt.shape[-1] % (16 * 80)
    assert to_remove % h["code_hop_size"] == 0

    if to_remove != 0:
        to_remove_code = to_remove // h["code_hop_size"]
        to_remove_f0 = to_remove // 80

        audio_gt = audio_gt[:-to_remove]
        audio_mask = audio_mask[:-to_remove]
        code["code"] = code["code"][..., :-to_remove_code]
        code["f0"] = code["f0"][..., :-to_remove_f0]
        code_inpainting["code"] = code_inpainting["code"][..., :-to_remove_code]
        code_inpainting["f0"] = code_inpainting["f0"][..., :-to_remove_f0]

audio_gen, _ = generate(h, generator, code)
audio_inp, _ = generate(h, generator, code_inpainting)

audio_gen = normalize(audio_gen.astype(np.float32))
audio_mask = normalize(audio_mask.astype(np.float32))
audio_inp = normalize(audio_inp.astype(np.float32))
audio_gt = normalize(audio_gt.squeeze().numpy().astype(np.float32))

gt_file = os.path.join(i_da_params["output_dir"], fname_out_name + "_gt.wav")
output_file_inpainting = os.path.join(i_da_params["output_dir"], fname_out_name + "_inpainted.wav")
output_file_gen = os.path.join(i_da_params["output_dir"], fname_out_name + "_gen.wav")
output_file_mask = os.path.join(i_da_params["output_dir"], fname_out_name + "_masked.wav")
write(gt_file, h["sampling_rate"], audio_gt)
write(output_file_mask, h["sampling_rate"], audio_mask)
write(output_file_gen, h["sampling_rate"], audio_gen)
write(output_file_inpainting, h["sampling_rate"], audio_inp)

## Evaluation of the inpainting

In [None]:
# Get the audio files and compare the results
ground_truth = audio_gt
i_ea_inpainted = audio
i_da_inpainted = audio_inp

asr = whisper.load_model("small.en", device=device)
transcript_i_ea = asr.transcribe(i_ea_inpainted, fp16=False, language="English")["text"]
transcript_i_da = asr.transcribe(i_da_inpainted, fp16=False, language="English")["text"]

score_pesq_i_ea = pesq(sr_16, ground_truth, i_ea_inpainted, mode='nb')
score_pesq_i_da = pesq(sr_16, ground_truth, i_da_inpainted, mode='nb')

score_stoi_i_ea = stoi(ground_truth, i_ea_inpainted, sr_16, extended=False)
score_stoi_i_da = stoi(ground_truth, i_da_inpainted, sr_16, extended=False)

score_cer_i_ea = cer(i_ea_params['wave_text'], transcript_i_da)
score_cer_i_da = cer(i_ea_params['wave_text'], transcript_i_da)

print("Evalution on I_ea methods :")
print(f"PESQ = {score_pesq_i_ea} - STOI = {score_stoi_i_ea} - CER = {100*score_cer_i_ea}%")
print("Evalution on I_da methods :")
print(f"PESQ = {score_pesq_i_da} - STOI = {score_stoi_i_da} - CER = {100*score_cer_i_da}%")