In [100]:
# Imports
import torch
import torchaudio
import torchcrepe
from einops import rearrange
from stable_audio_tools import get_pretrained_model
from stable_audio_tools.inference.generation import generate_diffusion_cond
from stable_audio_tools.inference import sampling
from tqdm.notebook import trange

import numpy as np
import math
import pretty_midi
import sounddevice as sd
import librosa
import matplotlib.pyplot as plt
from IPython.display import Audio
import soundfile as sf
from torch.nn import MSELoss

# Custom Helpers
from audio_helpers import get_github_audio, plot_pitch, plot_pitch_comparison, animate_pitch_arrays, create_gradio_interface

device = "cuda" if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else "cpu"
print("Using {}".format(device))

Using cuda


In [None]:
@torch.enable_grad()
def calculate_pitch(audio, sample_rate, model_type='tiny'):
    # Compute pitch
    if isinstance(audio, np.ndarray):
        audio = torch.tensor(audio, device=device, dtype=torch.float32)

    if audio.ndim > 1:
        audio = audio.to(device=device, dtype=torch.float32).mean(dim=0, keepdim=True)

    hop = int(sample_rate / 500.)  # 5 ms hop
    pitch, periodicity = torchcrepe.predict(audio, sample_rate, hop_length=hop, fmin=50, fmax=550,
                            model=model_type, # or 'full'
                            batch_size=4096, device=device, return_periodicity=True, decoder=torchcrepe.decode.soft_argmax, differentiable=True)
    # Clean up pitch
    win_l = 3
    periodicity = torchcrepe.filter.median(periodicity, win_l)
    # periodicity = torchcrepe.threshold.Silence(-60.)(periodicity, audio, sample_rate, hop)
    pitch = torchcrepe.threshold.At(.5)(pitch, periodicity)
    pitch = torchcrepe.filter.mean(pitch, win_l)
    return pitch

In [None]:
# Download model | Stable Audio Open Small
# `https://huggingface.co/stabilityai/stable-audio-open-small`
model, model_config = get_pretrained_model("stabilityai/stable-audio-open-small")
sample_rate = model_config["sample_rate"]
sample_size = model_config["sample_size"]

model = model.to(device)

In [None]:
# if you don't have the audio file download it
get_github_audio("https://github.com/pdx-cs-sound/wavs/raw/refs/heads/main/gc.wav")

target_audio, target_sr = torchaudio.load('../data/audio/gc.wav')
if sample_rate != target_sr: # Resample to model rate
    resampler = torchaudio.transforms.Resample(sample_rate, target_sr)
    target_audio = resampler(target_audio)

# Reduce to this really specific time that stable audio open small has
time_sec = 11.888616780045352
target_audio = target_audio[:, :int(time_sec * sample_rate)]
target_pitch = calculate_pitch(target_audio, sample_rate, 'tiny')

print(f"Target length is: {target_audio.shape[1] / sample_rate}")

In [None]:
# now lets plot the target_pitch
plot_pitch(target_pitch, sample_rate)

In [93]:
from functools import partial

def pitch_callback(autoencoder, target_pitch, base_step_scale, alpha, inner_strength, in_dict):
    x, denoised, t = in_dict['x'], in_dict['denoised'], in_dict['t'].detach().float()

    step = in_dict['i'] # grab step
    pbar = in_dict.get('pbar', None) # I updated stable-audio-tools to return this in in_dict

    # PnP-Flow schedule: (1 - t)^alpha
    time_weight = (1.0 - t) ** alpha
    step_scale  = base_step_scale * time_weight
    inner_strength = inner_strength
    # print(f"time_weight = {float(time_weight):.4f}, step_scale = {float(step_scale):.4f}, inner_strength = {float(inner_strength):.2f}")
    
    with torch.enable_grad():
        denoised.requires_grad_(True)
        # print(f"t = {in_dict['t']:.3f}, denoised..shape, .requires_grad = {denoised.shape},, {denoised.requires_grad}")

        pred_audio = autoencoder.decoder(denoised.half()).float() # convert to float 32

        shape_save = pred_audio.shape[-1]
        kernel_size = 64
        pred_audio = torch.nn.functional.avg_pool1d(
            pred_audio, kernel_size=kernel_size, stride=1, padding=kernel_size//2
        )
        pred_audio = pred_audio[..., :shape_save]

        # print(f"Decoder Audio \nshape: {audio.shape} \ndtype: {audio.dtype}")
        pred_audio = rearrange(pred_audio, "b d n -> d (b n)")
        # print("Generated Audio shape:", audio.shape, f"Generated Audio length: {(audio.shape[1]/sample_rate):.2f}")

        # Display audio at each step
        diplay_audio = pred_audio.div(torch.max(torch.abs(pred_audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() 
        # how stable audio converted it
        # display(Audio(diplay_audio.numpy(), rate=sample_rate))

        # Compute pitch
        pred_audio = pred_audio.mean(dim=0, keepdim=True)
        pitch = calculate_pitch(pred_audio, sample_rate, 'tiny')

        loss_fn = MSELoss()
        loss = loss_fn(torch.nan_to_num(pitch, nan=0.),
                        torch.nan_to_num(target_pitch, nan=0.))
        # print(f"diff scale {torch.nan_to_num(pitch, nan=0.).std()}")
        # loss = mse_loss + 0.01 * (1.0 / (torch.nan_to_num(pitch, nan=0.).std() + 1e-6))

        # Graph both pitch
        if step % 10 == 0:  # lets save every ten
            pitch_array.append(pitch.detach().cpu().numpy()) # make sure to use this in jupter for external context
            audio_array.append(diplay_audio.detach().cpu().numpy())
            # plot_pitch_comparison(pitch, target_pitch, sample_rate, overlay=True)

        grad_x = torch.autograd.grad(loss, denoised, grad_outputs=torch.ones_like(loss), retain_graph=False, allow_unused=True)[0]
        # d_denoised = -step_scale * grad_x

        if grad_x is None:
            print("Warning: grad_x is None (no gradient path from pitch to denoised)")
            return

        with torch.no_grad():
            grad_norm = grad_x.norm()
            if torch.isfinite(grad_norm) and grad_norm > 0.1:
                grad_x = grad_x * (0.1 / grad_norm)
            # denoised -= step_scale * grad_x
            denoised.add_(-inner_strength * step_scale * grad_x)
            denoised.requires_grad_(False)

            # print(f"loss={loss.item():.6f}")
            if pbar is not None:
                # pbar.set_postfix({'one': 1, 'two': 2})
                # print(pbar.postfix)
                pbar.set_postfix({**({k.strip(): v.strip() for k, v in (item.split('=') for item in pbar.postfix.split(','))} 
                                     if pbar.postfix else {}),  'loss': f"{loss.item():.2f}"})

# def cfg_scheduler(cfg_scale, step, p=.4, k =.5):
#     # return min(max(0.0, cfg_scale), (math.log(step) ** 2)/100)
#     return cfg_scale * math.exp(-k * (step / math.exp(p)))

def cfg_scheduler(cfg_scale, step, cutoff=5, middle_descend=2):
    if step < cutoff:
        return cfg_scale
    elif step < cutoff + middle_descend:
        return cfg_scale * (middle_descend - (step - cutoff)) / middle_descend
    else:
        return 0.0

##### Random scheduler I created on desmos
##### Note didnt even end up using this
$$
\Large
\text{CFG}(step) = c \cdot e^{\left(-k \cdot \left(\frac{step}{e^{p}}\right)\right)}
$$
- **\( \text{CFG}(step) \)** — the scheduled classifier-free guidance scale at this sampling step  
- **\( c \)** — the initial CFG scale (your starting guidance strength)  
- **\( step \)** — the current diffusion iteration (1, 2, 3, …)  
- **\( k \)** — a decay constant controlling how quickly CFG falls toward zero  
- **\( p \)** — a shaping parameter that adjusts how sharply the decay curve bends  

In [101]:
# monkey Patch pingpon
@torch.no_grad()
def sample_flow_pingpong(model, x, steps=None, sigma_max=1, sigmas=None, callback=None, dist_shift=None, **extra_args):
    """Draws samples from a model given starting noise. Ping-pong sampling for distilled models"""

    assert steps is not None or sigmas is not None, "Either steps or sigmas must be provided"

    # Make tensor of ones to broadcast the single t values
    ts = x.new_ones([x.shape[0]])

    if sigmas is None:

        # Create the noise schedule
        t = torch.linspace(sigma_max, 0, steps + 1)

        if dist_shift is not None:
            t = dist_shift.time_shift(t, x.shape[-1])

    else:
        t = sigmas

    # Add pbar to in_dict and set a cfg_scale to reduce to 0 with time
    cfg_scheduler = extra_args.pop("cfg_scheduler", None)
    if cfg_scheduler: base_cfg = float(extra_args.pop("cfg_scale", 0.0))

    pbar = trange(len(t) - 1, desc="Sampling", ncols=100)
    for i in pbar:
        if cfg_scheduler:
            scheduled_cfg = cfg_scheduler(base_cfg, i + 1)
            extra_args["cfg_scale"] = scheduled_cfg  # replace cfg_scale with scheduled_cfg
            pbar.set_postfix({**({k.strip(): v.strip()
                                for k, v in (item.split('=') for item in pbar.postfix.split(','))} if pbar.postfix else {}),
                            'cfg_scale': f"{scheduled_cfg:.2f}"})  # utterly absurd

        denoised = x - t[i] * model(x, t[i] * ts, **extra_args)
        if callback is not None:
            callback({'x': x, 'i': i, 't': t[i], 'sigma': t[i], 'sigma_hat': t[i], 'denoised': denoised, 'pbar': pbar})

        t_next = t[i + 1]
        x = (1 - t_next) * denoised + t_next * torch.randn_like(x)

    return x

sampling.sample_flow_pingpong = sample_flow_pingpong

In [102]:
conditioning = [{
    # "prompt": "nylon guitar country sound",  # This prompt is quite bad on small, but small does work
    "prompt": "my saxophone cried a whale on a sunday",  # This prompt is quite bad on small, but small does work
    "seconds_total": time_sec
}]
pitch_array = []
audio_array = []

autoencoder = model._modules['pretransform']._modules.get("model")
# Params
base_step_scale = 5
alpha = 10
inner_strength = 50
callback_wrapper = partial(pitch_callback, autoencoder, target_pitch, base_step_scale, alpha, inner_strength)
cfg_scheduler_wrapper = partial(cfg_scheduler, cutoff=20, middle_descend=2)

# callback_wrapper = partial(pitch_guidance_callback, autoencoder, target_pitch, 50)

# Generate stereo audio
output = generate_diffusion_cond(
    model,
    # Marco's Notes:
    # 7 steps works good for sao small, higher than that gets scary
    # If using normal sao higher steps is usually pretty good.
    conditioning=conditioning,
    steps=50,
    cfg_scale=10, # Config of 1 often good for small, higher works on normal
    sample_size=sample_size,
    sigma_min=10,
    sigma_max=300,
    # sampler_type="dpmpp-3m-sde",  # Use this for normal open
    sampler_type="pingpong",  # Use this for small
    device=device,
    callback=callback_wrapper,
    cfg_scheduler=cfg_scheduler_wrapper
)

# Rearrange audio batch to a single sequence
output = rearrange(output, "b d n -> d (b n)")

1509537729
using cool


Sampling:   0%|                                                              | 0/50 [00:00<?, ?it/s]

##### Now you can hear the final audio

In [95]:
# Peak normalize, convert to int16
cleaned_output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()

# Clip length
#clipped_output = cleaned_output[..., :int(sample_rate * total_seconds)]

Audio(cleaned_output.numpy(), rate=sample_rate)

In [None]:
# clean up
import gc
print(gc.collect())
import torch
print(torch.cuda.empty_cache())
print(torch.cuda.synchronize())

Sounds like garbage right? (well thats if you took over like 10 steps)
Instead lets run gradio and try to see what some of the audio looked like during the journey

In [None]:
demo = create_gradio_interface(pitch_array, target_pitch, sample_rate, audio_array, target_audio)

In [None]:
demo.launch(server_name="0.0.0.0", server_port=7860, inline=False)

In [None]:
demo.close()

In [92]:
# use this to close any still open ones
import gradio as gr
gr.close_all()

You can also view just the pitch frames overlaid with the `target_pitch` and then select the audio frame in the next cell

In [None]:
plot_pitch(target_pitch, sample_rate)

In [96]:
animate_pitch_arrays(pitch_array, sample_rate, target_pitch)

In [None]:
frame_number = 4
Audio(audio_array[frame_number], rate=sample_rate)

If you want to compare to the original audio run this

In [None]:
# Also display original audio
Audio(target_audio.numpy(), rate=sample_rate)