##Setup

This cell installs the necessary libraries to run the code, loads the pre-trained MusicGen small model (https://huggingface.co/facebook/musicgen-small), and initializes the generation pipeline. MusicGen generates music from text descriptions by converting text into audio tokens, predicting them using a transformer, and decoding them into 32kHz audio waveforms.


In [None]:
import torch
import torchaudio
import random
import soundfile as sf
import numpy as np
import pandas as pd


from scipy.signal import butter, sosfilt
from transformers import AutoProcessor, MusicgenForConditionalGeneration
from IPython.display import Audio, display

In [None]:
!pip install -q transformers accelerate

device = "cuda" if torch.cuda.is_available() else "cpu"

# We use the 'musicgen-small' variant to balance audio quality
# and computational efficiency, making the pipeline suitable
# for Colab.

model_id = "facebook/musicgen-small"
processor = AutoProcessor.from_pretrained(model_id)

model = MusicgenForConditionalGeneration.from_pretrained(
    model_id,
    dtype=torch.float16 if device == "cuda" else torch.float32,
    device_map="auto"
)

print("Ready on:", device)


## Music Generation

We define a set of text prompts for different music genres and randomly select one for each genre. Then we use the MusicGen model to generate audio for each selected prompt. The generated audio is saved as WAV files.

Prompts are intentionally descriptive: this choice ensures that differences in output quality are primarily due to the generation process and not prompt bias


In [None]:
genre_prompts = {
    "pop": [
        "Upbeat pop track with bright synths, steady four-on-the-floor drums, catchy melody throughout, consistent energy",
        "Emotional pop song with strings, soft drums and piano, 80 BPM",
        "Modern pop song with groovy bass and energetic drums",
        "Dance-pop track with four-on-the-floor kick, pulsing bass and sparkling keys",
        "Mid-tempo pop track with airy pads, dreamy atmosphere and soft percussion"
    ],
    "rock": [
        "Energetic rock track with electric guitars, powerful bass, consistent energy throughout",
        "Classic rock riff with electric guitar, steady bass and bluesy feel",
        "Slow rock ballad with emotional lead guitar, soft drums, building intensity",
        "Indie rock track with jangly guitars and tight drum groove, melodic bass, upbeat throughout",
        "Heavy rock track with aggressive rhythm guitars, palm-muted riffs, thunderous drums, 130 BPM"
    ],
    "jazz": [
        "Jazz song with saxophone lead melody, walking double bass, piano comping, brushed drums throughout",
        "Smooth jazz track with electric piano and soft saxophone melody",
        "Jazz fusion track with syncopated drums and fretless bass, energetic, 120 BPM",
        "Big band jazz swing with brass section, walking bass, piano and saxophone",
        "Atmospheric jazz ballad with trumpet and soft piano"
    ],
    "classical": [
        "Romantic classical song for solo piano with expressive dynamics and flowing arpeggios",
        "Piano playing a romantic, slow song, rich harmonies, cello melody, violin",
        "Full symphony orchestra with sweeping strings and bold brass, dramatic atmosphere",
        "Baroque-style piece with harpsichord and chamber ensemble, 110 BPM",
        "Soft classical piece for piano and cello duet with gentle piano accompaniment, expressive, 60 BPM"
    ],
    "hiphop": [
        "Hip-hop beat with dusty drums and vinyl crackle",
        "Modern trap beat with deep bass and rapid hi-hats",
        "Lo-fi hip-hop beat with warm piano samples and laid-back groove, chill atmosphere throughout",
        "Aggressive hip-hop instrumental with heavy kicks and brass stabs",
        "Chill hip-hop beat with soft Rhodes chords, subtle percussion, warm bass, relaxed vibe"
    ],
    "electronic": [
        "Ambient electronic soundscape with evolving pads and soft pulses, 80 BPM",
        "Melodic techno track with steady kick, hypnotic arpeggios, pulsing bass, building tension",
        "Deep house groove with warm bass and airy chords",
        "Dubstep-style track with bass and sharp snares",
        "Future bass track with detuned synth chords and sidechain pumping"
    ],
}

selected_prompts = {}
for genre, prompt_list in genre_prompts.items():
    chosen = random.choice(prompt_list)
    selected_prompts[genre] = chosen


print("Chosen prompts:")
for genre, prompt in selected_prompts.items():
    print(f"- {genre}: {prompt}")

# MusicGen operates on discrete audio tokens; the number of
# tokens directly controls output duration.

duration_tokens = 700
sr = model.config.audio_encoder.sampling_rate  # find the sampling rate
print(f"Sampling rate: {sr} Hz")

# Parameters are chosen to balance diversity and stability:
# - temperature controls randomness
# - top-k and top-p limit the sampling space
# - guidance_scale is kept moderate to avoid over-conditioning

for genre, prompt in selected_prompts.items():
    print(f"\nGeneration for '{genre}'...")
    print(f"Prompt: {prompt}")

    inputs = processor(
        text=[prompt],
        padding=True,
        return_tensors="pt"
    ).to(device)

    with torch.inference_mode():
        audio_values = model.generate(
            **inputs,
            do_sample=True,
            guidance_scale=5,
            max_new_tokens=duration_tokens,
            temperature=0.7,
            top_k=250,
            top_p=0.9,
        )

    # Output shape: (batch, channels, samples)
    audio = audio_values[0].detach().cpu().float().numpy()

    display(Audio(audio, rate=sr))

    filename = f"raw_{genre}.wav"
    sf.write(filename, audio.T, sr)
    print(f"Saved as {filename}")

## Preprocessing: Normalization and High-Pass Filtering
Through this function we first apply peak normalization to standardize signal amplitude across all samples.
This step ensures that differences observed during restoration are not driven by loudness variations.  

Then we apply a second-order Butterworth high-pass filter with a cutoff frequency of 30 Hz.
This removes subsonic components that do not contribute to perceptual audio quality but may interfere with restoration models.
This preprocessing step does not aim to improve audio quality directly, but to standardize the input signals.

In [None]:
def preprocess_normalize_highpass(audio, sr, hp_cutoff=30.0):
    audio = audio.astype(np.float32)

    # stereo to mono
    if audio.ndim == 2:
        audio = np.mean(audio, axis=1)

    # Peak normalization

    # Normalize amplitude to a fixed peak level to ensure
    # consistent loudness across samples

    peak = np.max(np.abs(audio))
    if peak > 0:
        audio = 0.99 * (audio / peak)

    # High-pass filtering
    # Remove very low-frequency components (e.g., DC offset,
    # that do not contribute to musical content
    # but may interfere with restoration models.

    sos_hp = butter(
        N=2,
        Wn=hp_cutoff,
        btype="highpass",
        fs=sr,
        output="sos"
    )

    audio_hp = sosfilt(sos_hp, audio)
    return audio_hp

# Apply preprocessing to all generated samples
genres = ["pop", "rock", "jazz", "classical", "hiphop", "electronic"]

for genre in genres:
    raw_filename = f"raw_{genre}.wav"
    out_filename = f"preproc_{genre}.wav"

    audio, sr = sf.read(raw_filename)
    audio_pre = preprocess_normalize_highpass(audio, sr, hp_cutoff=30.0)

    sf.write(out_filename, audio_pre, sr)
    print(f"Preprocess per {genre}: saved as {out_filename}")



###Preprocessing validation
We verify the effectiveness of peak normalization and high pass filter by measuring peak amplitude and low frequency energy ratio below 30 hz.

Peak amplitude is used to confirm that peak normalization successfully standardizes signal levels across samples.
The low-frequency energy ratio serves as an objective indicator of subsonic content, which should be reduced by the high-pass filter.

In [None]:
def peak_amplitude(x):

    '''
    Compute peak (maximum absolute) amplitude of an audio signal.
    Used to verify the effectiveness of peak normalization

    '''

    return np.max(np.abs(x))


def lf_energy_ratio(x, sr, lf_max_hz=30.0):
    '''
    Compute the ratio of low-frequency energy below a given cutoff.
    This metric quantifies the proportion of signal energy contained
    in subsonic frequencies which should have
    been reduced by high-pass filtering

    '''

    X = np.fft.rfft(x)
    freqs = np.fft.rfftfreq(len(x), d=1.0 / sr)
    power = np.abs(X) ** 2

    lf_energy = np.sum(power[freqs <= lf_max_hz])
    total_energy = np.sum(power) + 1e-12

    return lf_energy / total_energy

def load_mono(path):

    audio, sr = sf.read(path, dtype="float32")
    if audio.ndim == 2:
        audio = np.mean(audio, axis=1)
    return audio, sr

# We quantitatively evaluate the effect of preprocessing by
# comparing raw and preprocessed signals

genres = ["pop", "rock", "jazz", "classical", "hiphop", "electronic"]

rows = []

for genre in genres:
    raw_path = f"raw_{genre}.wav"
    pre_path = f"preproc_{genre}.wav"

    raw_audio, sr = load_mono(raw_path)
    pre_audio, _ = load_mono(pre_path)

    rows.append({
        "genre": genre,
        "stage": "raw",
        "peak": peak_amplitude(raw_audio),
        "lf_ratio_<30Hz": lf_energy_ratio(raw_audio, sr)
    })

    rows.append({
        "genre": genre,
        "stage": "preproc",
        "peak": peak_amplitude(pre_audio),
        "lf_ratio_<30Hz": lf_energy_ratio(pre_audio, sr)
    })

df = pd.DataFrame(rows)
df = df.sort_values(by=["genre", "stage"])

print("Preprocessing validation:")
print(df)


## Denoising: noise reduction with Noisereduce library
We reduce background noise from each preprocessed audio clip using the noisereduce library. The function estimates a noise profile, and applies non-stationary noise reduction.


In [None]:
!pip install -q noisereduce
import noisereduce as nr

def denoise_noisereduce_mono(audio, sr):
    audio = audio.astype(np.float32)

    """
    Apply non-stationary noise reduction to a mono audio signal.
    Although MusicGen does not generate environmental noise, its outputs
    often contain low-level artifacts and noise-like components.
    This function treats such artifacts as non-stationary noise and
    applies spectral attenuation.

    """

    if audio.ndim == 2:
        audio = np.mean(audio, axis=1)

    y_denoised = nr.reduce_noise(
        y=audio,
        y_noise=None,       # noise profile estimated automatically
        sr=sr,
        prop_decrease=0.8,  # noise attenuation percentage
        stationary=False    # assuming non-stationary noise
    )

    return y_denoised

    # Apply denoising to preprocessed audio

for genre in genres:

    pre_file = f"preproc_{genre}.wav"
    out_file = f"denoised_{genre}.wav"

    print(f"\nGenre: {genre}")

    audio_pre, sr = sf.read(pre_file)
    print("Preprocessed:")
    display(Audio(audio_pre, rate=sr))

    audio_denoised = denoise_noisereduce_mono(audio_pre, sr)
    print("Denoised:")
    display(Audio(audio_denoised, rate=sr))

    # Save denoised file
    sf.write(out_file, audio_denoised, sr)
    print(f"Saved as {out_file}")

In [None]:
import matplotlib.pyplot as plt
import librosa
import librosa.display

def plot_spectrogram(y, sr, title, fmax=None):

    S = librosa.stft(y, n_fft=2048, hop_length=512)
    S_db = librosa.amplitude_to_db(np.abs(S), ref=np.max)

    plt.figure(figsize=(10, 4))
    librosa.display.specshow(
        S_db,
        sr=sr,
        hop_length=512,
        x_axis="time",
        y_axis="hz"
    )

    if fmax is not None:
        plt.ylim(0, fmax)

    plt.colorbar(format="%+2.0f dB")
    plt.title(title)
    plt.tight_layout()
    plt.show()

genres = ["pop", "rock", "jazz", "classical", "hiphop", "electronic"]

for genre in genres:
    print(f"\n Genre: {genre.upper()}")

    y_pre, sr = sf.read(f"preproc_{genre}.wav")
    y_den, _ = sf.read(f"denoised_{genre}.wav")

    plot_spectrogram(
        y_pre, sr,
        title=f"{genre} – Preprocessed",
        fmax=16000
    )

    plot_spectrogram(
        y_den, sr,
        title=f"{genre} – Denoised",
        fmax=16000
    )



## Bandwidth Extension with HiFi‑GAN BWE (from 24 kHz to 48 kHz)

We use a bandwidth‑extension model (HiFi‑GAN BWE, https://github.com/brentspell/hifi-gan-bwe) (which is a third‑party implementation and not the official HiFi‑GAN release) to enhance the bandwidth of the denoised
audio clips.  
Each cleaned audio file (assumed to be mono or converted to mono) is downsampled to the 24 kHz input rate expected by the model, and feed it to the BWE model to generate a 48 kHz version with reconstructed high‑frequency content.

If the BWE model fails for any reason, a standard resampling fallback  is applied to produce a 48 kHz output.  

- ⚠️**WARNING!**⚠️ we note that the adopted BWE model is primarily optimized for speech signals rather than music.
Its application in this context is therefore exploratory and aims to assess whether generative bandwidth extension can improve the perceived fidelity of music generated by text-to-audio models.


In [None]:
!pip install -q hifi-gan-bwe
import warnings
import os
from hifi_gan_bwe import BandwidthExtender

warnings.filterwarnings("ignore")

print("starting BWE...")

# Setup
device = "cuda" if torch.cuda.is_available() else "cpu"

if 'bwe_model' not in locals():
    bwe_model = BandwidthExtender.from_pretrained("hifi-gan-bwe-13-59f00ca-vctk-24kHz-48kHz").to(device)
    bwe_model.eval()

target_sr_out = 48000
fs_in = 24000

# Apply bandwidth extension to denoised audio


for genre in genres:
    in_path = f"denoised_{genre}.wav"
    out_path = f"bwe_{genre}_48k.wav"

    if not os.path.exists(in_path):
        print(f"Skipping {genre}: file not found.")
        continue

    y, sr = sf.read(in_path, dtype="float32")

    # Ensure mono input

    if y.ndim == 2:
        y = np.mean(y, axis=1)

    x = torch.from_numpy(y).float().unsqueeze(0).unsqueeze(0).to(device)

    # Resample to model input rate if needed

    if sr != fs_in:
        resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=fs_in).to(device)
        x = resampler(x)
        current_sr = fs_in
    else:
        current_sr = sr

    print(f"Processing {genre}...")

    with torch.no_grad():
      try:
          y_48k_t = bwe_model(x, current_sr)
          y_48k = y_48k_t.squeeze().cpu().numpy()
          print(f"{genre}: BWE Success")

      except Exception as e:
          print(f"{genre}: BWE Failed. Error: {e}")
          print(" -> Fallback: standard resampling")

          y_48k_t = torchaudio.functional.resample(
              x, orig_freq=current_sr, new_freq=target_sr_out
          )
          y_48k = y_48k_t.squeeze().cpu().numpy()


    sf.write(out_path, y_48k, target_sr_out)

    print(f"Preview {genre}:")
    display(Audio(y_48k, rate=target_sr_out))

In [None]:
genres = ["pop", "rock", "jazz", "classical", "hiphop", "electronic"]

def plot_spectrogram(y, sr, title, fmax=None):
    S = librosa.stft(y, n_fft=2048, hop_length=512)
    S_db = librosa.amplitude_to_db(np.abs(S), ref=np.max)

    plt.figure(figsize=(10, 4))
    librosa.display.specshow(
        S_db,
        sr=sr,
        hop_length=512,
        x_axis="time",
        y_axis="hz"
    )
    if fmax is not None:
        plt.ylim(0, fmax)
    plt.colorbar(format="%+2.0f dB")
    plt.title(title)
    plt.tight_layout()
    plt.show()

for genre in genres:
    print(f"\n=== GENRE: {genre.upper()} ===")

    den_path = f"denoised_{genre}.wav"
    bwe_path = f"bwe_{genre}_48k.wav"

    y_den, sr_den = sf.read(den_path)
    y_bwe, sr_bwe = sf.read(bwe_path)

    # stereo → mono
    if y_den.ndim == 2:
        y_den = y_den.mean(axis=1)
    if y_bwe.ndim == 2:
        y_bwe = y_bwe.mean(axis=1)

    plot_spectrogram(
        y_den, sr_den,
        title=f"{genre} – Denoised",
        fmax=24000
    )

    plot_spectrogram(
        y_bwe, sr_bwe,
        title=f"{genre} – BWE Output (48 kHz)",
        fmax=24000
    )