In [1]:
import os

import hydra
import torch
import torchaudio
from stable_audio_tools.inference.generation import generate_diffusion_cond

In [None]:
# params
seed = 42
num_samples = 2
exp_cfg = "train_musdb_controlnet_audio_large"
ckpt_path = "../ckpts/musdb-audio/epoch=192-valid_loss=0.418.ckpt"
dataset_path = "../data/musdb18hq/"

# load config
with hydra.initialize(config_path="..", version_base=None):
    cond_cfg = hydra.compose(config_name="config", overrides=[f'exp={exp_cfg}',
                                                              f'datamodule.val_dataset.path={dataset_path}/test.tar', 
                                                              f'datamodule.train_dataset.path={dataset_path}/train.tar'])
    
# init model
model = hydra.utils.instantiate(cond_cfg["model"])
ckpt = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(ckpt['state_dict'], strict=False)
model = model.cuda()

# load dataloader
datamodule = hydra.utils.instantiate(cond_cfg["datamodule"])
val_dataloader = datamodule.val_dataloader()

In [3]:
# load conditioning (replace with your audio and prompts; prompts must follow "in: stems; out:stems" structure)

_, y, prompts, start_seconds, total_seconds = next(iter(val_dataloader))
y = torch.clip(y, -1, 1)
num_samples = min(num_samples, y.shape[0])

conditioning = [{
    "audio": y[i:i+1].cuda(),
    "prompt": prompts[i],
    "seconds_start": start_seconds[i],
    "seconds_total": total_seconds[i],
} for i in range(num_samples)]

In [None]:
# generate 

output = generate_diffusion_cond(
            model.model,
            seed=seed,
            batch_size=num_samples,
            steps=100,
            cfg_scale=7.0,
            conditioning=conditioning,
            sample_size=y.shape[-1],
            sigma_min=0.3,
            sigma_max=500,
            sampler_type="dpmpp-3m-sde",
            device="cuda"
        )

In [5]:
# save results
if "out" not in os.listdir():
    os.mkdir("out")

for i in range(num_samples):
    prompt = {prompts[i].replace(" ", "")}
    torchaudio.save(f"out/input_{i}_prompt_{prompt}.wav", y[i].cpu(), sample_rate=44100)
    torchaudio.save(f"out/output_{i}_prompt_{prompt}.wav", output[i].cpu(), sample_rate=44100)
    torchaudio.save(f"out/mix_{i}_prompt_{prompt}.wav", y[i].cpu() + output[i].cpu(), sample_rate=44100)