# NanoMamba - Interspeech 2026 Full Training (GPU)

**NanoMamba: Noise-Robust KWS with SA-SSM**

| Cell | 내용 | 예상 시간 |
|:----:|------|:---------:|
| 1 | 환경설정 + GSC V2 다운로드 | ~5분 |
| 2 | **전체 학습 + 평가 한번에** | ~8-12시간 |
| 3 | 결과 다운로드 | 즉시 |

⚠️ **런타임 → 런타임 유형 변경 → GPU (T4)** 선택 필수!

In [None]:
#@title Cell 1: 환경 설정 + 데이터 다운로드
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
else:
    raise RuntimeError("GPU not available! Change runtime type to GPU.")

# Clone repo
!git clone https://github.com/DrJinHoChoi/NanoMamba-Interspeech2026.git
%cd NanoMamba-Interspeech2026

# Download Google Speech Commands V2
import os
DATA_DIR = './data'
os.makedirs(DATA_DIR, exist_ok=True)

if not os.path.exists(os.path.join(DATA_DIR, 'speech_commands_v0.02')):
    print("\n Downloading Google Speech Commands V2...")
    !wget -q http://download.tensorflow.org/data/speech_commands_v0.02.tar.gz -O /tmp/gsc_v2.tar.gz
    !mkdir -p {DATA_DIR}/speech_commands_v0.02
    !tar -xzf /tmp/gsc_v2.tar.gz -C {DATA_DIR}/speech_commands_v0.02
    !rm /tmp/gsc_v2.tar.gz
    print("Download complete!")
else:
    print("Data already exists.")

# Verify
classes = [d for d in os.listdir(f'{DATA_DIR}/speech_commands_v0.02') 
           if os.path.isdir(f'{DATA_DIR}/speech_commands_v0.02/{d}') and not d.startswith('_')]
print(f"\nFound {len(classes)} keyword classes")
print("Ready to train!")

In [None]:
#@title Cell 2: 전체 학습 + 노이즈 평가 (한번에 실행)
#@markdown ### 학습 모델 (9종)
#@markdown - NanoMamba-Tiny (4,634), Small (12,032)
#@markdown - BC-ResNet-1 (7,464), BC-ResNet-3 (43,200), DS-CNN-S (23,756)
#@markdown - SA-SSM Ablation: Full, dt_only, b_only, Standard
#@markdown ### 노이즈 평가
#@markdown - 3 noise types (factory, white, babble) x 7 SNR (-15~+15dB)

import subprocess, sys, os

# Add src/ to Python path for imports
os.environ['PYTHONPATH'] = os.path.join(os.getcwd(), 'src')

ALL_MODELS = ",".join([
    # Proposed
    "NanoMamba-Tiny", "NanoMamba-Small",
    # Baselines
    "BC-ResNet-1", "BC-ResNet-3", "DS-CNN-S",
    # SA-SSM Ablation
    "NanoMamba-Tiny-Full", "NanoMamba-Tiny-dtOnly",
    "NanoMamba-Tiny-bOnly", "NanoMamba-Tiny-Standard",
])

cmd = [
    sys.executable, "-u", "src/train_all_models.py",
    "--data_dir", "./data",
    "--checkpoint_dir", "./checkpoints_full",
    "--epochs", "30",
    "--batch_size", "64",
    "--seed", "42",
    "--models", ALL_MODELS,
    "--noise_types", "factory,white,babble",
    "--snr_range", "-15,-10,-5,0,5,10,15",
    "--per_class",
]

print(f"Running: {' '.join(cmd)}\n")
env = os.environ.copy()
env['PYTHONPATH'] = os.path.join(os.getcwd(), 'src')
process = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stderr, env=env)
process.wait()

if process.returncode == 0:
    print("\n" + "="*60)
    print("  ALL TRAINING + EVALUATION COMPLETE!")
    print("="*60)
else:
    print(f"\nProcess exited with code {process.returncode}")

In [None]:
#@title Cell 3: 결과 확인 + 다운로드
import json
import glob

# Find result files
result_files = glob.glob('checkpoints_full/results/*.json')
print(f"Found {len(result_files)} result files:")
for f in sorted(result_files):
    print(f"  {f}")

# Show latest results
if result_files:
    latest = sorted(result_files)[-1]
    with open(latest) as f:
        results = json.load(f)
    
    print(f"\n{'='*70}")
    print(f"  Results from: {latest}")
    print(f"{'='*70}")
    
    # Clean accuracy table
    if 'model_results' in results:
        print(f"\n{'Model':<30} {'Params':>8} {'Val':>8} {'Test':>8}")
        print('-' * 58)
        for name, data in results['model_results'].items():
            val = data.get('best_val_acc', '-')
            test = data.get('test_acc', '-')
            params = data.get('params', '-')
            val_str = f"{val:.2f}%" if isinstance(val, (int, float)) else str(val)
            test_str = f"{test:.2f}%" if isinstance(test, (int, float)) else str(test)
            print(f"{name:<30} {str(params):>8} {val_str:>8} {test_str:>8}")
    
    # Noise robustness table
    if 'noise_results' in results:
        print(f"\n{'='*70}")
        print("  Noise Robustness Results")
        print(f"{'='*70}")
        noise_data = results['noise_results']
        for model_name, model_noise in noise_data.items():
            print(f"\n  {model_name}:")
            for noise_type, snr_results in model_noise.items():
                snr_str = ", ".join([f"{snr}dB:{acc:.1f}%" 
                                     for snr, acc in sorted(snr_results.items(), 
                                     key=lambda x: float(x[0]) if x[0] != 'clean' else 999)])
                print(f"    {noise_type}: {snr_str}")

# Zip all results for download
!zip -r /content/smartear_results.zip checkpoints_full/

from google.colab import files
files.download('/content/smartear_results.zip')
print("\nResults downloaded!")