# CD4MT

## 1. env

In [54]:
# old ver: from flash_attn.flash_attention import FlashAttention

In [55]:
# Auto-select CUDA device with max free memory
import os, torch

def auto_select_cuda(min_free_gb=4.0):
    if not torch.cuda.is_available():
        print("CUDA not available -> cpu")
        return torch.device("cpu")
    best, best_free = None, -1.0
    for i in range(torch.cuda.device_count()):
        free_b, total_b = torch.cuda.mem_get_info(i)
        free_gb = free_b / (1024**3)
        if free_gb > best_free and free_gb >= min_free_gb:
            best, best_free = i, free_gb
    if best is None:
        for i in range(torch.cuda.device_count()):
            free_b, _ = torch.cuda.mem_get_info(i)
            free_gb = free_b / (1024**3)
            if free_gb > best_free:
                best, best_free = i, free_gb
    print(f"Selected CUDA:{best} (free ~{best_free:.1f} GB)")
    return torch.device(f"cuda:{best}")

device = auto_select_cuda(min_free_gb=4.0)



Selected CUDA:3 (free ~23.2 GB)


In [56]:
# env: cdp10
# Python 3.10.18
# torch 2.8.0
# cu12

In [57]:
%matplotlib inline
import os, sys, yaml, torch, numpy as np, matplotlib.pyplot as plt
import soundfile as sf
from pathlib import Path
from IPython.display import Audio, display, HTML, Markdown
import matplotlib.font_manager as fm

ROOT = "/data1/yuchen/cd4mt"
sys.path.append("/data1/yuchen/MusicLDM-Ext/src")
sys.path.append(ROOT)
sys.path.append(f"{ROOT}/ldm")
os.chdir(ROOT)

print(f"Working directory: {os.getcwd()}")
from src.music2latent.music2latent import EncoderDecoder

from ldm.data.multitrack_datamodule import DataModuleFromConfig
device = device  # unified
print(f"Device: {device}")


Working directory: /data1/yuchen/cd4mt
Device: cuda:3


## 2. config and load_data

In [58]:
CFG_PATH = "configs/cd4mt_medium.yaml"

with open(CFG_PATH, 'r') as f:
    cfg = yaml.safe_load(f)

In [59]:
# Load ct_config and ct_hparams from medium YAML (for CT) with robust typing
import yaml, torch, numpy as np
HCONF_PATH = '/data1/yuchen/cd4mt/configs/cd4mt_medium.yaml'
with open(HCONF_PATH, 'r') as _f:
    ct_config = yaml.safe_load(_f)

def _get(d, path, default=None):
    cur = d
    for k in path.split('.'):
        if not isinstance(cur, dict) or k not in cur:
            return default
        cur = cur[k]
    return cur

# Robust type casting (YAML可能把科学计数值读成字符串)
_lr_raw = _get(ct_config, 'model.params.base_learning_rate', 5e-5)
try:
    lr = float(_lr_raw)
except Exception:
    print('[ct_hparams] WARN: invalid lr, fallback to 5e-5; raw=', _lr_raw)
    lr = 5e-5

_bs_raw = _get(ct_config, 'data.params.batch_size', 4)
try:
    batch_size = int(_bs_raw)
except Exception:
    print('[ct_hparams] WARN: invalid batch_size, fallback to 4; raw=', _bs_raw)
    batch_size = 4

_sigma_data_raw = _get(ct_config, 'model.params.diffusion_sigma_data', 0.5)
try:
    sigma_data = float(_sigma_data_raw)
except Exception:
    print('[ct_hparams] WARN: invalid sigma_data, fallback to 0.5; raw=', _sigma_data_raw)
    sigma_data = 0.5

ct_hparams = {
    'batch_size': batch_size,
    'lr': lr,
    'ema_decay': 0.95,
    'num_scales': 32,
    'sigma_min': 0.002,
    'sigma_max': 80.0,
    'sigma_data': sigma_data,
    'epochs': 120,
    'grad_clip': 1.0,
    'log_interval': 20,
}
print('ct_hparams loaded from', HCONF_PATH)
print('ct_hparams:', ct_hparams, '| types: lr', type(lr).__name__, 'batch_size', type(batch_size).__name__, 'sigma_data', type(sigma_data).__name__)


ct_hparams loaded from /data1/yuchen/cd4mt/configs/cd4mt_medium.yaml
ct_hparams: {'batch_size': 6, 'lr': 0.0001, 'ema_decay': 0.95, 'num_scales': 32, 'sigma_min': 0.002, 'sigma_max': 80.0, 'sigma_data': 0.5, 'epochs': 120, 'grad_clip': 1.0, 'log_interval': 20} | types: lr float batch_size int sigma_data float


In [60]:
dm = DataModuleFromConfig(**cfg["data"]["params"])
dm.prepare_data()
dm.setup(stage="fit")

train_loader = dm.train_dataloader()
print(f"train_loader: {len(train_loader)}")
batch = next(iter(train_loader))
print(f"batch.keys(): {list(batch.keys())}")

wav_stems = batch["waveform_stems"]  # (B, S, T)
wav_mix = batch.get("waveform", None)  # (B, T)

print(f"wav_stems: {wav_stems.shape}")
if wav_mix is not None:
    print(f"wav_mix: {wav_mix.shape}")

B, S, T = wav_stems.shape
print(f"Batch={B}, Stems={S}, Time={T}")

Found 1290 tracks.
sr=44100, min: 10, max: 600
Keeping 1289 of 1290 tracks
Data size: 26309
Use mixup rate of 0.0; Use SpecAug (T,F) of (0, 0); Use blurring effect or not False
| Audiostock Dataset Length:26309 | Epoch Length: 26309
Found 271 tracks.
sr=44100, min: 10, max: 600
Keeping 270 of 271 tracks
Data size: 5422
Use mixup rate of 0.0; Use SpecAug (T,F) of (0, 0); Use blurring effect or not False
| Audiostock Dataset Length:5422 | Epoch Length: 5422
Found 152 tracks.
sr=44100, min: 10, max: 600
Keeping 151 of 152 tracks
Data size: 3249
Use mixup rate of 0.0; Use SpecAug (T,F) of (0, 0); Use blurring effect or not False
| Audiostock Dataset Length:3249 | Epoch Length: 3249
train_loader: 4385
batch.keys(): ['fname', 'fbank_stems', 'waveform_stems', 'waveform', 'fbank']
wav_stems: torch.Size([6, 4, 524288])
wav_mix: torch.Size([6, 524288])
Batch=6, Stems=4, Time=524288


## 3. CAE test

In [61]:
ae = EncoderDecoder(device=device)
print(f"\n stem_num {S} ")

stem_names = cfg['model']['params']['stem_names']
latents_list = []
encode_shapes = []

for s in range(S):
    stem_name = stem_names[s] if s < len(stem_names) else f"stem_{s}"
    print(f"\n {stem_name} stem {s}:")
    stem_audio = wav_stems[:, s].cpu().numpy()  
    print(f"stem_audio : {stem_audio.shape}")

    stem_latents = ae.encode(stem_audio)
    if isinstance(stem_latents, np.ndarray):
        stem_latents = torch.from_numpy(stem_latents)
    
    print(f"stem_latents.shape {stem_latents.shape}, tem_latents.dtype {stem_latents.dtype}, range: [{stem_latents.min():.3f}, {stem_latents.max():.3f}]")
    latents_list.append(stem_latents)
    encode_shapes.append(stem_latents.shape)


latents_stacked = torch.stack(latents_list, dim=1)  # (B, S, C, L)
print(f"latents_stacked.shape: {latents_stacked.shape}")
print(f" Batch={latents_stacked.shape[0]}, Stems={latents_stacked.shape[1]}, Channels={latents_stacked.shape[2]}, Length={latents_stacked.shape[3]}")
latents = latents_stacked.to(device)
print(f" latents on : {device}")

/data1/yuchen/cd4mt/src/music2latent/music2latent

 stem_num 4 

 bass stem 0:
stem_audio : (6, 524288)
stem_latents.shape torch.Size([6, 64, 127]), tem_latents.dtype torch.float16, range: [-4.656, 4.070]

 drums stem 1:
stem_audio : (6, 524288)
stem_latents.shape torch.Size([6, 64, 127]), tem_latents.dtype torch.float16, range: [-4.254, 3.434]

 guitar stem 2:
stem_audio : (6, 524288)
stem_latents.shape torch.Size([6, 64, 127]), tem_latents.dtype torch.float16, range: [-4.598, 3.605]

 piano stem 3:
stem_audio : (6, 524288)
stem_latents.shape torch.Size([6, 64, 127]), tem_latents.dtype torch.float16, range: [-5.102, 4.488]
latents_stacked.shape: torch.Size([6, 4, 64, 127])
 Batch=6, Stems=4, Channels=64, Length=127
 latents on : cuda:3


In [62]:
recst_list = []
for s in range(S):
    stem_name = stem_names[s] if s < len(stem_names) else f"stem_{s}"
    print(f"\nDecode {stem_name}")
    stem_latents = latents[:, s].cpu().numpy()  # (B, C, L)
    
    try:
        recst = ae.decode(stem_latents)
        print(f"recst.shape: {recst.shape}")
        print(f"range: [{recst.min():.3f}, {recst.max():.3f}]")

        if isinstance(recst, torch.Tensor):
            recst = recst.cpu().numpy() 
        current_length = recst.shape[-1]
        if current_length > T:
            excess = current_length - T
            start_trim = excess // 2
            end_trim = excess - start_trim
            recst = recst[..., start_trim:current_length-end_trim]
            
        elif current_length < T:
            deficit = T - current_length
            pad_left = deficit // 2
            pad_right = deficit - pad_left
            recst = np.pad(recst, ((0,0), (pad_left, pad_right)), mode='constant', constant_values=0)
        
        recst_list.append(recst)
        
    except Exception as e:
        raise e

recst_aud = np.stack(recst_list, axis=1)  # (B, S, T')
recst_tensor = torch.from_numpy(recst_aud).to(device)

print(f"\nrecst shape: {recst_aud.shape}")
print(f"original lenght: {T}, recst length: {recst_aud.shape[2]}")

if recst_aud.shape[2] == T:
    mse_error = np.mean((wav_stems.cpu().numpy() - recst_aud)**2)
    print(f"MSE: {mse_error:.6f}")
else:
    print("length error")


Decode bass
recst.shape: torch.Size([6, 521728])
range: [-0.262, 0.248]

Decode drums
recst.shape: torch.Size([6, 521728])
range: [-0.529, 0.462]

Decode guitar
recst.shape: torch.Size([6, 521728])
range: [-0.510, 0.624]

Decode piano
recst.shape: torch.Size([6, 521728])
range: [-0.340, 0.366]

recst shape: (6, 4, 524288)
original lenght: 524288, recst length: 524288
MSE: 0.002353


## 4. CD4MT 扩散模型初始化

In [63]:
# !pip install flash_attn
# !pip install piq
# !pip install blobfile

In [64]:
from src.cm.script_util import create_model_and_diffusion, model_and_diffusion_defaults
import torch.nn as nn

with open(CFG_PATH, "r") as f:
    cfg_fresh = yaml.safe_load(f)

cm_model_params = model_and_diffusion_defaults()

# Merge override from YAML if present, else derive from UNet config
ovr = (cfg_fresh.get('model',{}).get('params',{}).get('cm_model_override'))
if not ovr:
    u = cfg_fresh['model']['params']['unet_']['params']
    def _to_str_list(v):
        if isinstance(v,str): return v
        if isinstance(v,(list,tuple)): return ','.join(str(int(x)) for x in v)
        return str(v)
    ovr = {
        'image_size': u.get('image_size', cm_model_params['image_size']),
        'num_channels': u.get('model_channels', cm_model_params['num_channels']),
        'num_res_blocks': u.get('num_res_blocks', cm_model_params['num_res_blocks']),
        'channel_mult': _to_str_list(u.get('channel_mult','1,2,4')),
        'num_heads': u.get('num_heads', cm_model_params['num_heads']),
        'num_head_channels': u.get('num_head_channels', 32),
        'num_heads_upsample': -1,
        'attention_resolutions': _to_str_list(u.get('attention_resolutions',[8,4,2])),
        'dropout': u.get('dropout', 0.1),
        'class_cond': False,
        'use_checkpoint': False,
        'use_scale_shift_norm': True,
        'resblock_updown': False,
        'use_fp16': False,
        'use_new_attention_order': False,
        'learn_sigma': False,
        'weight_schedule': 'karras',
        'sigma_min': cfg_fresh['model']['params'].get('sigma_min', cm_model_params['sigma_min']),
        'sigma_max': cfg_fresh['model']['params'].get('sigma_max', cm_model_params['sigma_max']),
    }
else:
    ovr = dict(ovr)
# Enforce flash-attn constraint
ovr['num_head_channels'] = int(ovr.get('num_head_channels',32))
if ovr['num_head_channels'] not in (16,32,64):
    ovr['num_head_channels'] = 32
# Normalize string fields
for k in ['channel_mult','attention_resolutions']:
    if k in ovr and not isinstance(ovr[k], str):
        if isinstance(ovr[k], (list,tuple)):
            ovr[k] = ','.join(str(int(x)) for x in ovr[k])
        else:
            ovr[k] = str(ovr[k])
cm_model_params.update(ovr)

# Sanity check divisibility per level
_ch_mult = [int(x) for x in cm_model_params['channel_mult'].split(',')]
assert all((cm_model_params['num_channels']*m) % cm_model_params['num_head_channels'] == 0 for m in _ch_mult),     f"num_channels*channel_mult must be divisible by num_head_channels; got num_channels={cm_model_params['num_channels']}, channel_mult={_ch_mult}, num_head_channels={cm_model_params['num_head_channels']}"

# Build model and diffusion
in_ch = cfg_fresh['model']['params']['unet_']['params']['in_channels']
out_ch = cfg_fresh['model']['params']['unet_']['params']['out_channels']
sigma_data = cfg_fresh['model']['params'].get('diffusion_sigma_data', 0.5)
model, diffusion = create_model_and_diffusion(distillation=False, **cm_model_params)
diffusion.sigma_data = sigma_data

# Speed-ups: loss=L2, convert model to fp16 torso
try:
    diffusion.loss_norm = 'l2'
    print('[CT] diffusion.loss_norm = l2')
except Exception as e:
    print('[CT] set loss_norm failed:', e)
try:
    if hasattr(model,'convert_to_fp16'): model.convert_to_fp16()
    print('[CT] model.convert_to_fp16() done')
except Exception as e:
    print('[CT] convert_to_fp16 failed:', e)

# Patch I/O convs to match diffusion channels
new_in = nn.Conv2d(in_ch, model.input_blocks[0][0].out_channels,
                   kernel_size=model.input_blocks[0][0].kernel_size,
                   stride=model.input_blocks[0][0].stride,
                   padding=model.input_blocks[0][0].padding,
                   bias=model.input_blocks[0][0].bias is not None)
nn.init.kaiming_normal_(new_in.weight, mode='fan_out', nonlinearity='relu')
if new_in.bias is not None: nn.init.zeros_(new_in.bias)
model.input_blocks[0][0] = new_in

new_out = nn.Conv2d(model.out[-1].in_channels, out_ch,
                    kernel_size=model.out[-1].kernel_size,
                    stride=model.out[-1].stride,
                    padding=model.out[-1].padding,
                    bias=model.out[-1].bias is not None)
nn.init.zeros_(new_out.weight)
if new_out.bias is not None: nn.init.zeros_(new_out.bias)
model.out[-1] = new_out

model = model.to(device).eval()
print('CM UNet params:', cm_model_params)


[CT] diffusion.loss_norm = l2
[CT] model.convert_to_fp16() done
CM UNet params: {'sigma_min': 0.0001, 'sigma_max': 3.0, 'image_size': 32, 'num_channels': 192, 'num_res_blocks': 2, 'num_heads': 6, 'num_heads_upsample': -1, 'num_head_channels': 32, 'attention_resolutions': '16,8,4', 'channel_mult': '1,2,4', 'dropout': 0.1, 'class_cond': False, 'use_checkpoint': False, 'use_scale_shift_norm': True, 'resblock_updown': False, 'use_fp16': False, 'use_new_attention_order': False, 'learn_sigma': False, 'weight_schedule': 'karras'}


## 4. Consistency Training (CAE → Karras)


In [65]:
import math
import copy
import random

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

import swanlab

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

def reshape_latents(latent_tensor: torch.Tensor):
    B, S, C, L = latent_tensor.shape
    flat = latent_tensor.view(B, S * C, L)
    side = int(math.sqrt(L))
    if side * side < L:
        side += 1
    pad = side * side - L
    if pad > 0:
        flat = F.pad(flat, (0, pad))
    imgs = flat.view(B, S * C, side, side)
    return imgs, {'side': side, 'pad': pad, 'latent_len': L}


In [66]:
# Helper: encode a dataloader batch of waveforms to 2D imgs for CM (no interface mismatch)
import math
import torch
import numpy as np
import torch.nn.functional as F

def encode_batch_to_imgs(ae, batch, to_float32=True):
    # batch['waveform_stems']: (B, S, T)
    wav_stems = batch['waveform_stems']
    if isinstance(wav_stems, np.ndarray):
        wav_stems = torch.from_numpy(wav_stems)
    if wav_stems.dim() != 3:
        raise ValueError(f"expected wav_stems of shape (B,S,T); got {tuple(wav_stems.shape)}")

    B, S, T = wav_stems.shape
    latents_list = []
    for s in range(S):
        # (B, T) -> CAE encode 将 B 视作 audio_channels，输出 (B, C, L)
        stem_audio = wav_stems[:, s].cpu().numpy()
        stem_lat = ae.encode(stem_audio)  # (B, C, L)
        if isinstance(stem_lat, np.ndarray):
            stem_lat = torch.from_numpy(stem_lat)
        if to_float32:
            stem_lat = stem_lat.to(torch.float32)
        latents_list.append(stem_lat)

    # 堆叠成 (B, S, C, L)
    latents = torch.stack(latents_list, dim=1).contiguous()

    # reshape: (B, S*C, H, W)，H=W=ceil(sqrt(L))
    B, S, C, L = latents.shape
    flat = latents.view(B, S * C, L)
    side = int(math.sqrt(L))
    if side * side < L:
        side += 1
    pad = side * side - L
    if pad > 0:
        flat = F.pad(flat, (0, pad))
    imgs = flat.view(B, S * C, side, side)
    return imgs, {"side": side, "pad": pad, "latent_len": L}



In [67]:
from src.cm.script_util import create_model
from src.cm.karras_diffusion import KarrasDenoiser
from src.cm.nn import update_ema

latents_cpu = latents.detach().to(torch.float32).cpu()
imgs_sample, reshape_meta = reshape_latents(latents_cpu)
print(f'Latent reshape -> {imgs_sample.shape}, meta={reshape_meta}')

# steps_per_epoch will use dm.train_dataloader() later in training cell

model = create_model(
    image_size=imgs_sample.shape[-1],
    num_channels=64,
    num_res_blocks=1,
    channel_mult='1',
    learn_sigma=False,
    class_cond=False,
    use_checkpoint=False,
    attention_resolutions='1024',
    num_heads=1,
    num_head_channels=-1,
    num_heads_upsample=-1,
    use_scale_shift_norm=True,
    dropout=0.05,
    resblock_updown=False,
    use_fp16=False,
    use_new_attention_order=False,
)

in_channels = imgs_sample.shape[1]
model.input_blocks[0][0] = torch.nn.Conv2d(
    in_channels,
    model.input_blocks[0][0].out_channels,
    kernel_size=model.input_blocks[0][0].kernel_size,
    stride=model.input_blocks[0][0].stride,
    padding=model.input_blocks[0][0].padding,
    bias=model.input_blocks[0][0].bias is not None,
)
model.out[-1] = torch.nn.Conv2d(
    model.out[-1].in_channels,
    in_channels,
    kernel_size=model.out[-1].kernel_size,
    stride=model.out[-1].stride,
    padding=model.out[-1].padding,
    bias=model.out[-1].bias is not None,
)

diffusion = KarrasDenoiser(
    sigma_data=ct_hparams['sigma_data'],
    sigma_min=ct_hparams['sigma_min'],
    sigma_max=ct_hparams['sigma_max'],
    weight_schedule='karras',
    loss_norm='l2',
)
device = device  # unified
model.to(device)

target_model = copy.deepcopy(model)
for p in target_model.parameters():
    p.requires_grad_(False)
target_model.to(device)
target_model.eval()

optimizer = torch.optim.Adam(model.parameters(), lr=ct_hparams['lr'])



Latent reshape -> torch.Size([6, 256, 12, 12]), meta={'side': 12, 'pad': 17, 'latent_len': 127}


In [68]:
# Sanity check: single batch forward/backward to validate pipeline
batch0 = next(iter(dm.train_dataloader()))
assert 'waveform_stems' in batch0
imgs0, meta0 = encode_batch_to_imgs(ae, batch0)
print('Sanity imgs:', imgs0.shape, meta0)
imgs0 = imgs0.to(device)

from torch.cuda.amp import autocast, GradScaler
_tmp_opt = torch.optim.Adam(model.parameters(), lr=float(ct_hparams['lr']))
_tmp_scaler = GradScaler()

model.train()
with autocast(dtype=torch.float16):
    _outs = diffusion.consistency_losses(
        model,
        imgs0.half(),
        num_scales=ct_hparams['num_scales'],
        target_model=target_model,
        teacher_model=None,
        teacher_diffusion=None,
    )
    _loss = _outs['loss'].mean()

_tmp_opt.zero_grad(); _tmp_scaler.scale(_loss).backward(); _tmp_scaler.unscale_(_tmp_opt)
import torch as _th
_th.nn.utils.clip_grad_norm_(model.parameters(), ct_hparams['grad_clip'])
_tmp_scaler.step(_tmp_opt); _tmp_scaler.update()
print(f"Sanity check OK. loss={float(_loss):.6f}")



Sanity imgs: torch.Size([6, 256, 12, 12]) {'side': 12, 'pad': 17, 'latent_len': 127}
Sanity check OK. loss=0.173927


  _tmp_scaler = GradScaler()
  with autocast(dtype=torch.float16):


In [69]:
# # CM training loop (epoch-level avg; deterministic val; best EMA ckpt)
# import torch, os, copy, numpy as np
# from torch.cuda.amp import autocast, GradScaler

# ae = EncoderDecoder(device=device)
# optimizer = torch.optim.Adam(model.parameters(), lr=float(ct_hparams['lr']))
# scaler = GradScaler()

# from src.cm.nn import update_ema
# # persistent EMA target model
# target_model = copy.deepcopy(model).to(device)
# for p in target_model.parameters():
#     p.requires_grad_(False)
# target_model.eval()

# best, best_path = float('inf'), None
# os.makedirs('checkpoints', exist_ok=True)
# import time, swanlab
# run = swanlab.init(project="cd4mt", experiment_name=f"ct_unet_{time.strftime('%m%d_%H%M%S')}", config={**ct_hparams, "num_params": sum(p.numel() for p in model.parameters())})
# _total_step = 0


# for epoch in range(ct_hparams['epochs']):
#     # Train
#     model.train();
#     train_loader = dm.train_dataloader()
#     run, steps = 0.0, 0
#     for step, batch in enumerate(train_loader):
#         imgs, meta = encode_batch_to_imgs(ae, batch)
#         imgs = imgs.to(device)
#         with autocast(dtype=torch.float16):
#             losses = diffusion.consistency_losses(
#                 model,
#                 imgs.half(),
#                 num_scales=ct_hparams['num_scales'],
#                 target_model=target_model,
#                 teacher_model=None,
#                 teacher_diffusion=None,
#             )
#             loss = losses['loss'].mean()
#         optimizer.zero_grad()
#         scaler.scale(loss).backward()
#         scaler.unscale_(optimizer)
#         torch.nn.utils.clip_grad_norm_(model.parameters(), ct_hparams['grad_clip'])
#         scaler.step(optimizer)
#         scaler.update()
#         update_ema(target_model.parameters(), model.parameters(), rate=ct_hparams['ema_decay'])
#         run += float(loss.item()); steps += 1
#         _total_step += 1
#         if step % ct_hparams['log_interval'] == 0:
#             swanlab.log({"train/loss": float(loss.item()), "epoch": epoch, "step": _total_step})
#             print(f"[train] epoch={epoch} step={step}/{len(train_loader)} loss={loss.item():.4f}")
#     avg_train = run / max(1, steps)
#     swanlab.log({"train/avg": float(avg_train), "epoch": epoch})
#     print(f"[train] epoch={epoch} avg={avg_train:.4f}")

#     # Val (deterministic)
#     model.eval();
#     val_loader = dm.val_dataloader()
#     vrun, vsteps = 0.0, 0
#     # save/restore RNG
#     _cpu = torch.get_rng_state(); _cuda = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None
#     torch.manual_seed(12345); np.random.seed(12345)
#     if torch.cuda.is_available(): torch.cuda.manual_seed_all(12345)
#     with torch.no_grad():
#         for vstep, vbatch in enumerate(val_loader):
#             vimgs, vmeta = encode_batch_to_imgs(ae, vbatch)
#             vimgs = vimgs.to(device)
#             gen = torch.Generator(device=vimgs.device); gen.manual_seed(12345 + vstep)
#             vnoise = torch.randn_like(vimgs, generator=gen)
#             with autocast(dtype=torch.float16):
#                 vloss = diffusion.consistency_losses(
#                     model,
#                     vimgs.half(),
#                     num_scales=ct_hparams['num_scales'],
#                     target_model=target_model,
#                     teacher_model=None,
#                     teacher_diffusion=None,
#                     noise=vnoise,
#                 )['loss'].mean()
#             vrun += float(vloss.item()); vsteps += 1
#     # restore RNG
#     torch.set_rng_state(_cpu); 
#     if _cuda is not None: torch.cuda.set_rng_state_all(_cuda)
#     avg_val = vrun / max(1, vsteps)
#     swanlab.log({"val/avg": float(avg_val), "epoch": epoch})
#     print(f"[val]  epoch={epoch} avg={avg_val:.4f}")

#     # Save best by val
#     if avg_val < best:
#         best = avg_val
#         best_path = f"checkpoints/ct_unet_ema_best_val{best:.6f}.pth"
#         torch.save({'state_dict': target_model.state_dict(), 'epoch': epoch, 'ct_hparams': ct_hparams, 'meta': meta}, best_path)
#         swanlab.log({"val/best": float(best)})
#         print(f"[ckpt] saved best -> {best_path}")

# print(f"Best ckpt: {best_path}, val={best:.6f}")

# swanlab.finish()



In [70]:
# Cleanup CUDA memory in Jupyter
import gc, torch

def cleanup_cuda(devices=None, names=None):
    # 1) 尝试把可能的大对象搬回 CPU 并移除全局引用
    if names is None:
        names = [
            'model', 'target_model', 'optimizer', 'scaler',
            'train_loader', 'val_loader', 'dm',
            'imgs', 'vimgs', 'latents', 'latents_cpu', 'latents_stacked', 'imgs_sample',
            'ae',  # EncoderDecoder 容器
        ]
    for n in names:
        obj = globals().pop(n, None)
        if obj is None:
            continue
        try:
            # nn.Module / Tensor-like
            if hasattr(obj, 'to'):
                obj.to('cpu')
            # EncoderDecoder: 把子模块搬回 CPU
            if hasattr(obj, 'gen') and hasattr(obj.gen, 'to'):
                obj.gen.to('cpu')
        except Exception as e:
            print(f'[cleanup] skip moving {n}: {e}')
        del obj
    # 2) 垃圾回收 + 清空各卡缓存
    gc.collect()
    if devices is None:
        devices = list(range(torch.cuda.device_count()))
    for i in devices:
        try:
            with torch.cuda.device(i):
                torch.cuda.empty_cache()
        except Exception:
            pass
    try:
        torch.cuda.ipc_collect()
    except Exception:
        pass
    print('[cleanup] done')

# 针对你现在的情形（可能 0 和 4 被占）
cleanup_cuda(devices=[0, 4])

[cleanup] done


In [None]:
# Quick small-epoch run with param summary (auto-appended)
import os, yaml, torch
from ldm.data.multitrack_datamodule import DataModuleFromConfig
from ldm.models.diffusion.cd4mt_diffusion import ScoreDiffusionModel
from ldm.modules.util import summarize_params
from pytorch_lightning.callbacks import ModelCheckpoint
import pytorch_lightning as pl

CFG_PATH = os.getenv('CFG_PATH', 'configs/cd4mt_small_plus.yaml')
print(f'Using config: {CFG_PATH}')
with open(CFG_PATH, 'r') as f:
    cfg = yaml.safe_load(f)

# Data
dm = DataModuleFromConfig(**cfg['data']['params'])
dm.prepare_data(); dm.setup('fit')

# Model
train_model_config = cfg['model']['params'].copy()
train_unet_config = train_model_config.pop('unet_')
model = ScoreDiffusionModel(unet_config=train_unet_config, **train_model_config)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)

# Param summary
try:
    summarize_params(model.audio_diffusion, name='AudioDiffusionModel_2d')
    if hasattr(model.audio_diffusion, 'unet'):
        summarize_params(model.audio_diffusion.unet, name='UNet (inner)')
except Exception as e:
    print('summarize_params failed:', e)

# Trainer (small)
checkpoint_dir = './training_logs/checkpoints_nb'
ckpt_cb = ModelCheckpoint(dirpath=checkpoint_dir, filename='nb-test-{epoch:02d}-{step:04d}', save_top_k=1, save_last=True, every_n_epochs=1, verbose=True)
trainer = pl.Trainer(
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1,
    max_epochs=1,
    limit_train_batches=2,
    limit_val_batches=1,
    num_sanity_val_steps=0,
    logger=False,
    callbacks=[ckpt_cb],
    precision=cfg.get('trainer', {}).get('precision', 16),
)

trainer.fit(model, dm)



Using config: configs/cd4mt_small_plus.yaml
Found 1290 tracks.
sr=44100, min: 10, max: 600
Keeping 1289 of 1290 tracks
Data size: 26309
Use mixup rate of 0.0; Use SpecAug (T,F) of (0, 0); Use blurring effect or not False
| Audiostock Dataset Length:26309 | Epoch Length: 26309
Found 271 tracks.
sr=44100, min: 10, max: 600
Keeping 270 of 271 tracks
Data size: 5422
Use mixup rate of 0.0; Use SpecAug (T,F) of (0, 0); Use blurring effect or not False
| Audiostock Dataset Length:5422 | Epoch Length: 5422
Found 152 tracks.
sr=44100, min: 10, max: 600
Keeping 151 of 152 tracks
Data size: 3249
Use mixup rate of 0.0; Use SpecAug (T,F) of (0, 0); Use blurring effect or not False
| Audiostock Dataset Length:3249 | Epoch Length: 3249


ValueError: EncoderDecoder requires explicit 'device' (e.g., torch.device('cuda:4') or 'cpu')

[amax:471158] tcp_peer_recv_connect_ack: invalid header type: 114
[amax:471158] tcp_peer_recv_connect_ack: invalid header type: 114
[amax:471158] tcp_peer_recv_connect_ack: invalid header type: 0
[amax:471158] tcp_peer_recv_connect_ack: invalid header type: 0
[amax:471158] tcp_peer_recv_connect_ack: invalid header type: 97
[amax:471158] tcp_peer_recv_connect_ack: invalid header type: 116
[amax:471158] tcp_peer_recv_connect_ack: invalid header type: 116
[amax:471158] tcp_peer_recv_connect_ack: invalid header type: 97
[amax:471158] [[23851,0],0] tcp_peer_recv_blocking: recv() failed for [[513,790],16777472]: Connection reset by peer (104)
[amax:471158] tcp_peer_recv_connect_ack: invalid header type: 34
[amax:471158] tcp_peer_recv_connect_ack: invalid header type: 49
[amax:471158] tcp_peer_recv_connect_ack: invalid header type: 214
[amax:471158] tcp_peer_recv_connect_ack: invalid header type: 49
[amax:471158] tcp_peer_recv_connect_ack: invalid header type: 122
[amax:471158] tcp_peer_recv_

: 

In [None]:
# CM UNet param size only (no training)
import os, yaml, torch, numpy as np
from src.cm.script_util import create_model
from ldm.modules.util import summarize_params

CFG_PATH = os.getenv('CFG_PATH', 'configs/cd4mt_small_plus.yaml')
with open(CFG_PATH, 'r') as f:
    cfg = yaml.safe_load(f)

S = int(cfg.get('model', {}).get('params', {}).get('num_stems', 4))
C_lat = int(cfg.get('model', {}).get('params', {}).get('cae_z_channels', 64))
in_channels = S * C_lat
cmov = cfg.get('model', {}).get('params', {}).get('cm_model_override', {})

args = dict(
    image_size=cmov.get('image_size', 32),
    num_channels=cmov.get('num_channels', 128),
    num_res_blocks=cmov.get('num_res_blocks', 2),
    channel_mult=cmov.get('channel_mult', '1,2,4'),
    learn_sigma=cmov.get('learn_sigma', False),
    class_cond=cmov.get('class_cond', False),
    use_checkpoint=cmov.get('use_checkpoint', False),
    attention_resolutions=cmov.get('attention_resolutions', '16,8,4'),
    num_heads=cmov.get('num_heads', 4),
    num_head_channels=cmov.get('num_head_channels', 32),
    num_heads_upsample=cmov.get('num_heads_upsample', -1),
    use_scale_shift_norm=cmov.get('use_scale_shift_norm', True),
    dropout=cmov.get('dropout', 0.1),
    resblock_updown=cmov.get('resblock_updown', False),
    use_fp16=cmov.get('use_fp16', False),
    use_new_attention_order=cmov.get('use_new_attention_order', False),
)

cm_unet = create_model(**args)
# Adjust I/O conv to match multitrack channels
try:
    if hasattr(cm_unet, 'input_blocks') and hasattr(cm_unet.input_blocks[0][0], 'weight'):
        orig_in = cm_unet.input_blocks[0][0]
        cm_unet.input_blocks[0][0] = torch.nn.Conv2d(
            in_channels, orig_in.out_channels,
            kernel_size=orig_in.kernel_size, stride=orig_in.stride, padding=orig_in.padding,
            bias=(orig_in.bias is not None),
        )
    if hasattr(cm_unet, 'out') and hasattr(cm_unet.out[-1], 'weight'):
        orig_out = cm_unet.out[-1]
        cm_unet.out[-1] = torch.nn.Conv2d(
            orig_out.in_channels, in_channels,
            kernel_size=orig_out.kernel_size, stride=orig_out.stride, padding=orig_out.padding,
            bias=(orig_out.bias is not None),
        )
except Exception as e:
    print('Adjust I/O conv failed:', e)

summarize_params(cm_unet, name='CM.UNet (standalone)')



[Model] CM.UNet (standalone): params=67,550,336 (trainable=67,550,336), param_mem≈257.7 MB (dtype=torch.float32)


{'name': 'CM.UNet (standalone)',
 'total_params': 67550336,
 'trainable_params': 67550336,
 'buffers': 0,
 'bytes_per_param': 4,
 'param_mem_mb': 257.68408203125,
 'dtype': 'torch.float32'}