In [17]:
import matplotlib.pyplot as plt
import numpy as np

from bayesian_statistics.models.preprocessing.data_preprocessor import (
    ObsidianDataPreprocessor,
)
from bayesian_statistics.nngp.model import (
    MultinomialNNGPConfig,
    prepare_multinomial_dataset,
    run_mcmc,
)

data_dir = "/home/ohta/dev/bayesian_statistics/data"
preprocessor = ObsidianDataPreprocessor(data_dir, scale_variables=True)
preprocessor.load_data()

from pathlib import Path

import polars as pl

from bayesian_statistics.models.visualization import ObsidianVisualizer

variable_names = [
    "average_elevation",
    "average_slope_angle",
    "cost_kouzu",
    "cost_shinshu",
    "cost_hakone",
    "cost_takahara",
    "cost_river",
]

# 1. 期間ごとに推論を走らせ、grid_mean を保持しておく
time_periods = {
    0: "早期・早々期",
    1: "前期",
    2: "中期",
    3: "後期",
    4: "晩期",
}
origins = ["神津島", "信州", "箱根", "高原山", "その他"]

results_by_period = {}
datasets_by_period = {}
for period in time_periods:
    dataset = prepare_multinomial_dataset(
        preprocessor=preprocessor,
        period=period,
        origins=origins,
        variable_names=variable_names,
        grid_subsample_ratio=0.1,
        drop_zero_total_sites=True,
    )
    config = MultinomialNNGPConfig(
        n_iter=200,
        burn_in=50,
        thinning=1,
        neighbor_count=40,
        kernel_lengthscale=0.15,
        kernel_variance=1.0,
        seed=42,
    )
    results = run_mcmc(dataset, config)
    results_by_period[period] = results.posterior_grid_mean(sample_conditional=False)
    datasets_by_period[period] = dataset

1. building NNGP factor cache for 53 sites...
2. using 40 neighbors for NNGP approximation...


100%|██████████| 53/53 [00:00<00:00, 2591.16it/s]
100%|██████████| 135752/135752 [01:17<00:00, 1758.33it/s]


3. running MCMC...
4. saving 150 posterior samples...


100%|██████████| 200/200 [00:03<00:00, 57.58it/s]


1. building NNGP factor cache for 61 sites...
2. using 40 neighbors for NNGP approximation...


100%|██████████| 61/61 [00:00<00:00, 1835.90it/s]
100%|██████████| 135752/135752 [01:18<00:00, 1737.64it/s]


3. running MCMC...
4. saving 150 posterior samples...


100%|██████████| 200/200 [00:04<00:00, 49.96it/s]


1. building NNGP factor cache for 146 sites...
2. using 40 neighbors for NNGP approximation...


100%|██████████| 146/146 [00:00<00:00, 1966.07it/s]
100%|██████████| 135752/135752 [01:17<00:00, 1742.42it/s]


3. running MCMC...
4. saving 150 posterior samples...


100%|██████████| 200/200 [00:08<00:00, 22.88it/s]


1. building NNGP factor cache for 59 sites...
2. using 40 neighbors for NNGP approximation...


100%|██████████| 59/59 [00:00<00:00, 2579.25it/s]
100%|██████████| 135752/135752 [01:18<00:00, 1738.52it/s]


3. running MCMC...
4. saving 150 posterior samples...


100%|██████████| 200/200 [00:03<00:00, 53.20it/s]


1. building NNGP factor cache for 18 sites...
2. using 40 neighbors for NNGP approximation...


100%|██████████| 18/18 [00:00<00:00, 3912.39it/s]
100%|██████████| 135752/135752 [00:33<00:00, 4085.38it/s]


3. running MCMC...
4. saving 150 posterior samples...


100%|██████████| 200/200 [00:01<00:00, 156.52it/s]


In [18]:
save_dir = Path(
    "/home/ohta/dev/bayesian_statistics/notebooks/output/obsidian_nngp_multinomial"
)

In [19]:
def plot_result(origin_index: int, grid_mean, preprocessor, dataset, period: int):
    boundary = (
        preprocessor.df_elevation.filter(
            pl.col("average_elevation").is_null(), ~pl.col("is_sea")
        )
        .select(["x", "y"])
        .to_numpy()
    )
    true_ratio = dataset.counts / dataset.counts.sum(axis=1, keepdims=True)
    true_ratio = np.nan_to_num(true_ratio).T
    fig, ax = plt.subplots(1, 1, figsize=(8, 6), constrained_layout=True)

    subset = np.linspace(0, dataset.grid_points.shape[0] - 1, 10000, dtype=int)
    ax.scatter(
        dataset.grid_points[:, 0],
        dataset.grid_points[:, 1],
        c=grid_mean[origin_index, :],
        cmap="Blues",
        s=10,
        alpha=0.8,
    )

    ax.scatter(
        dataset.coords[:, 0],
        dataset.coords[:, 1],
        c=true_ratio[origin_index],
        cmap="Blues",
        s=25,
        edgecolors="black",
        linewidths=0.5,
        alpha=0.8,
        label="beta (intercept) at sites",
    )
    ax.scatter(boundary[:, 0], boundary[:, 1], c="grey", s=0.001)
    # colorbar
    plt.colorbar(ax.collections[0], ax=ax, label="posterior mean")
    ax.set_title(f"{origins[origin_index]}, {time_periods[period]}")
    ax.set_xlabel("longitude")
    ax.set_ylabel("latitude")

    # 枠線を細くする
    for spine in ax.spines.values():
        spine.set_linewidth(0.5)

    return fig


for period in time_periods:
    for origin_index in range(len(origins)):
        fig = plot_result(
            origin_index=origin_index,
            grid_mean=results_by_period[period],
            preprocessor=preprocessor,
            dataset=datasets_by_period[period],
            period=period,
        )
        fig.savefig(
            save_dir / f"nngp_multinomial_origin{origin_index}_period{period}.png"
        )
        plt.close(fig)