In [None]:
# compare_pitch_to_corpus.py
from __future__ import annotations
import os
import glob
import math
import numpy as np
import pandas as pd
from typing import Iterable, Optional, Dict, List, Tuple

# ====== 설정 ======
REF_DIR = "/content/drive/MyDrive/aria-midi-v1-unique-ext/chopin_dataset/train"                               # Train Data 폴더
TARGET_MIDI = "/content/drive/MyDrive/aria-midi-v1-unique-ext/chopin_dataset/sample_chopin_long_v1.mid"       # 단일 타깃 MIDI
SAVE_CSV_PATH = "/content/drive/MyDrive/aria-midi-v1-unique-ext/chopin_dataset/pitch_similarity_results.csv"  # 결과 CSV 저장 경로

DURATION_WEIGHTED = True   # 음 길이로 가중
USE_DRUMS = False          # 드럼 제외
VELOCITY_WEIGHTED = False  # 벨로시티 가중 X

# ====== 유틸 ======
_EPS = 1e-12

def _normalize_histogram(h: np.ndarray) -> np.ndarray:
    h = np.asarray(h, dtype=float)
    s = h.sum()
    if s <= 0:
        return np.ones_like(h) / len(h)
    return h / s

def pitch_class_histogram_from_midi(
    midi_path: str,
    duration_weighted: bool = True,
    use_drums: bool = False,
    velocity_weighted: bool = False,
    program_filter: Optional[Iterable[int]] = None,
) -> np.ndarray:
    import pretty_midi
    pm = pretty_midi.PrettyMIDI(midi_path)
    hist = np.zeros(12, dtype=float)

    keep_programs = set(program_filter) if program_filter is not None else None

    for inst in pm.instruments:
        if not use_drums and inst.is_drum:
            continue
        if keep_programs is not None and inst.program not in keep_programs:
            continue
        for note in inst.notes:
            pc = note.pitch % 12
            dur = float(note.end - note.start)
            w = 1.0
            if duration_weighted:
                w *= max(dur, 0.0)
            if velocity_weighted:
                w *= max(note.velocity, 1)
            hist[pc] += w

    return _normalize_histogram(hist)

def pitch_class_histogram_from_pitches(
    pitches: Iterable[int], weights: Optional[Iterable[float]] = None
) -> np.ndarray:
    hist = np.zeros(12, dtype=float)
    if weights is None:
        for p in pitches:
            hist[p % 12] += 1.0
    else:
        for p, w in zip(pitches, weights):
            hist[p % 12] += float(w)
    return _normalize_histogram(hist)

def jensen_shannon_divergence(p: np.ndarray, q: np.ndarray) -> float:
    p = _normalize_histogram(p)
    q = _normalize_histogram(q)
    m = 0.5 * (p + q)

    def _kl(a, b):
        a = np.clip(a, _EPS, 1.0)
        b = np.clip(b, _EPS, 1.0)
        return np.sum(a * np.log2(a / b))

    return 0.5 * _kl(p, m) + 0.5 * _kl(q, m)

def cosine_similarity(p: np.ndarray, q: np.ndarray) -> float:
    p = np.asarray(p, dtype=float)
    q = np.asarray(q, dtype=float)
    num = float(np.dot(p, q))
    den = float(np.linalg.norm(p) * np.linalg.norm(q))
    if den <= _EPS:
        return 1.0
    return num / den

def earth_movers_distance_1d_circle(p: np.ndarray, q: np.ndarray) -> float:
    # 12-bin 원형에서의 간단한 EMD 근사
    p = _normalize_histogram(p)
    q = _normalize_histogram(q)
    diff = p - q

    def accum_abs(a):
        return float(np.sum(np.abs(np.cumsum(a))))

    best = accum_abs(diff)
    for k in range(1, 12):
        best = min(best, accum_abs(np.roll(diff, k)))
    return best / 12.0

def compare_pitch_hists(ref_hist: np.ndarray, tgt_hist: np.ndarray) -> Dict[str, float]:
    return {
        "js_div": jensen_shannon_divergence(ref_hist, tgt_hist),  # ↓ 좋음
        "cosine": cosine_similarity(ref_hist, tgt_hist),          # ↑ 좋음
        "emd": earth_movers_distance_1d_circle(ref_hist, tgt_hist) # ↓ 좋음
    }

def collect_mid_files(folder: str) -> List[str]:
    files = glob.glob(os.path.join(folder, "**", "*.mid"), recursive=True)
    files += glob.glob(os.path.join(folder, "**", "*.midi"), recursive=True)
    return sorted(files)

# ====== 메인 로직 ======
def main():
    import pretty_midi  # ensure dependency
    print("Scanning reference corpus...")
    ref_files = collect_mid_files(REF_DIR)
    if not ref_files:
        raise FileNotFoundError(f"No MIDI files found in REF_DIR: {REF_DIR}")
    print(f"Reference files: {len(ref_files)}")

    # 기준(코퍼스) 히스토그램들 미리 계산
    ref_hists: Dict[str, np.ndarray] = {}
    for i, f in enumerate(ref_files, 1):
        try:
            ref_hists[f] = pitch_class_histogram_from_midi(
                f, DURATION_WEIGHTED, USE_DRUMS, VELOCITY_WEIGHTED
            )
        except Exception as e:
            print(f"[WARN] Failed to parse ref: {f} -> {e}")

    # --- 타깃 입력: 단일 파일 ---
    target_files: List[str] = []
    if os.path.isfile(TARGET_MIDI):
        target_files = [TARGET_MIDI]
    else:
        raise FileNotFoundError(
            "타깃을 찾을 수 없습니다. TARGET_MIDI(파일) 또는 TARGET_DIR(폴더) 설정을 확인하세요."
        )

    rows = []
    for t in target_files:
        try:
            tgt_hist = pitch_class_histogram_from_midi(
                t, DURATION_WEIGHTED, USE_DRUMS, VELOCITY_WEIGHTED
            )
        except Exception as e:
            print(f"[WARN] Failed to parse target: {t} -> {e}")
            continue

        for ref_path, ref_hist in ref_hists.items():
            metrics = compare_pitch_hists(ref_hist, tgt_hist)
            rows.append({
                "target": t,
                "reference": ref_path,
                "js_div": metrics["js_div"],
                "cosine": metrics["cosine"],
                "emd": metrics["emd"],
            })

    if not rows:
        print("No comparisons produced.")
        return

    df = pd.DataFrame(rows)

    # 평균값만 계산
    avg_js = df["js_div"].mean()
    avg_emd = df["emd"].mean()
    avg_cos = df["cosine"].mean()

    print("\n=== Pitch Distribution Similarity (평균값) ===")
    print(f"JS Divergence (↓): {avg_js:.6f}")
    print(f"EMD (↓):           {avg_emd:.6f}")
    print(f"Cosine (↑):        {avg_cos:.6f}")

    # 필요하다면 CSV 저장도 가능
    os.makedirs(os.path.dirname(SAVE_CSV_PATH), exist_ok=True)
    df.to_csv(SAVE_CSV_PATH, index=False, encoding="utf-8")
    print(f"\nSaved raw results -> {SAVE_CSV_PATH}")

if __name__ == "__main__":
    main()