# NanoMamba-Tiny Checkpoint Test

**NanoMamba: Noise-Robust KWS with Spectral-Aware State Space Models (Interspeech 2026)**

| Cell | Content | Time |
|:----:|---------|:----:|
| 1 | Setup + Download GSC V2 | ~3min |
| 2 | Load checkpoint + Clean test | ~2min |
| 3 | Noise robustness evaluation | ~10min |
| 4 | Results visualization | instant |

**Runtime > Change runtime type > GPU (T4)**

In [None]:
#@title Cell 1: Setup + Download GSC V2
import torch, os, sys, time
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# Clone repo (checkpoints included)
if not os.path.exists('/content/NanoMamba-Interspeech2026'):
    !git clone https://github.com/DrJinHoChoi/NanoMamba-Interspeech2026.git /content/NanoMamba-Interspeech2026

%cd /content/NanoMamba-Interspeech2026

# Verify checkpoint exists
ckpt_path = 'checkpoints_full/NanoMamba-Tiny/best.pt'
assert os.path.exists(ckpt_path), f"Checkpoint not found: {ckpt_path}"
print(f"Checkpoint found: {ckpt_path} ({os.path.getsize(ckpt_path)/1024:.1f} KB)")

# Download Google Speech Commands V2
DATA_DIR = './data'
GSC_DIR = os.path.join(DATA_DIR, 'SpeechCommands', 'speech_commands_v0.02')

if not os.path.exists(GSC_DIR):
    print("\nDownloading Google Speech Commands V2 (~2.3GB)...")
    os.makedirs(GSC_DIR, exist_ok=True)
    !wget -q http://download.tensorflow.org/data/speech_commands_v0.02.tar.gz -O /tmp/gsc_v2.tar.gz
    !tar -xzf /tmp/gsc_v2.tar.gz -C {GSC_DIR}
    !rm /tmp/gsc_v2.tar.gz
    print("Download complete!")
else:
    print(f"GSC V2 already exists at {GSC_DIR}")

classes = [d for d in os.listdir(GSC_DIR)
           if os.path.isdir(os.path.join(GSC_DIR, d)) and not d.startswith('_')]
print(f"Found {len(classes)} keyword classes")
print("\nReady for testing!")

In [None]:
#@title Cell 2: Load Checkpoint + Clean Evaluation
import torch
import numpy as np
from nanomamba import create_nanomamba_tiny
from train_all_models import (
    SpeechCommandsDataset, evaluate, GSC_LABELS_12
)
from torch.utils.data import DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1. Create model and load checkpoint
model = create_nanomamba_tiny(n_classes=12)
ckpt = torch.load('checkpoints_full/NanoMamba-Tiny/best.pt',
                   map_location=device, weights_only=True)
missing, unexpected = model.load_state_dict(
    ckpt['model_state_dict'], strict=False)
model = model.to(device)
model.eval()

params = sum(p.numel() for p in model.parameters())
fp32_kb = params * 4 / 1024
int8_kb = params * 1 / 1024

print(f"Model: NanoMamba-Tiny")
print(f"  Parameters: {params:,}")
print(f"  FP32 size: {fp32_kb:.1f} KB")
print(f"  INT8 size: {int8_kb:.1f} KB")
print(f"  Checkpoint epoch: {ckpt['epoch']}")
print(f"  Checkpoint val_acc: {ckpt['val_acc']:.2f}%")
if missing:
    print(f"  Missing keys (using defaults): {missing}")

# 2. Load validation and test datasets
print("\nLoading datasets...")
val_dataset = SpeechCommandsDataset('./data', subset='validation', augment=False)
test_dataset = SpeechCommandsDataset('./data', subset='testing', augment=False)

val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

# 3. Clean evaluation
print("\nEvaluating (clean)...")
t0 = time.time()
val_acc, val_preds, val_labels = evaluate(model, val_loader, device)
test_acc, test_preds, test_labels = evaluate(model, test_loader, device)
elapsed = time.time() - t0

print(f"\n{'='*50}")
print(f"  CLEAN ACCURACY RESULTS")
print(f"{'='*50}")
print(f"  Validation: {val_acc:.2f}%")
print(f"  Test:       {test_acc:.2f}%")
print(f"  Eval time:  {elapsed:.1f}s")

# 4. Per-class accuracy
print(f"\n  Per-class Test Accuracy:")
print(f"  {'Class':<12} {'Correct':>8} {'Total':>8} {'Acc':>8}")
print(f"  {'-'*40}")
for i, label in enumerate(GSC_LABELS_12):
    mask = test_labels == i
    total_i = mask.sum()
    if total_i > 0:
        correct_i = (test_preds[mask] == i).sum()
        acc_i = 100.0 * correct_i / total_i
        print(f"  {label:<12} {correct_i:>8} {total_i:>8} {acc_i:>7.1f}%")

print(f"\n  Overall Test Accuracy: {test_acc:.2f}%")

In [None]:
#@title Cell 3: Noise Robustness Evaluation
from train_all_models import evaluate_noisy, evaluate_noisy_per_class
import json

noise_types = ['factory', 'white', 'babble', 'street', 'pink']
snr_levels = [-15, -10, -5, 0, 5, 10, 15]

print(f"{'='*70}")
print(f"  NOISE ROBUSTNESS EVALUATION")
print(f"  Model: NanoMamba-Tiny ({params:,} params)")
print(f"  Noise types: {noise_types}")
print(f"  SNR levels: {snr_levels} dB")
print(f"{'='*70}")

noise_results = {}
for noise_type in noise_types:
    noise_results[noise_type] = {}
    print(f"\n  {noise_type.upper()}:")
    for snr in snr_levels:
        t0 = time.time()
        acc = evaluate_noisy(model, val_loader, device,
                             noise_type=noise_type, snr_db=snr)
        elapsed = time.time() - t0
        noise_results[noise_type][snr] = acc
        print(f"    SNR={snr:>4}dB: {acc:.2f}% ({elapsed:.1f}s)")
    noise_results[noise_type]['clean'] = val_acc

# Summary table
print(f"\n{'='*80}")
print(f"  SUMMARY TABLE")
print(f"{'='*80}")
header = f"  {'Noise':<10} | {'Clean':>7} | " + " | ".join(f"{s:>5}dB" for s in snr_levels)
print(header)
print(f"  {'-'*len(header)}")
for noise_type in noise_types:
    clean = noise_results[noise_type]['clean']
    snrs = [noise_results[noise_type][s] for s in snr_levels]
    row = f"  {noise_type:<10} | {clean:>6.1f}% | " + " | ".join(f"{s:>5.1f}%" for s in snrs)
    print(row)

# Save results
save_results = {
    'model': 'NanoMamba-Tiny',
    'params': params,
    'checkpoint_epoch': ckpt['epoch'],
    'clean_val_acc': val_acc,
    'clean_test_acc': test_acc,
    'noise_robustness': {}
}
for nt in noise_types:
    save_results['noise_robustness'][nt] = {
        str(k): v for k, v in noise_results[nt].items()
    }

os.makedirs('results', exist_ok=True)
with open('results/nanomamba_tiny_test_results.json', 'w') as f:
    json.dump(save_results, f, indent=2)
print(f"\nResults saved to results/nanomamba_tiny_test_results.json")

In [None]:
#@title Cell 4: Results Visualization
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['font.size'] = 12

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# --- Plot 1: Noise robustness curves ---
ax1 = axes[0]
colors = {'factory': '#e74c3c', 'white': '#3498db', 'babble': '#2ecc71',
          'street': '#f39c12', 'pink': '#9b59b6'}
markers = {'factory': 'o', 'white': 's', 'babble': '^', 'street': 'D', 'pink': 'v'}

for noise_type in noise_types:
    accs = [noise_results[noise_type][s] for s in snr_levels]
    ax1.plot(snr_levels, accs,
             color=colors[noise_type], marker=markers[noise_type],
             linewidth=2, markersize=8, label=noise_type.capitalize())

ax1.axhline(y=val_acc, color='gray', linestyle='--', alpha=0.5, label=f'Clean ({val_acc:.1f}%)')
ax1.set_xlabel('SNR (dB)')
ax1.set_ylabel('Accuracy (%)')
ax1.set_title(f'NanoMamba-Tiny ({params:,} params) - Noise Robustness')
ax1.legend(loc='lower right')
ax1.grid(True, alpha=0.3)
ax1.set_ylim([0, 100])

# --- Plot 2: Per-class accuracy (clean) ---
ax2 = axes[1]
class_accs = []
for i, label in enumerate(GSC_LABELS_12):
    mask = test_labels == i
    total_i = mask.sum()
    if total_i > 0:
        correct_i = (test_preds[mask] == i).sum()
        class_accs.append(100.0 * correct_i / total_i)
    else:
        class_accs.append(0)

bars = ax2.barh(GSC_LABELS_12, class_accs, color='#3498db', edgecolor='white')
ax2.set_xlabel('Accuracy (%)')
ax2.set_title(f'Per-Class Test Accuracy (Overall: {test_acc:.1f}%)')
ax2.set_xlim([0, 100])
for bar, acc in zip(bars, class_accs):
    ax2.text(bar.get_width() + 1, bar.get_y() + bar.get_height()/2,
             f'{acc:.1f}%', va='center', fontsize=10)

plt.tight_layout()
plt.savefig('results/nanomamba_tiny_test_results.png', dpi=150, bbox_inches='tight')
plt.show()
print("Plot saved to results/nanomamba_tiny_test_results.png")

In [None]:
#@title Cell 5: Download Results
!zip -r /content/nanomamba_tiny_test.zip results/

from google.colab import files
files.download('/content/nanomamba_tiny_test.zip')
print("Results downloaded!")