In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.optimize import least_squares

# -------------------- Core helpers --------------------


def neighbors_8(idx, G):
    """Return list of 8-neighbors (Moore) for linear index idx on a GxG grid."""
    i, j = divmod(idx, G)
    neigh = []
    for di in (-1, 0, 1):
        for dj in (-1, 0, 1):
            if di == 0 and dj == 0:
                continue
            ni, nj = i + di, j + dj
            if 0 <= ni < G and 0 <= nj < G:
                neigh.append(ni * G + nj)
    return neigh


def build_M(G):
    """Build linear map S = M n given 8-neighborhood site-value rule."""
    N = G * G
    M = np.zeros((N, N))
    for idx in range(N):
        neigh = neighbors_8(idx, G)
        M[idx, idx] = 1.0
        if len(neigh) > 0:
            M[idx, neigh] = 0.5 / len(neigh)
    return M


def fischer_burmeister(a, b):
    """Fischer–Burmeister function φ(a,b) = sqrt(a^2 + b^2) - a - b."""
    return np.sqrt(a * a + b * b) - a - b


# -------------------- Equilibrium solver --------------------


def solve_equilibrium(
    c,
    beta=0.95,
    kappa=0.0,
    delta_b=0.8,
    G=10,
    mass_total=1.0,
    n_init=None,
    u_init=0.0,
    tol=1e-12,
    max_nfev=4000,
):
    """Solve for equilibrium given cost c and other parameters.
    Returns dict with n,u,r,S, vacancy stats, profits, values, etc.
    """
    N = G * G
    M = build_M(G)

    # Initial guess
    rng = np.random.default_rng(123) if n_init is None else None
    n0 = rng.dirichlet(np.ones(N)) * mass_total if n_init is None else n_init
    z0 = np.concatenate([n0, [u_init]])

    lower = np.concatenate([np.zeros(N), [-np.inf]])
    upper = np.concatenate([np.full(N, np.inf), [np.inf]])

    def system(z):
        n = z[:N]
        u = z[-1]
        S = M @ n
        lam = S - u - c
        fb = fischer_burmeister(lam, n)  # complementarity residuals
        pop = np.array([n.sum() - mass_total])  # population constraint
        return np.concatenate([fb, pop])

    res = least_squares(
        system,
        z0,
        bounds=(lower, upper),
        xtol=tol,
        ftol=tol,
        gtol=tol,
        max_nfev=max_nfev,
    )
    if not res.success:
        raise RuntimeError("Equilibrium solver failed: " + res.message)

    z = res.x
    n = z[:N]
    u = z[-1]
    n[n < 1e-12] = 0.0  # clean numerics

    S = M @ n
    r = S - u
    B = (c / delta_b) * n
    pi = (r - c) * n - kappa * B
    V = pi / (1 - beta)

    threshold = 1e-10
    occupied_mask = n > threshold
    vacancy_rate_cells = 1 - occupied_mask.mean()  # share of empty cells
    vacancy_rate_mass = (
        1 - n[occupied_mask].sum()
    )  # should be 0 with mass constraint, but kept for generality

    return dict(
        n=n,
        u=u,
        S=S,
        r=r,
        B=B,
        pi=pi,
        V=V,
        vacancy_rate_cells=vacancy_rate_cells,
        vacancy_rate_mass=vacancy_rate_mass,
        solver=res,
    )


# -------------------- Comparative statics wrapper --------------------


def compare_over_c(c_values, beta=0.95, kappa=0.0, delta_b=0.8, G=10, mass_total=1.0):
    """Loop over a list/array of c values and collect equilibrium statistics.
    Warm-start each solve with the previous solution to speed up.
    """
    results = []
    n_guess = None
    u_guess = 0.0
    for c in c_values:
        out = solve_equilibrium(
            c, beta, kappa, delta_b, G, mass_total, n_init=n_guess, u_init=u_guess
        )
        results.append(
            dict(
                c=c,
                u=out["u"],
                mean_r=out["r"].mean(),
                vacancy_cells=out["vacancy_rate_cells"],
                vacancy_mass=out["vacancy_rate_mass"],
            )
        )
        # warm starts
        n_guess = out["n"]
        u_guess = out["u"]
    return pd.DataFrame(results)


# -------------------- Run comparative statics --------------------

c_grid = np.linspace(0.5, 100, 21)  # 0.5 to 5 in steps of 0.2
cs_df = compare_over_c(c_grid)

print("Comparative statics results:")
print(cs_df)

# -------------------- Plots --------------------
plt.figure()
plt.plot(cs_df["c"], cs_df["vacancy_cells"], marker="o")
plt.xlabel("c")
plt.ylabel("Vacancy rate (cells)")
plt.title("Vacancy rate across c")
plt.tight_layout()

plt.figure()
plt.plot(cs_df["c"], cs_df["u"], marker="o")
plt.xlabel("c")
plt.ylabel("u*")
plt.title("Equilibrium utility across c")
plt.tight_layout()

plt.figure()
plt.plot(cs_df["c"], cs_df["mean_r"], marker="o")
plt.xlabel("c")
plt.ylabel("Mean rent")
plt.title("Mean rent across c")
plt.tight_layout()

cs_df.head()