# IndianBatsModel - Testing & Inference

This notebook tests the trained Bat Species Classifier on new audio files.

**Prerequisites:**
1.  **Trained Model**: You must have a trained `.pth` model file (e.g., from the training notebook).
2.  **Test Data**: Audio files organized in folders by species (similar to training data).

**Steps:**
1.  Setup Environment (Clone code).
2.  Prepare Test Data (Generate Spectrograms).
3.  Load Model.
4.  Evaluate Accuracy.
5.  Run Inference on individual files.


In [None]:
# 1. Setup Environment
!git clone https://github.com/Quarkisinproton/IndianBatsModel.git
!pip install librosa pyyaml pandas matplotlib scikit-learn

In [None]:
# 2. Import Modules
import sys
import os
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from torch.utils.data import DataLoader
from sklearn.metrics import classification_report, confusion_matrix

# Add repo path
REPO_DIR = '/kaggle/working/IndianBatsModel'
SRC_DIR = os.path.join(REPO_DIR, 'src')
if REPO_DIR not in sys.path: sys.path.append(REPO_DIR)
if SRC_DIR not in sys.path: sys.path.append(SRC_DIR)

# Import project modules
try:
    from src.data_prep.generate_annotations import generate_annotations
    from src.data_prep.wombat_to_spectrograms import process_all as generate_spectrograms
    from src.data_prep.extract_end_frequency import process_all_and_write_csv as extract_features
    from src.datasets.spectrogram_with_features_dataset import SpectrogramWithFeaturesDataset
    from src.models.cnn_with_features import CNNWithFeatures
    print("Imports successful!")
except ImportError as e:
    print(f"Import Error: {e}")
    # Fallback imports
    from data_prep.generate_annotations import generate_annotations
    from data_prep.wombat_to_spectrograms import process_all as generate_spectrograms
    from data_prep.extract_end_frequency import process_all_and_write_csv as extract_features
    from datasets.spectrogram_with_features_dataset import SpectrogramWithFeaturesDataset
    from models.cnn_with_features import CNNWithFeatures
    print("Imports successful (fallback)!")

In [None]:
# 3. Configuration
WORK_DIR = '/kaggle/working'

# --- INPUTS ---
# Path to your trained model file (Upload this to Kaggle Datasets if needed)
# If you just ran training, it might be at: '/kaggle/working/models/bat_fused_best.pth'
MODEL_PATH = '/kaggle/working/models/bat_fused_best.pth' 

# Path to TEST audio folders
# (You can use the same folders as training to verify, or new folders for testing)
TEST_AUDIO_DIRS = [
    '/kaggle/input/pip-ceylonicusbat-species',
    '/kaggle/input/pip-tenuisbat-species'
]

# --- OUTPUTS ---
TEST_JSON_DIR = os.path.join(WORK_DIR, 'test_data/annotations')
TEST_SPECT_DIR = os.path.join(WORK_DIR, 'test_data/spectrograms')
TEST_FEATURES_DIR = os.path.join(WORK_DIR, 'test_data/features')
TEST_FEATURES_CSV = os.path.join(TEST_FEATURES_DIR, 'test_features.csv')

# Ensure directories exist
Path(TEST_FEATURES_DIR).mkdir(parents=True, exist_ok=True)

print(f"Model Path: {MODEL_PATH}")
print(f"Test Data Output: {TEST_SPECT_DIR}")

In [None]:
# 4. Prepare Test Data
# We need to convert the raw test audio into spectrograms and features, just like training.

print("--- Step 1: Generating Annotations ---")
generate_annotations(
    raw_audio_dirs=TEST_AUDIO_DIRS,
    output_dir=TEST_JSON_DIR,
    label_strategy='folder'
)

print("\n--- Step 2: Generating Spectrograms ---")
generate_spectrograms(
    raw_audio_dirs=TEST_AUDIO_DIRS,
    json_dir=TEST_JSON_DIR,
    out_dir=TEST_SPECT_DIR,
    species_key='label'
)

print("\n--- Step 3: Extracting Features ---")
extract_features(
    raw_audio_dirs=TEST_AUDIO_DIRS,
    json_dir=TEST_JSON_DIR,
    out_csv=TEST_FEATURES_CSV,
    species_key='label'
)
print("\nTest Data Preparation Complete.")

In [None]:
# 5. Load Test Dataset
try:
    test_dataset = SpectrogramWithFeaturesDataset(
        root_dir=TEST_SPECT_DIR,
        features_csv=TEST_FEATURES_CSV
    )
    print(f"Loaded Test Dataset: {len(test_dataset)} samples")
    print(f"Classes: {test_dataset.class_to_idx}")
except Exception as e:
    print(f"Error loading dataset: {e}")
    print("Did spectrogram generation fail?")
    test_dataset = []

In [None]:
# 6. Load Trained Model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if len(test_dataset) > 0:
    # Infer dimensions from dataset
    sample_img, sample_feat, _ = test_dataset[0]
    NUM_CLASSES = 3 # Adjust if your model was trained with different classes
    FEAT_DIM = sample_feat.shape[0]
    
    print(f"Initializing model with num_classes={NUM_CLASSES}, feat_dim={FEAT_DIM}")
    
    model = CNNWithFeatures(num_classes=NUM_CLASSES, numeric_feat_dim=FEAT_DIM, pretrained=False)
    
    if os.path.exists(MODEL_PATH):
        try:
            state_dict = torch.load(MODEL_PATH, map_location=device)
            model.load_state_dict(state_dict)
            model.to(device)
            model.eval()
            print("Model loaded successfully!")
        except Exception as e:
            print(f"Error loading model weights: {e}")
    else:
        print(f"CRITICAL: Model file not found at {MODEL_PATH}")
        print("Please upload your trained model or check the path.")
else:
    print("Cannot load model: Dataset is empty.")

In [None]:
# 7. Evaluate Accuracy
if len(test_dataset) > 0 and 'model' in locals():
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)
    
    all_preds = []
    all_labels = []
    
    print("Running evaluation...")
    with torch.no_grad():
        for images, features, labels in test_loader:
            images, features = images.to(device), features.to(device)
            outputs = model(images, features)
            _, preds = torch.max(outputs, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())
    
    # Calculate Metrics
    class_names = list(test_dataset.class_to_idx.keys())
    
    print("\n" + "="*40)
    print("CLASSIFICATION REPORT")
    print("="*40)
    print(classification_report(all_labels, all_preds, target_names=class_names))
    
    print("\nCONFUSION MATRIX:")
    print(confusion_matrix(all_labels, all_preds))
else:
    print("Skipping evaluation (missing model or data).")

In [None]:
# 8. Make a Guess (Inference on Random Samples)
import matplotlib.pyplot as plt
import random

if len(test_dataset) > 0 and 'model' in locals():
    # Pick 3 random samples
    indices = random.sample(range(len(test_dataset)), min(3, len(test_dataset)))
    
    model.eval()
    fig, axes = plt.subplots(1, len(indices), figsize=(15, 5))
    if len(indices) == 1: axes = [axes]
    
    idx_to_class = {v: k for k, v in test_dataset.class_to_idx.items()}
    
    for i, idx in enumerate(indices):
        img, feat, label = test_dataset[idx]
        
        # Predict
        with torch.no_grad():
            img_batch = img.unsqueeze(0).to(device)
            feat_batch = feat.unsqueeze(0).to(device)
            output = model(img_batch, feat_batch)
            probs = torch.nn.functional.softmax(output, dim=1)
            conf, pred_idx = torch.max(probs, 1)
            
        pred_class = idx_to_class[pred_idx.item()]
        true_class = idx_to_class[label.item()]
        confidence = conf.item() * 100
        
        # Plot
        # Un-normalize for display
        img_disp = img.permute(1, 2, 0).numpy()
        img_disp = img_disp * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]
        img_disp = np.clip(img_disp, 0, 1)
        
        axes[i].imshow(img_disp)
        axes[i].set_title(f"True: {true_class}\nPred: {pred_class}\nConf: {confidence:.1f}%")
        axes[i].axis('off')
    
    plt.show()
else:
    print("Skipping inference demo.")