In [None]:
# Imports
import random
import torch
import soundfile as sf
import os
import json
import math
import sys
from contextlib import contextmanager
from transformers import AutoProcessor, MusicgenForConditionalGeneration, EncodecModel
from datasets import load_dataset
from musicrfm import MusicGenController
from musicrfm.utils import make_json_serializable
import IPython.display as ipd

# Context manager to suppress print statements for the training block (since xRFM prints a lot w/o a verbose option)
@contextmanager
def suppress_output():
    """Suppress stdout and stderr."""
    with open(os.devnull, 'w') as devnull:
        old_stdout = sys.stdout
        old_stderr = sys.stderr
        sys.stdout = devnull
        sys.stderr = devnull
        try:
            yield
        finally:
            sys.stdout = old_stdout
            sys.stderr = old_stderr


## Config
Modify the parameters below to customize your training and generation

In [None]:
# Configuration
# ==================== Training Configuration ====================
SEED = 319
TARGET_NOTE = "C"  # The note to control (C, D, E, F, G, A, B, or with sharps/flats like C#, D#)
NUM_EXAMPLES = -1  # Number of examples per class (-1 for all available)
N_COMPONENTS = 12  # Number of RFM components
RFM_ITERS = 30  # Number of RFM iterations
BATCH_SIZE = 16

# ==================== Generation Configuration ====================
# Test prompts for generation
TEST_PROMPTS = [
    "Trance edm song with synths and reverb",
    "slow and relaxing, chill lofi hip hop song",
    "Fast beat, hip hop, upbeat that has a positive vibe",
    "Dreamy future bass with chopped synth chords"
]

# Layer selection method: "all", "top_k", or "exp_weighting"
LAYER_SELECTION = "exp_weighting"
TOP_K = 12  # Only used if LAYER_SELECTION == "top_k"
EXP_BASE_WEIGHT = 1.0  # Base weight for exponential weighting
EXP_DECAY_RATE = 0.95  # Decay rate for exponential weighting

# Control coefficients to test
CONTROL_COEFFICIENTS = [0.4]

# Time control: None, "exp_decay", or "linear_decay"
TIME_CONTROL = "exp_decay"
TIME_DECAY_RATE = 0.998  # For exponential decay

# Probabilistic injection (0.0-1.0)
INJECT_CHANCE = 0.3

# Is this a regression task? (False for classification like notes)
IS_REGRESSION = False

# ==================== Paths ====================
OUTPUT_DIR = "./trained_concepts"
DIRECTIONS_PATH = os.path.join(OUTPUT_DIR, "directions")
RESULTS_PATH = os.path.join(OUTPUT_DIR, "results")
GENERATIONS_PATH = os.path.join(OUTPUT_DIR, "generations")

# Set random seeds
random.seed(SEED)
torch.manual_seed(SEED)

DEVICE = "cuda:3" if torch.cuda.is_available() else "cpu"
print(f"Configuration loaded")
print(f"Device: {DEVICE}")
print(f"Target: {TARGET_NOTE}")
print(f"Layer selection: {LAYER_SELECTION}")


## Load Models and Dataset
Change the dataset loaded to change the concept you want to train for.

In [None]:
print("Loading models...")
music_model = MusicgenForConditionalGeneration.from_pretrained(
    "facebook/musicgen-large"
).to(DEVICE)
music_processor = AutoProcessor.from_pretrained("facebook/musicgen-large")
encodec_model = EncodecModel.from_pretrained("facebook/encodec_32khz").to(DEVICE)
encodec_processor = AutoProcessor.from_pretrained("facebook/encodec_32khz")

controller = MusicGenController(
    music_model,
    music_processor,
    encodec_model,
    encodec_processor,
    control_method="music_rfm",
    n_components=N_COMPONENTS,
    rfm_iters=RFM_ITERS,
    batch_size=BATCH_SIZE
)
print("Models loaded successfully!")

## Training and Saving Control Directions

You only need to load data / train the directions once to use them for any prompt. After training once you can simply use them in the next step (generation).

In [None]:
print(f"Loading Syntheory dataset...")
dataset = load_dataset("meganwei/syntheory", "notes")["train"]
print(f"Dataset loaded with {len(dataset)} examples")

# Split data into positive (target note) and negative (other notes)
print(f"\nPreparing data for note: {TARGET_NOTE}")
positive_examples = [x for x in dataset if x["root_note_name"] == TARGET_NOTE]
negative_examples = [x for x in dataset if x["root_note_name"] != TARGET_NOTE]

print(f"Found {len(positive_examples)} positive examples (note {TARGET_NOTE})")
print(f"Found {len(negative_examples)} negative examples (other notes)")

# Sample and split into train/val/test (70/15/15)
num_examples = min(len(positive_examples), len(negative_examples))
if NUM_EXAMPLES != -1:
    num_examples = min(num_examples, NUM_EXAMPLES)

positive_samples = random.sample(positive_examples, num_examples)
negative_samples = random.sample(negative_examples, num_examples)

n_train_pos = int(0.7 * num_examples)
n_val_pos = int(0.15 * num_examples)
n_train_neg = int(0.7 * num_examples)
n_val_neg = int(0.15 * num_examples)

train_samples = positive_samples[:n_train_pos] + negative_samples[:n_train_neg]
val_samples = positive_samples[n_train_pos:n_train_pos+n_val_pos] + negative_samples[n_train_neg:n_train_neg+n_val_neg]
test_samples = positive_samples[n_train_pos+n_val_pos:] + negative_samples[n_train_neg+n_val_neg:]

random.shuffle(train_samples)
random.shuffle(val_samples)
random.shuffle(test_samples)

print(f"\n✓ Data split complete:")
print(f"  Train: {len(train_samples)} samples")
print(f"  Val: {len(val_samples)} samples")
print(f"  Test: {len(test_samples)} samples")


In [None]:
# Training control directions
print("Extracting audio features...")
train_features = [controller.get_audio_features(x) for x in train_samples]
val_features = [controller.get_audio_features(x) for x in val_samples]
test_features = [controller.get_audio_features(x) for x in test_samples]
print("Features extracted")

train_labels = torch.tensor(
    [1 if x["root_note_name"] == TARGET_NOTE else 0 for x in train_samples]
).reshape(-1, 1)
val_labels = torch.tensor(
    [1 if x["root_note_name"] == TARGET_NOTE else 0 for x in val_samples]
).reshape(-1, 1)
test_labels = torch.tensor(
    [1 if x["root_note_name"] == TARGET_NOTE else 0 for x in test_samples]
).reshape(-1, 1)

train_data = torch.cat(train_features, dim=0)
val_data = torch.cat(val_features, dim=0)
test_data = torch.cat(test_features, dim=0)

print(f"\n✓ Data prepared:")
print(f"  Train: {train_data.shape}")
print(f"  Val: {val_data.shape}")
print(f"  Test: {test_data.shape}")


In [None]:
print("Computing control directions (this may take 10-20 minutes)...")
print("Note: Progress output is suppressed for cleaner output. Be patient!")

# Suppress verbose output from compute_directions. Can remove this if you want, but output is large
with suppress_output():
    test_predictor_accs, test_direction_accs, results = controller.compute_directions(
        train_data=train_data,
        train_labels=train_labels,
        val_data=val_data,
        val_labels=val_labels,
        test_data=test_data,
        test_labels=test_labels,
        hidden_layers=list(range(-1, -48, -1)), 
        tuning_metric='auc',
        pooling='mean',
        hyperparam_samples=20 # number of times to randomly sample hyperparameters
    )

results['n_components'] = N_COMPONENTS
results['test_predictor_accs'] = test_predictor_accs
results['test_direction_accs'] = test_direction_accs

print("\n" + "="*60)
print("Training Results")
print("="*60)
print(f"Test Predictor Accuracy: {test_predictor_accs}")
print(f"Test Direction Accuracy: {test_direction_accs}")
print("="*60)

# Save directions
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(DIRECTIONS_PATH, exist_ok=True)
os.makedirs(RESULTS_PATH, exist_ok=True)
os.makedirs(GENERATIONS_PATH, exist_ok=True)

# Save results
concept_name = f"note_{TARGET_NOTE}_ncomp{N_COMPONENTS}"
results_file = os.path.join(RESULTS_PATH, f"{concept_name}.json")

with open(results_file, 'w') as f:
    json.dump(make_json_serializable(results), f, indent=4)

controller.save(
    concept=concept_name,
    model_name="musicgen_large",
    path=DIRECTIONS_PATH
)

print(f"Saved concept: {concept_name}")
print(f"Directions: {DIRECTIONS_PATH}")
print(f"Results: {results_file}")


## Layer Selection (if you want to just generation, start from here)

Select which layers to control based on training performance:


In [None]:
concept_name = f"note_{TARGET_NOTE}_ncomp{N_COMPONENTS}"
results_file = os.path.join(RESULTS_PATH, f"{concept_name}.json")

with open(results_file, 'r') as f:
    results = json.load(f)

def select_top_k_layers(results, k, regression=False):
    """Select the top k layers by performance."""
    if regression:
        # Lower is better for regression
        sorted_layers = sorted(results['layer_metrics'], 
                             key=lambda x: results['layer_metrics'][x]['train_results']['score'], 
                             reverse=False)
    else:
        # Higher is better for classification
        sorted_layers = sorted(results['layer_metrics'], 
                             key=lambda x: results['layer_metrics'][x]['train_results']['score'], 
                             reverse=True)
    
    layers_to_control = [int(x) for x in sorted_layers[:k]]
    layers_to_control.sort()
    
    print(f"Selected top {k} layers: {layers_to_control}")
    return layers_to_control

def select_exponential_layer_dropout(results, regression=False, base_weight=1.0, decay_rate=0.95):
    """Select all layers with exponentially decreasing weights based on performance."""
    layer_scores = {}
    for layer_str, metrics in results['layer_metrics'].items():
        layer_num = int(layer_str)
        score = metrics['train_results']['score']
        layer_scores[layer_num] = score
    
    # Normalize scores to [0, 1]
    if regression:
        best_score = min(layer_scores.values())
        worst_score = max(layer_scores.values())
        normalized_scores = {layer: (worst_score - score) / (worst_score - best_score) 
                           for layer, score in layer_scores.items()}
    else:
        best_score = max(layer_scores.values())
        worst_score = min(layer_scores.values())
        normalized_scores = {layer: (score - worst_score) / (best_score - worst_score) 
                           for layer, score in layer_scores.items()}
    
    # Calculate exponential weights
    layers_to_control = []
    layer_weights = []
    
    for layer_num in sorted(layer_scores.keys()):
        normalized_score = normalized_scores[layer_num]
        weight = base_weight * (normalized_score ** (1/decay_rate))
        layers_to_control.append(layer_num)
        layer_weights.append(weight)
    
    print(f"Using exponential layer dropout with {len(layers_to_control)} layers")
    print(f"Top 5 layers by weight:")
    sorted_by_weight = sorted(zip(layers_to_control, layer_weights), key=lambda x: x[1], reverse=True)
    for i in range(min(5, len(sorted_by_weight))):
        layer, weight = sorted_by_weight[i]
        print(f"  Layer {layer}: weight={weight:.4f}")
    
    return layers_to_control, layer_weights

# Select layers based on configuration
layer_weights = None

if LAYER_SELECTION == "all":
    layers_to_control = list(range(-1, -48, -1))
    print(f"Using all {len(layers_to_control)} layers")
    
elif LAYER_SELECTION == "top_k":
    layers_to_control = select_top_k_layers(results, TOP_K, IS_REGRESSION)
        
elif LAYER_SELECTION == "exp_weighting":
    layers_to_control, layer_weights = select_exponential_layer_dropout(
        results, IS_REGRESSION, EXP_BASE_WEIGHT, EXP_DECAY_RATE
    )
else:
    raise ValueError(f"Invalid layer selection method: {LAYER_SELECTION}")


## Time Control Functions (Optional)

Define time-varying control functions:


In [None]:
def exponential_decay(t, base_coef, decay_rate=0.998):
    """Exponential decay function for time-varying control."""
    return base_coef * (decay_rate ** t)

def linear_decay(t, base_coef, total_steps=1500):
    """Linear decay function for time-varying control."""
    return base_coef * (1 - min(max(t / total_steps, 0), 1))

# Setup time control function
time_control_fn = None
if TIME_CONTROL == "exp_decay":
    time_control_fn = lambda t, base: exponential_decay(t, base, TIME_DECAY_RATE)
    print(f"Using exponential decay with rate {TIME_DECAY_RATE}")
elif TIME_CONTROL == "linear_decay":
    time_control_fn = lambda t, base: linear_decay(t, base)
    print("Using linear decay over 1500 steps")
elif TIME_CONTROL is None:
    print("Using constant control (no time variation)")


## Generate Controlled Music

Generate music with and without control for comparison


In [None]:
import time

controller = MusicGenController(
        music_model,
        music_processor,
        encodec_model,
        encodec_processor,
        control_method="music_rfm",
        n_components=N_COMPONENTS,  # Should match training
        rfm_iters=RFM_ITERS,
        batch_size=BATCH_SIZE
)

controller.load(
    concept=concept_name,
    model_name="musicgen_large",
    path=DIRECTIONS_PATH
)

sampling_rate = music_model.config.audio_encoder.sampling_rate

print("\n" + "="*60)
print("Generating Controlled Music")
print("="*60)

BATCH_SIZE_GEN = 4
for i in range(0, len(TEST_PROMPTS), BATCH_SIZE_GEN):
    batch_prompts = TEST_PROMPTS[i:i+BATCH_SIZE_GEN]
    batch_indices = list(range(i, min(i+BATCH_SIZE_GEN, len(TEST_PROMPTS))))
    
    print(f"\nProcessing batch {i//BATCH_SIZE_GEN + 1}: prompts {i+1}-{min(i+BATCH_SIZE_GEN, len(TEST_PROMPTS))}")
    for j, (idx, prompt) in enumerate(zip(batch_indices, batch_prompts)):
        print(f"  Prompt {idx+1}: '{prompt}'")
    
    # Generate baseline (no control) for entire batch
    print("  Generating baseline (no control)...")
    inputs = music_processor(
        text=batch_prompts,
        padding=True,
        return_tensors="pt"
    ).to(DEVICE)
    
    start_time = time.time()
    with torch.no_grad():
        baseline_audios = music_model.generate(**inputs, max_new_tokens=1000)
    baseline_elapsed = time.time() - start_time
    print(f"    ✓ Baseline generation took {baseline_elapsed:.2f} seconds")

    for j, (prompt_idx, audio) in enumerate(zip(batch_indices, baseline_audios)):
        baseline_path = f"{GENERATIONS_PATH}/prompt{prompt_idx}_baseline.flac"
        sf.write(baseline_path, audio[0].cpu().numpy(), sampling_rate, format='FLAC')
    print(f"    ✓ Saved {len(batch_indices)} baseline files")
    
    # Generate with different control strengths
    for control_coef in CONTROL_COEFFICIENTS:
        print(f"  Generating with control={control_coef}...")
        
        start_time = time.time()
        controlled_audios = controller.generate(
            batch_prompts,
            layers_to_control=layers_to_control,
            control_coef=control_coef,
            max_new_tokens=1000,
            time_control_fn=time_control_fn,
            layer_weights=layer_weights,
            inject_chance=INJECT_CHANCE
        )
        controlled_elapsed = time.time() - start_time
        print(f"    ✓ Controlled generation took {controlled_elapsed:.2f} seconds")
        
        # Save each controlled output
        for j, (prompt_idx, audio) in enumerate(zip(batch_indices, controlled_audios)):
            # Build filename with config info
            filename_parts = [f"prompt{prompt_idx}", f"control{control_coef}"]
            if LAYER_SELECTION != "all":
                filename_parts.append(LAYER_SELECTION)
            if TIME_CONTROL:
                filename_parts.append(TIME_CONTROL)
            if INJECT_CHANCE < 1.0:
                filename_parts.append(f"inject{INJECT_CHANCE}")
            
            output_path = f"{GENERATIONS_PATH}/{'_'.join(filename_parts)}.flac"
            sf.write(output_path, audio[0].cpu().numpy(), sampling_rate, format='FLAC')
        print(f"    ✓ Saved {len(batch_indices)} controlled files")

print("\n" + "="*60)
print("✓ Generation complete!")
print(f"  Generated {len(TEST_PROMPTS) * (len(CONTROL_COEFFICIENTS) + 1)} audio files")
print(f"  Output directory: {GENERATIONS_PATH}")
print("="*60)


    ✓ Controlled generation took 48.30 seconds
    ✓ Saved 4 controlled files

✓ Generation complete!
  Generated 8 audio files
  Output directory: ./trained_concepts/generations


## Listen to Results

In [39]:
prompt_idx = 3

print(f"Prompt: '{TEST_PROMPTS[prompt_idx]}'")
print("\nBaseline (no control):")
display(ipd.Audio(f"{GENERATIONS_PATH}/prompt{prompt_idx}_baseline.flac"))

for control_coef in CONTROL_COEFFICIENTS:
    print(f"\nWith control coefficient {control_coef}:")
    
    filename_parts = [f"prompt{prompt_idx}", f"control{control_coef}"]
    if LAYER_SELECTION != "all":
        filename_parts.append(LAYER_SELECTION)
    if TIME_CONTROL:
        filename_parts.append(TIME_CONTROL)
    if INJECT_CHANCE < 1.0:
        filename_parts.append(f"inject{INJECT_CHANCE}")
    
    audio_path = f"{GENERATIONS_PATH}/{'_'.join(filename_parts)}.flac"
    display(ipd.Audio(audio_path))


Prompt: 'Dreamy future bass with chopped synth chords'

Baseline (no control):



With control coefficient 0.4:
