# Download Datasets

In [None]:
!git clone https://github.com/HumBug-Mosquito/HumBugDB.git

In [None]:
!wget https://zenodo.org/record/4904800/files/humbugdb_neurips_2021_1.zip?download=1
!wget https://zenodo.org/record/4904800/files/humbugdb_neurips_2021_2.zip?download=1
!wget https://zenodo.org/record/4904800/files/humbugdb_neurips_2021_3.zip?download=1
!wget https://zenodo.org/record/4904800/files/humbugdb_neurips_2021_4.zip?download=1

In [None]:
!unzip /content/humbugdb_neurips_2021_1.zip?download=1 -d '/content/HumBugDB/data/audio'
!unzip /content/humbugdb_neurips_2021_2.zip?download=1 -d '/content/HumBugDB/data/audio'
!unzip /content/humbugdb_neurips_2021_3.zip?download=1 -d '/content/HumBugDB/data/audio'
!unzip /content/humbugdb_neurips_2021_4.zip?download=1 -d '/content/HumBugDB/data/audio'

# Data Preprocessing

In [None]:
import os
import pandas as pd
import torchaudio
import numpy as np
import random
import torch
from torch.utils.data import Dataset

# Define paths using raw strings
mosquito_csv_path = r'C:\Users\akash\OneDrive\Desktop\ML projects\mosquitos\HumBugDB\data\metadata\neurips_2021_zenodo_0_0_1.csv'
noise_csv_path = r'C:\Users\akash\OneDrive\Desktop\ML projects\mosquitos\Actual noise for mosquito project\meta\esc50_noise.csv'
mosquito_audio_dir = r'C:\Users\akash\OneDrive\Desktop\ML projects\mosquitos\HumBugDB\data\audio'
noise_audio_dir = r'C:\Users\akash\OneDrive\Desktop\ML projects\mosquitos\Actual noise for mosquito project'

# Load mosquito data
mosquito_df = pd.read_csv(mosquito_csv_path)
mosquito_df = mosquito_df[mosquito_df['sound_type'] == 'mosquito']

# Clean species column
mosquito_df = mosquito_df[mosquito_df['species'].apply(lambda x: isinstance(x, str))]
unique_species = mosquito_df['species'].unique()
print("Unique species detected:", unique_species)

# Create mapping of species to indices
species_to_index = {species: idx + 1 for idx, species in enumerate(unique_species)}
species_to_index['No Mosquito'] = 0

# List valid mosquito audio files
mosquito_files = [f"{row['id']}.wav" for _, row in mosquito_df.iterrows() if os.path.exists(os.path.join(mosquito_audio_dir, f"{row['id']}.wav"))]

# Load noise data
noise_df = pd.read_csv(noise_csv_path)
noise_files = [row['filename'] for _, row in noise_df.iterrows() if os.path.exists(os.path.join(noise_audio_dir, row['filename']))]

if not mosquito_files or not noise_files:
    raise FileNotFoundError("No valid audio files found. Check paths or download data.")

# Define the Dataset Class
class MixedMosquitoDataset(Dataset):
    def __init__(self, mosquito_files, noise_files, sample_rate=16000, duration=2):
        self.mosquito_files = mosquito_files
        self.noise_files = noise_files
        self.sample_rate = sample_rate
        self.duration = duration
        self.num_samples = len(mosquito_files) * 2  # Twice the number of mosquito files

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        if idx < len(self.mosquito_files):
            # Mixed sample
            mosquito_file = random.choice(self.mosquito_files)
            noise_file = random.choice(self.noise_files)

            # Check if mosquito file exists
            mosquito_file_path = os.path.join(mosquito_audio_dir, mosquito_file)
            if not os.path.exists(mosquito_file_path):
                print(f"File not found: {mosquito_file_path}")
                return None, 0  # Return dummy data

            mosquito_waveform, _ = torchaudio.load(mosquito_file_path)

            # Check if noise file exists
            noise_file_path = os.path.join(noise_audio_dir, noise_file)
            if not os.path.exists(noise_file_path):
                print(f"File not found: {noise_file_path}")
                return None, 0  # Return dummy data

            noise_waveform, _ = torchaudio.load(noise_file_path)

            # Ensure both audio files have the same sample rate
            if mosquito_waveform.shape[0] != noise_waveform.shape[0]:
                noise_waveform = torchaudio.transforms.Resample(orig_freq=noise_waveform.shape[0], new_freq=mosquito_waveform.shape[0])(noise_waveform)

            # Adjust lengths
            target_length = max(mosquito_waveform.shape[1], noise_waveform.shape[1])
            mosquito_waveform = torch.nn.functional.pad(mosquito_waveform, (0, target_length - mosquito_waveform.shape[1]))
            noise_waveform = torch.nn.functional.pad(noise_waveform, (0, target_length - noise_waveform.shape[1]))

            # Mix audio
            mixed_waveform = 0.5 * mosquito_waveform + 0.5 * noise_waveform

            # Save mixed audio
            output_file_name = f'mixed_{idx}.wav'
            output_file_path = os.path.join(audio_dir, output_file_name)
            torchaudio.save(output_file_path, mixed_waveform, self.sample_rate)

            # Prepare metadata entry
            mosquito_id = mosquito_file.split('.')[0]
            noise_id = noise_file.split('.')[0]
            species = mosquito_df.loc[mosquito_df['id'] == mosquito_id, 'species'].values[0]
            return mixed_waveform, {
                'id': idx,
                'file_name': output_file_name,
                'mosquito_id': mosquito_id,
                'mosquito_species': species,
                'noise_id': noise_id,
                'noise_type': noise_df.loc[noise_df['filename'] == noise_file, 'category'].values[0]  # Assuming 'category' column exists
            }
        else:
            # Pure noise sample
            noise_file = random.choice(self.noise_files)
            noise_waveform, _ = torchaudio.load(os.path.join(noise_audio_dir, noise_file))
            noise_waveform = self._adjust_length(noise_waveform)
            mixed_waveform = noise_waveform
            label = 0  # No Mosquito

        # Convert to Mel spectrogram
        mel_spectrogram = torchaudio.transforms.MelSpectrogram(n_mels=128, n_fft=2048, hop_length=512)(mixed_waveform)
        mel_spectrogram = torchaudio.transforms.AmplitudeToDB()(mel_spectrogram)

        return mel_spectrogram.squeeze(0), label

    def _adjust_length(self, waveform):
        # Adjust length to 2 seconds (sample_rate * duration)
        target_length = self.sample_rate * self.duration
        if waveform.shape[1] < target_length:
            # Pad
            padding = target_length - waveform.shape[1]
            waveform = torch.nn.functional.pad(waveform, (0, padding))
        else:
            # Trim
            start = random.randint(0, waveform.shape[1] - target_length)
            waveform = waveform[:, start:start + target_length]
        return waveform

# Model Training

In [None]:
# -*- coding: utf-8 -*-
"""
Mosquito species CNN training (CPU-only) using mixed audios + controlled noise-only negatives.

- Reads your metadata CSV (must include: mix_file, species; optional: sound_type)
- Uses mix_{id}.wav at AUDIO_OUTPUT_DIR and species labels from CSV
- Adds extra negatives from NOISE_DIR as class 0 = "no_mosquito" (controlled fraction)
- Stratified split with robust fallback if some species are too rare
- Class-balanced sampler for training
- Early stopping + checkpointing best val accuracy
- Final export to:
    * PyTorch weights (.pt)
    * TorchScript (.pt)
    * ONNX (.onnx)

Avoids classification_report label mismatch by explicitly passing labels seen in y_true/y_pred.
"""
import os
import random
import warnings
from typing import List, Tuple, Dict

import numpy as np
import pandas as pd
from tqdm import tqdm

import librosa
import soundfile as sf

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, ConcatDataset, WeightedRandomSampler

from sklearn.model_selection import StratifiedShuffleSplit, train_test_split
from sklearn.metrics import classification_report, confusion_matrix

# -----------------------
# ==== YOUR PATHS ====
# -----------------------
AUDIO_OUTPUT_DIR = r"C:\Users\akash\OneDrive\Desktop\ML projects\mosquitos\mosquito_noise_mix"
CSV_PATH         = r"C:\Users\akash\OneDrive\Desktop\ML projects\mosquitos\mosquito_noise_mix\metadata\esc50_noise.csv"
NOISE_DIR        = r"C:\Users\akash\OneDrive\Desktop\ML projects\mosquitos\Actual noise for mosquito project"

# -----------------------
# ==== CONFIG ===========
# -----------------------
SEED = 42
SR = 16000
DURATION = 2.0            # seconds (will pad/trim to this)
N_MELS = 128
N_FFT = 1024
HOP_LENGTH = 160
FMIN = 20
FMAX = SR // 2
BATCH_SIZE = 32
EPOCHS = 30
PATIENCE = 5
LR = 3e-4
WEIGHT_DECAY = 1e-4
VAL_SIZE = 0.2
NUM_WORKERS = 0           # Windows-safe
PIN_MEMORY = False        # CPU-only
MAX_NOISE_FRACTION = 0.25 # cap negatives as 25% of total to avoid overfitting

MODEL_DIR = os.path.join(AUDIO_OUTPUT_DIR, "models")
EXPORT_DIR = os.path.join(MODEL_DIR, "exports")
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(EXPORT_DIR, exist_ok=True)

# Force CPU
DEVICE = torch.device("cpu")

# Reproducibility
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# -----------------------
# ===== Utilities =======
# -----------------------
def _safe_exists(fp: str) -> bool:
    try:
        return os.path.exists(fp)
    except Exception:
        return False

def list_audio_files(root: str, exts=(".wav", ".flac", ".mp3")) -> List[str]:
    out = []
    for dirpath, _, filenames in os.walk(root):
        for f in filenames:
            if f.lower().endswith(exts):
                out.append(os.path.join(dirpath, f))
    return out

def load_audio_fixed(path: str, sr: int, duration: float) -> np.ndarray:
    """Load audio, resample, mono, pad/trim to fixed duration."""
    target_len = int(sr * duration)
    try:
        y, file_sr = librosa.load(path, sr=sr, mono=True)
    except Exception:
        # fallback to soundfile then resample with librosa
        y, file_sr = sf.read(path, always_2d=False)
        if y.ndim > 1:
            y = np.mean(y, axis=1)
        if file_sr != sr:
            y = librosa.resample(y, orig_sr=file_sr, target_sr=sr)
    if y.size == 0:
        y = np.zeros(target_len, dtype=np.float32)
    # pad/trim
    if len(y) < target_len:
        pad = target_len - len(y)
        y = np.pad(y, (0, pad), mode="constant")
    elif len(y) > target_len:
        y = y[:target_len]
    return y.astype(np.float32)

def to_logmelspec(y: np.ndarray, sr: int) -> np.ndarray:
    S = librosa.feature.melspectrogram(
        y=y, sr=sr, n_fft=N_FFT, hop_length=HOP_LENGTH,
        n_mels=N_MELS, fmin=FMIN, fmax=FMAX, power=2.0
    )
    S_db = librosa.power_to_db(S, ref=np.max)
    # per-sample standardization
    mu = S_db.mean()
    sigma = S_db.std() + 1e-6
    S_db = (S_db - mu) / sigma
    return S_db.astype(np.float32)  # (n_mels, time)

def spec_augment(spec: torch.Tensor, time_mask_prob=0.3, freq_mask_prob=0.3,
                 time_mask_max=20, freq_mask_max=12) -> torch.Tensor:
    """SpecAugment on (1, n_mels, time)."""
    _, n_mels, T = spec.shape
    out = spec.clone()
    if random.random() < time_mask_prob:
        t = random.randint(1, min(time_mask_max, max(1, T // 8)))
        t0 = random.randint(0, max(0, T - t))
        out[:, :, t0:t0 + t] = 0.0
    if random.random() < freq_mask_prob:
        f = random.randint(1, min(freq_mask_max, max(1, n_mels // 8)))
        f0 = random.randint(0, max(0, n_mels - f))
        out[:, f0:f0 + f, :] = 0.0
    return out

# -----------------------
# ===== Dataset(s) ======
# -----------------------
class MixedMosquitoDataset(Dataset):
    """Samples from mixed audios listed in CSV (mix_file, species)."""
    def __init__(self, df: pd.DataFrame, label2idx: Dict[str, int], augment: bool = True):
        self.df = df.reset_index(drop=True)
        self.label2idx = label2idx
        self.augment = augment

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        path = str(row["mix_file"])
        species = str(row["species"]).strip()
        label = self.label2idx.get(species, None)
        if label is None:
            # safety: treat unknown as "no_mosquito"
            label = self.label2idx["no_mosquito"]

        y = load_audio_fixed(path, SR, DURATION)
        # light waveform jitter
        if self.augment:
            # random time shift up to 10% length
            max_shift = int(0.1 * len(y))
            if max_shift > 0 and random.random() < 0.5:
                s = random.randint(-max_shift, max_shift)
                y = np.roll(y, s)
            # tiny gaussian noise
            if random.random() < 0.3:
                y = y + 0.005 * np.random.randn(len(y)).astype(np.float32)

        spec = to_logmelspec(y, SR)  # (mels, T)
        spec = torch.from_numpy(spec).unsqueeze(0)  # (1, mels, T)
        if self.augment:
            spec = spec_augment(spec)
        return spec, int(label)

class NoiseOnlyDataset(Dataset):
    """Negative-only samples from NOISE_DIR => class 0 (no_mosquito)."""
    def __init__(self, files: List[str], label_no_mosquito: int, augment: bool = True):
        self.files = files
        self.label = label_no_mosquito
        self.augment = augment

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx: int):
        path = self.files[idx]
        y = load_audio_fixed(path, SR, DURATION)
        if self.augment:
            max_shift = int(0.1 * len(y))
            if max_shift > 0 and random.random() < 0.5:
                s = random.randint(-max_shift, max_shift)
                y = np.roll(y, s)
            if random.random() < 0.3:
                y = y + 0.005 * np.random.randn(len(y)).astype(np.float32)
        spec = to_logmelspec(y, SR)
        spec = torch.from_numpy(spec).unsqueeze(0)
        if self.augment:
            spec = spec_augment(spec)
        return spec, self.label

# -----------------------
# ======= Model =========
# -----------------------
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, pool=(2, 2)):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
        self.pool = nn.MaxPool2d(pool)

    def forward(self, x):
        x = self.conv(x)
        x = self.pool(x)
        return x

class TinyCNN(nn.Module):
    def __init__(self, n_classes: int):
        super().__init__()
        self.features = nn.Sequential(
            ConvBlock(1, 32),
            ConvBlock(32, 64),
            ConvBlock(64, 128),
            ConvBlock(128, 256),
        )
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Dropout(0.3),
            nn.Linear(256, n_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.head(x)
        return x

# -----------------------
# ==== Train / Eval =====
# -----------------------
def build_label_space(df: pd.DataFrame) -> Tuple[Dict[str, int], Dict[int, str]]:
    """Build label mapping with class 0 = no_mosquito, others from species present."""
    species = df["species"].fillna("").astype(str).str.strip()
    species = sorted([s for s in species.unique() if s != ""])
    label2idx = {"no_mosquito": 0}
    for i, s in enumerate(species, start=1):
        label2idx[s] = i
    idx2label = {v: k for k, v in label2idx.items()}
    return label2idx, idx2label

def make_datasets(df_all: pd.DataFrame) -> Tuple[Dataset, Dataset, Dict[str, int], Dict[int, str]]:
    # Filter mixed rows that actually exist and have species
    mask_exist = df_all["mix_file"].astype(str).apply(_safe_exists)
    df = df_all[mask_exist].copy()
    # Keep only rows with a species label (these define positive classes)
    df = df[df["species"].fillna("").astype(str).str.strip() != ""].copy()
    if len(df) == 0:
        raise RuntimeError("No valid mixed samples found with existing 'mix_file' and non-empty 'species'.")

    label2idx, idx2label = build_label_space(df)

    # Mixed dataset
    mixed_ds = MixedMosquitoDataset(df=df, label2idx=label2idx, augment=True)

    # Noise-only dataset (class 0)
    noise_files = list_audio_files(NOISE_DIR, exts=(".wav", ".flac"))
    random.shuffle(noise_files)
    # Cap negatives so they don't exceed MAX_NOISE_FRACTION of total
    max_noise = int(MAX_NOISE_FRACTION * len(mixed_ds))
    if max_noise < 1 and len(noise_files) > 0:
        max_noise = 1
    noise_files = noise_files[:max_noise]
    noise_ds = NoiseOnlyDataset(noise_files, label_no_mosquito=label2idx["no_mosquito"], augment=True)

    # Combined
    full_ds = ConcatDataset([mixed_ds, noise_ds])

    return full_ds, mixed_ds, label2idx, idx2label

def stratified_split_indices(labels: np.ndarray, val_size: float) -> Tuple[np.ndarray, np.ndarray]:
    """Try stratified; if fails due to rare classes, fall back to random split."""
    try:
        sss = StratifiedShuffleSplit(n_splits=1, test_size=val_size, random_state=SEED)
        train_idx, val_idx = next(sss.split(np.zeros_like(labels), labels))
    except Exception:
        warnings.warn("Stratified split failed due to rare classes. Falling back to random split.", RuntimeWarning)
        train_idx, val_idx = train_test_split(
            np.arange(len(labels)), test_size=val_size, random_state=SEED, shuffle=True
        )
    return train_idx, val_idx

def indices_and_labels_from_concat(ds: ConcatDataset) -> np.ndarray:
    """Extract labels for each item in ConcatDataset without loading audio."""
    labels = []
    for sub in ds.datasets:
        if isinstance(sub, MixedMosquitoDataset):
            labs = sub.df["species"].fillna("").astype(str).str.strip().map(lambda s: sub.label2idx.get(s, 0)).tolist()
            labels.extend(labs)
        elif isinstance(sub, NoiseOnlyDataset):
            labels.extend([sub.label] * len(sub))
        else:
            raise TypeError("Unexpected dataset type in ConcatDataset.")
    return np.array(labels, dtype=np.int64)

def make_loaders(ds_full: ConcatDataset, train_idx: np.ndarray, val_idx: np.ndarray,
                 labels_all: np.ndarray) -> Tuple[DataLoader, DataLoader]:
    # Train weights for class-balanced sampling
    y_train = labels_all[train_idx]
    class_counts = np.bincount(y_train, minlength=y_train.max() + 1)
    class_weights = 1.0 / (class_counts + 1e-6)
    sample_weights = class_weights[y_train]
    # ensure proper dtype for sampler
    sampler = WeightedRandomSampler(weights=torch.tensor(sample_weights, dtype=torch.double),
                                    num_samples=len(train_idx),
                                    replacement=True)

    train_subset = torch.utils.data.Subset(ds_full, train_idx.tolist())
    val_subset   = torch.utils.data.Subset(ds_full, val_idx.tolist())

    train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, sampler=sampler,
                              num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
    val_loader   = DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False,
                              num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
    return train_loader, val_loader

def evaluate_and_save(model: nn.Module, val_loader: DataLoader,
                      idx2label: Dict[int, str], best_model_path: str, export_dir: str):
    """Evaluate model on val_loader and export in multiple formats."""
    model.eval()
    y_true, y_pred = [], []

    with torch.no_grad():
        for xb, yb in val_loader:
            xb = xb.to(DEVICE, non_blocking=False)
            yb = yb.to(DEVICE, non_blocking=False)
            logits = model(xb)
            preds = torch.argmax(logits, dim=1)
            y_true.extend(yb.cpu().numpy().tolist())
            y_pred.extend(preds.cpu().numpy().tolist())

    y_true = np.array(y_true, dtype=np.int64)
    y_pred = np.array(y_pred, dtype=np.int64)

    # Explicitly pass labels actually present (union) to avoid mismatch
    labels_present = sorted(set(y_true.tolist()) | set(y_pred.tolist()))
    target_names = [idx2label[i] for i in labels_present]

    print("\nValidation report (best checkpoint):")
    print(classification_report(y_true, y_pred, labels=labels_present,
                                target_names=target_names, zero_division=0))
    print("Confusion matrix (rows=true, cols=pred):")
    print(confusion_matrix(y_true, y_pred, labels=labels_present))

    # Ensure export dir
    os.makedirs(export_dir, exist_ok=True)

    # Also save a final PyTorch weights copy
    final_pt = os.path.join(export_dir, "mosquito_cnn_final.pt")
    torch.save(model.state_dict(), final_pt)

    # TorchScript
    example_input, _ = next(iter(val_loader))
    example_input = example_input.to("cpu")
    ts = torch.jit.trace(model.cpu(), example_input)
    ts_path = os.path.join(export_dir, "mosquito_cnn_final_torchscript.pt")
    ts.save(ts_path)

    # ONNX
    onnx_path = os.path.join(export_dir, "mosquito_cnn_final.onnx")
    torch.onnx.export(
        model.cpu(),
        example_input,
        onnx_path,
        input_names=["input"],
        output_names=["output"],
        dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
        opset_version=11
    )
    print(f"\n✅ Models exported to:\n  - {final_pt}\n  - {ts_path}\n  - {onnx_path}")

def train():
    # Load metadata CSV
    if not _safe_exists(CSV_PATH):
        raise FileNotFoundError(f"CSV_PATH not found: {CSV_PATH}")
    df_all = pd.read_csv(CSV_PATH)
    required_cols = {"mix_file", "species"}
    missing = required_cols - set([c.lower() for c in df_all.columns.str.lower()])
    # Try to normalize column names if needed
    df_all.columns = [c.strip() for c in df_all.columns]
    # remap columns case-insensitively
    rename_map = {}
    for col in df_all.columns:
        cl = col.lower()
        if cl == "mix_file":
            rename_map[col] = "mix_file"
        elif cl == "species":
            rename_map[col] = "species"
    if rename_map:
        df_all = df_all.rename(columns=rename_map)
    assert "mix_file" in df_all.columns and "species" in df_all.columns, \
        "CSV must contain 'mix_file' and 'species' columns."

    # Build datasets
    full_ds, mixed_ds, label2idx, idx2label = make_datasets(df_all)

    # Labels for split
    labels_all = indices_and_labels_from_concat(full_ds)

    # Split
    train_idx, val_idx = stratified_split_indices(labels_all, VAL_SIZE)

    # DataLoaders
    train_loader, val_loader = make_loaders(full_ds, train_idx, val_idx, labels_all)

    # Model
    n_classes = len(label2idx)
    print(f"Classes ({n_classes}): { {i: lbl for i, lbl in idx2label.items()} }")
    print(f"Training samples: {len(train_idx)} | Validation: {len(val_idx)}")
    model = TinyCNN(n_classes).to(DEVICE)

    # Loss with class weights (inverse frequency from train set)
    y_train = labels_all[train_idx]
    class_counts = np.bincount(y_train, minlength=n_classes)
    class_w = 1.0 / (class_counts + 1e-6)
    class_w = class_w / class_w.mean()
    criterion = nn.CrossEntropyLoss(weight=torch.tensor(class_w, dtype=torch.float32, device=DEVICE))

    optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    # Remove verbose kwarg for compatibility with older PyTorch versions
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=2)

    # Training loop with early stopping
    best_acc = -1.0
    patience = PATIENCE
    best_path = os.path.join(MODEL_DIR, "mosquito_cnn_best.pt")

    print("\nStarting training on CPU...")
    for epoch in range(1, EPOCHS + 1):
        model.train()
        running_loss, correct, total = 0.0, 0, 0

        for xb, yb in tqdm(train_loader, desc=f"[Epoch {epoch:02d}] Train", leave=False):
            xb = xb.to(DEVICE, non_blocking=False)
            yb = yb.to(DEVICE, non_blocking=False)

            optimizer.zero_grad()
            logits = model(xb)
            loss = criterion(logits, yb)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * xb.size(0)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == yb).sum().item()
            total += yb.size(0)

        train_loss = running_loss / max(1, total)
        train_acc = correct / max(1, total)

        # Validation
        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0
        with torch.no_grad():
            for xb, yb in tqdm(val_loader, desc=f"[Epoch {epoch:02d}] Val  ", leave=False):
                xb = xb.to(DEVICE, non_blocking=False)
                yb = yb.to(DEVICE, non_blocking=False)

                logits = model(xb)
                loss = criterion(logits, yb)

                val_loss += loss.item() * xb.size(0)
                preds = torch.argmax(logits, dim=1)
                val_correct += (preds == yb).sum().item()
                val_total += yb.size(0)

        val_loss = val_loss / max(1, val_total)
        val_acc = val_correct / max(1, val_total)

        print(f"[Epoch {epoch:02d}] train_loss={train_loss:.4f} acc={train_acc:.4f} | val_loss={val_loss:.4f} acc={val_acc:.4f}")

        # Scheduler + Early stopping
        scheduler.step(val_acc)
        # manual LR log for compatibility across PyTorch versions
        for param_group in optimizer.param_groups:
            print(f"    LR adjusted to: {param_group['lr']:.6f}")

        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), best_path)
            print(f"  ✅ Saved best checkpoint to: {best_path}")
            patience = PATIENCE
        else:
            patience -= 1
            if patience <= 0:
                print("Early stopping triggered.")
                break

    # Load best checkpoint for final eval/export
    model.load_state_dict(torch.load(best_path, map_location="cpu"))
    evaluate_and_save(model, val_loader, idx2label, best_model_path=best_path, export_dir=EXPORT_DIR)

if __name__ == "__main__":
    warnings.filterwarnings("ignore", category=UserWarning)
    train()

# Inference

In [2]:
import torch
import torch.nn.functional as F
import torchaudio
import sounddevice as sd
import numpy as np
import time

# ========== Load Model ==========
model_path = r"mosquito_cnn_final_torchscript.pt"
model = torch.jit.load(model_path, map_location="cpu")
model.eval()
print("✅ Model loaded successfully!")

# Species index mapping
species_index_map = {
    0: 'no_mosquito', 1: 'ae aegypti', 2: 'ae albopictus', 3: 'an albimanus',
    4: 'an arabiensis', 5: 'an atroparvus', 6: 'an barbirostris', 7: 'an coluzzii',
    8: 'an coustani', 9: 'an dirus', 10: 'an farauti', 11: 'an freeborni',
    12: 'an funestus', 13: 'an funestus sl', 14: 'an funestus ss', 15: 'an gambiae',
    16: 'an gambiae sl', 17: 'an gambiae ss', 18: 'an harrisoni', 19: 'an leesoni',
    20: 'an maculatus', 21: 'an maculipalpis', 22: 'an merus', 23: 'an minimus',
    24: 'an pharoensis', 25: 'an quadriannulatus', 26: 'an rivulorum', 27: 'an sinensis',
    28: 'an squamosus', 29: 'an stephensi', 30: 'an ziemanni', 31: 'coquillettidia sp',
    32: 'culex pipiens complex', 33: 'culex quinquefasciatus', 34: 'culex tarsalis',
    35: 'culex tigripes', 36: 'ma africanus', 37: 'ma uniformis', 38: 'toxorhynchites brevipalpis'
}

# ========== Audio Recording Settings ==========
SAMPLE_RATE = 16000   # Hz
DURATION = 1          # seconds (record in chunks of 1 second)

# ========== Transform to Spectrogram ==========
transform = torchaudio.transforms.MelSpectrogram(
    sample_rate=SAMPLE_RATE,
    n_mels=64,
    n_fft=1024,
    hop_length=512
)

def predict_from_audio(audio_chunk):
    """Convert raw audio -> mel spectrogram -> model prediction"""
    # Convert to tensor
    audio_tensor = torch.from_numpy(audio_chunk).float()
    
    # If stereo, take mean (convert to mono)
    if audio_tensor.ndim > 1:
        audio_tensor = audio_tensor.mean(dim=1)
    
    # Convert to spectrogram (shape: [n_mels, time])
    spec = transform(audio_tensor)
    
    # Resize/Pad/Crop to match training input size (64x64 assumed)
    spec = torch.nn.functional.interpolate(spec.unsqueeze(0).unsqueeze(0),
                                           size=(64, 64),
                                           mode="bilinear",
                                           align_corners=False)
    
    # Run inference
    with torch.no_grad():
        output = model(spec)
        probs = F.softmax(output, dim=1)
        predicted_class = torch.argmax(probs, dim=1).item()
    
    return predicted_class, probs[0, predicted_class].item()

# ========== Live Loop ==========
print("🎤 Listening... Press Ctrl+C to stop.")
try:
    while True:
        # Record 1 sec of audio
        audio = sd.rec(int(SAMPLE_RATE * DURATION), samplerate=SAMPLE_RATE, channels=1, dtype='float32')
        sd.wait()

        audio = np.squeeze(audio)  # remove channel dim

        # Predict species
        pred_class, confidence = predict_from_audio(audio)
        species_name = species_index_map[pred_class]

        print(f"🔊 Detected: {species_name}  (confidence: {confidence:.2f})")
        time.sleep(0.01)

except KeyboardInterrupt:
    print("\n🛑 Stopped recording.")

✅ Model loaded successfully!
🎤 Listening... Press Ctrl+C to stop.
🔊 Detected: an sinensis  (confidence: 0.14)
🔊 Detected: an sinensis  (confidence: 0.14)
🔊 Detected: an sinensis  (confidence: 0.14)

🛑 Stopped recording.
