<a href="https://colab.research.google.com/github/Omri-Triff/Text-to-Timbre-Drum-Transfer/blob/main/models/TRIA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TRIA

**What this does:**
- Takes **timbre** from drum audio (defines the sound/style)
- Takes **rhythm** from any audio (beatbox, voice, music)
- Generates drums that combine both!

**Paper:** [The Rhythm In Anything](https://arxiv.org/abs/2509.15625)  
**Demo:** https://therhythminanything.github.io  
**Repo:** https://github.com/interactiveaudiolab/tria

##  Tips for Best Results

### Timbre Prompt (Drums)
-  Use clean, isolated drum recordings
-  2-6 seconds is ideal
-  High quality recordings work best
- Examples: drum kit, drum machine, electronic drums

### Rhythm Prompt (Any Audio)
-  Clear rhythmic content helps
-  Beatbox, vocals, percussion work great
-  Even music with drums can be used
- ⚠️ Very noisy audio may give poor results

### Parameters
- **seed**: Change for different variations (0-1000)
- **temp**: Higher = more random (default: 1.0)
- **guidance_scale**: Higher = stronger conditioning (default: 2.0)
- **prefix_dur**: More timbre context (default: 2.0s)
- **max_dur**: Total length (default: 6.0s)

### Creative Ideas
-  Beatbox → Realistic drum kit
-  Voice → Drum patterns
-  Song → Drums in different style
-  MIDI drums → Acoustic drums
-  Guitar rhythm → Drum backing

---

##  Resources

- **Paper:** [The Rhythm In Anything](https://arxiv.org/abs/2509.15625)
- **Demo Site:** https://therhythminanything.github.io
- **GitHub:** https://github.com/interactiveaudiolab/tria

## Setup (Run Once)

In [None]:
%%capture
# Clone repository
!git clone https://github.com/interactiveaudiolab/tria.git
%cd tria

In [None]:
%%capture
# Install dependencies
!pip install -q descript-audiotools librosa soundfile scipy pyloudnorm

In [None]:
# Download model weights from Hugging Face
import os
from pathlib import Path
from huggingface_hub import hf_hub_download

# Configuration
REPO_ID = "canfious/TextDrums"  # HF repo

print(" Downloading pretrained models from Hugging Face...\n")

# Create directories
os.makedirs("pretrained/tria/small_musdb_moises_2b/80000", exist_ok=True)
os.makedirs("pretrained/tokenizer/dac", exist_ok=True)

try:
    # Download TRIA model
    print(" Downloading TRIA model (~165MB)...")
    model_path = hf_hub_download(
        repo_id=REPO_ID,
        filename="tria/small_musdb_moises_2b/80000/model.pt",
        local_dir="pretrained",
        local_dir_use_symlinks=False
    )
    print(f" TRIA model downloaded to: {model_path}")

    # Download tokenizer
    print("\n  Downloading tokenizer (~293MB)...")
    tokenizer_path = hf_hub_download(
        repo_id=REPO_ID,
        filename="tokenizer/dac/dac_44.1kHz_7.7kbps.pt",
        local_dir="pretrained",
        local_dir_use_symlinks=False
    )
    print(f" Tokenizer downloaded to: {tokenizer_path}")

    print("\n" + "="*60)
    print(" All models downloaded successfully!")
    print("="*60)

except Exception as e:
    print(f"\n Error downloading models: {e}")
    print("\n Make sure you:")
    print("  1. Updated REPO_ID with your Hugging Face username")
    print("  2. Uploaded the models to your HF repository")
    print("  3. Made the repository public (or logged in with huggingface-cli)")
    raise

 Downloading pretrained models from Hugging Face...

 Downloading TRIA model (~165MB)...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
For more details, check out https://huggingface.co/docs/huggingface_hub/main/en/guides/download#download-files-to-local-folder.


tria/small_musdb_moises_2b/80000/model.p(…):   0%|          | 0.00/172M [00:00<?, ?B/s]

 TRIA model downloaded to: pretrained/tria/small_musdb_moises_2b/80000/model.pt

  Downloading tokenizer (~293MB)...


tokenizer/dac/dac_44.1kHz_7.7kbps.pt:   0%|          | 0.00/307M [00:00<?, ?B/s]

 Tokenizer downloaded to: pretrained/tokenizer/dac/dac_44.1kHz_7.7kbps.pt

 All models downloaded successfully!


In [None]:
%%capture
!pip install -q primePy

# Imports
import torch
from functools import partial
from audiotools import AudioSignal
from tria.model.tria import TRIA
from tria.pipelines.tokenizer import Tokenizer
from tria.features import rhythm_features
from IPython.display import Audio, display
import numpy as np

print(" Imports successful!")
print(f"PyTorch: {torch.__version__}")
print(f"Device: {'GPU' if torch.cuda.is_available() else 'CPU'}")

## Load Model

In [None]:
%%time
# Device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Model configuration
model_cfg = {
    "codebook_size": 1024,
    "n_codebooks": 9,
    "n_channels": 512,
    "n_feats": 2,
    "n_heads": 8,
    "n_layers": 12,
    "mult": 4,
    "p_dropout": 0.0,
    "bias": True,
    "max_len": 1000,
    "pos_enc": "rope",
    "qk_norm": True,
    "use_sdpa": True,
    "interp": "nearest",
    "share_emb": True,
}

# Load model
print("Loading TRIA model...")
model = TRIA(**model_cfg)
state_dict = torch.load("pretrained/tria/small_musdb_moises_2b/80000/model.pt", map_location="cpu")
model.load_state_dict(state_dict, strict=True)
model.to(device)
model.eval()

# Load tokenizer
print("Loading tokenizer...")
tokenizer = Tokenizer(name="dac")
tokenizer = tokenizer.to(device)

# Feature extraction
feature_cfg = {
    "sample_rate": 16_000,
    "n_bands": 2,
    "n_mels": 40,
    "window_length": 384,
    "hop_length": 192,
    "quantization_levels": 5,
    "slow_ma_ms": 200,
    "post_smooth_ms": 100,
    "legacy_normalize": False,
    "clamp_max": 50.0,
    "normalize_quantile": 0.98,
}
feat_fn = partial(rhythm_features, **feature_cfg)

print(" Model loaded successfully!")

Using device: cuda
Loading TRIA model...
Loading tokenizer...


  WeightNorm.apply(module, name, dim)


 Model loaded successfully!
CPU times: user 1.91 s, sys: 847 ms, total: 2.76 s
Wall time: 3.32 s


## Inference Functions

In [None]:
@torch.no_grad()
def generate_drums(
    timbre_path, rhythm_path,
    prefix_dur=2.0, max_dur=6.0,
    seed=0,
    # Inference parameters
    top_p=0.95, top_k=None, temp=1.0,
    mask_temp=10.5, guidance_scale=2.0, causal_bias=1.0,
    iterations=[8, 8, 8, 8, 4, 4, 4, 4, 4]
):
    """
    Generate drums from timbre and rhythm audio

    Args:
        timbre_path: Path to drum audio (defines style)
        rhythm_path: Path to any audio (defines rhythm)
        prefix_dur: Duration of timbre prefix (seconds)
        max_dur: Maximum generation duration (seconds)
        seed: Random seed for variations
        top_p, temp, etc: Sampling parameters

    Returns:
        AudioSignal with generated drums
    """

    sample_rate = tokenizer.sample_rate
    n_channels = tokenizer.n_channels
    interp = model.interp

    # Load audio
    timbre = AudioSignal(timbre_path).resample(sample_rate).to(device).to_mono()
    rhythm = AudioSignal(rhythm_path).resample(sample_rate).to(device).to_mono()

    # Truncate
    timbre = timbre.truncate_samples(int(prefix_dur * sample_rate))
    rhythm = rhythm.truncate_samples(int(max_dur * sample_rate) - timbre.signal_length)

    timbre.ensure_max_of_audio()
    rhythm.ensure_max_of_audio()

    # Tokenize
    timbre_tokens = tokenizer.encode(timbre)
    rhythm_tokens = tokenizer.encode(rhythm)

    tokens = torch.cat([timbre_tokens.tokens, rhythm_tokens.tokens], dim=-1)
    n_batch, n_codebooks, n_frames = tokens.shape
    prefix_frames = timbre_tokens.tokens.shape[-1]

    # Extract features
    _feats = feat_fn(rhythm)
    _feats = torch.nn.functional.interpolate(_feats, n_frames - prefix_frames, mode=interp)
    feats = torch.zeros(n_batch, _feats.shape[1], n_frames, device=device)
    feats[..., prefix_frames:] = _feats

    # Masks
    prefix_mask = torch.arange(n_frames, device=device)[None, :] < prefix_frames
    tokens_mask = prefix_mask[:, None, :].repeat(1, n_codebooks, 1)
    feats_mask = ~prefix_mask

    # Generate
    generated = model.inference(
        tokens, feats, tokens_mask, feats_mask,
        top_p=top_p, top_k=top_k, temp=temp,
        mask_temp=mask_temp,
        iterations=iterations,
        guidance_scale=guidance_scale,
        causal_bias=causal_bias,
        seed=[seed],
    )[..., prefix_frames:]

    # Decode
    rhythm_tokens.tokens = generated
    output = tokenizer.decode(rhythm_tokens)
    output.normalize(-20.0)
    output.ensure_max_of_audio()

    return output


def play_audio(audio_signal, title="Audio"):
    """Display audio in notebook"""
    audio_data = audio_signal.audio_data.cpu().numpy().flatten()
    sample_rate = audio_signal.sample_rate
    print(f" {title}")
    display(Audio(audio_data, rate=sample_rate))


print(" Functions defined!")

 Functions defined!


## Try With Example Files

In [None]:
%%time
# Generate with example files
print("Generating drums from example files...\n")

output = generate_drums(
    timbre_path="assets/drums/drums_1.wav",
    rhythm_path="assets/beatbox/beatbox_1.wav",
    seed=42
)

print("\n Generation complete!\n")

# Play inputs
timbre_audio = AudioSignal("assets/drums/drums_1.wav")
rhythm_audio = AudioSignal("assets/beatbox/beatbox_1.wav")

play_audio(timbre_audio, " Input: Timbre (Drums)")
play_audio(rhythm_audio, " Input: Rhythm (Beatbox)")
play_audio(output, " Output: Generated Drums")

Generating drums from example files...


 Generation complete!

  Input: Timbre (Drums)


  Input: Rhythm (Beatbox)


  Output: Generated Drums


CPU times: user 7.5 s, sys: 852 ms, total: 8.35 s
Wall time: 15.3 s


## Upload Your Own Audio

In [None]:
from google.colab import files
import shutil

print("Upload timbre audio (drums for style):")
timbre_files = files.upload()
timbre_file = list(timbre_files.keys())[0]

print("\nUpload rhythm audio (any audio for rhythm):")
rhythm_files = files.upload()
rhythm_file = list(rhythm_files.keys())[0]

print(f"\n Uploaded:")
print(f"  Timbre: {timbre_file}")
print(f"  Rhythm: {rhythm_file}")

In [None]:
%%time
# Generate with your files
print("Generating drums from your audio...\n")

output = generate_drums(
    timbre_path=timbre_file,
    rhythm_path=rhythm_file,
    seed=0,
    prefix_dur=2.0,
    max_dur=6.0
)

print("\n Generation complete!\n")

# Play
timbre_audio = AudioSignal(timbre_file)
rhythm_audio = AudioSignal(rhythm_file)

play_audio(timbre_audio, " Your Timbre (Drums)")
play_audio(rhythm_audio, " Your Rhythm")
play_audio(output, " Generated Drums")

# Save and download
output.write("generated_drums.wav")
print("\n Downloading generated audio...")
files.download("generated_drums.wav")

##  Generate Multiple Variations

In [None]:
%%time
# Generate 3 variations with different seeds
print("Generating 3 variations...\n")

for i, seed in enumerate([0, 42, 123]):
    print(f"\n{'='*50}")
    print(f"Variation {i+1} (seed={seed})")
    print('='*50)

    output = generate_drums(
        timbre_path="assets/drums/drums_1.wav",
        rhythm_path="assets/beatbox/beatbox_1.wav",
        seed=seed
    )

    play_audio(output, f" Variation {i+1} (seed={seed})")

    # Save
    filename = f"variation_{i+1}_seed{seed}.wav"
    output.write(filename)
    print(f"Saved: {filename}")

## Advanced: Custom Parameters

In [None]:
%%time
# Generate with custom parameters
output = generate_drums(
    timbre_path="assets/drums/drums_1.wav",
    rhythm_path="assets/beatbox/beatbox_1.wav",

    # Duration
    prefix_dur=2.5,  # More timbre context
    max_dur=8.0,     # Longer generation

    # Randomness
    seed=99,
    temp=1.2,        # More random (0.5-2.0)
    top_p=0.9,       # Nucleus sampling (0.0-1.0)

    # Conditioning
    guidance_scale=3.0,  # Stronger conditioning (0.0-10.0)
    causal_bias=0.8,     # Forward preference (0.0-1.0)
    mask_temp=12.0,      # Masking strategy (0.0-50.0)

    # Quality (more iterations = better but slower)
    iterations=[10, 10, 10, 8, 8, 6, 6, 4, 4]
)

play_audio(output, " Custom Parameters")
output.write("custom_generation.wav")
files.download("custom_generation.wav")

##  Batch Processing

In [None]:
import zipfile

# Upload multiple rhythm files
print("Upload multiple audio files to process:")
uploaded = files.upload()

# Process each file
outputs = []
for filename in uploaded.keys():
    print(f"\nProcessing: {filename}")

    output = generate_drums(
        timbre_path="assets/drums/drums_1.wav",
        rhythm_path=filename,
        seed=0
    )

    out_name = f"drums_{filename}"
    output.write(out_name)
    outputs.append(out_name)
    print(f" Saved: {out_name}")

# Create zip
print("\nCreating zip file...")
with zipfile.ZipFile("all_generations.zip", "w") as zipf:
    for out_file in outputs:
        zipf.write(out_file)

print("\n Downloading all files...")
files.download("all_generations.zip")
print(" Done!")