# 🎸 Generating Music using MusicGen and W&B 🐝

In this notebook we demonstrate how you can generate music from text prompts or generate new music from existing music using the MusicGen model from [Audiocraft](https://github.com/facebookresearch/audiocraft) and play and visualize them using [Weights & Biases](https://wandb.ai/site).

In [None]:
# @title Install AudioCraft + WandB
!pip install -q -U audiocraft wandb

In [None]:
# @title
import os
import random
from tempfile import TemporaryDirectory

from scipy import signal
from scipy.io import wavfile

import torchaudio
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write

import wandb
import numpy as np
from tqdm.auto import tqdm
from google.colab import files
import matplotlib.pyplot as plt

In [None]:
# @title ## MusicGen Configs

# @markdown WandB Project Name
project_name = "audiocraft" # @param {type:"string"}

wandb.init(project=project_name, job_type="musicgen/inference")

config = wandb.config

# @markdown Select the MusicGen variant
config.model_name = "small" # @param ["small", "medium", "large", "melody"]

# @markdown ---
# @markdown ## Conditional Generation Configs

# @markdown The prompt for generating audio. You can give multiple prompts separated by `|` in the input. You can also leave it blank for unconditional generation.
config.prompts = "happy rock | energetic EDM | sad jazz" # @param {type:"string"}

descriptions = [prompt.strip() for prompt in config.prompts.split("|")]
config.is_unconditional = config.prompts.strip() == ""

# @markdown **Note:** If you have provided prompts, you will be prompted to provide an audio file in addition to the prompts to condition the model. If you don't want to provide a file as an additional condition to the model, just press on the `cancel` button.
input_audio, input_sampling_rate, wandb_input_audio = None, None, None
if not config.is_unconditional:
    input_audio_file = files.upload()
    if input_audio_file != {}:
        wandb_input_audio = wandb.Audio(list(input_audio_file.keys())[0])
        input_audio, input_sampling_rate = torchaudio.load(
            list(input_audio_file.keys())[0]
        )
        config.input_audio_available = True
    else:
        config.input_audio_available = False

# @markdown Number of audio samples generated, this is relevant only for unconditional generation, i.e, if `config.prompts` is left blank.
config.num_samples = 4 # @param {type:"slider", min:1, max:10, step:1}

# @markdown Specify the random seed
seed = None # @param {type:"raw"}

max_seed = int(1024 * 1024 * 1024)
if not isinstance(seed, int):
    seed = random.randint(1, max_seed)
if seed < 0:
    seed = - seed
seed = seed % max_seed
config.seed = seed

# @markdown ---
# @markdown ## Generation Parameters
# @markdown Use sampling if True, else do argmax decoding
config.use_sampling = True # @param {type:"boolean"}

# @markdown `top_k` used for sampling; limits us to `k` number of  of the top tokens to consider.
config.top_k = 250 # @param {type:"slider", min:0, max:1000, step:1}

# @markdown `top_p` used for sampling; limits us to the top tokens within a probability mass `p`
config.top_p = 0.0 # @param {type:"slider", min:0, max:1.0, step:0.01}

# @markdown Softmax temperature parameter
config.temperature = 1.0 # @param {type:"slider", min:0, max:1.0, step:0.01}

# @markdown Duration of the generated waveform
config.duration = 10 # @param {type:"slider", min:1, max:30, step:1}

# @markdown Coefficient used for classifier free guidance
config.cfg_coef = 3 # @param {type:"slider", min:1, max:100, step:1}

# @markdown Whether to perform 2 forward for Classifier Free Guidance instead of batching together the two. This has some impact on how things are padded but seems to have little impact in practice.
config.two_step_cfg = False # @param {type:"boolean"}

# @markdown When doing extended generation (i.e. more than 30 seconds), by how much should we extend the audio each time. Larger values will mean less context is preserved, and shorter value will require extra computations.
config.extend_stride = 18 # @param {type:"slider", min:1, max:30, step:1}

In [None]:
# @title Generate Audio using MusicGen

model = MusicGen.get_pretrained(config.model_name)
model.set_generation_params(
    use_sampling=config.use_sampling,
    top_k=config.top_k,
    top_p=config.top_p,
    temperature=config.temperature,
    duration=config.duration,
    cfg_coef=config.cfg_coef,
    two_step_cfg=config.two_step_cfg,
    extend_stride=config.extend_stride
)

generated_wav = None
if config.is_unconditional:
    if input_audio is None:
        generated_wav = model.generate_unconditional(
            num_samples=config.num_samples, progress=True
        )
    else:
        generated_wav = model.generate_with_chroma(
            descriptions,
            input_audio[None].expand(3, -1, -1),
            input_sampling_rate
        )
else:
    generated_wav = model.generate(descriptions, progress=True)

In [None]:
# @title Log Audio to Weights & Biases Dashboard

def get_spectrogram(audio_file, output_file):
    sample_rate, samples = wavfile.read(audio_file)
    frequencies, times, Sxx = signal.spectrogram(samples, sample_rate)

    log_Sxx = 10 * np.log10(Sxx + 1e-10)
    vmin = np.percentile(log_Sxx, 5)
    vmax = np.percentile(log_Sxx, 95)

    mean_spectrum = np.mean(log_Sxx, axis=1)
    threshold_low = np.percentile(mean_spectrum, 5)
    threshold_high = np.percentile(mean_spectrum, 95)

    freq_indices = np.where(mean_spectrum > threshold_low)
    freq_min = 20
    freq_max = frequencies[freq_indices].max()

    fig, ax = plt.subplots()
    cmap = plt.get_cmap('magma')

    ax.pcolormesh(
        times,
        frequencies,
        log_Sxx,
        shading='gouraud',
        cmap=cmap,
        vmin=vmin,
        vmax=vmax
    )
    ax.axis('off')
    ax.set_ylim([freq_min, freq_max])

    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
    plt.savefig(
        output_file, format='png', bbox_inches='tight', pad_inches=0
    )
    plt.close()

    return wandb.Image(output_file)


temp_dir = TemporaryDirectory()
columns = ["Prompt", "Generated-Audio", "Spectrogram", "Seed"]
if input_audio is not None:
    columns.insert(1, "Input-Audio")
wandb_table = wandb.Table(columns=columns)

for idx, wav in enumerate(generated_wav):
    file_name = os.path.join(temp_dir.name, str(idx))
    audio_write(
        file_name,
        wav.cpu(),
        model.sample_rate,
        strategy="loudness",
        loudness_compressor=True,
    )
    wandb_audio = wandb.Audio(file_name +  ".wav")
    wandb.log({"Generated-Audio": wandb_audio})
    desc = descriptions[idx] if len(descriptions) > 1 else config.prompts
    wandb_table_row = [
        desc,
        wandb_audio,
        get_spectrogram(
            audio_file=file_name +  ".wav",
            output_file=os.path.join(temp_dir.name, str(idx) + ".png")
        ),
        config.seed
    ]
    if input_audio is not None:
        wandb_table_row.insert(1, wandb_input_audio)
    wandb_table.add_data(*wandb_table_row)

wandb.log({"Generated-Audio-Table": wandb_table})

wandb.finish()
temp_dir.cleanup()