# YOHO Demo Notebook

Demonstration of inference and visualization using the `YOHO model`.

**Requirements:**
- Trained model checkpoint (best_yoho.pth)
- ESC-50 style dataset structure (optional)
- Required packages: torch, torchaudio, matplotlib, librosa, numpy

In [None]:
import torch
import torchaudio
import numpy as np
import matplotlib.pyplot as plt
import librosa.display
from pathlib import Path

# Import model and utilities
# Adjust paths according to your project structure
from model.yoho import YOHO
from config import YOHOConfig
from utils.visualize import plot_spectrogram_with_preds

print("Imports completed")

## 1. Configuration and Model Loading

In [None]:
# Configuration (use the same as during training)
cfg = YOHOConfig()

# Override some parameters if needed
cfg.model.num_classes = 50  # ESC-50 has 50 classes
cfg.model.seg_enabled = True
cfg.model.use_memory = False  # disable memory for simple demo

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load model
model = YOHO(cfg.model).to(device)
model.eval()

# Load checkpoint
checkpoint_path = Path("checkpoints/best_yoho.pth")  # <- change to your actual path
if checkpoint_path.exists():
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    print(f"Model loaded successfully from {checkpoint_path}")
else:
    print(f"Checkpoint not found at {checkpoint_path}")
    print("Demo will run with random weights (predictions will be meaningless)")

## 2. Generate or Load Sample Audio

In [None]:
def generate_synthetic_audio(duration_sec=5.0, sr=44100):
    """
    Generate simple synthetic audio: sine wave + noise + second tone
    """
    t = torch.linspace(0, duration_sec, int(sr * duration_sec))
    
    # 440 Hz tone (A4 note)
    tone1 = 0.7 * torch.sin(2 * np.pi * 440 * t)
    
    # 880 Hz tone starting at 2 seconds
    tone2 = 0.5 * torch.sin(2 * np.pi * 880 * (t - 2.0)) * (t >= 2.0)
    
    # Background noise
    noise = 0.15 * torch.randn_like(t)
    
    waveform = tone1 + tone2 + noise
    
    # Normalize
    waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)
    
    return waveform, sr


# Option 1: Synthetic audio
waveform, sr = generate_synthetic_audio(duration_sec=5.0)

# Option 2: Load real file (uncomment if you have sample)
# waveform, sr = torchaudio.load("data/sample.wav")
# waveform = waveform.mean(0)  # to mono

print(f"Audio shape: {waveform.shape}, sample rate: {sr} Hz")

## 3. Run Inference

In [None]:
# Prepare input (add batch dimension)
audio_input = waveform.unsqueeze(0).to(device)  # [1, T]

# Inference
with torch.no_grad():
    predictions = model.infer(audio_input, conf_thres=0.25, iou_thres=0.45)

print("Inference completed")
print(f"Detected events: {len(predictions.get('boxes', []))}")

In [None]:
print("\nStreaming inference example:")
long_audio = torch.cat([waveform] * 10)  # 50 seconds
stream_preds = model.stream_infer(long_audio, chunk_length_sec=5.0, overlap_sec=1.5)
print(f"Processed {len(stream_preds)} chunks")

## 4. Compute Spectrogram for Visualization

In [None]:
# Same parameters as in model
n_mels = cfg.model.spec.n_mels if hasattr(cfg.model, 'spec') else 128
hop_length = cfg.model.spec.hop_length if hasattr(cfg.model, 'spec') else 512

mel_transform = torchaudio.transforms.MelSpectrogram(
    sample_rate=sr,
    n_mels=n_mels,
    hop_length=hop_length,
    f_max=8000,
    normalized=True
)

spec_tensor = mel_transform(waveform)
spec_db = librosa.power_to_db(spec_tensor.numpy(), ref=np.max)

print(f"Spectrogram shape: {spec_db.shape}")

## 5. Visualization

In [None]:
# Prepare predictions in format expected by visualize function
# Note: adjust according to your actual output structure of .infer()
vis_preds = {
    'boxes': predictions.get('boxes', torch.empty((0,4))),
    'scores': predictions.get('scores', torch.empty((0,))),
    'labels': predictions.get('labels', torch.empty((0,), dtype=torch.long)),
    'masks': predictions.get('masks', None)
}

# Plot
fig = plot_spectrogram_with_preds(
    spec_db,
    vis_preds,
    title="YOHO Inference on Synthetic Audio",
    save_path="demo_result.png"
)

plt.show()

## 6. Optional: Detailed Output

In [None]:
if len(vis_preds['boxes']) > 0:
    print("\nDetected events:")
    for i, (box, score, label) in enumerate(zip(
        vis_preds['boxes'],
        vis_preds['scores'],
        vis_preds['labels']
    )):
        t_start, f_start, t_end, f_end = box.tolist()
        print(f"Event #{i+1}:")
        print(f"  Time: {t_start:.2f}s â†’ {t_end:.2f}s")
        print(f"  Frequency: {f_start:.0f}Hz â†’ {f_end:.0f}Hz")
        print(f"  Class: {label.item()}  |  Confidence: {score.item():.3f}")
        print()
else:
    print("No events detected above confidence threshold.")

---

You can now:
- Replace synthetic audio with real recordings
- Try different confidence thresholds
- Compare results before/after model improvements

Happy experimenting! ðŸš€