In [None]:
import sys
import contextlib
import os
import warnings

import numpy as np

from data_generating_process import gen_X, dgp_ols, dgp_sem, dgp_sar, dgp_gwr, dgp_grf
from diagnostics import diagnostics_and_heatmaps

if __name__ == "__main__":

    @contextlib.contextmanager
    def suppress_stdout():
        with open(os.devnull, 'w') as fnull:
            old_stdout = sys.stdout
            sys.stdout = fnull
            try:
                yield
            finally:
                sys.stdout = old_stdout

    warnings.filterwarnings("ignore", message=".*is an island.*|not fully connected")
    warnings.filterwarnings(
        "ignore",
        message="The weights matrix is not fully connected",
        category=UserWarning,
        module="libpysal.weights.weights"
    )
    warnings.filterwarnings(
        "ignore",
        message="Casting complex values to real discards the imaginary part",
        category=np.exceptions.ComplexWarning,
        module="spreg.diagnostics"
    )

    rng = np.random.default_rng(121)
    
    n = 25
    
    X_shared = gen_X(n*n, rng=rng)

    # numbers chosen to exaggerate spatial autocorrelation
    df_ols, _ = dgp_ols(sigma=1.0, n=n, X=X_shared, rng=rng)
    df_sem, _ = dgp_sem(lam=0.8, sigma=0.3, n=n, X=X_shared,  rng=rng)
    df_sar, _ = dgp_sar(rho=0.8, sigma=0.3, n=n, X=X_shared, rng=rng)
    df_gwr, _ = dgp_gwr(beta_grad=np.array([3,1,0,-1,2,0.6,-0.4]),
                        sigma=0.3, n=n, X=X_shared, rng=rng)
    df_esf = dgp_grf(sigma2=1.0, ell=2.5, nugget=0.05, n=n, X=X_shared, rng=rng)
    
    datasets = dict(OLS=df_ols, SEM=df_sem, SAR=df_sar, GWR=df_gwr, ESF=df_esf)
        
    with suppress_stdout():
        diagnostics_and_heatmaps(datasets, sar_order=1, b_size=5, n_splits=10)

In [16]:
def _grid_to_matrix(df, value_col, n):
    grid = np.full((n, n), np.nan)
    for _, r in df.iterrows():
        grid[int(r["row"]), int(r["col"])] = r[value_col]
    return grid

def plot_discrete(df, value_col="y", n: int = 10, ax=None):
    """
    Discrete: use a qualitative colourmap (tab20)
    """
    grid = _grid_to_matrix(df, value_col, n)
    if ax is None:
        fig, ax = plt.subplots()
    im = ax.imshow(grid, cmap="viridis", interpolation="nearest")
    ax.set_xticks([])
    ax.set_yticks([])
    return ax, im

def plot_continuous(df, value_col="y", n:int = 10, ax=None, cmap="viridis"):
    grid = _grid_to_matrix(df, value_col, n)
    if ax is None:
        fig, ax = plt.subplots()
    im = ax.imshow(grid, cmap=cmap, interpolation="bilinear")
    ax.set_xticks([]); ax.set_yticks([])
    return ax, im

In [None]:
import numpy as np
import matplotlib.pyplot as plt

fig, axes = plt.subplots(3, 2, figsize=(15, 15))

plot_discrete(datasets["OLS"], n=30, ax=axes[0, 0])

plot_discrete(datasets["SEM"], n=30, ax=axes[0, 1])

plot_discrete(datasets["SAR"], n=30, ax=axes[1, 0])

plot_discrete(datasets["GWR"], n=30, ax=axes[1, 1])

plot_discrete(datasets["ESF"], n=30, ax=axes[2, 0])

axes[2, 1].axis("off")

plt.tight_layout()
plt.show()

for name in ["OLS", "SEM", "SAR", "GWR", "ESF"]:
    fig, ax = plt.subplots()
    plot_discrete(datasets[name], n=30, ax=ax)
    fig.tight_layout()
    fig.savefig(f"{name}_heatmap.png", dpi=150)
    plt.close(fig)
