In [None]:
from model_vc import Generator
import torch
import matplotlib.pyplot as plt

# from synthesis import build_model, wavegen
from hparams import hparams
from wavenet_vocoder import WaveNet
from wavenet_vocoder.util import is_mulaw_quantize, is_mulaw, is_raw, is_scalar_input
from tqdm import tqdm
import audio
from nnmnkwii import preprocessing as P
import numpy as np
from scipy.io import wavfile

In [None]:
g = Generator(160, 0, 512, 16)
g.load_state_dict(torch.load('model_latest.pth'))

## Random Input

In [None]:
x = torch.randn(1, 880, 320)
plt.imshow(x.squeeze(0))

In [None]:
mel_outputs = g.decoder(x)
                
mel_outputs_postnet = g.postnet(mel_outputs.transpose(2,1))
mel_outputs_postnet = mel_outputs + mel_outputs_postnet.transpose(2,1)

In [None]:
plt.imshow(mel_outputs_postnet.squeeze(0).squeeze(0).detach().numpy())

## Seeded Input

In [None]:
x = np.load('example_vocals-feats.npy')
x = torch.from_numpy(x)
x = x[:720, :].unsqueeze(0)
x.shape

In [None]:
mel_outputs, mel_outputs_postnet, encoder_output = g(x)

In [None]:
mel_outputs_postnet.shape

## WaveNet

In [None]:
def build_model():
    if is_mulaw_quantize(hparams.input_type):
        if hparams.out_channels != hparams.quantize_channels:
            raise RuntimeError(
                "out_channels must equal to quantize_chennels if input_type is 'mulaw-quantize'")
    if hparams.upsample_conditional_features and hparams.cin_channels < 0:
        s = "Upsample conv layers were specified while local conditioning disabled. "
        s += "Notice that upsample conv layers will never be used."
        print(s)

    upsample_params = hparams.upsample_params
    upsample_params["cin_channels"] = hparams.cin_channels
    upsample_params["cin_pad"] = hparams.cin_pad
    model = WaveNet(
        out_channels=hparams.out_channels,
        layers=hparams.layers,
        stacks=hparams.stacks,
        residual_channels=hparams.residual_channels,
        gate_channels=hparams.gate_channels,
        skip_out_channels=hparams.skip_out_channels,
        cin_channels=hparams.cin_channels,
        gin_channels=hparams.gin_channels,
        n_speakers=hparams.n_speakers,
        dropout=hparams.dropout,
        kernel_size=hparams.kernel_size,
        cin_pad=hparams.cin_pad,
        upsample_conditional_features=hparams.upsample_conditional_features,
        upsample_params=upsample_params,
        scalar_input=is_scalar_input(hparams.input_type),
        output_distribution=hparams.output_distribution,
    )
    return model


def batch_wavegen(model, c=None, g=None, fast=True, tqdm=tqdm):
    assert c is not None
    B = c.shape[0]
    model.eval()
    if fast:
        model.make_generation_fast_()

    # Transform data to GPU
    g = None if g is None else g.to(device)
    c = None if c is None else c.to(device)

    if hparams.upsample_conditional_features:
        length = (c.shape[-1] - hparams.cin_pad * 2) * audio.get_hop_size()
    else:
        # already dupulicated
        length = c.shape[-1]

    with torch.no_grad():
        y_hat = model.incremental_forward(
            c=c, g=g, T=length, tqdm=tqdm, softmax=True, quantize=True,
            log_scale_min=hparams.log_scale_min)

    if is_mulaw_quantize(hparams.input_type):
        # needs to be float since mulaw_inv returns in range of [-1, 1]
        y_hat = y_hat.max(1)[1].view(B, -1).float().cpu().data.numpy()
        for i in range(B):
            y_hat[i] = P.inv_mulaw_quantize(y_hat[i], hparams.quantize_channels - 1)
    elif is_mulaw(hparams.input_type):
        y_hat = y_hat.view(B, -1).cpu().data.numpy()
        for i in range(B):
            y_hat[i] = P.inv_mulaw(y_hat[i], hparams.quantize_channels - 1)
    else:
        y_hat = y_hat.view(B, -1).cpu().data.numpy()

    if hparams.postprocess is not None and hparams.postprocess not in ["", "none"]:
        for i in range(B):
            y_hat[i] = getattr(audio, hparams.postprocess)(y_hat[i])

    if hparams.global_gain_scale > 0:
        for i in range(B):
            y_hat[i] /= hparams.global_gain_scale

    return y_hat


def to_int16(x):
    if x.dtype == np.int16:
        return x
    assert x.dtype == np.float32
    assert x.min() >= -1 and x.max() <= 1.0
    return (x * 32767).astype(np.int16)

In [None]:
device = torch.device("cuda")
model = build_model().to(device)
checkpoint = torch.load("/wavenet_vocoder/checkpoints/checkpoint_latest_ema.pth")
model.load_state_dict(checkpoint["state_dict"])

In [None]:
c = mel_outputs_postnet.squeeze(0).detach()

# Split c into chunks across the 0th dimension
length = c.shape[0]
c = c.T
print(c.shape)
c_chunks = c.reshape(80, length//20, 20)
c_chunks = c_chunks.permute(1, 0, 2)
c = c_chunks
print(c.shape)

# # Resize c to 1, 80, 866
# print(c.shape)
# c = TF.resize(c, (80, 866))
# c = c[:, :, :50]
# print(c.shape)

# Generate
y_hats = batch_wavegen(model, c=c, g=None, fast=True, tqdm=tqdm)
y_hats = torch.from_numpy(y_hats).flatten().unsqueeze(0).numpy()

gen = y_hats[0]
gen = np.clip(gen, -1.0, 1.0)
wavfile.write('test.wav', hparams.sample_rate, to_int16(gen))