In [None]:
import os

from glob import glob

import torch
import matplotlib.pyplot as plt

from sklearn.decomposition import PCA
from sklearn.manifold import Isomap

from vseq.data import BaseDataset
from vseq.data.batchers import AudioBatcher, ListBatcher
from vseq.data.datapaths import DATASETS, TIMIT
from vseq.data.loaders import AudioLoader, TIMITAlignmentLoader, TIMITSpeakerLoader
from vseq.data.transforms import MuLawEncode
from vseq.models.clockwork_vae import CWVAEAudioTasNet
from vseq.settings import CHECKPOINT_DIR
from vseq.utils.device import *

In [None]:
CHECKPOINT_DIR

# Setup

In [None]:
run_id = "1xrnjn5y"
dataset = DATASETS[TIMIT]

In [None]:
def get_run_path(run_id, checkpoint_dir=CHECKPOINT_DIR):
    run_dirs = os.listdir(checkpoint_dir)
    run_dir = [run_dir for run_dir in run_dirs if run_id in run_dir]

    if len(run_dir) > 1:
        raise IOError(f"More than one run found with ID {run_id}: {run_dir}")
    elif len(run_dir) == 0:
        raise IOError(f"No runs found with ID {run_id}")
    return os.path.join(checkpoint_dir, run_dir[0])

# Load model

In [None]:
# run-20210802_090834-
run_path = get_run_path(run_id)
run_files_path = os.path.join(run_path, "files")

In [None]:
device = get_free_gpus()
model = CWVAEAudioTasNet.load(run_files_path, device=device)

In [None]:
# model.summary(input_size=(4, model.overall_stride), x_sl=torch.tensor([model.overall_stride]), device="cpu");

In [None]:
modalities = [
    (AudioLoader("wav", cache=False), MuLawEncode(bits=16), AudioBatcher(padding_module=model.overall_stride)),
    (TIMITAlignmentLoader("PHN"), None, ListBatcher),
    (TIMITAlignmentLoader("WRD"), None, ListBatcher),
    (TIMITSpeakerLoader(), None, ListBatcher),
]

train_dataset = BaseDataset(
    source=dataset.train,
    modalities=modalities,
)
valid_dataset = BaseDataset(
    source=dataset.test,
    modalities=modalities,
)


In [None]:
(a, w, p, s), metadata = valid_dataset[0]

In [None]:
s

# Section

In [None]:
speaker_metadata["ABC0"]

In [None]:
birthday

In [None]:
source_rows[0]

In [None]:
model.cwvae.encoder.in_transform[0].weight

In [None]:
from vseq.utils.device import *

In [None]:
# Plot the PCA of latent representations for the same word pronounced by two different speakers.
# Plot the PCA of latent representations for different phonemes pronounced by the same speaker.
# Plot the PCA of latent representations for same phoneme pronounced by two different speakers. 

In [None]:
run

In [None]:
%timeit os.listdir(CHECKPOINT_DIR)

In [None]:
%timeit glob(CHECKPOINT_DIR + "/*")