In [1]:
# Cell 1 — setup
import os
import json
from pathlib import Path

import numpy as np
import torch

# Reprodutibilidade
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)
if device.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))

# Repo paths (ajuste se necessário)
DATA_DIR = Path("../data")
CACHE_ROOT = DATA_DIR / "nsynth_mel_cache"

TRAIN_ROOT = DATA_DIR / "nsynth-train.jsonwav" / "nsynth-train"
VALID_ROOT = DATA_DIR / "nsynth-valid.jsonwav" / "nsynth-valid"
TEST_ROOT  = DATA_DIR / "nsynth-test.jsonwav"  / "nsynth-test"

# JSONs
TRAIN_JSON = TRAIN_ROOT / "examples.json"
VALID_JSON = VALID_ROOT / "examples.json"
TEST_JSON  = TEST_ROOT  / "examples.json"

# Cache dirs
TRAIN_CACHE = CACHE_ROOT / "train"
VALID_CACHE = CACHE_ROOT / "valid"
TEST_CACHE  = CACHE_ROOT / "test"

print("CACHE_ROOT exists:", CACHE_ROOT.exists(), "|", CACHE_ROOT)
print("TRAIN_CACHE exists:", TRAIN_CACHE.exists(), "|", TRAIN_CACHE)
print("VALID_CACHE exists:", VALID_CACHE.exists(), "|", VALID_CACHE)
print("TEST_CACHE  exists:", TEST_CACHE.exists(),  "|", TEST_CACHE)

print("TRAIN_JSON exists:", TRAIN_JSON.exists())
print("VALID_JSON exists:", VALID_JSON.exists())
print("TEST_JSON  exists:", TEST_JSON.exists())


device: cuda
GPU: NVIDIA GeForce RTX 3060 Ti
CACHE_ROOT exists: True | ../data/nsynth_mel_cache
TRAIN_CACHE exists: True | ../data/nsynth_mel_cache/train
VALID_CACHE exists: True | ../data/nsynth_mel_cache/valid
TEST_CACHE  exists: True | ../data/nsynth_mel_cache/test
TRAIN_JSON exists: True
VALID_JSON exists: True
TEST_JSON  exists: True


In [2]:
# Cell 2 — import model code
import sys
sys.path.append(str(Path("..").resolve()))  # permite importar scripts/...

from scripts.models import ConditionalVAE  # ajuste se o nome do módulo/classe diferir

print("Imported ConditionalVAE from scripts/models.py")


Imported ConditionalVAE from scripts/models.py


In [33]:
# Cell 3 — choose split + load examples
SPLIT = "valid"  # "valid" ou "test" (recomendo começar em "valid" se quiser comparar com plots anteriores)

SPLIT_TO_JSON = {"train": TRAIN_JSON, "valid": VALID_JSON, "test": TEST_JSON}
SPLIT_TO_CACHE = {"train": TRAIN_CACHE, "valid": VALID_CACHE, "test": TEST_CACHE}

json_path = SPLIT_TO_JSON[SPLIT]
cache_dir = SPLIT_TO_CACHE[SPLIT]

with open(json_path, "r") as f:
    examples = json.load(f)

keys = list(examples.keys())
print("SPLIT:", SPLIT)
print("json_path:", json_path)
print("cache_dir:", cache_dir)
print("Number of examples:", len(keys))

# sample
for k in keys[:3]:
    meta = examples[k]
    print(k, "| pitch:", meta["pitch"], "| family:", meta["instrument_family"], "| source:", meta.get("instrument_source", None))


SPLIT: valid
json_path: ../data/nsynth-valid.jsonwav/nsynth-valid/examples.json
cache_dir: ../data/nsynth_mel_cache/valid
Number of examples: 12678
keyboard_acoustic_004-060-025 | pitch: 60 | family: 4 | source: 0
bass_synthetic_033-050-100 | pitch: 50 | family: 0 | source: 2
bass_synthetic_009-052-050 | pitch: 52 | family: 0 | source: 2


In [34]:
# Cell 4 — Dataset from cache
from torch.utils.data import Dataset

class NsynthMelCacheDataset(Dataset):
    def __init__(self, keys, examples, cache_dir: Path):
        self.keys = list(keys)
        self.examples = examples
        self.cache_dir = Path(cache_dir)

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        k = self.keys[idx]
        pt_path = self.cache_dir / f"{k}.pt"
        if not pt_path.exists():
            raise FileNotFoundError(f"Missing cache file: {pt_path}")

        x = torch.load(pt_path, weights_only=True)  # (1,80,128) normalizado [-1,1]
        pitch = torch.tensor(int(self.examples[k]["pitch"]), dtype=torch.long)
        family = torch.tensor(int(self.examples[k]["instrument_family"]), dtype=torch.long)
        return x, pitch, family, k

# sanity sample
ds_tmp = NsynthMelCacheDataset(keys[:10], examples, cache_dir)
x0, p0, f0, k0 = ds_tmp[0]
print("sample key:", k0)
print("x shape:", tuple(x0.shape), "| dtype:", x0.dtype, "| range:", float(x0.min()), float(x0.max()))
print("pitch:", int(p0), "| family:", int(f0))


sample key: keyboard_acoustic_004-060-025
x shape: (1, 80, 128) | dtype: torch.float32 | range: -1.0 1.0
pitch: 60 | family: 4


In [35]:
# Cell 5 — DataLoader
from torch.utils.data import DataLoader

BATCH_SIZE = 256
NUM_WORKERS = 8

MAX_KEYS = 50_000   # coloque None para usar tudo (mas pra começar é melhor limitar)
keys_use = keys if MAX_KEYS is None else keys[:MAX_KEYS]

ds = NsynthMelCacheDataset(keys_use, examples, cache_dir)

loader = DataLoader(
    ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=(device.type == "cuda"),
    persistent_workers=(NUM_WORKERS > 0),
)

x, pitch, family, k = next(iter(loader))
print("batch x:", tuple(x.shape), x.dtype)
print("batch pitch:", pitch.shape, pitch.dtype, "| min/max:", int(pitch.min()), int(pitch.max()))
print("batch family:", family.shape, family.dtype, "| min/max:", int(family.min()), int(family.max()))
print("first key:", k[0])


batch x: (256, 1, 80, 128) torch.float32
batch pitch: torch.Size([256]) torch.int64 | min/max: 15 117
batch family: torch.Size([256]) torch.int64 | min/max: 0 10
first key: keyboard_acoustic_004-060-025


In [36]:
# Cell 6 — load checkpoint
CKPT_PATH = Path("../notebooks/ckpts/cvae_pitch_lat32_beta2.0_fb0.5_20260209_184708.pt")  # <-- ajuste pro seu ckpt do CVAE (nome/dir)

LATENT_DIM = 32
PITCH_VOCAB = 128
COND_DIM = 16

cvae = ConditionalVAE(
    latent_dim=LATENT_DIM,
    pitch_vocab=PITCH_VOCAB,
    cond_dim=COND_DIM,
).to(device)

ckpt = torch.load(CKPT_PATH, map_location="cpu")

# Suporta: state_dict direto OU dict com chaves comuns
if isinstance(ckpt, dict):
    if "model_state_dict" in ckpt:
        state = ckpt["model_state_dict"]
    elif "state_dict" in ckpt:
        state = ckpt["state_dict"]
    else:
        # às vezes o próprio dict já é o state_dict
        state = ckpt
else:
    state = ckpt

missing, unexpected = cvae.load_state_dict(state, strict=False)
cvae.eval()

print("Loaded ckpt:", CKPT_PATH)
print("missing keys:", len(missing))
print("unexpected keys:", len(unexpected))


Loaded ckpt: ../notebooks/ckpts/cvae_pitch_lat32_beta2.0_fb0.5_20260209_184708.pt
missing keys: 19
unexpected keys: 5



You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



In [37]:
# Cell 7 — extract embeddings (mu/logvar) + pitch cond + metadata and save npz
from tqdm.auto import tqdm
import numpy as np
import torch
from pathlib import Path

N_SAMPLES = 20_000  # ajuste
OUT_DIR = Path("../outputs")
OUT_DIR.mkdir(parents=True, exist_ok=True)

out_npz = OUT_DIR / f"embeddings_cvae_{SPLIT}_N{N_SAMPLES}.npz"

cvae.eval()

mus, logvars = [], []
conds = []  # ✅ pitch embedding (cond_dim)
pitches, families = [], []
keys_out = []

count = 0
with torch.no_grad():
    for x, pitch, family, k in tqdm(loader, desc=f"Extract mu/logvar/cond ({SPLIT})"):
        x = x.to(device, non_blocking=True)
        pitch_dev = pitch.to(device, non_blocking=True)

        # ✅ encoder NÃO usa pitch no seu models.py
        mu, logvar = cvae.encoder(x)  # (B, latent_dim)

        # ✅ cond é o embedding do pitch
        cond = cvae.pitch_cond(pitch_dev)  # (B, cond_dim)

        bsz = mu.shape[0]
        mus.append(mu.detach().cpu())
        logvars.append(logvar.detach().cpu())
        conds.append(cond.detach().cpu())

        pitches.append(pitch.detach().cpu())
        families.append(family.detach().cpu())
        keys_out.extend(list(k))

        count += bsz
        if count >= N_SAMPLES:
            break

mu_all = torch.cat(mus, dim=0)[:N_SAMPLES].numpy()
logvar_all = torch.cat(logvars, dim=0)[:N_SAMPLES].numpy()
cond_all = torch.cat(conds, dim=0)[:N_SAMPLES].numpy()

pitch_all = torch.cat(pitches, dim=0)[:N_SAMPLES].numpy().astype(np.int32)
family_all = torch.cat(families, dim=0)[:N_SAMPLES].numpy().astype(np.int32)
keys_all = np.array(keys_out[:N_SAMPLES])

print("mu_all:", mu_all.shape)
print("logvar_all:", logvar_all.shape)
print("cond_all:", cond_all.shape)
print("pitch_all:", pitch_all.shape, "| min/max:", int(pitch_all.min()), int(pitch_all.max()))
print("family_all:", family_all.shape, "| unique:", len(np.unique(family_all)))
print("keys_all:", keys_all.shape)

np.savez_compressed(
    out_npz,
    mu=mu_all,
    logvar=logvar_all,
    cond=cond_all,          # ✅ pitch embedding
    pitch=pitch_all,
    family=family_all,
    keys=keys_all,
    split=SPLIT,
    latent_dim=mu_all.shape[1],
    cond_dim=cond_all.shape[1],
)
print("Saved:", out_npz)


Extract mu/logvar/cond (valid):   0%|          | 0/50 [00:00<?, ?it/s]

Extract mu/logvar/cond (valid): 100%|██████████| 50/50 [00:01<00:00, 26.40it/s]


mu_all: (12678, 32)
logvar_all: (12678, 32)
cond_all: (12678, 16)
pitch_all: (12678,) | min/max: 9 120
family_all: (12678,) | unique: 10
keys_all: (12678,)
Saved: ../outputs/embeddings_cvae_valid_N20000.npz


In [38]:
# Cell 8 — load saved embeddings and basic sanity checks
import numpy as np
from pathlib import Path

OUT_DIR = Path("../outputs")
in_npz = OUT_DIR / f"embeddings_cvae_{SPLIT}_N{N_SAMPLES}.npz"

assert in_npz.exists(), f"File not found: {in_npz}"

data = np.load(in_npz, allow_pickle=True)

mu_all     = data["mu"]        # (N, latent_dim)
logvar_all = data["logvar"]    # (N, latent_dim)
cond_all   = data["cond"]      # (N, cond_dim)
pitch_all  = data["pitch"]     # (N,)
family_all = data["family"]    # (N,)
keys_all   = data["keys"]      # (N,)

print("Loaded:", in_npz)
print("-" * 40)
print("mu_all:", mu_all.shape)
print("logvar_all:", logvar_all.shape)
print("cond_all:", cond_all.shape)
print("pitch_all:", pitch_all.shape, "| min/max:", pitch_all.min(), pitch_all.max())
print("family_all:", family_all.shape, "| unique:", np.unique(family_all))
print("keys_all:", keys_all.shape)

# -----------------------------
# Sanity checks
# -----------------------------

# 1) Shapes consistency
N = mu_all.shape[0]
assert logvar_all.shape[0] == N
assert cond_all.shape[0] == N
assert pitch_all.shape[0] == N
assert family_all.shape[0] == N
assert keys_all.shape[0] == N

# 2) Finite values
assert np.isfinite(mu_all).all(), "NaNs or infs in mu_all"
assert np.isfinite(logvar_all).all(), "NaNs or infs in logvar_all"
assert np.isfinite(cond_all).all(), "NaNs or infs in cond_all"

# 3) Pitch-conditioning sanity:
# same pitch → same cond embedding (up to numerical precision)
unique_pitches = np.unique(pitch_all)
test_pitch = unique_pitches[len(unique_pitches) // 2]

idx = np.where(pitch_all == test_pitch)[0]
if len(idx) >= 2:
    diff = np.abs(cond_all[idx[0]] - cond_all[idx[1]]).max()
    print(f"Cond diff for same pitch ({test_pitch}): {diff:.6e}")
else:
    print(f"Not enough samples to test cond consistency for pitch {test_pitch}")

print("Sanity checks passed ✅")


Loaded: ../outputs/embeddings_cvae_valid_N20000.npz
----------------------------------------
mu_all: (12678, 32)
logvar_all: (12678, 32)
cond_all: (12678, 16)
pitch_all: (12678,) | min/max: 9 120
family_all: (12678,) | unique: [ 0  1  2  3  4  5  6  7  8 10]
keys_all: (12678,)
Cond diff for same pitch (65): 0.000000e+00
Sanity checks passed ✅


In [39]:
# Cell 9 — PCA 2D on mu + Plotly visualizations
import numpy as np
from sklearn.decomposition import PCA
import plotly.express as px

# Map instrument family to readable labels
NSYNTH_FAMILY_MAP = {
    0: "bass",
    1: "brass",
    2: "flute",
    3: "guitar",
    4: "keyboard",
    5: "mallet",
    6: "organ",
    7: "reed",
    8: "string",
    9: "synth_lead",
    10: "vocal",
}

family_labels = np.array([
    f"{int(fid)} – {NSYNTH_FAMILY_MAP.get(int(fid), 'unknown')}"
    for fid in family_all
])

# --- PCA 2D ---
pca = PCA(n_components=2, random_state=SEED)
mu_pca = pca.fit_transform(mu_all)

print("Explained variance ratio:", pca.explained_variance_ratio_,
      "| sum:", pca.explained_variance_ratio_.sum())

data_pca = {
    "pc1": mu_pca[:, 0],
    "pc2": mu_pca[:, 1],
    "pitch": pitch_all.astype(int),
    "family": family_all.astype(int),
    "family_label": family_labels,   # categorical legend
    "key": keys_all,
}

# --- Plot 1: categorical legend by family ---
fig1 = px.scatter(
    data_pca,
    x="pc1",
    y="pc2",
    color="family_label",
    hover_data=["pitch", "family", "key"],
    title=f"PCA (mu) — color: instrument family (SPLIT={SPLIT}, N={len(mu_all)})",
)
fig1.show()

# --- Plot 2: gradient by pitch ---
fig2 = px.scatter(
    data_pca,
    x="pc1",
    y="pc2",
    color="pitch",
    hover_data=["family_label", "key"],
    title=f"PCA (mu) — color: pitch (SPLIT={SPLIT}, N={len(mu_all)})",
)
fig2.show()


Explained variance ratio: [0.39533982 0.09419107] | sum: 0.4895309


In [40]:
# Cell 10 — UMAP 2D on mu + Plotly visualizations
import numpy as np
import umap
import plotly.express as px

# garante SEED como int (evita erro tipo '<' not supported entre str e float)
SEED_INT = int(SEED)

# Map instrument family to readable labels
NSYNTH_FAMILY_MAP = {
    0: "bass",
    1: "brass",
    2: "flute",
    3: "guitar",
    4: "keyboard",
    5: "mallet",
    6: "organ",
    7: "reed",
    8: "string",
    9: "synth_lead",
    10: "vocal",
}

family_labels = np.array([
    f"{int(fid)} – {NSYNTH_FAMILY_MAP.get(int(fid), 'unknown')}"
    for fid in family_all
])

# --- UMAP 2D ---
reducer = umap.UMAP(
    n_neighbors=30,
    min_dist=0.1,
    n_components=2,
    metric="euclidean",
    init="spectral",
    learning_rate=1.0,
    random_state=SEED,
)
mu_umap = reducer.fit_transform(mu_all)

data_umap = {
    "u1": mu_umap[:, 0],
    "u2": mu_umap[:, 1],
    "pitch": pitch_all.astype(int),
    "family": family_all.astype(int),
    "family_label": family_labels,   # categorical legend
    "key": keys_all,
}

# --- Plot 1: categorical legend by family ---
fig1 = px.scatter(
    data_umap,
    x="u1",
    y="u2",
    color="family_label",
    hover_data=["pitch", "family", "key"],
    title=f"UMAP (mu) — color: instrument family (SPLIT={SPLIT}, N={len(mu_all)})",
)
fig1.show()

# --- Plot 2: gradient by pitch ---
fig2 = px.scatter(
    data_umap,
    x="u1",
    y="u2",
    color="pitch",
    hover_data=["family_label", "key"],
    title=f"UMAP (mu) — color: pitch (SPLIT={SPLIT}, N={len(mu_all)})",
)
fig2.show()



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



In [41]:
# Cell 11 — Silhouette analysis (global + by pitch window)
import numpy as np
import pandas as pd
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_samples

def silhouette_per_group(
    X: np.ndarray,
    group: np.ndarray,
    n_clusters: int = 10,
    seed: int = 0,
) -> pd.DataFrame:
    """
    Clusters X with KMeans and computes silhouette per sample.
    Returns mean/std/count grouped by `group` (e.g., family or pitch).
    """
    km = KMeans(n_clusters=n_clusters, random_state=seed, n_init="auto")
    labels = km.fit_predict(X)

    sil = silhouette_samples(X, labels)

    df = pd.DataFrame({
        "group": group.astype(int),
        "silhouette": sil.astype(float),
    })

    out = df.groupby("group")["silhouette"].agg(
        mean_silhouette="mean",
        std_silhouette="std",
        n_samples="count",
    ).sort_values("mean_silhouette", ascending=False)

    out.index.name = "group"
    return out

def silhouette_pitch_window(
    X: np.ndarray,
    pitch: np.ndarray,
    family: np.ndarray,
    center_pitch: int,
    tol: int = 1,
    min_samples_family: int = 10,
    n_clusters: int = 10,
    seed: int = 0,
) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Filters samples to pitch in [center_pitch - tol, center_pitch + tol].
    Optionally removes families with too few samples.
    Returns:
      - per-family silhouette stats (mean/std/n)
      - the filtered dataframe used (for debugging)
    """
    pitch = pitch.astype(int)
    family = family.astype(int)

    mask = (pitch >= center_pitch - tol) & (pitch <= center_pitch + tol)
    idx = np.where(mask)[0]

    df = pd.DataFrame({
        "idx": idx,
        "pitch": pitch[idx],
        "family": family[idx],
    })

    # keep only families with enough samples
    fam_counts = df["family"].value_counts()
    keep = fam_counts[fam_counts >= min_samples_family].index.tolist()
    df = df[df["family"].isin(keep)].copy()

    print(f"Pitch window: {center_pitch} ± {tol}")
    print("Total samples:", len(idx))
    print("Families kept:", keep)
    print("Samples after filter:", len(df))

    if len(df) < (n_clusters + 2):
        raise ValueError(f"Not enough samples after filtering to run KMeans(n_clusters={n_clusters}).")

    Xf = X[df["idx"].to_numpy()]
    famf = df["family"].to_numpy()

    per_family = silhouette_per_group(
        X=Xf,
        group=famf,
        n_clusters=n_clusters,
        seed=seed,
    )
    per_family.index.name = "family"
    return per_family, df

# -------------------------------------------------------
# 1) Global silhouette per FAMILY (using ALL samples)
# -------------------------------------------------------
SEED_INT = int(SEED)
global_silhouette = silhouette_per_group(
    X=mu_all,
    group=family_all,
    n_clusters=10,
    seed=SEED_INT,
)
global_silhouette.index.name = "family"
print("\nGlobal silhouette per family:")
display(global_silhouette)

# -------------------------------------------------------
# 2) Pitch-window silhouette (example: pitch 60 ± 1)
#    (Você pode mudar center_pitch/tol depois)
# -------------------------------------------------------
results_pitch_60, df_pitch_60 = silhouette_pitch_window(
    X=mu_all,
    pitch=pitch_all,
    family=family_all,
    center_pitch=60,
    tol=1,
    min_samples_family=10,
    n_clusters=5,       # menor pq o subset fica pequeno; ajuste se tiver mais amostras
    seed=SEED_INT,
)

print("\nPitch 60 ± 1 — silhouette per family:")
display(results_pitch_60)



Global silhouette per family:


Unnamed: 0_level_0,mean_silhouette,std_silhouette,n_samples
family,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
3,0.130457,0.094854,2081
0,0.128407,0.101196,2638
1,0.119569,0.117244,886
8,0.112678,0.090846,814
4,0.105469,0.098428,2404
7,0.097416,0.084086,720
5,0.09115,0.092647,663
10,0.066734,0.083478,404
6,0.053935,0.091226,1598
2,0.050528,0.074785,470


Pitch window: 60 ± 1
Total samples: 522
Families kept: [0, 4, 3, 6, 1, 7, 8, 10, 5, 2]
Samples after filter: 522

Pitch 60 ± 1 — silhouette per family:


Unnamed: 0_level_0,mean_silhouette,std_silhouette,n_samples
family,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,0.386704,0.139014,108
5,0.380028,0.103056,18
8,0.336036,0.148352,27
3,0.202809,0.089115,74
1,0.19527,0.091251,55
10,0.185722,0.078356,25
7,0.17662,0.081371,46
6,0.169729,0.09491,66
4,0.135333,0.12297,91
2,0.1329,0.026111,12
