In [None]:
import os

from glob import glob

import IPython
import torch
import matplotlib.pyplot as plt

from sklearn.decomposition import PCA
from sklearn.manifold import Isomap
import scipy.stats

from vseq.data import BaseDataset
from vseq.data.batchers import AudioBatcher, ListBatcher
from vseq.data.datapaths import DATASETS, TIMIT
from vseq.data.loaders import AudioLoader, TIMITAlignmentLoader, TIMITSpeakerLoader
from vseq.data.transforms import MuLawEncode, MuLawDecode
from vseq.models.clockwork_vae import CWVAEAudioTasNet
from vseq.settings import CHECKPOINT_DIR
from vseq.utils.device import *

In [None]:
CHECKPOINT_DIR

# Setup

In [None]:
torch.autograd.set_grad_enabled(False)

1xrnjn5y | Big V100 model, 2 layers

In [None]:
run_id = "1xrnjn5y"
dataset = DATASETS[TIMIT]

In [None]:
def get_run_path(run_id, checkpoint_dir=CHECKPOINT_DIR):
    run_dirs = os.listdir(checkpoint_dir)
    run_dir = [run_dir for run_dir in run_dirs if run_id in run_dir]

    if len(run_dir) > 1:
        raise IOError(f"More than one run found with ID {run_id}: {run_dir}")
    elif len(run_dir) == 0:
        raise IOError(f"No runs found with ID {run_id}")
    return os.path.join(checkpoint_dir, run_dir[0])

In [None]:
def plot_logistic_mixture(means, log_scales, unnorm_weights, x_min=-1, x_max=1, n_vals=100, ax=None):
    l = np.linspace(x_min, x_max, n_vals)
    pdf = np.zeros_like(l)

    weights = unnorm_weights.softmax(-1).cpu().numpy()
    means = means.cpu().numpy()
    variances = log_scales.exp().cpu().numpy()

    for i in range(10):
        pdf += scipy.stats.logistic.pdf(l, loc=means[i], scale=variances[i]) * weights[i]

    if ax is None:
        fig, ax = plt.subplots(1, 1)
    else:
        fig = plt.gcf()

    ax.plot(l, pdf)
    return fig, ax

# Load model

In [None]:
run_path = get_run_path(run_id)
run_files_path = os.path.join(run_path, "files")
run_path

In [None]:
device = get_free_gpus()
model = CWVAEAudioTasNet.load(run_files_path, device=device)

In [None]:
# model.summary(input_size=(4, model.overall_stride), x_sl=torch.tensor([model.overall_stride]), device="cpu");

In [None]:
audio_loader = AudioLoader("wav", cache=False)
audio_batcher = AudioBatcher(padding_module=model.overall_stride)
transform_enc = MuLawEncode(bits=16)
transform_dec = MuLawDecode(bits=16)

In [None]:
modalities = [
    (audio_loader, transform_enc, audio_batcher),
    (TIMITAlignmentLoader("PHN"), None, ListBatcher),
    (TIMITAlignmentLoader("WRD"), None, ListBatcher),
    (TIMITSpeakerLoader(), None, ListBatcher),
]

train_dataset = BaseDataset(
    source=dataset.train,
    modalities=modalities,
)
valid_dataset = BaseDataset(
    source=dataset.test,
    modalities=modalities,
)


In [None]:
(a, w, p, s), metadata = valid_dataset[0]
a, w, p, s, metadata

# Section

## Reconstructions

Reconstruction accuracy is generally quite good.

One can note the following things:
- Occasionally, the sampled reconstruction for a single timestep is an outlier. This happens because a poor component is sampled from the mixture of logistics. The problem disappears when decoding from the mode.
- The mixture weights are not generally one hot. The maximum weight has an empirical distribution with a peak at around 0.6. This leaves quite some mass to be distributed to different (to some degree) components.

In [None]:
data = [valid_dataset[i] for i in range(4)]
audio = [a for i, ((a, p, w, s), metadata) in enumerate(data)]
x, x_sl = audio_batcher(audio)
x = x.to(device)

torch.manual_seed(6)
loss, metrics, output = model(x, x_sl)
data[0][-1]

In [None]:
audio[0].shape, x.shape, output.reconstruction.shape, output.latents[0].shape

In [None]:
fig, axes = plt.subplots(2, 1, figsize=(20, 10))

axes[0].set_title("Reconstructions in µ-law space")

axes[0].plot(x[0][16000:32000].cpu())
axes[0].plot(output.reconstruction[0][16000:32000].cpu(), alpha=0.8)

axes[1].plot(x[0][21000:21500].cpu())
axes[1].plot(output.reconstruction[0][21000:21500].cpu(), alpha=0.8)

In [None]:
fig, axes = plt.subplots(2, 1, figsize=(20, 10))

axes[0].set_title("Reconstructions in linear space")

axes[0].plot(transform_dec(audio[0][16000:32000].cpu()))
axes[0].plot(transform_dec(output.reconstruction[0][16000:32000].cpu()) + 1)

axes[1].plot(transform_dec(audio[0][21500:22000].cpu()))
axes[1].plot(transform_dec(output.reconstruction[0][21500:22000].cpu()))

In [None]:
mode_component = output.reconstruction_parameters[0].softmax(-1).argmax(-1).unsqueeze(-1)
mode = torch.gather(output.reconstruction_parameters[1], index=mode_component, dim=-1).squeeze()
output.reconstruction_parameters[1][0].shape, mode_component.shape, mode.shape

In [None]:
fig, axes = plt.subplots(2, 1, figsize=(20, 10))

axes[0].set_title("Reconstructions in linear space")

axes[0].plot(transform_dec(audio[0][16000:32000].cpu()))
axes[0].plot(transform_dec(mode[0][16000:32000].cpu()) + 1)

axes[1].plot(transform_dec(audio[0][21500:22000].cpu()))
axes[1].plot(transform_dec(mode[0][21500:22000].cpu()))

In [None]:
(mode - x).pow(2).mean(), (output.reconstruction.squeeze() - x).pow(2).mean()

In [None]:
v, bins, _ = plt.hist((output.reconstruction.squeeze() - x).pow(2).flatten().cpu().numpy(), alpha=0.5, bins=50, label="samples")
v, bins, _ = plt.hist((mode - x).pow(2).flatten().cpu().numpy(), alpha=0.5, bins=bins, label="mode")
plt.yscale("log")
plt.title("MSE between target and samples from or mode of $p(x_t|z)$")
plt.legend()

In [None]:
# Where is the MSE high?

fig, axes = plt.subplots(3, 1, figsize=(20, 10))

axes[0].set_title("Reconstructions in µ-law space")

axes[0].plot(x[0][16000:32000].cpu())
axes[0].plot(output.reconstruction[0][16000:32000].cpu(), alpha=0.8)

axes[1].plot(x[0][25000:25100].cpu())
axes[1].plot(output.reconstruction[0][25000:25100].cpu(), alpha=0.8)


abs_err = (x[0][:].cpu() - output.reconstruction[0][:].cpu().squeeze()).abs()
axes[2].plot(abs_err)
axes[2].plot(np.convolve(abs_err, np.ones(100)/100, mode='valid'), label="Running absolute error")

plt.legend()

In [None]:
import torchaudio

In [None]:
spectrogram = torchaudio.transforms.Spectrogram(
    n_fft=512,
    win_length=320,
    hop_length=160,
    power=2.0,
    normalized=False,
    onesided=True,
)

todb = torchaudio.transforms.AmplitudeToDB()

spec_audio = todb(spectrogram(audio[0]))
spec_recon = todb(spectrogram(output.reconstruction[0].squeeze().cpu()[:audio[0].shape[0]]))

err_freq = (spec_audio.flip(0) - spec_recon.flip(0)).abs().mean(1)
err_time = (spec_audio.flip(0) - spec_recon.flip(0)).abs().mean(0)

spec_recon.shape

In [None]:
fig, axes = plt.subplots(3, 1, figsize=(20, 10))
axes[0].imshow(spec_audio.flip(0), aspect="auto")
axes[1].imshow(spec_recon.flip(0), aspect="auto")
axes[2].plot(err_time)


In [None]:
plt.plot(err_freq[1:])

In [None]:
mode_idx = output.reconstruction_parameters[0].softmax(-1).argmax(-1, keepdim=True)
mode_weight = torch.gather(output.reconstruction_parameters[0].softmax(-1), index=mode_idx, dim=-1).squeeze()

In [None]:
max_val, va = output.reconstruction_parameters[0].softmax(-1).max(-1, keepdim=True)
not_max_idx = output.reconstruction_parameters[0].softmax(-1) < max_val
not_mode_weigts = output.reconstruction_parameters[0].softmax(-1)[not_max_idx]

In [None]:
# Which mixture components are used? Does it change over time?
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

axes[0].imshow(output.reconstruction_parameters[0][0].squeeze().cpu().softmax(-1).T, aspect="auto", interpolation="none")
axes[1].hist(mode_weight[0].cpu().numpy(), alpha=0.5, density=True, label="Max weight")
axes[1].hist(not_mode_weigts.cpu().numpy(), alpha=0.5, density=True, label="Other weights");

plt.legend()


In [None]:
T = 21000

h, w = 8, 7
n = h * w

fig, axes = plt.subplots(h, w, figsize=(16, 10), sharex=True)
fig.tight_layout()
axes = [a for ax in axes for a in ax]
axes[-w//2].set_xlabel("$p(x_t|z^1_t,z^2_t)$ at different timesteps", fontsize=16)
for t in range(n):
    plot_logistic_mixture(output.reconstruction_parameters[1][0][T+t, 0, :], output.reconstruction_parameters[2][0][T+t, 0, :], output.reconstruction_parameters[0][0][T+t, 0, :], ax=axes[t])
    axes[t].set_title(f"t={t}")

In [None]:
fig, axes = plt.subplots(10, 1, figsize=(10, 10))
for i, ax in enumerate(axes):
    ax.plot(output.reconstruction_parameters[1][0][21000:21200, 0, i].cpu(), alpha=0.8)

## Different speaker, same dialect, same gender

In [None]:
word = "water"
dialect = "New England"
sex = "F"

In [None]:
(a, p, w, s), metadata = valid_dataset[0]
s, w

In [None]:
water_sentences = [(i, ((a, p, w, s), metadata)) for i, ((a, p, w, s), metadata) in enumerate(valid_dataset) if word in w[2]]

In [None]:
water_sentences = [(i, ((a, p, w, s), metadata)) for i, ((a, p, w, s), metadata) in water_sentences if s.dialect == dialect and s.sex == sex]
len(water_sentences)

In [None]:
for (i, ((a, p, w, s), metadata)) in water_sentences:
    print(i, s)

In [None]:
audio = [a for i, ((a, p, w, s), metadata) in water_sentences]
x, x_sl = audio_batcher(audio)
x = transform_enc(x)
x = x.to(device)
loss, metrics, output = model(x, x_sl)

In [None]:
# Plot PCA of latent representations over time spanning a single work

In [None]:
pca = PCA(n_components=2, whiten=True)

In [None]:
isomap = Isomap(n_components=2, n_neighbors=5)

## Phonemes, same speaker

In [None]:
# Plot the PCA of latent representations for the same word pronounced by two different speakers.
# Plot the PCA of latent representations for different phonemes pronounced by the same speaker.
# Plot the PCA of latent representations for same phoneme pronounced by two different speakers. 