In [None]:

"""
KNN genre detector optimized for environments with librosa installed.
- Audio IO/processing via librosa only (no fallbacks).
- Feature: average magnitude spectrum from STFT (Hann), then L2 normalize.
- Distance: Euclidean; fast k-NN using argpartition.
- Parallel feature extraction with joblib.
"""

from __future__ import annotations
import os
import json
from dataclasses import dataclass, asdict
from typing import List, Tuple, Optional, Dict, Any

import numpy as np
import pandas as pd
from joblib import dump, load, Parallel, delayed
import librosa


@dataclass
class SpectralConfig:
    sr: int = 22050
    duration: float = 5.0
    n_fft: int = 4096
    hop_length: int = 2048
    k: int = 5
    res_type: str = "soxr_vhq"
    l2_normalize: bool = True
    center_stft: bool = False   # center=False avoids extra padding and is a bit faster
    n_jobs: int = -1            # parallelism for feature extraction
    eps: float = 1e-12


def _detect_columns(df: pd.DataFrame) -> Tuple[str, str]:
    cols = {c.lower(): c for c in df.columns}
    path_candidates = ["path", "filepath", "file", "audio", "filename", "wav", "mp3", "ogg", "uri", "link"]
    label_candidates = ["label", "genre", "class", "target", "y", "tag"]
    path_col = next((cols[c] for c in path_candidates if c in cols), None)
    label_col = next((cols[c] for c in label_candidates if c in cols), None)
    if path_col is None or label_col is None:
        raise ValueError(
            "No pude detectar columnas de ruta y etiqueta. "
            "Incluye ('path'/'filepath'/...) y ('label'/'genre'/...)."
        )
    return path_col, label_col


def _load_audio(path: str, cfg: SpectralConfig) -> np.ndarray:
    """Load audio with librosa, mono, fixed 5s at cfg.sr."""
    # librosa.load trims or pads to duration if we fix length afterwards
    y, _ = librosa.load(path, sr=cfg.sr, mono=True, res_type=cfg.res_type, duration=cfg.duration)
    target_len = int(round(cfg.duration * cfg.sr))
    if len(y) != target_len:
        y = librosa.util.fix_length(y, target_len)  # pad/trim to exact length
    # remove DC and normalize amplitude
    y = y - float(np.mean(y))
    maxabs = float(np.max(np.abs(y)) + cfg.eps)
    y = (y / maxabs).astype(np.float32, copy=False)
    return y


def _spectral_feature(y: np.ndarray, cfg: SpectralConfig) -> np.ndarray:
    """Average magnitude spectrum from STFT with Hann window."""
    S = librosa.stft(
        y,
        n_fft=cfg.n_fft,
        hop_length=cfg.hop_length,
        window="hann",
        center=cfg.center_stft,
    )  # shape: (n_bins, n_frames)
    mag = np.abs(S).astype(np.float32, copy=False)
    feat = np.mean(mag, axis=1)  # (n_bins,)
    if cfg.l2_normalize:
        feat /= (np.linalg.norm(feat) + cfg.eps)
    return feat.astype(np.float32, copy=False)


def compute_feature_for_file(path: str, cfg: SpectralConfig) -> np.ndarray:
    y = _load_audio(path, cfg)
    return _spectral_feature(y, cfg)


def _euclidean(a: np.ndarray, b: np.ndarray) -> float:
    return float(np.linalg.norm(a - b))


class KNNSpectralClassifier:
    def __init__(self, cfg: Optional[SpectralConfig] = None):
        self.cfg = cfg or SpectralConfig()
        self.X: Optional[np.ndarray] = None  # (N, D), float32
        self.labels: List[str] = []
        self.paths: List[str] = []
        self.meta: Dict[str, Any] = {}

    def fit_from_dataframe(
        self, df: pd.DataFrame, path_col: Optional[str] = None, label_col: Optional[str] = None
    ) -> "KNNSpectralClassifier":
        pcol, lcol = (path_col, label_col) if (path_col and label_col) else _detect_columns(df)

        # Gather valid (path,label)
        pairs = []
        for _, row in df.iterrows():
            path = str(row[pcol])
            label = str(row[lcol])
            if os.path.exists(path):
                pairs.append((path, label))

        if not pairs:
            raise RuntimeError("No se encontraron rutas de audio existentes en el CSV.")

        # Parallel feature extraction
        feats = Parallel(n_jobs=self.cfg.n_jobs, prefer="threads")(
            delayed(compute_feature_for_file)(p, self.cfg) for (p, _) in pairs
        )

        self.X = np.ascontiguousarray(np.stack(feats).astype(np.float32, copy=False))
        self.labels = [lab for (_, lab) in pairs]
        self.paths = [p for (p, _) in pairs]
        self.meta = {"config": asdict(self.cfg)}
        return self

    def kneighbors(self, x: np.ndarray, k: Optional[int] = None):
        if self.X is None:
            raise RuntimeError("Modelo vacío. Llama fit_from_dataframe primero o load().")
        k = int(k or self.cfg.k)
        # compute all distances (vectorized); float32 for speed
        diff = self.X - x[None, :].astype(self.X.dtype, copy=False)
        dists = np.sqrt(np.sum(diff * diff, axis=1, dtype=np.float32))
        # fast top-k selection
        if k < len(dists):
            idxs = np.argpartition(dists, k)[:k]
            idxs = idxs[np.argsort(dists[idxs])]
        else:
            idxs = np.argsort(dists)
        return [
            {"idx": int(i), "path": self.paths[int(i)], "label": self.labels[int(i)], "distance": float(dists[int(i)])}
            for i in idxs[:k]
        ]

    def predict(self, path: str, k: Optional[int] = None):
        x = compute_feature_for_file(path, self.cfg)
        neighbors = self.kneighbors(x, k=k)
        # Majority vote with nearest-neighbor tie-break
        counts: Dict[str, int] = {}
        for nb in neighbors:
            counts[nb["label"]] = counts.get(nb["label"], 0) + 1
        max_count = max(counts.values())
        cands = [lbl for lbl, c in counts.items() if c == max_count]
        pred = cands[0] if len(cands) == 1 else neighbors[0]["label"]
        probs = {lbl: counts.get(lbl, 0) / len(neighbors) for lbl in set(self.labels)}
        return pred, neighbors, probs

    def save(self, path: str) -> None:
        if self.X is None:
            raise RuntimeError("Nada que guardar: X es None.")
        dump({"X": self.X, "labels": self.labels, "paths": self.paths, "meta": self.meta}, path)

    def load(self, path: str) -> "KNNSpectralClassifier":
        payload = load(path)
        self.X = payload["X"]
        self.labels = list(payload["labels"])
        self.paths = list(payload["paths"])
        self.meta = dict(payload.get("meta", {}))
        if "config" in self.meta:
            try:
                self.cfg = SpectralConfig(**self.meta["config"])
            except Exception:
                pass
        return self


def auto_fit_and_save(
    csv_path: str,
    out_model_path: str,
    path_col: Optional[str] = None,
    label_col: Optional[str] = None,
    cfg_overrides: Optional[Dict[str, Any]] = None,
) -> str:
    cfg = SpectralConfig(**(cfg_overrides or {}))
    df = pd.read_csv(csv_path)
    clf = KNNSpectralClassifier(cfg)
    clf.fit_from_dataframe(df, path_col=path_col, label_col=label_col)
    clf.save(out_model_path)
    return out_model_path


def classify_with_model(model_path: str, audio_path: str, k: Optional[int] = None) -> Dict[str, Any]:
    clf = KNNSpectralClassifier().load(model_path)
    pred, neighbors, probs = clf.predict(audio_path, k=k)
    return {
        "predicted_label": pred,
        "neighbors": neighbors,
        "vote_probs": probs,
        "config": clf.meta.get("config", {}),
    }


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="KNN detector de género musical por espectro (librosa).")
    parser.add_argument("--dataset", type=str, help="Ruta al CSV con columnas de rutas y etiquetas.")
    parser.add_argument("--path-col", type=str, default=None, help="Nombre de la columna de ruta.")
    parser.add_argument("--label-col", type=str, default=None, help="Nombre de la columna de etiqueta/género.")
    parser.add_argument("--out-model", type=str, default="knn_model.joblib", help="Ruta de salida del modelo.")
    parser.add_argument("--sr", type=int, default=22050)
    parser.add_argument("--duration", type=float, default=5.0)
    parser.add_argument("--n-fft", type=int, default=4096)
    parser.add_argument("--hop-length", type=int, default=2048)
    parser.add_argument("--k", type=int, default=5)
    parser.add_argument("--center-stft", action="store_true", help="Usa center=True en STFT (por defecto False).")
    parser.add_argument("--n-jobs", type=int, default=-1, help="Paralelismo en extracción de features.")
    parser.add_argument("--classify", type=str, default=None, help="Ruta a un audio de 5s para clasificar.")
    parser.add_argument("--k-infer", type=int, default=None, help="k a usar en inferencia.")

    args = parser.parse_args()

    cfg_over = dict(
        sr=args.sr,
        duration=args.duration,
        n_fft=args.n_fft,
        hop_length=args.hop_length,
        k=args.k,
        center_stft=args.center_stft,
        n_jobs=args.n_jobs,
    )

    if args.dataset:
        model_path = auto_fit_and_save(
            csv_path=args.dataset,
            out_model_path=args.out_model,
            path_col=args.path_col,
            label_col=args.label_col,
            cfg_overrides=cfg_over,
        )
        print(f"[ok] Modelo guardado en: {model_path}")

    if args.classify:
        if not os.path.exists(args.out_model):
            raise SystemExit(f"Modelo no encontrado: {args.out_model}")
        result = classify_with_model(args.out_model, args.classify, k=args.k_infer)
        print(json.dumps(result, indent=2, ensure_ascii=False))