In [None]:
import os
import hydra
import torch
import torchaudio
from stable_audio_tools.inference.generation import generate_diffusion_cond

In [None]:
# params
seed = 42

num_samples = 1
exp_cfg = "train_moisesdb_controlnet_audio_large"
ckpt_path = "../ckpts/moisesdb-audio/last.ckpt"
dataset_path = "../data/moisesdb/"

# load config
with hydra.initialize(config_path="..", version_base=None):
    cond_cfg = hydra.compose(config_name="config", overrides=[f'exp={exp_cfg}',
                                                              f'datamodule.val_dataset.path={dataset_path}/19.tar', 
                                                              f'datamodule.train_dataset.path={dataset_path}/19.tar',
                                                              f'datamodule.batch_size_train=1',
                                                              f'datamodule.batch_size_val=1'])
    
# init model
model = hydra.utils.instantiate(cond_cfg["model"])
ckpt = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(ckpt['state_dict'], strict=False)
model = model.cuda()

# load dataloader
datamodule = hydra.utils.instantiate(cond_cfg["datamodule"])
val_dataloader = datamodule.val_dataloader()

In [None]:
def generate_and_save(b, out_path):
    _, y, prompts, start_seconds, total_seconds = b
    y = torch.clip(y, -1, 1)

    num_samples = 1

    conditioning = [{
        "audio": y[i:i+1].cuda(),
        "prompt": prompts[i],
        "seconds_start": start_seconds[i],
        "seconds_total": total_seconds[i],
    } for i in range(num_samples)]

    output = generate_diffusion_cond(
            model.model,
            seed=seed,
            batch_size=num_samples,
            steps=100,
            cfg_scale=7.0,
            conditioning=conditioning,
            sample_size=y.shape[-1],
            sigma_min=0.3,
            sigma_max=500,
            sampler_type="dpmpp-3m-sde",
            device="cuda"
        )

    with open(os.path.join(out_path, "prompt.txt"), "w") as file:
        file.write(prompts[0])
    torchaudio.save(os.path.join(out_path, "input.wav"), y[0].cpu(), sample_rate=44100)
    torchaudio.save(os.path.join(out_path, "output.wav"), output[0].cpu(), sample_rate=44100)
    torchaudio.save(os.path.join(out_path, "mix.wav"), y[0].cpu() + output[0].cpu(), sample_rate=44100)

In [None]:
# save results
out_path = "out"

if out_path not in os.listdir():
    os.mkdir(out_path)

b = next(iter(val_dataloader))
generate_and_save(b, out_path)