In [None]:
# -*- coding: utf-8 -*-
"""
IVDD binary grading (2-class: ivdd / normal) from DeepLabCut CSV (3-level header)
- 5 keypoints × (x,y) = 10 dims
- Windowing: SEQ_LEN=30, STRIDE=5
- Model: TimeDistributed(Dense->ReLU) -> LSTM -> LSTM -> Dense(2 logits)
- ラベルは「ファイル名」に 'ivdd' または 'normal' を含むことで自動判定
"""

import os
import glob
import math
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt

# ========= 設定 =========
DATA_DIR   = r"./data_ivdd"                 # CSV を置くフォルダ
CSV_GLOB   = os.path.join(DATA_DIR, "*.csv")

# ★ 指定の5点（ユーザー指定）
KEYPOINTS = [
    "left back paw",
    "right back paw",
    "left front paw",
    "right front paw",
    "tail set",
]

USE_LIKELIHOOD = False
MIN_KEEP_LIKELIHOOD = 0.6

SEQ_LEN = 60
STRIDE  = 30
DIMS    = 10                   # 5点×(x,y)
BATCH_SIZE = 32
EPOCHS     = 50
LR         = 1e-3
VAL_SPLIT_BY_FILE = True       # ファイル単位でtrain/val分割（リーク防止）

# ---- 2クラス定義 ----
CLASS_NAMES  = ["ivdd", "normal"]
CLASS_TO_IDX = {c:i for i,c in enumerate(CLASS_NAMES)}
N_CLASSES    = len(CLASS_NAMES)

tf.random.set_seed(42)
np.random.seed(42)

# ========= ラベル判定（ファイル名から） =========
def infer_label_from_filename(filename: str) -> int:
    """ファイル名（拡張子除く）に 'ivdd' or 'normal' が含まれる前提でインデックスを返す"""
    base = os.path.basename(filename).lower()
    is_ivdd   = "ivdd"   in base
    is_normal = "normal" in base
    if is_ivdd and not is_normal:
        return CLASS_TO_IDX["ivdd"]
    if is_normal and not is_ivdd:
        return CLASS_TO_IDX["normal"]
    raise ValueError(f"ラベルを特定できません: {filename}（ivdd/normal が含まれている必要があります）")

# ========= 入出力ユーティリティ =========
def _norm_name(s: str) -> str:
    # 小文字化し、空白/アンダースコア/ハイフンを除去して比較
    return "".join(ch for ch in s.lower() if ch not in " _-")

def _resolve_keypoints(all_bodyparts: list[str], requested: list[str]) -> list[str]:
    # DLC列に実在するbodypart名へマッピング（表記ゆれ吸収）
    norm2orig = {}
    for bp in all_bodyparts:
        k = _norm_name(bp)
        if k not in norm2orig:  # 衝突は先勝ち
            norm2orig[k] = bp
    resolved = []
    missing = []
    for req in requested:
        k = _norm_name(req)
        if k in norm2orig:
            resolved.append(norm2orig[k])
        else:
            missing.append(req)
    if missing:
        raise ValueError(f"指定キーポイントがCSVで見つかりません: {missing}\n利用可能: {all_bodyparts}")
    return resolved

def read_dlc_5kp_xy(csv_path: str,
                    keypoints: list[str],
                    use_likelihood=True,
                    min_keep_likelihood=0.6) -> tuple[np.ndarray, list[str]]:
    """
    DLCの3段ヘッダCSVを読み込み、指定5点の (x,y) だけを抽出して (N,10) を返す。
    - keypoints: ユーザーが指定した5点（表記ゆれは自動解決）
    - likelihood < 閾値 の (x,y) は NaN → 線形補間
    """
    df = pd.read_csv(csv_path, header=[0,1,2], index_col=0)
    bodyparts = list({bp for (_, bp, _) in df.columns})
    use_kps = _resolve_keypoints(bodyparts, keypoints)

    # (x,y) を列ごとに抽出
    cols = {}
    for bp in use_kps:
        cols[f"{bp}_x"] = df.xs((bp, "x"), level=[1,2], axis=1)
        cols[f"{bp}_y"] = df.xs((bp, "y"), level=[1,2], axis=1)
    X_df = pd.concat(cols.values(), axis=1)
    X_df.columns = list(cols.keys())

    # likelihood で品質管理 → 低いときは (x,y) を NaN
    if use_likelihood:
        for bp in use_kps:
            try:
                lcol = df.xs((bp, "likelihood"), level=[1,2], axis=1).values.flatten()
                low = lcol < min_keep_likelihood
                for c in [f"{bp}_x", f"{bp}_y"]:
                    vals = X_df[c].values
                    vals[low] = np.nan
                    X_df[c] = vals
            except KeyError:
                pass  # likelihood 列がない場合はそのまま

    # 線形補間→前後補完→0埋め
    X_df = X_df.interpolate(method="linear", limit_direction="both", axis=0)
    X_df = X_df.fillna(method="bfill").fillna(method="ffill").fillna(0.0)

    X = X_df.values.astype(np.float32)          # (N, 10)
    return X, use_kps

def zscore_per_file(X: np.ndarray, eps: float=1e-6) -> np.ndarray:
    mu = X.mean(axis=0, keepdims=True)
    sd = X.std(axis=0, keepdims=True)
    return (X - mu) / (sd + eps)

def make_windows(X: np.ndarray, seq_len: int, stride: int) -> np.ndarray:
    n = X.shape[0]
    if n < seq_len:
        return np.empty((0, seq_len, X.shape[1]), dtype=X.dtype)
    starts = range(0, n - seq_len + 1, stride)
    return np.stack([X[s:s+seq_len] for s in starts], axis=0)

def build_dataset(csv_paths, seq_len=SEQ_LEN, stride=STRIDE):
    X_list, y_list, file_ids = [], [], []
    used_kps_any = None

    for p in csv_paths:
        # ファイル名からラベルを決定
        y_lab = infer_label_from_filename(p)

        X_raw, used_kps = read_dlc_5kp_xy(
            p,
            keypoints=KEYPOINTS,
            use_likelihood=USE_LIKELIHOOD,
            min_keep_likelihood=MIN_KEEP_LIKELIHOOD
        )
        if X_raw.shape[1] != DIMS:
            raise ValueError(f"{os.path.basename(p)}: 取り出し次元 {X_raw.shape[1]} が期待の {DIMS} と違います")

        used_kps_any = used_kps  # 代表して最後に表示

        X_raw = zscore_per_file(X_raw)
        X_win = make_windows(X_raw, seq_len, stride)   # (M, T, D)
        if X_win.shape[0] == 0:
            print(f"[WARN] {os.path.basename(p)}: フレーム不足（{seq_len}未満）でスキップ")
            continue

        X_list.append(X_win)
        y_list.append(np.full((X_win.shape[0],), y_lab, dtype=np.int64))
        file_ids += [os.path.basename(p)] * X_win.shape[0]

    if not X_list:
        raise RuntimeError("データが作れませんでした。CSV と命名規則（ivdd/normal を含む）を確認してください。")

    X = np.concatenate(X_list, axis=0)
    y = np.concatenate(y_list, axis=0)
    file_ids = np.array(file_ids)
    print(f"[INFO] 使用キーポイント実名: {used_kps_any}")
    return X, y, file_ids

# ========= モデル（構造そのまま） =========
class LSTM_RNN(keras.Model):
    def __init__(self, n_input, n_hidden, n_classes):
        super().__init__()
        self.input_dense = keras.layers.Dense(n_hidden, activation='relu')
        self.time_dist   = keras.layers.TimeDistributed(self.input_dense)
        self.lstm1 = keras.layers.LSTM(n_hidden, return_sequences=True)
        self.lstm2 = keras.layers.LSTM(n_hidden)
        self.out   = keras.layers.Dense(n_classes)  # logits

    def call(self, x, training=False):
        x = self.time_dist(x)
        x = self.lstm1(x, training=training)
        x = self.lstm2(x, training=training)
        x = self.out(x)
        return x

class LSTMWithL2(LSTM_RNN):
    def __init__(self, n_input, n_hidden, n_classes, l2_lambda=1e-4):
        super().__init__(n_input, n_hidden, n_classes)
        self.l2_lambda = l2_lambda
        self.loss_fn   = keras.losses.CategoricalCrossentropy(from_logits=True)
        self.metric_acc  = keras.metrics.CategoricalAccuracy(name="accuracy")
        self.metric_loss = keras.metrics.Mean(name="loss")

    def compile(self, optimizer, **kwargs):
        super().compile(optimizer=optimizer, **kwargs)

    def _l2(self):
        return self.l2_lambda * tf.add_n([tf.nn.l2_loss(v) for v in self.trainable_variables])

    def train_step(self, data):
        x, y = data
        with tf.GradientTape() as tape:
            logits = self(x, training=True)
            loss = self.loss_fn(y, logits) + self._l2()
        grads = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.trainable_variables))
        self.metric_loss.update_state(loss)
        self.metric_acc.update_state(y, logits)
        return {"loss": self.metric_loss.result(), "accuracy": self.metric_acc.result()}

    def test_step(self, data):
        x, y = data
        logits = self(x, training=False)
        loss = self.loss_fn(y, logits) + self._l2()
        self.metric_loss.update_state(loss)
        self.metric_acc.update_state(y, logits)
        return {"loss": self.metric_loss.result(), "accuracy": self.metric_acc.result()}

# ========= データ読み込み & 学習 =========
csv_files  = sorted(glob.glob(CSV_GLOB))
if not csv_files:
    raise FileNotFoundError(f"CSV が見つかりません: {CSV_GLOB}")

X, y, file_ids = build_dataset(csv_files, SEQ_LEN, STRIDE)
print("X:", X.shape, "y:", y.shape, "files:", len(np.unique(file_ids)))

# one-hot
y_oh = keras.utils.to_categorical(y, num_classes=N_CLASSES)

# ファイルリーク防止の分割
if VAL_SPLIT_BY_FILE:
    uniq = np.unique(file_ids)
    tr_files, va_files = train_test_split(uniq, test_size=0.2, random_state=42, shuffle=True)
    tr_mask = np.isin(file_ids, tr_files)
    va_mask = np.isin(file_ids, va_files)
    X_train, y_train = X[tr_mask], y_oh[tr_mask]
    X_val,   y_val   = X[va_mask], y_oh[va_mask]
else:
    X_train, X_val, y_train, y_val = train_test_split(X, y_oh, test_size=0.2, random_state=42, stratify=y)

# クラス不均衡対策（両クラスが存在しないとエラーになるのでガード）
class_weight = None
try:
    cls_w = compute_class_weight("balanced", classes=np.arange(N_CLASSES), y=np.argmax(y_train, axis=1))
    class_weight = {int(c): float(w) for c, w in enumerate(cls_w)}
except Exception as e:
    print("[WARN] class_weight の計算に失敗（クラス数不足？）:", e)

n_hidden = 30
model = LSTMWithL2(n_input=DIMS, n_hidden=n_hidden, n_classes=N_CLASSES, l2_lambda=1e-4)

decay_steps = 100000  # 元コード準拠の大きめ
lr_schedule = keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=LR, decay_steps=decay_steps, decay_rate=0.96, staircase=True
)
opt = keras.optimizers.Adam(learning_rate=lr_schedule)
model.compile(optimizer=opt)

ckpt_dir  = os.path.join(DATA_DIR, "checkpoints"); os.makedirs(ckpt_dir, exist_ok=True)
ckpt_path = os.path.join(ckpt_dir, "best.keras")
callbacks = [
    keras.callbacks.ModelCheckpoint(ckpt_path, monitor="val_accuracy", save_best_only=True, verbose=1),
    keras.callbacks.EarlyStopping(monitor="val_accuracy", patience=10, restore_best_weights=True, verbose=1),
    keras.callbacks.ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=5, min_lr=1e-5, verbose=1),
]

history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    class_weight=class_weight,
    verbose=1,
    callbacks=callbacks
)

# ========= 評価 =========
y_val_prob = model.predict(X_val, batch_size=BATCH_SIZE)
y_val_pred = np.argmax(y_val_prob, axis=1)
y_val_true = np.argmax(y_val, axis=1)

print("\nValidation report:")
print(classification_report(y_val_true, y_val_pred, target_names=CLASS_NAMES, digits=4))
cm = confusion_matrix(y_val_true, y_val_pred)
print("Confusion matrix:\n", cm)

# 学習曲線
plt.figure(figsize=(10,4))
plt.subplot(1,2,1); plt.plot(history.history["loss"]); plt.plot(history.history["val_loss"]); plt.title("Loss"); plt.legend(["train","val"])
plt.subplot(1,2,2); plt.plot(history.history["accuracy"]); plt.plot(history.history["val_accuracy"]); plt.title("Accuracy"); plt.legend(["train","val"])
plt.tight_layout(); plt.show()

# 最終保存
final_path = os.path.join(DATA_DIR, "ivdd_binary_lstm_final.keras")
model.save(final_path)
print("Saved model to:", final_path)
