This AI Music Generator by **M.Ashbil Shahid** leverages *Facebook's MusicGen* and *MultiBandDiffusion* to create original compositions from text prompts with 4+ model variations. Features include multi-model support, parameter customization, diffusion decoding, and interactive playback controls in a streamlined interface. Ideal for musicians, content creators, and audio enthusiasts seeking instant royalty-free music generation!


# Necessary Libraries

In [1]:
!pip install 'torch>=2.0'
!pip install -U git+https://git@github.com/facebookresearch/audiocraft
!pip install gradio==3.50.2

import os
import time
import logging
import soundfile as sf
import numpy as np
import gradio as gr
from audiocraft.models import MusicGen, MultiBandDiffusion

Collecting git+https://****@github.com/facebookresearch/audiocraft
  Cloning https://****@github.com/facebookresearch/audiocraft to /tmp/pip-req-build-t1af7kwm
  Running command git clone --filter=blob:none --quiet 'https://****@github.com/facebookresearch/audiocraft' /tmp/pip-req-build-t1af7kwm
  Resolved https://****@github.com/facebookresearch/audiocraft to commit e5fcc458a4dc1c6f7248cbceac9cfe471f2c92b8
  Preparing metadata (setup.py) ... [?25l[?25hdone



# Functionalities

In [2]:
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Global variables for model management
CURRENT_MODEL = None
DIFFUSION_DECODER = None
CURRENT_MODEL_NAME = None
USE_DIFFUSION = False

def load_model(model_size, use_diffusion):
    global CURRENT_MODEL, DIFFUSION_DECODER, CURRENT_MODEL_NAME, USE_DIFFUSION
    try:
        if CURRENT_MODEL_NAME != model_size or USE_DIFFUSION != use_diffusion:
            logger.info(f"Loading model: {model_size}, diffusion: {use_diffusion}")
            CURRENT_MODEL = MusicGen.get_pretrained(f'facebook/musicgen-{model_size}')
            CURRENT_MODEL_NAME = model_size
            USE_DIFFUSION = use_diffusion
            if use_diffusion:
                DIFFUSION_DECODER = MultiBandDiffusion.get_mbd_musicgen()
            logger.info("Model loaded successfully")
    except Exception as e:
        logger.error(f"Model loading failed: {str(e)}")
        raise

def generate_music(prompt, duration, model_size, use_diffusion, top_k, top_p, temperature):
    try:
        start_time = time.time()

        if not prompt or len(prompt.strip()) == 0:
            return None, None, "Error: Please enter a valid prompt"

        if duration < 1 or duration > 600:
            return None, None, "Error: Duration must be between 1 and 600 seconds"

        try:
            load_model(model_size, use_diffusion)
        except Exception as e:
            return None, None, f"Model Error: {str(e)}"

        CURRENT_MODEL.set_generation_params(
            use_sampling=True,
            top_k=int(top_k),
            top_p=float(top_p),
            temperature=float(temperature),
            duration=int(duration),
        )

        try:
            output = CURRENT_MODEL.generate(
                descriptions=[prompt],
                progress=True,
                return_tokens=True
            )
        except Exception as e:
            return None, None, f"Generation Error: {str(e)}"

        try:
            if use_diffusion and DIFFUSION_DECODER is not None:
                audio = DIFFUSION_DECODER.tokens_to_wav(output[1])
            else:
                audio = output[0]

            audio_np = audio.cpu().numpy().squeeze()
            sr = 32000

            os.makedirs("outputs", exist_ok=True)
            filename = f"outputs/musicgen_{int(time.time())}.wav"
            sf.write(filename, audio_np, sr)

            generation_time = time.time() - start_time
            return (sr, audio_np), filename, f"Success! Generated in {generation_time:.1f}s"

        except Exception as e:
            return None, None, f"Processing Error: {str(e)}"

    except Exception as e:
        logger.exception("Critical error during generation:")
        return None, None, f"Unexpected Error: {str(e)}"



# Interface

In [3]:
def create_interface():
    with gr.Blocks(theme=gr.themes.Soft(), title="AI MusicGen by ASH") as interface:
        gr.Markdown("# 🎵 Professional Music Generation Interface")

        with gr.Row():
            with gr.Column(scale=2):
                with gr.Group():
                    gr.Markdown("## Model Configuration")
                    model_size = gr.Dropdown(
                        choices=["small", "medium", "melody", "large"],
                        value="large",
                        label="Model Architecture"
                    )
                    use_diffusion = gr.Checkbox(
                        label="Enable Multi-Band Diffusion",
                        value=False
                    )

                with gr.Group():
                    gr.Markdown("## Generation Parameters")
                    prompt = gr.Textbox(
                        label="Music Description",
                        placeholder="Describe the music you want to generate...",
                        lines=4
                    )
                    duration = gr.Slider(
                        minimum=5, maximum=600, value=30, step=5,
                        label="Duration (seconds)"
                    )
                    with gr.Row():
                        top_k = gr.Slider(
                            minimum=0, maximum=500, value=250, step=10,
                            label="Top-K Sampling"
                        )
                        top_p = gr.Slider(
                            minimum=0, maximum=1, value=0.9, step=0.05,
                            label="Top-P Sampling"
                        )
                    temperature = gr.Slider(
                        minimum=0.1, maximum=2.0, value=1.0, step=0.1,
                        label="Creativity Temperature"
                    )

                generate_btn = gr.Button("Generate Music", variant="primary")

            with gr.Column(scale=3):
                gr.Markdown("## Output Console")
                with gr.Group():
                    audio_output = gr.Audio(
                        label="Generated Composition",
                        format="wav",
                        elem_id="main_audio"
                    )
                    download_file = gr.File(
                        label="Download WAV File"
                    )
                    status = gr.Textbox(
                        label="Generation Status"
                    )

                with gr.Group():
                    gr.Markdown("### Playback Management")
                    with gr.Row():
                        play_btn = gr.Button("▶️ Play")
                        pause_btn = gr.Button("⏸️ Pause")
                        stop_btn = gr.Button("⏹️ Stop")

        play_btn.click(
            None,
            _js="""
            () => {
                try {
                    const container = document.getElementById('main_audio');
                    const audioElement = container?.querySelector('audio');
                    if (audioElement) audioElement.play();
                } catch (e) {
                    console.error('Play error:', e);
                }
            }
            """
        )

        pause_btn.click(
            None,
            _js="""
            () => {
                try {
                    const audioElement = document.querySelector('#main_audio audio');
                    if (audioElement) audioElement.pause();
                } catch (e) {
                    console.error('Pause error:', e);
                }
            }
            """
        )

        stop_btn.click(
            None,
            _js="""
            () => {
                try {
                    const audioElement = document.querySelector('#main_audio audio');
                    if (audioElement) {
                        audioElement.pause();
                        audioElement.currentTime = 0;
                    }
                } catch (e) {
                    console.error('Stop error:', e);
                }
            }
            """
        )

        generate_btn.click(
            generate_music,
            inputs=[prompt, duration, model_size, use_diffusion, top_k, top_p, temperature],
            outputs=[audio_output, download_file, status]
        )

        # Add concurrency limit here
        interface.queue(concurrency_count=1)

    return interface

if __name__ == "__main__":
    interface = create_interface()
    interface.launch(share=True)

IMPORTANT: You are using gradio version 3.50.2, however version 4.44.1 is available, please upgrade.
--------
Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Running on public URL: https://258587e8f3179f4766.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)
