In [None]:
#!pip install -U datasets
!pip install datasets==2.16.0
!pip install huggingface-hub==0.20.0
!apt-get install -y libsox-dev
!pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121
!pip install causal-conv1d==1.4.0 && pip install mamba-ssm==2.2.2

Collecting datasets==2.16.0
  Downloading datasets-2.16.0-py3-none-any.whl.metadata (20 kB)
Collecting pyarrow-hotfix (from datasets==2.16.0)
  Downloading pyarrow_hotfix-0.7-py3-none-any.whl.metadata (3.6 kB)
Collecting dill<0.3.8,>=0.3.0 (from datasets==2.16.0)
  Downloading dill-0.3.7-py3-none-any.whl.metadata (9.9 kB)
Collecting fsspec<=2023.10.0,>=2023.1.0 (from fsspec[http]<=2023.10.0,>=2023.1.0->datasets==2.16.0)
  Downloading fsspec-2023.10.0-py3-none-any.whl.metadata (6.8 kB)
INFO: pip is looking at multiple versions of multiprocess to determine which version is compatible with other requirements. This could take a while.
Collecting multiprocess (from datasets==2.16.0)
  Downloading multiprocess-0.70.18-py312-none-any.whl.metadata (7.5 kB)
  Downloading multiprocess-0.70.17-py312-none-any.whl.metadata (7.2 kB)
  Downloading multiprocess-0.70.15-py311-none-any.whl.metadata (7.2 kB)
Downloading datasets-2.16.0-py3-none-any.whl (507 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
#!/usr/bin/env python3
"""
Mamba KWS Inference Benchmark
============================

This script benchmarks inference performance for small, medium, and large Mamba models
for Keyword Spotting using the exact implementations from the collab code files.
It can run both locally (CPU) and on Colab (GPU).

Features:
- Tests latency over 1000 runs
- Tests throughput for batch sizes: 1, 2, 4, 8, 16, 32
- Uses exact Mamba implementation from collab code
- Uses pre-trained model weights from models/ directory
"""
from __future__ import annotations
import time
import statistics
import json
import math
from datetime import datetime
from pathlib import Path
import numpy as np
from typing import Optional, Dict, List, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio.transforms as T

# Detect environment
IS_COLAB = 'google.colab' in str(get_ipython()) if 'get_ipython' in globals() else False
HAS_CUDA = torch.cuda.is_available()

print(f"Environment: {'Colab' if IS_COLAB else 'Local'}")
print(f"CUDA available: {HAS_CUDA}")
device = torch.device("cuda" if HAS_CUDA else "cpu")
print(f"Device: {device}")

# Try to import mamba_ssm, with graceful fallback for local testing
try:
    from mamba_ssm import Mamba
    print("✅ Using mamba_ssm library")
    USE_MAMBA_SSM = True
except ImportError:
    print("⚠️  mamba_ssm not available - will run benchmark with dummy Mamba for performance testing")
    USE_MAMBA_SSM = False

    # Simple dummy Mamba for benchmarking computational load (not for accuracy)
    class Mamba(nn.Module):
        def __init__(self, d_model: int, d_state: int = 16, expand: int = 2):
            super().__init__()
            d_inner = d_model * expand
            self.proj = nn.Linear(d_model, d_inner * 2)
            self.conv1d = nn.Conv1d(d_inner, d_inner, 3, padding=1, groups=d_inner)
            self.out_proj = nn.Linear(d_inner, d_model)

        def forward(self, x):
            # Simple approximation for benchmarking computational complexity
            x_proj = self.proj(x)
            x1, x2 = x_proj.chunk(2, dim=-1)
            x1 = x1.transpose(1, 2)
            x1 = self.conv1d(x1)
            x1 = x1.transpose(1, 2)
            x1 = F.silu(x1)
            return self.out_proj(x1 * F.silu(x2))

# =============================================================================
# Audio Processing - Exact implementation from collab code
# =============================================================================

class WaveToSpec:
    def __init__(self,
                 feature_type: str = "mel",
                 sample_rate: int = 16_000,
                 n_fft: int = 2048,
                 hop_length: int = 256,
                 n_mels: int = 128,
                 n_mfcc: int = 40,
                 top_db: int | None = 80,
                 apply_mask: bool = True,
                 freq_mask_param: int = 15,
                 time_mask_param: int = 10):
        self.feature_type = feature_type.lower(); assert self.feature_type in {"mel","mfcc"}
        self.apply_mask = apply_mask and self.feature_type == "mel"

        if self.feature_type == "mel":
            self.spec = T.MelSpectrogram(sample_rate, n_fft, hop_length, n_mels, power=2)
            self.to_db = T.AmplitudeToDB(stype="power", top_db=top_db)
            if self.apply_mask:
                self.freq_mask = T.FrequencyMasking(freq_mask_param)
                self.time_mask = T.TimeMasking(time_mask_param)
        else:
            self.spec = T.MFCC(sample_rate, n_mfcc,
                                melkwargs=dict(n_fft=n_fft, hop_length=hop_length, n_mels=n_mels))
            self.to_db = None
            self.freq_mask = self.time_mask = None

    def __call__(self, wav: torch.Tensor) -> torch.Tensor:
        if wav.dim() == 1:
            wav = wav.unsqueeze(0)
        feats = self.spec(wav)
        if self.to_db is not None:
            feats = self.to_db(feats.clamp(min=1e-10))
        if self.apply_mask:
            # No masking during inference
            pass
        return feats

# =============================================================================
# Model Architecture - Exact implementation from collab code
# =============================================================================

class MambaKWS(nn.Module):
    def __init__(self, num_classes: int, d_model=256, d_state=32, expand=2, n_layers=8, in_ch=1, feature_dim=128):
        super().__init__()

        # Convolutional embedding layer for feature extraction
        self.conv_embed = nn.Sequential(
            nn.Conv2d(in_ch, 32, 3, padding=1),
            nn.BatchNorm2d(32), nn.SiLU(),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.BatchNorm2d(32), nn.SiLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64), nn.SiLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.BatchNorm2d(64), nn.SiLU(),
            nn.MaxPool2d((2, 1)),
        )

        # Calculate the flattened dimension after convolutions to project to d_model
        freq_dim_after_conv = feature_dim // 4
        flattened_dim = 64 * freq_dim_after_conv

        # Projection layer to map flattened conv features to Mamba's dimension
        self.proj = nn.Sequential(
            nn.Linear(flattened_dim, d_model),
            nn.LayerNorm(d_model),
            nn.SiLU(),
            nn.Dropout(0.1)
        )

        # Add Mamba blocks with layer norm and residuals
        self.blocks = nn.ModuleList([
            nn.ModuleDict({
                "norm": nn.LayerNorm(d_model),
                "mamba": Mamba(d_model=d_model, d_state=d_state, expand=expand),
                "dropout": nn.Dropout(max(0.02, 0.05 - (i * 0.005)))
            }) for i in range(n_layers)
        ])
        self.pre_classifier_norm = nn.LayerNorm(d_model)

        # Classifier head
        self.classifier_dropout = nn.Dropout(0.1)
        self.classifier = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.SiLU(),
            nn.Dropout(0.05),
            nn.Linear(d_model // 2, num_classes)
        )

    def forward(self, x, lengths: torch.Tensor | None = None):  # x: [B, T, F]
        # reshape for Conv2d: [B, T, F] -> [B, 1, F, T]
        x = x.permute(0, 2, 1).unsqueeze(1)

        # conv front-end
        x = self.conv_embed(x)                  # [B, 64, F', T']

        # flatten per time-step and project
        x = x.permute(0, 3, 1, 2).contiguous().flatten(2)  # [B, T', 64*F']
        x = self.proj(x)                                   # [B, T', d_model]

        # Mamba blocks with residual connections
        for i, blk in enumerate(self.blocks):
            residual = x
            x = blk["norm"](x)
            x = blk["mamba"](x)
            x = blk["dropout"](x)
            x = residual + x

        x = self.pre_classifier_norm(x)

        # mask-aware mean pooling over time
        if lengths is not None:
            t_lens = torch.div(lengths, 2, rounding_mode='floor').clamp(min=1).to(x.device)  # first pool halves time
            Tprime = x.size(1)
            mask = (torch.arange(Tprime, device=x.device)[None, :] < t_lens[:, None]).float()  # [B, T']
            mask = mask.unsqueeze(-1)  # [B, T', 1]
            x_sum = (x * mask).sum(dim=1)                            # [B, d_model]
            denom = mask.sum(dim=1).clamp(min=1.0)                   # [B, 1]
            pooled = x_sum / denom
        else:
            pooled = x.mean(dim=1)

        # single-head output
        main_output = self.classifier(self.classifier_dropout(pooled))
        return main_output

# =============================================================================
# Model Configurations
# =============================================================================

MODEL_CONFIGS = {
    'small': {
        'model_path': 'content/small_kws_mel_97.42.pt',
        'd_model': 64,
        'd_state': 16,
        'n_layers': 8,
        'expected_classes': 35  # Google Speech Commands v0.02
    },
    'medium': {
        'model_path': 'content/medium_kws_melSpec_97.58.pt',
        'd_model': 128,
        'd_state': 16,
        'n_layers': 10,
        'expected_classes': 35
    },
    'large': {
        'model_path': 'content/large_kws_melSpec_97.75.pt',
        'd_model': 192,
        'd_state': 16,
        'n_layers': 12,
        'expected_classes': 35
    }
}

# =============================================================================
# Utility Functions
# =============================================================================

def load_model(model_size: str, device: torch.device) -> MambaKWS:
    """Load a pre-trained model with exact configuration from collab code"""
    config = MODEL_CONFIGS[model_size]

    # Create model with exact same parameters as collab code
    model = MambaKWS(
        num_classes=config['expected_classes'],
        d_model=config['d_model'],
        d_state=config['d_state'],
        n_layers=config['n_layers'],
        feature_dim=128  # Mel-spectrogram feature dimension
    )

    # Try to load weights if available
    model_path = Path(config['model_path'])
    if model_path.exists() and USE_MAMBA_SSM:
        try:
            checkpoint = torch.load(model_path, map_location=device)

            # Handle different checkpoint formats
            if isinstance(checkpoint, dict):
                if 'model_state_dict' in checkpoint:
                    state_dict = checkpoint['model_state_dict']
                elif 'state_dict' in checkpoint:
                    state_dict = checkpoint['state_dict']
                else:
                    state_dict = checkpoint
            else:
                state_dict = checkpoint

            model.load_state_dict(state_dict, strict=True)
            print(f"✅ {model_size} model loaded successfully from {model_path}")

        except Exception as e:
            print(f"⚠️  Could not load {model_size} weights ({e}), using random weights for benchmarking")
    else:
        if not model_path.exists():
            print(f"⚠️  Model file not found: {model_path}")
        if not USE_MAMBA_SSM:
            print(f"⚠️  Using dummy Mamba implementation")
        print(f"Using random weights for {model_size} model (benchmark mode)")

    model.to(device)
    model.eval()

    return model

def create_dummy_input(batch_size: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
    """Create dummy mel-spectrogram input matching collab code preprocessing"""
    # Typical mel-spectrogram dimensions from 1-second audio at 16kHz
    # After conv layers: time dimension is roughly halved
    T, F = 63, 128  # Time frames, Mel frequency bins
    features = torch.randn(batch_size, T, F, device=device)
    lengths = torch.full((batch_size,), T, dtype=torch.long, device=device)
    return features, lengths

def count_parameters(model: nn.Module) -> int:
    """Count total trainable parameters"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def benchmark_memory_usage(model: nn.Module, device: torch.device,
                          batch_sizes: List[int] = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1028, 2056]) -> Dict[str, Dict[str, float]]:
    """Benchmark memory usage for CUDA inference"""
    if device.type != 'cuda':
        return {}

    model.eval()
    results = {}

    # Get baseline memory usage (model weights)
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    baseline_memory = torch.cuda.memory_allocated()

    for batch_size in batch_sizes:
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

        # Record memory before inference
        memory_before = torch.cuda.memory_allocated()

        # Run inference
        dummy_input, dummy_lengths = create_dummy_input(batch_size, device)
        with torch.no_grad():
            _ = model(dummy_input, dummy_lengths)

        torch.cuda.synchronize()

        # Record peak memory during inference
        peak_memory = torch.cuda.max_memory_allocated()
        memory_after = torch.cuda.memory_allocated()

        # Calculate memory usage
        inference_memory = peak_memory - baseline_memory
        activation_memory = peak_memory - memory_before

        results[f"batch_{batch_size}"] = {
            'baseline_mb': baseline_memory / (1024**2),
            'peak_mb': peak_memory / (1024**2),
            'inference_mb': inference_memory / (1024**2),
            'activation_mb': activation_memory / (1024**2),
            'memory_per_sample_mb': activation_memory / batch_size / (1024**2)
        }

        # Cleanup
        del dummy_input, dummy_lengths
        torch.cuda.empty_cache()

    return results

def benchmark_latency(model: nn.Module, device: torch.device, num_runs: int = 1000) -> Dict[str, float]:
    """Benchmark latency for single batch inference"""
    model.eval()

    # Warmup runs
    dummy_input, dummy_lengths = create_dummy_input(1, device)
    with torch.no_grad():
        for _ in range(10):
            _ = model(dummy_input, dummy_lengths)

    if device.type == 'cuda':
        torch.cuda.synchronize()

    # Actual benchmark
    times = []
    with torch.no_grad():
        for _ in range(num_runs):
            dummy_input, dummy_lengths = create_dummy_input(1, device)

            start_time = time.perf_counter()
            _ = model(dummy_input, dummy_lengths)

            if device.type == 'cuda':
                torch.cuda.synchronize()

            end_time = time.perf_counter()
            times.append((end_time - start_time) * 1000)  # Convert to milliseconds

    return {
        'mean_ms': statistics.mean(times),
        'median_ms': statistics.median(times),
        'std_ms': statistics.stdev(times),
        'min_ms': min(times),
        'max_ms': max(times),
        'p95_ms': np.percentile(times, 95),
        'p99_ms': np.percentile(times, 99)
    }

def benchmark_throughput(model: nn.Module, device: torch.device,
                        batch_sizes: List[int] = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1028, 2056]) -> Dict[int, Dict[str, float]]:
    """Benchmark throughput for different batch sizes"""
    model.eval()
    results = {}

    for batch_size in batch_sizes:
        print(f"  Testing batch size {batch_size}...")

        # Warmup
        dummy_input, dummy_lengths = create_dummy_input(batch_size, device)
        with torch.no_grad():
            for _ in range(5):
                _ = model(dummy_input, dummy_lengths)

        if device.type == 'cuda':
            torch.cuda.synchronize()

        # Benchmark
        num_runs = max(10, 100 // batch_size)  # Fewer runs for larger batches
        times = []

        with torch.no_grad():
            for _ in range(num_runs):
                dummy_input, dummy_lengths = create_dummy_input(batch_size, device)

                start_time = time.perf_counter()
                _ = model(dummy_input, dummy_lengths)

                if device.type == 'cuda':
                    torch.cuda.synchronize()

                end_time = time.perf_counter()
                times.append(end_time - start_time)

        avg_time = statistics.mean(times)
        throughput = batch_size / avg_time

        results[batch_size] = {
            'avg_time_s': avg_time,
            'throughput_samples_per_s': throughput,
            'time_per_sample_ms': (avg_time / batch_size) * 1000
        }

    return results

def run_full_benchmark():
    """Run complete benchmark for all model sizes"""
    print(f"\n🚀 Starting Mamba KWS Inference Benchmark")
    print(f"Device: {device}")
    print(f"Using mamba_ssm: {USE_MAMBA_SSM}")
    print("=" * 60)

    results = {
        'timestamp': datetime.now().isoformat(),
        'device': str(device),
        'cuda_available': HAS_CUDA,
        'using_mamba_ssm': USE_MAMBA_SSM,
        'models': {}
    }

    for model_size in ['small', 'medium', 'large']:
        print(f"\n📊 Benchmarking {model_size.upper()} model...")

        # Load model
        model = load_model(model_size, device)
        param_count = count_parameters(model)

        print(f"Model parameters: {param_count:,}")

        # Latency benchmark
        print("🔍 Running latency benchmark (1000 runs)...")
        latency_results = benchmark_latency(model, device)

        # Throughput benchmark
        print("📈 Running throughput benchmark...")
        throughput_results = benchmark_throughput(model, device)

        # Memory benchmark (CUDA only)
        memory_results = {}
        if device.type == 'cuda':
            print("🧠 Running memory usage benchmark...")
            memory_results = benchmark_memory_usage(model, device)

        results['models'][model_size] = {
            'parameters': param_count,
            'config': MODEL_CONFIGS[model_size],
            'latency': latency_results,
            'throughput': throughput_results,
            'memory': memory_results
        }

        # Print summary
        print(f"\n📋 {model_size.upper()} Results Summary:")
        print(f"  Parameters: {param_count:,}")
        print(f"  Latency (single): {latency_results['mean_ms']:.2f}ms ± {latency_results['std_ms']:.2f}ms")
        print(f"  P95 latency: {latency_results['p95_ms']:.2f}ms")
        print(f"  Max throughput: {throughput_results[32]['throughput_samples_per_s']:.1f} samples/sec (batch=32)")

        if memory_results:
            batch_1_memory = memory_results['batch_1']
            batch_32_memory = memory_results['batch_32']
            print(f"  Memory (batch=1): {batch_1_memory['inference_mb']:.1f}MB total, {batch_1_memory['activation_mb']:.1f}MB activations")
            print(f"  Memory (batch=32): {batch_32_memory['inference_mb']:.1f}MB total, {batch_32_memory['memory_per_sample_mb']:.2f}MB per sample")

        # Cleanup
        del model
        if device.type == 'cuda':
            torch.cuda.empty_cache()

    # Save results
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_file = f"benchmark_results_{timestamp}.json"

    with open(results_file, 'w') as f:
        json.dump(results, f, indent=2)

    print(f"\n💾 Results saved to: {results_file}")

    # Print final summary
    print(f"\n🎯 BENCHMARK SUMMARY")
    print("=" * 60)
    if device.type == 'cuda':
        print(f"{'Model':>6} | {'Params':>8} | {'Latency':>8} | {'Throughput':>10} | {'Memory':>8}")
        print("-" * 60)
        for model_size in ['small', 'medium', 'large']:
            model_results = results['models'][model_size]
            memory_str = f"{model_results['memory']['batch_1']['inference_mb']:.1f}MB" if model_results['memory'] else "N/A"
            print(f"{model_size.upper():>6} | {model_results['parameters']:>8,} | "
                  f"{model_results['latency']['mean_ms']:>6.2f}ms | "
                  f"{model_results['throughput'][1]['throughput_samples_per_s']:>8.1f} sps | "
                  f"{memory_str:>8}")
    else:
        for model_size in ['small', 'medium', 'large']:
            model_results = results['models'][model_size]
            print(f"{model_size.upper():>6}: {model_results['parameters']:>8,} params | "
                  f"{model_results['latency']['mean_ms']:>6.2f}ms | "
                  f"{model_results['throughput'][1]['throughput_samples_per_s']:>6.1f} sps")

    return results

if __name__ == "__main__":
    results = run_full_benchmark()


Environment: Colab
CUDA available: True
Device: cuda
✅ Using mamba_ssm library

🚀 Starting Mamba KWS Inference Benchmark
Device: cuda
Using mamba_ssm: True

📊 Benchmarking SMALL model...
⚠️  Model file not found: content/small_kws_mel_97.42.pt
Using random weights for small model (benchmark mode)
Model parameters: 462,147
🔍 Running latency benchmark (1000 runs)...
📈 Running throughput benchmark...
  Testing batch size 1...
  Testing batch size 2...
  Testing batch size 4...
  Testing batch size 8...
  Testing batch size 16...
  Testing batch size 32...
  Testing batch size 64...
  Testing batch size 128...
  Testing batch size 256...
  Testing batch size 512...
  Testing batch size 1028...
  Testing batch size 2056...
🧠 Running memory usage benchmark...

📋 SMALL Results Summary:
  Parameters: 462,147
  Latency (single): 5.49ms ± 0.65ms
  P95 latency: 6.59ms
  Max throughput: 3289.6 samples/sec (batch=32)
  Memory (batch=1): 2.1MB total, 2.1MB activations
  Memory (batch=32): 105.5MB to