In [1]:
import os, sys, glob, shutil
import argparse
import json
import yaml
import numpy as np
from pprint import pprint

import torch
import yaml
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from scipy.io import wavfile
from matplotlib import pyplot as plt
from tqdm.notebook import tqdm

In [2]:
%cd "D:\Schoolwork\TERM 3\WORK\visual_prosody"

D:\Schoolwork\TERM 3\WORK\visual_prosody


In [10]:
file_paths = glob.glob(r'.\raw_data\LibriTTS\**\*.wav', recursive=True)

In [11]:
len(file_paths)

5736

In [13]:
for file_path in file_paths:
    uid = os.path.basename(file_path)[:-4]
    source_path = file_path
    target_path = rf".\raw_data\LibriTTS\{uid}.wav"
    shutil.move(source_path, target_path)


In [14]:
val_uids = []
for file_path in glob.glob(r'.\raw_data\LibriTTS\*.wav'):
    val_uids.append(os.path.basename(file_path)[:-4])

In [16]:
len(val_uids)

5736

In [17]:
from utils.model import get_model, get_vocoder, get_param_num, vocoder_infer
from utils.tools import to_device, log, synth_one_sample, expand, plot_mel
from model import FastSpeech2Loss
from dataset import Dataset
# from utils.auto_tqdm import tqdm
from evaluate import evaluate

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [32]:
parser = argparse.ArgumentParser()
parser.add_argument("--restore_step", type=int, default=0)
parser.add_argument(
    "-p",
    "--preprocess_config",
    type=str,
    required=True,
    help="path to preprocess.yaml",
)
parser.add_argument(
    "-m", "--model_config", type=str, required=True, help="path to model.yaml"
)
parser.add_argument(
    "-t", "--train_config", type=str, required=True, help="path to train.yaml"
)

argString = '-p ./config/LibriTTS/0714lb_preprocess.yaml -m ./config/LibriTTS/0714lb_model.yaml -t ./config/LibriTTS/0714lb_train.yaml'
# args = parser.parse_args()
args = parser.parse_args(argString.split())

In [33]:
pprint(args)
# Read Config
preprocess_config = yaml.load(
    open(args.preprocess_config, "r"), Loader=yaml.FullLoader
)
model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader)
train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader)
configs = (preprocess_config, model_config, train_config)
print("Prepare training ...")

preprocess_config, model_config, train_config = configs

Namespace(restore_step=0, preprocess_config='./config/LibriTTS/0714lb_preprocess.yaml', model_config='./config/LibriTTS/0714lb_model.yaml', train_config='./config/LibriTTS/0714lb_train.yaml')
Prepare training ...


In [34]:
ckpt_path = r'./output/LibriTTS/LibriTTS_800000.pth.tar'
ckpt = torch.load(ckpt_path)

In [35]:
# Prepare model
model, optimizer = get_model(args, configs, device, train=True)
model.load_state_dict(ckpt["model"], strict=False)
model.to(device)
model = nn.DataParallel(model)
num_param = get_param_num(model)
Loss = FastSpeech2Loss(preprocess_config, model_config).to(device)
print("Number of FastSpeech2 Parameters:", num_param)

=> Not using speaker embeddings.
False
None
Number of FastSpeech2 Parameters: 35391169


In [36]:
# Load vocoder
vocoder = get_vocoder(model_config, device)
step = args.restore_step + 1
model.eval()
print()
dataset = Dataset(
    "val.txt", 'val', preprocess_config, train_config, sort=False, drop_last=False
)

Removing weight norm...



In [37]:
len(dataset)

512

In [38]:
batch_size = train_config["optimizer"]["batch_size"]
batch_size = 1

In [39]:
loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=dataset.collate_fn,
    )
# Get loss function
Loss = FastSpeech2Loss(preprocess_config, model_config).to(device)

In [40]:
os.path.join(train_config['path']['result_path'], 'plots')

'./output/0714lb/result/LibriTTS\\plots'

In [41]:
output_plot_path = os.path.join(train_config['path']['result_path'], 'plot')
output_mel_syn_path = os.path.join(train_config['path']['result_path'], 'mel', 'syn')
output_mel_gt_path = os.path.join(train_config['path']['result_path'], 'mel', 'gt')
output_wav_syn_path = os.path.join(train_config['path']['result_path'], 'wav', 'synthesized')
output_wav_rec_path = os.path.join(train_config['path']['result_path'], 'wav', 'reconstructed')
os.makedirs(output_plot_path, exist_ok=True)
os.makedirs(output_mel_syn_path, exist_ok=True)
os.makedirs(output_mel_gt_path, exist_ok=True)
os.makedirs(output_wav_syn_path, exist_ok=True)
os.makedirs(output_wav_rec_path, exist_ok=True)

In [42]:
for batchs in tqdm(loader):
    for targets in batchs:
        targets = to_device(targets, device)
        with torch.no_grad():
            predictions = model(*(targets[2:]))
        basenames = targets[0]
        for i in range(len(predictions[0])):
            basename = basenames[i]
            src_len = predictions[8][i].item()
            mel_len = predictions[9][i].item()
            mel_prediction = predictions[1][i, :mel_len].detach().transpose(0, 1)
            mel_target = targets[6][i, :mel_len].detach().transpose(0, 1)

            torch.save(mel_prediction.cpu(), os.path.join(output_mel_syn_path, f"{basename}.pt"))
            torch.save(mel_target.cpu(), os.path.join(output_mel_gt_path, f"{basename}.pt"))
            
            
            duration = predictions[5][i, :src_len].detach().cpu().numpy()
            if preprocess_config["preprocessing"]["pitch"]["feature"] == "phoneme_level":
                pitch = predictions[2][i, :src_len].detach().cpu().numpy()
                pitch = expand(pitch, duration)
            else:
                pitch = predictions[2][i, :mel_len].detach().cpu().numpy()
            if preprocess_config["preprocessing"]["energy"]["feature"] == "phoneme_level":
                energy = predictions[3][i, :src_len].detach().cpu().numpy()
                energy = expand(energy, duration)
            else:
                energy = predictions[3][i, :mel_len].detach().cpu().numpy()

            with open(os.path.join(preprocess_config["path"]["preprocessed_path"], 
                                   "stats.json")) as f:
                stats = json.load(f)
                stats = stats["pitch"] + stats["energy"][:2]
                                       
            fig = plot_mel(
                [
                    (mel_prediction.cpu().numpy(), pitch, energy),
                    (mel_target.cpu().numpy(), pitch, energy),
                ],
                stats,
                ["Synthetized Spectrogram", "Ground-Truth Spectrogram"],
            )
            ### TODO: change to svg
            plt.savefig(os.path.join(output_plot_path, f"{basename}.png"))
            plt.close()

        # from .model import vocoder_infer

        mel_predictions = predictions[1].transpose(1, 2)
        mel_targets = targets[6].transpose(1, 2)
        
        lengths = predictions[9] * preprocess_config["preprocessing"]["stft"]["hop_length"]
        wav_predictions = vocoder_infer(
            mel_predictions, vocoder, model_config, preprocess_config, lengths=lengths
        )
        wav_targets = vocoder_infer(
        mel_targets, vocoder, model_config, preprocess_config, lengths=lengths
    )
    
        sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"]
        for wav, basename in zip(wav_predictions, basenames):
            wavfile.write(os.path.join(output_wav_syn_path, f"{basename}.wav"), sampling_rate, wav)
        for wav, basename in zip(wav_targets, basenames):
            wavfile.write(os.path.join(output_wav_rec_path, f"{basename}.wav"), sampling_rate, wav)

        
    #     break
    # break

  0%|          | 0/512 [00:00<?, ?it/s]

FileNotFoundError: [Errno 2] No such file or directory: './preprocessed_data/LibriTTS\\mel\\val\\5400-mel-5400_3587_000037_000001.npy'