In [None]:
# %matplotlib widget

from model_vc import Generator, GeneratorV2
import torch
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import pandas as pd
import seaborn as sns
from tqdm import tqdm

# from synthesis import build_model, wavegen
from hparams import hparams
from wavenet_vocoder import WaveNet
from wavenet_vocoder.util import is_mulaw_quantize, is_mulaw, is_raw, is_scalar_input
from tqdm import tqdm
import audio
from nnmnkwii import preprocessing as P
import numpy as np
from scipy.io import wavfile

from data_loader import SpecsCombined

## Accompaniment Generator

In [None]:
g_accom = Generator(160, 0, 512, 20)
g_accom.load_state_dict(torch.load('model_latest_accom.pth'))

## Dataset

In [None]:
dataset = SpecsCombined('~/Data/segments_combined', len_crop=860)

## Data Loading

In [None]:
accom_spec, vocals_spec = dataset[500]
accom_spec = accom_spec.unsqueeze(0)
vocals_spec = vocals_spec.unsqueeze(0)
print(accom_spec.shape, vocals_spec.shape)
_, vocals_spec_2 = dataset[2]
vocals_spec_2 = vocals_spec_2.unsqueeze(0)

## Accompaniment Latent Vector Generation

In [None]:
accom_vec = g_accom(accom_spec, return_encoder_output=True)
accom_vec.shape

## Random Input

In [None]:
x = torch.randn(1, 860, 80)
# x = torch.sin(x)
plt.imshow(x.squeeze(0))
# x_noise = torch.FloatTensor(1, 860, 320).uniform_(-0.06, 0.06)
# plt.imshow(x_noise.squeeze(0))

## Real Input

In [None]:
x = np.load('example_vocals-feats.npy')
x = torch.from_numpy(x)
x = x[:860, :].unsqueeze(0)
x.shape

## Vocals Network

In [None]:
g_vocals = GeneratorV2(160, 0, 512, 20, 860, 128)
g_vocals.load_state_dict(torch.load('model_lowest_val_vae.pth'))

## Random Latent Vector Generation

In [None]:
condition_vec = g_vocals.cond_proj(accom_vec.flatten(start_dim=1))
latent_vec = torch.cat((torch.rand(1, 128), condition_vec), dim=-1)

## Seeded Latent Vector Generation

In [None]:
condition_vec = g_vocals.cond_proj(accom_vec.flatten(start_dim=1))

vocal_vec_1 = g_vocals.vocals_proj(g_vocals(vocals_spec, return_encoder_output=True).flatten(start_dim=1))
vocal_vec_2 = g_vocals.vocals_proj(g_vocals(vocals_spec_2, return_encoder_output=True).flatten(start_dim=1))

# # Take the average of the two
# vocal_vec = (vocal_vec_1 + vocal_vec_2) / 2

# vocal_vec = (vocal_vec_1 * 0.5) + (vocal_vec_2 * 0.5)

# vocal_vec = vocal_vec_1 + (vocal_vec_2 * 0.5)

latent_vec = torch.cat((vocal_vec_1, condition_vec), dim=-1)

## Encoding

In [None]:
# Reparameterization trick
mu = g_vocals.mu_fc(latent_vec)
logvar = g_vocals.logvar_fc(latent_vec)
std = torch.exp(logvar / 2)
q = torch.distributions.Normal(mu, std)
z = q.rsample()

encoder_outputs = g_vocals.latent_proj(z)

encoder_outputs = encoder_outputs.reshape(1, 860, 320)

plt.imshow(vocals_spec.squeeze(0))

## Synthesis

In [None]:
mel_outputs = g_vocals.decoder(encoder_outputs)
                
mel_outputs_postnet = g_vocals.postnet(mel_outputs.transpose(2,1))
mel_outputs_postnet = mel_outputs + mel_outputs_postnet.transpose(2,1)

In [None]:
plt.imshow(mel_outputs_postnet.squeeze(0).squeeze(0).detach().numpy())

## WaveNet

In [None]:
def build_model():
    if is_mulaw_quantize(hparams.input_type):
        if hparams.out_channels != hparams.quantize_channels:
            raise RuntimeError(
                "out_channels must equal to quantize_chennels if input_type is 'mulaw-quantize'")
    if hparams.upsample_conditional_features and hparams.cin_channels < 0:
        s = "Upsample conv layers were specified while local conditioning disabled. "
        s += "Notice that upsample conv layers will never be used."
        print(s)

    upsample_params = hparams.upsample_params
    upsample_params["cin_channels"] = hparams.cin_channels
    upsample_params["cin_pad"] = hparams.cin_pad
    model = WaveNet(
        out_channels=hparams.out_channels,
        layers=hparams.layers,
        stacks=hparams.stacks,
        residual_channels=hparams.residual_channels,
        gate_channels=hparams.gate_channels,
        skip_out_channels=hparams.skip_out_channels,
        cin_channels=hparams.cin_channels,
        gin_channels=hparams.gin_channels,
        n_speakers=hparams.n_speakers,
        dropout=hparams.dropout,
        kernel_size=hparams.kernel_size,
        cin_pad=hparams.cin_pad,
        upsample_conditional_features=hparams.upsample_conditional_features,
        upsample_params=upsample_params,
        scalar_input=is_scalar_input(hparams.input_type),
        output_distribution=hparams.output_distribution,
    )
    return model


def batch_wavegen(model, c=None, g=None, fast=True, tqdm=tqdm):
    assert c is not None
    B = c.shape[0]
    model.eval()
    if fast:
        model.make_generation_fast_()

    # Transform data to GPU
    g = None if g is None else g.to(device)
    c = None if c is None else c.to(device)

    if hparams.upsample_conditional_features:
        length = (c.shape[-1] - hparams.cin_pad * 2) * audio.get_hop_size()
    else:
        # already dupulicated
        length = c.shape[-1]

    with torch.no_grad():
        y_hat = model.incremental_forward(
            c=c, g=g, T=length, tqdm=tqdm, softmax=True, quantize=True,
            log_scale_min=hparams.log_scale_min)

    if is_mulaw_quantize(hparams.input_type):
        # needs to be float since mulaw_inv returns in range of [-1, 1]
        y_hat = y_hat.max(1)[1].view(B, -1).float().cpu().data.numpy()
        for i in range(B):
            y_hat[i] = P.inv_mulaw_quantize(y_hat[i], hparams.quantize_channels - 1)
    elif is_mulaw(hparams.input_type):
        y_hat = y_hat.view(B, -1).cpu().data.numpy()
        for i in range(B):
            y_hat[i] = P.inv_mulaw(y_hat[i], hparams.quantize_channels - 1)
    else:
        y_hat = y_hat.view(B, -1).cpu().data.numpy()

    if hparams.postprocess is not None and hparams.postprocess not in ["", "none"]:
        for i in range(B):
            y_hat[i] = getattr(audio, hparams.postprocess)(y_hat[i])

    if hparams.global_gain_scale > 0:
        for i in range(B):
            y_hat[i] /= hparams.global_gain_scale

    return y_hat


def to_int16(x):
    if x.dtype == np.int16:
        return x
    assert x.dtype == np.float32
    assert x.min() >= -1 and x.max() <= 1.0
    return (x * 32767).astype(np.int16)

In [None]:
device = torch.device("cuda")
model = build_model().to(device)
checkpoint = torch.load("/wavenet_vocoder/checkpoints/checkpoint_latest_ema.pth")
model.load_state_dict(checkpoint["state_dict"])

In [None]:
# outputs = (mel_outputs_postnet/2) + (accom_spec/2)
# c = outputs.squeeze(0).detach()

num_chunks = 20

# Original vocals
# c = vocals_spec.squeeze(0).detach()
# Vocal output
c = mel_outputs_postnet.squeeze(0).detach()
# Accom output
# c = accom_spec.squeeze(0).detach()

# Split c into chunks across the 0th dimension
length = c.shape[0]
c = c.T
c_chunks = c.reshape(80, length//num_chunks, num_chunks)
c_chunks = c_chunks.permute(1, 0, 2)
c = c_chunks

# # Resize c to 1, 80, 866
# print(c.shape)
# c = TF.resize(c, (80, 866))
# c = c[:, :, :50]
# print(c.shape)

# Generate
y_hats = batch_wavegen(model, c=c, g=None, fast=True, tqdm=tqdm)
y_hats = torch.from_numpy(y_hats).flatten().unsqueeze(0).numpy()

gen = y_hats[0]
gen = np.clip(gen, -1.0, 1.0)
wavfile.write('test.wav', hparams.sample_rate, to_int16(gen))

In [None]:
# Save the vocals models
# torch.save(g_vocals.state_dict(), './model_v3_7k.pth')

## T-SNE Visualization of Song Distrib

In [None]:
device = 'cuda'
g_accom.to(device)
g_vocals.to(device)

vecs = []
for i in tqdm(range(len(dataset))):
    accom_spec, vocals_spec = dataset[i]
    accom_spec = accom_spec.unsqueeze(0).to(device)
    vocals_spec = vocals_spec.unsqueeze(0).to(device)

    accom_vec = g_accom(accom_spec, return_encoder_output=True)
    condition_vec = g_vocals.cond_proj(accom_vec.flatten(start_dim=1))
    vocal_vec = g_vocals.vocals_proj(g_vocals(vocals_spec, return_encoder_output=True).flatten(start_dim=1))

    latent_vec = torch.cat((vocal_vec, condition_vec), dim=-1)

    vecs.append(latent_vec.detach().cpu().numpy())

In [None]:
# Generate labels from the file list in the dataset
file_list = dataset.files
name_to_label = {}
labels = []
for file in file_list:
    name = file.split('/')[-1].split('_')[0]

    if name not in name_to_label:
        name_to_label[name] = len(name_to_label)
    
    labels.append(name_to_label[name])

In [None]:
# Stack numpy list into a single numpy array
vec_stack = np.vstack(vecs)
# Get a list of numbers from 0 - 178 as a numpy array
# num_list = np.arange(0, 178)
print("Number of songs:", len(np.unique(labels)))
print("Number of labels:", len(labels))
print("Number of vectors:", len(vecs))

In [None]:
# Filter vectors for first 10 songs
num_songs = 80
offset = 70
filtered_vecs = []
filtered_labels = []
for i in range(len(labels)):
    if labels[i] < num_songs and labels[i] >= offset:
        filtered_vecs.append(vec_stack[i])
        filtered_labels.append(labels[i])

filtered_vec_stack = np.vstack(filtered_vecs)
filtered_labels = np.array(filtered_labels)

In [None]:
n_components = 2
tsne = TSNE(n_components, learning_rate='auto', init='pca')
tsne_result = tsne.fit_transform(filtered_vec_stack)
tsne_result.shape

In [None]:
tsne_result_df = pd.DataFrame({'tsne_1': tsne_result[:,0], 'tsne_2': tsne_result[:,1], 'label': filtered_labels})
fig, ax = plt.subplots(1)
sns.scatterplot(x='tsne_1', y='tsne_2', hue='label', data=tsne_result_df, ax=ax,s=120, palette="tab10")
lim = (tsne_result.min()-5, tsne_result.max()+5)
ax.set_xlim(lim)
ax.set_ylim(lim)
ax.set_aspect('equal')
ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0)

## Album-based Labelling

In [None]:
song_meta = pd.read_csv('song_meta.csv')

In [None]:
song, artist, writer, album, year, ref = song_meta.iloc[0]

# Match song title to album
name_to_album = {}
for i in range(len(name_to_label.keys())):
    song_name = list(name_to_label.keys())[i].lower()
    # Loop through all songs in song_meta and store the album name
    # if a song name matches the song name in the dataset
    for j in range(len(song_meta)):
        song, artist, writer, album, year, ref = song_meta.iloc[j]
        # if album not in ('1989', 'Taylor Swift'):
        #     continue
        song = song.lower().replace('"', '')
        
        album = album.replace('(Deluxe edition)', '').split(' ')[0]

        if song in song_name:
            name_to_album[song_name] = album

In [None]:
album_labels = []
album_vecs = []
for i in range(len(file_list)):
    file = file_list[i]
    vec = vecs[i]

    name = file.split('/')[-1].split('_')[0]
    name = name.lower()
    
    if name in name_to_album:
        album_labels.append(name_to_album[name])
        album_vecs.append(vec)

album_vecs = np.vstack(album_vecs)

In [None]:
n_components = 2
tsne = TSNE(n_components, learning_rate='auto', init='pca')
tsne_result = tsne.fit_transform(album_vecs)
tsne_result.shape

In [None]:
tsne_result_df = pd.DataFrame({'tsne_1': tsne_result[:,0], 'tsne_2': tsne_result[:,1], 'label': album_labels})
fig, ax = plt.subplots(1)
sns.scatterplot(x='tsne_1', y='tsne_2', hue='label', data=tsne_result_df, ax=ax,s=120)
lim = (tsne_result.min()-5, tsne_result.max()+5)
ax.set_xlim(lim)
ax.set_ylim(lim)
ax.set_aspect('equal')
ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0)

## Mean-Vector Song Labelling

In [None]:
name_to_mean_vec = {}
for file in file_list:
    name = file.split('/')[-1].split('_')[0]

    if name not in name_to_mean_vec:
        name_to_mean_vec[name] = []
    
    name_to_mean_vec[name].append(vecs[i])

mean_vec_labels = []
mean_vecs = []
for name in name_to_mean_vec:
    mean_vec_labels.append(name)
    mean_vecs.append(np.mean(name_to_mean_vec[name], axis=0))
mean_vecs = np.vstack(mean_vecs)

mean_vecs.shape, len(mean_vec_labels)

In [None]:
# Filter vectors for first 10 songs
num_songs = 170
filtered_mean_vec_labels = mean_vec_labels[:num_songs]
filtered_mean_vecs = mean_vecs[:num_songs]

In [None]:
n_components = 2
tsne = TSNE(n_components, learning_rate=200, init='pca')
tsne_result = tsne.fit_transform(filtered_mean_vecs)
tsne_result.shape

tsne_result_df = pd.DataFrame({'tsne_1': tsne_result[:,0], 'tsne_2': tsne_result[:,1], 'label': filtered_mean_vec_labels})
fig, ax = plt.subplots(1)
sns.scatterplot(x='tsne_1', y='tsne_2', hue='label', data=tsne_result_df, ax=ax,s=120)
lim = (tsne_result.min()-100, tsne_result.max()+100)
ax.set_xlim(lim)
ax.set_ylim(lim)
ax.set_aspect('equal')
ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0)

In [None]:
for i in range(len(tsne_result)):
    print(i, tsne_result[i])

In [None]:
start_id = 34
print(filtered_mean_vec_labels[start_id], filtered_mean_vec_labels[start_id + 1])

## Mean-Vectors Labelled by Album

In [None]:
album_mean_vec_labels = []
album_mean_vecs = []
for i in range(len(mean_vec_labels)):
    name = mean_vec_labels[i]
    name = name.lower()
    vec = mean_vecs[i]

    if name in name_to_album:
        album_mean_vec_labels.append(name_to_album[name])
        album_mean_vecs.append(vec)

album_mean_vecs = np.vstack(album_mean_vecs)

In [None]:
n_components = 2
tsne = TSNE(n_components, learning_rate=200, init='pca')
tsne_result = tsne.fit_transform(album_mean_vecs)
print(tsne_result.shape)

tsne_result_df = pd.DataFrame({'tsne_1': tsne_result[:,0], 'tsne_2': tsne_result[:,1], 'label': album_mean_vec_labels})
fig, ax = plt.subplots(1)
sns.scatterplot(x='tsne_1', y='tsne_2', hue='label', data=tsne_result_df, ax=ax,s=120, palette="tab10")
lim = (tsne_result.min()-250, tsne_result.max()+250)
ax.set_xlim(lim)
ax.set_ylim(lim)
ax.set_aspect('equal')
ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0)

In [None]:
album_mean_vecs.shape, len(album_mean_vec_labels)

In [None]:
plt.scatter(tsne_result[:,0], tsne_result[:,1], cmap='tab10')