In [5]:
import sys, os
if os.path.abspath('../') not in sys.path:
    sys.path.append(os.path.abspath('../'))

In [10]:
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data.sampler import SubsetRandomSampler
# import pytorch_lightning as pl
# from omegaconf import OmegaConf

import nemo
# import nemo.collections.tts as nemo_tts
from nemo.collections.tts.models.base import SpectrogramGenerator, Vocoder
from nemo.collections.tts.helpers.helpers import OperationMode

import IPython.display as ipd
import numpy as np
import datetime
import tensorflow as tf

In [14]:
from speech_distances.datasets import load_dataset
from speech_distances.models import load_model
from speech_distances.dpam import CustomLoss

In [22]:
SEQ_LEN = 1000

## Infer

In [4]:
# spectrogram_generator = SpectrogramGenerator.from_pretrained("tts_en_tacotron2", override_config_path=None)
# spectrogram_generator.training = False
# spectrogram_generator.calculate_loss = False

In [5]:
# tokens = spectrogram_generator.parse('some string that I want to hear. his is not a test alert. Please, start working. Some more text to make this text bigger. Do me a favour, please.')
# spectrogram_generator.generate_spectrogram(tokens=tokens)

In [6]:
# spectrograms = spectrogram_generator.generate_spectrogram(tokens=torch.cat((tokens, tokens, tokens, tokens)))

In [7]:
# vocoder = Vocoder.from_pretrained("tts_waveglow_88m")
# vocoder = Vocoder.from_pretrained("tts_squeezewave")
# vocoder.eval()

In [8]:
# waveforms = vocoder.convert_spectrogram_to_audio(spec=spectrograms).cpu()

In [9]:
# ipd.Audio(waveforms[0].numpy(), rate=22050)

## Dataset

In [23]:
def pad_collate(batch):
    (wavs, samplerates,  _, transcripts) = zip(*batch)
    new_wavs = []
    wav_lens = []
    for w in wavs:
        idx = np.random.randint(0, w.shape[1]-SEQ_LEN-1)
        torch.squeeze(w)[idx:idx+SEQ_LEN]
        new_wavs.append(torch.squeeze(w)[idx:idx+SEQ_LEN])
    wav_lens = torch.ones(len(new_wavs), dtype=torch.int) * SEQ_LEN
    
    return torch.stack(new_wavs), wav_lens

In [24]:
dataset = load_dataset('ljspeech')
# dataloader = DataLoader(dataset, batch_size=10, shuffle=False, collate_fn=pad_collate, num_workers=4)

In [25]:
dataset_size = len(dataset)
indices = list(range(dataset_size))
validation_split = 0.8
split = int(np.floor(validation_split * dataset_size))
if 1 :
    np.random.seed(1337)
    np.random.shuffle(indices)
train_indices, valid_indices = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(valid_indices)

In [26]:
batch_size = 10
train_dataloader = DataLoader(dataset, 
                           batch_size=batch_size, 
                           shuffle=False, 
                           collate_fn=pad_collate,
                           sampler=train_sampler,
                           num_workers=4)
val_dataloader = DataLoader(dataset, 
                           batch_size=batch_size, 
                           shuffle=False, 
                           collate_fn=pad_collate,
                           sampler=valid_sampler,
                           num_workers=4)

In [27]:
iterator = iter(val_dataloader)
waveforms, wav_lens= next(iterator)

In [16]:
waveforms.shape, wav_lens

(torch.Size([10, 5000]),
 tensor([5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000],
        dtype=torch.int32))

In [28]:
ipd.Audio(waveforms[0].numpy(), rate=22050)

## Train Vocoder

In [14]:
# WaveGlow 88M
# SqueezeNet 24M
# HiFiGan 84M
# MelGan 9M
# UniGlow 4M

In [30]:
# model
# vocoder = load_model('uniglow', device='cuda')
vocoder = load_model('uniglow', device='cpu')

[NeMo I 2021-05-28 02:56:46 cloud:66] Downloading from: https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_uniglow/versions/1.0.0rc1/files/tts_uniglow.nemo to /home/darayavaus/.cache/torch/NeMo/NeMo_1.0.0rc1/tts_uniglow/6d1602a9610471099e8fcc2d29e5ae9a/tts_uniglow.nemo
[NeMo I 2021-05-28 03:00:27 common:654] Instantiating model from pre-trained checkpoint


[NeMo W 2021-05-28 03:00:27 modelPT:133] Please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.
    Train config : 
    dataset:
      _target_: nemo.collections.tts.data.datalayers.AudioDataset
      manifest_filepath: /home/tlv/data/LJSpeech-1.1/train.txt
      max_duration: null
      min_duration: 0.1
      n_segments: 16384
      trim: false
    dataloader_params:
      drop_last: false
      shuffle: true
      batch_size: 12
      num_workers: 4
    
[NeMo W 2021-05-28 03:00:27 modelPT:140] Please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method and provide a valid configuration file to setup the validation data loader(s). 
    Validation config : 
    dataset:
      _target_: nemo.collections.tts.data.datalayers.AudioDataset
      manifest_filepath: /home/tlv/data/LJSpeech-1.1/val.txt
      max_duration: null
      min_duration: 0.1
      n_segments: 49152
      trim:

[NeMo I 2021-05-28 03:00:27 features:240] PADDING: 0
[NeMo I 2021-05-28 03:00:27 features:249] STFT using conv
[NeMo I 2021-05-28 03:00:28 modelPT:376] Model UniGlowModel was successfully restored from /home/darayavaus/.cache/torch/NeMo/NeMo_1.0.0rc1/tts_uniglow/6d1602a9610471099e8fcc2d29e5ae9a/tts_uniglow.nemo.


In [32]:
%%time

# optimizer
optimizer = optim.Adam(vocoder.parameters(), lr=1e-4)
# logger
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = 'logs/' + current_time
summary_writer = tf.summary.create_file_writer(log_dir)
custom_loss = CustomLoss(seq_len=SEQ_LEN, stft_loss_coef=0.1, dpam_loss_coef=1.0)

max_epoch=30
step = 100
running_loss = 0.0
vocoder.train()
for epoch in range(max_epoch):
    
    if epoch+1 % 5 == 0:
        running_dpam_loss = 0.0
        running_loss = 0.0
        for i, batch in enumerate(val_dataloader, 0):
            vocoder.mode = OperationMode.validation
            z, logdet, predicted_audio, spec, spec_len = vocoder(audio=waveforms.cuda(), audio_len=wav_lens.cuda())
            loss = custom_loss(z=z, logdet=logdet, gt_audio=waveforms.cuda(), predicted_audio=predicted_audio, sigma=1.0)
            running_loss += loss.item()
            shape_diff = SEQ_LEN - predicted_audio.shape[1]
            predicted_audio = F.pad(predicted_audio, (0, shape_diff), mode='constant', value=0)
            dpam_loss = torch.mean(model.model_dist.forward(predicted_audio, waveforms.cuda()))
            running_dpam_loss += dpam_loss.item()
        print('[%d] val loss: %.3f' %
                  (epoch+1, running_loss / step))
        with summary_writer.as_default():
            tf.summary.scalar('validation loss', 
                              running_loss/len(val_dataloader), 
                              step=epoch+1)
            tf.summary.scalar('validation dpam loss', 
                              running_dpam_loss/len(val_dataloader), 
                              step=epoch+1)
        torch.save(vocoder.state_dict(), f'{log_dir}/uniglow_epoch{}.state')

    running_loss = 0.0
    for i, batch in enumerate(train_dataloader, 0):
        optimizer.zero_grad()
        waveforms, wav_lens = batch
        vocoder.mode = OperationMode.training
        z, logdet, predicted_audio = vocoder(audio=waveforms.cuda(), audio_len=wav_lens.cuda())
        loss = custom_loss(z=z, 
                           logdet=logdet, 
                           gt_audio=waveforms.cuda(), 
                           predicted_audio=predicted_audio, 
                           sigma=1.0)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % step == step-1:
            print('[%d, %5d] loss: %.3f' %
                  (epoch+1, i+1, running_loss / step))
            with summary_writer.as_default():
                tf.summary.scalar('training loss', running_loss/step, step=i+1)
            running_loss = 0.0


torch.save(vocoder.state_dict(), f'{log_dir}/uniglow_epoch{}.state')
print('Finished Training')

SyntaxError: f-string: empty expression not allowed (<unknown>, line 32)

In [21]:
!nvidia-smi

Thu May 27 13:20:05 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 208...  Off  | 00000000:02:00.0 Off |                  N/A |
| 27%   35C    P2    59W / 260W |  10998MiB / 11019MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces