In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Jensen–Shannon（JS）情報量版：
ホワイトリスト16件の TorIP_geocoded.csv (ip,timestamp,lat,lng) を対象に

(1) 全ファイル間の JS 行列を作成（CSV/PNG＋plt.show）
(2) 隣接ペアの JS 時系列（x=2日の中点、日付縦、CSV/PNG＋plt.show）
(3) 行列の各行総和（その日の総乖離度）を時系列表示（PNG＋plt.show）
(4) 起動直後に ./wasserstein_matrix.csv を安全削除（互換対応）

分布推定は lat/lon の 2D ヒストグラム（共通ビン、εスムージング）で離散化して計算。
JS は 0〜ln(2) の範囲（対称・有限）。対数は自然対数。
"""

# ========= 設定 =========
PREFIX   = "js"        # 出力接頭辞
UNIT     = "deg"       # "deg" or "km"
MIN_ROWS = 1           # 各CSVで採用する最小有効行数
Z_THR_TS = 3.5         # 時系列のロバストZ（JSの時系列に対して）

# JS計算の設定
BINS_2D  = 64          # 2Dヒストのビン数（lat×lon）
EPS      = 1e-12       # スムージング（ゼロ割/ゼロ対数回避）
RANGE_PAD= 1e-6        # ヒストレンジに少しパディング

# 対象ファイル（この16件以外は処理しない＝ホワイトリスト）
TARGET_FILES = [
    "20251003142418-TorIP_geocoded.csv",
    "20251003143810-TorIP_geocoded.csv",
    "20251003164852-TorIP_geocoded.csv",
    "20251003172946-TorIP_geocoded.csv",
    "20251003173524-TorIP_geocoded.csv",
    "20251004173514-TorIP_geocoded.csv",
    "20251005173523-TorIP_geocoded.csv",
    "20251006173520-TorIP_geocoded.csv",
    "20251007173522-TorIP_geocoded.csv",
    "20251008173518-TorIP_geocoded.csv",
    "20251009173527-TorIP_geocoded.csv",
    "20251010173526-TorIP_geocoded.csv",
    "20251011173523-TorIP_geocoded.csv",
    "20251012173526-TorIP_geocoded.csv",
    "20251013173525-TorIP_geocoded.csv",
    "20251014173524-TorIP_geocoded.csv",
]
# =======================

import os, sys, csv, math
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
from matplotlib import font_manager, rcParams
from datetime import datetime
import matplotlib.dates as mdates

# ===== 日本語フォント（任意） =====
JP_FONT_CANDIDATES = ["Noto Sans CJK JP","Noto Serif CJK JP","IPAexGothic","IPAPGothic","TakaoGothic"]
def setup_japanese_font():
    available = {f.name for f in font_manager.fontManager.ttflist}
    for name in JP_FONT_CANDIDATES:
        if name in available:
            rcParams["font.family"] = "sans-serif"
            rcParams["font.sans-serif"] = [name]
            break
    rcParams["axes.unicode_minus"] = False

# ===== 文字コードフォールバック =====
ENCODINGS = ("utf-8", "utf-8-sig", "cp932")
def open_with_fallback(path):
    last = None
    for enc in ENCODINGS:
        try:
            return open(path, "r", encoding=enc, newline="")
        except Exception as e:
            last = e
    try:
        return open(path, "r", newline="")
    except Exception:
        raise last or RuntimeError(f"cannot open {path}")

# ===== ヘッダ自動判定 =====
def row_has_header_like(cells):
    if len(cells) < 4:
        return True
    try:
        float((cells[2] or "").strip()); float((cells[3] or "").strip())
        return False
    except Exception:
        return True

# ===== CSV → 有効 lat/lon 抽出 =====
def load_valid_latlon(filepath):
    lats, lons = [], []
    with open_with_fallback(filepath) as f:
        rdr = csv.reader(f)
        first = next(rdr, None)
        if first is None:
            return np.empty(0), np.empty(0)
        use_header = row_has_header_like(first)
        rows_iter = rdr if use_header else [first] + list(rdr)
        for row in rows_iter:
            if len(row) < 4:
                continue
            lat_s = (row[2] or "").strip()
            lon_s = (row[3] or "").strip()
            if not lat_s or not lon_s:
                continue
            try:
                lat = float(lat_s); lon = float(lon_s)
                if math.isfinite(lat) and math.isfinite(lon):
                    lats.append(lat); lons.append(lon)
            except Exception:
                continue
    return np.array(lats, float), np.array(lons, float)

# ===== 単位換算（deg→km） =====
def to_km(lat, lon, ref_lat_for_scale=None):
    if ref_lat_for_scale is None:
        ref_lat_for_scale = np.median(lat) if lat.size else 0.0
    la_km = 111.32
    lo_km = 111.32 * math.cos(float(ref_lat_for_scale) * math.pi/180.0)
    return lat * la_km, lon * lo_km

# ===== KL / JS（離散・2Dヒスト） =====
def _normalize_with_eps(arr):
    arr = arr.astype(float)
    s = arr.sum()
    if s <= 0:
        # 全ゼロ対策
        arr = np.ones_like(arr) / arr.size
    else:
        arr = arr / s
    arr = arr + EPS
    arr = arr / arr.sum()
    return arr

def kl_discrete(p, q):
    """KL(P||Q) for discrete distributions (both normalized, with EPS). natural log."""
    p = _normalize_with_eps(p)
    q = _normalize_with_eps(q)
    return float(np.sum(p * np.log(p / q)))

def js_discrete(p, q):
    """JS(P,Q) = 0.5*KL(P||M) + 0.5*KL(Q||M), with M=(P+Q)/2; natural log."""
    p = _normalize_with_eps(p)
    q = _normalize_with_eps(q)
    m = 0.5*(p + q)
    return 0.5*float(np.sum(p * np.log(p / m))) + 0.5*float(np.sum(q * np.log(q / m)))

def symmetric_js_2d(lat_a, lon_a, lat_b, lon_b, bins=BINS_2D):
    """
    2Dヒスト（共通ビン）で P,Q を作り JS(P,Q) を返す。
    JS は既に対称。値域は [0, ln(2)]。
    """
    lat_min = float(min(np.min(lat_a), np.min(lat_b))) - RANGE_PAD
    lat_max = float(max(np.max(lat_a), np.max(lat_b))) + RANGE_PAD
    lon_min = float(min(np.min(lon_a), np.min(lon_b))) - RANGE_PAD
    lon_max = float(max(np.max(lon_a), np.max(lon_b))) + RANGE_PAD

    H_a, xedges, yedges = np.histogram2d(lat_a, lon_a, bins=bins,
                                         range=[[lat_min, lat_max],[lon_min, lon_max]])
    H_b, _, _           = np.histogram2d(lat_b, lon_b, bins=[xedges, yedges])
    return js_discrete(H_a.ravel(), H_b.ravel())

# ===== ロバストZ（時系列に対して） =====
def robust_zscore(x: np.ndarray) -> np.ndarray:
    if x.size == 0:
        return np.zeros(0, float)
    med = np.median(x)
    mad = np.median(np.abs(x - med))
    if mad == 0:
        return np.zeros_like(x, float)
    return 0.6745 * (x - med) / mad

# ===== ファイル名先頭14桁 → (date, datetime_full) =====
def date_from_filename(fname: str):
    base = os.path.basename(fname)
    key = base.split("-")[0]
    if len(key) < 8:
        return None
    try:
        dt = datetime.strptime(key[:14], "%Y%m%d%H%M%S") if len(key) >= 14 else datetime.strptime(key[:8], "%Y%m%d")
        return dt.date(), dt
    except Exception:
        return None

def main():
    # 起動直後に ./wasserstein_matrix.csv を安全削除（互換要求）
    try:
        if os.path.exists("./wasserstein_matrix.csv"):
            os.remove("./wasserstein_matrix.csv")
            print("[INFO] removed: ./wasserstein_matrix.csv")
    except Exception as e:
        print(f"[WARN] cannot remove ./wasserstein_matrix.csv: {e}", file=sys.stderr)

    setup_japanese_font()

    # ホワイトリストのみ & 存在確認
    files_all = TARGET_FILES
    files = [f for f in files_all if os.path.exists(f)]
    missing = [f for f in files_all if not os.path.exists(f)]
    if missing:
        for m in missing:
            print(f"[WARN] missing: {m}", file=sys.stderr)
    if len(files) < 2:
        print("[ERROR] 処理対象のCSVが2つ未満です。", file=sys.stderr)
        return

    # 読み込み＆日付抽出
    records = []  # (date, datetime_full, fname, lats, lons)
    for f in files:
        d = date_from_filename(f)
        if d is None:
            print(f"[WARN] bad filename (no date): {f}", file=sys.stderr)
            continue
        day, dt_full = d
        lats, lons = load_valid_latlon(f)
        if len(lats) >= MIN_ROWS and len(lons) >= MIN_ROWS:
            if UNIT == "km":
                lats, lons = to_km(lats, lons)
            records.append((day, dt_full, f, lats, lons))
        else:
            print(f"[WARN] 有効行不足: {f}  lat={len(lats)} lon={len(lons)}", file=sys.stderr)

    if len(records) < 2:
        print("[ERROR] 有効なファイルが2つ未満です。", file=sys.stderr)
        return

    # 日時で並べ替え
    records.sort(key=lambda t: t[1])
    file_list  = [r[2] for r in records]
    base_names = [os.path.splitext(os.path.basename(f))[0] for f in file_list]
    n = len(records)

    # ---------- (1) 全ファイル間の JS 行列 ----------
    dist = np.zeros((n, n), float)
    for i in range(n):
        for j in range(i + 1, n):
            lat_i, lon_i = records[i][3], records[i][4]
            lat_j, lon_j = records[j][3], records[j][4]
            d = symmetric_js_2d(lat_i, lon_i, lat_j, lon_j, bins=BINS_2D)
            dist[i, j] = dist[j, i] = d

    # 出力：行列CSV（見出し付き）＆順序リスト
    matrix_csv = f"{PREFIX}_matrix.csv"
    files_txt  = f"{PREFIX}_files.txt"
    try:
        if os.path.exists(matrix_csv):
            os.remove(matrix_csv)
    except Exception:
        pass

    with open(matrix_csv, "w", encoding="utf-8", newline="") as w:
        wr = csv.writer(w)
        wr.writerow([""] + base_names)
        for i, name in enumerate(base_names):
            wr.writerow([name] + [f"{x:.6f}" for x in dist[i]])

    with open(files_txt, "w", encoding="utf-8") as w:
        w.write("\n".join(file_list))

    # ヒートマップ（保存＋表示）
    fig = plt.figure(figsize=(max(6, n * 0.7), max(5, n * 0.7)))
    ax = plt.gca()
    im = ax.imshow(dist, aspect="equal", interpolation="nearest")
    cbar = plt.colorbar(im)
    cbar.set_label(f"Jensen–Shannon (bins={BINS_2D})")
    ax.set_xticks(range(n)); ax.set_yticks(range(n))
    ax.set_xticklabels(base_names, rotation=70, ha="right")
    ax.set_yticklabels(base_names)
    ax.set_title("Jensen–Shannon divergence heatmap")
    plt.tight_layout()
    out_png_matrix = f"{PREFIX}_heatmap.png"
    plt.savefig(out_png_matrix, dpi=200, bbox_inches="tight")
    plt.show()
    plt.close()

    # ---------- (2) 隣接ペア JS の時系列（x=2日の中点） ----------
    dt_full = [r[1] for r in records]
    js_vals  = []
    x_dates  = []
    file_pair= []
    for i in range(1, n):
        latL, lonL = records[i-1][3], records[i-1][4]
        latR, lonR = records[i][3],  records[i][4]
        d = symmetric_js_2d(latL, lonL, latR, lonR, bins=BINS_2D)
        js_vals.append(d)
        # 中点日時（ズレ防止）
        mid = dt_full[i-1] + (dt_full[i] - dt_full[i-1]) / 2
        x_dates.append(mid)
        file_pair.append(f"{base_names[i-1]} → {base_names[i]}")

    x_dates = np.array(x_dates, dtype=object)
    js_vals = np.array(js_vals, float)

    # ロバストZと閾線
    z = robust_zscore(js_vals)
    is_anom = np.abs(z) > Z_THR_TS
    med = np.median(js_vals)
    mad = np.median(np.abs(js_vals - med))
    thr_line = med + (Z_THR_TS/0.6745)*mad if mad > 0 else med

    # CSV
    ts_csv = f"{PREFIX}_timeseries.csv"
    try:
        if os.path.exists(ts_csv):
            os.remove(ts_csv)
    except Exception:
        pass
    with open(ts_csv, "w", encoding="utf-8", newline="") as w:
        wr = csv.writer(w)
        wr.writerow(["date_mid","file_pair","js","z"])
        for i in range(len(js_vals)):
            wr.writerow([
                x_dates[i].strftime("%Y-%m-%d %H:%M:%S"),
                file_pair[i],
                f"{js_vals[i]:.6f}",
                f"{z[i]:.3f}"
            ])

    # プロット（保存＋表示、日付ラベルは縦）
    fig = plt.figure(figsize=(10, 4))
    ax  = plt.gca()
    idx_n = np.where(~is_anom)[0]; idx_a = np.where(is_anom)[0]
    ax.plot(x_dates[idx_n], js_vals[idx_n], marker="o", linestyle="none", label="normal",  alpha=0.7)
    ax.plot(x_dates[idx_a], js_vals[idx_a], marker="x", linestyle="none", label="anomaly", alpha=0.95)
    ax.axhline(thr_line, linestyle="--")
    ax.xaxis.set_major_locator(mdates.AutoDateLocator())
    ax.xaxis.set_major_formatter(mdates.AutoDateFormatter(ax.xaxis.get_major_locator()))
    plt.setp(ax.get_xticklabels(), rotation=90, ha="center")
    plt.xlabel("date (midpoint of two consecutive files from filename)")
    plt.ylabel("Jensen–Shannon divergence (2D histogram)")
    plt.title("Consecutive-pair JS divergence (x=mid-date)")
    plt.legend()
    plt.tight_layout()
    out_png_ts = f"{PREFIX}_timeseries.png"
    plt.savefig(out_png_ts, dpi=200, bbox_inches="tight")
    plt.show()
    plt.close()

    # ---------- (3) JS行列の各行総和（その日の総乖離度） ----------
    row_sums = dist.sum(axis=1)                 # 各行の総和
    dt_full_sorted = [r[1] for r in records]    # 並べ替え済みの日時

    fig = plt.figure(figsize=(10, 4))
    ax = plt.gca()
    ax.plot(dt_full_sorted, row_sums, marker="o")
    ax.xaxis.set_major_locator(mdates.AutoDateLocator())
    ax.xaxis.set_major_formatter(mdates.AutoDateFormatter(ax.xaxis.get_major_locator()))
    plt.setp(ax.get_xticklabels(), rotation=90, ha="center")
    plt.xlabel("date (from filename)")
    plt.ylabel("row-sum of JS divergence")
    plt.title("Overall deviation per day (row-sum of JS)")
    plt.tight_layout()
    out_png_rowsum = f"{PREFIX}_row_sums.png"
    plt.savefig(out_png_rowsum, dpi=200, bbox_inches="tight")
    plt.show()
    plt.close()

    # レポート
    print("✅ 出力完了")
    print("  ", matrix_csv,     "（JS行列CSV）")
    print("  ", files_txt,      "（行列のファイル順）")
    print("  ", out_png_matrix, "（JSヒートマップPNG＋表示）")
    print("  ", ts_csv,         "（隣接ペアのJS時系列CSV）")
    print("  ", out_png_ts,     "（隣接ペアのJS時系列PNG＋表示）")
    print("  ", out_png_rowsum, "（行総和の時系列PNG＋表示）")

if __name__ == "__main__":
    main()
