# Text and Beatbox into Drum Generation

## Project Overview

This project explores a two-stage generative audio pipeline that combines:
1. **Text-based timbral control** (drum style / genre)
2. **Rhythmic control via beatboxing**

The goal is to synthesize realistic drum performances that:
- Preserve the rhythmic structure of a beatbox input
- Match a desired drum timbre specified using natural language

We integrate:
- **AudioLDM (Stage 1)** for text → drum timbre generation
- **TRIA (Stage 2)** for rhythm transfer and drum synthesis

## Input

In [35]:
# USER INPUT: text prompt + beatbox WAV

from google.colab import files
%cd /content

# ---- 1. Text prompt ----
TEXT_PROMPT = input("Enter drum style prompt (e.g. 'Rock'): ")

print("\nPrompt received:")
print(f"  \"{TEXT_PROMPT}\"")

# ---- 2. Beatbox WAV upload ----
print("\nPlease upload a BEATBOX WAV file:")
uploaded = files.upload()

# Take the first uploaded file
BEATBOX_WAV_PATH = list(uploaded.keys())[0]

# Basic validation
if not BEATBOX_WAV_PATH.lower().endswith(".wav"):
    raise ValueError("Uploaded file must be a .wav file")

print("\nBeatbox file received:")
print(f"  {BEATBOX_WAV_PATH}")

# ---- 3. Summary ----
print("\nInputs ready")
print("Text prompt:", TEXT_PROMPT)
print("Beatbox WAV:", BEATBOX_WAV_PATH)

/content
Enter drum style prompt (e.g. 'Rock'): soft jazz

Prompt received:
  "soft jazz"

Please upload a BEATBOX WAV file:


Saving beatbox9.wav to beatbox9.wav

Beatbox file received:
  beatbox9.wav

Inputs ready
Text prompt: soft jazz
Beatbox WAV: beatbox9.wav


In [36]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## 1. Text to Drum Timbre Generation - AudioLDM model

We first generate a short drum audio sample that represents the **desired drum style**.
This sample will later serve as the *timbre reference* for TRIA.

Example prompts:
- "Vintage jazz drum kit"
- "Heavy metal drums with distortion"
- "80s electronic drum machine"

### 1.1 Environment Setup

System and GPU check

In [37]:
import torch
if torch.cuda.is_available():
    print(" GPU Connected: ", torch.cuda.get_device_name(0))
else:
    print(" Warning: No GPU connected. Go to Runtime > Change runtime type > T4 GPU")

 GPU Connected:  Tesla T4


Install dependencies

In [38]:
# AudioLDM
!pip install -q "diffusers==0.33.1" transformers accelerate scipy

### 1.2 Run model

In [39]:
%cd /content

/content


In [40]:
%%writefile run_model.py

import os
os.environ["TRANSFORMERS_NO_TF"] = "1"
os.environ["USE_TF"] = "0"

import torch
from diffusers import AudioLDMPipeline
import scipy.io.wavfile
import argparse
import numpy as np

def generate_audio(prompt, duration=5.0, steps=50, output_file="output.wav"):
    print(f"\nStarting generation for prompt: '{prompt}'")

    # Load the pretrained AudioLDM model
    # The model will be downloaded only if it is not already cached
    # float16 is used for better performance and lower memory usage
    try:
        pipe = AudioLDMPipeline.from_pretrained(
            "cvssp/audioldm-s-full-v2",
             dtype=torch.float16
        )
    except Exception as e:
        print(f"Error loading model: {e}")
        return

    # Move the model to GPU if available
    if torch.cuda.is_available():
        pipe = pipe.to("cuda")
        print("Using CUDA GPU")
    else:
        print("Using CPU (generation may be slow)")

    # Generate audio from text prompt
    print("Generating audio...")
    audio = pipe(
        prompt,
        num_inference_steps=steps,
        audio_length_in_s=duration,
        guidance_scale=1.5,  # lower = often cleaner audio
        negative_prompt="melody, bass, synth, guitar, piano, vocals, singing, speech, chords, orchestra, reverb, ambience, static, hiss, noise, distortion, artifacts, low quality"  # new
    ).audios[0]

    # --- Save the generated audio to a WAV file (robust) ---
    audio = np.asarray(audio)

    # If audio is shape (n,) it's fine; if it's (n,1) flatten it
    audio = audio.squeeze()

    # Clip to valid range
    audio = np.clip(audio, -1.0, 1.0)

    # Convert to int16 PCM (standard wav format)
    audio_int16 = (audio * 32767.0).astype(np.int16)

    scipy.io.wavfile.write(output_file, rate=16000, data=audio_int16)
    print(f"Audio saved to: {output_file}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate audio from text using AudioLDM")
    parser.add_argument("--prompt", type=str, required=True, help="Text description for audio generation")
    parser.add_argument("--out", type=str, default="generated.wav", help="Output WAV filename")
    parser.add_argument("--time", type=float, default=5.0, help="Audio duration in seconds")
    args = parser.parse_args()

    generate_audio(
        args.prompt,
        duration=args.time,
        output_file=args.out
    )

Overwriting run_model.py


Generate drum audio according to prompt and time selection

In [41]:
# Build the full prompt for AudioLDM
FULL_PROMPT = (
    f"drums solo in {TEXT_PROMPT} style"
)
print("\nFull prompt sent to AudioLDM:")
print(f'  "{FULL_PROMPT}"')

# Run AudioLDM using the generated prompt
!python run_model.py --prompt "$FULL_PROMPT" --out "drums.wav" --time 5


Full prompt sent to AudioLDM:
  "drums solo in soft jazz style"

Starting generation for prompt: 'drums solo in soft jazz style'
Keyword arguments {'dtype': torch.float16} are not expected by AudioLDMPipeline and will be ignored.
Loading pipeline components...: 100% 6/6 [00:00<00:00, 17.38it/s]
Using CUDA GPU
Generating audio...
100% 50/50 [00:02<00:00, 17.65it/s]
Audio saved to: drums.wav


In [42]:
from IPython.display import Audio
Audio("drums.wav")

## 2. Rhythm Transfer with TRIA Model

TRIA (The Rhythm In Anything) allows us to:
- Extract rhythmic structure from any audio (e.g. beatboxing)
- Apply it to a reference timbre audio

Inputs:
- Drum timbre reference (from AudioLDM output)
- Beatbox rhythm audio

Output:
- Drum audio preserving the beatbox rhythm

### 2.1 Setup

In [43]:
# Install dependencies
!pip install -q descript-audiotools librosa soundfile scipy pyloudnorm
!pip install -q primePy

In [44]:
%%capture
# Clone repository
!git clone https://github.com/interactiveaudiolab/tria.git
%cd tria

In [45]:
# Download model weights from Hugging Face
import os
from pathlib import Path
from huggingface_hub import hf_hub_download

# Configuration
REPO_ID = "canfious/TextDrums"  # HF repo

print(" Downloading pretrained models from Hugging Face...\n")

# Create directories
os.makedirs("pretrained/tria/small_musdb_moises_2b/80000", exist_ok=True)
os.makedirs("pretrained/tokenizer/dac", exist_ok=True)

try:
    # Download TRIA model
    print(" Downloading TRIA model (~165MB)...")
    model_path = hf_hub_download(
        repo_id=REPO_ID,
        filename="tria/small_musdb_moises_2b/80000/model.pt",
        local_dir="pretrained",
        local_dir_use_symlinks=False
    )
    print(f" TRIA model downloaded to: {model_path}")

    # Download tokenizer
    print("\n  Downloading tokenizer (~293MB)...")
    tokenizer_path = hf_hub_download(
        repo_id=REPO_ID,
        filename="tokenizer/dac/dac_44.1kHz_7.7kbps.pt",
        local_dir="pretrained",
        local_dir_use_symlinks=False
    )
    print(f" Tokenizer downloaded to: {tokenizer_path}")

    print("\n" + "="*60)
    print(" All models downloaded successfully!")
    print("="*60)

except Exception as e:
    print(f"\n Error downloading models: {e}")
    print("\n Make sure you:")
    print("  1. Updated REPO_ID with your Hugging Face username")
    print("  2. Uploaded the models to your HF repository")
    print("  3. Made the repository public (or logged in with huggingface-cli)")
    raise

 Downloading pretrained models from Hugging Face...

 Downloading TRIA model (~165MB)...
 TRIA model downloaded to: pretrained/tria/small_musdb_moises_2b/80000/model.pt

  Downloading tokenizer (~293MB)...
 Tokenizer downloaded to: pretrained/tokenizer/dac/dac_44.1kHz_7.7kbps.pt

 All models downloaded successfully!


In [46]:
%%capture

# Imports
import torch
from functools import partial
from audiotools import AudioSignal
from tria.model.tria import TRIA
from tria.pipelines.tokenizer import Tokenizer
from tria.features import rhythm_features
from IPython.display import Audio, display
import numpy as np

### 2.2 Load Model

In [47]:
%%time
# Device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Model configuration
model_cfg = {
    "codebook_size": 1024,
    "n_codebooks": 9,
    "n_channels": 512,
    "n_feats": 2,
    "n_heads": 8,
    "n_layers": 12,
    "mult": 4,
    "p_dropout": 0.0,
    "bias": True,
    "max_len": 1000,
    "pos_enc": "rope",
    "qk_norm": True,
    "use_sdpa": True,
    "interp": "nearest",
    "share_emb": True,
}

# Load model
print("Loading TRIA model...")
model = TRIA(**model_cfg)
state_dict = torch.load("pretrained/tria/small_musdb_moises_2b/80000/model.pt", map_location="cpu")
model.load_state_dict(state_dict, strict=True)
model.to(device)
model.eval()

# Load tokenizer
print("Loading tokenizer...")
tokenizer = Tokenizer(name="dac")
tokenizer = tokenizer.to(device)

# Feature extraction
feature_cfg = {
    "sample_rate": 16_000,
    "n_bands": 2,
    "n_mels": 40,
    "window_length": 384,
    "hop_length": 192,
    "quantization_levels": 5,
    "slow_ma_ms": 200,
    "post_smooth_ms": 100,
    "legacy_normalize": False,
    "clamp_max": 50.0,
    "normalize_quantile": 0.98,
}
feat_fn = partial(rhythm_features, **feature_cfg)

print(" Model loaded successfully!")

Using device: cuda
Loading TRIA model...
Loading tokenizer...
 Model loaded successfully!
CPU times: user 1.44 s, sys: 426 ms, total: 1.86 s
Wall time: 1.86 s


### 2.3 Inference Functions

In [48]:
@torch.no_grad()
def generate_drums(
    timbre_path, rhythm_path,
    prefix_dur=2.0, max_dur=6.0,
    seed=0,
    # Inference parameters
    top_p=0.95, top_k=None, temp=1.0,
    mask_temp=10.5, guidance_scale=2.0, causal_bias=1.0,
    iterations=[10, 10, 10, 8, 8, 6, 6, 4, 4]
):
    """
    Generate drums from timbre and rhythm audio

    Args:
        timbre_path: Path to drum audio (defines style)
        rhythm_path: Path to any audio (defines rhythm)
        prefix_dur: Duration of timbre prefix (seconds)
        max_dur: Maximum generation duration (seconds)
        seed: Random seed for variations
        top_p, temp, etc: Sampling parameters

    Returns:
        AudioSignal with generated drums
    """

    sample_rate = tokenizer.sample_rate
    n_channels = tokenizer.n_channels
    interp = model.interp

    # Load audio
    timbre = AudioSignal(timbre_path).resample(sample_rate).to(device).to_mono()
    rhythm = AudioSignal(rhythm_path).resample(sample_rate).to(device).to_mono()

    # Truncate
    timbre = timbre.truncate_samples(int(prefix_dur * sample_rate))
    rhythm = rhythm.truncate_samples(int(max_dur * sample_rate) - timbre.signal_length)

    timbre.ensure_max_of_audio()
    rhythm.ensure_max_of_audio()

    # Tokenize
    timbre_tokens = tokenizer.encode(timbre)
    rhythm_tokens = tokenizer.encode(rhythm)

    tokens = torch.cat([timbre_tokens.tokens, rhythm_tokens.tokens], dim=-1)
    n_batch, n_codebooks, n_frames = tokens.shape
    prefix_frames = timbre_tokens.tokens.shape[-1]

    # Extract features
    _feats = feat_fn(rhythm)
    _feats = torch.nn.functional.interpolate(_feats, n_frames - prefix_frames, mode=interp)
    feats = torch.zeros(n_batch, _feats.shape[1], n_frames, device=device)
    feats[..., prefix_frames:] = _feats

    # Masks
    prefix_mask = torch.arange(n_frames, device=device)[None, :] < prefix_frames
    tokens_mask = prefix_mask[:, None, :].repeat(1, n_codebooks, 1)
    feats_mask = ~prefix_mask

    # Generate
    generated = model.inference(
        tokens, feats, tokens_mask, feats_mask,
        top_p=top_p, top_k=top_k, temp=temp,
        mask_temp=mask_temp,
        iterations=iterations,
        guidance_scale=guidance_scale,
        causal_bias=causal_bias,
        seed=[seed],
    )[..., prefix_frames:]

    # Decode
    rhythm_tokens.tokens = generated
    output = tokenizer.decode(rhythm_tokens)
    output.normalize(-20.0)
    output.ensure_max_of_audio()

    return output


def play_audio(audio_signal, title="Audio"):
    """Display audio in notebook"""
    audio_data = audio_signal.audio_data.cpu().numpy().flatten()
    sample_rate = audio_signal.sample_rate
    print(f" {title}")
    display(Audio(audio_data, rate=sample_rate))


print(" Functions defined!")

 Functions defined!


## 3. Generate Drums Audio

In [53]:
%%time

timbre_audio = AudioSignal("/content/drums.wav")
rhythm_audio = AudioSignal(f"/content/{BEATBOX_WAV_PATH}")

rythm_duration = rhythm_audio.duration

print("Generating drums...\n")

output = generate_drums(
    timbre_path="/content/drums.wav",
    rhythm_path=f"/content/{BEATBOX_WAV_PATH}",
    seed=42,
    max_dur=rythm_duration+2
)

print("\nGeneration complete!\n")


# Play inputs
play_audio(timbre_audio, "Input: Timbre (AudioLDM drums.wav)")
play_audio(rhythm_audio, f"Input: Rhythm ({BEATBOX_WAV_PATH})")
play_audio(output, "Output: Generated Drums")

Generating drums...


Generation complete!

 Input: Timbre (AudioLDM drums.wav)


 Input: Rhythm (beatbox12.wav)


 Output: Generated Drums


CPU times: user 5.77 s, sys: 7.15 ms, total: 5.78 s
Wall time: 5.83 s


### 3.1 Generate Multiple Variations

In [54]:
%%time
# Generate 3 variations with different seeds
print("Generating 3 variations...\n")

for i, seed in enumerate([0, 42, 123]):
    print(f"\n{'='*50}")
    print(f"Variation {i+1} (seed={seed})")
    print('='*50)

    output = generate_drums(
        timbre_path="/content/drums.wav",
        rhythm_path=f"/content/{BEATBOX_WAV_PATH}",
        seed=seed,
        max_dur=rythm_duration+2
    )

    play_audio(output, f" Variation {i+1} (seed={seed})")

    # Save
    filename = f"variation_{i+1}_seed{seed}.wav"
    output.cpu().write(filename)
    print(f"Saved: {filename}")

Generating 3 variations...


Variation 1 (seed=0)
  Variation 1 (seed=0)


Saved: variation_1_seed0.wav

Variation 2 (seed=42)
  Variation 2 (seed=42)


Saved: variation_2_seed42.wav

Variation 3 (seed=123)
  Variation 3 (seed=123)


Saved: variation_3_seed123.wav
CPU times: user 17.2 s, sys: 22 ms, total: 17.3 s
Wall time: 17.4 s


### 2.7 Custom Parameters

In [51]:
# Different parameters for the generate_drums function for different results.

%%time
# Generate with custom parameters
output = generate_drums(
    timbre_path=f"/content/drums.wav",
    rhythm_path=f"/content/{BEATBOX_WAV_PATH}",

    # Duration
    prefix_dur=2.5,  # More timbre context
    max_dur=rythm_duration+2,

    # Randomness
    seed=99,
    temp=1.2,        # More random (0.5-2.0)
    top_p=0.9,       # Nucleus sampling (0.0-1.0)

    # Conditioning
    guidance_scale=3.0,  # Stronger conditioning (0.0-10.0)
    causal_bias=0.8,     # Forward preference (0.0-1.0)
    mask_temp=12.0,      # Masking strategy (0.0-50.0)

    # Quality (more iterations = better but slower)
    iterations=[12, 12, 12, 10, 10, 8, 8, 6, 6]
)


play_audio(output, " Custom Parameters")
output.cpu().write("custom_generation.wav")

  Custom Parameters


CPU times: user 3.12 s, sys: 4.08 ms, total: 3.12 s
Wall time: 3.14 s


<audiotools.core.audio_signal.AudioSignal at 0x78c6a0f3b7a0>