# Riffusion AI music generator(Gardio)

## Credit to [Riffusion](https://github.com/hmartiro/riffusion-inference) and [Amrrs](https://github.com/amrrs/ai-music-video) project.


### Using google colab to use this notebook is highly recommended

In [None]:
!nvidia-smi

# Install Following Libraries 

In [None]:
!pip install -q https://github.com/camenduru/stable-diffusion-webui-colab/releases/download/0.0.15/xformers-0.0.15.dev0+189828c.d20221207-cp38-cp38-linux_x86_64.whl

In [None]:
! pip install -U transformers diffusers gradio ftfy pydub -q 

In [169]:
"""
Audio processing tools to convert between spectrogram images and waveforms.
"""
import io
import typing as T

import numpy as np
from PIL import Image
import pydub
from scipy.io import wavfile
import torch
import torchaudio


def wav_bytes_from_spectrogram_image(image: Image.Image) -> T.Tuple[io.BytesIO, float]:
    """
    Reconstruct a WAV audio clip from a spectrogram image. Also returns the duration in seconds.
    """

    max_volume = 50
    power_for_image = 0.25
    Sxx = spectrogram_from_image(image, max_volume=max_volume, power_for_image=power_for_image)

    sample_rate = 44100  # [Hz]
    clip_duration_ms = 5000  # [ms] (duration fixed at 5.11 sec)

    bins_per_image = 512
    n_mels = 512

    # FFT parameters
    window_duration_ms = 100  # [ms]
    padded_duration_ms = 400  # [ms]
    step_size_ms = 10  # [ms]

    # Derived parameters
    num_samples = int(image.width / float(bins_per_image) * int(clip_duration_ms)) * sample_rate

    print(image.width / float(bins_per_image))

    n_fft = int(padded_duration_ms / 1000.0 * sample_rate)
    hop_length = int(step_size_ms / 1000.0 * sample_rate)
    win_length = int(window_duration_ms / 1000.0 * sample_rate)

    samples = waveform_from_spectrogram(
        Sxx=Sxx,
        n_fft=n_fft,
        hop_length=hop_length,
        win_length=win_length,
        num_samples=num_samples,
        sample_rate=sample_rate,
        mel_scale=True,
        n_mels=n_mels,
        max_mel_iters=200,
        num_griffin_lim_iters=32,
    )

    wav_bytes = io.BytesIO()
    wavfile.write(wav_bytes, sample_rate, samples.astype(np.int16))
    wav_bytes.seek(0)

    duration_s = float(len(samples)) / sample_rate

    return wav_bytes, duration_s


def spectrogram_from_image(
    image: Image.Image, max_volume: float = 50, power_for_image: float = 0.25
) -> np.ndarray:
    """
    Compute a spectrogram magnitude array from a spectrogram image.
    """
    # Convert to a numpy array of floats
    data = np.array(image).astype(np.float32)

    # Flip Y take a single channel
    data = data[::-1, :, 0]

    # Invert
    data = 255 - data

    # Rescale to max volume
    data = data * max_volume / 255 

    # Reverse the power curve
    data = np.power(data, 1 / power_for_image)

    return data

def image_from_spectrogram(
    spectrogram: np.ndarray, max_volume: float = 50, power_for_image: float = 0.25
) -> Image.Image:
    """
    Compute a spectrogram image from a spectrogram magnitude array.
    """
    # Apply the power curve
    data = np.power(spectrogram, power_for_image)

    # Rescale to 0-1
    data = data / np.max(data)

    # Rescale to 0-255
    data = data * 255

    # Invert
    data = 255 - data

    # Convert to a PIL image
    image = Image.fromarray(data.astype(np.uint8))

    # Flip Y
    image = image.transpose(Image.FLIP_TOP_BOTTOM)

    # Convert to RGB
    image = image.convert("RGB")

    return image

def spectrogram_from_waveform(
    waveform: np.ndarray,
    sample_rate: int,
    n_fft: int,
    hop_length: int,
    win_length: int,
    mel_scale: bool = True,
    n_mels: int = 512,
) -> np.ndarray:
    """
    Compute a spectrogram from a waveform.
    """

    spectrogram_func = torchaudio.transforms.Spectrogram(
        n_fft=n_fft,
        power=None,
        hop_length=hop_length,
        win_length=win_length,
    )

    waveform_tensor = torch.from_numpy(waveform.astype(np.float32)).reshape(1, -1)
    Sxx_complex = spectrogram_func(waveform_tensor).numpy()[0]

    Sxx_mag = np.abs(Sxx_complex)

    if mel_scale:
        mel_scaler = torchaudio.transforms.MelScale(
            n_mels=n_mels,
            sample_rate=sample_rate,
            f_min=0,
            f_max=10000,
            n_stft=n_fft // 2 + 1,
            norm=None,
            mel_scale="htk",
        )

        Sxx_mag = mel_scaler(torch.from_numpy(Sxx_mag)).numpy()

    return Sxx_mag


def waveform_from_spectrogram(
    Sxx: np.ndarray,
    n_fft: int,
    hop_length: int,
    win_length: int,
    num_samples: int,
    sample_rate: int,
    mel_scale: bool = True,
    n_mels: int = 512,
    max_mel_iters: int = 200,
    num_griffin_lim_iters: int = 32,
    device: str = "cuda:0",
) -> np.ndarray:
    """
    Reconstruct a waveform from a spectrogram.
    This is an approximate inverse of spectrogram_from_waveform, using the Griffin-Lim algorithm
    to approximate the phase.
    """
    Sxx_torch = torch.from_numpy(Sxx).to(device)

    if mel_scale:
        mel_inv_scaler = torchaudio.transforms.InverseMelScale(
            n_mels=n_mels,
            sample_rate=sample_rate,
            f_min=0,
            f_max=10000,
            n_stft=n_fft // 2 + 1,
            norm=None,
            mel_scale="htk",
            max_iter=max_mel_iters,
        ).to(device)

        Sxx_torch = mel_inv_scaler(Sxx_torch)

    griffin_lim = torchaudio.transforms.GriffinLim(
        n_fft=n_fft,
        win_length=win_length,
        hop_length=hop_length,
        power=1.0,
        n_iter=num_griffin_lim_iters,
    ).to(device)

    waveform = griffin_lim(Sxx_torch).cpu().numpy()

    return waveform


def mp3_bytes_from_wav_bytes(wav_bytes: io.BytesIO) -> io.BytesIO:
    mp3_bytes = io.BytesIO()
    sound = pydub.AudioSegment.from_wav(wav_bytes)
    sound.export(mp3_bytes, format="mp3")
    mp3_bytes.seek(0)
    return mp3_bytes

# Import all required Libraries 

In [2]:
import gradio as gr

import torch
from diffusers import StableDiffusionPipeline


# Load the model from Hugging Face Model Hub

In [None]:
model_id = "riffusion/riffusion-model-v1"
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)

In [5]:
pipe = pipe.to("cuda")
pipe.enable_xformers_memory_efficient_attention()

# Create the core Audio Function

In [6]:
import random
COLORS = [
    ["#ff0000", "#00ff00"],
    ["#00ff00", "#0000ff"],
    ["#0000ff", "#ff0000"],
]    
        


In [None]:
from diffusers import StableDiffusionPipeline
import torch

img_model_id = "runwayml/stable-diffusion-v1-5"
img_pipe = StableDiffusionPipeline.from_pretrained(img_model_id, torch_dtype=torch.float16, revision="fp16")
img_pipe = img_pipe.to("cuda")
img_pipe.enable_xformers_memory_efficient_attention()

In [190]:
# Rejects nfsw prompts. If any found retry another description.
prompt = 'skeleton man in a basement in ohio'

In [None]:
# generate WAV file
spectogram = pipe(prompt).images[0]
wav = wav_bytes_from_spectrogram_image(spectogram)
with open("output.wav", "wb") as f:
    # print(wav[1])
    f.write(wav[0].getbuffer())


In [None]:
image = img_pipe(prompt + ", photo realsitic, emotionally evocative, a thing of beauty beyond imagination or words").images[0]

In [None]:
image #sample view image

In [64]:
image.save("image.png")

In [171]:
video_path = gr.make_waveform('output.wav', bg_image='image.png', bars_color=random.choice(COLORS))

In [None]:
from IPython.display import HTML
from base64 import b64encode
 
def show_video(video_path, video_width = 600):
   
  video_file = open(video_path, "r+b").read()
 
  video_url = f"data:video/mp4;base64,{b64encode(video_file).decode()}"
  return HTML(f"""<video width={video_width} controls><source src="{video_url}"></video>""")
 
show_video(video_path)

In [176]:
def audio_gen(prompt):
    spectogram = pipe(prompt).images[0]
    wav = wav_bytes_from_spectrogram_image(spectogram)
    with open("output.wav", "wb") as f:
        f.write(wav[0].getbuffer())
    print("audio saved")
    print("image started")
    txt_prompt = prompt + ", artstation hall of fame gallery, editors choice, #1 digital painting of all time, most beautiful image ever created, emotionally evocative, greatest art ever made, lifetime achievement magnum opus masterpiece, the most amazing breathtaking image with the deepest message ever painted, a thing of beauty beyond imagination or words"
    image = img_pipe(txt_prompt).images[0] 
    image.save("image.png") 
    print("image saved")
    video = gr.make_waveform('output.wav', bg_image='image.png', bars_color=random.choice(COLORS))
    print("video done!")
    return ('output.wav',video)


In [None]:
audio_gen(prompt)

In [None]:
# gardio interface hosted locally
gr.Interface(
    audio_gen,
    inputs=[gr.Textbox(label="prompt")],
    outputs=[
        gr.Audio(type='filepath'),
        gr.Video(type='filepath')
    ],
    title = 'Riffusion Music Page'
).launch(debug = True)