In [1]:
import matplotlib

import matplotlib.pyplot as plt
import IPython.display as ipd

import matplotlib.pyplot as plt
import IPython.display as ipd

%matplotlib inline

import os
import json
import argparse
import math
from collections import defaultdict, OrderedDict
import time
from tqdm import tqdm
import numpy as np
import IPython.display as ipd
from pathlib import Path

import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torch.multiprocessing as mp
import torch.distributed as dist
import torch.distributions as D
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data.distributed import DistributedSampler
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data.sampler import WeightedRandomSampler

import librosa

import soundfile as sf

from torch.optim import Adam
from torch_optimizer import Lamb
import bitsandbytes as bnb

import tb_dllogger as logger

from loss_function import FastPitchLoss, FastPitchMASLoss, FastPitchTVCGMMLoss
from text.text_processing import TextProcessor
from data_utils import batch_to_gpu, TextMelAliCollate, TextMelAliLoader
import models
import commons
import utils

import random
import matplotlib.pyplot as plt
from utils import load_wav_to_torch

import sys
sys.path.append('./BigVGAN_/')

from BigVGAN_.env import AttrDict
from BigVGAN_.meldataset import MAX_WAV_VALUE
from BigVGAN_.models import BigVGAN as Generator

def get_melSpech(wav_path):
    audio, sampling_rate = load_wav_to_torch(wav_path)
    tacotronstft = commons.TacotronSTFT(
                hps.data.filter_length, hps.data.hop_length, hps.data.win_length,
                hps.data.n_mel_channels, hps.data.sampling_rate, hps.data.mel_fmin,
                hps.data.mel_fmax)

    if sampling_rate != tacotronstft.sampling_rate:
        raise ValueError("{} {} SR doesn't match target {} SR".format(
            sampling_rate, 22050))
    audio_norm = audio / hps.data.max_wav_value
    audio_norm = audio_norm.unsqueeze(0)
    melspec, energy = tacotronstft.mel_spectrogram(audio_norm)
    melspec = torch.squeeze(melspec, 0)
    energy = torch.squeeze(energy, 0)

    mel_padded = torch.FloatTensor(1, hps.data.n_mel_channels, melspec.size(1))
    mel_padded[0] = melspec
    mel_padded = mel_padded.to(device)

    energy_padded = torch.FloatTensor(1, energy.size(0))
    energy_padded[0] = energy
    energy_padded = energy_padded.to(device)

    return mel_padded, energy_padded

device = torch.device('cuda')

In [2]:
config_file = os.path.join("BigVGAN_/cp_model", 'config.json')
with open(config_file) as f:
    data = f.read()

json_config = json.loads(data)
h = AttrDict(json_config)

generator = Generator(h).to(device)

state_dict_g = torch.load("BigVGAN_/cp_model/g_05000000.zip", map_location="cpu")
generator.load_state_dict(state_dict_g['generator'])

generator.eval()
generator.remove_weight_norm()
print("Succcess Load Vocoder")

Removing weight norm...
Succcess Load Vocoder


In [3]:
model_dir = "logs/base_blank_emo_lang_pitch_emoencoder/"
checkpoint_path = utils.latest_checkpoint_path(model_dir, "fastpitch_*.pt")
hps = utils.get_hparams_from_dir(model_dir)

tp = TextProcessor(None, hps.data.text_cleaners)
n_symbols = len(tp.symbols)

model = models.FastPitch(n_mel_channels=hps.data.n_mel_channels, n_lang=hps.data.n_lang, n_symbols=n_symbols, padding_idx=tp.padding_idx, **hps.model).to(device)
model.forward = model.infer

checkpoint = torch.load(checkpoint_path, map_location='cpu')
sd = {k.replace('module.', ''): v for k, v in checkpoint['state_dict'].items()}
getattr(model, 'module', model).load_state_dict(sd)

_ = model.eval()

logs/base_blank_emo_lang_pitch_emoencoder/fastpitch_240.pt


In [4]:
root_database = "/run/media/fourier/Data2/Pras/Thesis/Database/dataset_name"

with open("filelists/runpod_test_filelist.txt", "r", encoding="utf-8") as txt_file:
    lines = txt_file.readlines()

for line in lines:
    random_test = line.rstrip().split("|")
    wav_path = random_test[0]
    transcription = random_test[2]
    lid = torch.IntTensor([int(random_test[1])]).to(device)
    embeds_filename = wav_path.split("/")[-1].split(".")[0]
    database_name = wav_path.split("/")[8]
    mel_ref, energy_ref = get_melSpech(wav_path)

    spk_emb_src = torch.Tensor(np.load(f"{root_database.replace('dataset_name', database_name)}/spk_embeds/{embeds_filename}.npy")).reshape(1,-1).to(device)
    tst_stn = transcription

    text_encoded = torch.IntTensor(tp.encode_text(tst_stn, lid.item())).unsqueeze(0).to(device)
    y_gen_tst, *_ = model(text_encoded, mel_tgt=mel_ref, mel_len=torch.tensor([mel_ref.shape[2]]), speaker=spk_emb_src, language=lid)

    with torch.no_grad():
        x = y_gen_tst.cpu().detach().numpy()
        x = torch.FloatTensor(x).to(device)
        y_g_hat = generator(x)
        audio = y_g_hat.squeeze()
        audio = audio * MAX_WAV_VALUE
        audio = audio.cpu().numpy().astype('int16')

    print(wav_path.split('/')[-1])
    Path(f"TestFile/{str(random_test[1])}/{wav_path.split('/')[-2]}").mkdir(parents=True, exist_ok=True)
    sf.write(f"TestFile/{str(random_test[1])}/{wav_path.split('/')[-2]}/{wav_path.split('/')[-1]}", audio, hps.data.sampling_rate)

vibid_fyat_058.wav
vibid_fyat_049.wav
vibid_fyat_018.wav
vibid_mdpa_058.wav
vibid_fyat_161.wav
vibid_mdpa_186.wav
vibid_mdpa_110.wav
vibid_mdpa_191.wav
vibid_mdpa_005.wav
vibid_fyat_076.wav
vibid_fyat_154.wav
vibid_fyat_003.wav
vibid_mdpa_070.wav
vibid_mdpa_608.wav
vibid_fyat_109.wav
vibid_mdpa_120.wav
vibid_mdpa_009.wav
vibid_mdpa_200.wav
vibid_fyat_172.wav
vibid_fyat_141.wav
vibid_fyat_291.wav
vibid_fyat_239.wav
vibid_fyat_389.wav
vibid_mdpa_397.wav
vibid_fyat_222.wav
vibid_fyat_258.wav
vibid_mdpa_224.wav
vibid_fyat_340.wav
vibid_mdpa_330.wav
vibid_mdpa_269.wav
vibid_fyat_285.wav
vibid_fyat_395.wav
vibid_mdpa_348.wav
vibid_mdpa_244.wav
vibid_mdpa_621.wav
vibid_mdpa_632.wav
vibid_mdpa_213.wav
vibid_fyat_231.wav
vibid_mdpa_219.wav
vibid_fyat_376.wav
vibid_mmht_1180.wav
vibid_fena_0849.wav
vibid_fena_1243.wav
vibid_fena_0835.wav
vibid_mmht_0140.wav
vibid_fena_0236.wav
vibid_fena_1500.wav
vibid_mmht_0298.wav
vibid_mmht_1418.wav
vibid_mmht_1474.wav
vibid_fena_0771.wav
vibid_fena_0794.wav


In [5]:
import shutil

root_database = "/run/media/fourier/Data2/Pras/Thesis/Database/dataset_name"

with open("filelists/runpod_test_filelist.txt", "r", encoding="utf-8") as txt_file:
    lines = txt_file.readlines()

for line in lines:
    random_test = line.rstrip().split("|")
    wav_path = random_test[0]
    transcription = random_test[2]

    print(wav_path.split('/')[-1])
    Path(f"TestOri/{str(random_test[1])}/{wav_path.split('/')[-2]}").mkdir(parents=True, exist_ok=True)
    shutil.copy(wav_path, f"TestOri/{str(random_test[1])}/{wav_path.split('/')[-2]}/{wav_path.split('/')[-1]}")

vibid_fyat_058.wav
vibid_fyat_049.wav
vibid_fyat_018.wav
vibid_mdpa_058.wav
vibid_fyat_161.wav
vibid_mdpa_186.wav
vibid_mdpa_110.wav
vibid_mdpa_191.wav
vibid_mdpa_005.wav
vibid_fyat_076.wav
vibid_fyat_154.wav
vibid_fyat_003.wav
vibid_mdpa_070.wav
vibid_mdpa_608.wav
vibid_fyat_109.wav
vibid_mdpa_120.wav
vibid_mdpa_009.wav
vibid_mdpa_200.wav
vibid_fyat_172.wav
vibid_fyat_141.wav
vibid_fyat_291.wav
vibid_fyat_239.wav
vibid_fyat_389.wav
vibid_mdpa_397.wav
vibid_fyat_222.wav
vibid_fyat_258.wav
vibid_mdpa_224.wav
vibid_fyat_340.wav
vibid_mdpa_330.wav
vibid_mdpa_269.wav
vibid_fyat_285.wav
vibid_fyat_395.wav
vibid_mdpa_348.wav
vibid_mdpa_244.wav
vibid_mdpa_621.wav
vibid_mdpa_632.wav
vibid_mdpa_213.wav
vibid_fyat_231.wav
vibid_mdpa_219.wav
vibid_fyat_376.wav
vibid_mmht_1180.wav
vibid_fena_0849.wav
vibid_fena_1243.wav
vibid_fena_0835.wav
vibid_mmht_0140.wav
vibid_fena_0236.wav
vibid_fena_1500.wav
vibid_mmht_0298.wav
vibid_mmht_1418.wav
vibid_mmht_1474.wav
vibid_fena_0771.wav
vibid_fena_0794.wav


In [4]:
with open("filelists/paper2_test_filelist.txt", "r", encoding="utf-8") as txt_file:
    lines = txt_file.readlines()

root_database = "/run/media/fourier/Data2/Pras/Thesis/Database/dataset_name"
random_test = random.choice(lines).rstrip().split("|")
wav_path = random_test[0]
transcription = random_test[2]
lid = torch.IntTensor([int(random_test[1])]).to(device)
embeds_filename = wav_path.split("/")[-1].split(".")[0]
database_name = wav_path.split("/")[8]
mel_ref, energy_ref = get_melSpech(wav_path)

spk_emb_src = torch.Tensor(np.load(f"{root_database.replace('dataset_name', database_name)}/spk_embeds/{embeds_filename}.npy")).reshape(1,-1).to(device)
tst_stn = transcription
print(transcription)
print(wav_path.replace(root_database.replace('dataset_name', database_name), ""))

text_encoded = torch.IntTensor(tp.encode_text(tst_stn, lid.item())).unsqueeze(0).to(device)
y_gen_tst, *_ = model(text_encoded, mel_tgt=mel_ref, mel_len=torch.tensor([mel_ref.shape[2]]), speaker=spk_emb_src, language=lid)

born once every one hundred years, dies in flames!
/wavs/0018/Happy/0018_000927.wav
torch.Size([1, 50, 768])


In [5]:
with torch.no_grad():
    x = y_gen_tst.cpu().detach().numpy()
    x = torch.FloatTensor(x).to(device)
    y_g_hat = generator(x)
    audio = y_g_hat.squeeze()
    audio = audio * MAX_WAV_VALUE
    audio = audio.cpu().numpy().astype('int16')

sf.write("sample_sound/generated.wav", audio, hps.data.sampling_rate)
ipd.Audio(audio, rate=hps.data.sampling_rate)

In [6]:
y, sr = librosa.load(wav_path)
sf.write("sample_sound/original.wav", y, sr)
ipd.Audio(y, rate=sr)

In [8]:
spk_emb_tgt = spk_emb_src

In [25]:
lid_tgt = lid

In [26]:
mel_tgt_mod = mel_ref

In [17]:
mel_ref.shape

torch.Size([1, 80, 136])

In [18]:
mel_tgt_mod.shape

torch.Size([1, 80, 240])

In [30]:
y_gen_mod, _, _, pitch_pred, energy_pred = model(text_encoded, mel_tgt=mel_tgt_mod, mel_len=torch.tensor([mel_tgt_mod.shape[2]]), speaker=spk_emb_tgt, language=lid_tgt)
with torch.no_grad():
    x = y_gen_tst.cpu().detach().numpy()
    x = torch.FloatTensor(x).to(device)
    y_g_hat = generator(x)
    audio = y_g_hat.squeeze()
    audio = audio * MAX_WAV_VALUE
    audio = audio.cpu().numpy().astype('int16')

sf.write("sample_sound/generated_mod.wav", audio, hps.data.sampling_rate)
ipd.Audio(audio, rate=hps.data.sampling_rate)

torch.Size([1, 27, 768])


In [11]:
plt.plot(pitch_pred[0].detach().numpy())
plt.show()