# SpotDiffusion AudioLDM Inference

This notebook demonstrates how to use SpotDiffusion with AudioLDM for efficient long-form audio generation. SpotDiffusion uses random shifting and tiling to generate seamless audio without the performance penalty of MultiDiffusion.

## Key Features:
- **Random Shifting**: Applies consistent random shifts across diffusion steps to hide tile boundaries
- **Non-overlapping Tiles**: Processes non-overlapping audio tiles for better performance  
- **Fallback to Regular DDIM**: Automatically uses standard generation for short audio
- **Memory Efficient**: Sliding window VAE decoding for long sequences

In [15]:
# Import required libraries
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
import IPython.display as ipd
import soundfile as sf
import time
import os

# Add AudioLDM to path
sys.path.append('./AudioLDM')
sys.path.append('./AudioLDM/audioldm')

# Import AudioLDM
from audioldm import build_model
from audioldm.latent_diffusion.ddim import DDIMSampler
from audioldm.pipeline import duration_to_latent_t_size

# Reload the spotutil module to get the REFLECTION PADDING fix
import importlib
if 'utils.spotutil' in sys.modules:
    importlib.reload(sys.modules['utils.spotutil'])

# Import SpotDiffusion utilities
from utils.spotutil import *

print("✅ All imports successful!")
print("🔧 FIXED: Global noise coherence for tile boundaries")
print("🔧 FIXED: Removed unnecessary frequency tiling - now only tiles in time dimension")
print("🔧 FIXED: Reflection padding instead of circular shifting for natural audio boundaries")

✅ All imports successful!
🔧 FIXED: Global noise coherence for tile boundaries
🔧 FIXED: Removed unnecessary frequency tiling - now only tiles in time dimension
🔧 FIXED: Reflection padding instead of circular shifting for natural audio boundaries


In [7]:
# Load AudioLDM model
print("📦 Loading AudioLDM model...")
model = build_model(model_name="audioldm-m-full")
sr = 16000
print(f"✅ Model loaded successfully! Sample rate: {sr} Hz")

📦 Loading AudioLDM model...
Load AudioLDM: %s audioldm-m-full
DiffusionWrapper has 415.95 M params.
DiffusionWrapper has 415.95 M params.


  WeightNorm.apply(module, name, dim)
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


✅ Model loaded successfully! Sample rate: 16000 Hz


In [8]:
# Check actual model latent dimensions
print(f"Model latent_f_size: {model.latent_f_size}")
print(f"Model channels: {model.channels}")
print(f"Model latent_t_size: {model.latent_t_size}")

# Check what a typical latent tensor looks like
duration = 10.0
latent_size = duration_to_latent_t_size(duration)
shape = [1, model.channels, latent_size, model.latent_f_size]
print(f"Typical latent shape for {duration}s audio: {shape}")
print(f"That's: [batch={shape[0]}, channels={shape[1]}, time_frames={shape[2]}, freq_bins={shape[3]}]")

print("\n🎯 CRITICAL INSIGHT:")
print(f"AudioLDM's frequency dimension is only {model.latent_f_size} bins.")
print("Before fix: We were unnecessarily tiling this tiny frequency dimension.")
print("After fix: We only tile in time, always process full frequency spectrum.")
print("This should dramatically improve audio coherence!")

Model latent_f_size: 16
Model channels: 8
Model latent_t_size: 256
Typical latent shape for 10.0s audio: [1, 8, 256, 16]
That's: [batch=1, channels=8, time_frames=256, freq_bins=16]

🎯 CRITICAL INSIGHT:
AudioLDM's frequency dimension is only 16 bins.
Before fix: We were unnecessarily tiling this tiny frequency dimension.
After fix: We only tile in time, always process full frequency spectrum.
This should dramatically improve audio coherence!


In [16]:
# Demonstrate the reflection padding fix
print("🧪 Demonstrating reflection padding vs circular shifting:")

# Create a simple test signal
test_signal = torch.tensor([1., 2., 3., 4., 5., 6., 7., 8.]).unsqueeze(0).unsqueeze(0).unsqueeze(-1)
print(f"Original signal: {test_signal.squeeze().tolist()}")

# Circular shifting (old method)
shifted_circular = torch.roll(test_signal, shifts=3, dims=2)
print(f"Circular shifted: {shifted_circular.squeeze().tolist()}")
print("  ❌ Note: [6,7,8] wraps to beginning - unnatural for audio!")

# Reflection padding (new method)
padded = torch.nn.functional.pad(test_signal, (0, 0, 3, 3), mode='reflect')
print(f"Reflection padded: {padded.squeeze().tolist()}")
shifted_reflected = torch.roll(padded, shifts=3, dims=2)
print(f"Shifted + padded: {shifted_reflected.squeeze().tolist()}")
cropped = shifted_reflected[:, :, 3:-3, :]
print(f"Final cropped: {cropped.squeeze().tolist()}")
print("  ✅ Note: Wrap-around happens in natural mirrored region!")

print("\n🎯 This should eliminate the unnatural audio boundaries that were degrading quality!")

🧪 Demonstrating reflection padding vs circular shifting:
Original signal: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
Circular shifted: [6.0, 7.0, 8.0, 1.0, 2.0, 3.0, 4.0, 5.0]
  ❌ Note: [6,7,8] wraps to beginning - unnatural for audio!
Reflection padded: [4.0, 3.0, 2.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 7.0, 6.0, 5.0]
Shifted + padded: [7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
Final cropped: [4.0, 3.0, 2.0, 1.0, 2.0, 3.0, 4.0, 5.0]
  ✅ Note: Wrap-around happens in natural mirrored region!

🎯 This should eliminate the unnatural audio boundaries that were degrading quality!


## SpotDiffusion Configuration

Configure the SpotDiffusion parameters:
- `tile_size_time`: Size of tiles in time dimension (frames)
- `min_overlap`: Minimum overlap between tiles (not used in non-overlapping version)
- `max_shift_ratio`: Maximum random shift as ratio of tile size

**Important**: SpotDiffusion only tiles in the time dimension. The full frequency spectrum is always processed together for coherent audio generation.

In [11]:
# Configure SpotDiffusion parameters
config = SpotDiffusionConfig(
    tile_size_time=256,     # Tile size in time dimension
    max_shift_ratio=1.00    # 50% maximum shift
)

print(f"SpotDiffusion Config:")
print(f"  Tile size (time): {config.tile_size_time}")
print(f"  Max shift ratio: {config.max_shift_ratio}")
print(f"✅ FIXED: Only tiling in time dimension, full frequency spectrum per tile")

SpotDiffusion Config:
  Tile size (time): 256
  Max shift ratio: 1.0
✅ FIXED: Only tiling in time dimension, full frequency spectrum per tile


## Single Prompt Generation

Generate audio from a single text prompt using SpotDiffusion.

In [17]:
# Test prompt and parameters
prompt = "90s rock song with electric guitar and heavy drums"
duration = 30.0  # seconds
ddim_steps = 200
cfg_scale = 3.0

print(f"🎯 Generating: '{prompt}'")
print(f"⏱️  Duration: {duration}s")
print()

# Generate audio using SpotDiffusion
waveform = spotdiffusion_text_to_audio(
    model=model,
    prompt=prompt,
    duration=duration,
    ddim_steps=ddim_steps,
    unconditional_guidance_scale=cfg_scale,
    config=config,
)

print(f"\n✅ Generation complete!")
print(f"📊 Audio shape: {waveform.shape}")
print(f"📊 Audio length: {len(waveform) / sr:.2f}s")

🎯 Generating: '90s rock song with electric guitar and heavy drums'
⏱️  Duration: 30.0s

🎯 Generating '90s rock song with electric guitar and heavy drums' (30.0s)
📐 Latent shape: [1, 8, 768, 16]
🔄 Using SpotDiffusion (latent_size=768 > tile_size=256)
SpotDiffusion: Processing 200 steps
Tile config: time=256 (freq=full spectrum)
Latent shape: [1, 8, 768, 16]


SpotDiffusion Steps:   0%|          | 0/200 [00:00<?, ?it/s]

SpotDiffusion Steps: 100%|██████████| 200/200 [00:42<00:00,  4.71it/s]


⏱️  Diffusion time: 42.5s
🪟 VAE decoding...
🪟 Using sliding window VAE decode: latent_length=768, window_size=256
   Window 1/4: latent frames 0-256
   Window 2/4: latent frames 192-448
   Window 3/4: latent frames 384-640
   Window 4/4: latent frames 576-768
✅ VAE decode complete: torch.Size([1, 8, 768, 16]) -> torch.Size([1, 1, 3072, 64])
⏱️  VAE time: 0.2s
🔊 Vocoder...
⏱️  Vocoder time: 0.2s
⏱️  Total time: 42.8s

✅ Generation complete!
📊 Audio shape: (491552,)
📊 Audio length: 30.72s


In [5]:
# Save and play the generated audio
output_file = "spotdiffusion_output.wav"
sf.write(output_file, waveform, sr)
print(f"💾 Saved audio to: {output_file}")

# Play audio
ipd.Audio(waveform, rate=sr)

💾 Saved audio to: spotdiffusion_output.wav


## Performance Comparison

Compare SpotDiffusion performance against regular DDIM for different audio lengths.

In [None]:
# Performance comparison
test_prompts = ["electronic music", "acoustic guitar", "nature sounds"]
test_durations = [10.0, 20.0, 30.0]  # seconds
ddim_steps = 50  # Reduced for faster testing

results = []

for prompt in test_prompts[:1]:  # Test with first prompt only
    for duration in test_durations:
        print(f"\n🔄 Testing: '{prompt}' for {duration}s")
        
        # Test SpotDiffusion
        start_time = time.time()
        waveform_spot = spotdiffusion_text_to_audio(
            model=model,
            prompt=prompt,
            duration=duration,
            ddim_steps=ddim_steps,
            config=config
        )
        spot_time = time.time() - start_time
        
        results.append({
            'prompt': prompt,
            'duration': duration,
            'method': 'SpotDiffusion',
            'time': spot_time,
            'audio_length': len(waveform_spot) / sr
        })
        
        print(f"  SpotDiffusion: {spot_time:.1f}s")

# Print results table
print("\n" + "="*60)
print("PERFORMANCE RESULTS")
print("="*60)
print(f"{'Prompt':<20} {'Duration':<10} {'Method':<15} {'Time':<10} {'Audio Len'}")
print("-"*60)

for result in results:
    print(f"{result['prompt']:<20} {result['duration']:<10.1f} {result['method']:<15} {result['time']:<10.1f} {result['audio_length']:<10.2f}")

## Visualization and Analysis

Visualize the generated audio and analyze the results.

In [None]:
# Plot waveform
plt.figure(figsize=(15, 6))

plt.subplot(2, 1, 1)
time_axis = np.linspace(0, len(waveform) / sr, len(waveform))
plt.plot(time_axis, waveform)
plt.title(f"SpotDiffusion Generated Audio: '{prompt}'")
plt.xlabel("Time (s)")
plt.ylabel("Amplitude")
plt.grid(True, alpha=0.3)

# Plot spectrogram
plt.subplot(2, 1, 2)
from scipy import signal
f, t, Sxx = signal.spectrogram(waveform, sr, nperseg=1024)
plt.pcolormesh(t, f, 10 * np.log10(Sxx), shading='gouraud', cmap='viridis')
plt.title("Spectrogram")
plt.xlabel("Time (s)")
plt.ylabel("Frequency (Hz)")
plt.colorbar(label='Power (dB)')

plt.tight_layout()
plt.show()

## Configuration Experiments

Test different SpotDiffusion configurations to understand their impact.

In [None]:
# Test different tile sizes
tile_sizes = [128, 256, 512]
test_prompt = "ambient electronic music"
test_duration = 20.0

print("🧪 Testing different tile sizes:")
print()

for tile_size in tile_sizes:
    config_test = SpotDiffusionConfig(
        tile_size_time=tile_size,
        max_shift_ratio=0.5
    )
    
    print(f"Testing tile_size_time = {tile_size}")
    
    start_time = time.time()
    waveform_test = spotdiffusion_text_to_audio(
        model=model,
        prompt=test_prompt,
        duration=test_duration,
        ddim_steps=50,
        config=config_test
    )
    test_time = time.time() - start_time
    
    # Save result
    output_file = f"spot_tile_{tile_size}.wav"
    sf.write(output_file, waveform_test, sr)
    
    print(f"  Time: {test_time:.1f}s")
    print(f"  Saved: {output_file}")
    print()

In [None]:
# Test different shift ratios
shift_ratios = [0.0, 0.25, 0.5, 0.75]
test_prompt = "jazz piano solo"
test_duration = 15.0

print("🧪 Testing different shift ratios:")
print()

for shift_ratio in shift_ratios:
    config_test = SpotDiffusionConfig(
        tile_size_time=256,
        max_shift_ratio=shift_ratio
    )
    
    print(f"Testing max_shift_ratio = {shift_ratio}")
    
    start_time = time.time()
    waveform_test = spotdiffusion_text_to_audio(
        model=model,
        prompt=test_prompt,
        duration=test_duration,
        ddim_steps=50,
        config=config_test
    )
    test_time = time.time() - start_time
    
    # Save result
    output_file = f"spot_shift_{shift_ratio:.2f}.wav"
    sf.write(output_file, waveform_test, sr)
    
    print(f"  Time: {test_time:.1f}s")
    print(f"  Saved: {output_file}")
    print()

## Quality Analysis

Compare audio quality between different configurations using simple metrics.

In [None]:
# Load and compare generated audio files
import glob

# Get all generated files
spot_files = glob.glob("spot_*.wav")
print(f"Found {len(spot_files)} SpotDiffusion audio files:")

for file in spot_files:
    # Load audio
    audio, _ = sf.read(file)
    
    # Calculate simple metrics
    rms = np.sqrt(np.mean(audio**2))
    peak = np.max(np.abs(audio))
    dynamic_range = 20 * np.log10(peak / (rms + 1e-8))
    
    print(f"  {file}:")
    print(f"    RMS: {rms:.4f}")
    print(f"    Peak: {peak:.4f}")
    print(f"    Dynamic Range: {dynamic_range:.1f} dB")
    print()

## Advanced Usage: Custom Diffusion Loop

For research purposes, you can access the low-level SpotDiffusion functions directly.

In [None]:
# Example: Manual control over SpotDiffusion sampling
print("🔬 Advanced SpotDiffusion usage:")

# Setup
prompt = "experimental electronic music"
model.cond_stage_model.embed_mode = "text"
conditioning = model.get_learned_conditioning([prompt])
unconditional_conditioning = model.get_learned_conditioning([""])

sampler = DDIMSampler(model)
duration = 15.0
latent_size = duration_to_latent_t_size(duration)
shape = [1, model.channels, latent_size, model.latent_f_size]

print(f"Latent shape: {shape}")
print(f"Should use SpotDiffusion: {should_use_spotdiffusion(latent_size, config)}")

# Generate random shift for step 0
shift_time = generate_random_shift(shape, config, step=0)
print(f"Random shift for step 0: time={shift_time} (freq=0, no freq shifting)")

# Create tiles
tiles = create_tiles(shape, config)
print(f"Generated {len(tiles)} tiles:")
for i, (t_start, t_end, f_start, f_end) in enumerate(tiles):
    print(f"  Tile {i+1}: time=[{t_start}:{t_end}], freq=[{f_start}:{f_end}] (full spectrum)")

## Summary

SpotDiffusion provides an efficient alternative to MultiDiffusion for long-form audio generation:

### Advantages:
- **Better Performance**: Non-overlapping tiles reduce computational overhead
- **Seamless Results**: Random shifting hides tile boundaries effectively
- **Memory Efficient**: Works with sliding window VAE decoding
- **Automatic Fallback**: Uses regular DDIM for short audio

### Use Cases:
- Long-form music generation (>30 seconds)
- Ambient soundscapes
- Background music for videos
- Audio texture synthesis

### Next Steps:
- Experiment with different tile sizes for your use case
- Try various shift ratios to balance quality vs. artifacts
- Compare with MultiDiffusion for quality assessment
- Implement regional prompting (future work)