Skip to content

Some optimizations for memory and speed#11

Closed
ghost-agent-250 wants to merge 1 commit intoHumeAI:mainfrom
ghost-agent-250:optimize-memory-and-speed
Closed

Some optimizations for memory and speed#11
ghost-agent-250 wants to merge 1 commit intoHumeAI:mainfrom
ghost-agent-250:optimize-memory-and-speed

Conversation

@ghost-agent-250
Copy link
Copy Markdown

@ghost-agent-250 ghost-agent-250 commented Mar 15, 2026

The example provided in README is slow and consumes high memory. I have removed the unnecessary loading of the encoder in the main module (tada.py)

For the best speed and memory utilization, we may want to update the example with:

  • Using torch compile (needs warmup)
  • num_flow_matching_steps = 10
  • prompt preprocessing

With this script, I was able to run the 3B model on RTX 5090 at 0.096 RTF and ~9GB VRAM

import os
import time
import torch
import torchaudio

from tada.modules.encoder import Encoder, EncoderOutput
from tada.modules.tada import InferenceOptions, TadaForCausalLM

device = "cuda"
PROMPT_CACHE = "prompt_cache.pt"
SAMPLE_RATE = 24000

audio_path = "samples/ljspeech.wav"
prompt_text = "The examination and testimony of the experts, enabled the commission to conclude that five shots may have been fired."

if os.path.exists(PROMPT_CACHE):
    state = torch.load(PROMPT_CACHE, map_location=device, weights_only=False)
    prompt = EncoderOutput(**state)
else:
    encoder = Encoder.from_pretrained("HumeAI/tada-codec", subfolder="encoder").to(device)
    audio, sample_rate = torchaudio.load(audio_path)
    audio = audio.to(device)
    prompt = encoder(audio, text=[prompt_text], sample_rate=sample_rate)
    torch.save(vars(prompt), PROMPT_CACHE)

for field in vars(prompt):
    v = getattr(prompt, field)
    if isinstance(v, torch.Tensor) and v.is_floating_point():
        setattr(prompt, field, v.to(torch.bfloat16))

model = TadaForCausalLM.from_pretrained(
    "HumeAI/tada-3b-ml",
    torch_dtype=torch.bfloat16,
).to(device)
model._decoder.to(torch.bfloat16)
model.compile()

torch.cuda.empty_cache()

torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()

with torch.inference_mode():
    # Warmup (must share inference_mode context with timed run)
    for i in range(2):
        model.generate(
            prompt=prompt,
            text="Please call Stella. Ask her to bring these things with her from the store.",
            inference_options=InferenceOptions(
                num_flow_matching_steps=8
            ),
        )

    torch.cuda.synchronize()
    t0 = time.perf_counter()
    output = model.generate(
        prompt=prompt,
        text="Please call Stella. Ask her to bring these things with her from the store.",
        inference_options=InferenceOptions(
            num_flow_matching_steps=10
        ),
    )
torch.cuda.synchronize()
elapsed = time.perf_counter() - t0

out_samples = output.audio[0].shape[-1]
out_duration_sec = out_samples / SAMPLE_RATE
rtf = elapsed / out_duration_sec
peak_gb = torch.cuda.max_memory_allocated() / 2**30
print(f"Peak GPU memory: {peak_gb:.2f} GB")
print(f"RTF: {rtf:.3f} ({elapsed:.2f}s gen / {out_duration_sec:.2f}s audio)")

torchaudio.save("output.wav", output.audio[0].cpu().float().unsqueeze(0), SAMPLE_RATE)

@srao25 srao25 closed this Mar 16, 2026
@srao25
Copy link
Copy Markdown
Collaborator

srao25 commented Mar 16, 2026

Fixed these in the latest commits.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants