In [2]:
#FAST API MUSICGEN INTEGRATION

In [1]:
import asyncio
import threading
import time
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
from audiocraft.models import MusicGen
import numpy as np
import soundfile as sf
import tempfile
import os
from pyngrok import ngrok
import uvicorn
import nest_asyncio

In [None]:
# Enable nested event loops for Jupyter
nest_asyncio.apply()

model = None

class PromptRequest(BaseModel):
    prompt: str

# Create FastAPI app
app = FastAPI(title="MusicGen API")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"], 
    allow_credentials=True,
    allow_methods=["*"],  # Allow all methods
    allow_headers=["*"],  # Allow all headers
)

In [None]:
@app.get("/")
async def root():
    return {"message": "MusicGen API is running!", "model_loaded": model is not None}

@app.post("/generate")
async def generate_audio(request: PromptRequest):
    """Generate audio from text prompt"""
    tmp_file_path = None
    try:
        if model is None:
            raise HTTPException(status_code=500, detail="Model not loaded")
        
        print(f"Generating audio for prompt: {request.prompt}")
        
        # Generate audio
        wav = model.generate([request.prompt])
        
        # Convert to numpy array
        audio = wav[0].cpu().numpy()
        sample_rate = model.sample_rate
        
        # Ensure audio is 1D (mono) and normalize
        if audio.ndim > 1:
            audio = audio.squeeze()
            if audio.ndim > 1:
                audio = audio[0]
        
        # Normalize audio
        audio = audio / np.max(np.abs(audio))
        
        # Save to temporary file
        with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
            sf.write(tmp_file.name, audio, sample_rate)
            tmp_file_path = tmp_file.name
            
            # Return file response without background cleanup
            return FileResponse(
                tmp_file.name,
                media_type="audio/wav",
                filename="generated_audio.wav"
            )
            
    except Exception as e:
        # Clean up temp file if there was an error
        if tmp_file_path and os.path.exists(tmp_file_path):
            try:
                os.unlink(tmp_file_path)
            except:
                pass
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health")
async def health_check():
    """Health check endpoint"""
    return {"status": "healthy", "model_loaded": model is not None}

def load_model():
    """Load the MusicGen model"""
    global model
    print("Loading MusicGen model...")
    model = MusicGen.get_pretrained('facebook/musicgen-small')
    model.set_generation_params(duration=2)
    print("✅ Model loaded successfully!")

def start_server_thread():
    """Start the server in a separate thread"""
    def run_server():
        # Prompt for ngrok auth token
        from getpass import getpass
        print("🔑 Please enter your ngrok auth token:")
        print("   Get it from: https://dashboard.ngrok.com/get-started/your-authtoken")
        auth_token = getpass("Auth Token: ")
        
        # Set the auth token
        ngrok.set_auth_token(auth_token)
        
        # Start ngrok tunnel
        public_url = ngrok.connect(7865, domain="steady-notable-manatee.ngrok-free.app")
        print(f"🌐 Public URL: {public_url}")
        print(f"📱 Local URL: http://localhost:7865")
        print(f"📖 API docs: {public_url}/docs")
        
        # Start server
        uvicorn.run(app, host="0.0.0.0", port=7865, log_level="info")
    
    # Start server in background thread
    server_thread = threading.Thread(target=run_server, daemon=True)
    server_thread.start()
    
    # Give the server time to start
    time.sleep(5)
    print("🚀 Server started in background!")

# Load model (run this cell first)
load_model()

# Start server (run this cell second)
start_server_thread()

# Test the API (run this cell to test)
import requests
import json

def test_api(prompt, save_as="test_output.wav"):
    """Test function to generate audio via API"""
    try:
        response = requests.post(
            "http://localhost:7865/generate",
            json={"prompt": prompt},
            stream=True
        )
        
        if response.status_code == 200:
            with open(save_as, 'wb') as f:
                for chunk in response.iter_content(chunk_size=8192):
                    f.write(chunk)
            print(f"✅ Audio saved as {save_as}")
            
            # Display audio player
            from IPython.display import Audio, display
            display(Audio(save_as))
            
        else:
            print(f"❌ Error: {response.status_code} - {response.text}")
            
    except Exception as e:
        print(f"❌ Error: {e}")

# Example usage:
# test_api("single kick drum hit, one shot, no music", "kick.wav")
# test_api("808 bass drum, electronic, punchy", "808.wav")

Loading MusicGen model...




✅ Model loaded successfully!
🔑 Please enter your ngrok auth token:
   Get it from: https://dashboard.ngrok.com/get-started/your-authtoken
🚀 Server started in background!


Auth Token:  ········


t=2025-07-16T22:41:01-0400 lvl=warn msg="ngrok config file found at both XDG and legacy locations, using XDG location" xdg_path=/home/jupyter-nn2415/.config/ngrok/ngrok.yml legacy_path=/home/jupyter-nn2415/.ngrok2/ngrok.yml
t=2025-07-16T22:41:01-0400 lvl=warn msg="can't bind default web address, trying alternatives" obj=web addr=127.0.0.1:4040


🌐 Public URL: NgrokTunnel: "https://steady-notable-manatee.ngrok-free.app" -> "http://localhost:7865"
📱 Local URL: http://localhost:7865
📖 API docs: NgrokTunnel: "https://steady-notable-manatee.ngrok-free.app" -> "http://localhost:7865"/docs


INFO:     Started server process [1880485]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:7865 (Press CTRL+C to quit)
t=2025-07-16T22:41:06-0400 lvl=warn msg="failed to check for update" obj=updater err="Post \"https://update.equinox.io/check\": context deadline exceeded"


INFO:     2600:4808:5392:d500:84d9:6a59:910b:6801:0 - "GET / HTTP/1.1" 200 OK
INFO:     2600:4808:5392:d500:84d9:6a59:910b:6801:0 - "GET /favicon.ico HTTP/1.1" 404 Not Found
