In [1]:
import torch
import torchaudio
from einops import rearrange
from stable_audio_tools import get_pretrained_model
from stable_audio_tools.inference.generation import generate_diffusion_cond

device = "cuda" if torch.cuda.is_available() else "cpu"


  from pkg_resources import packaging


In [2]:
from huggingface_hub import login
login(token="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")

In [3]:
# Download model
model, model_config = get_pretrained_model("stabilityai/stable-audio-open-1.0")
sample_rate = model_config["sample_rate"]
sample_size = model_config["sample_size"]

model = model.to(device)

No module named 'flash_attn'
flash_attn not installed, disabling Flash Attention


  WeightNorm.apply(module, name, dim)
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:  36%|###6      | 1.75G/4.85G [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


In [None]:
from torch import nn
import typing as tp
from librosa import filters
import matplotlib as plt
import einops
import numpy as np
import julius
import torchmetrics

class ChromaExtractor(nn.Module):
    
    """
    Chroma extraction and quantization.

    Args:
        sample_rate (int): Sample rate for the chroma extraction.
        n_chroma (int): Number of chroma bins for the chroma extraction.
        radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12).
        nfft (int, optional): Number of FFT.
        winlen (int, optional): Window length.
        winhop (int, optional): Window hop size.
        argmax (bool, optional): Whether to use argmax. Defaults to False.
        norm (float, optional): Norm for chroma normalization. Defaults to inf.
    """

    def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12, nfft: tp.Optional[int] = None,
                 winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None, argmax: bool = False,
                 norm: float = torch.inf):
        super().__init__()
        self.winlen = winlen or 2 ** radix2_exp
        self.nfft = nfft or self.winlen
        self.winhop = winhop or (self.winlen // 4)
        self.sample_rate = sample_rate
        self.n_chroma = n_chroma
        self.norm = norm
        self.argmax = argmax
        self.register_buffer('fbanks', torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0,
                                                                       n_chroma=self.n_chroma)), persistent=False)
        self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen,
                                                      hop_length=self.winhop, power=2, center=True,
                                                      pad=0, normalized=True)

    def forward(self, wav: torch.Tensor) -> torch.Tensor:
        T = wav.shape[-1]
        # in case we are getting a wav that was dropped out (nullified)
        # from the conditioner, make sure wav length is no less that nfft
        if T < self.nfft:
            pad = self.nfft - T
            r = 0 if pad % 2 == 0 else 1
            wav = torch.nn.functional.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0)
            assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}"

        spec = self.spec(wav).squeeze(1)
        raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec)
        norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6)
        norm_chroma = einops.rearrange(norm_chroma, 'b d t -> b t d')

        if self.argmax:
            idx = norm_chroma.argmax(-1, keepdim=True)
            norm_chroma[:] = 0
            norm_chroma.scatter_(dim=-1, index=idx, value=1)

        return norm_chroma
    
class ChromaCosineSimilarityMetric(torchmetrics.Metric):
    
    """
    Chroma cosine similarity metric.

        This metric extracts a chromagram for a reference waveform and
        a generated waveform and compares each frame using the cosine similarity
        function. The output is the mean cosine similarity.

        Args:
            sample_rate (int): Sample rate used by the chroma extractor.
            n_chroma (int): Number of chroma used by the chroma extractor.
            radix2_exp (int): Exponent for the chroma extractor.
            argmax (bool): Whether the chroma extractor uses argmax.
            eps (float): Epsilon for cosine similarity computation.
    """

    def __init__(self, sample_rate: int, n_chroma: int, radix2_exp: int, argmax: bool, eps: float = 1e-8):
        super().__init__()
        self.chroma_sample_rate = sample_rate
        self.n_chroma = n_chroma
        self.eps = eps
        self.chroma_extractor = ChromaExtractor(sample_rate=self.chroma_sample_rate, n_chroma=self.n_chroma,
                                                radix2_exp=radix2_exp, argmax=argmax)
        self.add_state("cosine_sum", default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state("weight", default=torch.tensor(0.), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, targets: torch.Tensor,
               sizes: torch.Tensor, sample_rates: torch.Tensor) -> None:
        """Compute cosine similarity between chromagrams and accumulate scores over the dataset."""
        if preds.size(0) == 0:
            return

        assert preds.shape == targets.shape, (
            f"Preds and target shapes mismatch: preds={preds.shape}, targets={targets.shape}")
        assert preds.size(0) == sizes.size(0), (
            f"Number of items in preds ({preds.shape}) mismatch ",
            f"with sizes ({sizes.shape})")
        assert preds.size(0) == sample_rates.size(0), (
            f"Number of items in preds ({preds.shape}) mismatch ",
            f"with sample_rates ({sample_rates.shape})")
        assert torch.all(sample_rates == sample_rates[0].item()), "All sample rates are not the same in the batch"

        device = self.weight.device
        preds, targets = preds.to(device), targets.to(device)  # type: ignore
        sample_rate = sample_rates[0].item()
        preds = convert_audio(preds, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1)
        targets = convert_audio(targets, from_rate=sample_rate, to_rate=self.chroma_sample_rate, to_channels=1)
        gt_chroma = self.chroma_extractor(targets)
        gen_chroma = self.chroma_extractor(preds)
        chroma_lens = (sizes / self.chroma_extractor.winhop).ceil().int()
        for i in range(len(gt_chroma)):
            t = int(chroma_lens[i].item())
            cosine_sim = torch.nn.functional.cosine_similarity(
                gt_chroma[i, :t], gen_chroma[i, :t], dim=1, eps=self.eps)
            self.cosine_sum += cosine_sim.sum(dim=0)  # type: ignore
            self.weight += torch.tensor(t)  # type: ignore

    def compute(self) -> float:
        """Computes the average cosine similarty across all generated/target chromagrams pairs."""
        assert self.weight.item() > 0, "Unable to compute with total number of comparisons <= 0"  # type: ignore
        return (self.cosine_sum / self.weight).item()  # type: ignore
    
def convert_audio(wav: torch.Tensor, from_rate: float,
                  to_rate: float, to_channels: int) -> torch.Tensor:
    """Convert audio to new sample rate and number of audio channels."""
    wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
    wav = convert_audio_channels(wav, to_channels)
    return wav

def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor:

    *shape, src_channels, length = wav.shape
    if src_channels == channels:
        pass
    elif channels == 1:
        # Case 1:
        # The caller asked 1-channel audio, and the stream has multiple
        # channels, downmix all channels.
        wav = wav.mean(dim=-2, keepdim=True)
    elif src_channels == 1:
        # Case 2:
        # The caller asked for multiple channels, but the input file has
        # a single channel, replicate the audio over all channels.
        wav = wav.expand(*shape, channels, length)
    elif src_channels >= channels:
        # Case 3:
        # The caller asked for multiple channels, and the input file has
        # more channels than requested. In that case return the first channels.
        wav = wav[..., :channels, :]
    else:
        # Case 4: What is a reasonable choice here?
        raise ValueError('The audio file has less channels than requested but is not mono.')
    return wav

def plot_chromagram(chroma, sample_rate, hop_length):
    # Get parameters
    n_frames = chroma.shape[1]        # number of time frames

    # Compute time axis (in seconds)
    times = np.arange(n_frames) * hop_length / sample_rate

    # Plot
    plt.figure(figsize=(10, 4))
    plt.imshow(chroma[0].T.cpu(), 
            aspect='auto', 
            origin='lower', 
            extent=[times[0], times[-1], 0, 12])

    plt.xlabel('Time (s)')
    plt.ylabel('Chroma bins')
    plt.title('Chroma Features')

    # Set y-axis labels to note names
    note_labels = ['C', 'C#', 'D', 'D#', 'E', 'F', 
                'F#', 'G', 'G#', 'A', 'A#', 'B']

    plt.yticks(ticks=range(len(note_labels)), labels=note_labels)
    plt.colorbar(label='Intensity')
    plt.show()

def chroma_guidance_callback(model, target_audio, step_scale, in_dict):
    denoised = in_dict['denoised']  # model prediction (latent audio or waveform)
    x = in_dict['x']
    print(f"t = {in_dict['t']:.3f}, denoised..shape, .requires_grad = {denoised.shape},, {denoised.requires_grad}") 

    # Convert latent -> waveform
    autoencoder = model._modules['pretransform']._modules.get("model")
    pred_audio = autoencoder.decoder(denoised.half())
    pred_audio = einops.rearrange(pred_audio, "b d n -> d (b n)")

    # Create chroma metric/loss
    chroma_metric = ChromaCosineSimilarityMetric(
        sample_rate=autoencoder.sample_rate,
        n_chroma=12,
        radix2_exp=12,
        argmax=False
    )

    # Fake wrappers for batch info
    B = pred_audio.shape[0]
    sizes = torch.tensor([pred_audio.shape[-1]] * B, device=pred_audio.device)
    sample_rates = torch.tensor([44100] * B, device=pred_audio.device)

    # Compute chroma features
    with torch.enable_grad():
        x.requires_grad = True
        denoised.requires_grad = True

        # Compute similarity
        chroma_metric.update(pred_audio, target_audio, sizes, sample_rates)
        similarity = chroma_metric.compute()

        print("chromagram.shape =",pred_audio.shape,", chromagram.requires_grad =", pred_audio.requires_grad)
        plot_chromagram(pred_audio, sample_rate=autoencoder.sample_rate, winhop=ChromaCosineSimilarityMetric.chroma_extractor.winhop)

        # Compute similarity
        chroma_metric.update(pred_audio, target_audio, sizes, sample_rates)
        similarity = chroma_metric.compute()

        # Convert to loss (maximize similarity â†’ minimize 1 - similarity)
        chroma_loss = 1.0 - similarity
        grad_x = torch.autograd.grad(chroma_loss, x, grad_outputs=torch.ones_like(chroma_loss), retain_graph=False)[0]
        d_denoised = -step_scale * grad_x
        denoised = denoised + d_denoised

    denoised.requires_grad = False
    x.requires_grad = False
    # Return or backpropagate
    return

def trim_audio_seconds(wav: torch.Tensor, sample_rate: int, duration_s: float):
    """Trim (or pad) an audio tensor to the desired duration in seconds."""
    target_len = int(sample_rate * duration_s)
    wav = wav[..., :target_len]  # Trim
    return wav


In [None]:
import librosa
from functools import partial 

# Load a .wav file
audio_path = r"C:\Users\simeo\VSCodeProjects\StableAudioProject\stable-audio-tools\panflute_scale.wav"

wav, sr = librosa.load(audio_path)   # wav shape: (channels, samples)
wav = torch.tensor(wav).unsqueeze(0).to(device)
wav = trim_audio_seconds(wav, sample_rate=sr, duration_s=11.888616780045352)
print(type(wav))
print(wav.shape)
print(sr)
callback_wrapper = partial(chroma_guidance_callback, model, target_audio=wav, step_scale=0.1)


conditioning = [{
    "prompt": "mono jazz saxophone solo",
    "seconds_start": 0,
    "seconds_total": 10
}]

# Generate stereo audio
output = generate_diffusion_cond(
    model,
    # Marco's Notes:
    # 7 steps works good for sao small, higher than that gets scary
    # If using normal sao higher steps is usually pretty good.
    steps=7,
    cfg_scale=1, # Config of 1 often good for small, higher works on normal
    conditioning=conditioning,
    sample_size=sample_size,
    sigma_min=.3,
    sigma_max=500,
    #sampler_type="dpmpp-3m-sde",  # Use this for normal open
    sampler_type="pingpong",  # Use this for small
    device=device,
    callback=callback_wrapper,
    seed=1234,
)

<class 'torch.Tensor'>
torch.Size([1, 1357771])
22050
1234


  with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad):


  0%|          | 0/7 [00:00<?, ?it/s]

KeyboardInterrupt: 