In [163]:
# Ensure project libs are on sys.path before imports
import sys, os
ROOT = '/data1/yuchen/cd4mt'
sys.path.append(f'{ROOT}/src')
sys.path.append(ROOT)
sys.path.append(f'{ROOT}/ldm')
print('[ok] project paths configured')


# Require micromamba env 'cdp10'
_env = os.environ.get('CONDA_DEFAULT_ENV') or os.environ.get('MAMBA_DEFAULT_ENV') or ''
print(f"Active env: {_env}")
assert 'cdp10' in _env, "Please 'micromamba activate cdp10' before running this notebook."


[ok] project paths configured
Active env: cdp10


# CD4MT 


This notebook presents a clean, student-style walkthrough for CD4MT.

- Goal: show how to train a small CT model and run inference.
- Dataset: `dataset/slakh_44100` (44100 Hz).
- Checkpoints: under `checkpoints/`.

Structure:
1. Overview
2. Environment
3. Config & Data
4. Model
5. Inference (pretrained)
6. Training (short run)
7. Visualization & Metrics
8. Checkpoints & Notes


- Flow: load config/data, encode stems to CAE latents, sanity-check decode, build CT UNet, inspect shapes, visualize and play audio.

Key shapes used below:
- Waveforms: `(B, S, T)` e.g. `(4, 4, 524288)`; mix `(B, T)`.
- CAE latent per stem: `(B, C_lat, L)` e.g. `(4, 64, 127)`.
- Stacked latents: `(B, S, C_lat, L)` e.g. `(4, 4, 64, 127)`.
- 2D map for UNet: `(B, S*C_lat, H, W)` with `H=W=ceil(sqrt(L))` (for `L=127`, `H=W=12`).

Section guide:
- 1. env - device and imports.
- 2. config and load_data - prints config and fetches a batch; logs `(B,S,T)`.
- 3. CAE encoder test - encodes to `(B,64,L)`, stacks to `(B,4,64,L)`; then decode + MSE.
- 4. CD4MT model init - CT UNet with in/out channels `S*C_lat=256`, EMA ckpt.
- 5. Shape instrumentation - param shapes and optional short-run logger.
- 6. Generation test - use Karras sampler (test.py) or Lightning `ScoreDiffusionModel.sample(...)`.
- 7-9. Visualization and audio players.

## 1. Environment

In [164]:
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
import torch
print(torch.__version__)

import os
import sys
import yaml
import torch
import numpy as np
import matplotlib.pyplot as plt
import soundfile as sf
from pathlib import Path
import logging
from datetime import datetime
import pytorch_lightning as pl
import torch.cuda

if hasattr(torch, 'set_float32_matmul_precision'):
    torch.set_float32_matmul_precision('medium')
from pytorch_lightning.callbacks import ModelCheckpoint

# optional: swanlab for experiment tracking
import importlib.util as _ilu
swanlab = None
if _ilu.find_spec('swanlab') is not None:
    import swanlab


from IPython.display import Audio, display, HTML, Markdown
import matplotlib.font_manager as fm
from src.music2latent.music2latent import EncoderDecoder
from ldm.data.multitrack_datamodule import DataModuleFromConfig
from src.cm.script_util import (
    model_and_diffusion_defaults,
    create_model_and_diffusion,
    cm_train_defaults,
    args_to_dict,
    add_dict_to_argparser,
    create_ema_and_scales_fn,
)


Device: cuda:0
2.7.1+cu126


In [165]:
# !pip install /data1/yuchen/cd4mt/src/env/torchvision-0.22.0+cu118-cp39-cp39-manylinux_2_28_x86_64.whl
# !pip install /data1/yuchen/cd4mt/src/env/torchvision-0.22.0+cu118-cp39-cp39-manylinux_2_28_x86_64.whl


## 2. Config & Data

Explainer - Config and First Batch (Shapes):

- Reads YAML config and prints key settings (stems, batch size, sample rate).
- Builds DataModule and pulls one batch to inspect shapes.
- Expected keys: `['fname', 'fbank_stems', 'waveform_stems', 'waveform', 'fbank']`.
- Typical shapes here:
  - `waveform_stems`: `(B, S, T)`, e.g. `(4, 4, 524288)`.
  - `waveform` (mix): `(B, T)`, e.g. `(4, 524288)`.

In [166]:
# Build DataModule (self-contained; no try/except)
CFG_PATH = "configs/cd4mt_small.yaml"
import yaml
with open(CFG_PATH, 'r') as f:
    cfg = yaml.safe_load(f)
print(f"Loaded config: {CFG_PATH}")
from ldm.data.multitrack_datamodule import DataModuleFromConfig

dm = DataModuleFromConfig(**cfg["data"]["params"])
dm.prepare_data()
dm.setup(stage="fit")

train_loader = dm.train_dataloader()
print(f"train_loader: {len(train_loader)}")


Loaded config: configs/cd4mt_small.yaml
Found 1290 tracks.
sr=44100, min: 10, max: 600
Keeping 1289 of 1290 tracks
Data size: 26309
Use mixup rate of 0.0; Use SpecAug (T,F) of (0, 0); Use blurring effect or not False
| Audiostock Dataset Length:26309 | Epoch Length: 26309
Found 271 tracks.
sr=44100, min: 10, max: 600
Keeping 270 of 271 tracks
Data size: 5422
Use mixup rate of 0.0; Use SpecAug (T,F) of (0, 0); Use blurring effect or not False
| Audiostock Dataset Length:5422 | Epoch Length: 5422
Found 152 tracks.
sr=44100, min: 10, max: 600
Keeping 151 of 152 tracks
Data size: 3249
Use mixup rate of 0.0; Use SpecAug (T,F) of (0, 0); Use blurring effect or not False
| Audiostock Dataset Length:3249 | Epoch Length: 3249
train_loader: 6578


## 3. CAE Encoder Test

Explainer - CAE Encode/Stack/Decode (Shapes):

- For each stem `s` in 4 stems, encode `(B, T)` to `(B, C_lat, L)` with CAE (e.g. `(4, 64, 127)`).
- Stack across stems → `latents_stacked`: `(B, S, C_lat, L)` (e.g. `(4, 4, 64, 127)`).
- Decode each stem back to waveform and align to original `T` via crop/pad.
- Report reconstruction MSE for sanity.

In [167]:
# Fetch one batch and define waveform tensors
batch = next(iter(train_loader))
assert 'waveform_stems' in batch, 'Expected key waveform_stems in batch'

wav_stems = batch['waveform_stems']  # Tensor (B, S, T)
wave_mix = batch['waveform'] if 'waveform' in batch else wav_stems.sum(dim=1)
B, S_batch, T = wav_stems.shape
sample_rate = int(cfg['data']['params']['preprocessing']['audio']['sampling_rate'])
print(f'Batch shapes: wav_stems={tuple(wav_stems.shape)}, waveform={tuple(wave_mix.shape)}')


Batch shapes: wav_stems=(4, 4, 524288), waveform=(4, 524288)


In [168]:
S = 4
stem_names = cfg['model']['params']['stem_names']
latents_list = []
encode_shapes = []
ae = EncoderDecoder(device=device)
print(f"\n stem_num {S} ")

for s in range(S):
    stem_name = stem_names[s] if s < len(stem_names) else f"stem_{s}"
    print(f"\n {stem_name} stem {s}:")
    stem_audio = wav_stems[:, s].cpu().numpy()  
    print(f"stem_audio : {stem_audio.shape}")

    stem_latents = ae.encode(stem_audio)
    if isinstance(stem_latents, np.ndarray):
        stem_latents = torch.from_numpy(stem_latents)
    
    # unify dtype to float32 to avoid AMP mix issues
    stem_latents = stem_latents.to(dtype=torch.float32)
    print(f"stem_latents.shape {stem_latents.shape}, tem_latents.dtype {stem_latents.dtype}, range: [{stem_latents.min():.3f}, {stem_latents.max():.3f}]")
    latents_list.append(stem_latents)
    encode_shapes.append(stem_latents.shape)
    
# ensure same latent length across stems for stacking
min_L = min(t.shape[-1] for t in latents_list)
if any(t.shape[-1] != min_L for t in latents_list):
    print(f"[Warn] Latent length mismatch across stems, cropping all to L={min_L}")
    latents_list = [t[..., :min_L] for t in latents_list]



latents_stacked = torch.stack(latents_list, dim=1)  # (B, S, C, L)
print(f"latents_stacked.shape: {latents_stacked.shape}")
print(f" Batch={latents_stacked.shape[0]}, Stems={latents_stacked.shape[1]}, Channels={latents_stacked.shape[2]}, Length={latents_stacked.shape[3]}")
latents = latents_stacked.to(device)
print(f" latents on : {device}")

/data1/yuchen/cd4mt/src/music2latent/music2latent

 stem_num 4 

 bass stem 0:
stem_audio : (4, 524288)
stem_latents.shape torch.Size([4, 64, 127]), tem_latents.dtype torch.float32, range: [-4.180, 3.848]

 drums stem 1:
stem_audio : (4, 524288)
stem_latents.shape torch.Size([4, 64, 127]), tem_latents.dtype torch.float32, range: [-4.422, 3.328]

 guitar stem 2:
stem_audio : (4, 524288)
stem_latents.shape torch.Size([4, 64, 127]), tem_latents.dtype torch.float32, range: [-5.012, 3.721]

 piano stem 3:
stem_audio : (4, 524288)
stem_latents.shape torch.Size([4, 64, 127]), tem_latents.dtype torch.float32, range: [-4.805, 4.062]
latents_stacked.shape: torch.Size([4, 4, 64, 127])
 Batch=4, Stems=4, Channels=64, Length=127
 latents on : cuda:0


In [169]:
recst_list = []
for s in range(S):
    stem_name = stem_names[s] if s < len(stem_names) else f"stem_{s}"
    print(f"\nDecode {stem_name}")
    stem_latents = latents[:, s].cpu().numpy()  # (B, C, L)
    recst = ae.decode(stem_latents)
    print(f"recst.shape: {recst.shape}")
    print(f"range: [{recst.min():.3f}, {recst.max():.3f}]")

    if isinstance(recst, torch.Tensor):
        recst = recst.cpu().numpy()
    current_length = recst.shape[-1]
    if current_length > T:
        excess = current_length - T
        start_trim = excess // 2
        end_trim = excess - start_trim
        recst = recst[..., start_trim:current_length-end_trim]
    elif current_length < T:
        deficit = T - current_length
        pad_left = deficit // 2
        pad_right = deficit - pad_left
        recst = np.pad(recst, ((0,0), (pad_left, pad_right)), mode='constant', constant_values=0)

    recst_list.append(recst)

recst_aud = np.stack(recst_list, axis=1)  # (B, S, T')
recst_tensor = torch.from_numpy(recst_aud).to(device)

print(f"\nrecst shape: {recst_aud.shape}")
print(f"original lenght: {T}, recst length: {recst_aud.shape[2]}")

if recst_aud.shape[2] == T:
    mse_error = np.mean((wav_stems.cpu().numpy() - recst_aud)**2)
    print(f"MSE: {mse_error:.6f}")
else:
    print("length error")


Decode bass
recst.shape: torch.Size([4, 521728])
range: [-0.192, 0.281]

Decode drums
recst.shape: torch.Size([4, 521728])
range: [-0.612, 0.625]

Decode guitar
recst.shape: torch.Size([4, 521728])
range: [-0.329, 0.316]

Decode piano
recst.shape: torch.Size([4, 521728])
range: [-0.349, 0.352]

recst shape: (4, 4, 524288)
original lenght: 524288, recst length: 524288
MSE: 0.001932


## 4. CD4MT Model Initialization (CT)

CT UNet Setup 
- Compute `in_channels = S * C_lat` (here 4×64=256) and replace first/last convs to match.
- Load best EMA checkpoint from `checkpoints/ct_unet_ema_best_val*.pth`.
- Print total parameter count and confirm device.

### 4.1 Model Init and Shape

In [170]:
# Build CT model from medium config and load best EMA checkpoint
import yaml, torch, re
from pathlib import Path
from src.cm.script_util import create_model

CFG_PATH_LOCAL = 'configs/cd4mt_medium.yaml'
with open(CFG_PATH_LOCAL, 'r') as f:
    cfg_local = yaml.safe_load(f)
S = int(cfg_local['model']['params'].get('num_stems', 4))
C = int(cfg_local['model']['params'].get('cae_z_channels', 64))
in_channels = S * C

cmov = cfg_local['model']['params'].get('cm_model_override', {})
getv = lambda k,d: cmov.get(k,d)
model = create_model(
    image_size=int(getv('image_size', 32)),
    num_channels=int(getv('num_channels', 192)),
    num_res_blocks=int(getv('num_res_blocks', 2)),
    channel_mult=str(getv('channel_mult', '1,2,4')),
    learn_sigma=bool(getv('learn_sigma', False)),
    class_cond=bool(getv('class_cond', False)),
    use_checkpoint=bool(getv('use_checkpoint', False)),
    attention_resolutions=str(getv('attention_resolutions', '16,8,4')),
    num_heads=int(getv('num_heads', 6)),
    num_head_channels=int(getv('num_head_channels', 32)),
    num_heads_upsample=int(getv('num_heads_upsample', -1)),
    use_scale_shift_norm=bool(getv('use_scale_shift_norm', True)),
    dropout=float(getv('dropout', 0.1)),
    resblock_updown=bool(getv('resblock_updown', False)),
    use_fp16=bool(getv('use_fp16', False)),
    use_new_attention_order=bool(getv('use_new_attention_order', False)),
)



In [171]:
# Utilities: concise shape logging for batches and modules
import os, torch, pytorch_lightning as pl
from typing import Any

def shape_of(x: Any):
    return tuple(x.shape) if hasattr(x, 'shape') else type(x).__name__

class ShapeLogger(pl.Callback):
    """Log batch shapes at each epoch/step/GPU (local rank)"""
    def __init__(self, prefix="[Data]"):
        self.prefix=prefix
    def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
        lr = getattr(trainer.strategy, "local_rank", 0)
        e = trainer.current_epoch
        s = trainer.global_step
        if isinstance(batch, dict):
            shapes = {k: shape_of(v) for k,v in batch.items() if hasattr(v,"shape")}
        elif isinstance(batch, (list, tuple)):
            shapes = [shape_of(v) for v in batch]
        else:
            shapes = shape_of(batch)
        print(f"{self.prefix} epoch={e} step={s} gpu={lr} batch_idx={batch_idx} shapes={shapes}")
    def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx=0):
        lr = getattr(trainer.strategy, "local_rank", 0)
        e = trainer.current_epoch
        s = trainer.global_step
        if isinstance(batch, dict):
            shapes = {k: shape_of(v) for k,v in batch.items() if hasattr(v,"shape")}
        elif isinstance(batch, (list, tuple)):
            shapes = [shape_of(v) for v in batch]
        else:
            shapes = shape_of(batch)
        print(f"[ValData] epoch={e} step={s} gpu={lr} batch_idx={batch_idx} shapes={shapes}")
        
# Parameter shapes of the whole model
sd = model.state_dict()
print("[Model Param Shapes] name -> shape")
for k,v in sd.items():
    s = tuple(v.shape) if hasattr(v, 'shape') else str(type(v))
    print(f"{k}: {s}")



[Model Param Shapes] name -> shape
time_embed.0.weight: (768, 192)
time_embed.0.bias: (768,)
time_embed.2.weight: (768, 768)
time_embed.2.bias: (768,)
input_blocks.0.0.weight: (192, 3, 3, 3)
input_blocks.0.0.bias: (192,)
input_blocks.1.0.in_layers.0.weight: (192,)
input_blocks.1.0.in_layers.0.bias: (192,)
input_blocks.1.0.in_layers.2.weight: (192, 192, 3, 3)
input_blocks.1.0.in_layers.2.bias: (192,)
input_blocks.1.0.emb_layers.1.weight: (384, 768)
input_blocks.1.0.emb_layers.1.bias: (384,)
input_blocks.1.0.out_layers.0.weight: (192,)
input_blocks.1.0.out_layers.0.bias: (192,)
input_blocks.1.0.out_layers.3.weight: (192, 192, 3, 3)
input_blocks.1.0.out_layers.3.bias: (192,)
input_blocks.2.0.in_layers.0.weight: (192,)
input_blocks.2.0.in_layers.0.bias: (192,)
input_blocks.2.0.in_layers.2.weight: (192, 192, 3, 3)
input_blocks.2.0.in_layers.2.bias: (192,)
input_blocks.2.0.emb_layers.1.weight: (384, 768)
input_blocks.2.0.emb_layers.1.bias: (384,)
input_blocks.2.0.out_layers.0.weight: (192,)


### 4.2 Infer

In [172]:
import torch.nn as nn
# adjust first/last conv to S*C
model.input_blocks[0][0] = nn.Conv2d(in_channels, model.input_blocks[0][0].out_channels,
                                     kernel_size=model.input_blocks[0][0].kernel_size,
                                     stride=model.input_blocks[0][0].stride,
                                     padding=model.input_blocks[0][0].padding,
                                     bias=(model.input_blocks[0][0].bias is not None))
model.out[-1] = nn.Conv2d(model.out[-1].in_channels, in_channels,
                          kernel_size=model.out[-1].kernel_size,
                          stride=model.out[-1].stride,
                          padding=model.out[-1].padding,
                          bias=(model.out[-1].bias is not None))
model = model.to(device); model.eval()

# pick best EMA ckpt from ./checkpoints (min val)
ckdir = Path('checkpoints')
bests = sorted(ckdir.glob('ct_unet_ema_best_val*.pth'))
val = lambda p: (float('inf') if not re.search(r'best_val([0-9.]+)\.pth$', p.name) else float(re.search(r'best_val([0-9.]+)\.pth$', p.name).group(1)))
ckpt = min(bests, key=val) if bests else sorted(ckdir.glob('ct_unet_ema_last_e*.pth'), key=lambda p: p.stat().st_mtime, reverse=True)[0]
print('Using CT EMA ckpt:', ckpt)
ck = torch.load(str(ckpt), map_location='cpu')
sd = ck.get('state_dict', ck)
missing, unexpected = model.load_state_dict(sd, strict=False)
print(f'CT load_state: missing={len(missing)}, unexpected={len(unexpected)}')
# report params
_total = sum(p.numel() for p in model.parameters())
print(f'CT.UNet params: {_total:,} (~{_total*4/1024/1024:.1f} MB fp32) in_channels={in_channels}')

Using CT EMA ckpt: checkpoints/ct_unet_ema_best_val0.073735.pth


CT load_state: missing=0, unexpected=0
CT.UNet params: 151,484,992 (~577.9 MB fp32) in_channels=256


## 5. Audio Generation Test

Explainer - Sampling Note (How to Generate):

- The raw CT UNet doesn’t implement `sample(...)`.
- Use the CT helpers (see `test.py`): `KarrasDenoiser` + `karras_sample(...)`, then invert 2D latents to `(B,S,C_lat,L)` and CAE-decode.
- Or instantiate Lightning `ScoreDiffusionModel` and call `.sample(...)` which wraps schedule/sampler.

In [173]:
from src.cm.karras_diffusion import KarrasDenoiser, karras_sample
import math, numpy as np, torch

with torch.no_grad():
    # Derive channels and latent length
    S_local = int(S) if 'S' in globals() else int(cfg_local['model']['params'].get('num_stems', 4))
    C_local = int(C) if 'C' in globals() else int(cfg_local['model']['params'].get('cae_z_channels', 64))
    if 'latents_stacked' in globals():
        L_local = int(latents_stacked.shape[-1])
    else:
        L_local = int(cfg_local.get('sampling', {}).get('length', 127))
    side = int(math.sqrt(L_local))
    if side * side < L_local:
        side += 1
    in_channels = S_local * C_local
    gen_batch_size = 1
    gen_steps = 10

    sigma_min = float(cfg_local['model']['params'].get('sigma_min', 1e-4))
    sigma_max = float(cfg_local['model']['params'].get('sigma_max', 3.0))
    sigma_data = float(cfg_local['model']['params'].get('diffusion_sigma_data', 0.5))

    print('Generate with Karras sampler')
    print(f'  2D shape: (B={gen_batch_size}, SC={in_channels}, H=W={side}), steps={gen_steps}')

    diffusion = KarrasDenoiser(
        sigma_data=sigma_data,
        sigma_min=sigma_min,
        sigma_max=sigma_max,
        weight_schedule='karras',
        loss_norm='l2',
    )

    shape = (gen_batch_size, in_channels, side, side)
    torch.manual_seed(12345)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(12345)

    gen_imgs = karras_sample(
        diffusion,
        model,
        shape=shape,
        steps=gen_steps,
        device=device,
        sigma_min=sigma_min,
        sigma_max=sigma_max,
        model_kwargs={},
    )

    # Invert to latents and decode
    flat = gen_imgs.float().cpu().view(gen_batch_size, in_channels, side*side)
    if flat.shape[-1] >= L_local:
        flat = flat[..., :L_local]
    else:
        reps = (L_local + flat.shape[-1] - 1) // flat.shape[-1]
        flat = torch.cat([flat] * reps, dim=-1)[..., :L_local]
    gen_latents = flat.view(gen_batch_size, S_local, C_local, L_local)

    # Decode (first sample) to audio stems
    gen_wavs = []
    T_target = int(wav_stems.shape[-1]) if 'wav_stems' in globals() else None
    for s in range(S_local):
        stem_lat = gen_latents[0, s].numpy()
        audio = ae.decode(stem_lat, denoising_steps=1)
        if isinstance(audio, torch.Tensor):
            at = audio.detach().float().cpu()
            if at.ndim == 2 and at.shape[0] <= 16 and at.shape[0] < at.shape[1]:
                at = at.transpose(0, 1)
            wav = at.numpy()
        else:
            wav = np.asarray(audio, dtype=np.float32)
            if wav.ndim == 2 and wav.shape[0] <= 16 and wav.shape[0] < wav.shape[1]:
                wav = wav.T
        wav = np.squeeze(wav).astype(np.float32)
        if T_target is not None:
            if wav.shape[-1] > T_target:
                wav = wav[:T_target]
            elif wav.shape[-1] < T_target:
                pad = T_target - wav.shape[-1]
                wav = np.pad(wav, (0, pad), mode='constant')
        gen_wavs.append(wav)
    gen_wavs = np.stack(gen_wavs, axis=0)
    gen_aud = torch.from_numpy(gen_wavs[None, ...])
    print(f'Generated audio tensor: {tuple(gen_aud.shape)}')


Generate with Karras sampler
  2D shape: (B=1, SC=256, H=W=12), steps=10
Generated audio tensor: (1, 4, 524288)
