In [None]:
import sys, os
if os.path.abspath('../') not in sys.path:
    sys.path.append(os.path.abspath('../'))

In [None]:
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data.sampler import SubsetRandomSampler
# import pytorch_lightning as pl
# from omegaconf import OmegaConf

import nemo
# import nemo.collections.tts as nemo_tts
from nemo.collections.tts.models.base import SpectrogramGenerator, Vocoder
from nemo.collections.tts.helpers.helpers import OperationMode

import IPython.display as ipd
import numpy as np
import datetime
import tensorflow as tf
import soundfile as sf
import torchaudio
from tqdm import tqdm

In [None]:
from speech_distances.datasets import load_dataset
from speech_distances.models import load_model
from speech_distances.dpam import CustomLoss

In [None]:
SEQ_LEN = 5000

## Infer

In [None]:
spectrogram_generator = SpectrogramGenerator.from_pretrained("tts_en_tacotron2", override_config_path=None)
spectrogram_generator.training = False
spectrogram_generator.calculate_loss = False

In [None]:
tokens = spectrogram_generator.parse('This is a sample text for a deep learning course project.')
spectrogram_generator.generate_spectrogram(tokens=tokens)

In [None]:
spectrograms = spectrogram_generator.generate_spectrogram(tokens=torch.cat((tokens, tokens, tokens, tokens)))

In [None]:
vocoder = load_model('uniglow', device='cpu')
vocoder.eval()
pass

In [None]:
state = torch.load('models/uniglow.model', 
                   map_location="cpu")

In [None]:
vocoder.load_state_dict(state)

In [None]:
waveforms = vocoder.convert_spectrogram_to_audio(spec=spectrograms.cpu()).cpu()

In [None]:
ipd.Audio(waveforms[0].numpy(), rate=22050)

In [None]:
sf.write('wavs/after_finetune.wav', waveforms[0].numpy(), 22050)

## Dataset

In [None]:
def pad_collate(batch):
    (wavs, samplerates,  _, transcripts) = zip(*batch)
    new_wavs = []
    wav_lens = []
    for w in wavs:
        idx = np.random.randint(0, w.shape[1]-SEQ_LEN-1)
        torch.squeeze(w)[idx:idx+SEQ_LEN]
        new_wavs.append(torch.squeeze(w)[idx:idx+SEQ_LEN])
    wav_lens = torch.ones(len(new_wavs), dtype=torch.int) * SEQ_LEN
    
    return torch.stack(new_wavs), wav_lens

In [None]:
dataset = load_dataset('ljspeech')
# dataloader = DataLoader(dataset, batch_size=10, shuffle=False, collate_fn=pad_collate, num_workers=4)

In [None]:
dataset_size = len(dataset)
indices = list(range(dataset_size))
validation_split = 0.8
split = int(np.floor(validation_split * dataset_size))
if 1 :
    np.random.seed(1337)
    np.random.shuffle(indices)
train_indices, valid_indices = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(valid_indices)

In [None]:
batch_size = 10
train_dataloader = DataLoader(dataset, 
                           batch_size=batch_size, 
                           shuffle=False, 
                           collate_fn=pad_collate,
                           sampler=train_sampler,
                           num_workers=4)
val_dataloader = DataLoader(dataset, 
                           batch_size=batch_size, 
                           shuffle=False, 
                           collate_fn=pad_collate,
                           sampler=valid_sampler,
                           num_workers=4)

In [None]:
iterator = iter(val_dataloader)
waveforms, wav_lens= next(iterator)

In [None]:
waveforms.shape, wav_lens

In [None]:
ipd.Audio(waveforms[0].numpy(), rate=22050)

## Train Vocoder

In [None]:
# WaveGlow 88M
# SqueezeNet 24M
# HiFiGan 84M
# MelGan 9M
# UniGlow 4M

In [None]:
# model
# vocoder = load_model('uniglow', device='cuda')
vocoder = load_model('uniglow', device='cpu')

In [None]:
%%time

# optimizer
optimizer = optim.Adam(vocoder.parameters(), lr=1e-4)
# logger
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = 'logs/' + current_time
summary_writer = tf.summary.create_file_writer(log_dir)
custom_loss = CustomLoss(seq_len=SEQ_LEN, stft_loss_coef=0.1, dpam_loss_coef=1.0)

max_epoch=30
step = 100
running_loss = 0.0
vocoder.train()
for epoch in range(max_epoch):
    
    if epoch+1 % 5 == 0:
        running_dpam_loss = 0.0
        running_loss = 0.0
        for i, batch in enumerate(val_dataloader, 0):
            vocoder.mode = OperationMode.validation
            z, logdet, predicted_audio, spec, spec_len = vocoder(audio=waveforms.cuda(), audio_len=wav_lens.cuda())
            loss = custom_loss(z=z, logdet=logdet, gt_audio=waveforms.cuda(), predicted_audio=predicted_audio, sigma=1.0)
            running_loss += loss.item()
            shape_diff = SEQ_LEN - predicted_audio.shape[1]
            predicted_audio = F.pad(predicted_audio, (0, shape_diff), mode='constant', value=0)
            dpam_loss = torch.mean(model.model_dist.forward(predicted_audio, waveforms.cuda()))
            running_dpam_loss += dpam_loss.item()
        print('[%d] val loss: %.3f' %
                  (epoch+1, running_loss / step))
        with summary_writer.as_default():
            tf.summary.scalar('validation loss', 
                              running_loss/len(val_dataloader), 
                              step=epoch+1)
            tf.summary.scalar('validation dpam loss', 
                              running_dpam_loss/len(val_dataloader), 
                              step=epoch+1)
        torch.save(vocoder.state_dict(), f'{log_dir}/uniglow_{epoch}.state')

    running_loss = 0.0
    for i, batch in enumerate(train_dataloader, 0):
        optimizer.zero_grad()
        waveforms, wav_lens = batch
        vocoder.mode = OperationMode.training
        z, logdet, predicted_audio = vocoder(audio=waveforms.cuda(), audio_len=wav_lens.cuda())
        loss = custom_loss(z=z, 
                           logdet=logdet, 
                           gt_audio=waveforms.cuda(), 
                           predicted_audio=predicted_audio, 
                           sigma=1.0)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % step == step-1:
            print('[%d, %5d] loss: %.3f' %
                  (epoch+1, i+1, running_loss / step))
            with summary_writer.as_default():
                tf.summary.scalar('training loss', running_loss/step, step=i+1)
            running_loss = 0.0


torch.save(vocoder.state_dict(), f'{log_dir}/uniglow_final.state')
print('Finished Training')

In [None]:
!nvidia-smi