<a href="https://colab.research.google.com/github/SEOUL-ABSS/SHIPSHIP/blob/main/SHIPDETECTOR_V2_BYOL-A.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ==============================================================================
#         DeepShip + MBARI (Streaming VAD) — Ship vs Noise Binary Classifier
#     (RAM-safe MBARI ingestion, Reservoir Sampling, Full Audit & Visualization)
# ==============================================================================

print("1) 환경설정/설치 중 ...")
!pip -q install tensorflow tensorflow_hub soundfile librosa boto3 noisereduce umap-learn psutil

# -------------------------- Imports & Setup -----------------------------------
import os, sys, subprocess, random, math, gc, time, warnings, shutil
from collections import Counter
import numpy as np
import pandas as pd
import psutil
import soundfile as sf
import tensorflow as tf
import tensorflow_hub as hub
import librosa, librosa.display
import boto3
from botocore import UNSIGNED
from botocore.client import Config
from sklearn.model_selection import GroupShuffleSplit
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix, f1_score, roc_auc_score, roc_curve, auc
import umap.umap_ as umap
import matplotlib.pyplot as plt
import seaborn as sns
warnings.filterwarnings("ignore", category=UserWarning)

# Fonts (optional, for Korean labels)
!sudo apt-get -y install fonts-nanum > /dev/null
!sudo fc-cache -fv > /dev/null
import matplotlib.font_manager as fm
font_path = '/usr/share/fonts/truetype/nanum/NanumGothic.ttf'
if os.path.exists(font_path):
    fm.fontManager.addfont(font_path)
    plt.rc('font', family='NanumGothic')
    plt.rcParams['axes.unicode_minus'] = False

SEED = 42
np.random.seed(SEED); random.seed(SEED); tf.random.set_seed(SEED)
os.environ["PYTHONHASHSEED"] = str(SEED)

# -------------------------- Paths & Consts ------------------------------------
YAMNET_SAMPLE_RATE = 16000
BASE = "/content"
DEEPSHIP_BASE_PATH = f"{BASE}/DeepShip"
MBARI_BASE_DIR     = f"{BASE}/MBARI_noise_data"
REPORT_DIR         = f"{BASE}/reports"
os.makedirs(REPORT_DIR, exist_ok=True)

# -------------------------- Small Utils ---------------------------------------
class Timer:
    def __init__(self, name): self.name=name
    def __enter__(self): self.t=time.time(); return self
    def __exit__(self, *a):
        rss = psutil.Process().memory_info().rss/1024**3
        print(f"[TIMER] {self.name}: {time.time()-self.t:.2f}s | RSS≈{rss:.2f} GB")

def mem():
    return f"RSS≈{psutil.Process().memory_info().rss/1024**3:.2f} GB"

# ==============================================================================
# 2) 데이터 확보 (DeepShip + MBARI 일부)
# ==============================================================================
print("\n2) 데이터 확보 ...")
with Timer("DeepShip clone"):
    if not os.path.exists(DEEPSHIP_BASE_PATH):
        subprocess.run(['git','clone','--depth','1','https://github.com/irfankamboh/DeepShip.git', DEEPSHIP_BASE_PATH],
                       check=True, capture_output=True)
        print(" - DeepShip OK")
    else: print(" - DeepShip 이미 존재")

with Timer("MBARI fetch (최대 10개)"):
    os.makedirs(MBARI_BASE_DIR, exist_ok=True)
    if not os.listdir(MBARI_BASE_DIR):
        s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))
        pages = s3.get_paginator('list_objects_v2').paginate(Bucket='pacific-sound-16khz', Prefix='2018/01/')
        dl_count, MAX_DL = 0, 10
        for page in pages:
            for obj in page.get('Contents', []):
                if obj['Key'].endswith('.wav') and obj.get('Size',0) > 0:
                    local = os.path.join(MBARI_BASE_DIR, os.path.basename(obj['Key']))
                    if not os.path.exists(local):
                        s3.download_file('pacific-sound-16khz', obj['Key'], local); dl_count+=1
                if dl_count >= MAX_DL: break
            if dl_count >= MAX_DL: break
        print(f" - MBARI OK ({dl_count}개)")
    else:
        print(f" - MBARI OK (이미 {len(os.listdir(MBARI_BASE_DIR))}개 존재)")

# ==============================================================================
# 3) 스트리밍 VAD + 저수지 표본추출 기반 세그먼트 생성 (RAM-safe)
# ==============================================================================
EPS = 1e-12

def get_activity_intervals_streaming(file_path, top_db=25.0, frame_sec=0.5, hop_sec=0.25):
    """
    librosa.load 없이 2-pass 스트리밍으로 활성/비활성(초 단위)을 구함.
    pass1: block RMS dB max, pass2: 임계치 이상 블록 병합.
    """
    try:
        with sf.SoundFile(file_path) as f:
            sr = f.samplerate; n = len(f)
            F = max(1, int(round(frame_sec*sr))); H = max(1, int(round(hop_sec*sr)))

            # pass1
            max_db = -np.inf; pos = 0
            while pos + F <= n:
                f.seek(pos); y = f.read(frames=F, dtype='float32', always_2d=False)
                if y.ndim>1: y=y.mean(axis=1)
                rms = float(np.sqrt(np.mean(y**2)) + EPS)
                db  = 20*np.log10(rms+EPS)
                if db > max_db: max_db = db
                pos += H
            if not np.isfinite(max_db): return [], []

            thresh = max_db - top_db

            # pass2
            active=[]; in_active=False; cur_start=0.0
            pos=0
            while pos + F <= n:
                f.seek(pos); y = f.read(frames=F, dtype='float32', always_2d=False)
                if y.ndim>1: y=y.mean(axis=1)
                rms = float(np.sqrt(np.mean(y**2)) + EPS); db = 20*np.log10(rms+EPS)
                t0 = pos/sr; t1=(pos+F)/sr
                if db >= thresh:
                    if not in_active: in_active=True; cur_start=t0
                else:
                    if in_active: in_active=False; active.append((cur_start, t1))
                pos+=H
            if in_active: active.append((cur_start, n/sr))

            # inactive
            inactive=[]; last=0.0; dur=n/sr
            for s,e in active:
                if s>last: inactive.append((last,s))
                last=e
            if last<dur: inactive.append((last,dur))
            return active, inactive
    except Exception:
        return [], []

def _reservoir_segments_from_spans(spans, file_path, sr_orig, seg_dur, hop, cap):
    """ spans를 seg_dur로 쪼개되, cap개만 저수지 표본추출로 유지 (O(cap)) """
    res=[]; seen=0
    for s,e in spans:
        if e-s < seg_dur: continue
        st = s
        while st <= e - seg_dur + 1e-9:
            seg = (file_path, float(st), sr_orig)
            seen += 1
            if len(res) < cap:
                res.append(seg)
            else:
                j = random.randint(1, seen)
                if j <= cap: res[j-1] = seg
            st += hop
    return res

def create_dataset_deepship_plus_mbari(
    deepship_root, mbari_root,
    segment_duration=5.0,
    deepship_overlap=0.2,  # DeepShip은 다양성 위해 약간의 겹침 허용
    mbari_overlap=0.0,     # MBARI는 중복 줄이기 위해 0~0.1 추천
    mbari_cap_active_per_file=120,
    mbari_cap_inactive_per_file=120,
    vad_frame_sec=0.5, vad_hop_sec=0.25, vad_top_db=25.0,
):
    ship_segments=[]; noise_segments=[]
    # DeepShip
    print("\n[세그 생성] DeepShip ...")
    hop_ds = segment_duration*(1-deepship_overlap)
    ds_files=ds_ship=ds_noise=0
    for root,_,files in os.walk(deepship_root):
        for fn in sorted([f for f in files if f.lower().endswith('.wav')]):
            fp = os.path.join(root, fn)
            try:
                info = sf.info(fp); ds_files += 1
            except: continue
            act, inact = get_activity_intervals_streaming(fp, top_db=vad_top_db, frame_sec=vad_frame_sec, hop_sec=vad_hop_sec)
            for s,e in act:
                st=s
                while st <= e - segment_duration + 1e-9:
                    ship_segments.append((fp, float(st), info.samplerate)); ds_ship+=1
                    st += hop_ds
            for s,e in inact:
                st=s
                while st <= e - segment_duration + 1e-9:
                    noise_segments.append((fp, float(st), info.samplerate)); ds_noise+=1
                    st += hop_ds
            gc.collect()
    print(f" - DeepShip 파일:{ds_files} | ship:{ds_ship} | noise:{ds_noise} | {mem()}")

    # MBARI → hard-negatives
    print("[세그 생성] MBARI (hard negatives, cap 적용) ...")
    hop_mb = segment_duration*(1-mbari_overlap)
    mb_files=mb_noise_added=0
    for fn in sorted([f for f in os.listdir(mbari_root) if f.lower().endswith('.wav')]):
        fp = os.path.join(mbari_root, fn)
        try: info = sf.info(fp); mb_files += 1
        except: continue
        act, inact = get_activity_intervals_streaming(fp, top_db=vad_top_db, frame_sec=vad_frame_sec, hop_sec=vad_hop_sec)
        active_segs   = _reservoir_segments_from_spans(act,   fp, info.samplerate, segment_duration, hop_mb, mbari_cap_active_per_file)
        inactive_segs = _reservoir_segments_from_spans(inact, fp, info.samplerate, segment_duration, hop_mb, mbari_cap_inactive_per_file)
        noise_segments.extend(active_segs); noise_segments.extend(inactive_segs)
        mb_noise_added += (len(active_segs)+len(inactive_segs))
        gc.collect()
    print(f" - MBARI 파일:{mb_files} | noise 추가:{mb_noise_added} | {mem()}")

    infos  = ship_segments + noise_segments
    labels = (['ship']*len(ship_segments)) + (['noise']*len(noise_segments))
    return infos, labels

# ==============================================================================
# 4) 오디오 로딩(세그먼트) & 임베딩 & 모델
# ==============================================================================
def load_and_process_segment(file_info, duration, target_sr, rms_norm=True):
    """ 파일에서 해당 구간만 읽어 리샘플 & (선택)RMS 정규화 — 메모리 친화적 """
    file_path, start_time, orig_sr = file_info
    try:
        start = int(start_time*orig_sr); num = int(duration*orig_sr)
        y, _ = sf.read(file_path, start=start, stop=start+num, dtype='float32', always_2d=False)
        if y.ndim>1: y = y.mean(axis=1)
        if orig_sr != target_sr:
            y = librosa.resample(y, orig_sr=orig_sr, target_sr=target_sr, res_type="kaiser_fast")
        if rms_norm:
            rms = np.sqrt(np.mean(y**2))+1e-12
            y = y * ((10**(-20/20))/rms)
        return y
    except Exception:
        return None

def extract_yamnet_embedding(info, yamnet_model, seg_dur):
    y = load_and_process_segment(info, seg_dur, YAMNET_SAMPLE_RATE, rms_norm=True)
    if y is None: return None
    try:
        _, emb, _ = yamnet_model(y)
        if emb.shape[0]==0: return None
        return tf.reduce_mean(emb, axis=0).numpy()
    except Exception:
        return None

def embed_infos(infos, yamnet_model, seg_dur, max_items=None, show_every=1000):
    """ infos → (N,1024) 임베딩. max_items로 상한 설정 가능 """
    X = []; kept_indices = []
    for i,info in enumerate(infos):
        if (max_items is not None) and (len(X)>=max_items): break
        e = extract_yamnet_embedding(info, yamnet_model, seg_dur)
        if e is not None:
            X.append(e); kept_indices.append(i)
        if (i+1)%show_every==0:
            print(f"  임베딩 {i+1}/{len(infos)} ... {mem()}")
    return np.asarray(X, dtype=np.float32), kept_indices

def load_yamnet():
    print("\nYAMNet 로드 ...", end="")
    m = hub.load("https://tfhub.dev/google/yamnet/1")
    print(" OK")
    return m

def build_classifier(input_dim, num_classes=2, lr=5e-4):
    inp = tf.keras.Input(shape=(input_dim,), name="emb")
    x = tf.keras.layers.Dense(256, activation='relu')(inp)
    x = tf.keras.layers.Dropout(0.5)(x)
    x = tf.keras.layers.Dense(128, activation='relu')(x)
    x = tf.keras.layers.Dropout(0.5)(x)
    out = tf.keras.layers.Dense(num_classes, activation='softmax')(x)
    model = tf.keras.Model(inp, out)
    model.compile(optimizer=tf.keras.optimizers.Adam(lr), loss='categorical_crossentropy', metrics=['accuracy'])
    return model

# ==============================================================================
# 5) 데이터셋 생성 + 감사(Audit) + 시각화
# ==============================================================================
CONFIG = {
    "segment_duration": 5.0,
    "deepship_overlap": 0.2,
    "mbari_overlap": 0.0,
    "mbari_cap_active_per_file": 120,
    "mbari_cap_inactive_per_file": 120,
    "vad_frame_sec": 0.5, "vad_hop_sec": 0.25, "vad_top_db": 25.0,
    "test_size": 0.2, "epochs": 40, "batch_size": 32, "learning_rate": 5e-4,
}

with Timer("세그먼트 생성(Streaming)"):
    infos, labels = create_dataset_deepship_plus_mbari(
        DEEPSHIP_BASE_PATH, MBARI_BASE_DIR,
        segment_duration=CONFIG["segment_duration"],
        deepship_overlap=CONFIG["deepship_overlap"],
        mbari_overlap=CONFIG["mbari_overlap"],
        mbari_cap_active_per_file=CONFIG["mbari_cap_active_per_file"],
        mbari_cap_inactive_per_file=CONFIG["mbari_cap_inactive_per_file"],
        vad_frame_sec=CONFIG["vad_frame_sec"], vad_hop_sec=CONFIG["vad_hop_sec"], vad_top_db=CONFIG["vad_top_db"],
    )

# ---- Audit: 파일 수/총 길이/클래스 분포(세그먼트 기준) ----
def audit_dataset(infos, labels):
    print("\n[DATASET 감사]")
    cnt = Counter(labels)
    n_files = len(set([i[0] for i in infos]))
    total_dur_h = 0.0
    # 대략적 총 길이(중복 포함) : 세그먼트 수 * seg_dur
    total_dur_h = (len(infos)*CONFIG["segment_duration"])/3600.0
    print(f" - 세그먼트 수: {len(infos)} (files≈{n_files})")
    for k,v in cnt.items(): print(f"   · {k}: {v}")
    print(f" - (중복 포함) 세그먼트 총 길이≈ {total_dur_h:.2f} h")

audit_dataset(infos, labels)

# ---- Class 분포 시각화 ----
def plot_class_distribution(labels, title="세그먼트 클래스 분포"):
    plt.figure(figsize=(5,4))
    vc = pd.Series(labels).value_counts().sort_index()
    sns.barplot(x=vc.index, y=vc.values)
    plt.title(title); plt.ylabel("count"); plt.grid(axis='y', alpha=0.3)
    plt.show()

plot_class_distribution(labels)

# ---- 샘플 스펙트로그램 확인 ----
def show_sample_spectrograms(infos, labels, n_per_class=3):
    classes = sorted(set(labels))
    plt.figure(figsize=(5*n_per_class, 4*len(classes)))
    idx=1
    for c in classes:
        idxs = [i for i,l in enumerate(labels) if l==c]
        random.shuffle(idxs)
        for j in idxs[:n_per_class]:
            file_path, start_time, sr = infos[j]
            y = load_and_process_segment(infos[j], CONFIG["segment_duration"], YAMNET_SAMPLE_RATE, rms_norm=False)
            plt.subplot(len(classes), n_per_class, idx)
            if y is None:
                plt.title(f"{c}: load fail"); idx+=1; continue
            D = librosa.amplitude_to_db(np.abs(librosa.stft(y, n_fft=1024, hop_length=320)), ref=np.max)
            librosa.display.specshow(D, sr=YAMNET_SAMPLE_RATE, x_axis='time', y_axis='log', cmap='magma')
            plt.title(f"{c} | {os.path.basename(file_path)} @ {start_time:.1f}s")
            idx+=1
    plt.tight_layout(); plt.show()

show_sample_spectrograms(infos, labels, n_per_class=3)

# ==============================================================================
# 6) 학습/검증 분할 (파일 기준 Group split → 데이터 누수 방지)
# ==============================================================================
le = LabelEncoder(); y_enc = le.fit_transform(labels)
groups = np.array([i[0] for i in infos])
gss = GroupShuffleSplit(n_splits=1, test_size=CONFIG["test_size"], random_state=SEED)
tr_idx, te_idx = next(gss.split(infos, y_enc, groups))

Xtr_info = [infos[i] for i in tr_idx]; ytr_enc = y_enc[tr_idx]
Xte_info = [infos[i] for i in te_idx]; yte_enc = y_enc[te_idx]

print(f"\n[Split] train={len(Xtr_info)} | test={len(Xte_info)} (files train/test = {len(set([i[0] for i in Xtr_info]))}/{len(set([i[0] for i in Xte_info]))})")

# ==============================================================================
# 7) YAMNet 임베딩 → 분류기 학습/평가 + 시각화 + UMAP
# ==============================================================================
yamnet = load_yamnet()

with Timer("임베딩(Train)"):
    Xtr, tr_kept_indices = embed_infos(Xtr_info, yamnet, CONFIG["segment_duration"])
    ytr_enc_filtered = ytr_enc[tr_kept_indices]
    ytr = tf.keras.utils.to_categorical(ytr_enc_filtered, num_classes=len(le.classes_))
    print(f" - Xtr:{Xtr.shape} | {mem()}")

with Timer("임베딩(Test)"):
    Xte, te_kept_indices = embed_infos(Xte_info, yamnet, CONFIG["segment_duration"])
    yte_enc_filtered = yte_enc[te_kept_indices]
    yte = tf.keras.utils.to_categorical(yte_enc_filtered, num_classes=len(le.classes_))
    print(f" - Xte:{Xte.shape} | {mem()}")

# ---- 임베딩 분포 quick check ----
print(f"[임베딩 norm] Train mean={Xtr.mean():.3f} | std={Xtr.std():.3f}")
print(f"[임베딩 norm] Test  mean={Xte.mean():.3f} | std={Xte.std():.3f}")

# ---- Class weight (언더샘플링 대신) ----
cnt_tr = Counter(ytr_enc_filtered); total = sum(cnt_tr.values())
class_weight = {cls: total/(len(cnt_tr)*cnt) for cls,cnt in cnt_tr.items()}
print("[class_weight]", class_weight, " (0:noise, 1:ship 순서일 가능성)")

# ---- 모델 ----
clf = build_classifier(Xtr.shape[-1], num_classes=len(le.classes_), lr=CONFIG["learning_rate"])
cb = [
    tf.keras.callbacks.EarlyStopping(patience=8, restore_best_weights=True, monitor='val_loss'),
    tf.keras.callbacks.ReduceLROnPlateau(patience=4, factor=0.5, min_lr=1e-6)
]
with Timer("모델 학습"):
    hist = clf.fit(Xtr, ytr, validation_data=(Xte, yte), epochs=CONFIG["epochs"], batch_size=CONFIG["batch_size"],
                   class_weight=class_weight, verbose=1, callbacks=cb)

# ---- 학습 곡선 ----
plt.figure(figsize=(12,4))
plt.subplot(1,2,1); plt.plot(hist.history['accuracy'], label='train'); plt.plot(hist.history['val_accuracy'], label='val'); plt.legend()
plt.title('정확도'); plt.grid(True, alpha=0.3)
plt.subplot(1,2,2); plt.plot(hist.history['loss'], label='train'); plt.plot(hist.history['val_loss'], label='val'); plt.legend()
plt.title('손실'); plt.grid(True, alpha=0.3)
plt.show()

# ---- 평가 지표 ----
probs = clf.predict(Xte, verbose=0); preds = probs.argmax(axis=1); true = yte.argmax(axis=1)
acc = (preds==true).mean()
f1m = f1_score(true, preds, average='macro')
try:
    ship_idx = list(le.classes_).index('ship')
    auc_ship = roc_auc_score((true==ship_idx).astype(int), probs[:, ship_idx]) if len(np.unique(true))>1 else float('nan')
except: auc_ship = float('nan')
print(f"\n[TEST] Acc={acc:.4f} | Macro-F1={f1m:.4f} | AUC(ship)={auc_ship:.4f}")
print("\n[분류 리포트]\n", classification_report(true, preds, target_names=le.classes_))

# ---- 혼동행렬 ----
cm = confusion_matrix(true, preds)
plt.figure(figsize=(5,4))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=le.classes_, yticklabels=le.classes_)
plt.xlabel('예측'); plt.ylabel('실제'); plt.title('혼동 행렬'); plt.show()

# ---- ROC ----
if len(np.unique(true))==2:
    fpr, tpr, _ = roc_curve((true==ship_idx).astype(int), probs[:, ship_idx])
    plt.figure(figsize=(5,4))
    plt.plot(fpr, tpr, lw=2, label=f"AUC={auc(fpr,tpr):.3f}")
    plt.plot([0,1],[0,1],'--',alpha=0.5)
    plt.xlabel('FPR'); plt.ylabel('TPR'); plt.title('ROC (ship)'); plt.legend(); plt.grid(True, alpha=0.3); plt.show()

# ---- UMAP (샘플 제한; 메모리 안전) ----
def show_umap(X, y, classes, title="UMAP (YAMNet Embeddings)", max_points=1500):
    if len(X) > max_points:
        idx = np.random.RandomState(SEED).choice(len(X), size=max_points, replace=False)
        Xs, ys = X[idx], y[idx]
    else:
        Xs, ys = X, y
    if len(Xs) < 10:
        print("UMAP: 표본이 너무 적어 생략"); return
    reducer = umap.UMAP(n_neighbors=min(15, len(Xs)-1), min_dist=0.1, n_components=2, random_state=SEED)
    XY = reducer.fit_transform(Xs)
    df = pd.DataFrame(dict(x=XY[:,0], y=XY[:,1], label=[classes[i] for i in ys]))
    plt.figure(figsize=(7,6))
    sns.scatterplot(data=df, x='x', y='y', hue='label', s=20, alpha=0.7)
    plt.title(title); plt.grid(True, alpha=0.3); plt.show()

show_umap(np.vstack([Xtr,Xte]), np.hstack([ytr.argmax(1), yte.argmax(1)]), le.classes_, title="UMAP (Train+Test)")

print("\n🎉 전체 파이프라인 완료.")

1) 환경설정/설치 중 ...
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m139.3/139.3 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.0/14.0 MB[0m [31m57.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.7/85.7 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[?25hdebconf: unable to initialize frontend: Dialog
debconf: (No usable dialog-like program is installed, so the dialog based frontend cannot be used. at /usr/share/perl5/Debconf/FrontEnd/Dialog.pm line 78, <> line 1.)
debconf: falling back to frontend: Readline
debconf: unable to initialize frontend: Readline
debconf: (This frontend requires a controlling tty.)
debconf: falling back to frontend: Teletype
dpkg-preconfigure: unable to re-open stdin: 

2) 데이터 확보 ...
 - DeepShip OK
[TIMER] DeepShip clone: 80.42s | RSS≈1.46 GB
 - MBARI OK (10개)
[TIMER] MBARI fetch (최대 10개): 316.13s | RSS≈1.47 GB

[세그 생성] DeepShip ...
 - D

In [None]:
# ================================ OOD 평가 모듈 =================================
# 이 블록은 기존 파이프라인에서 학습이 끝난 후에 붙여 실행하세요.
# 필요 전역: YAMNET_SAMPLE_RATE, CONFIG, yamnet, clf(학습된 분류기), le,
#            Xtr/ytr, Xte/yte, Xtr_info/Xte_info (option), BASE 경로
# ==============================================================================

import os, subprocess, random, math, gc, glob, re
import numpy as np
import pandas as pd
import soundfile as sf
import librosa, librosa.display
from collections import Counter
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_curve, auc, average_precision_score, precision_recall_curve

# ---------- 1) Git에서 OOD 샘플 오디오 가볍게 수집 ----------
OOD_ROOT = f"{BASE}/ood_audio_corpus"
os.makedirs(OOD_ROOT, exist_ok=True)

OOD_REPOS = [
    # 소형 예제/테스트 오디오가 비교적 들어있는 경우가 많음
    ("https://github.com/openai/whisper.git",          "whisper"),
    ("https://github.com/pytorch/audio.git",           "torchaudio"),
    ("https://github.com/iver56/audiomentations.git",  "audiomentations"),
    ("https://github.com/huggingface/transformers.git","transformers"),
]

def clone_if_needed(url, name):
    dst = os.path.join(OOD_ROOT, name)
    if not os.path.exists(dst):
        try:
            subprocess.run(["git","clone","--depth","1",url,dst], check=True, capture_output=True)
            print(f" - OK: {url}")
        except Exception as e:
            print(f" - FAIL: {url} ({e})")
    else:
        print(f" - already exists: {url}")
    return dst

print("\n[OOD] 리포지토리 수집 ...")
repo_dirs = [clone_if_needed(u,n) for (u,n) in OOD_REPOS]

# 오디오 확장자 패턴(넓게 잡되 개수 제한)
EXTS = (".wav",".flac",".ogg",".mp3",".m4a",".aac",".wma",".aiff",".aif",".aifc",".au",".mp2",".opus")
def find_audio_files(roots, max_total=200):
    all_files=[]
    for r in roots:
        for ext in EXTS:
            all_files += glob.glob(os.path.join(r, "**", f"*{ext}"), recursive=True)
    # 너무 많은 경우 샘플링
    if len(all_files) > max_total:
        random.shuffle(all_files)
        all_files = all_files[:max_total]
    return all_files

ood_files = find_audio_files(repo_dirs, max_total=250)
print(f" - 수집된 OOD 원본 파일: {len(ood_files)}")

# ---------- 2) OOD 세그먼트(5초) 스트리밍 생성 ----------
def stream_segments_for_ood(file_path, seg_dur=5.0, stride=5.0, cap_per_file=6):
    """librosa.load 없이 스트리밍으로 5초 구간을 균일 스트라이드로 최대 cap만 추출"""
    segs=[]
    try:
        info = sf.info(file_path)
        total = info.frames
        sr    = info.samplerate
        if info.duration < seg_dur: return segs

        # 균일 스트라이드로 시작점 후보 생성
        starts = np.arange(0, info.duration - seg_dur + 1e-9, stride)
        random.shuffle(starts)
        for st in starts[:cap_per_file]:
            segs.append((file_path, float(st), sr))
    except:
        pass
    return segs

# 너무 많이 뽑지 않도록 전체 cap (예: 800 세그먼트)
OOD_GLOBAL_CAP = 800
ood_segments=[]
for f in ood_files:
    segs = stream_segments_for_ood(f, seg_dur=CONFIG["segment_duration"], stride=CONFIG["segment_duration"], cap_per_file=6)
    ood_segments.extend(segs)
    if len(ood_segments) >= OOD_GLOBAL_CAP: break
print(f" - 생성된 OOD 세그먼트: {len(ood_segments)}")

# ---------- 3) OOD 임베딩 ----------
def load_and_process_segment(info, duration, target_sr, rms_norm=True):
    file_path, start_time, orig_sr = info
    try:
        start = int(start_time*orig_sr); num = int(duration*orig_sr)
        y, _ = sf.read(file_path, start=start, stop=start+num, dtype='float32', always_2d=False)
        if y.ndim>1: y = y.mean(axis=1)
        if orig_sr != target_sr:
            y = librosa.resample(y, orig_sr=orig_sr, target_sr=target_sr, res_type="kaiser_fast")
        if rms_norm:
            rms = np.sqrt(np.mean(y**2))+1e-12
            y = y * ((10**(-20/20))/rms)
        return y
    except:
        return None

def yamnet_embed_batch(infos, seg_dur=5.0, batch=128):
    X=[]; rms_list=[]; kept=[]
    for i,info in enumerate(infos):
        y = load_and_process_segment(info, seg_dur, YAMNET_SAMPLE_RATE, rms_norm=True)
        if y is None: continue
        # RMS(정규화 전에)도 저장해 에너지 편향 분석
        y_raw = load_and_process_segment(info, seg_dur, YAMNET_SAMPLE_RATE, rms_norm=False)
        rms_list.append(float(np.sqrt(np.mean(y_raw**2))+1e-12) if y_raw is not None else np.nan)
        try:
            _, emb, _ = yamnet(y)
            if emb.shape[0] == 0: continue
            X.append(tf.reduce_mean(emb, axis=0).numpy())
            kept.append(info)
        except:
            continue
        if (i+1)%500==0:
            print(f"  OOD 임베딩 {i+1}/{len(infos)}...")
    return np.asarray(X, dtype=np.float32), np.asarray(rms_list), kept

print("\n[OOD] 임베딩 추출 ...")
Xood, rms_ood, kept_ood = yamnet_embed_batch(ood_segments, seg_dur=CONFIG["segment_duration"])
print(f" - Xood:{Xood.shape}")

if Xood.shape[0] == 0:
    print("경고: OOD 임베딩이 비었습니다. 리포 소스나 max_total, cap을 조정해보세요.")

# ---------- 4) 임계값 선택(검증셋 TPR=95%) & ID/OOD FPR 비교 ----------
# 학습에 사용한 train에서 validation을 분리(간단히 10% hold-out)
def split_val_from_train(Xtr, ytr_onehot, val_ratio=0.1, seed=42):
    n = len(Xtr)
    idx = np.arange(n)
    rng = np.random.RandomState(seed)
    rng.shuffle(idx)
    k = max(1, int(round(n*val_ratio)))
    val_idx = idx[:k]; tr_idx = idx[k:]
    return Xtr[tr_idx], ytr_onehot[tr_idx], Xtr[val_idx], ytr_onehot[val_idx]

Xtr_fit, ytr_fit, Xval, yval = split_val_from_train(Xtr, ytr, val_ratio=0.1, seed=SEED)

# 재학습 없이 clf를 재사용하되, val 확률만 새로 추정
p_val = clf.predict(Xval, verbose=0)
p_te  = clf.predict(Xte,  verbose=0)

ship_idx = list(le.classes_).index('ship')
yval_bin = (yval.argmax(1)==ship_idx).astype(int)
yte_bin  = (yte.argmax(1)==ship_idx).astype(int)

def select_threshold_by_tpr(y_true_bin, y_score, target_tpr=0.95):
    fpr, tpr, thr = roc_curve(y_true_bin, y_score)
    # TPR이 target에 가장 근접한 점의 threshold
    j = np.argmin(np.abs(tpr - target_tpr))
    return float(thr[j]), float(tpr[j]), float(fpr[j])

tau, tpr_at_tau, fpr_at_tau = select_threshold_by_tpr(yval_bin, p_val[:,ship_idx], target_tpr=0.95)
print(f"\n[임계값] TPR@val≈95% → τ={tau:.4f} (val TPR={tpr_at_tau:.3f}, val FPR={fpr_at_tau:.3f})")

# ID-테스트 FPR / OOD FPR
fpr_id  = float(((p_te[:,ship_idx] >= tau) & (yte_bin==0)).mean()) if len(yte_bin)>0 else float('nan')

p_ood = clf.predict(Xood, verbose=0) if Xood.shape[0]>0 else np.zeros((0,len(le.classes_)),dtype=np.float32)
fpr_ood = float((p_ood[:,ship_idx] >= tau).mean()) if p_ood.shape[0]>0 else float('nan')

print(f"[FPR] ID(Test) FPR@τ={fpr_id:.4f} | OOD FPR@τ={fpr_ood:.4f}")

# ---------- 5) 시각화: 확률 분포 / ROC-PR / 에너지 편향 ----------
# (a) 확률 히스토그램
plt.figure(figsize=(7,5))
sns.kdeplot(p_te[yte_bin==1, ship_idx], label="ID: ship", fill=True, alpha=0.3)
sns.kdeplot(p_te[yte_bin==0, ship_idx], label="ID: noise", fill=True, alpha=0.3)
if p_ood.shape[0]>0:
    sns.kdeplot(p_ood[:, ship_idx], label="OOD (others)", fill=True, alpha=0.3)
plt.axvline(tau, color='k', ls='--', label=f"τ={tau:.2f}")
plt.title("Ship 확률 분포(ID vs OOD)"); plt.xlabel("P(ship)"); plt.legend(); plt.grid(True, alpha=0.3); plt.show()

# (b) ROC/PR (ID 기준)
fpr_id_curve, tpr_id_curve, _ = roc_curve(yte_bin, p_te[:,ship_idx])
roc_auc_id = auc(fpr_id_curve, tpr_id_curve)
prec, rec, _ = precision_recall_curve(yte_bin, p_te[:,ship_idx])
auprc = average_precision_score(yte_bin, p_te[:,ship_idx])

plt.figure(figsize=(11,4))
plt.subplot(1,2,1)
plt.plot(fpr_id_curve, tpr_id_curve, lw=2, label=f"AUC={roc_auc_id:.3f}")
plt.plot([0,1],[0,1],'--',alpha=0.4)
plt.xlabel("FPR"); plt.ylabel("TPR"); plt.title("ROC (ID Test)"); plt.legend(); plt.grid(True, alpha=0.3)

plt.subplot(1,2,2)
plt.plot(rec, prec, lw=2)
plt.xlabel("Recall"); plt.ylabel("Precision"); plt.title(f"PR (ID Test), AUPRC={auprc:.3f}")
plt.grid(True, alpha=0.3)
plt.show()

# (c) 에너지 decile 별 FPR (ID-Noise vs OOD)
def segment_rms(info, seg_dur=5.0):
    y = load_and_process_segment(info, seg_dur, YAMNET_SAMPLE_RATE, rms_norm=False)
    if y is None: return np.nan
    return float(np.sqrt(np.mean(y**2))+1e-12)

# ID-Noise RMS와 확률
id_noise_idx = np.where(yte_bin==0)[0]
rms_id_noise = np.array([segment_rms(Xte_info[i], CONFIG["segment_duration"]) if 'Xte_info' in globals() else np.nan
                         for i in id_noise_idx])
prob_id_noise = p_te[id_noise_idx, ship_idx]

def fpr_by_rms_decile(rms_arr, prob_arr, tau, n_bins=10):
    valid = np.isfinite(rms_arr)
    rms_arr, prob_arr = rms_arr[valid], prob_arr[valid]
    if len(rms_arr) < 10:
        return None
    qs = np.quantile(rms_arr, np.linspace(0,1,n_bins+1))
    bins = np.digitize(rms_arr, qs[1:-1], right=True)
    out=[]
    for b in range(n_bins):
        m = (bins==b)
        if m.sum()==0: out.append(np.nan)
        else: out.append(float((prob_arr[m] >= tau).mean()))
    return out, qs

ood_rms = np.zeros(0);
if len(kept_ood)>0:
    ood_rms = np.array([segment_rms(info, CONFIG["segment_duration"]) for info in kept_ood])

res_id = fpr_by_rms_decile(rms_id_noise, prob_id_noise, tau, n_bins=10)
res_ood = (None, None)
if len(ood_rms)>0:
    res_ood = fpr_by_rms_decile(ood_rms, p_ood[:,ship_idx], tau, n_bins=10)

if res_id is not None:
    fpr_bins_id, qs_id = res_id
    plt.figure(figsize=(7,4))
    plt.plot(range(1,11), fpr_bins_id, marker='o', label='ID-Noise')
    if isinstance(res_ood[0], list):
        plt.plot(range(1,11), res_ood[0], marker='o', label='OOD')
    plt.xticks(range(1,11)); plt.xlabel("RMS decile (낮음→높음)")
    plt.ylabel(f"FPR@τ"); plt.title("에너지 구간별 FPR (낮을수록 좋음)")
    plt.grid(True, alpha=0.3); plt.legend(); plt.show()
else:
    print("RMS decile 분석을 위한 유효 표본이 부족합니다.")

print("\n[요약]")
print(f" - 임계값 τ(Val TPR≈95%): {tau:.3f}")
print(f" - FPR(ID-noise)@τ: {fpr_id:.4f}")
print(f" - FPR(OOD)@τ: {fpr_ood:.4f} (낮을수록 좋음)")
print(f" - ROC-AUC(ID test): {roc_auc_id:.3f}, AUPRC(ID test): {auprc:.3f}")
print(" - 그래프: 확률분포/ROC/PR/에너지-디사일 FPR으로, 에너지-편향 여부를 함께 점검")
# ==============================================================================
