# Multinomial NNGP Lengthscale LOOCV

Aitchison distance を指標に、lengthscale の候補を LOOCV で比較します。


※ `max_sites_per_period` を変更すれば、全遺跡 vs. サブセットのいずれでも実行できます。


In [None]:
import gc
from pathlib import Path
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
import seaborn as sns
from IPython.display import display
from tqdm.auto import tqdm

from bayesian_statistics.models.preprocessing.data_preprocessor import (
    ObsidianDataPreprocessor,
)
from bayesian_statistics.nngp.model import (
    MultinomialDataset,
    MultinomialNNGPConfig,
    prepare_multinomial_dataset,
    run_mcmc,
)

sns.set_style("whitegrid")
plt.rcParams["figure.figsize"] = (10, 6)
gc.collect();

In [None]:
data_dir = "/home/ohta/dev/bayesian_statistics/data"
variable_names = [
    "average_elevation",
    "average_slope_angle",
    "cost_kouzu",
    "cost_shinshu",
    "cost_hakone",
    "cost_takahara",
    "cost_river",
]
time_periods = {
    0: "早期・早々期",
    1: "前期",
    2: "中期",
    3: "後期",
    4: "晩期",
}
origins = ["神津島", "信州", "箱根", "高原山", "その他"]

lengthscale_grid = [0.05, 0.1, 0.15, 0.2]
max_sites_per_period = 10  # None にすると全遺跡で実行

n_iter = 200
burn_in = 50
thinning = 1
neighbor_count = 40
kernel_variance = 1.0
seed_base = 1234

preprocessor = ObsidianDataPreprocessor(data_dir, scale_variables=True)
preprocessor.load_data()
gc.collect();

In [None]:
datasets_by_period: dict[int, MultinomialDataset] = {}
for period in time_periods:
    dataset = prepare_multinomial_dataset(
        preprocessor=preprocessor,
        period=period,
        origins=origins,
        variable_names=variable_names,
        grid_subsample_ratio=0.01,
        drop_zero_total_sites=True,
    )
    datasets_by_period[period] = dataset
    print(
        f"period {period}: sites={dataset.num_sites()}, grid={dataset.grid_points.shape[0]}"
    )

gc.collect();

In [None]:
def build_loocv_dataset(dataset_full: MultinomialDataset, drop_idx: int):
    keep_mask = np.ones(dataset_full.num_sites(), dtype=bool)
    keep_mask[drop_idx] = False

    coords = dataset_full.coords[keep_mask]
    counts = dataset_full.counts[keep_mask]
    totals = dataset_full.total_counts[keep_mask]
    design_sites = dataset_full.design_matrix_sites[keep_mask]
    site_ids = dataset_full.site_ids[keep_mask]

    left_coord = dataset_full.coords[drop_idx]
    left_design = dataset_full.design_matrix_sites[drop_idx]

    grid_points = np.vstack([dataset_full.grid_points, left_coord])
    design_grid = np.vstack([dataset_full.design_matrix_grid, left_design])

    dataset_train = MultinomialDataset(
        period=dataset_full.period,
        origins=dataset_full.origins,
        coords=coords,
        counts=counts,
        total_counts=totals,
        design_matrix_sites=design_sites,
        grid_points=grid_points,
        design_matrix_grid=design_grid,
        site_ids=site_ids,
        variable_names=dataset_full.variable_names,
    )

    target_grid_idx = grid_points.shape[0] - 1
    observed_counts = dataset_full.counts[drop_idx]
    observed_total = dataset_full.total_counts[drop_idx]
    observed_comp = observed_counts / observed_total

    return dataset_train, observed_comp, target_grid_idx


def aitchison_distance(p: np.ndarray, q: np.ndarray, eps: float = 1e-9) -> float:
    p_ = np.clip(p, eps, None)
    q_ = np.clip(q, eps, None)
    p_ = p_ / p_.sum()
    q_ = q_ / q_.sum()
    log_ratio = np.log(p_) - np.log(q_)
    log_ratio -= log_ratio.mean()
    return float(np.linalg.norm(log_ratio))


def iter_site_indices(dataset_full: MultinomialDataset, max_sites: Optional[int]):
    total = dataset_full.num_sites()
    if max_sites is None or max_sites >= total:
        return range(total)
    return range(max_sites)


gc.collect();

In [None]:
records: list[dict[str, float]] = []

for lengthscale in lengthscale_grid:
    print(f"Lengthscale {lengthscale}")
    for period, label in time_periods.items():
        dataset_full = datasets_by_period[period]
        limit = dataset_full.num_sites()
        if max_sites_per_period is not None:
            limit = min(limit, max_sites_per_period)
        if limit == 0:
            continue

        for drop_idx in tqdm(range(limit), leave=False, desc=f"period {period}"):
            dataset_train, observed_comp, target_idx = build_loocv_dataset(
                dataset_full, drop_idx
            )

            config = MultinomialNNGPConfig(
                n_iter=n_iter,
                burn_in=burn_in,
                thinning=thinning,
                neighbor_count=neighbor_count,
                kernel_lengthscale=lengthscale,
                kernel_variance=kernel_variance,
                seed=seed_base + period * 1000 + drop_idx,
            )

            results = run_mcmc(dataset_train, config)
            grid_probs = results.posterior_grid_mean(sample_conditional=False)
            pred_comp = grid_probs[:, target_idx]
            dist = aitchison_distance(pred_comp, observed_comp)

            records.append(
                {
                    "lengthscale": lengthscale,
                    "period": period,
                    "site_id": int(dataset_full.site_ids[drop_idx]),
                    "aitchison": dist,
                }
            )

            del results
            gc.collect()

results_df = pd.DataFrame(records)
display(results_df.head())

csv_path = Path("output/nngp_ratio_maps/loocv_results.csv")
csv_path.parent.mkdir(parents=True, exist_ok=True)
results_df.to_csv(csv_path, index=False)
print(f"Saved raw results to {csv_path}")

In [None]:
summary = (
    results_df.groupby("lengthscale")["aitchison"]
    .agg(["mean", "std", "count"])
    .sort_values("mean")
)
display(summary)
best_lengthscale = summary.index[0]
print(f"Best lengthscale (mean Aitchison): {best_lengthscale}")

summary_path = Path("output/nngp_ratio_maps/loocv_summary.csv")
summary.to_csv(summary_path)
print(f"Saved summary to {summary_path}")

In [None]:
plt.figure(figsize=(8, 6))
sns.boxplot(data=results_df, x="lengthscale", y="aitchison")
plt.title("LOOCV Aitchison distance by lengthscale")
plt.xlabel("lengthscale")
plt.ylabel("Aitchison distance")
plot_path = Path("output/nngp_ratio_maps/loocv_lengthscale_boxplot.png")
plt.savefig(plot_path, dpi=300, bbox_inches="tight")
plt.show()
print(f"Saved boxplot to {plot_path}")

上記の表・グラフを参考に、lengthscale のチューニングを行ってください。
