# 🎵 YuE Music Generation (Colab Version)
**Notebook contributed by Digital Down**  
All music examples copyright © Digital Down. More examples: [YouTube Channel](https://www.youtube.com/@digital_down)

## Quick Start
1. Installation
2. Restart session and go to step 3
3. Lyrics Configuration
4. Genre Configuration
5. Mode Configuration
6. Advanced Parameters (optional)
7. Generate Music!


# 1. Installation
(Make sure runtime type is set to GPU)

(Run installation cell)

(This will take some time)

(Restart session only when installation cell is finished)

---

In [3]:
# @title Setup Environment {"vertical-output":true}

# Install PyTorch 2.1.0 with CUDA 11.8 (matches YuE requirements)
!pip install torch==2.1.0 torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118

# Install specific Flash Attention version
!pip install flash-attn==2.3.3 --no-build-isolation

# Install YuE requirements
!pip install -r <(curl -sSL https://raw.githubusercontent.com/multimodal-art-projection/YuE/main/requirements.txt)
!pip install --upgrade protobuf

# Install system dependencies
!sudo apt-get update
!sudo apt-get install -y git-lfs
!git lfs install

# Clone main repository with cleanup
!rm -rf YuE
!git clone https://github.com/multimodal-art-projection/YuE.git

# Clone model repository with LFS
%cd YuE/inference
!git clone https://huggingface.co/m-a-p/xcodec_mini_infer

# Ensure LFS files are fetched
%cd xcodec_mini_infer
!git lfs pull
%cd ..

# Return to root directory
%cd ../..

# Move Models to Cuda
%cd YuE/inference

import os
import sys
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import gc

# Add necessary paths for Colab environment
current_dir = os.getcwd()
sys.path.append(os.path.join(current_dir, 'xcodec_mini_infer'))
sys.path.append(os.path.join(current_dir, 'xcodec_mini_infer', 'descriptaudiocodec'))

def clear_memory():
    """
    Aggressively clear GPU and CPU memory
    """
    gc.collect()
    torch.cuda.empty_cache()
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()

def move_model_to_gpu(model, stage_name):
    """
    Safely move YuE model stage to GPU with memory optimization
    """
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is not available")

    try:
        print(f"\nMoving {stage_name} to GPU...")
        model = model.half().cuda()  # Convert to float16 and move to GPU
        print(f"{stage_name} successfully moved to GPU")
        return model

    except Exception as e:
        print(f"Error moving {stage_name} to GPU: {str(e)}")
        raise

def load_stage1_model():
    """
    Load Stage 1 (7B) model with memory optimization and proper GPU initialization
    """
    clear_memory()
    print("\nLoading Stage 1 model...")

    # Load model with updated Flash Attention parameter
    model = AutoModelForCausalLM.from_pretrained(
        "m-a-p/YuE-s1-7B-anneal-en-cot",
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        attn_implementation="flash_attention_2"
    ).to('cuda')  # Explicitly move to GPU using .to()

    return model

def load_stage2_model():
    """
    Load Stage 2 (1B) model with memory optimization and proper GPU initialization
    """
    clear_memory()
    print("\nLoading Stage 2 model...")

    # Load model with updated Flash Attention parameter
    model = AutoModelForCausalLM.from_pretrained(
        "m-a-p/YuE-s2-1B-general",
        torch_dtype=torch.float16,
        low_cpu_mem_usage=True,
        attn_implementation="flash_attention_2"
    ).to('cuda')  # Explicitly move to GPU using .to()

    return model

# Before loading models, check GPU compatibility
print("Checking GPU compatibility...")
if torch.cuda.is_available():
    device_name = torch.cuda.get_device_name(0)
    print(f"Found GPU: {device_name}")

    # Updated compatibility check including T4
    compatible_gpus = ['t4', 'a100', 'a40', 'a5000', 'a6000', '3090', '3080', '3070', '3060', '4090', '4080', '4070', '4060']
    is_compatible = any(x in device_name.lower() for x in compatible_gpus)

    if not is_compatible:
        raise RuntimeError(f"GPU {device_name} is not compatible with Flash Attention. Need T4, Ampere (RTX 3000+), or newer.")

    # Set default CUDA device
    torch.cuda.set_device(0)
else:
    raise RuntimeError("No GPU available")

# Usage:
try:
    # Load Stage 1 tokenizer
    tokenizer1 = AutoTokenizer.from_pretrained("m-a-p/YuE-s1-7B-anneal-en-cot")

    # Load and move Stage 1 model
    stage1_model = load_stage1_model()

    # Load Stage 2 tokenizer
    tokenizer2 = AutoTokenizer.from_pretrained("m-a-p/YuE-s2-1B-general")

    # Load and move Stage 2 model
    stage2_model = load_stage2_model()

    print("\nBoth models loaded successfully!")

except Exception as e:
    print(f"\nError during model loading: {str(e)}")
    raise

Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu118
Collecting protobuf!=4.24.0,>=3.19.6 (from tensorboard->-r /dev/fd/63 (line 9))
  Using cached protobuf-3.19.6-py2.py3-none-any.whl.metadata (828 bytes)
Using cached protobuf-3.19.6-py2.py3-none-any.whl (162 kB)
Installing collected packages: protobuf
  Attempting uninstall: protobuf
    Found existing installation: protobuf 5.29.3
    Uninstalling protobuf-5.29.3:
      Successfully uninstalled protobuf-5.29.3
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
google-ai-generativelanguage 0.6.15 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 3.19.6 which is incompatible.
google-cloud-bigtable 2.28.1 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.2, but you have protobuf 

Collecting protobuf
  Using cached protobuf-5.29.3-cp38-abi3-manylinux2014_x86_64.whl.metadata (592 bytes)
Using cached protobuf-5.29.3-cp38-abi3-manylinux2014_x86_64.whl (319 kB)
Installing collected packages: protobuf
  Attempting uninstall: protobuf
    Found existing installation: protobuf 3.19.6
    Uninstalling protobuf-3.19.6:
      Successfully uninstalled protobuf-3.19.6
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
descript-audiotools 0.7.2 requires protobuf<3.20,>=3.9.2, but you have protobuf 5.29.3 which is incompatible.[0m[31m
[0mSuccessfully installed protobuf-5.29.3
Hit:1 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease
Hit:2 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease
Hit:3 http://archive.ubuntu.com/ubuntu jammy InRelease
Hit:4 http://archive.ubuntu.com/ubuntu jammy-updates 

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

config.json:   0%|          | 0.00/683 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/1.76M [00:00<?, ?B/s]


Loading Stage 2 model...


model.safetensors:   0%|          | 0.00/3.93G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]


Both models loaded successfully!


# 2. Restart session and continue below

---



# 3. Lyrics

---

In [1]:
# @title  {"vertical-output":true}
# %% Configuration Cell
from IPython.display import display
import ipywidgets as widgets

# --- Segment Controls ---
segments_slider = widgets.IntSlider(
    value=2,
    min=1,
    max=8,
    description='Number of Segments:',
    style={'description_width': 'initial'},
    layout={'width': '90%'}
)

segment_types = []
lyric_sections = []

# --- Dynamic Segment Creator ---
def create_segment_ui(change):
    global segment_types, lyric_sections
    num_segments = change['new']

    # Clear previous widgets
    segment_types.clear()
    lyric_sections.clear()

    # Create new dropdowns and text areas
    for i in range(num_segments):
        # Create segment type dropdown
        default_value = 'verse' if i == 0 else 'chorus' if i == 1 else 'verse'
        seg_dropdown = widgets.Dropdown(
            options=['verse', 'chorus', 'bridge', 'outro'],
            value=default_value,
            description=f'Segment {i+1}:',
            layout={'width': '250px', 'margin': '0 0 5px 0'}
        )

        # Create locked header
        header = widgets.Text(
            value=f'[{seg_dropdown.value}]',
            disabled=True,
            layout={'width': '250px', 'margin': '0 0 5px 0'}
        )

        # Create lyrics input
        lyrics_input = widgets.Textarea(
            placeholder=f'Enter {seg_dropdown.value} lyrics...',
            layout={'width': '100%', 'height': '100px'}
        )

        # Link dropdown to header
        def update_header(change, header=header, lyrics_input=lyrics_input):
            header.value = f'[{change["new"]}]'
            lyrics_input.placeholder = f'Enter {change["new"]} lyrics...'
        seg_dropdown.observe(update_header, names='value')

        segment_types.append(seg_dropdown)
        lyric_sections.append(widgets.VBox([
            header,
            lyrics_input
        ], layout={'margin': '0 0 20px 0'}))

    # Update UI elements
    segments_box.children = tuple(segment_types)
    lyrics_box.children = tuple(lyric_sections)

# --- UI Containers ---
segments_box = widgets.VBox([])
lyrics_box = widgets.VBox([])

# --- File Generation ---
def save_lyrics(_):
    lyrics_content = []
    for seg_type, lyrics in zip(segment_types, lyric_sections):
        header = f'[{seg_type.value}]'
        content = lyrics.children[1].value.strip()
        if content:
            lyrics_content.append(f"{header}\n{content}")
        else:
            lyrics_content.append(f"{header}\n[Add your lyrics here]")

    with open('lyrics.txt', 'w') as f:
        f.write('\n\n'.join(lyrics_content))
    print('✅ lyrics.txt generated successfully!')

# --- Assembly ---
generate_btn = widgets.Button(
    description="Generate Lyrics File",
    layout={'margin': '20px 0 0 0', 'width': '300px'},
    button_style='success'
)

generate_btn.on_click(save_lyrics)
segments_slider.observe(create_segment_ui, names='value')

# Initialize with default segments
create_segment_ui({'new': segments_slider.value})

display(
    widgets.VBox([
        widgets.HTML("<h3 style='color: #2c3e50'>Song Structure Configuration</h3>"),
        segments_slider,
        widgets.HTML("<h4 style='margin-top: 20px; color: #34495e'>Section Types</h4>"),
        segments_box,
        widgets.HTML("<h4 style='margin-top: 30px; color: #34495e'>Lyrics Composition</h4>"),
        lyrics_box,
        generate_btn
    ], layout={
        'padding': '25px',
        'border': '2px solid #ecf0f1',
        'border_radius': '10px',
        'width': '95%'
    })
)

VBox(children=(HTML(value="<h3 style='color: #2c3e50'>Song Structure Configuration</h3>"), IntSlider(value=2, …

✅ lyrics.txt generated successfully!


# 4. Genre

---

In [2]:
# @title  {"vertical-output":true}
from IPython.display import display
import ipywidgets as widgets

# Genre Box
genre_box = widgets.VBox([
    widgets.HTML("<h3 style='color: #2c3e50'>Genre Configuration</h3>"),
    widgets.Text(
        value="inspiring female pop electronic vocal",
        placeholder="Enter space-separated genre tags...",
        description='Tags:',
        layout={'width': '90%'},
        style={'description_width': '100px'}
    ),
    widgets.Button(
        description="Generate Genre File",
        button_style='success',
        layout={'width': '300px', 'margin': '20px 0 0 0'}
    )
], layout={
    'padding': '20px',
    'border': '2px solid #e0e0e0',
    'width': '95%',
    'border_radius': '8px'
})

# Save handler
def save_genre(b):
    with open('genre.txt', 'w') as f:
        f.write(genre_box.children[1].value)
    print(f"✅ genre.txt generated successfully!")

genre_box.children[-1].on_click(save_genre)
display(genre_box)

VBox(children=(HTML(value="<h3 style='color: #2c3e50'>Genre Configuration</h3>"), Text(value='inspiring female…

✅ genre.txt generated successfully!


# 5. Mode

---







In [3]:
# @title  {"vertical-output":true}
from IPython.display import display
import ipywidgets as widgets
import os

# --- Style Configuration ---
box_style = {
    'padding': '20px',
    'border': '2px solid #e0e0e0',
    'margin': '0 0 20px 0',
    'border_radius': '8px',
    'width': '95%'
}

# --- Mode Configuration ---
mode_box = widgets.VBox([
    widgets.HTML("<h3 style='color: #2c3e50; margin-bottom: 15px'>Mode Configuration</h3>"),

    widgets.Dropdown(
        options=['CoT', 'Single-Track ICL', 'Dual-Track ICL'],
        value='CoT',
        description='Mode:',
        style={'description_width': '100px'},
        layout={'width': '300px'}
    ),

    widgets.HTML("""<div style="margin: 10px 0; padding: 10px; background: #f8f9fa; border-radius: 5px;">
        <div id="modeDesc" style="color: #2c3e50; font-size: 0.95em; line-height: 1.4;">
            Chain-of-Thought: No reference audio needed
        </div>
    </div>"""),

    widgets.VBox([], id='uploadSection', layout={'width': '100%'}),

    widgets.Text(
        value='180',
        description='Song Length (seconds):',
        style={'description_width': '150px'},
        layout={'width': '300px', 'margin': '20px 0 0 0'}
    )
], layout=box_style)

# --- Generation Parameters ---
param_box = widgets.VBox([
    widgets.HTML("<h3 style='color: #2c3e50; margin-bottom: 15px'>⚙️ Generation Parameters</h3>"),

    widgets.IntSlider(
        value=2,
        min=1,
        max=16,
        description='Batch Size:',
        style={'description_width': '100px'},
        layout={'width': '90%'}
    )
], layout=box_style)

# --- Mode Descriptions ---
mode_descriptions = {
    'CoT': "🎵 <b>Chain-of-Thought (CoT)</b><br>Generate music from scratch using only lyrics and genre tags with no reference audio",
    'Single-Track ICL': "🎧 <b>Single-Track In-Context Learning</b><br>Upload any single audio track (music, vocals, or instrumental) for style reference",
    'Dual-Track ICL': "🎤+🎹 <b>Dual-Track In-Context Learning</b><br>Provide separate vocal and instrumental tracks for precise style matching"
}

# Colab-specific CSS to hide upload counters
display(widgets.HTML('''
<style>
    .widget-fileupload button span:last-child {
        display: none !important;
    }
    .widget-fileupload {
        width: 400px !important;
    }
</style>
'''))

def update_mode_ui(change):
    """Update UI based on selected mode"""
    mode = change['new']

    # Update description
    mode_box.children[2].value = f"""
    <div style="margin: 10px 0; padding: 10px; background: #f8f9fa; border-radius: 5px;">
        <div style="color: #2c3e50; font-size: 0.95em; line-height: 1.4;">
            {mode_descriptions[mode]}
        </div>
    </div>
    """

    # Update file upload section
    uploads = []
    if mode == 'Single-Track ICL':
        uploads.append(widgets.FileUpload(
            description="Upload Reference Track",
            multiple=False,
            layout={'width': '400px'}
        ))
    elif mode == 'Dual-Track ICL':
        uploads.append(widgets.VBox([
            widgets.FileUpload(
                description="Upload Vocals Track",
                multiple=False,
                layout={'width': '400px'}
            ),
            widgets.FileUpload(
                description="Upload Instrumental Track",
                multiple=False,
                layout={'width': '400px'}
            )
        ]))

    mode_box.children[3].children = uploads

# --- Validation ---
def validate_song_length(change):
    try:
        int(change['new'])
        mode_box.children[4].style = {'description_color': 'black'}
    except ValueError:
        mode_box.children[4].style = {'description_color': 'red'}

# --- Save Configuration ---
save_btn = widgets.Button(
    description="💾 Save All Configurations",
    button_style='success',
    layout={'width': '300px', 'margin': '20px 0 0 0'}
)

def save_config(b):
    # Validate song length
    try:
        song_length = int(mode_box.children[4].value)
        assert song_length > 0
    except:
        print("❌ Error: Song length must be a positive integer")
        return

    # Save parameters
    config = {
        'mode': mode_box.children[1].value,
        'batch_size': param_box.children[1].value,
        'song_length': song_length
    }

    # Save audio files
    os.makedirs('ref_audio', exist_ok=True)
    mode = config['mode']

    if mode == 'Single-Track ICL' and len(mode_box.children[3].children) > 0:
        uploader = mode_box.children[3].children[0]
        if uploader.value:
            with open('ref_audio/ref_audio.mp3', 'wb') as f:
                f.write(uploader.value[0]['content'])
    elif mode == 'Dual-Track ICL' and len(mode_box.children[3].children) > 0:
        vocals, inst = mode_box.children[3].children[0].children
        if vocals.value and inst.value:
            with open('ref_audio/ref_vocals.mp3', 'wb') as f:
                f.write(vocals.value[0]['content'])
            with open('ref_audio/ref_instrumental.mp3', 'wb') as f:
                f.write(inst.value[0]['content'])

    print(f"""
✅ Configuration saved:
   - Mode: {config['mode']}
   - Batch Size: {config['batch_size']}
   - Song Length: {config['song_length']} seconds
    """)

# --- Setup ---
mode_box.children[1].observe(update_mode_ui, names='value')
mode_box.children[4].observe(validate_song_length, names='value')
save_btn.on_click(save_config)

# Initialize UI
update_mode_ui({'new': 'CoT'})

# Display everything
display(widgets.VBox([
    widgets.HTML("<h2 style='color: #2c3e50; margin-bottom: 20px'></h2>"),
    mode_box,
    param_box,
    save_btn
], layout={'padding': '15px', 'width': '100%'}))

HTML(value='\n<style>\n    .widget-fileupload button span:last-child {\n        display: none !important;\n   …

VBox(children=(HTML(value="<h2 style='color: #2c3e50; margin-bottom: 20px'></h2>"), VBox(children=(HTML(value=…


✅ Configuration saved:
   - Mode: CoT
   - Batch Size: 2
   - Song Length: 180 seconds
    


# 6. Advanced

---

In [None]:
# @title {"vertical-output":true}
from IPython.display import display
import ipywidgets as widgets

param_box = widgets.VBox([
    widgets.HTML("<h3 style='color: #2c3e50; margin-bottom: 15px'>Advanced Parameters (Optional)</h3>"),
    widgets.IntSlider(
        value=2,
        min=1,
        max=16,
        description='Batch Size:',
        style={'description_width': '100px'},
        layout={'width': '90%'}
    )
], layout={
    'padding': '20px',
    'border': '2px solid #e0e0e0',
    'margin': '0 0 20px 0',
    'border_radius': '8px',
    'width': '95%'
})

display(param_box)

# 7. Generate

---

In [4]:
# @title  {"vertical-output":true}
from IPython.display import display
import ipywidgets as widgets
import os

def run_inference(_):
    # Get config values
    mode = mode_box.children[1].value
    batch_size = int(param_box.children[1].value) if 'param_box' in globals() else 4
    n_segments = segments_slider.value
    song_length = int(mode_box.children[4].value)

    # Base command
    cmd = [
        "cd YuE/inference/ &&",
        "python infer.py",
        f"--cuda_idx 0",
        f"--stage2_model m-a-p/YuE-s2-1B-general",
        f"--genre_txt /content/genre.txt",
        f"--lyrics_txt /content/lyrics.txt",
        f"--run_n_segments {n_segments}",
        f"--stage2_batch_size {batch_size}",
        f"--output_dir ../output",
        f"--max_new_tokens {min(3000, song_length*20)}",  # Heuristic based on song length
        "--repetition_penalty 1.1"
    ]

    # Stage1 model selection
    if "ICL" in mode:
        cmd.insert(4, "--stage1_model m-a-p/YuE-s1-7B-anneal-en-icl")
    else:
        cmd.insert(4, "--stage1_model m-a-p/YuE-s1-7B-anneal-en-cot")

    # Add mode-specific parameters
    if mode == 'Single-Track ICL':
        if os.path.exists('ref_audio/ref_audio.mp3'):
            cmd += [
                "--use_audio_prompt",
                f"--audio_prompt_path ../ref_audio/ref_audio.mp3",
                "--prompt_start_time 0",
                f"--prompt_end_time {min(30, song_length)}"
            ]

    elif mode == 'Dual-Track ICL':
        if os.path.exists('ref_audio/ref_vocals.mp3') and os.path.exists('ref_audio/ref_instrumental.mp3'):
            cmd += [
                "--use_dual_tracks_prompt",
                f"--vocal_track_prompt_path ../ref_audio/ref_vocals.mp3",
                f"--instrumental_track_prompt_path ../ref_audio/ref_instrumental.mp3",
                "--prompt_start_time 0",
                f"--prompt_end_time {min(30, song_length)}"
            ]

    # Convert to single string
    full_cmd = " ".join(cmd)

    # Execute
    print("🚀 Starting music generation with command:")
    print(full_cmd)
    !{full_cmd}
    print("\n🎉 Generation complete! Check output directory for results.")

# Create and display run button
run_btn = widgets.Button(
    description="▶️ Start Music Generation",
    button_style='success',
    layout={'width': '300px', 'margin': '20px 0 0 0'}
)
run_btn.on_click(run_inference)
display(run_btn)

Button(button_style='success', description='▶️ Start Music Generation', layout=Layout(margin='20px 0 0 0', wid…

🚀 Starting music generation with command:
cd YuE/inference/ && python infer.py --cuda_idx 0 --stage2_model m-a-p/YuE-s2-1B-general --stage1_model m-a-p/YuE-s1-7B-anneal-en-cot --genre_txt /content/genre.txt --lyrics_txt /content/lyrics.txt --run_n_segments 2 --stage2_batch_size 2 --output_dir ../output --max_new_tokens 3000 --repetition_penalty 1.1
2025-02-06 23:41:39.082480: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-02-06 23:41:39.101542: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1738885299.123404   12908 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when on

## License & Attribution
**YuE Model License:** Apache 2.0  
**Colab Implementation:** Contributed by Digital Down  
**Music Examples:** © Digital Down, All Rights Reserved  

When sharing generated music:
- Include "AI-generated with YuE" in credits
- Recommended hashtags: #AImusic #YuEGen

[More Music Examples](https://youtube.com/@digital_down) | [Report Issues](https://github.com/multimodal-art-projection/YuE/issues)