In [None]:
%pip install -q --extra-index-url https://download.pytorch.org/whl/cpu Pillow scipy "torch>=2.1" torchaudio "diffusers>=0.16.1" "transformers>=4.33.0"
%pip install -q "git+https://github.com/huggingface/optimum-intel.git" "gradio>=3.34.0"
%pip install -q -U --pre --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly "openvino>=2025.1" "openvino-genai>=2025.1"

In [None]:
%pip install openvino-tokenizers

In [2]:
import transformers
transformers.__version__
from transformers import GlmModel

In [9]:
from pathlib import Path
import requests
import openvino_genai as ov_genai
import io
import numpy as np
from PIL import Image
from scipy.io import wavfile
import torch
import torchaudio
import IPython.display as ipd
import os
import sys

In [13]:
if not Path("notebook_utils.py").exists():
    r = requests.get(
        url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/notebook_utils.py",
    )
    open("notebook_utils.py", "w").write(r.text)

if not Path("cmd_helper.py").exists():
    r = requests.get(url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/utils/cmd_helper.py")
    open("cmd_helper.py", "w").write(r.text)

if not Path("gradio_helper.py").exists():
    r = requests.get(url="https://raw.githubusercontent.com/openvinotoolkit/openvino_notebooks/latest/notebooks/riffusion-text-to-music/gradio_helper.py")
    open("gradio_helper.py", "w").write(r.text)

from gradio_helper import make_demo

# Read more about telemetry collection at https://github.com/openvinotoolkit/openvino_notebooks?tab=readme-ov-file#-telemetry
from notebook_utils import collect_telemetry

collect_telemetry("riffusion-text-to-music.ipynb")

MODEL_ID = "riffusion/riffusion-model-v1"
MODEL_DIR = Path("riffusion_pipeline")

In [None]:
from huggingface_hub import configure_http_backend
from optimum.intel.openvino import OVStableDiffusionPipeline

In [7]:
def backend_factory() -> requests.Session:
    session = requests.Session()
    session.verify = False
    return session

configure_http_backend(backend_factory=backend_factory)

In [None]:
pipe=OVStableDiffusionPipeline.from_pretrained(
    MODEL_ID,
    EXPORT=True,
    device="CPU",
    compile=False
)

In [7]:
from transformers import CLIPTokenizerFast
from openvino_tokenizers import convert_tokenizer
from openvino.runtime import serialize



In [None]:
hf_tok = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch16")
ov_tok, ov_detok = convert_tokenizer(hf_tok, with_detokenizer=True)
pipe.save_pretrained("riffuson_pipeline", save_config=True)
pipe = ov_genai.OVGenAIPipeline.from_pretrained(MODEL_ID, device="CPU", compile=False)
pipe

In [11]:
def wav_bytes_from_spectrogram_image(image: Image.Image) -> tuple[io.BytesIO, float]:
    """
    Reconstruct a WAV audio clip from a spectrogram image. Also returns the duration in seconds.

    Parameters:
      image (Image.Image): generated spectrogram image
    Returns:
      wav_bytes (io.BytesIO): audio signal encoded in wav bytes
      duration_s (float): duration in seconds
    """

    max_volume = 50
    power_for_image = 0.25
    Sxx = spectrogram_from_image(image, max_volume=max_volume, power_for_image=power_for_image)

    sample_rate = 44100  # [Hz]
    clip_duration_ms = 5000  # [ms]

    bins_per_image = 512
    n_mels = 512

    # FFT parameters
    window_duration_ms = 100  # [ms]
    padded_duration_ms = 400  # [ms]
    step_size_ms = 10  # [ms]

    # Derived parameters
    num_samples = int(image.width / float(bins_per_image) * clip_duration_ms) * sample_rate
    n_fft = int(padded_duration_ms / 1000.0 * sample_rate)
    hop_length = int(step_size_ms / 1000.0 * sample_rate)
    win_length = int(window_duration_ms / 1000.0 * sample_rate)

    samples = waveform_from_spectrogram(
        Sxx=Sxx,
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        num_samples=num_samples,
        sample_rate=sample_rate,
        mel_scale=True,
        n_mels=n_mels,
        num_griffin_lim_iters=32,
    )

    wav_bytes = io.BytesIO()
    wavfile.write(wav_bytes, sample_rate, samples.astype(np.int16))
    wav_bytes.seek(0)

    duration_s = float(len(samples)) / sample_rate

    return wav_bytes, duration_s


def spectrogram_from_image(image: Image.Image, max_volume: float = 50, power_for_image: float = 0.25) -> np.ndarray:
    """
    Compute a spectrogram magnitude array from a spectrogram image.

    Parameters:
      image (image.Image): input image
      max_volume (float, *optional*, 50): max volume for spectrogram magnitude
      power_for_image (float, *optional*, 0.25): power for reversing power curve
    """
    # Convert to a numpy array of floats
    data = np.array(image).astype(np.float32)

    # Flip Y take a single channel
    data = data[::-1, :, 0]

    # Invert
    data = 255 - data

    # Rescale to max volume
    data = data * max_volume / 255

    # Reverse the power curve
    data = np.power(data, 1 / power_for_image)

    return data


def waveform_from_spectrogram(
    Sxx: np.ndarray,
    n_fft: int,
    hop_length: int,
    win_length: int,
    num_samples: int,
    sample_rate: int,
    mel_scale: bool = True,
    n_mels: int = 512,
    num_griffin_lim_iters: int = 32,
    device: str = "cpu",
) -> np.ndarray:
    """
    Reconstruct a waveform from a spectrogram.
    This is an approximate waveform, using the Griffin-Lim algorithm
    to approximate the phase.
    """
    Sxx_torch = torch.from_numpy(Sxx).to(device)

    if mel_scale:
        mel_inv_scaler = torchaudio.transforms.InverseMelScale(
            n_mels=n_mels,
            sample_rate=sample_rate,
            f_min=0,
            f_max=10000,
            n_stft=n_fft // 2 + 1,
            norm=None,
            mel_scale="htk",
        ).to(device)

        Sxx_torch = mel_inv_scaler(Sxx_torch)

    griffin_lim = torchaudio.transforms.GriffinLim(
        n_fft=n_fft,
        win_length=win_length,
        hop_length=hop_length,
        power=1.0,
        n_iter=num_griffin_lim_iters,
    ).to(device)

    waveform = griffin_lim(Sxx_torch).cpu().numpy()

    return waveform

In [12]:
def generate(prompt: str, negative_prompt: str = "") -> tuple[Image.Image, str]:
    """
    function for generation audio from text prompt

    Parameters:
      prompt (str): input prompt for generation.
      negative_prompt (str): negative prompt for generation, contains undesired concepts for generation, which should be avoided. Can be empty.
    Returns:
      spec (Image.Image) - generated spectrogram image
    """
    spec_tokens = pipe.generate(prompt, negative_prompt=negative_prompt, num_inference_steps=20)
    spec = Image.fromarray(spec_tokens.data[0])
    wav = wav_bytes_from_spectrogram_image(spec)
    with open("output.wav", "wb") as f:
        f.write(wav[0].getbuffer())
    return spec, "output.wav"

In [None]:
tok_dir = Path("riffusion_pipeline/tokenizer")
tok_dir.mkdir(parents=True, exist_ok=True)
serialize(ov_tok, tok_dir / "openvino_tokenizer.xml")
serialize(ov_detok, tok_dir / "openvino_detokenizer.xml")

In [None]:
spectrogram, wav_path = generate("Techno beat")

In [None]:
spectrogram

In [None]:
ipd.Audio(wav_path)

In [None]:
def select_device(device_str: str, current_text: str = "", progress: gr.Progress = gr.Progress()):
    """
    Helper function for uploading model on the device.

    Parameters:
      device_str (str): Device name.
      current_text (str): Current content of user instruction field (used only for backup purposes, temporally replacing it on the progress bar during model loading).
      progress (gr.Progress): gradio progress tracker
    Returns:
      current_text
    """
    if device_str != pipe._device:
        pipe.clear_requests()
        pipe.to(device_str)

        for i in progress.tqdm(range(1), desc=f"Model loading on {device_str}"):
            pipe.compile()
    return current_text

In [None]:
demo = make_demo(generate_fn=generate, select_device_fn=select_device)

try:
    demo.queue().launch(debug=True, height=800)
except Exception:
    demo.queue().launch(debug=True, share=True, height=800)

In [None]:
import io
import numpy as np
from PIL import Image
from scipy.io import wavfile
from scipy import signal
import torch
import torchaudio

def improved_wav_bytes_from_spectrogram_image(image: Image.Image) -> tuple[io.BytesIO, float]:
    """
    Reconstruct a WAV audio clip from a spectrogram image with improved quality.
    This version prioritizes stability and compatibility.
    
    Parameters:
      image (Image.Image): generated spectrogram image
    Returns:
      wav_bytes (io.BytesIO): audio signal encoded in wav bytes
      duration_s (float): duration in seconds
    """
    
    # Improved but conservative parameters
    max_volume = 25  # Reduced to prevent distortion
    power_for_image = 0.3  # Slightly adjusted for better dynamics
    
    Sxx = improved_spectrogram_from_image(image, max_volume=max_volume, power_for_image=power_for_image)

    sample_rate = 44100  # [Hz]
    clip_duration_ms = 5000  # [ms]

    bins_per_image = 512
    n_mels = 512

    # FFT parameters
    window_duration_ms = 100  # [ms]
    padded_duration_ms = 400  # [ms]
    step_size_ms = 10  # [ms]

    # Derived parameters
    num_samples = int(image.width / float(bins_per_image) * clip_duration_ms) * sample_rate
    n_fft = int(padded_duration_ms / 1000.0 * sample_rate)
    hop_length = int(step_size_ms / 1000.0 * sample_rate)
    win_length = int(window_duration_ms / 1000.0 * sample_rate)

    # Use improved but stable waveform generation
    samples = improved_waveform_from_spectrogram(
        Sxx=Sxx,
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        num_samples=num_samples,
        sample_rate=sample_rate,
        mel_scale=True,
        n_mels=n_mels,
        num_griffin_lim_iters=40,  # Moderate increase for better quality
    )

    # Apply safe audio enhancements
    samples = apply_safe_audio_enhancements(samples, sample_rate)

    # Normalize and convert to WAV
    wav_bytes = samples_to_wav_bytes(samples, sample_rate)
    duration_s = float(len(samples)) / sample_rate

    return wav_bytes, duration_s


def improved_waveform_from_spectrogram(
    Sxx: np.ndarray,
    n_fft: int,
    hop_length: int,
    win_length: int,
    num_samples: int,
    sample_rate: int,
    mel_scale: bool = True,
    n_mels: int = 512,
    num_griffin_lim_iters: int = 40,
    device: str = "cpu",
) -> np.ndarray:
    """
    Improved but stable waveform reconstruction from spectrogram.
    """
    # Force CPU for maximum compatibility
    device = "cpu"
    
    Sxx_torch = torch.from_numpy(Sxx).to(device).float()

    if mel_scale:
        mel_inv_scaler = torchaudio.transforms.InverseMelScale(
            n_mels=n_mels,
            sample_rate=sample_rate,
            f_min=0,
            f_max=10000,  # Conservative upper frequency
            n_stft=n_fft // 2 + 1,
            norm=None,
            mel_scale="htk",
        ).to(device)

        Sxx_torch = mel_inv_scaler(Sxx_torch)

    # Use Griffin-Lim with safe parameters
    griffin_lim = torchaudio.transforms.GriffinLim(
        n_fft=n_fft,
        win_length=win_length,
        hop_length=hop_length,
        power=1.0,
        n_iter=num_griffin_lim_iters,
        momentum=0.9,  # Add some momentum for better convergence
        length=num_samples,
        rand_init=False,  # Disable random init for stability
    ).to(device)

    try:
        waveform = griffin_lim(Sxx_torch).cpu().numpy()
    except Exception as e:
        print(f"Griffin-Lim failed: {e}, using basic reconstruction")
        # Fallback to simple ISTFT without Griffin-Lim
        waveform = basic_istft_reconstruction(Sxx_torch, n_fft, hop_length, win_length, num_samples)

    return waveform


def basic_istft_reconstruction(Sxx_torch, n_fft, hop_length, win_length, num_samples):
    """
    Basic ISTFT reconstruction as fallback.
    """
    try:
        # Create zero phase
        zero_phase = torch.zeros_like(Sxx_torch)
        complex_spec = torch.complex(Sxx_torch, zero_phase)
        
        # Create window
        window = torch.hann_window(win_length, device=Sxx_torch.device)
        
        # ISTFT
        waveform = torch.istft(
            complex_spec,
            n_fft=n_fft,
            hop_length=hop_length,
            win_length=win_length,
            window=window,
            length=num_samples,
            return_complex=False
        )
        return waveform.cpu().numpy()
    except Exception as e:
        print(f"Basic ISTFT also failed: {e}, returning silence")
        return np.zeros(num_samples, dtype=np.float32)


def apply_safe_audio_enhancements(samples: np.ndarray, sample_rate: int) -> np.ndarray:
    """
    Apply safe audio enhancements that won't cause errors.
    """
    # 1. Simple clipping prevention
    samples = np.clip(samples, -1.0, 1.0)
    
    # 2. Basic high-pass filter to remove rumble
    try:
        # Very conservative filter
        nyquist = sample_rate / 2
        low_cutoff = 80 / nyquist
        b, a = signal.butter(2, low_cutoff, btype='high')
        samples = signal.filtfilt(b, a, samples)
    except Exception as e:
        print(f"High-pass filter failed: {e}, skipping")
    
    # 3. Remove DC offset
    samples = samples - np.mean(samples)
    
    # 4. Simple normalization with headroom
    peak = np.max(np.abs(samples))
    if peak > 0:
        samples = samples * (0.8 / peak)  # Leave 20% headroom
    
    # 5. Gentle fade in/out to prevent clicks
    fade_length = min(int(0.01 * sample_rate), len(samples) // 10)  # 10ms or 10% of length
    if fade_length > 0:
        # Fade in
        fade_in = np.linspace(0, 1, fade_length)
        samples[:fade_length] *= fade_in
        
        # Fade out
        fade_out = np.linspace(1, 0, fade_length)
        samples[-fade_length:] *= fade_out
    
    return samples


def samples_to_wav_bytes(samples: np.ndarray, sample_rate: int) -> io.BytesIO:
    """
    Convert audio samples to WAV bytes with proper formatting.
    """
    # Ensure samples are in valid range
    samples = np.clip(samples, -1.0, 1.0)
    
    # Convert to 16-bit integers
    samples_int16 = (samples * 32767).astype(np.int16)
    
    # Create WAV bytes
    wav_bytes = io.BytesIO()
    wavfile.write(wav_bytes, sample_rate, samples_int16)
    wav_bytes.seek(0)
    
    return wav_bytes


def improved_spectrogram_from_image(image: Image.Image, max_volume: float = 25, power_for_image: float = 0.3) -> np.ndarray:
    """
    Improved spectrogram extraction from image with better preprocessing.
    """
    # Convert to numpy array
    data = np.array(image).astype(np.float32)

    # Handle different image formats
    if len(data.shape) == 3:
        # Convert RGB to grayscale if needed
        if data.shape[2] >= 3:
            # Use luminance formula for better conversion
            data = 0.299 * data[:, :, 0] + 0.587 * data[:, :, 1] + 0.114 * data[:, :, 2]
        else:
            data = data[:, :, 0]
    
    # Flip Y axis (spectrogram convention)
    data = data[::-1, :]

    # Invert colors (dark = low energy, bright = high energy)
    data = 255 - data

    # Rescale to max volume
    data = data * max_volume / 255

    # Apply power curve with numerical stability
    data = np.maximum(data, 1e-10)  # Prevent zeros
    data = np.power(data, 1 / power_for_image)

    return data