In [1]:
# Imports
import torch
import torchaudio
import torchcrepe
from einops import rearrange
from stable_audio_tools import get_pretrained_model
from stable_audio_tools.inference.generation import generate_diffusion_cond

import numpy as np
import pretty_midi
import sounddevice as sd
import librosa
import matplotlib.pyplot as plt
from IPython.display import Audio
import soundfile as sf
from torch.nn import MSELoss

# Custom Helpers
from audio_helpers import get_github_audio, plot_pitch, plot_pitch_comparison, animate_pitch_arrays, create_gradio_interface

device = "cuda" if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else "cpu"
print("Using {}".format(device))

Using cached SoundFont: soundfonts/FluidR3Mono_GM.sf3
Using mps


In [2]:
@torch.enable_grad()
def calculate_pitch(audio, sample_rate):
    # Compute pitch
    if isinstance(audio, np.ndarray):
        audio = torch.tensor(audio, device=device, dtype=torch.float32)

    if audio.ndim > 1:
        audio = audio.to(device=device, dtype=torch.float32).mean(dim=0, keepdim=True)

    hop = int(sample_rate / 200.)  # 5 ms hop
    pitch, periodicity = torchcrepe.predict(audio, sample_rate, hop_length=hop, fmin=50, fmax=550,
                            model='tiny', # or 'full'
                            batch_size=2048, device=device, return_periodicity=True, decoder=torchcrepe.decode.soft_argmax, differentiable=True)
    # Clean up pitch
    win_l = 3
    periodicity = torchcrepe.filter.median(periodicity, win_l)
    periodicity = torchcrepe.threshold.Silence(-60.)(periodicity, audio, sample_rate, hop)
    pitch = torchcrepe.threshold.At(.5)(pitch, periodicity)
    pitch = torchcrepe.filter.mean(pitch, win_l)
    return pitch

In [3]:
# Download model | Stable Audio Open Small
# `https://huggingface.co/stabilityai/stable-audio-open-small`
model, model_config = get_pretrained_model("stabilityai/stable-audio-open-small")
sample_rate = model_config["sample_rate"]
sample_size = model_config["sample_size"]

model = model.to(device)

No module named 'flash_attn'
flash_attn not installed, disabling Flash Attention


  WeightNorm.apply(module, name, dim)


In [4]:
# if you don't have the audio file download it
get_github_audio("https://github.com/pdx-cs-sound/wavs/raw/refs/heads/main/gc.wav")

target_audio, target_sr = torchaudio.load('data/audio/gc.wav')
if sample_rate != target_sr: # Resample to model rate
    resampler = torchaudio.transforms.Resample(sample_rate, target_sr)
    target_audio = resampler(target_audio)

# Reduce to this really specific time that stable audio open small has
time_sec = 11.888616780045352
target_audio = target_audio[:, :int(time_sec * sample_rate)]
target_pitch = calculate_pitch(target_audio, sample_rate)

print(f"Target length is: {target_audio.shape[1] / sample_rate}")

Downloading gc.wav...
Saved to data/audio/gc.wav
Target length is: 11.888616780045352


In [80]:
from functools import partial

def pitch_callback(autoencoder, target_pitch, base_step_scale, alpha, inner_strength, in_dict):
    x, denoised, t = in_dict['x'], in_dict['denoised'], in_dict['t'].detach().float()
    once = True

    # PnP-Flow schedule: (1 - t)^alpha
    time_weight = (1.0 - t) ** alpha
    step_scale  = base_step_scale * time_weight
    inner_strength = inner_strength
    print(f"time_weight = {float(time_weight):.4f}, step_scale = {float(step_scale):.4f}, inner_strength = {float(inner_strength):.2f}")

    counter = 1
    print('--')
    with torch.enable_grad():
        while counter >= 1:
            denoised.requires_grad_(True)
            # print(f"t = {in_dict['t']:.3f}, denoised..shape, .requires_grad = {denoised.shape},, {denoised.requires_grad}")

            pred_audio = autoencoder.decoder(denoised.half()).float() # convert to float 32

            shape_save = pred_audio.shape[-1]
            kernel_size = 64
            pred_audio = torch.nn.functional.avg_pool1d(
                pred_audio, kernel_size=kernel_size, stride=1, padding=kernel_size//2
            )
            pred_audio = pred_audio[..., :shape_save]

            # print(f"Decoder Audio \nshape: {audio.shape} \ndtype: {audio.dtype}")
            pred_audio = rearrange(pred_audio, "b d n -> d (b n)")
            # print("Generated Audio shape:", audio.shape, f"Generated Audio length: {(audio.shape[1]/sample_rate):.2f}")

            # Display audio at each step
            diplay_audio = pred_audio.div(torch.max(torch.abs(pred_audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() # how stable audio converted it
            # display(Audio(diplay_audio.numpy(), rate=sample_rate))

            # Compute pitch
            pred_audio = pred_audio.mean(dim=0, keepdim=True)
            pitch = calculate_pitch(pred_audio, sample_rate)

            loss_fn = MSELoss()
            loss = loss_fn(torch.nan_to_num(pitch, nan=0.),
                            torch.nan_to_num(target_pitch, nan=0.))

            # find a good starting point
            if counter == 1 and loss.item() <= 13800 and once:
                print(f"Found good starting pitch beginning iterations")
                counter = 30 # do this many times once
                once = False

            if not once:
                print(f"Iteration: {counter}")

            # Graph both pitch
            if counter == 1:
                pitch_array.append(pitch.detach().cpu().numpy()) # make sure to use this in jupter for external context
                audio_array.append(diplay_audio.detach().cpu().numpy())
                # plot_pitch_comparison(pitch, target_pitch, sample_rate, overlay=True)

            grad_x = torch.autograd.grad(loss, denoised, grad_outputs=torch.ones_like(loss), retain_graph=False, allow_unused=True)[0]
            # d_denoised = -step_scale * grad_x

            if grad_x is None:
                print("Warning: grad_x is None (no gradient path from pitch to denoised)")
                return

            with torch.no_grad():
                grad_norm = grad_x.norm()
                if torch.isfinite(grad_norm) and grad_norm > 0.1:
                    grad_x = grad_x * (0.1 / grad_norm)
                # denoised -= step_scale * grad_x
                denoised.add_(-inner_strength * step_scale * grad_x)
                denoised.requires_grad_(False)

                print(f"loss={loss.item():.6f}")

            counter -= 1

In [81]:
conditioning = [{
    "prompt": "nylon guitar country sound",  # This prompt is quite bad on small, but small does work
    "seconds_total": time_sec
}]
pitch_array = []
audio_array = []

autoencoder = model._modules['pretransform']._modules.get("model")
# Params
base_step_scale = 2
alpha = 1
inner_strength = 50
callback_wrapper = partial(pitch_callback, autoencoder, target_pitch, base_step_scale, alpha, inner_strength)

# callback_wrapper = partial(pitch_guidance_callback, autoencoder, target_pitch, 50)

# Generate stereo audio
output = generate_diffusion_cond(
    model,
    # Marco's Notes:
    # 7 steps works good for sao small, higher than that gets scary
    # If using normal sao higher steps is usually pretty good.
    conditioning=conditioning,
    steps=30,
    cfg_scale=1., # Config of 1 often good for small, higher works on normal
    sample_size=sample_size,
    sigma_min=10,
    sigma_max=300,
    # sampler_type="dpmpp-3m-sde",  # Use this for normal open
    sampler_type="pingpong",  # Use this for small
    device=device,
    callback=callback_wrapper
)

# Rearrange audio batch to a single sequence
output = rearrange(output, "b d n -> d (b n)")

3377606376


  0%|          | 0/30 [00:00<?, ?it/s]

time_weight = 0.0000, step_scale = 0.0000, inner_strength = 50.00
--


  3%|▎         | 1/30 [00:01<00:56,  1.95s/it]

loss=13954.159180
time_weight = 0.0032, step_scale = 0.0065, inner_strength = 50.00
--


  7%|▋         | 2/30 [00:03<00:45,  1.63s/it]

loss=14234.647461
time_weight = 0.0042, step_scale = 0.0084, inner_strength = 50.00
--


 10%|█         | 3/30 [00:04<00:41,  1.52s/it]

loss=14098.815430
time_weight = 0.0055, step_scale = 0.0110, inner_strength = 50.00
--


 13%|█▎        | 4/30 [00:06<00:38,  1.49s/it]

loss=14072.356445
time_weight = 0.0072, step_scale = 0.0143, inner_strength = 50.00
--


 17%|█▋        | 5/30 [00:07<00:36,  1.46s/it]

loss=13947.341797
time_weight = 0.0093, step_scale = 0.0186, inner_strength = 50.00
--
Found good starting pitch beginning iterations
Iteration: 30
loss=13726.415039
Iteration: 29
loss=13729.912109
Iteration: 28
loss=13743.445312
Iteration: 27
loss=13751.296875
Iteration: 26
loss=13747.855469
Iteration: 25
loss=13795.323242
Iteration: 24
loss=13708.367188
Iteration: 23
loss=13698.061523
Iteration: 22
loss=13691.707031
Iteration: 21
loss=13694.344727
Iteration: 20
loss=13659.861328
Iteration: 19
loss=13603.322266
Iteration: 18
loss=13665.824219
Iteration: 17
loss=13669.346680
Iteration: 16
loss=13587.822266
Iteration: 15
loss=13599.118164
Iteration: 14
loss=13640.424805
Iteration: 13
loss=13652.423828
Iteration: 12
loss=13651.112305
Iteration: 11
loss=13628.159180
Iteration: 10
loss=13609.585938
Iteration: 9
loss=13604.594727
Iteration: 8
loss=13613.001953
Iteration: 7
loss=13610.200195
Iteration: 6
loss=13640.562500
Iteration: 5
loss=13589.104492
Iteration: 4
loss=13580.833984
Iteratio

 20%|██        | 6/30 [00:48<05:53, 14.71s/it]

loss=13578.734375
time_weight = 0.0121, step_scale = 0.0243, inner_strength = 50.00
--


 23%|██▎       | 7/30 [00:49<03:58, 10.35s/it]

loss=15073.933594
time_weight = 0.0158, step_scale = 0.0316, inner_strength = 50.00
--


 27%|██▋       | 8/30 [00:50<02:44,  7.49s/it]

loss=14157.361328
time_weight = 0.0205, step_scale = 0.0410, inner_strength = 50.00
--


 30%|███       | 9/30 [00:52<01:56,  5.56s/it]

loss=14191.029297
time_weight = 0.0266, step_scale = 0.0532, inner_strength = 50.00
--


 33%|███▎      | 10/30 [00:53<01:25,  4.26s/it]

loss=14307.361328
time_weight = 0.0344, step_scale = 0.0689, inner_strength = 50.00
--


 37%|███▋      | 11/30 [00:54<01:03,  3.36s/it]

loss=13960.926758
time_weight = 0.0445, step_scale = 0.0890, inner_strength = 50.00
--


 40%|████      | 12/30 [00:56<00:49,  2.74s/it]

loss=13919.247070
time_weight = 0.0573, step_scale = 0.1146, inner_strength = 50.00
--


 43%|████▎     | 13/30 [00:57<00:39,  2.31s/it]

loss=14107.547852
time_weight = 0.0736, step_scale = 0.1471, inner_strength = 50.00
--


 47%|████▋     | 14/30 [00:58<00:32,  2.01s/it]

loss=15685.587891
time_weight = 0.0939, step_scale = 0.1878, inner_strength = 50.00
--


 50%|█████     | 15/30 [01:00<00:27,  1.80s/it]

loss=14846.633789
time_weight = 0.1192, step_scale = 0.2384, inner_strength = 50.00
--


 53%|█████▎    | 16/30 [01:01<00:23,  1.67s/it]

loss=13896.565430
time_weight = 0.1502, step_scale = 0.3003, inner_strength = 50.00
--


 57%|█████▋    | 17/30 [01:02<00:20,  1.55s/it]

loss=14026.687500
time_weight = 0.1874, step_scale = 0.3749, inner_strength = 50.00
--


 60%|██████    | 18/30 [01:03<00:17,  1.48s/it]

loss=14026.687500
time_weight = 0.2315, step_scale = 0.4630, inner_strength = 50.00
--


 63%|██████▎   | 19/30 [01:05<00:15,  1.43s/it]

loss=14026.687500
time_weight = 0.2822, step_scale = 0.5645, inner_strength = 50.00
--


 67%|██████▋   | 20/30 [01:06<00:13,  1.39s/it]

loss=14293.677734
time_weight = 0.3392, step_scale = 0.6785, inner_strength = 50.00
--
Found good starting pitch beginning iterations
Iteration: 30
loss=13028.583008
Iteration: 29
loss=13141.450195
Iteration: 28
loss=13041.323242
Iteration: 27
loss=13326.332031
Iteration: 26
loss=18504.062500
Iteration: 25
loss=18642.794922
Iteration: 24
loss=18705.371094
Iteration: 23
loss=18839.251953
Iteration: 22
loss=18886.009766
Iteration: 21
loss=18843.224609
Iteration: 20
loss=18704.697266
Iteration: 19
loss=18369.351562
Iteration: 18
loss=18624.544922
Iteration: 17
loss=18948.015625
Iteration: 16
loss=18958.521484
Iteration: 15
loss=18682.402344
Iteration: 14
loss=18798.525391
Iteration: 13
loss=18607.890625
Iteration: 12
loss=18302.781250
Iteration: 11
loss=17785.712891
Iteration: 10
loss=17161.939453
Iteration: 9
loss=16680.158203
Iteration: 8
loss=15689.381836
Iteration: 7
loss=14834.624023
Iteration: 6
loss=14617.749023
Iteration: 5
loss=14355.282227
Iteration: 4
loss=14072.061523
Iteratio

 70%|███████   | 21/30 [01:44<01:51, 12.43s/it]

loss=13529.489258
time_weight = 0.4013, step_scale = 0.8026, inner_strength = 50.00
--


 73%|███████▎  | 22/30 [01:46<01:13,  9.14s/it]

loss=18374.810547
time_weight = 0.4667, step_scale = 0.9334, inner_strength = 50.00
--


 77%|███████▋  | 23/30 [01:47<00:47,  6.78s/it]

loss=18715.005859
time_weight = 0.5333, step_scale = 1.0666, inner_strength = 50.00
--


 80%|████████  | 24/30 [01:48<00:30,  5.14s/it]

loss=16218.311523
time_weight = 0.5987, step_scale = 1.1974, inner_strength = 50.00
--


 83%|████████▎ | 25/30 [01:50<00:19,  3.99s/it]

loss=72056.812500
time_weight = 0.6608, step_scale = 1.3215, inner_strength = 50.00
--


 87%|████████▋ | 26/30 [01:51<00:12,  3.17s/it]

loss=14583.484375
time_weight = 0.7178, step_scale = 1.4355, inner_strength = 50.00
--


 90%|█████████ | 27/30 [01:52<00:07,  2.60s/it]

loss=14026.687500
time_weight = 0.7685, step_scale = 1.5370, inner_strength = 50.00
--


 93%|█████████▎| 28/30 [01:53<00:04,  2.20s/it]

loss=14026.687500
time_weight = 0.8126, step_scale = 1.6251, inner_strength = 50.00
--


 97%|█████████▋| 29/30 [01:55<00:01,  1.92s/it]

loss=14026.687500
time_weight = 0.8498, step_scale = 1.6997, inner_strength = 50.00
--
Found good starting pitch beginning iterations
Iteration: 30
loss=13733.216797
Iteration: 29
loss=13809.914062
Iteration: 28
loss=13845.854492
Iteration: 27
loss=13629.458008
Iteration: 26
loss=13918.074219
Iteration: 25
loss=14009.900391
Iteration: 24
loss=13527.650391
Iteration: 23
loss=13829.571289
Iteration: 22
loss=13706.516602
Iteration: 21
loss=13936.793945
Iteration: 20
loss=13923.268555
Iteration: 19
loss=14026.687500
Iteration: 18
loss=14026.687500
Iteration: 17
loss=14026.687500
Iteration: 16
loss=14026.687500
Iteration: 15
loss=14026.687500
Iteration: 14
loss=14026.687500
Iteration: 13
loss=14026.687500
Iteration: 12
loss=14026.687500
Iteration: 11
loss=14026.687500
Iteration: 10
loss=14026.687500
Iteration: 9
loss=14026.687500
Iteration: 8
loss=14026.687500
Iteration: 7
loss=14026.687500
Iteration: 6
loss=14026.687500
Iteration: 5
loss=14026.687500
Iteration: 4
loss=14026.687500
Iteratio

100%|██████████| 30/30 [02:31<00:00,  5.04s/it]

loss=14026.687500





##### Now you can hear the final audio

In [82]:
# Peak normalize, convert to int16
cleaned_output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()

# Clip length
#clipped_output = cleaned_output[..., :int(sample_rate * total_seconds)]

Audio(cleaned_output.numpy(), rate=sample_rate)

Sounds like garbage right? (well thats if you took over like 10 steps)
Instead lets run gradio and try to see what some of the audio looked like during the journey

In [83]:
demo = create_gradio_interface(pitch_array, target_pitch, sample_rate, audio_array, target_audio)

In [84]:
demo.launch()

* Running on local URL:  http://127.0.0.1:7861
* To create a public link, set `share=True` in `launch()`.




In [None]:
demo.close()

In [None]:
# use this to close any still open ones
import gradio as gr
gr.close_all()

You can also view just the pitch frames overlaid with the `target_pitch` and then select the audio frame in the next cell

In [None]:
animate_pitch_arrays(pitch_array, sample_rate, target_pitch)

In [None]:
frame_number = 35
Audio(audio_array[frame_number], rate=sample_rate)

If you want to compare to the original audio run this

In [None]:
# Also display original audio
Audio(target_audio.numpy(), rate=sample_rate)

In [None]:
import torch.nn.functional as F

def pitch_guidance_callback(autoencoder, target_pitch, strength, in_dict):
    "Main guidance routine using chroma features"
    t, x, denoised = in_dict['t'], in_dict['x'], in_dict['denoised']
    if target_pitch is None:
        print(f"t = {t:.3f}: target_pitch is None. Skipping pitch guidance callback")
        return
    inner_strength = strength / 10  # or /5, adjust based on testing
    print('--')
    with torch.enable_grad():
        for step in range(10):
            denoised.requires_grad_(True)  # Enable grad on denoised itself
            pred_audio = autoencoder.decoder(denoised.half())
            shape_save = pred_audio.shape[-1]
            kernel_size = 64
            pred_audio = torch.nn.functional.avg_pool1d(
                pred_audio, kernel_size=kernel_size, stride=1, padding=kernel_size//2
            )
            pred_audio = pred_audio[..., :shape_save]
            #print("pred_audio.shape =",pred_audio.shape)
            # pred_audio = torchaudio.functional.lowpass_biquad(
            #     pred_audio, sample_rate=sample_rate, cutoff_freq=4000)
            pred_audio = rearrange(pred_audio, "b d n -> d (b n)")
            pred_pitch = calculate_pitch(pred_audio, sample_rate)

            loss = F.mse_loss(
                torch.nan_to_num(pred_pitch, nan=0),
                torch.nan_to_num(target_pitch, nan=0)
                )
            grad = torch.autograd.grad(loss, denoised)[0]
            #gamma_t = (1 - t)
            gamma_t = 0.3 + 0.7 * (1 - t)  # Goes from 0.3 → 1.0

            # Clip and update
            with torch.no_grad():
                if grad.norm() > 0.1:
                    grad = grad * (0.1 / grad.norm())
                denoised -= inner_strength * gamma_t * grad
                denoised.requires_grad_(False)

            print(f"  Inner step {step}: loss={loss.item():.6f}")

