<a href="https://colab.research.google.com/github/SEOUL-ABSS/SHIPSHIP/blob/main/SONAR8.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ==============================================================================
#     ShipsEar Ship vs Noise — V5~V8 Only (RAM‑safe Fine‑Tuning + Report)
#     - Removed V1~V4. Keep V5 (emb) + V6~V8 (fine‑tune) only
#     - Major RAM fixes for V6+ (YAMNet map_fn returns pooled vectors, not (B,T,1024))
#     - No layer creation inside Layer.call; deterministic tf.data prefetch; GPU mem‑growth
#     - Summary CSV + per‑version CM + AP JSON + pretty report printed at the end
# ==============================================================================

print("1) 환경설정/설치 중 ...")
!pip -q install "tensorflow==2.19.0" tensorflow_hub==0.16.1
!pip -q install librosa==0.10.2.post1 soundfile==0.12.1 umap-learn==0.5.6 scikit-learn==1.5.2 psutil==5.9.8 seaborn==0.13.2 joblib==1.4.2

# (Colab일 때) 구글 드라이브 마운트
try:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=False)
    print("Drive mounted.")
except Exception as e:
    print("Colab이 아니라면 무시:", e)

# (선택) 한글 폰트
!apt -yq install fonts-nanum >/dev/null

# ------------------------- Imports & Setup ------------------------------------
import os, re, sys, random, math, gc, time, warnings, shutil, glob, json
from collections import Counter, defaultdict
import numpy as np
import pandas as pd
import psutil
import soundfile as sf
import tensorflow as tf
import tensorflow_hub as hub
import librosa
import scipy.signal as spsig
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm

from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.metrics import (classification_report, confusion_matrix, f1_score,
                             roc_auc_score, average_precision_score,
                             balanced_accuracy_score, top_k_accuracy_score, accuracy_score)
from sklearn.model_selection import GroupShuffleSplit

warnings.filterwarnings("ignore", category=UserWarning)
SEED=42
np.random.seed(SEED); random.seed(SEED); tf.random.set_seed(SEED)
os.environ["PYTHONHASHSEED"]=str(SEED)

# GPU 메모리 성장 허용 (VRAM OOM 방지)
try:
    gpus=tf.config.experimental.list_physical_devices('GPU')
    for g in gpus:
        tf.config.experimental.set_memory_growth(g, True)
    if gpus: print(f" - GPU found: {len(gpus)} | memory growth enabled")
except Exception as e:
    print(" - GPU memory growth set failed (ok):", e)

def mem(): return f"RSS≈{psutil.Process().memory_info().rss/1024**3:.2f} GB"

# 폰트
if os.path.exists('/usr/share/fonts/truetype/nanum/NanumGothic.ttf'):
    fm.fontManager.addfont('/usr/share/fonts/truetype/nanum/NanumGothic.ttf')
    plt.rc('font', family='NanumGothic'); plt.rcParams['axes.unicode_minus'] = False
    print(" - 폰트 OK: NanumGothic")

# ------------------------- Paths & Config -------------------------------------
BASE="/content"
SHIPSEAR_DRIVE="/content/drive/MyDrive/ShipsEar"   # ← 필요시 수정
SHIPSEAR=f"{BASE}/ShipsEar_colab"
os.makedirs("results", exist_ok=True); os.makedirs("cache", exist_ok=True); os.makedirs("artifacts", exist_ok=True)

YAM_SR=16000

# ---- Binary mode (Ship vs Noise) switch ----
BINARY_MODE = True
POS_LABEL = "Ship"
# For binary classification, Top-K is not meaningful; keep for compatibility
BASE_CONFIG["topk"] = 1
BASE_CONFIG=dict(
    seg_dur=1.0,               # 5초
    ship_overlap=0.2,          # A–D overlap 비율(0.2→stride=4s)
    noise_overlap=0.0,         # E 중복 최소화
    vad_frame_sec=0.5, vad_hop_sec=0.25, vad_top_db=25.0,
    test_size=0.2, epochs=40, batch=32, lr=5e-4,
    spec_per_class=2,
    umap_max_points=2000,
    max_seg_per_group_per_class=500,
    noise_jitter_sec=0.5,
    topk=2,
    cache_emb=True,            # 임베딩 캐시
)

# ------------------------- 버전 정의 (V5~V8만 유지) --------------------------
# type: 'emb' (사전 임베딩+헤드) | 'ft' (end-to-end fine-tuning)
VERSIONS = [
    # ----- (A) 임베딩 기반 -----
    dict(name="v5_meanstd_mlp_aug", type="emb", classifier="mlp", pooling="meanstd", aug="light"),

    # ----- (B) 부분 파인튜닝 (RAM‑safe) -----
    dict(name="v6_ft_mean_headwarmup_unfreeze",    type="ft", pooling="mean",
         warmup_epochs=5, ft_epochs=10, base_lr=3e-4, ft_lr=1e-5, batch_ft=8, aug="light"),
    dict(name="v7_ft_meanstd_headwarmup_unfreeze", type="ft", pooling="meanstd",
         warmup_epochs=5, ft_epochs=10, base_lr=3e-4, ft_lr=1e-5, batch_ft=8, aug="light"),
    dict(name="v8_ft_meanstd_fullft_tinyLR",       type="ft", pooling="meanstd",
         warmup_epochs=0, ft_epochs=12, base_lr=1e-5, ft_lr=1e-5, batch_ft=8, aug="light"),
]

# ======================================================================
# 2) 데이터 확보
# ======================================================================
print("\n2) 데이터 확보 중 ...")
if os.path.exists(SHIPSEAR_DRIVE):
    if not os.path.exists(SHIPSEAR) or not os.listdir(SHIPSEAR):
        shutil.copytree(SHIPSEAR_DRIVE, SHIPSEAR, dirs_exist_ok=True)
        print(" - ShipsEar 복사 완료")
    else:
        print(" - ShipsEar 이미 존재")
else:
    raise FileNotFoundError(f" - ShipsEar 드라이브 경로 없음: {SHIPSEAR_DRIVE}")

# ======================================================================
# 3) 라벨 매핑 & 그룹 키
# ======================================================================
A_kw = ["fishing","trawler","trawl","mussel","tug","dredger","dredge"]
B_kw = ["motorboat","motor boat","pilot","sailboat","sailing"]
C_kw = ["ferry","passenger"]
D_kw = ["oceanliner","ocean liner","ro-ro","roro","ro_ro","cargo","containership","container","tanker","bulk","liner","oceangoing"]
E_kw = ["background","noise","ambient","no_ship","noship","silence"]

def resolve_ships_ear_class(path):
    name = os.path.basename(path).lower()
    parent = os.path.basename(os.path.dirname(path)).lower()
    txt = f"{parent} {name}"
    def has_any(txt, kws): return any(k in txt for k in kws)
    if has_any(txt, E_kw): return "E"
    if has_any(txt, A_kw): return "A"
    if has_any(txt, B_kw): return "B"
    if has_any(txt, C_kw): return "C"
    if has_any(txt, D_kw): return "D"
    m = re.search(r'\bclass[_\s-]*([abcde])\b', txt)
    if m: return m.group(1).upper()
    return None

def ships_ear_group_key(path):
    base = os.path.basename(path)
    stem = os.path.splitext(base)[0]
    m = re.search(r'(\d{8}[_-]?\d{4})', stem) or re.search(r'(\d{4}[-_]\d{2}[-_]\d{2}[_-]?\d{2}[-_]?\d{2})', stem)
    if m: return m.group(1)
    parent = os.path.basename(os.path.dirname(path))
    toks = re.split(r'[_\-]+', stem)
    prefix = "_".join(toks[:3]) if len(toks)>=3 else stem
    return f"{parent}:{prefix}"

# ======================================================================
# 4) VAD & 세그 생성
# ======================================================================
EPS=1e-12

def get_activity_intervals_streaming(file_path, top_db=25.0, frame_sec=0.5, hop_sec=0.25):
    try:
        with sf.SoundFile(file_path) as f:
            sr=f.samplerate; n=len(f)
            F=max(1,int(round(frame_sec*sr))); H=max(1,int(round(hop_sec*sr)))
            # pass1: 최대 dB
            max_db=-np.inf; pos=0
            while pos+F<=n:
                f.seek(pos); y=f.read(frames=F, dtype='float32', always_2d=False)
                if y.ndim>1: y=y.mean(axis=1)
                rms=float(np.sqrt(np.mean(y**2))+EPS)
                db=20*np.log10(rms+EPS)
                if db>max_db: max_db=db
                pos+=H
            if not np.isfinite(max_db): return [], []
            th = max_db - top_db
            # pass2: 병합
            active=[]; in_act=False; cur=0.0; pos=0
            while pos+F<=n:
                f.seek(pos); y=f.read(frames=F, dtype='float32', always_2d=False)
                if y.ndim>1: y=y.mean(axis=1)
                rms=float(np.sqrt(np.mean(y**2))+EPS); db=20*np.log10(rms+EPS)
                t0=pos/sr; t1=(pos+F)/sr
                if db>=th:
                    if not in_act: in_act=True; cur=t0
                else:
                    if in_act: in_act=False; active.append((cur,t1))
                pos+=H
            if in_act: active.append((cur,n/sr))
            # (참고) 비활성
            inactive=[]; last=0.0; dur=n/sr
            for s,e in active:
                if s>last: inactive.append((last,s))
                last=e
            if last<dur: inactive.append((last,dur))
            return active, inactive
    except Exception:
        return [], []

def slice_spans_to_segments(spans, seg_dur, hop):
    segs=[]
    for s,e in spans:
        if e-s < seg_dur: continue
        st=s
        while st <= e - seg_dur + 1e-9:
            segs.append((float(st),))
            st += hop
    return segs

def build_segments_ships_ear(root, cfg):
    seg_dur=cfg["seg_dur"]
    hop_ship = seg_dur*(1-cfg["ship_overlap"])
    hop_noise= seg_dur*(1-cfg["noise_overlap"])
    noise_jitter=cfg.get("noise_jitter_sec", 0.0)
    cap=cfg.get("max_seg_per_group_per_class", None)

    infos=[]; labels=[]; groups=[]
    missing=0
    per_gc_count=defaultdict(int)
    summary = defaultdict(int)

    for fp in glob.glob(os.path.join(root, "**", "*.wav"), recursive=True):
        cls = resolve_ships_ear_class(fp)
        if cls is None:
            missing+=1; continue
        try:
            info=sf.info(fp)
        except:
            continue

        gkey = ships_ear_group_key(fp)
        if cls in ["A","B","C","D"]:
            act,_ = get_activity_intervals_streaming(fp, top_db=cfg["vad_top_db"],
                                                     frame_sec=cfg["vad_frame_sec"], hop_sec=cfg["vad_hop_sec"])
            spans = act; hop = hop_ship
        else: # E
            dur = info.frames/info.samplerate
            spans = [(0.0, dur)]; hop = hop_noise

        segs = slice_spans_to_segments(spans, seg_dur, hop)
        random.shuffle(segs)

        for (st,) in segs:
            if cls == "E" and noise_jitter>0:
                j = random.uniform(-noise_jitter, noise_jitter)
                st = max(0.0, min(st + j, (info.frames/info.samplerate) - seg_dur))
            key=(gkey, cls)
            if cap is not None and per_gc_count[key] >= cap:
                continue
            infos.append((fp, float(st), info.samplerate))
            labels.append(cls)
            groups.append(gkey)
            per_gc_count[key]+=1
            summary[cls]+=1

    return infos, labels, groups, summary, missing

# ======================================================================
# 5) 오디오 로드/증강/임베딩 (★ YAMNet 패치 포함)
# ======================================================================

def load_segment(info, seg_dur, target_sr=YAM_SR, rms_norm=True):
    fp, start_time, orig_sr = info
    try:
        start=int(start_time*orig_sr); num=int(seg_dur*orig_sr)
        with sf.SoundFile(fp, 'r') as f:
            actual_num_frames = f.frames - start
            if actual_num_frames <= 0:
                print(f"WARN: Segment start ({start_time:.2f}s) beyond file end for {os.path.basename(fp)}")
                return None
            num = min(num, actual_num_frames)
        y,_=sf.read(fp, start=start, stop=start+num, dtype='float32', always_2d=False)
        if y is None: return None
        if y.ndim>1: y=y.mean(axis=1)
        if orig_sr!=target_sr:
            y = safe_resample(y, orig_sr, target_sr)
        if y is None:
            print(f"WARN: Resampling returned None for {os.path.basename(fp)}"); return None
        if rms_norm:
            rms=np.sqrt(np.mean(y**2))+1e-12
            y *= (10**(-20/20))/rms  # -20 dBFS
        return y.astype(np.float32)
    except Exception as e:
        print(f"ERROR: Failed to load segment {info} - {e}")
        return None

def augment_wave(y, sr, kind="light"):
    if y is None: return None
    if kind=="light":
        g_db = random.uniform(-3, 3)
        y = y * (10**(g_db/20))
        max_shift = int(0.25*sr)
        sh = random.randint(-max_shift, max_shift)
        if sh>0:
            y = np.concatenate([np.zeros(sh, dtype=y.dtype), y[:-sh]])
        elif sh<0:
            y = np.concatenate([y[-sh:], np.zeros(-sh, dtype=y.dtype)])
    return y

# ------------------------- Resampling helper (avoid resampy) ------------------
# Uses SciPy's polyphase resampler when available; otherwise falls back to
# librosa's FFT resampler (no resampy), and finally to linear interpolation.
try:
    import scipy.signal as _spsig_internal
except Exception:
    _spsig_internal = None

def safe_resample(y, orig_sr, target_sr):
    if orig_sr == target_sr:
        return y
    try:
        if _spsig_internal is not None:
            g = math.gcd(int(orig_sr), int(target_sr))
            up = int(target_sr)//g; down = int(orig_sr)//g
            return _spsig_internal.resample_poly(y, up, down).astype(np.float32)
        # Fallback: librosa FFT backend (does not require resampy)
        return librosa.resample(y, orig_sr=orig_sr, target_sr=target_sr, res_type="fft").astype(np.float32)
    except Exception:
        # Last resort: linear interpolation
        new_len = int(round(len(y) * float(target_sr) / float(orig_sr)))
        xp = np.arange(len(y))
        x_new = np.linspace(0, len(y), new_len, endpoint=False)
        return np.interp(x_new, xp, y).astype(np.float32)

YAM_URL = "https://tfhub.dev/google/yamnet/1"

def make_yamnet_infer():
    """hub.load → 실패 시 hub.KerasLayer 폴백. infer(y: 1D float32) -> raw outputs"""
    try:
        module = hub.load(YAM_URL)
        def infer(y):
            y = tf.convert_to_tensor(y, tf.float32)     # (N,)
            return module(y)                            # tuple or dict
        _ = infer(np.zeros(16000, np.float32))
        print("[YAMNet] backend=hub.load")
        return infer
    except Exception as e1:
        print("[YAMNet] hub.load failed → fallback to KerasLayer:", repr(e1))
        layer = hub.KerasLayer(YAM_URL, trainable=False)
        def infer(y):
            y = tf.convert_to_tensor(y, tf.float32)
            try:
                return layer(y)                         # 일부 환경에선 바로 동작
            except Exception:
                return layer(tf.expand_dims(y, 0))      # 배치 차원 강제 (1, N)
        _ = infer(np.zeros(16000, np.float32))
        print("[YAMNet] backend=KerasLayer")
        return infer

def _extract_embeddings_from_output(out):
    emb = None
    if isinstance(out, (list, tuple)):
        if len(out) >= 2: emb = out[1]
    elif isinstance(out, dict):
        emb = out.get("embeddings") or out.get("embedding")
        if emb is None:
            for v in out.values():
                if isinstance(v, dict):
                    emb = v.get("embeddings") or v.get("embedding")
                    if emb is not None:
                        break
    if emb is None:
        return None

    emb = tf.convert_to_tensor(emb)
    if emb.shape.rank == 3 and emb.shape[0] == 1:
        emb = tf.squeeze(emb, axis=0)
    if emb.shape.rank == 1:
        emb = tf.expand_dims(emb, 0)
    return emb  # (T, 1024)

def yamnet_embed(infer, y, pooling="meanstd"):
    if y is None:
        return None
    try:
        out = infer(y)
        emb = _extract_embeddings_from_output(out)
        if emb is None or emb.shape.rank != 2 or int(emb.shape[0]) == 0:
            return None
        if pooling == "mean":
            feat = tf.reduce_mean(emb, axis=0)
        elif pooling == "meanstd":
            m = tf.reduce_mean(emb, axis=0)
            s = tf.math.reduce_std(emb, axis=0)
            feat = tf.concat([m, s], axis=0)
        else:
            raise ValueError("pooling must be 'mean' or 'meanstd'")
        return feat.numpy().astype(np.float32)
    except Exception as e:
        print(f"ERROR: Failed to embed waveform - {e}")
        return None

def embed_many(infos, yam_infer, cfg, pooling="mean", aug=None, cache_key=None, show_every=5000):
    cache_path = None
    if cfg.get("cache_emb", True) and cache_key:
        cache_path = os.path.join("cache", f"emb_{cache_key}.npz")
        if os.path.exists(cache_path):
            try:
                z=np.load(cache_path, allow_pickle=True)
                print(f" - 캐시 로드: {cache_path} | X:{z['X'].shape} | keep:{z['keep'].shape}")
                return z["X"], z["keep"]
            except Exception as e:
                print(f"WARN: 캐시 로드 실패 {cache_path} - {e}. 재생성합니다.")
                if os.path.exists(cache_path): os.remove(cache_path)

    X=[]; keep=[]
    for i,info in enumerate(infos,1):
        y=load_segment(info, cfg["seg_dur"], YAM_SR, rms_norm=True)
        if aug:
            y=augment_wave(y, YAM_SR, kind=aug)
        e=yamnet_embed(yam_infer, y, pooling=pooling)
        if e is not None:
            X.append(e); keep.append(i-1)
        if i%show_every==0:
            print(f"  ... {i}/{len(infos)} | {mem()}")

    X=np.asarray(X, np.float32)
    keep=np.array(keep, np.int64)

    if X.size == 0:
        print(f"ERROR: Failed to generate any embeddings for {len(infos)} segments.")

    if cache_path is not None and X.size > 0:
        try:
            np.savez_compressed(cache_path, X=X, keep=keep)
            print(f" - 캐시 저장: {cache_path}")
        except Exception as e:
            print(f"WARN: 캐시 저장 실패 {cache_path} - {e}")

    return X, keep

# ======================================================================
# 6) 분할(가능하면 그룹-계층)
# ======================================================================

def stratified_group_split(y, groups, test_size=0.2, seed=SEED):
    n=len(y)
    if n < 2:
        raise RuntimeError("[데이터 부족] 세그먼트가 2개 미만입니다.")
    uniq_groups = len(np.unique(groups))
    counts = np.bincount(y) if len(y)>0 else np.array([])
    min_per_class = int(counts[counts > 0].min()) if counts.size else 0
    desired = max(2, int(round(1.0 / max(1e-9, test_size))))
    n_splits = max(2, min(desired, uniq_groups, max(2, min_per_class)))

    try:
        from sklearn.model_selection import StratifiedGroupKFold
        sgkf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=seed)
        tr_idx, te_idx = next(sgkf.split(np.zeros(n), y, groups))
        method=f"StratifiedGroupKFold(n_splits={n_splits})"
    except Exception:
        gss=GroupShuffleSplit(n_splits=1, test_size=test_size, random_state=seed)
        tr_idx, te_idx = next(gss.split(np.arange(n), y, groups))
        method="GroupShuffleSplit"
    return tr_idx, te_idx, method

# ======================================================================
# 7) (A) 임베딩 기반 분류기(MLP/LogReg/SVM)
# ======================================================================

def build_mlp(input_dim, num_classes, lr):
    reg=tf.keras.regularizers.l2(1e-4)
    inp=tf.keras.Input(shape=(input_dim,), name="emb")
    x=tf.keras.layers.BatchNormalization()(inp)
    x=tf.keras.layers.Dense(512, activation='relu', kernel_regularizer=reg)(x); x=tf.keras.layers.Dropout(0.5)(x)
    x=tf.keras.layers.Dense(128, activation='relu', kernel_regularizer=reg)(x); x=tf.keras.layers.Dropout(0.4)(x)
    out=tf.keras.layers.Dense(num_classes, activation='softmax')(x)
    m=tf.keras.Model(inp,out)
    m.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr), loss='categorical_crossentropy', metrics=['accuracy'])
    return m

def train_eval_emb(version, Xtr, ytr, Xte, yte, classes, cfg):
    res={}
    if Xtr.size == 0 or Xte.size == 0 or Xtr.ndim != 2 or Xte.ndim != 2:
        raise RuntimeError("[임베딩 실패] Xtr/Xte가 비어있거나 차원이 잘못되었습니다.")
    if version["classifier"]=="mlp":
        clf=build_mlp(Xtr.shape[-1], len(classes), lr=cfg["lr"])
        callbacks=[
            tf.keras.callbacks.EarlyStopping(patience=8, restore_best_weights=True, monitor='val_loss'),
            tf.keras.callbacks.ReduceLROnPlateau(patience=4, factor=0.5, min_lr=1e-6),
        ]
        ytr_cat=tf.keras.utils.to_categorical(ytr, num_classes=len(classes))
        yte_cat=tf.keras.utils.to_categorical(yte, num_classes=len(classes))
        cnt_tr=Counter(ytr); total=sum(cnt_tr.values())
        class_weight={cls: total/(len(cnt_tr)*cnt) for cls,cnt in cnt_tr.items()}
        t0=time.time()
        clf.fit(Xtr, ytr_cat, validation_data=(Xte, yte_cat),
                epochs=cfg["epochs"], batch_size=cfg["batch"], verbose=0,
                class_weight=class_weight, callbacks=callbacks)
        probs=clf.predict(Xte, verbose=0); pred=probs.argmax(1)
        model_path=f"artifacts/{version['name']}_mlp.keras"; clf.save(model_path)
        res["artifact"]=model_path; res["time_sec"]=time.time()-t0
    else:
        from sklearn.linear_model import LogisticRegression
        from sklearn.svm import SVC
        scaler=StandardScaler().fit(Xtr)
        Xtr_s=scaler.transform(Xtr); Xte_s=scaler.transform(Xte)
        t0=time.time()
        if version["classifier"]=="logreg":
            clf=LogisticRegression(max_iter=2000, class_weight="balanced", n_jobs=-1)
            clf.fit(Xtr_s, ytr); probs=clf.predict_proba(Xte_s); pred=probs.argmax(1)
        else:
            clf=SVC(C=2.0, kernel='rbf', probability=True, class_weight='balanced')
            clf.fit(Xtr_s, ytr); probs=clf.predict_proba(Xte_s); pred=probs.argmax(1)
        res["time_sec"]=time.time()-t0
        import joblib
        model_path=f"artifacts/{version['name']}_{version['classifier']}.joblib"
        scaler_path=f"artifacts/{version['name']}_scaler.joblib"
        joblib.dump(clf, model_path); joblib.dump(scaler, scaler_path)
        res["artifact"]=model_path; res["scaler"]=scaler_path

    true=yte
    res["acc"]=accuracy_score(true, pred)
    res["bal_acc"]=balanced_accuracy_score(true, pred)
    res["macroF1"]=f1_score(true, pred, average='macro')
    try:
        yte_cat=tf.keras.utils.to_categorical(true, num_classes=len(classes))
        res["macroROC"]=roc_auc_score(yte_cat, probs, average='macro', multi_class='ovr')
    except Exception:
        res["macroROC"]=np.nan
    try:
        res["topk"]=top_k_accuracy_score(true, probs, k=BASE_CONFIG["topk"], labels=range(len(classes)))
    except Exception:
        res["topk"]=np.nan
    ap={}
    for i,lab in enumerate(classes):
        y_bin=(true==i).astype(int)
        if 0<y_bin.sum()<len(y_bin): ap[lab]=float(average_precision_score(y_bin, probs[:,i]))
        else: ap[lab]=float("nan")
    res["ap_per_class"]=ap
    res["cm"]=confusion_matrix(true, pred)
    # Binary metrics override (if applicable)
    if len(classes) == 2:
        try:
            pos_idx = classes.index(POS_LABEL) if POS_LABEL in classes else 1
            res["macroROC"] = roc_auc_score(true, probs[:, pos_idx])
        except Exception:
            res["macroROC"] = np.nan
        res["topk"] = np.nan
    return res

# ======================================================================
# 8) (B) 부분 파인튜닝 — RAM‑safe tf.map_fn (returns fixed D per sample)
# ======================================================================

def make_wave_ds(infos, y_enc, cfg, batch, shuffle=False, aug=None):
    L=int(YAM_SR*cfg["seg_dur"])

    def gen():
        for info, y in zip(infos, y_enc):
            wav=load_segment(info, cfg["seg_dur"], YAM_SR, rms_norm=True)
            if aug: wav=augment_wave(wav, YAM_SR, kind=aug)
            if wav is None: continue
            if len(wav)<L:
                pad=np.zeros(L, dtype=np.float32); pad[:len(wav)]=wav; wav=pad
            elif len(wav)>L:
                wav=wav[:L]
            yield wav.astype(np.float32), np.int32(y)

    ds=tf.data.Dataset.from_generator(
        gen,
        output_signature=(
            tf.TensorSpec(shape=(L,), dtype=tf.float32),
            tf.TensorSpec(shape=(), dtype=tf.int32)
        )
    )
    if shuffle:
        ds=ds.shuffle(buffer_size=min(8000, len(infos)))
    # 메모리 안전: 과도한 프리패치 방지
    ds=ds.batch(batch, drop_remainder=False).prefetch(2)
    return ds

class YamnetEmbeddingLayer(tf.keras.layers.Layer):
    """RAM‑safe: map_fn returns per‑sample pooled vector (1024 or 2048),
    avoiding (B,T,1024) materialization. Also defines output shapes for Keras."""
    def __init__(self, yamnet_url, pooling="mean", **kwargs):
        super().__init__(**kwargs)
        if pooling not in ["mean", "meanstd"]:
            raise ValueError("pooling must be 'mean' or 'meanstd'")
        self.yamnet_url = yamnet_url
        self.pooling = pooling
        # YAMNet from TF‑Hub is effectively non‑trainable (frozen graph); keep False to avoid warnings
        self.yamnet_layer = hub.KerasLayer(self.yamnet_url, trainable=False, name="yamnet_base")
        self.out_dim = 1024 if pooling=="mean" else 2048

    def _pool_single(self, waveform):
        # waveform: (N,)
        out = self.yamnet_layer(waveform)
        # extract embeddings
        if isinstance(out, (list, tuple)) and len(out) > 1:
            emb = out[1]  # (T,1024) expected
        elif isinstance(out, dict):
            emb = out.get("embeddings") or out.get("embedding")
            if emb is None and len(out) > 0:
                # fallback: take the last value
                emb = list(out.values())[-1]
        else:
            raise RuntimeError("Unexpected YAMNet output format")
        emb = tf.convert_to_tensor(emb)
        # If a batch dim accidentally appears, drop it using safe gather (no Squeeze)
        rank = tf.rank(emb)
        emb = tf.cond(tf.equal(rank, 3), lambda: emb[0], lambda: emb)  # (T,1024)
        # pool to fixed size
        m = tf.reduce_mean(emb, axis=0)  # (1024,)
        if self.pooling == "mean":
            return m
        s = tf.math.reduce_std(emb, axis=0)  # (1024,)
        return tf.concat([m, s], axis=0)     # (2048,)

    def call(self, inputs):
        # inputs: (B, N)
        outputs = []
        for i in tf.range(tf.shape(inputs)[0]):
            w = inputs[i]
            feat = self._pool_single(w)
            outputs.append(tf.expand_dims(feat,0))
        return tf.concat(outputs, axis=0)

    def compute_output_shape(self, input_shape):
        # input_shape = (B, N)
        return (input_shape[0], self.out_dim)

    def get_config(self):
        config = super().get_config()
        config.update({"yamnet_url": self.yamnet_url, "pooling": self.pooling})
        return config


def build_yamnet_ft_model(num_classes, pooling="meanstd", lr=3e-4):
    wave_in = tf.keras.Input(shape=(int(YAM_SR*BASE_CONFIG["seg_dur"]),), dtype=tf.float32, name="wave")
    embedding_layer = YamnetEmbeddingLayer(yamnet_url=YAM_URL, pooling=pooling, name="yamnet_embedding")
    feat = embedding_layer(wave_in)  # (B, D)

    x=tf.keras.layers.BatchNormalization()(feat)
    x=tf.keras.layers.Dense(512, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)
    x=tf.keras.layers.Dropout(0.5)(x)
    x=tf.keras.layers.Dense(128, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(1e-4))(x)
    x=tf.keras.layers.Dropout(0.4)(x)
    out=tf.keras.layers.Dense(num_classes, activation='softmax')(x)

    model=tf.keras.Model(wave_in, out)
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr),
                  loss='sparse_categorical_crossentropy',
                  metrics=['sparse_categorical_accuracy'])
    return model


def train_eval_ft(version, Xtr_infos, ytr, Xte_infos, yte, classes, yam_infer):
    """Head-only training on frozen YAMNet embeddings.
    Rationale: TF-Hub yamnet/1 exposes no trainable vars; true base FT is not
    possible here. We stream pooled embeddings and train an MLP head.
    """
    if yam_infer is None:
        raise RuntimeError("YAMNet infer not initialized.")

    pooling=version.get("pooling", "meanstd")
    aug=version.get("aug", None)

    print(" - (Head-only) 임베딩(Train) 생성 중 ...", end="")
    Xtr, keep_tr = embed_many(Xtr_infos, yam_infer, BASE_CONFIG, pooling=pooling, aug=aug,
                              cache_key=f"{version['name']}_{pooling}_ft_tr")
    ytr_v=ytr[keep_tr]; print(" OK", Xtr.shape, "|", mem())

    print(" - (Head-only) 임베딩(Test) 생성 중 ...", end="")
    Xte, keep_te = embed_many(Xte_infos, yam_infer, BASE_CONFIG, pooling=pooling, aug=None,
                              cache_key=f"{version['name']}_{pooling}_ft_te")
    yte_v=yte[keep_te]; print(" OK", Xte.shape, "|", mem())

    # 재사용: MLP head 학습/평가
    v_copy = dict(version)  # shallow copy to avoid side effects
    v_copy["classifier"] = "mlp"
    v_copy["type"] = "ft"
    res = train_eval_emb(v_copy, Xtr, ytr_v, Xte, yte_v, classes, cfg=BASE_CONFIG)
    return res

# ======================================================================
# 9) 파이프라인 실행 (한 번 분할 → 모든 버전 공통 비교)
# ======================================================================

def run_all(config=BASE_CONFIG, versions=VERSIONS):
    """Single split → compare across versions (V5~V8), head-only for FT.
    - V5 uses embedding+MLP.
    - V6~V8 run as head-only on frozen YAMNet embeddings (same pipeline as V5),
      differing by pooling/augs/schedules keyed via cache.
    """
    # 세그 생성
    print("[STEP] 세그먼트 생성 중 ...")
    infos, labels, groups, summary, missing = build_segments_ships_ear(SHIPSEAR, config)
    print(" - 클래스별 개수:", dict(summary), "| 매핑 실패:", missing)

    if len(infos) < 2 or len(set(labels)) < 2:
        raise RuntimeError( "[데이터 부족] 세그먼트 수가 너무 적거나 클래스가 2종 미만입니다. - SHIPSEAR_DRIVE 경로와 하위 폴더/파일명을 재확인하세요.- 라벨 매핑 규칙(resolve_ships_ear_class)과 실제 폴더명이 맞는지 점검하세요."
        )

    n_files = len(set([i[0] for i in infos]))
    n_groups = len(set(groups))
    total_h = (len(infos)*config["seg_dur"]) / 3600.0
    print(f" - 세그:{len(infos)} | 파일≈{n_files} | 그룹≈{n_groups} | 총길이≈{total_h:.2f} h")

    if BINARY_MODE:
        labels_bin = ["Ship" if l in ["A","B","C","D"] else "Noise" for l in labels]
        le = LabelEncoder(); y_all = le.fit_transform(labels_bin)
    else:
        le = LabelEncoder(); y_all = le.fit_transform(labels)
    classes = list(le.classes_)
    g_arr = np.array(groups)

    tr_idx, te_idx, method = stratified_group_split(y_all, g_arr, config["test_size"])
    print(f"[Split] method={method} | train={len(tr_idx)} | test={len(te_idx)} | 그룹수 train/test={len(set(g_arr[tr_idx]))}/{len(set(g_arr[te_idx]))}")

    Xtr_infos = [infos[i] for i in tr_idx]; ytr = y_all[tr_idx]
    Xte_infos = [infos[i] for i in te_idx]; yte = y_all[te_idx]
    classes = list(le.classes_)

    # 임베딩 필요 여부: emb/ft 모두 infer 필요 (ft도 head-only)
    yam_infer = None
    if any(v["type"] in ["emb", "ft"] for v in versions):
        print("YAMNet infer 준비 중 ...", end="")
        yam_infer = make_yamnet_infer()
        try:
            test_wav = (np.random.randn(YAM_SR).astype(np.float32) * 1e-3)
            feat_mean = yamnet_embed(yam_infer, test_wav, pooling="mean")
            feat_ms   = yamnet_embed(yam_infer, test_wav, pooling="meanstd")
            print(" [Sanity] mean:", (None if feat_mean is None else feat_mean.shape),
                  "| meanstd:", (None if feat_ms is None else feat_ms.shape))
        except Exception as e:
            print(f" [Sanity Failed] - {e}")
            yam_infer = None

    all_results = []
    for v in versions:
        print(f"================= {v['name']} =================")
        pooling = v.get("pooling", "mean")

        if v["type"] == "emb":
            if yam_infer is None:
                print(f"YAMNet infer 문제로 임베딩 기반 버전 {v['name']} 건너뜀.")
                continue
            aug = v.get("aug", None)
            cache_key = f"{v['name']}_{pooling}_aug{aug}_{config['seg_dur']}s"

            print(" - 임베딩(Train) 중 ...", end="")
            Xtr, keep_tr = embed_many(Xtr_infos, yam_infer, config, pooling=pooling, aug=aug, cache_key=cache_key+"_tr")
            ytr_v = ytr[keep_tr]; print(" OK", Xtr.shape, "|", mem())
            print(" - 임베딩(Test) 중 ...", end="")
            Xte, keep_te = embed_many(Xte_infos, yam_infer, config, pooling=pooling, aug=None, cache_key=cache_key+"_te")
            yte_v = yte[keep_te]; print(" OK", Xte.shape, "|", mem())

            if Xtr.size == 0 or Xte.size == 0:
                raise RuntimeError(f"[{v['name']}] 임베딩 생성 실패")

            res = train_eval_emb(v, Xtr, ytr_v, Xte, yte_v, classes, cfg=config)

        elif v["type"] == "ft":
            if yam_infer is None:
                print(f"YAMNet infer 문제로 FT 버전 {v['name']} 건너뜀.")
                continue
            print(" - (Head-only) FT 버전 실행 중 ... |", mem())
            # Head-only: 동일 임베딩 생성 → 헤드 MLP 학습
            aug = v.get("aug", None)
            cache_key = f"{v['name']}_{pooling}_ft_{config['seg_dur']}s"

            print(" - (Head-only) 임베딩(Train) 생성 중 ...", end="")
            Xtr, keep_tr = embed_many(Xtr_infos, yam_infer, config, pooling=pooling, aug=aug, cache_key=cache_key+"_tr")
            ytr_v = ytr[keep_tr]; print(" OK", Xtr.shape, "|", mem())
            print(" - (Head-only) 임베딩(Test) 생성 중 ...", end="")
            Xte, keep_te = embed_many(Xte_infos, yam_infer, config, pooling=pooling, aug=None, cache_key=cache_key+"_te")
            yte_v = yte[keep_te]; print(" OK", Xte.shape, "|", mem())

            if Xtr.size == 0 or Xte.size == 0:
                raise RuntimeError(f"[{v['name']}] 임베딩 생성 실패")

            v_copy = dict(v)
            v_copy["classifier"] = "mlp"
            v_copy["type"] = "ft"
            res = train_eval_emb(v_copy, Xtr, ytr_v, Xte, yte_v, classes, cfg=config)

        else:
            print(f" - 알 수 없는 버전 타입: {v['type']}. 건너뜀.")
            continue

        # 공통: 결과 저장/집계
        row = dict(
            version=v['name'], type=v['type'],
            pooling=v.get('pooling','-'), classifier=(v.get('classifier','-') if v['type']=='emb' else 'mlp'),
            aug=(v.get('aug') or "none"),
            acc=res["acc"], bal_acc=res["bal_acc"], macroF1=res["macroF1"],
            macroROC=res["macroROC"], topk=res["topk"],
            time_sec=res["time_sec"], artifact=res.get("artifact","")
        )
        all_results.append((row, res))

        # 혼동행렬 저장
        cm = res["cm"]
        plt.figure(figsize=(5.5,4.8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
        plt.xlabel("예측"); plt.ylabel("실제"); plt.title(f"CM — {v['name']}")
        plt.tight_layout()
        plt.savefig(f"results/cm_{v['name']}.png", dpi=150)
        plt.close()

        # AP per class 저장
        with open(f"results/ap_{v['name']}.json","w") as f:
            json.dump(res["ap_per_class"], f, indent=2)

    # 요약 표
    df = pd.DataFrame([r[0] for r in all_results])
    if not df.empty:
        df_sorted = df.sort_values(["macroF1","bal_acc","acc"], ascending=False)
        df_sorted.to_csv("results/summary.csv", index=False)
        print("[SUMMARY — V5~V8]")
        print(df_sorted.to_string(index=False))

        with open("results/report.md","w", encoding="utf-8") as f:
            f.write(" Ship vs Noise — V5~V8 비교 요약")
            f.write("|version|type|pooling|classifier|aug|acc|bal_acc|macroF1|macroROC|topk|time_sec|artifact|")
            f.write("|---|---|---|---|---|---:|---:|---:|---:|---:|---:|---|")
            for _,row in df_sorted.iterrows():
                f.write(
                    f"|{row['version']}|{row['type']}|{row['pooling']}|{row['classifier']}|{row['aug']}|"
                    f"{row['acc']:.4f}|{row['bal_acc']:.4f}|{row['macroF1']:.4f}|{(np.nan if pd.isna(row['macroROC']) else row['macroROC']):.4f}|"
                    f"{(np.nan if pd.isna(row['topk']) else row['topk']):.4f}|{row['time_sec']:.1f}|{row['artifact']}|"
                )
            f.write("- 혼동행렬: results/cm_*.png - AP per class: results/ap_*json")
    else:
        print("No results to summarize (check data path & pipeline).")

    print("결과 파일:")
    print(" - results/summary.csv")
    print(" - results/report.md")
    print(" - results/cm_*.png")
    print(" - results/ap_*.json")
    print(" - artifacts/* (모델)")

# 실행
run_all(BASE_CONFIG, VERSIONS)
print("\n🎉 완료 — 버전별 결과는 results/summary.csv, results/report.md, cm_*.png, artifacts/* 에 저장됩니다.")


1) 환경설정/설치 중 ...
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Drive mounted.


 - GPU found: 1 | memory growth enabled
 - 폰트 OK: NanumGothic

2) 데이터 확보 중 ...
 - ShipsEar 이미 존재
[STEP] 세그먼트 생성 중 ...
 - 클래스별 개수: {'C': 5101, 'B': 3139, 'E': 1140, 'A': 1856, 'D': 1513} | 매핑 실패: 0
 - 세그:12749 | 파일≈85 | 그룹≈85 | 총길이≈3.54 h
[Split] method=StratifiedGroupKFold(n_splits=5) | train=9789 | test=2960 | 그룹수 train/test=69/16
YAMNet infer 준비 중 ...[YAMNet] backend=hub.load
 [Sanity] mean: (1024,) | meanstd: (2048,)
 - 임베딩(Train) 중 ...  ... 5000/9789 | RSS≈3.92 GB
 - 캐시 저장: cache/emb_v5_meanstd_mlp_aug_meanstd_auglight_1.0s_tr.npz
 OK (9789, 2048) | RSS≈4.04 GB
 - 임베딩(Test) 중 ... - 캐시 저장: cache/emb_v5_meanstd_mlp_aug_meanstd_auglight_1.0s_te.npz
 OK (2960, 2048) | RSS≈4.04 GB
 - (Head-only) FT 버전 실행 중 ... | RSS≈4.24 GB
 - (Head-only) 임베딩(Train) 생성 중 ...  ... 5000/9789 | RSS≈4.26 GB
 - 캐시 저장: cache/emb_v6_ft_mean_headwarmu