In [1]:
# --- Path setup: allow importing from scripts/ ---
import sys
from pathlib import Path

# project root = previous directory

PROJECT_ROOT = Path().resolve().parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

print(sys.path)

print("Project root:", PROJECT_ROOT)

# --- Import model + helpers ---
from scripts.models import ConditionalVAE, vae_loss, linear_beta_schedule


['/home/satan/git/VAE-Timbre-Spaces', '/home/satan/miniconda3/envs/timbre_space_dl/lib/python310.zip', '/home/satan/miniconda3/envs/timbre_space_dl/lib/python3.10', '/home/satan/miniconda3/envs/timbre_space_dl/lib/python3.10/lib-dynload', '', '/home/satan/miniconda3/envs/timbre_space_dl/lib/python3.10/site-packages']
Project root: /home/satan/git/VAE-Timbre-Spaces


In [2]:
# --- Reprodutibilidade ---
import random
import numpy as np
import torch

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# --- Device ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)
if device.type == "cuda":
    print("gpu:", torch.cuda.get_device_name(0))
    print("torch cuda:", torch.version.cuda)

# --- Audio parameters ---
SR = 16000
N_MELS = 80
N_FFT = 1024
HOP = 256
T = 128  # fixed number of time frames

# --- Training parameters ---
BATCH_SIZE = 128
EPOCHS = 15
LR = 5e-4

# --- Conditional VAE / anti-collapse ---
LATENT_DIM = 64
BETA_MAX = 1.0
WARMUP_STEPS = 2000
FREE_BITS = 0.25

print("Config OK")


device: cuda
gpu: NVIDIA GeForce RTX 3060 Ti
torch cuda: 12.1
Config OK


In [3]:
from pathlib import Path

# --- Paths (ajuste se necessário) ---
VALID_ROOT = Path("../data/nsynth-valid.jsonwav/nsynth-valid")  # mesma estrutura que você mostrou
TRAIN_ROOT = Path("../data/nsynth-train.jsonwav/nsynth-train")  # mesma estrutura que você mostrou
TEST_ROOT = Path("../data/nsynth-test.jsonwav/nsynth-test")    # mesma estrutura que você mostrou 

VALID_AUDIO_DIR = VALID_ROOT / "audio"
VALID_JSON_PATH = VALID_ROOT / "examples.json"

TEST_AUDIO_DIR = TEST_ROOT / "audio"
TEST_JSON_PATH = TEST_ROOT / "examples.json"

TRAIN_AUDIO_DIR = TRAIN_ROOT / "audio"
TRAIN_JSON_PATH = TRAIN_ROOT / "examples.json"

print("VALID_ROOT :", VALID_ROOT)
print("VALID_AUDIO_DIR exists:", VALID_AUDIO_DIR.exists(), "|", VALID_AUDIO_DIR)
print("VALID_JSON_PATH exists:", VALID_JSON_PATH.exists(), "|", VALID_JSON_PATH)

print("TRAIN_ROOT :", TRAIN_ROOT)
print("TRAIN_AUDIO_DIR exists:", TRAIN_AUDIO_DIR.exists(), "|", TRAIN_AUDIO_DIR)
print("TRAIN_JSON_PATH exists:", TRAIN_JSON_PATH.exists(), "|", TRAIN_JSON_PATH)

print("TEST_ROOT :", TEST_ROOT)
print("TEST_AUDIO_DIR exists:", TEST_AUDIO_DIR.exists(), "|", TEST_AUDIO_DIR)
print("TEST_JSON_PATH exists:", TEST_JSON_PATH.exists(), "|", TEST_JSON_PATH)
print("Paths OK")

VALID_ROOT : ../data/nsynth-valid.jsonwav/nsynth-valid
VALID_AUDIO_DIR exists: True | ../data/nsynth-valid.jsonwav/nsynth-valid/audio
VALID_JSON_PATH exists: True | ../data/nsynth-valid.jsonwav/nsynth-valid/examples.json
TRAIN_ROOT : ../data/nsynth-train.jsonwav/nsynth-train
TRAIN_AUDIO_DIR exists: True | ../data/nsynth-train.jsonwav/nsynth-train/audio
TRAIN_JSON_PATH exists: True | ../data/nsynth-train.jsonwav/nsynth-train/examples.json
TEST_ROOT : ../data/nsynth-test.jsonwav/nsynth-test
TEST_AUDIO_DIR exists: True | ../data/nsynth-test.jsonwav/nsynth-test/audio
TEST_JSON_PATH exists: True | ../data/nsynth-test.jsonwav/nsynth-test/examples.json
Paths OK


In [4]:
import json
from pathlib import Path
from collections import Counter
import numpy as np

# Escolha o split para checar: "train" | "valid" | "test"
SPLIT = "train"

ROOTS = {
    "train": TRAIN_ROOT,
    "valid": VALID_ROOT,
    "test":  TEST_ROOT,
}

root = Path(ROOTS[SPLIT])
JSON_PATH = root / "examples.json"
AUDIO_DIR = root / "audio"

print(f"SPLIT: {SPLIT}")
print("ROOT    :", root)
print("JSON    :", JSON_PATH.exists(), "|", JSON_PATH)
print("AUDIO   :", AUDIO_DIR.exists(), "|", AUDIO_DIR)

with open(JSON_PATH, "r") as f:
    examples = json.load(f)

keys = list(examples.keys())
print("Number of examples:", len(keys))

# -------------------------
# Sanity check: alguns exemplos
# -------------------------
print("\n--- Sample entries ---")
for k in keys[:3]:
    meta = examples[k]
    wav_ok = (AUDIO_DIR / f"{k}.wav").exists()
    print(
        k,
        "| wav:", wav_ok,
        "| pitch:", meta["pitch"],
        "| family:", meta["instrument_family"],
        "| source:", meta.get("instrument_source", None),
    )

# -------------------------
# Quick stats (útil pra paper e debug)
# -------------------------
pitches = [examples[k]["pitch"] for k in keys]
families = [examples[k]["instrument_family"] for k in keys]

print("\n--- Pitch stats ---")
print("min/max:", min(pitches), max(pitches))
print("mean/std:", float(np.mean(pitches)), float(np.std(pitches)))

print("\n--- Instrument family stats ---")
fam_counts = Counter(families)
print("num families:", len(fam_counts))
print("top 5:", fam_counts.most_common(5))

# -------------------------
# Wav existence check (amostrado)
# -------------------------
sample_n = min(200, len(keys))
sample_keys = keys[:sample_n]
missing = [k for k in sample_keys if not (AUDIO_DIR / f"{k}.wav").exists()]
print(f"\n--- WAV existence check (first {sample_n}) ---")
print("missing:", len(missing))
if missing:
    print("examples:", missing[:5])


SPLIT: train
ROOT    : ../data/nsynth-train.jsonwav/nsynth-train
JSON    : True | ../data/nsynth-train.jsonwav/nsynth-train/examples.json
AUDIO   : True | ../data/nsynth-train.jsonwav/nsynth-train/audio
Number of examples: 289205

--- Sample entries ---
guitar_acoustic_001-082-050 | wav: True | pitch: 82 | family: 3 | source: 0
bass_synthetic_120-108-050 | wav: True | pitch: 108 | family: 0 | source: 2
organ_electronic_120-050-127 | wav: True | pitch: 50 | family: 6 | source: 1

--- Pitch stats ---
min/max: 9 120
mean/std: 62.311657820577096 23.082482171499365

--- Instrument family stats ---
num families: 11
top 5: [(0, 65474), (4, 51821), (6, 34477), (5, 34201), (3, 32690)]

--- WAV existence check (first 200) ---
missing: 0


In [5]:
from torch.utils.data import Dataset
import torch
from pathlib import Path

CACHE_ROOT = Path("../data/nsynth_mel_cache")
print("CACHE_ROOT exists:", CACHE_ROOT.exists(), "|", CACHE_ROOT)
print("subdirs:", [p.name for p in CACHE_ROOT.iterdir() if p.is_dir()])

class NsynthMelCacheDataset(Dataset):
    def __init__(self, keys, examples, cache_dir: Path):
        self.keys = list(keys)
        self.examples = examples
        self.cache_dir = Path(cache_dir)

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        k = self.keys[idx]

        pt_path = self.cache_dir / f"{k}.pt"
        if not pt_path.exists():
            raise FileNotFoundError(f"Missing cache file: {pt_path}")

        # x: (1, 80, T) já normalizado em [-1,1]
        x = torch.load(pt_path, weights_only=True)

        # labels como tensores (melhor pro training loop)
        pitch = torch.tensor(int(self.examples[k]["pitch"]), dtype=torch.long)
        family = torch.tensor(int(self.examples[k]["instrument_family"]), dtype=torch.long)

        return x, pitch, family, k

# -------- sanity check por split --------
TRAIN_CACHE = CACHE_ROOT / "train"
VALID_CACHE = CACHE_ROOT / "valid"
TEST_CACHE  = CACHE_ROOT / "test"

print("TRAIN_CACHE exists:", TRAIN_CACHE.exists(), "|", TRAIN_CACHE)
print("VALID_CACHE exists:", VALID_CACHE.exists(), "|", VALID_CACHE)
print("TEST_CACHE  exists:", TEST_CACHE.exists(),  "|", TEST_CACHE)



CACHE_ROOT exists: True | ../data/nsynth_mel_cache
subdirs: ['test', 'train', 'valid']
TRAIN_CACHE exists: True | ../data/nsynth_mel_cache/train
VALID_CACHE exists: True | ../data/nsynth_mel_cache/valid
TEST_CACHE  exists: True | ../data/nsynth_mel_cache/test


In [6]:
import numpy as np
import torch
import torch.nn.functional as F
import librosa
from pathlib import Path

# Garanta que esses parâmetros já existem no notebook:
# SR = 16000
# N_MELS = 80
# N_FFT = 1024
# HOP = 256
# T = 128

def wav_to_logmel_tensor(wav_path: Path) -> torch.Tensor:
    audio, sr = librosa.load(wav_path, sr=SR, mono=True)

    mel = librosa.feature.melspectrogram(
        y=audio,
        sr=sr,
        n_fft=N_FFT,
        hop_length=HOP,
        n_mels=N_MELS,
        power=2.0,   # power spectrogram (default do librosa.melspectrogram)
    )

    log_mel = librosa.power_to_db(mel, ref=np.max)
    log_mel = np.clip(log_mel, -80.0, 0.0)

    # normalize dB [-80,0] -> [-1,1]
    x_norm = (log_mel + 80.0) / 80.0   # [0,1]
    x_norm = 2.0 * x_norm - 1.0        # [-1,1]

    x = torch.tensor(x_norm, dtype=torch.float32).unsqueeze(0)  # (1, 80, T')

    # pad/crop to fixed T
    time_frames = x.shape[-1]
    if time_frames >= T:
        x = x[:, :, :T]
    else:
        pad = T - time_frames
        x = F.pad(x, (0, pad))

    return x


In [7]:
import json
from pathlib import Path
from tqdm import tqdm
import torch

SPLIT = "test"

ROOTS = {
    "train": TRAIN_ROOT,
    "valid": VALID_ROOT,
    "test":  TEST_ROOT,
}

root = Path(ROOTS[SPLIT])
audio_dir = root / "audio"
json_path = root / "examples.json"

cache_root = Path("../data/nsynth_mel_cache")
cache_dir = cache_root / SPLIT
cache_dir.mkdir(parents=True, exist_ok=True)

with open(json_path, "r") as f:
    examples = json.load(f)

keys = list(examples.keys())

MAX_ITEMS = None
START_AT = 0
PROGRESS_FILE = cache_dir / "_progress.txt"

if MAX_ITEMS is None:
    end_at = len(keys)
else:
    end_at = min(len(keys), START_AT + MAX_ITEMS)

keys_slice = keys[START_AT:end_at]


skipped = written = errors = 0
first_error_printed = False

for i, k in enumerate(tqdm(keys_slice), start=START_AT):
    out_path = cache_dir / f"{k}.pt"
    if out_path.exists():
        skipped += 1
        continue

    wav_path = audio_dir / f"{k}.wav"
    try:
        x = wav_to_logmel_tensor(wav_path)
        torch.save(x, out_path)
        written += 1
    except Exception as e:
        errors += 1
        if not first_error_printed:
            print("FIRST ERROR at key:", k)
            print("wav_path:", wav_path)
            print("exception:", repr(e))
            first_error_printed = True

PROGRESS_FILE.write_text(str(end_at))
print("\nDone.")
print("written:", written, "| skipped:", skipped, "| errors:", errors)
print("Total processed:", written + skipped + errors)

100%|██████████| 4096/4096 [00:00<00:00, 194127.20it/s]


Done.
written: 0 | skipped: 4096 | errors: 0
Total processed: 4096





In [8]:
from torch.utils.data import DataLoader
import json
from pathlib import Path

# --- Roots dos splits (você já tem essas vars definidas) ---
TRAIN_ROOT = Path(TRAIN_ROOT)
VALID_ROOT = Path(VALID_ROOT)
TEST_ROOT  = Path(TEST_ROOT)

# --- Cache dirs (você já tem CACHE_ROOT definido na célula anterior) ---
TRAIN_CACHE = CACHE_ROOT / "train"
VALID_CACHE = CACHE_ROOT / "valid"
TEST_CACHE  = CACHE_ROOT / "test"

# --- Load examples.json de cada split ---
with open(TRAIN_ROOT / "examples.json", "r") as f:
    train_examples = json.load(f)
with open(VALID_ROOT / "examples.json", "r") as f:
    val_examples = json.load(f)

train_keys = list(train_examples.keys())
val_keys   = list(val_examples.keys())

print("train:", len(train_keys), "| cache:", TRAIN_CACHE)
print("val  :", len(val_keys),   "| cache:", VALID_CACHE)

# --- Datasets ---
train_ds = NsynthMelCacheDataset(train_keys, train_examples, TRAIN_CACHE)
val_ds   = NsynthMelCacheDataset(val_keys,   val_examples,   VALID_CACHE)

# --- DataLoaders ---
pin = (device.type == "cuda")

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=8,
    pin_memory=pin,
    persistent_workers=True,
)
val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=8,
    pin_memory=pin,
    persistent_workers=True,
)

# sanity check batch
x, pitch, family, k = next(iter(train_loader))
print("\n[Train batch sanity check]")
print("batch x:", tuple(x.shape), x.dtype, "| min/max:", float(x.min()), float(x.max()))
print("batch pitch:", pitch.shape, pitch.dtype, "| min/max:", int(pitch.min()), int(pitch.max()))
print("batch family:", family.shape, family.dtype, "| min/max:", int(family.min()), int(family.max()))
print("first key:", k[0])


train: 289205 | cache: ../data/nsynth_mel_cache/train
val  : 12678 | cache: ../data/nsynth_mel_cache/valid

[Train batch sanity check]
batch x: (128, 1, 80, 128) torch.float32 | min/max: -1.0 1.0
batch pitch: torch.Size([128]) torch.int64 | min/max: 12 108
batch family: torch.Size([128]) torch.int64 | min/max: 0 10
first key: vocal_synthetic_001-048-100


In [9]:
import torch

# --- Model ---
cvae = ConditionalVAE(
    latent_dim=LATENT_DIM,
    pitch_vocab=128,
    cond_dim=16,      # embedding do pitch
).to(device)

optimizer = torch.optim.Adam(cvae.parameters(), lr=LR)

def beta_schedule(global_step: int) -> float:
    return linear_beta_schedule(
        global_step=global_step,
        warmup_steps=WARMUP_STEPS,
        beta_max=BETA_MAX,
    )

print("ConditionalVAE initialized")
print("Total parameters:", sum(p.numel() for p in cvae.parameters()))


ConditionalVAE initialized
Total parameters: 2206897


In [10]:
import torch

# --- get one batch ---
x, pitch, family, k = next(iter(train_loader))

x = x.to(device, non_blocking=True)                 # (B,1,80,128)
pitch = pitch.to(device, non_blocking=True)         # (B,)
family = family.to(device, non_blocking=True)       # (B,) (não usamos ainda, mas mantemos)

# --- forward ---
x_hat, mu, logvar, z = cvae(x, pitch)

print("Shapes:")
print("x     :", tuple(x.shape))
print("x_hat :", tuple(x_hat.shape))
print("mu    :", tuple(mu.shape))
print("logvar:", tuple(logvar.shape))
print("z     :", tuple(z.shape))

# --- loss ---
beta = beta_schedule(global_step=0)
total, recon, kl_raw, kl_fb = vae_loss(
    x_hat, x, mu, logvar,
    beta=beta,
    free_bits=FREE_BITS,
)

print("\nLosses (untrained sanity check):")
print("beta :", beta)
print("total:", float(total))
print("recon:", float(recon))
print("kl_raw:", float(kl_raw))
print("kl_fb :", float(kl_fb))


Shapes:
x     : (128, 1, 80, 128)
x_hat : (128, 1, 80, 128)
mu    : (128, 64)
logvar: (128, 64)
z     : (128, 64)

Losses (untrained sanity check):
beta : 0.0
total: 0.4757314622402191
recon: 0.4757314622402191
kl_raw: 0.060875315219163895
kl_fb : 0.25


In [11]:
import time
import numpy as np
import torch

def train_one_epoch(model, loader, optimizer, global_step: int, log_every: int = 200):
    model.train()
    stats = {"total": [], "recon": [], "kl_raw": [], "kl_fb": []}

    for x, pitch, family, k in loader:
        x = x.to(device, non_blocking=True)
        pitch = pitch.to(device, non_blocking=True)

        beta = beta_schedule(global_step)

        x_hat, mu, logvar, z = model(x, pitch)

        total, recon, kl_raw, kl_fb = vae_loss(
            x_hat, x, mu, logvar,
            beta=beta,
            free_bits=FREE_BITS,
        )

        optimizer.zero_grad(set_to_none=True)
        total.backward()
        optimizer.step()

        stats["total"].append(total.item())
        stats["recon"].append(recon.item())
        stats["kl_raw"].append(kl_raw.item())
        stats["kl_fb"].append(kl_fb.item())

        if global_step % log_every == 0:
            print(
                f"[step {global_step}] beta={beta:.3f} "
                f"total={total.item():.3f} recon={recon.item():.3f} "
                f"kl_raw={kl_raw.item():.3f} kl_fb={kl_fb.item():.3f}"
            )

        global_step += 1

    return stats, global_step


@torch.no_grad()
def eval_one_epoch(model, loader):
    model.eval()
    stats = {"total": [], "recon": [], "kl_raw": [], "kl_fb": []}

    for x, pitch, family, k in loader:
        x = x.to(device, non_blocking=True)
        pitch = pitch.to(device, non_blocking=True)

        x_hat, mu, logvar, z = model(x, pitch)

        # em validação, beta=1.0 só pra acompanhar números
        total, recon, kl_raw, kl_fb = vae_loss(
            x_hat, x, mu, logvar,
            beta=1.0,
            free_bits=FREE_BITS,
        )

        stats["total"].append(total.item())
        stats["recon"].append(recon.item())
        stats["kl_raw"].append(kl_raw.item())
        stats["kl_fb"].append(kl_fb.item())

    return stats


def summarize_stats(stats: dict) -> dict:
    return {k: float(np.mean(v)) for k, v in stats.items()}


In [35]:
global_step = 0

for epoch in range(1, EPOCHS + 1):
    t0 = time.time()

    train_stats, global_step = train_one_epoch(
        cvae, train_loader, optimizer, global_step, log_every=200
    )
    val_stats = eval_one_epoch(cvae, val_loader)

    train_mean = summarize_stats(train_stats)
    val_mean = summarize_stats(val_stats)

    print(f"\nEpoch {epoch}/{EPOCHS} ({time.time() - t0:.1f}s)")
    print(
        f"  Train: total={train_mean['total']:.3f} "
        f"recon={train_mean['recon']:.3f} "
        f"kl_raw={train_mean['kl_raw']:.3f} "
        f"kl_fb={train_mean['kl_fb']:.3f}"
    )
    print(
        f"  Val  : total={val_mean['total']:.3f} "
        f"recon={val_mean['recon']:.3f} "
        f"kl_raw={val_mean['kl_raw']:.3f} "
        f"kl_fb={val_mean['kl_fb']:.3f}"
    )


[step 0] beta=0.000 total=0.446 recon=0.446 kl_raw=0.057 kl_fb=0.250
[step 200] beta=0.100 total=0.264 recon=0.233 kl_raw=0.271 kl_fb=0.307
[step 400] beta=0.200 total=0.263 recon=0.211 kl_raw=0.231 kl_fb=0.258
[step 600] beta=0.300 total=0.291 recon=0.215 kl_raw=0.221 kl_fb=0.254
[step 800] beta=0.400 total=0.290 recon=0.187 kl_raw=0.241 kl_fb=0.256
[step 1000] beta=0.500 total=0.308 recon=0.183 kl_raw=0.234 kl_fb=0.251
[step 1200] beta=0.600 total=0.320 recon=0.169 kl_raw=0.238 kl_fb=0.251
[step 1400] beta=0.700 total=0.361 recon=0.185 kl_raw=0.232 kl_fb=0.251
[step 1600] beta=0.800 total=0.406 recon=0.205 kl_raw=0.239 kl_fb=0.251
[step 1800] beta=0.900 total=0.380 recon=0.154 kl_raw=0.233 kl_fb=0.250
[step 2000] beta=1.000 total=0.417 recon=0.167 kl_raw=0.230 kl_fb=0.250
[step 2200] beta=1.000 total=0.410 recon=0.160 kl_raw=0.234 kl_fb=0.250

Epoch 1/15 (85.6s)
  Train: total=0.330 recon=0.188 kl_raw=0.276 kl_fb=0.298
  Val  : total=0.424 recon=0.172 kl_raw=0.238 kl_fb=0.252
[step 2

In [13]:
from pathlib import Path
import torch
import json
from datetime import datetime

CKPT_DIR = Path("../notebooks/ckpts")  # ou ../outputs/ckpts
CKPT_DIR.mkdir(parents=True, exist_ok=True)
global_step = 0
run_name = f"cvae_pitch_lat{LATENT_DIM}_beta{BETA_MAX}_fb{FREE_BITS}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
ckpt_path = CKPT_DIR / f"{run_name}.pt"
meta_path = CKPT_DIR / f"{run_name}.json"

ckpt = {
    "model_state": cvae.state_dict(),
    "optimizer_state": optimizer.state_dict(),
    "global_step": global_step,
    "epoch": epoch if "epoch" in globals() else None,
    "config": {
        "LATENT_DIM": LATENT_DIM,
        "BETA_MAX": BETA_MAX,
        "FREE_BITS": FREE_BITS,
        "WARMUP_STEPS": WARMUP_STEPS,
        "LR": LR,
        "cond_dim": 16,
        "pitch_vocab": 128,
        "N_MELS": N_MELS,
        "N_FFT": N_FFT,
        "HOP": HOP,
        "T": T,
        "SR": SR,
    },
}

torch.save(ckpt, ckpt_path)

# opcional: metadados em json (mais fácil de ler sem torch)
with open(meta_path, "w") as f:
    json.dump(ckpt["config"], f, indent=2)

print("Saved:", ckpt_path)
print("Config:", meta_path)


Saved: ../notebooks/ckpts/cvae_pitch_lat64_beta1.0_fb0.25_20260212_190037.pt
Config: ../notebooks/ckpts/cvae_pitch_lat64_beta1.0_fb0.25_20260212_190037.json


In [14]:
import torch
from pathlib import Path

# aponte para o checkpoint que você quer
CKPT_PATH = Path("../notebooks/ckpts/cvae_pitch_lat32_beta2.0_fb0.5_20260209_184708.pt")

ckpt = torch.load(CKPT_PATH, map_location=device)

# re-instancia o modelo igual
cvae = ConditionalVAE(
    latent_dim=ckpt["config"]["LATENT_DIM"],
    pitch_vocab=ckpt["config"]["pitch_vocab"],
    cond_dim=ckpt["config"]["cond_dim"],
).to(device)

optimizer = torch.optim.Adam(cvae.parameters(), lr=ckpt["config"]["LR"])

cvae.load_state_dict(ckpt["model_state"])
optimizer.load_state_dict(ckpt["optimizer_state"])

global_step = int(ckpt.get("global_step", 0))
start_epoch = (ckpt.get("epoch") or 0) + 1

print("Loaded:", CKPT_PATH)
print("global_step:", global_step, "| start_epoch:", start_epoch)

Loaded: ../notebooks/ckpts/cvae_pitch_lat32_beta2.0_fb0.5_20260209_184708.pt
global_step: 33900 | start_epoch: 16


  ckpt = torch.load(CKPT_PATH, map_location=device)


In [15]:
import numpy as np
import torch

N_SAMPLES = 2000  # ajuste se quiser mais/menos

cvae.eval()

mus = []
logvars = []
pitches = []
families = []
keys_out = []

def extract_mu_logvar(model, x, pitch):
    """
    Compatível com várias implementações de ConditionalVAE:
    - model.encode(x, pitch) -> (mu, logvar)
    - model.encoder(x, pitch) -> (mu, logvar)
    - model.encoder(x) -> (mu, logvar)
    """
    # 1) preferível: método encode explícito
    if hasattr(model, "encode") and callable(getattr(model, "encode")):
        out = model.encode(x, pitch)
        # pode vir (mu, logvar) ou (mu, logvar, ...)
        mu, logvar = out[0], out[1]
        return mu, logvar

    # 2) encoder pode aceitar pitch
    try:
        out = model.encoder(x, pitch)
        mu, logvar = out[0], out[1]
        return mu, logvar
    except TypeError:
        pass

    # 3) encoder sem cond
    out = model.encoder(x)
    mu, logvar = out[0], out[1]
    return mu, logvar

with torch.no_grad():
    total = 0
    for x, pitch, family, k in val_loader:
        x = x.to(device, non_blocking=True)
        pitch_dev = pitch.to(device, non_blocking=True)

        mu, logvar = extract_mu_logvar(cvae, x, pitch_dev)  # (B, latent_dim)

        bsz = mu.shape[0]
        total += bsz

        mus.append(mu.detach().cpu())
        logvars.append(logvar.detach().cpu())
        pitches.append(pitch.detach().cpu())
        families.append(family.detach().cpu())
        keys_out.extend(list(k))

        if total >= N_SAMPLES:
            break

mu_all = torch.cat(mus, dim=0)[:N_SAMPLES].numpy()
logvar_all = torch.cat(logvars, dim=0)[:N_SAMPLES].numpy()
pitch_all = torch.cat(pitches, dim=0)[:N_SAMPLES].numpy()
family_all = torch.cat(families, dim=0)[:N_SAMPLES].numpy()
keys_all = np.array(keys_out[:N_SAMPLES])

print("mu_all:", mu_all.shape)
print("logvar_all:", logvar_all.shape)
print("pitch_all:", pitch_all.shape, "| min/max:", pitch_all.min(), pitch_all.max())
print("family_all:", family_all.shape, "| unique:", len(np.unique(family_all)))
print("keys_all:", keys_all.shape)


mu_all: (2000, 32)
logvar_all: (2000, 32)
pitch_all: (2000,) | min/max: 9 118
family_all: (2000,) | unique: 10
keys_all: (2000,)


In [16]:
import numpy as np
from sklearn.decomposition import PCA
import plotly.express as px

# --- Instrument family mapping (NSynth) ---
FAMILY_NAMES = {
    0: "bass",
    1: "brass",
    2: "flute",
    3: "guitar",
    4: "keyboard",
    5: "mallet",
    6: "organ",
    7: "reed",
    8: "string",
    9: "synth_lead",
    10: "vocal",
}

family_names = np.vectorize(FAMILY_NAMES.get)(family_all)

# --- PCA 2D ---
pca = PCA(n_components=2, random_state=SEED)
mu_pca = pca.fit_transform(mu_all)

print(
    "Explained variance ratio:",
    pca.explained_variance_ratio_,
    "| sum:",
    pca.explained_variance_ratio_.sum(),
)

# Data for plotly
data = {
    "pc1": mu_pca[:, 0],
    "pc2": mu_pca[:, 1],
    "pitch": pitch_all.astype(int),
    "instrument_family": family_names,  # <-- categórico
    "key": keys_all,
}

# --- Plot 1: color by instrument_family (categorical legend) ---
fig1 = px.scatter(
    data,
    x="pc1",
    y="pc2",
    color="instrument_family",
    hover_data=["pitch", "key"],
    title="PCA (mu) — color: instrument family",
)
fig1.show()

# --- Plot 2: color by pitch (continuous) ---
fig2 = px.scatter(
    data,
    x="pc1",
    y="pc2",
    color="pitch",
    hover_data=["instrument_family", "key"],
    title="PCA (mu) — color: pitch",
)
fig2.show()


Explained variance ratio: [0.9634293  0.03452188] | sum: 0.99795115


In [17]:
import numpy as np
import umap.umap_ as umap
import plotly.express as px

NSYNTH_FAMILY_MAP = {
    0: "bass",
    1: "brass",
    2: "flute",
    3: "guitar",
    4: "keyboard",
    5: "mallet",
    6: "organ",
    7: "reed",
    8: "string",
    9: "synth_lead",
    10: "vocal",
}

family_labels = np.array([
    f"{int(fid)} – {NSYNTH_FAMILY_MAP.get(int(fid), 'unknown')}"
    for fid in family_all
])

# --- UMAP 2D ---
reducer = umap.UMAP(
    n_neighbors=30,
    min_dist=0.1,
    n_components=2,
    metric="euclidean",
    init="spectral",
    learning_rate=1.0,   # <-- trocado (antes estava "auto")
    random_state=SEED,
)
mu_umap = reducer.fit_transform(mu_all)

data_umap = {
    "u1": mu_umap[:, 0],
    "u2": mu_umap[:, 1],
    "pitch": pitch_all.astype(int),
    "family_label": family_labels,  # categórico => legenda por cor
    "key": keys_all,
}

fig1 = px.scatter(
    data_umap,
    x="u1",
    y="u2",
    color="family_label",
    hover_data=["pitch", "key"],
    title="UMAP (mu) — color: instrument family",
)
fig1.show()

fig2 = px.scatter(
    data_umap,
    x="u1",
    y="u2",
    color="pitch",
    hover_data=["family_label", "key"],
    title="UMAP (mu) — color: pitch",
)
fig2.show()



IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html


n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



In [18]:
import numpy as np
import torch
from plotly.subplots import make_subplots
import plotly.graph_objects as go

def denormalize_db(x_norm: np.ndarray) -> np.ndarray:
    """[-1, 1] -> [-80, 0] dB"""
    x_01 = (x_norm + 1.0) / 2.0
    return x_01 * 80.0 - 80.0

# --- pick one batch from val ---
cvae.eval()
x, pitch, family, k = next(iter(val_loader))
x = x.to(device, non_blocking=True)
pitch = pitch.to(device, non_blocking=True)

with torch.no_grad():
    x_hat, mu, logvar, z = cvae(x, pitch)

i = 0
x_i    = x[i, 0].detach().cpu().numpy()        # (80, T)
xhat_i = x_hat[i, 0].detach().cpu().numpy()    # (80, T)

x_i_db    = denormalize_db(x_i)
xhat_i_db = denormalize_db(xhat_i)
diff_db   = x_i_db - xhat_i_db

key_i = k[i]
pitch_i = int(pitch[i].detach().cpu())
family_i = int(family[i].detach().cpu())

print("Example key:", key_i)
print("pitch:", pitch_i, "| family:", family_i)
print("x     min/max:", float(x_i_db.min()), float(x_i_db.max()))
print("x_hat min/max:", float(xhat_i_db.min()), float(xhat_i_db.max()))

# --- Plotly figure ---
vmin, vmax = -80, 0
diff_lim = 20  # dB (ajuste se quiser)

fig = make_subplots(
    rows=1, cols=3,
    subplot_titles=[
        "Original log-mel (dB)",
        "Reconstruction (dB)",
        "Difference (orig - recon) (dB)"
    ],
    horizontal_spacing=0.05
)

# Original
fig.add_trace(
    go.Heatmap(
        z=x_i_db,
        zmin=vmin, zmax=vmax,
        colorscale="Viridis",
        colorbar=dict(title="dB"),
        hovertemplate="mel=%{y}<br>t=%{x}<br>dB=%{z:.2f}<extra></extra>",
    ),
    row=1, col=1
)

# Reconstruction
fig.add_trace(
    go.Heatmap(
        z=xhat_i_db,
        zmin=vmin, zmax=vmax,
        colorscale="Viridis",
        showscale=False,
        hovertemplate="mel=%{y}<br>t=%{x}<br>dB=%{z:.2f}<extra></extra>",
    ),
    row=1, col=2
)

# Difference
fig.add_trace(
    go.Heatmap(
        z=diff_db,
        zmin=-diff_lim, zmax=diff_lim,
        colorscale="RdBu",
        colorbar=dict(title="Δ dB"),
        hovertemplate="mel=%{y}<br>t=%{x}<br>ΔdB=%{z:.2f}<extra></extra>",
    ),
    row=1, col=3
)

# Layout tweaks
fig.update_layout(
    title=f"Key: {key_i} | pitch={pitch_i} | family={family_i}",
    width=1200,
    height=420,
)

# Make axes readable (time on x, mel bins on y)
for c in [1, 2, 3]:
    fig.update_xaxes(title_text="time frames", row=1, col=c)
    fig.update_yaxes(title_text="mel bins", row=1, col=c)

fig.show()


Example key: keyboard_acoustic_004-060-025
pitch: 60 | family: 4
x     min/max: -80.0 0.0
x_hat min/max: -80.52855682373047 -0.34371185302734375


In [19]:
import numpy as np
from sklearn.metrics import silhouette_score, silhouette_samples
import pandas as pd

# ----------------------------------
# Sanity check
# ----------------------------------
X = mu_all            # (N, latent_dim)
labels = family_all   # (N,)

print("X shape:", X.shape)
print("Num families:", len(np.unique(labels)))

# ----------------------------------
# Global silhouette score
# ----------------------------------
sil_global = silhouette_score(X, labels, metric="euclidean")
print(f"Global silhouette score: {sil_global:.4f}")

# ----------------------------------
# Silhouette per-sample
# ----------------------------------
sil_samples = silhouette_samples(X, labels, metric="euclidean")

# ----------------------------------
# Aggregate per family
# ----------------------------------
df = pd.DataFrame({
    "family": labels,
    "silhouette": sil_samples,
})

sil_by_family = (
    df
    .groupby("family")
    .agg(
        mean_silhouette=("silhouette", "mean"),
        std_silhouette=("silhouette", "std"),
        n_samples=("silhouette", "count"),
    )
    .sort_values("mean_silhouette", ascending=False)
)

sil_by_family


X shape: (2000, 32)
Num families: 10
Global silhouette score: -0.3960


Unnamed: 0_level_0,mean_silhouette,std_silhouette,n_samples
family,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5,0.529885,0.121267,104
10,0.198969,0.135269,58
2,-0.089783,0.247468,71
1,-0.270457,0.1018,130
6,-0.345868,0.257432,257
7,-0.423224,0.212638,106
3,-0.462207,0.285176,325
4,-0.516194,0.24086,372
0,-0.553224,0.126866,447
8,-0.721152,0.170664,130


In [20]:
import plotly.express as px

fig = px.bar(
    sil_by_family.reset_index(),
    x="family",
    y="mean_silhouette",
    error_y="std_silhouette",
    title="Silhouette score per instrument family (latent space μ)",
    labels={
        "family": "Instrument family",
        "mean_silhouette": "Mean silhouette score",
    }
)

fig.show()


In [21]:
import numpy as np
import pandas as pd
from sklearn.metrics import silhouette_samples

def silhouette_by_family(mu, families):
    """
    Retorna silhouette médio por family.
    """
    sil_samples = silhouette_samples(mu, families, metric="euclidean")

    df = pd.DataFrame({
        "family": families,
        "silhouette": sil_samples
    })

    summary = (
        df.groupby("family")
          .agg(
              mean_silhouette=("silhouette", "mean"),
              std_silhouette=("silhouette", "std"),
              n_samples=("silhouette", "count")
          )
          .sort_values("mean_silhouette", ascending=False)
    )

    return summary
def run_silhouette_for_pitch_window(
    mu_all,
    pitch_all,
    family_all,
    pitch_center,
    tolerance=1,
    min_samples_per_family=10,
):
    mask = np.abs(pitch_all - pitch_center) <= tolerance

    mu_sel = mu_all[mask]
    pitch_sel = pitch_all[mask]
    family_sel = family_all[mask]

    print(f"Pitch window: {pitch_center} ± {tolerance}")
    print("Total samples:", len(mu_sel))

    # filtrar famílias com poucos exemplos
    valid_families = [
        f for f in np.unique(family_sel)
        if np.sum(family_sel == f) >= min_samples_per_family
    ]

    fam_mask = np.isin(family_sel, valid_families)

    mu_sel = mu_sel[fam_mask]
    family_sel = family_sel[fam_mask]

    print("Families kept:", valid_families)
    print("Samples after filter:", len(mu_sel))

    return silhouette_by_family(mu_sel, family_sel)


In [22]:
results_pitch_60 = run_silhouette_for_pitch_window(
    mu_all, pitch_all, family_all,
    pitch_center=60,
    tolerance=1
)

results_pitch_36 = run_silhouette_for_pitch_window(
    mu_all, pitch_all, family_all,
    pitch_center=36,
    tolerance=1
)

results_pitch_84 = run_silhouette_for_pitch_window(
    mu_all, pitch_all, family_all,
    pitch_center=84,
    tolerance=1
)

results_pitch_60


Pitch window: 60 ± 1
Total samples: 88
Families kept: [np.int64(0), np.int64(3), np.int64(4), np.int64(6)]
Samples after filter: 59
Pitch window: 36 ± 1
Total samples: 87
Families kept: [np.int64(0), np.int64(3), np.int64(4), np.int64(6), np.int64(8)]
Samples after filter: 73
Pitch window: 84 ± 1
Total samples: 72
Families kept: [np.int64(3), np.int64(4)]
Samples after filter: 36


Unnamed: 0_level_0,mean_silhouette,std_silhouette,n_samples
family,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
3,0.444127,0.08653,13
0,0.007102,0.370149,17
6,-0.140917,0.499867,15
4,-0.233343,0.201049,14


In [23]:
# Silhouette GLOBAL (sem controle de pitch)
global_silhouette = silhouette_by_family(mu_all, family_all)

print("Global silhouette per family:")
global_silhouette


Global silhouette per family:


Unnamed: 0_level_0,mean_silhouette,std_silhouette,n_samples
family,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
5,0.529885,0.121267,104
10,0.198969,0.135269,58
2,-0.089783,0.247468,71
1,-0.270457,0.1018,130
6,-0.345868,0.257432,257
7,-0.423224,0.212638,106
3,-0.462207,0.285176,325
4,-0.516194,0.24086,372
0,-0.553224,0.126866,447
8,-0.721152,0.170664,130


In [24]:
comparison = pd.DataFrame({
    "global": global_silhouette["mean_silhouette"],
    "pitch_36": results_pitch_36["mean_silhouette"],
    "pitch_60": results_pitch_60["mean_silhouette"],
    "pitch_84": results_pitch_84["mean_silhouette"],
})

comparison


Unnamed: 0_level_0,global,pitch_36,pitch_60,pitch_84
family,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
0,-0.553224,0.469967,0.007102,
1,-0.270457,,,
2,-0.089783,,,
3,-0.462207,-0.24523,0.444127,0.210038
4,-0.516194,-0.186535,-0.233343,0.165442
5,0.529885,,,
6,-0.345868,-0.108674,-0.140917,
7,-0.423224,,,
8,-0.721152,-0.502282,,
10,0.198969,,,


In [25]:
import pandas as pd
import plotly.graph_objects as go

# --- Dados (a partir dos seus resultados) ---
data = {
    "family": [0, 3, 4, 6],
    "global": [-0.521493, -0.414603, -0.491003, -0.351645],
    "pitch_60": [0.067541, 0.485112, -0.259557, -0.158383],
}

df = pd.DataFrame(data)

# --- Plot ---
fig = go.Figure()

fig.add_trace(
    go.Bar(
        x=df["family"],
        y=df["global"],
        name="Global",
        marker_color="rgba(99, 110, 250, 0.7)",
    )
)

fig.add_trace(
    go.Bar(
        x=df["family"],
        y=df["pitch_60"],
        name="Pitch 60 ± 1",
        marker_color="rgba(239, 85, 59, 0.7)",
    )
)

fig.update_layout(
    title="Silhouette score per instrument family<br><sub>Global vs Pitch 60 ± 1 (latent space μ)</sub>",
    xaxis_title="Instrument family",
    yaxis_title="Mean silhouette score",
    barmode="group",
    width=900,
    height=450,
)

fig.show()


In [26]:
import numpy as np
import umap
import plotly.express as px
import pandas as pd

# -----------------------------
# Filter: pitch 60 ± 1
# -----------------------------
PITCH_CENTER = 60
DELTA = 1

mask = np.abs(pitch_all - PITCH_CENTER) <= DELTA

mu_p60 = mu_all[mask]
pitch_p60 = pitch_all[mask]
family_p60 = family_all[mask]
keys_p60 = keys_all[mask]

print("Pitch window:", PITCH_CENTER, "±", DELTA)
print("Samples:", mu_p60.shape[0])
print("Families:", np.unique(family_p60))

# -----------------------------
# UMAP
# -----------------------------
reducer = umap.UMAP(
    n_neighbors=20,
    min_dist=0.1,
    n_components=2,
    random_state=SEED,
)

umap_p60 = reducer.fit_transform(mu_p60)

# -----------------------------
# Family labels (NSynth)
# -----------------------------
NSYNTH_FAMILY_MAP = {
    0: "bass",
    1: "brass",
    2: "flute",
    3: "guitar",
    4: "keyboard",
    5: "mallet",
    6: "organ",
    7: "reed",
    8: "string",
    9: "synth_lead",
    10: "vocal",
}

family_labels = [
    f"{fid} – {NSYNTH_FAMILY_MAP.get(fid, 'unknown')}"
    for fid in family_p60
]

# -----------------------------
# DataFrame for Plotly
# -----------------------------
df_umap = pd.DataFrame({
    "u1": umap_p60[:, 0],
    "u2": umap_p60[:, 1],
    "family": family_p60.astype(int),
    "family_label": family_labels,
    "pitch": pitch_p60.astype(int),
    "key": keys_p60,
})

# -----------------------------
# Plot
# -----------------------------
fig = px.scatter(
    df_umap,
    x="u1",
    y="u2",
    color="family_label",
    hover_data=["key"],
    title="UMAP (μ) — Pitch 60 ± 1",
)

fig.update_layout(
    width=900,
    height=500,
)

fig.show()


Pitch window: 60 ± 1
Samples: 88
Families: [ 0  1  2  3  4  5  6  7  8 10]



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.



In [44]:
import numpy as np
import torch
import librosa
from IPython.display import Audio, display

# -----------------------------
# Helpers
# -----------------------------
def logmel_db_to_norm(logmel_db: np.ndarray) -> np.ndarray:
    """[-80,0] dB -> [-1,1]"""
    logmel_db = np.clip(logmel_db, -80.0, 0.0)
    x01 = (logmel_db + 80.0) / 80.0
    return 2.0 * x01 - 1.0

def norm_to_logmel_db(x_norm: np.ndarray) -> np.ndarray:
    """[-1,1] -> [-80,0] dB"""
    x01 = (x_norm + 1.0) / 2.0
    return x01 * 80.0 - 80.0

def logmel_db_to_audio(
    logmel_db: np.ndarray,
    sr: int = SR,
    n_fft: int = N_FFT,
    hop_length: int = HOP,
    n_mels: int = N_MELS,
    n_iter: int = 256,
) -> np.ndarray:
    """
    Approx inversion via Griffin-Lim.
    logmel_db: (n_mels, T) in dB, expected roughly in [-80, 0]
    """
    mel_power = librosa.db_to_power(logmel_db, ref=1.0)
    stft_mag = librosa.feature.inverse.mel_to_stft(
        M=mel_power, sr=sr, n_fft=n_fft, power=1.0
    )
    wav = librosa.griffinlim(
        stft_mag, n_iter=n_iter, hop_length=hop_length, win_length=n_fft
    )
    return wav

# -----------------------------
# Pick one example (key + pitch)
# -----------------------------
# Option A: reuse batch variables if you still have them from the notebook:
# i = 0
# key_i = k[i]
# pitch_i = int(pitch[i].cpu())
ROOTS = {
    "train": TRAIN_ROOT,
    "valid": VALID_ROOT,
    "test":  TEST_ROOT,
}

root = Path(ROOTS[SPLIT])
# roots for valid
VALID_ROOT = root / "valid"

JSON_PATH = root / "examples.json"
AUDIO_DIR = root / "audio"

# Option B: set manually (edit these):
key_i = keys_all[4]           # must exist in AUDIO_DIR
pitch_i = int(pitch_all[4])   # must match that key

wav_path = VALID_AUDIO_DIR / f"{key_i}.wav"
print(wav_path.exists(), "|", wav_path)
print("key:", key_i, "| pitch:", pitch_i)
print("wav_path:", wav_path)

# -----------------------------
# Load REAL wav and compute FULL log-mel (no crop)
# -----------------------------
wav_real, _ = librosa.load(wav_path, sr=SR, mono=True)

mel = librosa.feature.melspectrogram(
    y=wav_real, sr=SR, n_fft=N_FFT, hop_length=HOP, n_mels=N_MELS
)
logmel_db_full = librosa.power_to_db(mel, ref=np.max)   # (80, T_full)
x_norm_full = logmel_db_to_norm(logmel_db_full)         # (80, T_full)

T_full = x_norm_full.shape[1]
print("Full mel frames:", T_full)

# -----------------------------
# Sliding window + overlap-add
# -----------------------------
cvae.eval()

win = T                 # 128
hop_win = T // 2        # 64 (50% overlap). Try 32 for smoother stitching.

# Hann window for smooth overlap-add: (1, win)
w = np.hanning(win).astype(np.float32)[None, :]

# Accumulators on REAL length
y_acc = np.zeros((N_MELS, T_full), dtype=np.float32)
w_acc = np.zeros((1, T_full), dtype=np.float32)

pitch_tensor = torch.tensor([pitch_i], device=device, dtype=torch.long)

with torch.no_grad():
    for start in range(0, T_full, hop_win):
        end = start + win

        # take real chunk (can be shorter on last window)
        chunk = x_norm_full[:, start:end]  # (80, <=win)
        valid_len = chunk.shape[1]
        if valid_len == 0:
            break

        # pad to win before feeding the model
        if valid_len < win:
            chunk = np.pad(
                chunk,
                ((0, 0), (0, win - valid_len)),
                mode="constant",
                constant_values=-1.0
            )

        # to tensor: (1,1,80,win)
        x_in = torch.tensor(chunk, dtype=torch.float32, device=device).unsqueeze(0).unsqueeze(0)

        # forward
        x_hat, mu, logvar, z = cvae(x_in, pitch_tensor)

        # back to numpy: (80, win)
        chunk_hat = x_hat[0, 0].detach().cpu().numpy()

        # safety: ensure win length
        if chunk_hat.shape[1] != win:
            if chunk_hat.shape[1] > win:
                chunk_hat = chunk_hat[:, :win]
            else:
                chunk_hat = np.pad(
                    chunk_hat,
                    ((0, 0), (0, win - chunk_hat.shape[1])),
                    mode="constant",
                    constant_values=-1.0
                )

        # overlap-add ONLY valid_len (prevents broadcast errors at end)
        y_acc[:, start:start+valid_len] += chunk_hat[:, :valid_len] * w[:, :valid_len]
        w_acc[:, start:start+valid_len] += w[:, :valid_len]

# Normalize overlap-add
w_acc = np.maximum(w_acc, 1e-6)
x_hat_norm_full = y_acc / w_acc  # (80, T_full)

# -----------------------------
# Invert ORIGINAL full mel vs RECONSTRUCTED full mel
# -----------------------------
x_db_full = norm_to_logmel_db(x_norm_full)
xhat_db_full = norm_to_logmel_db(x_hat_norm_full)

wav_inv_from_orig = logmel_db_to_audio(x_db_full, n_iter=64)
wav_inv_from_recon = logmel_db_to_audio(xhat_db_full, n_iter=64)

print("wav lengths (samples): real / inv(orig mel) / inv(recon mel):",
      len(wav_real), len(wav_inv_from_orig), len(wav_inv_from_recon))

# -----------------------------
# Listen
# -----------------------------
print("\n▶ REAL NSynth WAV")
display(Audio(wav_real, rate=SR))

print("▶ Inverted from ORIGINAL full log-mel (approx)  [inversion bottleneck]")
display(Audio(wav_inv_from_orig, rate=SR))

print("▶ Inverted from RECONSTRUCTED full log-mel (approx)  [model + inversion]")
display(Audio(wav_inv_from_recon, rate=SR))


True | ../data/nsynth-valid.jsonwav/nsynth-valid/audio/bass_synthetic_034-030-050.wav
key: bass_synthetic_034-030-050 | pitch: 30
wav_path: ../data/nsynth-valid.jsonwav/nsynth-valid/audio/bass_synthetic_034-030-050.wav
Full mel frames: 251
wav lengths (samples): real / inv(orig mel) / inv(recon mel): 64000 64000 64000

▶ REAL NSynth WAV


▶ Inverted from ORIGINAL full log-mel (approx)  [inversion bottleneck]


▶ Inverted from RECONSTRUCTED full log-mel (approx)  [model + inversion]


In [45]:
import librosa
from IPython.display import Audio, display

def mel_db_to_audio_librosa_mel_to_audio(logmel_db, n_iter=128):
    mel_power = librosa.db_to_power(logmel_db, ref=1.0)
    wav = librosa.feature.inverse.mel_to_audio(
        M=mel_power,
        sr=SR,
        n_fft=N_FFT,
        hop_length=HOP,
        win_length=N_FFT,
        n_iter=n_iter,
        power=2.0,   # porque db_to_power retorna power
    )
    return wav

print("▶ mel_to_audio inversion | n_iter=128")
wav_inv_orig_128 = mel_db_to_audio_librosa_mel_to_audio(x_db_full, n_iter=128)
wav_inv_recon_128 = mel_db_to_audio_librosa_mel_to_audio(xhat_db_full, n_iter=128)

print("ORIG mel -> audio")
display(Audio(wav_inv_orig_128, rate=SR))

print("RECON mel -> audio")
display(Audio(wav_inv_recon_128, rate=SR))

print("▶ mel_to_audio inversion | n_iter=64")
wav_inv_orig_64 = mel_db_to_audio_librosa_mel_to_audio(x_db_full, n_iter=64)
wav_inv_recon_64 = mel_db_to_audio_librosa_mel_to_audio(xhat_db_full, n_iter=64)

print("ORIG mel -> audio")
display(Audio(wav_inv_orig_64, rate=SR))

print("RECON mel -> audio")
display(Audio(wav_inv_recon_64, rate=SR))

▶ mel_to_audio inversion | n_iter=128
ORIG mel -> audio


RECON mel -> audio


▶ mel_to_audio inversion | n_iter=64
ORIG mel -> audio


RECON mel -> audio


## Interpolation

In [46]:
import numpy as np
import torch
import librosa

def norm_to_logmel_db(x_norm: np.ndarray) -> np.ndarray:
    """[-1,1] -> [-80,0] dB"""
    x01 = (x_norm + 1.0) / 2.0
    return x01 * 80.0 - 80.0

def logmel_db_to_audio(
    logmel_db: np.ndarray,
    sr: int,
    n_fft: int,
    hop_length: int,
    n_mels: int,
    n_iter: int = 64,
) -> np.ndarray:
    mel_power = librosa.db_to_power(logmel_db, ref=1.0)
    stft_mag = librosa.feature.inverse.mel_to_stft(
        M=mel_power, sr=sr, n_fft=n_fft, power=1.0
    )
    wav = librosa.griffinlim(
        stft_mag, n_iter=n_iter, hop_length=hop_length, win_length=n_fft
    )
    return wav

@torch.no_grad()
def decode_windowed_overlap_add(
    cvae,
    x_norm_full: np.ndarray,        # (80, T_full) em [-1,1]
    pitch_i: int,
    device,
    win: int = 128,
    hop_win: int = 64,
):
    """
    Reconstrói um log-mel full-length fazendo janelas de tamanho win e overlap-add.
    Retorna xhat_norm_full (80, T_full).
    """
    cvae.eval()
    N_MELS, T_full = x_norm_full.shape

    w = np.hanning(win).astype(np.float32)[None, :]  # (1,win)
    y_acc = np.zeros((N_MELS, T_full), dtype=np.float32)
    w_acc = np.zeros((1, T_full), dtype=np.float32)

    pitch_tensor = torch.tensor([pitch_i], device=device, dtype=torch.long)

    for start in range(0, T_full, hop_win):
        end = start + win
        chunk = x_norm_full[:, start:end]
        valid_len = chunk.shape[1]
        if valid_len == 0:
            break

        if valid_len < win:
            chunk = np.pad(
                chunk,
                ((0, 0), (0, win - valid_len)),
                mode="constant",
                constant_values=-1.0
            )

        x_in = torch.tensor(chunk, dtype=torch.float32, device=device).unsqueeze(0).unsqueeze(0)
        x_hat, _, _, _ = cvae(x_in, pitch_tensor)
        chunk_hat = x_hat[0, 0].detach().cpu().numpy()  # (80,win)

        y_acc[:, start:start+valid_len] += chunk_hat[:, :valid_len] * w[:, :valid_len]
        w_acc[:, start:start+valid_len] += w[:, :valid_len]

    w_acc = np.maximum(w_acc, 1e-6)
    return y_acc / w_acc


In [47]:
import numpy as np

NSYNTH_FAMILY_MAP = {
    0: "bass", 1: "brass", 2: "flute", 3: "guitar", 4: "keyboard",
    5: "mallet", 6: "organ", 7: "reed", 8: "string", 9: "synth_lead", 10: "vocal"
}

def pick_pair_same_family_pitch(keys_all, pitch_all, family_all, target_family: int, target_pitch: int, tol: int = 0):
    """
    Pega dois exemplos da mesma family e pitch (±tol). Retorna (keyA, keyB, pitch_used).
    """
    pitch_all = np.asarray(pitch_all).astype(int)
    family_all = np.asarray(family_all).astype(int)
    keys_all = np.asarray(keys_all)

    mask = (family_all == target_family) & (np.abs(pitch_all - target_pitch) <= tol)
    idx = np.where(mask)[0]
    if len(idx) < 2:
        return None

    i1, i2 = idx[0], idx[1]
    # usa o pitch do primeiro (ou fixo no target_pitch)
    pitch_used = int(pitch_all[i1])
    return keys_all[i1], keys_all[i2], pitch_used

# Escolha inicial: mallet e vocal (como você sugeriu)
target_family = 10  # mallet (troque pra 10 vocal se quiser)
target_pitch = 60  # C4
tol = 1            # aceita ±1 semitone (ajuste)

pair = pick_pair_same_family_pitch(keys_all, pitch_all, family_all, target_family, target_pitch, tol=tol)
print("pair:", pair, "| family:", target_family, NSYNTH_FAMILY_MAP.get(target_family))

# Se pair der None, tente aumentar tol ou mude pitch/family


pair: (np.str_('vocal_synthetic_003-059-127'), np.str_('vocal_synthetic_003-061-100'), 59) | family: 10 vocal


In [55]:
import numpy as np

FAM_MALLET = 5
FAM_VOCAL  = 10
PITCH_TARGET = 60   # opcional
PITCH_TOL = 2       # opcional

def pick_idx(fam_id, pitch_target=None, pitch_tol=2):
    idxs = np.where(family_all == fam_id)[0]
    if pitch_target is not None:
        idxs = [i for i in idxs if abs(int(pitch_all[i]) - pitch_target) <= pitch_tol]
        idxs = np.array(idxs, dtype=int)
    if len(idxs) == 0:
        raise ValueError(f"Não achei amostra pra family={fam_id} com esse filtro de pitch.")
    return int(idxs[0])  # ou np.random.choice(idxs)

iA = pick_idx(FAM_MALLET, pitch_target=PITCH_TARGET, pitch_tol=PITCH_TOL)
iB = pick_idx(FAM_VOCAL,  pitch_target=PITCH_TARGET, pitch_tol=PITCH_TOL)

keyA, pitchA, famA = keys_all[iA], int(pitch_all[iA]), int(family_all[iA])
keyB, pitchB, famB = keys_all[iB], int(pitch_all[iB]), int(family_all[iB])

print("A:", keyA, "pitch", pitchA, "fam", famA)
print("B:", keyB, "pitch", pitchB, "fam", famB)

pair = (keyA, keyB, pitchA)  # ou pitchB, ambos devem ser próximos


A: mallet_acoustic_062-059-025 pitch 59 fam 5
B: vocal_synthetic_003-059-127 pitch 59 fam 10


In [56]:
from pathlib import Path
import librosa

AUDIO_DIR = VALID_AUDIO_DIR  # ou TEST_AUDIO_DIR

def wav_to_full_norm_logmel(wav_path: Path, sr: int, n_fft: int, hop: int, n_mels: int):
    wav, _ = librosa.load(wav_path, sr=sr, mono=True)
    mel = librosa.feature.melspectrogram(y=wav, sr=sr, n_fft=n_fft, hop_length=hop, n_mels=n_mels)
    logmel_db = librosa.power_to_db(mel, ref=np.max)
    # normaliza pra [-1,1] assumindo db em [-80,0]
    logmel_db = np.clip(logmel_db, -80.0, 0.0)
    x01 = (logmel_db + 80.0) / 80.0
    x_norm = 2.0 * x01 - 1.0
    return wav, x_norm  # wav_real, x_norm_full

keyA, keyB, pitch_used = pair
pathA = Path(AUDIO_DIR) / f"{keyA}.wav"
pathB = Path(AUDIO_DIR) / f"{keyB}.wav"
print("A:", pathA.exists(), pathA)
print("B:", pathB.exists(), pathB)
print("pitch_used:", pitch_used)

wavA, xA_norm_full = wav_to_full_norm_logmel(pathA, SR, N_FFT, HOP, N_MELS)
wavB, xB_norm_full = wav_to_full_norm_logmel(pathB, SR, N_FFT, HOP, N_MELS)

print("A frames:", xA_norm_full.shape[1], "| B frames:", xB_norm_full.shape[1])


A: True ../data/nsynth-valid.jsonwav/nsynth-valid/audio/mallet_acoustic_062-059-025.wav
B: True ../data/nsynth-valid.jsonwav/nsynth-valid/audio/vocal_synthetic_003-059-127.wav
pitch_used: 59
A frames: 251 | B frames: 251


In [57]:
import torch

@torch.no_grad()
def encode_mu_windowed(cvae, x_norm_full: np.ndarray, device, win: int = 128):
    """
    Encode mu por janela fixa (80,win). Pega a PRIMEIRA janela pra gerar mu representativo.
    (Depois a gente pode melhorar: média de mus das janelas.)
    """
    cvae.eval()
    chunk = x_norm_full[:, :win]
    if chunk.shape[1] < win:
        chunk = np.pad(chunk, ((0,0),(0,win-chunk.shape[1])), mode="constant", constant_values=-1.0)

    x_in = torch.tensor(chunk, dtype=torch.float32, device=device).unsqueeze(0).unsqueeze(0)  # (1,1,80,win)
    mu, logvar = cvae.encoder(x_in)  # encoder NÃO usa pitch no seu models.py
    return mu[0].detach().cpu()      # (latent_dim,)

muA = encode_mu_windowed(cvae, xA_norm_full, device, win=T)
muB = encode_mu_windowed(cvae, xB_norm_full, device, win=T)

print("muA:", muA.shape, "muB:", muB.shape)


muA: torch.Size([32]) muB: torch.Size([32])


In [58]:
from IPython.display import Audio, display

@torch.no_grad()
def decode_from_mu(cvae, mu_vec: torch.Tensor, pitch_i: int, device):
    """
    Decodifica usando z = mu (sem ruído), cond = pitch embedding.
    Retorna x_hat_norm (80,128) em tensor numpy.
    """
    cvae.eval()
    z = mu_vec.to(device).unsqueeze(0)  # (1,latent_dim)
    pitch_t = torch.tensor([pitch_i], device=device, dtype=torch.long)
    cond = cvae.pitch_cond(pitch_t)      # (1,cond_dim)
    x_hat = cvae.decoder(z, cond)        # (1,1,80,128)
    return x_hat[0,0].detach().cpu().numpy()

def interpolate(a: torch.Tensor, b: torch.Tensor, n: int):
    alphas = np.linspace(0.0, 1.0, n)
    return [((1-t)*a + t*b) for t in alphas], alphas

N_STEPS = 7
mus_interp, alphas = interpolate(muA, muB, N_STEPS)

print("Interpolando:", keyA, "->", keyB, "| pitch:", pitch_used)
for t, mu_t in zip(alphas, mus_interp):
    xhat_norm = decode_from_mu(cvae, mu_t, pitch_used, device)  # (80,128)
    xhat_db = norm_to_logmel_db(xhat_norm)
    wav_t = logmel_db_to_audio(xhat_db, sr=SR, n_fft=N_FFT, hop_length=HOP, n_mels=N_MELS, n_iter=64)

    print(f"alpha={t:.2f}")
    display(Audio(wav_t, rate=SR))


Interpolando: mallet_acoustic_062-059-025 -> vocal_synthetic_003-059-127 | pitch: 59
alpha=0.00


alpha=0.17


alpha=0.33


alpha=0.50


alpha=0.67


alpha=0.83


alpha=1.00


In [59]:
import numpy as np
import torch

@torch.no_grad()
def encode_mu_mean_over_windows(
    cvae,
    x_norm_full: np.ndarray,   # (80, T_full) em [-1,1]
    device,
    win: int = 128,
    hop_win: int = 64,
    max_windows: int | None = None,
):
    """
    Faz sliding window no log-mel full e calcula mu em cada janela.
    Retorna:
      mu_mean: (latent_dim,)
      mu_all:  (n_windows, latent_dim)
      starts:  lista com os inícios das janelas
    """
    cvae.eval()
    N_MELS, T_full = x_norm_full.shape

    mus = []
    starts = []

    n = 0
    for start in range(0, T_full, hop_win):
        end = start + win
        chunk = x_norm_full[:, start:end]
        valid_len = chunk.shape[1]
        if valid_len == 0:
            break

        # pad pra win
        if valid_len < win:
            chunk = np.pad(
                chunk,
                ((0, 0), (0, win - valid_len)),
                mode="constant",
                constant_values=-1.0
            )

        x_in = torch.tensor(chunk, dtype=torch.float32, device=device).unsqueeze(0).unsqueeze(0)  # (1,1,80,win)
        mu, logvar = cvae.encoder(x_in)  # encoder NÃO usa pitch no seu models.py

        mus.append(mu[0].detach().cpu())
        starts.append(start)

        n += 1
        if (max_windows is not None) and (n >= max_windows):
            break

    mu_all = torch.stack(mus, dim=0)                  # (n_windows, latent_dim)
    mu_mean = mu_all.mean(dim=0)                      # (latent_dim,)
    return mu_mean, mu_all, starts

# Exemplo de uso:
# muA_mean, muA_all, startsA = encode_mu_mean_over_windows(cvae, xA_norm_full, device, win=T, hop_win=T//2)
# muB_mean, muB_all, startsB = encode_mu_mean_over_windows(cvae, xB_norm_full, device, win=T, hop_win=T//2)


In [60]:
import numpy as np
import torch
from IPython.display import Audio, display

@torch.no_grad()
def decode_full_from_mu_with_overlap_add(
    cvae,
    x_norm_full_ref: np.ndarray,  # só pra pegar T_full e o "grid" de janelas
    mu_vec: torch.Tensor,         # (latent_dim,)
    pitch_i: int,
    device,
    win: int = 128,
    hop_win: int = 64,
):
    """
    Gera xhat_norm_full (80, T_full) decodificando um mu FIXO em cada janela
    e costurando com overlap-add.
    """
    cvae.eval()
    N_MELS, T_full = x_norm_full_ref.shape

    w = np.hanning(win).astype(np.float32)[None, :]  # (1,win)
    y_acc = np.zeros((N_MELS, T_full), dtype=np.float32)
    w_acc = np.zeros((1, T_full), dtype=np.float32)

    z = mu_vec.to(device).unsqueeze(0)  # (1,latent_dim)
    pitch_t = torch.tensor([pitch_i], device=device, dtype=torch.long)
    cond = cvae.pitch_cond(pitch_t)     # (1,cond_dim)

    for start in range(0, T_full, hop_win):
        end = start + win
        valid_len = min(win, T_full - start)
        if valid_len <= 0:
            break

        # decodifica sempre o mesmo frame (80,win)
        x_hat = cvae.decoder(z, cond)            # (1,1,80,win)
        chunk_hat = x_hat[0, 0].detach().cpu().numpy()  # (80,win)

        y_acc[:, start:start+valid_len] += chunk_hat[:, :valid_len] * w[:, :valid_len]
        w_acc[:, start:start+valid_len] += w[:, :valid_len]

    w_acc = np.maximum(w_acc, 1e-6)
    return y_acc / w_acc  # (80, T_full)

def interpolate(a: torch.Tensor, b: torch.Tensor, n: int):
    alphas = np.linspace(0.0, 1.0, n)
    return [((1-t)*a + t*b) for t in alphas], alphas

# --- Exemplo completo: mu_mean + full-length áudio ---
N_STEPS = 7
win = T
hop_win = T // 2

# 1) mu_mean dos dois áudios (robusto)
muA_mean, _, _ = encode_mu_mean_over_windows(cvae, xA_norm_full, device, win=win, hop_win=hop_win)
muB_mean, _, _ = encode_mu_mean_over_windows(cvae, xB_norm_full, device, win=win, hop_win=hop_win)

# 2) interpolações
mus_interp, alphas = interpolate(muA_mean, muB_mean, N_STEPS)

print("FULL-LENGTH interpolation:", keyA, "->", keyB, "| pitch:", pitch_used)
for t, mu_t in zip(alphas, mus_interp):
    # 3) reconstrói full-length no mel
    xhat_norm_full = decode_full_from_mu_with_overlap_add(
        cvae=cvae,
        x_norm_full_ref=xA_norm_full,   # usa o A como referência de tamanho
        mu_vec=mu_t,
        pitch_i=pitch_used,
        device=device,
        win=win,
        hop_win=hop_win,
    )

    # 4) mel -> áudio (griffin-lim)
    xhat_db_full = norm_to_logmel_db(xhat_norm_full)
    wav_interp = logmel_db_to_audio(
        xhat_db_full,
        sr=SR,
        n_fft=N_FFT,
        hop_length=HOP,
        n_mels=N_MELS,
        n_iter=64,
    )

    print(f"alpha={t:.2f}")
    display(Audio(wav_interp, rate=SR))


FULL-LENGTH interpolation: mallet_acoustic_062-059-025 -> vocal_synthetic_003-059-127 | pitch: 59
alpha=0.00


alpha=0.17


alpha=0.33


alpha=0.50


alpha=0.67


alpha=0.83


alpha=1.00


In [61]:
import numpy as np
import torch

@torch.no_grad()
def encode_mu(x, pitch=None):
    """
    Retorna mu(x). No seu models.py, encoder NÃO usa pitch.
    """
    mu, logvar = cvae.encoder(x)   # (B, D)
    return mu, logvar

@torch.no_grad()
def decode_from_z(z, pitch):
    """
    Decodifica a partir de um z (B,D) usando cond(pitch).
    """
    cond = cvae.pitch_cond(pitch)      # (B, cond_dim)
    x_hat = cvae.decoder(z, cond)      # (B,1,80,128)
    return x_hat

def lerp(a, b, alpha):
    return (1 - alpha) * a + alpha * b


In [62]:
# Escolha dois índices dentro do dataset de embeddings
iA = 10   # <-- troque
iB = 50   # <-- troque

keyA, keyB = keys_all[iA], keys_all[iB]
pitchA, pitchB = int(pitch_all[iA]), int(pitch_all[iB])
famA, famB = int(family_all[iA]), int(family_all[iB])

print("A:", iA, keyA, "pitch", pitchA, "family", famA)
print("B:", iB, keyB, "pitch", pitchB, "family", famB)

# Para interpolação "pura" no timbre: mantenha pitch fixo (recomendado)
pitch_interp = pitchA  # ou escolha manualmente

# Pegamos mu como proxy de z (trajetória no espaço latente)
zA = mu_all[iA].astype(np.float32)
zB = mu_all[iB].astype(np.float32)

N_STEPS = 9  # 7 ou 9 fica bom
alphas = np.linspace(0, 1, N_STEPS).astype(np.float32)

Z = np.stack([lerp(zA, zB, a) for a in alphas], axis=0)  # (N_STEPS, D)
print("Z:", Z.shape, "pitch_interp:", pitch_interp)


A: 10 organ_electronic_028-053-127 pitch 53 family 6
B: 50 organ_electronic_113-037-025 pitch 37 family 6
Z: (9, 32) pitch_interp: 53


In [63]:
# Tensor pitch fixo
pitch_t = torch.full((N_STEPS,), pitch_interp, device=device, dtype=torch.long)

# Decode do caminho no latente
Z_t = torch.tensor(Z, device=device, dtype=torch.float32)
with torch.no_grad():
    x_hat = decode_from_z(Z_t, pitch_t)   # (N_STEPS, 1, 80, 128)

# (Opcional mas recomendado) re-encodar x_hat para obter mu_hat no mesmo "espaço mu"
with torch.no_grad():
    mu_hat, logvar_hat = cvae.encoder(x_hat)  # (N_STEPS, D)

mu_hat_np = mu_hat.detach().cpu().numpy()
print("x_hat:", tuple(x_hat.shape), "| mu_hat:", mu_hat_np.shape)


x_hat: (9, 1, 80, 128) | mu_hat: (9, 32)


In [64]:
import plotly.express as px
import pandas as pd

# Projeção PCA do dataset + trajeto
mu_pca = pca.transform(mu_all)          # (N,2)
traj_pca = pca.transform(mu_hat_np)     # (steps,2)

df_base = pd.DataFrame({
    "pc1": mu_pca[:,0],
    "pc2": mu_pca[:,1],
    "family": family_all.astype(int),
    "pitch": pitch_all.astype(int),
    "key": keys_all
})

df_traj = pd.DataFrame({
    "pc1": traj_pca[:,0],
    "pc2": traj_pca[:,1],
    "alpha": alphas,
    "step": np.arange(N_STEPS),
})
df_traj["label"] = df_traj["step"].astype(str) + " | a=" + df_traj["alpha"].round(2).astype(str)

fig = px.scatter(
    df_base, x="pc1", y="pc2",
    color=df_base["family"].astype(str),  # categórico (legenda)
    hover_data=["pitch","key"],
    title=f"PCA(mu): trajetória interpolada | A={keyA} -> B={keyB} | pitch_fix={pitch_interp}"
)

# adiciona linha/markers da trajetória
fig.add_scatter(
    x=df_traj["pc1"], y=df_traj["pc2"],
    mode="lines+markers+text",
    text=df_traj["step"],
    textposition="top center",
    name="traj",
)

fig.show()


In [65]:
import plotly.express as px
import pandas as pd

mu_umap = reducer.transform(mu_all)       # (N,2) se disponível
traj_umap = reducer.transform(mu_hat_np)  # (steps,2)

df_base = pd.DataFrame({
    "u1": mu_umap[:,0],
    "u2": mu_umap[:,1],
    "family": family_all.astype(int),
    "pitch": pitch_all.astype(int),
    "key": keys_all
})

df_traj = pd.DataFrame({
    "u1": traj_umap[:,0],
    "u2": traj_umap[:,1],
    "alpha": alphas,
    "step": np.arange(N_STEPS),
})

fig = px.scatter(
    df_base, x="u1", y="u2",
    color=df_base["family"].astype(str),  # categórico
    hover_data=["pitch","key"],
    title=f"UMAP(mu): trajetória interpolada | A={keyA} -> B={keyB} | pitch_fix={pitch_interp}"
)

fig.add_scatter(
    x=df_traj["u1"], y=df_traj["u2"],
    mode="lines+markers+text",
    text=df_traj["step"],
    textposition="top center",
    name="traj",
)

fig.show()
