In [None]:
import nemo.collections.asr as nemo_asr
from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer
import torch
from IPython.display import clear_output
import json
from nemo.collections.tts.helpers.helpers import plot_spectrogram_to_numpy
import IPython.display
import PIL.Image
import matplotlib.pyplot as plt
import os
from nemo.collections.tts.models import hifigan, hifigan_ssl, ssl_tts, fastpitch_ssl
from nemo.core.neural_types.elements import AudioSignal, MelSpectrogramType
import librosa
import pickle

In [None]:
ssl_model_ckpt_path = "/home/pneekhara/NeMo2022/SSLCheckPoints/SSLConformer22050_Epoch37.ckpt"
ssl_model = ssl_tts.SSLDisentangler.load_from_checkpoint(ssl_model_ckpt_path, strict=False)
ssl_model = ssl_model.cpu()
ssl_model.eval()
clear_output()

In [None]:
hifi_path = "/home/pneekhara/NeMo2022/HiFiCKPTS/hifigan_libritts/HiFiLibriEpoch334.ckpt"
vocoder = hifigan.HifiGanModel.load_from_checkpoint(hifi_path).cpu()
vocoder.eval()
clear_output()

In [None]:
wav_featurizer = WaveformFeaturizer(sample_rate=22050, int_values=False, augmentor=None)

In [None]:
def load_wav(wav_path, pad_multiple=1024):
    wav = wav_featurizer.process(wav_path)
    if wav.shape[0]  % pad_multiple != 0:
        wav = torch.cat(
                [wav, torch.zeros(pad_multiple -wav.shape[0] % pad_multiple, dtype=torch.float)]
            )
    wav = wav[:-1]
    
    return wav

def load_hifigan_model(ckpt_path):
    hifigan_model = hifigan_ssl.HifiGanModel.load_from_checkpoint(ckpt_path)
    hifigan_model = hifigan_model.cpu()
    hifigan_model.eval()
    clear_output()
    return hifigan_model

def load_fastpitch_model(ckpt_path):
    fastpitch_model = fastpitch_ssl.FastPitchModel_SSL.load_from_checkpoint(ckpt_path)
    fastpitch_model = fastpitch_model.cpu()
    fastpitch_model.eval()
    fastpitch_model.non_trainable_models = {'vocoder' : vocoder}
    clear_output()
    return fastpitch_model
    
def find_latest_ckpt_path(experiment_dir):
    matches = []
    for root, dirnames, filenames in os.walk(experiment_dir):
        for filename in filenames:
            if filename.endswith("last.ckpt"):
                return os.path.join(root, filename)
            
def segment_wav(wav, segment_length=44100, hop_size=44100, min_segment_size=22050):
    if len(wav) < segment_length:
        pad = torch.zeros(segment_length - len(wav))
        segment = torch.cat([wav, pad])
        return [segment]
    else:
        si = 0
        segments = []
        while si < len(wav) - min_segment_size:
            segment = wav[si:si+segment_length]
#             print("Segment", si)
#             IPython.display.display(IPython.display.Audio(segment, rate=22050))
            if len(segment) < segment_length:
                pad = torch.zeros(segment_length - len(segment))
                segment = torch.cat([segment, pad])
                
            segments.append(segment)
            si += hop_size
        return segments

def get_speaker_stats(ssl_model, wav_featurizer, audio_paths):
    all_segments = []
    all_wavs = []
    for audio_path in audio_paths:
        wav = load_wav(audio_path)
        segments = segment_wav(wav)
        all_segments += segments
        all_wavs.append(wav)
    
    signal_batch = torch.stack(all_segments)
    #print("signal batch", signal_batch.shape)
    signal_length_batch = torch.stack( [ torch.tensor(signal_batch.shape[1]) for _i in range(len(all_segments)) ] )
    #print("signal length", signal_length_batch.shape)
    _, speaker_embeddings, _, _, _ = ssl_model.forward_for_export(
                    input_signal=signal_batch, input_signal_length=signal_length_batch, normalize_content=True
                )
    
    speaker_embedding = torch.mean(speaker_embeddings, dim=0)
    l2_norm = torch.norm(speaker_embedding, p=2)
    speaker_embedding = speaker_embedding/l2_norm
    non_zero_pc = []
    for wav in all_wavs:
        pitch_contour = get_pitch_contour(wav)
        pitch_contour_nonzero = pitch_contour[pitch_contour != 0]
        non_zero_pc.append(pitch_contour_nonzero)
    
    non_zero_pc = torch.cat(non_zero_pc)
    if len(non_zero_pc) > 0:
        pitch_mean = non_zero_pc.mean().item()
        pitch_std = non_zero_pc.std().item()
    else:
        print("could not find pitch contour")
        pitch_mean = 212.0
        pitch_std = 70.0
    
    return speaker_embedding[None], pitch_mean, pitch_std
        
        
def get_ssl_features_disentsngled(ssl_model, wav_featurizer, audio_path, emb_type="embedding_and_probs", use_unique_tokens=False):
    wav = load_wav(audio_path)
    audio_signal = wav[None]
    audio_signal_length = torch.tensor( [ wav.shape[0] ])
    _, speaker_embedding, content_embedding, content_log_probs, encoded_len = ssl_model.forward_for_export(
                    input_signal=audio_signal, input_signal_length=audio_signal_length, normalize_content=True
                )
    
    content_embedding = content_embedding[0,:encoded_len[0].item()]
    content_log_probs = content_log_probs[:encoded_len[0].item(),0,:]
    content_embedding = content_embedding.t()
    content_log_probs = content_log_probs.t()
    content_probs = torch.exp(content_log_probs)
    
    if emb_type == "probs":
        final_content_embedding = content_probs
        
    elif emb_type == "embedding":
        final_content_embedding = content_embedding
        
    elif emb_type == "log_probs":
        final_content_embedding = content_log_probs
        
    elif emb_type == "embedding_and_probs":
        final_content_embedding = torch.cat([content_embedding, content_probs], dim=0)
    
    duration = torch.ones(final_content_embedding.shape[1]) * 4.0
    if use_unique_tokens:
        token_predictions = torch.argmax(content_probs, dim=0)
        # print("token predictions:", token_predictions)
        content_buffer = [final_content_embedding[:, 0]]
        unique_content_embeddings = []
        unique_tokens = []
        durations = []
        for _t in range(1, final_content_embedding.shape[1]):
            if token_predictions[_t] == token_predictions[_t - 1]:
                content_buffer.append(final_content_embedding[:, _t])
            else:
                durations.append(len(content_buffer) * 4)
                unique_content_embeddings.append(torch.mean(torch.stack(content_buffer), dim=0))
                content_buffer = [final_content_embedding[:, _t]]
                unique_tokens.append(token_predictions[_t].item())

        if len(content_buffer) > 0:
            durations.append(len(content_buffer) * 4)
            unique_content_embeddings.append(torch.mean(torch.stack(content_buffer), dim=0))
            unique_tokens.append(token_predictions[_t].item())

        unique_content_embedding = torch.stack(unique_content_embeddings)
        final_content_embedding = unique_content_embedding.t()
        duration = torch.tensor(durations).float()
        
    return final_content_embedding[None], speaker_embedding, duration[None]

def get_pitch_contour(wav, pitch_mean=None, pitch_std=None):
    f0, _, _ = librosa.pyin(
        wav.numpy(),
        fmin=librosa.note_to_hz('C2'),
        fmax=librosa.note_to_hz('C7'),
        frame_length=1024,
        hop_length=256,
        sr=22050,
        center=True,
        fill_na=0.0,
    )
    pitch_contour = torch.tensor(f0, dtype=torch.float32)
    if (pitch_mean is not None) and (pitch_std is not None):
        pitch_contour = pitch_contour - pitch_mean
        pitch_contour[pitch_contour == -pitch_mean] = 0.0
        pitch_contour = pitch_contour / pitch_std
        
    return pitch_contour
    
    
def vocode_ssl_features_disentangled(hifigan_model, content_embedding, speaker_embedding, pitch_contour=None, compute_pitch=True):
    wav_generated = hifigan_model.synthesize_wav(content_embedding, speaker_embedding, pitch_contour, compute_pitch)
    return wav_generated

def load_speaker_wise_audio_paths(speaker_type="seen"):
    if speaker_type == "seen":
        manifest_path = "/home/pneekhara/NeMo2022/libri_val_formatted.json"
        speaker_wise_audio_paths = {}
        with open(manifest_path) as f:
            lines = f.readlines()
            for line in lines:
                record = json.loads(line)
                if record['speaker'] not in speaker_wise_audio_paths:
                    speaker_wise_audio_paths[record['speaker']] = []
                speaker_wise_audio_paths[record['speaker']].append(record['audio_filepath'])
            
            filtered_paths = {}
            for key in speaker_wise_audio_paths:
                if len(speaker_wise_audio_paths[key]) > 1:
                    filtered_paths[key] = speaker_wise_audio_paths[key]
            return filtered_paths
    elif speaker_type == "vctk":
        manifest_path = "/home/pneekhara/Datasets/vctk/vctk_test_local.json"
        speaker_wise_audio_paths = {}
        with open(manifest_path) as f:
            lines = f.readlines()
            for line in lines:
                record = json.loads(line)
                if record['speaker'] not in speaker_wise_audio_paths:
                    speaker_wise_audio_paths[record['speaker']] = []
                speaker_wise_audio_paths[record['speaker']].append(record['audio_filepath'])
            
            filtered_paths = {}
            for key in speaker_wise_audio_paths:
                if len(speaker_wise_audio_paths[key]) > 1:
                    filtered_paths[key] = speaker_wise_audio_paths[key][:10]
            return filtered_paths
    elif speaker_type == "bengali":
        manifest_path = "/home/pneekhara/Datasets/BengaliData/bengali_manifest.json"
        speaker_wise_audio_paths = {}
        with open(manifest_path) as f:
            lines = f.readlines()
            for line in lines:
                record = json.loads(line)
                if record['speaker'] not in speaker_wise_audio_paths:
                    speaker_wise_audio_paths[record['speaker']] = []
                speaker_wise_audio_paths[record['speaker']].append(record['audio_filepath'])
            
            filtered_paths = {}
            for key in speaker_wise_audio_paths:
                if len(speaker_wise_audio_paths[key]) > 1:
                    filtered_paths[key] = speaker_wise_audio_paths[key][:10]
            return filtered_paths
        
    else:
        manifest_path = "/home/pneekhara/Datasets/LibriDev/libri_dev_clean_local.json"
        speaker_wise_audio_paths = {}
        with open(manifest_path) as f:
            lines = f.readlines()
            for line in lines:
                record = json.loads(line)
                if record['speaker'] not in speaker_wise_audio_paths:
                    speaker_wise_audio_paths[record['speaker']] = []
                speaker_wise_audio_paths[record['speaker']].append(record['audio_filepath'])
            
            filtered_paths = {}
            spk_count = 0
            for key in speaker_wise_audio_paths:
                if len(speaker_wise_audio_paths[key]) > 1:
                    filtered_paths[key] = speaker_wise_audio_paths[key][:10]
                    spk_count += 1
                    if spk_count >= 10:
                        break
            return filtered_paths


def load_speaker_stats(speaker_wise_paths, speaker_type="seen", recache=False):
    pickle_path = "{}_speaker_stats.pkl".format(speaker_type)
    if os.path.exists(pickle_path) and not recache:
        with open(pickle_path, 'rb') as f:
            return pickle.load(f)
        
    speaker_stats = {}
    pitch_stats = {}
    if speaker_type == "seen":
        speaker_stats_fp = "/home/pneekhara/NeMo2022/libri_speaker_stats.json"
        with open(speaker_stats_fp, "r") as f:
            pitch_stats = json.loads(f.read())
        
    for speaker in speaker_wise_paths:
        print("computing stats for {}".format(speaker))
        speaker_embedding, pitch_mean, pitch_std = get_speaker_stats(ssl_model, wav_featurizer, speaker_wise_paths[speaker])
        speaker_stats[speaker] = {
            'speaker_embedding' : speaker_embedding,
            'pitch_mean' : pitch_mean,
            'pitch_std' : pitch_std
        }
        if str(speaker) in pitch_stats:
            speaker_stats[speaker]["pitch_mean"] = pitch_stats[str(speaker)]["pitch_mean"]
            speaker_stats[speaker]["pitch_std"] = pitch_stats[str(speaker)]["pitch_std"]

    with open(pickle_path, 'wb') as f:
        pickle.dump(speaker_stats, f)
    
    return speaker_stats

In [None]:
speaker_wise_paths_english = load_speaker_wise_audio_paths("unseen")
stats_unseen = load_speaker_stats(speaker_wise_paths_english, speaker_type="unseen", recache=False)

speaker_wise_paths = load_speaker_wise_audio_paths("bengali")
stats = load_speaker_stats(speaker_wise_paths, speaker_type="bengali", recache=False)

In [None]:
speaker_wise_paths = {**speaker_wise_paths, **speaker_wise_paths_english}
stats = {**stats, **stats_unseen}

In [None]:
def reconstruct_audio(fastpitch_model, speaker_stats, speaker_wise_paths, pitch_conditioning=True, compute_pitch=False, compute_duration=False, use_unique_tokens=False, n_audio=1, n_speakers=10):
    spk_count = 0
    for speaker in speaker_wise_paths:
        spk_count+=1
        if spk_count > n_speakers:
            break
        speaker_stat = speaker_stats[speaker]
        for wav_path in speaker_wise_paths[speaker][:n_audio]:
            content_embedding, _, duration = get_ssl_features_disentsngled(ssl_model, wav_featurizer, wav_path, emb_type="embedding_and_probs", use_unique_tokens=use_unique_tokens)
            pitch_contour = get_pitch_contour( load_wav(wav_path), pitch_mean=speaker_stat["pitch_mean"], pitch_std=speaker_stat["pitch_std"] )[None]
            
            print("Original Audio Speaker {}".format(speaker))
            IPython.display.display(IPython.display.Audio(load_wav(wav_path), rate=22050))
            with torch.no_grad():
                print("Reconstructed Audio Speaker {}".format(speaker))
                wav_generated = fastpitch_model.synthesize_wav(content_embedding, speaker_stat['speaker_embedding'], pitch_contour=pitch_contour, compute_pitch=compute_pitch,compute_duration=compute_duration, durs_gt=duration)
                IPython.display.display(IPython.display.Audio(wav_generated[0], rate=22050))
            print("**************************")
            
def swap_speakers(fastpitch_model, speaker_stats, speaker_wise_paths, spk1, spk2, pitch_conditioning=True, compute_pitch=False, compute_duration=False, use_unique_tokens=False, n_audio=1, n_speakers=10):
    wav_path1 = speaker_wise_paths[spk1][0]
    wav_path2 = speaker_wise_paths[spk2][0]
    
    speaker_embedding1 =speaker_stats[spk1]["speaker_embedding"]
    speaker_embedding2 =speaker_stats[spk2]["speaker_embedding"]
    
    content_embedding1, _, duration1 = get_ssl_features_disentsngled(ssl_model, wav_featurizer, wav_path1, emb_type="embedding_and_probs", use_unique_tokens=use_unique_tokens)
    content_embedding2, _, duration2 = get_ssl_features_disentsngled(ssl_model, wav_featurizer, wav_path2, emb_type="embedding_and_probs", use_unique_tokens=use_unique_tokens)
    
    pitch_contour1 = get_pitch_contour( load_wav(wav_path1), pitch_mean=speaker_stats[spk1]["pitch_mean"], pitch_std=speaker_stats[spk1]["pitch_std"] )[None]
    pitch_contour2 = get_pitch_contour( load_wav(wav_path2), pitch_mean=speaker_stats[spk2]["pitch_mean"], pitch_std=speaker_stats[spk2]["pitch_std"] )[None]
    
    print("Real Audio Speaker {}".format(spk1))
    IPython.display.display(IPython.display.Audio(load_wav(wav_path1), rate=22050))
    
    print("Real Audio Speaker {}".format(spk2))
    IPython.display.display(IPython.display.Audio(load_wav(wav_path2), rate=22050))
    
    print("Content of {}, Voice of {}".format(spk1, spk2))
    wav_generated = fastpitch_model.synthesize_wav(content_embedding1, speaker_embedding2, pitch_contour=pitch_contour1, compute_pitch=compute_pitch,compute_duration=compute_duration, durs_gt=duration1)
    IPython.display.display(IPython.display.Audio(wav_generated[0], rate=22050))
    
    print("Content of {}, Voice of {}".format(spk2, spk1))
    wav_generated = fastpitch_model.synthesize_wav(content_embedding2, speaker_embedding1, pitch_contour=pitch_contour2, compute_pitch=compute_pitch,compute_duration=compute_duration, durs_gt=duration2)
    IPython.display.display(IPython.display.Audio(wav_generated[0], rate=22050))
    

In [None]:
# ckpt_path = "/home/pneekhara/NeMo2022/tensorboards/FastPitch/ConstLR_PitchConditioningWithEncEpoch37/Epoch319.ckpt"
ckpt_path = "/home/pneekhara/NeMo2022/tensorboards/FastPitch/DurationPredictor/SegDurPerSampleEpoch604.ckpt"
fasptich_model = load_fastpitch_model(ckpt_path)
# reconstruct_audio(fasptich_model, stats, speaker_wise_paths)

## Reconstructed using Predicted Pitch

In [None]:
reconstruct_audio(fasptich_model, stats, speaker_wise_paths, compute_pitch=False, compute_duration=False, use_unique_tokens=False, n_audio=1)

## Swapping

In [None]:
speaker_wise_paths.keys()

In [None]:
swap_speakers(fasptich_model, stats, speaker_wise_paths, '00737', '1673', compute_pitch=True, compute_duration=True, use_unique_tokens=True)

In [None]:
swap_speakers(fasptich_model, stats, speaker_wise_paths, 283, 616, compute_pitch=True, compute_duration=False, use_unique_tokens=False)