In [1]:
# Imports
import torch
import torchaudio
import torchcrepe
from einops import rearrange
from stable_audio_tools import get_pretrained_model
from stable_audio_tools.inference.generation import generate_diffusion_cond

import numpy as np
import pretty_midi
import sounddevice as sd
import librosa
import matplotlib.pyplot as plt
from IPython.display import Audio
import soundfile as sf
from torch.nn import MSELoss

# Custom Helpers
from audio_helpers import render_midi

device = "cuda" if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else "cpu"
print("Using {}".format(device))

  from pkg_resources import packaging


Using cached SoundFont: soundfonts/FluidR3Mono_GM.sf3
Using mps


In [2]:
@torch.enable_grad()
def calculate_pitch(audio, sample_rate):
    # Compute pitch
    audio = audio.mean(dim=0, keepdim=True).to(device=device, dtype=torch.float32) # Down from stereo to mono
    hop = int(sample_rate / 200.)  # 5 ms hop
    pitch, periodicity = torchcrepe.predict(audio, sample_rate, hop_length=hop, fmin=50, fmax=550,
                            model='tiny', # or 'full'
                            batch_size=2048, device=device, return_periodicity=True, decoder=torchcrepe.decode.soft_argmax, differentiable=True)
    # Clean up pitch
    win_l = 3
    periodicity = torchcrepe.filter.median(periodicity, win_l)
    periodicity = torchcrepe.threshold.Silence(-60.)(periodicity, audio, sample_rate, hop)
    pitch = torchcrepe.threshold.At(.5)(pitch, periodicity)
    pitch = torchcrepe.filter.mean(pitch, win_l)
    return pitch

In [3]:
# Download model | Stable Audio Open Small
# `https://huggingface.co/stabilityai/stable-audio-open-small`
model, model_config = get_pretrained_model("stabilityai/stable-audio-open-small")
sample_rate = model_config["sample_rate"]
sample_size = model_config["sample_size"]

model = model.to(device)

No module named 'flash_attn'
flash_attn not installed, disabling Flash Attention


  WeightNorm.apply(module, name, dim)


In [4]:
target_audio, target_sr = torchaudio.load("./data/BDCT-0/4YNW3G/Audio Files/Bass.08_01.wav")
if sample_rate != target_sr: # Resample to model rate
    resampler = torchaudio.transforms.Resample(sample_rate, target_sr)
    target_audio = resampler(target_audio)

# Reduce to first 30 seconds
time_sec = 11.888616780045352
target_audio = target_audio[:, :int(time_sec * sample_rate)]
target_pitch = calculate_pitch(target_audio, sample_rate)

print(f"Target length is: {target_audio.shape[1] / sample_rate}")

Target length is: 11.888616780045352


In [5]:
from functools import partial

def pitch_callback(model, target_pitch, step_scale, in_dict):
    x, denoised = in_dict['x'], in_dict['denoised']
    with torch.enable_grad():
        x.requires_grad, denoised.requires_grad = True, True
        print(f"t = {in_dict['t']:.3f}, denoised..shape, .requires_grad = {denoised.shape},, {denoised.requires_grad}")

        autoencoder = model._modules['pretransform']._modules.get("model")

        audio = autoencoder.decoder(denoised.half())
        audio = rearrange(audio, "b d n -> d (b n)")
        print("Generated Audio shape:", audio.shape, f"Generated Audio length: {(audio.shape[1]/sample_rate):.2f}")

        # Compute pitch
        audio = audio.mean(dim=0, keepdim=True).to(device=device, dtype=torch.float32) # Down from stereo to mono
        pitch = calculate_pitch(audio, sample_rate)
        print("pitch.requires_grad =",pitch.requires_grad)
        print("pitch.shape =",pitch.shape)

        loss_fn = MSELoss()
        loss = loss_fn(pitch, target_pitch)
        print("loss.requires_grad =",loss.requires_grad)

        grad_x = torch.autograd.grad(loss, denoised, grad_outputs=torch.ones_like(loss), retain_graph=False, allow_unused=True)[0]
        d_denoised = -step_scale * grad_x
        denoised = denoised + d_denoised

    in_dict['denoised'] = denoised
    in_dict['x'] = x
    return

callback_wrapper = partial(pitch_callback, model, target_pitch, 1.)

In [6]:
conditioning = [{
    "prompt": "40 BPM jazz saxophone solo funk fusion",  # This prompt is quite bad on small, but small does work
    "seconds_total": time_sec
}]

# Generate stereo audio
output = generate_diffusion_cond(
    model,
    # Marco's Notes:
    # 7 steps works good for sao small, higher than that gets scary
    # If using normal sao higher steps is usually pretty good.
    conditioning=conditioning,
    steps=7,
    cfg_scale=.8, # Config of 1 often good for small, higher works on normal
    sample_size=sample_size,
    sigma_min=.3,
    sigma_max=400,
    # sampler_type="dpmpp-3m-sde",  # Use this for normal open
    sampler_type="pingpong",  # Use this for small
    device=device,
    callback=callback_wrapper
)

# Rearrange audio batch to a single sequence
# output = rearrange(output, "b d n -> d (b n)")

893241753


  self.setter(val)
  with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad):
  0%|          | 0/7 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


t = 1.000, denoised..shape, .requires_grad = torch.Size([1, 64, 256]),, True
Generated Audio shape: torch.Size([2, 524288]) Generated Audio length: 11.89


 14%|█▍        | 1/7 [00:00<00:05,  1.17it/s]

pitch.requires_grad = True
pitch.shape = torch.Size([1, 2408])
loss.requires_grad = True
t = 0.992, denoised..shape, .requires_grad = torch.Size([1, 64, 256]),, True
Generated Audio shape: torch.Size([2, 524288]) Generated Audio length: 11.89


 29%|██▊       | 2/7 [00:01<00:04,  1.04it/s]

pitch.requires_grad = True
pitch.shape = torch.Size([1, 2408])
loss.requires_grad = True
t = 0.976, denoised..shape, .requires_grad = torch.Size([1, 64, 256]),, True
Generated Audio shape: torch.Size([2, 524288]) Generated Audio length: 11.89


 43%|████▎     | 3/7 [00:03<00:04,  1.04s/it]

pitch.requires_grad = True
pitch.shape = torch.Size([1, 2408])
loss.requires_grad = True
t = 0.929, denoised..shape, .requires_grad = torch.Size([1, 64, 256]),, True
Generated Audio shape: torch.Size([2, 524288]) Generated Audio length: 11.89


 57%|█████▋    | 4/7 [00:04<00:03,  1.08s/it]

pitch.requires_grad = True
pitch.shape = torch.Size([1, 2408])
loss.requires_grad = True
t = 0.807, denoised..shape, .requires_grad = torch.Size([1, 64, 256]),, True
Generated Audio shape: torch.Size([2, 524288]) Generated Audio length: 11.89


 71%|███████▏  | 5/7 [00:05<00:02,  1.10s/it]

pitch.requires_grad = True
pitch.shape = torch.Size([1, 2408])
loss.requires_grad = True
t = 0.571, denoised..shape, .requires_grad = torch.Size([1, 64, 256]),, True
Generated Audio shape: torch.Size([2, 524288]) Generated Audio length: 11.89


 86%|████████▌ | 6/7 [00:06<00:01,  1.12s/it]

pitch.requires_grad = True
pitch.shape = torch.Size([1, 2408])
loss.requires_grad = True
t = 0.298, denoised..shape, .requires_grad = torch.Size([1, 64, 256]),, True
Generated Audio shape: torch.Size([2, 524288]) Generated Audio length: 11.89


100%|██████████| 7/7 [00:07<00:00,  1.09s/it]

pitch.requires_grad = True
pitch.shape = torch.Size([1, 2408])
loss.requires_grad = True





In [7]:
output = rearrange(output, "b d n -> d (b n)")

In [8]:
# Peak normalize, convert to int16
cleaned_output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()

# Clip length
#clipped_output = cleaned_output[..., :int(sample_rate * total_seconds)]

display(Audio(cleaned_output.numpy(), rate=sample_rate))