# Plant Ultrasonic Pops Classifier
Goal: detect, label, classify, and visualize ultrasonic pops emitted by seedlings.
This notebook expects WAV files recorded with the Pi GUI at 192000 or 384000 Hz.

In [None]:
# Imports. Install missing packages as needed.
# pip installs (if needed):
# !pip install numpy scipy soundfile matplotlib scikit-learn librosa joblib ipywidgets umap-learn
import os, sys, glob, json, math
from pathlib import Path
import numpy as np
import soundfile as sf
import scipy.signal as sig
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
import joblib

# Optional
try:
    import umap
except Exception:
    umap = None

print("OK: imports loaded")

In [None]:
# Configuration
DATA_DIR = Path("/home/griermarkov/nvme0n1/data/rec").resolve()
FEATURES_CSV = Path("../models/features_labels.csv").resolve()
MODEL_PATH = Path("../models/svm_model.joblib").resolve()
SAMPLE_RATE_HINT = 384000  # used if a file has no samplerate metadata

# Detection parameters
NFFT = 2048
HOP = 512
MIN_FREQ = 20000   # Hz
MAX_FREQ = 100000  # Hz
MAG_THRESH_DB = -40.0
MIN_EVENT_MS = 1.0
MAX_EVENT_MS = 60.0

print("DATA_DIR =", DATA_DIR)

In [None]:
def load_wav(path: Path):
    x, sr = sf.read(str(path), always_2d=False)
    if x.ndim > 1:
        x = x[:,0]
    return x.astype(np.float32), int(sr)

def stft_db(x, sr):
    f, t, Z = sig.stft(x, fs=sr, nperseg=NFFT, noverlap=NFFT-HOP, nfft=NFFT, padded=False, boundary=None)
    S = np.abs(Z) + 1e-12
    S_db = 20*np.log10(S / np.max(S))
    return f, t, S_db

def detect_events(S_db, f, t, min_freq=MIN_FREQ, max_freq=MAX_FREQ, mag_thresh_db=MAG_THRESH_DB, min_ms=MIN_EVENT_MS, max_ms=MAX_EVENT_MS):
    mask = np.zeros_like(S_db, dtype=bool)
    band = (f >= min_freq) & (f <= max_freq)
    mask[band, :] = S_db[band, :] > mag_thresh_db
    # Collapse frequency to a 1D time mask by max over freq
    time_mask = mask.any(axis=0)
    # Find contiguous regions
    events = []
    in_evt = False
    start_i = 0
    for i, val in enumerate(time_mask):
        if val and not in_evt:
            in_evt = True
            start_i = i
        elif not val and in_evt:
            in_evt = False
            end_i = i
            dt = (t[end_i-1] - t[start_i]) * 1000.0
            if dt >= min_ms and dt <= max_ms:
                events.append((start_i, end_i))
    if in_evt:
        end_i = len(time_mask)-1
        dt = (t[end_i] - t[start_i]) * 1000.0
        if dt >= min_ms and dt <= max_ms:
            events.append((start_i, end_i))
    return events  # list of (start_idx, end_idx) in STFT frame index

In [None]:
def extract_features(x, sr, f, t, S_db, evt):
    si, ei = evt
    # Time crop
    t0 = int(si * (len(x) / len(t)))
    t1 = int(ei * (len(x) / len(t)))
    xseg = x[t0:t1]
    dur_s = max(1e-6, len(xseg)/sr)
    # Spectral features in band
    band = (f >= MIN_FREQ) & (f <= MAX_FREQ)
    Sseg = S_db[band, si:ei]
    # Peak dB and frequency
    idx_max = np.unravel_index(np.argmax(Sseg), Sseg.shape)
    peak_db = np.max(Sseg)
    freqs_band = f[band]
    peak_freq = float(freqs_band[idx_max[0]])
    # Simple stats
    mean_db = float(np.mean(Sseg))
    p95_db = float(np.percentile(Sseg, 95))
    # Energy proxy
    energy = float(np.sum(10**(Sseg/20.0)))
    return {
        "dur_ms": dur_s*1000.0,
        "peak_db": peak_db,
        "peak_freq_hz": peak_freq,
        "mean_db": mean_db,
        "p95_db": p95_db,
        "energy": energy
    }

In [None]:
import csv

def batch_detect_and_featurize(wav_paths, limit=None, label_default="unlabeled"):
    rows = []
    for k, p in enumerate(wav_paths):
        if limit and k >= limit:
            break
        try:
            x, sr = load_wav(p)
            if sr <= 0:
                sr = SAMPLE_RATE_HINT
            f, t, S_db = stft_db(x, sr)
            evts = detect_events(S_db, f, t)
            for evt in evts:
                feats = extract_features(x, sr, f, t, S_db, evt)
                feats["file"] = str(p)
                feats["label"] = label_default
                rows.append(feats)
        except Exception as e:
            print("Error on", p, e)
    # Save
    FEATURES_CSV.parent.mkdir(parents=True, exist_ok=True)
    with open(FEATURES_CSV, "w", newline="") as fp:
        writer = csv.DictWriter(fp, fieldnames=["file","dur_ms","peak_db","peak_freq_hz","mean_db","p95_db","energy","label"])
        writer.writeheader()
        for r in rows:
            writer.writerow(r)
    print("Wrote", FEATURES_CSV, "with", len(rows), "rows")
    return rows

wav_list = sorted(DATA_DIR.glob("*.wav"))
print("Found", len(wav_list), "wav files")
# Uncomment to run detection on a subset:
# rows = batch_detect_and_featurize(wav_list, limit=50)

In [None]:
# Train a simple SVM on labeled rows in FEATURES_CSV
import pandas as pd

def load_features(csv_path=FEATURES_CSV):
    df = pd.read_csv(csv_path)
    df = df.dropna()
    return df

def train_svm(df):
    df_sup = df[df.label != "unlabeled"].copy()
    X = df_sup[["dur_ms","peak_db","peak_freq_hz","mean_db","p95_db","energy"]].values
    y = df_sup["label"].values
    Xtrain, Xtest, ytrain, ytest = train_test_split(X, y, test_size=0.25, random_state=42, stratify=y)
    clf = make_pipeline(StandardScaler(), SVC(kernel="rbf", probability=True))
    clf.fit(Xtrain, ytrain)
    yp = clf.predict(Xtest)
    print(classification_report(ytest, yp))
    print(confusion_matrix(ytest, yp))
    joblib.dump(clf, MODEL_PATH)
    print("Saved model to", MODEL_PATH)
    return clf

# df = load_features()
# clf = train_svm(df)

In [None]:
# Inference helper
def predict_csv(model_path=MODEL_PATH, csv_path=FEATURES_CSV):
    import pandas as pd
    df = pd.read_csv(csv_path)
    mask = df.label == "unlabeled"
    if not mask.any():
        print("No unlabeled rows found.")
        return None
    clf = joblib.load(model_path)
    X = df.loc[mask, ["dur_ms","peak_db","peak_freq_hz","mean_db","p95_db","energy"]].values
    probs = clf.predict_proba(X)
    preds = clf.classes_[np.argmax(probs, axis=1)]
    df.loc[mask, "pred_label"] = preds
    df.loc[mask, "pred_conf"] = np.max(probs, axis=1)
    out_csv = MODEL_PATH.parent / "predictions.csv"
    df.to_csv(out_csv, index=False)
    print("Wrote", out_csv)
    return out_csv

# out = predict_csv()

In [None]:
# Quick spectrogram utility to visualize an example file
def plot_spectrogram(wav_path, fmin=MIN_FREQ, fmax=MAX_FREQ):
    x, sr = load_wav(wav_path)
    f, t, S_db = stft_db(x, sr)
    band = (f >= fmin) & (f <= fmax)
    plt.figure(figsize=(10,4))
    plt.imshow(S_db[band,:], aspect="auto", origin="lower",
               extent=[t[0], t[-1], f[band][0], f[band][-1]])
    plt.xlabel("Time (s)")
    plt.ylabel("Freq (Hz)")
    plt.title(str(wav_path.name))
    plt.colorbar(label="dB rel")
    plt.show()

# Example:
# if len(wav_list) > 0:
#     plot_spectrogram(wav_list[0])