# IndianBatsModel - Real-World Inference Pipeline

This notebook runs the trained model on **raw, uncurated audio files**.
Unlike the test notebook, this pipeline handles:
1.  **Detection**: Finding potential bat calls in long recordings (ignoring silence/noise).
2.  **Classification**: Predicting the species for each detected call.
3.  **Filtering**: Ignoring low-confidence predictions.

**Use this for:** Field recordings or files where you don't know if/where the bats are.

In [None]:
# 1. Setup Environment
!git clone https://github.com/Quarkisinproton/IndianBatsModel.git
!pip install librosa pyyaml pandas matplotlib scikit-learn tqdm

In [None]:
# 2. Import Modules
import sys
import os
import torch
import numpy as np
import pandas as pd
import librosa
import librosa.display
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
from PIL import Image
from torchvision import transforms

# Define working directory
WORK_DIR = '/kaggle/working'


# Add repo path
REPO_DIR = os.path.join(WORK_DIR, 'IndianBatsModel')
SRC_DIR = os.path.join(REPO_DIR, 'src')
if REPO_DIR not in sys.path: sys.path.append(REPO_DIR)
if SRC_DIR not in sys.path: sys.path.append(SRC_DIR)

# Import project modules
try:
    from src.models.cnn_with_features import CNNWithFeatures
    from src.data_prep.wombat_to_spectrograms import make_mel_spectrogram
    from src.data_prep.extract_end_frequency import compute_end_frequency
    print("Imports successful!")
except ImportError as e:
    print(f"Import Error: {e}")

In [None]:
# 3. Configuration

# --- INPUTS ---
# Path to your trained model
MODEL_PATH = '/kaggle/working/models/bat_fused_best.pth'

# Path to folder containing RAW .wav files
INPUT_AUDIO_DIR = '/kaggle/input/your-test-audio-folder' 

# Auto-detect input audio directory if default doesn't exist
if not os.path.exists(INPUT_AUDIO_DIR):
    print(f"Default input directory {INPUT_AUDIO_DIR} not found. Searching /kaggle/input...")
    for root, dirs, files in os.walk('/kaggle/input'):
        if any(f.lower().endswith(('.wav', '.mp3', '.flac')) for f in files):
            INPUT_AUDIO_DIR = root
            print(f"Found audio files in: {INPUT_AUDIO_DIR}")
            break

# --- DETECTION SETTINGS ---
SAMPLE_RATE = 250000  # Typical for bat recorders (adjust if needed)
MIN_FREQ = 15000      # High-pass filter cutoff (15kHz) to remove wind/noise
ENERGY_THRESH = 0.02  # RMS energy threshold to trigger detection
MIN_DURATION = 0.01   # Minimum call duration (seconds)
PAD_DURATION = 0.05   # Padding around detected call (seconds)
CONFIDENCE_THRESH = 70.0 # Only report predictions > 70% confidence

# --- MODEL SETTINGS ---
NUM_CLASSES = 2       # Must match your trained model
CLASS_NAMES = ['pip-ceylonicusbat-species', 'pip-tenuisbat-species'] # Ensure correct order!

# Auto-detect model in /kaggle/input
found_models = []
for root, dirs, files in os.walk('/kaggle/input'):
    for file in files:
        if file.endswith('.pth'):
            found_models.append(os.path.join(root, file))
if found_models:
    MODEL_PATH = found_models[0]
    print(f"Auto-detected model: {MODEL_PATH}")

In [None]:
# 4. Define Detection & Processing Functions

def detect_events(y, sr, min_freq=15000, threshold=0.01, min_dur=0.01, pad=0.05):
    """Finds segments of interest in the audio based on energy in high frequencies."""
    # 1. High-pass filter
    # Simple approach: Compute STFT, zero out low bins, reconstruct (or just use spectral energy)
    S = librosa.stft(y, n_fft=2048, hop_length=512)
    freqs = librosa.fft_frequencies(sr=sr, n_fft=2048)
    
    # Zero out frequencies below min_freq
    S_filtered = S.copy()
    mask = freqs < min_freq
    S_filtered[mask, :] = 0
    
    # 2. Calculate RMS energy profile of filtered signal
    rms = librosa.feature.rms(S=S_filtered, frame_length=2048, hop_length=512)[0]
    times = librosa.frames_to_time(np.arange(len(rms)), sr=sr, hop_length=512)
    
    # 3. Thresholding
    is_active = rms > threshold
    
    # 4. Group into segments
    events = []
    start = None
    for i, active in enumerate(is_active):
        if active and start is None:
            start = times[i]
        elif not active and start is not None:
            end = times[i]
            if (end - start) >= min_dur:
                events.append((max(0, start - pad), min(librosa.get_duration(y=y, sr=sr), end + pad)))
            start = None
            
    # Handle case where file ends while active
    if start is not None:
        end = times[-1]
        if (end - start) >= min_dur:
            events.append((max(0, start - pad), end))
            
    # Merge overlapping segments
    if not events: return []
    
    merged = []
    curr_start, curr_end = events[0]
    for next_start, next_end in events[1:]:
        if next_start <= curr_end:
            curr_end = max(curr_end, next_end)
        else:
            merged.append((curr_start, curr_end))
            curr_start, curr_end = next_start, next_end
    merged.append((curr_start, curr_end))
    
    return merged

def prepare_input(y, sr, start, end):
    """Extracts segment, generates spectrogram and features for the model."""
    # Extract audio
    start_sample = int(start * sr)
    end_sample = int(end * sr)
    y_seg = y[start_sample:end_sample]
    
    if len(y_seg) < 512: return None, None, None, None # Too short
    
    # 1. Spectrogram
    S_db = make_mel_spectrogram(y_seg, sr)
    
    # Convert to Image (normalize to 0-255 like training)
    # Normalize to 0-1
    S_min, S_max = S_db.min(), S_db.max()
    S_norm = (S_db - S_min) / (S_max - S_min + 1e-8)
    
    # Apply Magma Colormap (matches training data)
    S_colored = plt.get_cmap('magma')(S_norm)
    
    # Convert to uint8 RGB (drop alpha)
    S_img = (S_colored[:, :, :3] * 255).astype(np.uint8)
    
    # Flip vertically so low frequency is at the bottom
    S_img = np.flipud(S_img)
    
    img = Image.fromarray(S_img)
    
    # Transform
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    img_tensor = transform(img)
    
    # 2. Features (End Frequency)
    end_freq = compute_end_frequency(y, sr, start, end)
    if np.isnan(end_freq):
        end_freq = 0.0
    feat_tensor = torch.tensor([end_freq], dtype=torch.float32)
    
    return img_tensor, feat_tensor, img, y_seg


In [None]:
# 5. Load Model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

try:
    # Initialize model structure
    model = CNNWithFeatures(num_classes=NUM_CLASSES, numeric_feat_dim=1, pretrained=False)
    
    if os.path.exists(MODEL_PATH):
        print(f"Loading {MODEL_PATH}...")
        checkpoint = torch.load(MODEL_PATH, map_location=device, weights_only=False)
        
        if isinstance(checkpoint, torch.nn.Module):
            model = checkpoint
        elif isinstance(checkpoint, dict):
            if 'model_state_dict' in checkpoint:
                model.load_state_dict(checkpoint['model_state_dict'])
            else:
                model.load_state_dict(checkpoint)
        
        model.to(device)
        model.eval()
        print("Model loaded!")
    else:
        print("Model file not found!")
except Exception as e:
    print(f"Error loading model: {e}")

In [None]:
# 6. Run Inference Loop
from IPython.display import display

# Create output directory for spectrograms
SPECTROGRAM_DIR = os.path.join(WORK_DIR, 'inference_spectrograms')
os.makedirs(SPECTROGRAM_DIR, exist_ok=True)
print(f'Saving spectrograms to {SPECTROGRAM_DIR}...')

results = []

# Find audio files
audio_files = []
if os.path.exists(INPUT_AUDIO_DIR):
    for root, dirs, files in os.walk(INPUT_AUDIO_DIR):
        for f in files:
            if f.lower().endswith(('.wav', '.mp3', '.flac')):
                audio_files.append(os.path.join(root, f))
else:
    print(f"Input directory {INPUT_AUDIO_DIR} does not exist.")

print(f"Found {len(audio_files)} files to process.")

for audio_path in tqdm(audio_files):
    try:
        # Load audio
        y, sr = librosa.load(audio_path, sr=None)
        
        # Detect events
        events = detect_events(y, sr, min_freq=MIN_FREQ, threshold=ENERGY_THRESH)
        
        if not events:
            results.append({
                'filename': os.path.basename(audio_path),
                'start': '-', 'end': '-',
                'prediction': 'No Detection',
                'confidence': 0.0
            })
            continue
            
        # Process each event
        for start, end in events:
            img_t, feat_t, pil_img, y_seg = prepare_input(y, sr, start, end)
            if img_t is None: continue
            
            # Inference
            with torch.no_grad():
                img_batch = img_t.unsqueeze(0).to(device)
                feat_batch = feat_t.unsqueeze(0).to(device)
                
                output = model(img_batch, feat_batch)
                probs = torch.nn.functional.softmax(output, dim=1)
                conf, pred_idx = torch.max(probs, 1)
                
            confidence_pct = conf.item() * 100
            pred_class = CLASS_NAMES[pred_idx.item()]
            
            # Filter low confidence
            final_pred = pred_class if confidence_pct >= CONFIDENCE_THRESH else "Uncertain"
            
            # Display inline if confident
            if final_pred != "Uncertain":
                print(f"\nDetected: {final_pred} ({confidence_pct:.1f}%) at {start:.2f}-{end:.2f}s in {os.path.basename(audio_path)}")
                # Show with axes
                # Generate high-res spectrogram for display
                plt.figure(figsize=(10, 4))
                # Compute STFT (Linear Frequency)
                D = librosa.amplitude_to_db(np.abs(librosa.stft(y_seg)), ref=np.max)
                librosa.display.specshow(D, sr=sr, x_axis='time', y_axis='linear', cmap='magma')
                plt.colorbar(format='%+2.0f dB')
                plt.title(f'{final_pred} ({confidence_pct:.1f}%)')
                plt.tight_layout()
                plt.show()
                
                # Save Spectrogram
                spec_filename = f"{os.path.basename(audio_path)}_{start:.2f}_{end:.2f}_{final_pred}.png"
                spec_path = os.path.join(SPECTROGRAM_DIR, spec_filename)
                pil_img.save(spec_path)

            results.append({
                'filename': os.path.basename(audio_path),
                'start': f"{start:.2f}",
                'end': f"{end:.2f}",
                'prediction': final_pred,
                'confidence': f"{confidence_pct:.1f}"
            })
            
    except Exception as e:
        print(f"Error processing {audio_path}: {e}")

# Display Results
df = pd.DataFrame(results)
print("\nInference Results:")
print(df.to_string())

In [None]:
# 7. Calculate Accuracy (If Ground Truth is known)
# This assumes the filename contains the true species name (e.g., 'recording_speciesA_01.wav')

if not df.empty and 'prediction' in df.columns:
    # Define mapping from filename keywords to true labels
    # UPDATE THIS based on your file naming convention
    keyword_map = {
        'ceylonicus': 'pip-ceylonicusbat-species',
        'tenuis': 'pip-tenuisbat-species'
    }
    
    def get_true_label(filename):
        for key, label in keyword_map.items():
            if key.lower() in filename.lower():
                return label
        return 'unknown'

    df['true_label'] = df['filename'].apply(get_true_label)
    
    # Filter out 'No Detection' and 'Uncertain' for accuracy calc
    valid_preds = df[
        (df['prediction'] != 'No Detection') & 
        (df['prediction'] != 'Uncertain') & 
        (df['true_label'] != 'unknown')
    ].copy()
    
    if not valid_preds.empty:
        valid_preds['correct'] = valid_preds['prediction'] == valid_preds['true_label']
        accuracy = valid_preds['correct'].mean() * 100
        
        print(f"\nAccuracy on {len(valid_preds)} confident detections: {accuracy:.1f}%")
        print(valid_preds[['filename', 'prediction', 'true_label', 'correct']].head(20))
    else:
        print("\nCould not calculate accuracy (no valid matching predictions found).")
else:
    print("No results to evaluate.")