In [None]:
# Imports
import torch
import torchaudio
import torchcrepe
from einops import rearrange
from stable_audio_tools import get_pretrained_model
from stable_audio_tools.inference.generation import generate_diffusion_cond

import numpy as np
import pretty_midi
import sounddevice as sd
import librosa
import matplotlib.pyplot as plt
from IPython.display import Audio
import soundfile as sf
from torch.nn import MSELoss

# Custom Helpers
from audio_helpers import get_github_audio, plot_pitch, plot_pitch_comparison, animate_pitch_arrays, create_gradio_interface

device = "cuda" if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else "cpu"
print("Using {}".format(device))

In [None]:
@torch.enable_grad()
def calculate_pitch(audio, sample_rate):
    # Compute pitch
    if isinstance(audio, np.ndarray):
        audio = torch.tensor(audio, device=device, dtype=torch.float32)

    if audio.ndim > 1:
        audio = audio.to(device=device, dtype=torch.float32).mean(dim=0, keepdim=True)

    hop = int(sample_rate / 200.)  # 5 ms hop
    pitch, periodicity = torchcrepe.predict(audio, sample_rate, hop_length=hop, fmin=50, fmax=550,
                            model='tiny', # or 'full'
                            batch_size=2048, device=device, return_periodicity=True, decoder=torchcrepe.decode.soft_argmax, differentiable=True)
    # Clean up pitch
    win_l = 3
    periodicity = torchcrepe.filter.median(periodicity, win_l)
    periodicity = torchcrepe.threshold.Silence(-60.)(periodicity, audio, sample_rate, hop)
    pitch = torchcrepe.threshold.At(.5)(pitch, periodicity)
    pitch = torchcrepe.filter.mean(pitch, win_l)
    return pitch

In [None]:
# Download model | Stable Audio Open Small
# `https://huggingface.co/stabilityai/stable-audio-open-small`
model, model_config = get_pretrained_model("stabilityai/stable-audio-open-small")
sample_rate = model_config["sample_rate"]
sample_size = model_config["sample_size"]

model = model.to(device)

In [None]:
# if you don't have the audio file download it
get_github_audio("https://github.com/pdx-cs-sound/wavs/raw/refs/heads/main/gc.wav")

target_audio, target_sr = torchaudio.load('data/audio/gc.wav')
if sample_rate != target_sr: # Resample to model rate
    resampler = torchaudio.transforms.Resample(sample_rate, target_sr)
    target_audio = resampler(target_audio)

# Reduce to this really specific time that stable audio open small has
time_sec = 11.888616780045352
target_audio = target_audio[:, :int(time_sec * sample_rate)]
target_pitch = calculate_pitch(target_audio, sample_rate)

print(f"Target length is: {target_audio.shape[1] / sample_rate}")

In [None]:
from functools import partial

pitch_array = []
audio_array = []

def pitch_callback(autoencoder, target_pitch, base_step_scale, alpha, in_dict):
    x, denoised, t = in_dict['x'], in_dict['denoised'].detach(), in_dict['t'].detach().float()
    
    with torch.enable_grad():
        denoised.requires_grad = True
        # print(f"t = {in_dict['t']:.3f}, denoised..shape, .requires_grad = {denoised.shape},, {denoised.requires_grad}")

        # PnP-Flow schedule: (1 - t)^alpha
        time_weight = (1.0 - t) ** alpha
        step_scale  = base_step_scale * time_weight
        # print(f"time_weight = {float(time_weight):.4f}, step_scale = {float(step_scale):.4f}")

        audio = autoencoder.decoder(denoised.half())
        # print(f"Decoder Audio \nshape: {audio.shape} \ndtype: {audio.dtype}")
        audio = audio.float()
        # print(f"after .float() Audio \nshape: {audio.shape} \ndtype: {audio.dtype}")
        audio = rearrange(audio, "b d n -> d (b n)")
        print("Generated Audio shape:", audio.shape, f"Generated Audio length: {(audio.shape[1]/sample_rate):.2f}")

        # Display audio at each step
        diplay_audio = audio.div(torch.max(torch.abs(audio))).clamp(-1, 1).mul(32767).to(torch.int16).cpu() # how stable audio converted it
        # display(Audio(diplay_audio.numpy(), rate=sample_rate))

        # Compute pitch
        audio = audio.mean(dim=0, keepdim=True)
        pitch = calculate_pitch(audio, sample_rate)

        # Graph both pitch
        pitch_array.append(pitch.detach().cpu().numpy()) # make sure to use this in jupter for external context
        audio_array.append(diplay_audio.detach().cpu().numpy())
        # plot_pitch_comparison(pitch, target_pitch, sample_rate, overlay=True)

        loss_fn = MSELoss()
        loss = loss_fn(torch.nan_to_num(pitch, nan=0.),
                        torch.nan_to_num(target_pitch, nan=0.))
        # print("loss.requires_grad =",loss.requires_grad)

        grad_x = torch.autograd.grad(loss, denoised, grad_outputs=torch.ones_like(loss), retain_graph=False, allow_unused=True)[0]
        # d_denoised = -step_scale * grad_x

        # Clip and update
        with torch.no_grad():
            if grad_x.norm() > 0.1:
                grad_x = grad_x * (0.1 / grad_x.norm())
            denoised += step_scale * grad_x
            denoised.requires_grad_(False)

            print(f"loss={loss.item():.6f}")
    return

In [None]:
import torch.nn.functional as F

def pitch_guidance_callback(autoencoder, target_pitch, strength, in_dict):
    "Main guidance routine using chroma features"
    t, x, denoised = in_dict['t'], in_dict['x'], in_dict['denoised']
    if target_pitch is None:
        print(f"t = {t:.3f}: target_pitch is None. Skipping pitch guidance callback")
        return
    inner_strength = strength / 10  # or /5, adjust based on testing
    print('--')
    with torch.enable_grad():
        for step in range(10):
            denoised.requires_grad_(True)  # Enable grad on denoised itself
            pred_audio = autoencoder.decoder(denoised.half())
            shape_save = pred_audio.shape[-1]
            kernel_size = 64
            pred_audio = torch.nn.functional.avg_pool1d(
                pred_audio, kernel_size=kernel_size, stride=1, padding=kernel_size//2
            )
            pred_audio = pred_audio[..., :shape_save]
            #print("pred_audio.shape =",pred_audio.shape)
            # pred_audio = torchaudio.functional.lowpass_biquad(
            #     pred_audio, sample_rate=sample_rate, cutoff_freq=4000)
            pred_audio = rearrange(pred_audio, "b d n -> d (b n)")
            pred_pitch = calculate_pitch(pred_audio, sample_rate)

            loss = F.mse_loss(
                torch.nan_to_num(pred_pitch, nan=0),
                torch.nan_to_num(target_pitch, nan=0)
                )
            grad = torch.autograd.grad(loss, denoised)[0]
            #gamma_t = (1 - t)
            gamma_t = 0.3 + 0.7 * (1 - t)  # Goes from 0.3 â†’ 1.0

            # Clip and update
            with torch.no_grad():
                if grad.norm() > 0.1:
                    grad = grad * (0.1 / grad.norm())
                denoised -= inner_strength * gamma_t * grad
                denoised.requires_grad_(False)

            print(f"  Inner step {step}: loss={loss.item():.6f}")



In [None]:
conditioning = [{
    "prompt": "nylon guitar country sound",  # This prompt is quite bad on small, but small does work
    "seconds_total": time_sec
}]
pitch_array = []
audio_array = []

autoencoder = model._modules['pretransform']._modules.get("model")
# Params
base_step_scale = 1
alpha = 5
callback_wrapper = partial(pitch_callback, autoencoder, target_pitch, base_step_scale, alpha)

# callback_wrapper = partial(pitch_guidance_callback, autoencoder, target_pitch, 50)

# Generate stereo audio
output = generate_diffusion_cond(
    model,
    # Marco's Notes:
    # 7 steps works good for sao small, higher than that gets scary
    # If using normal sao higher steps is usually pretty good.
    conditioning=conditioning,
    steps=50,
    cfg_scale=1., # Config of 1 often good for small, higher works on normal
    sample_size=sample_size,
    sigma_min=10,
    sigma_max=300,
    # sampler_type="dpmpp-3m-sde",  # Use this for normal open
    sampler_type="pingpong",  # Use this for small
    device=device,
    callback=callback_wrapper
)

# Rearrange audio batch to a single sequence
output = rearrange(output, "b d n -> d (b n)")

##### Now you can hear the final audio

In [None]:
# Peak normalize, convert to int16
cleaned_output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()

# Clip length
#clipped_output = cleaned_output[..., :int(sample_rate * total_seconds)]

Audio(cleaned_output.numpy(), rate=sample_rate)

Sounds like garbage right? (well thats if you took over like 10 steps)
Instead lets run gradio and try to see what some of the audio looked like during the journey

In [None]:
demo = create_gradio_interface(pitch_array, target_pitch, sample_rate, audio_array, target_audio)

In [None]:
demo.launch()

In [None]:
demo.close()

In [None]:
# use this to close any still open ones
import gradio as gr
gr.close_all()

You can also view just the pitch frames overlaid with the `target_pitch` and then select the audio frame in the next cell

In [None]:
animate_pitch_arrays(pitch_array, sample_rate, target_pitch)

In [None]:
frame_number = 35
Audio(audio_array[frame_number], rate=sample_rate)

If you want to compare to the original audio run this

In [None]:
# Also display original audio
Audio(target_audio.numpy(), rate=sample_rate)