<a href="https://colab.research.google.com/github/JamesMaxwellHarrison/gravity_simulator/blob/main/simulation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [8]:
import matplotlib.pyplot as plt
import numpy as np
from ipywidgets import interact, FloatSlider, IntSlider, Checkbox, Button, Label, VBox, HBox, interactive_output, Dropdown
from IPython.display import display
from matplotlib.table import Table

Pi_cache = None
P_cache = None

def plot_trade_potential(Y_orange=30, Y_green=30, Y_row=90,
                         orange_x=2.4, green_x=8.0, row_x=8.0,
                         line_width=1, sigma=5, seed=60, dispersion=0.5,
                         tau_mult=1, asymmetric_trade=0,
                         table_type='none', show_mr=False, show_gdp_labels=False):

    np.random.seed(int(seed))

    ###############
    ## Setup
    ###############

    # GDP
    Y_k, Y_l, Y_m, Y_n = np.random.dirichlet([1, 1, 1, 1]) * Y_row
    Y_w = Y_orange + Y_green + Y_k + Y_l + Y_m + Y_n

    # Position
    pos_orange = np.array([orange_x, 3.25])
    pos_green = np.array([green_x, 3.25]) + np.random.uniform(-1, 1, size=2) * dispersion
    scaled_positions = [
        np.array([row_x, 3.25]) + np.random.uniform(-1, 1, size=2) * dispersion
        for _ in range(4)
    ]

    # Storage
    positions = [pos_orange, pos_green] + scaled_positions
    gdps = [Y_orange, Y_green, Y_k, Y_l, Y_m, Y_n]
    colors = ['darkorange', 'lightgreen'] + ['skyblue'] * 4

    ###############
    ## Tau
    ###############

    N = len(positions)
    tau = np.ones((N, N))

    for i in range(N):
        for j in range(N):
            if i != j:
                d_ij = np.linalg.norm(positions[i] - positions[j])
                direction = np.sign(positions[j][0] - positions[i][0])  # +1 rightward, -1 leftward
                asym_factor = 1 + asymmetric_trade * direction / 2
                asym_factor = np.clip(asym_factor, 0.5, 1)  # Keep factor reasonable
                tau[i, j] = (1 + d_ij * tau_mult / 4) * asym_factor


    #############
    ## GE Solver
    ###############
    N = len(gdps)
    Y_i = np.array(gdps, dtype=np.float64)
    Y_w = np.sum(Y_i)
    tau = np.clip(tau, 1e-4, 100)

    tol = 1e-3
    max_iter = 100
    sigma = float(sigma)  # from slider
    epsilon = 1e-9

    # for iteration in range(max_iter):
    #     Pi_prev, P_prev = Pi.copy(), P.copy()

    #     # Update inward P_j
    #     for j in range(N):
    #         sum_term = 0.0
    #         for i in range(N):
    #             ratio = tau[i, j] / Pi[i]
    #             sum_term += (Y_i[i] / Y_w) * max(ratio, epsilon)**(1 - sigma)
    #         P[j] = max(sum_term, epsilon)**(1 / (1 - sigma))
    #         # Pi[j] = P[j]

    #     # Update outward Pi_i
    #     for i in range(N):
    #         sum_term = 0.0
    #         for j in range(N):
    #             ratio = tau[i, j] / P[j]
    #             sum_term += (Y_i[j] / Y_w) * max(ratio, epsilon)**(1 - sigma)
    #         Pi[i] = max(sum_term, epsilon)**(1 / (1 - sigma))


    #     # Check convergence
    #     if np.max(np.abs(P - P_prev)) < tol and np.max(np.abs(Pi - Pi_prev)) < tol:
    #         break

    # # Compute GE-consistent flows
    # flows = np.zeros((N, N))
    # for i in range(N):
    #     for j in range(N):
    #         denom = max(Pi[i] * P[j], epsilon)
    #         flows[i, j] = (Y_i[i] * Y_i[j] / Y_w) * (tau[i, j] / denom)**(1 - sigma)

    def compute_gravity_flows_optimized(Y_i, tau, sigma, tol=5e-3, max_iter=50, Pi_init=None, P_init=None):
        """
        Computes GE-consistent trade flows and multilateral resistance terms using an optimized solver.

        Parameters:
            Y_i (array): Vector of GDPs (Y₁, Y₂, ..., Yₙ)
            tau (2D array): NxN matrix of bilateral trade costs
            sigma (float): Elasticity of substitution
            tol (float): Convergence tolerance
            max_iter (int): Maximum number of iterations

        Returns:
            flows (2D array): NxN matrix of trade flows
            Pi (array): Vector of outward multilateral resistance terms
            P (array): Vector of inward multilateral resistance terms
        """

        N = len(Y_i)
        Y_w = np.sum(Y_i)

        # Precompute terms
        Y_share = Y_i / Y_w
        tau = np.clip(tau, 1e-4, 100)

        # Initialize resistance terms
        Pi = Pi_init.copy() if Pi_init is not None else np.ones(N)
        P = P_init.copy() if P_init is not None else np.ones(N)
        epsilon = 1e-9

        for _ in range(max_iter):
            Pi_prev = Pi.copy()
            P_prev = P.copy()

            # Update P_j
            for j in range(N):
                ratio = tau[:, j] / np.clip(Pi, epsilon, None)
                P[j] = np.power(np.sum(Y_share * np.clip(ratio, epsilon, None) ** (1 - sigma)), 1 / (1 - sigma))

            # Update Pi_i
            for i in range(N):
                ratio = tau[i, :] / np.clip(P, epsilon, None)
                Pi[i] = np.power(np.sum(Y_share * np.clip(ratio, epsilon, None) ** (1 - sigma)), 1 / (1 - sigma))

            # Convergence check
            if np.max(np.abs(P - P_prev)) < tol and np.max(np.abs(Pi - Pi_prev)) < tol:
                break

        # Vectorized trade flow computation
        numerator = np.outer(Y_i, Y_i) / Y_w
        denominator = np.outer(Pi, P)
        flows = numerator * (tau / np.clip(denominator, epsilon, None)) ** (1 - sigma)

        return flows, Pi, P

    global Pi_cache, P_cache

    flows, Pi, P = compute_gravity_flows_optimized(Y_i, tau, sigma, Pi_init=Pi_cache, P_init=P_cache)
    Pi_cache = Pi
    P_cache = P

    bilateral_trade = flows + flows.T



    #############
    ## Figure
    #############

    # Setup
    fig, ax = plt.subplots(figsize=(12, 6))
    ax.set_xlim(0, 12)
    ax.set_ylim(0, 6.5)
    ax.set_aspect('equal')
    ax.axis('off')

    # Circles
    def r(Y): return 0.15 + 0.02 * np.sqrt(Y)

    for idx, (pos, gdp, color) in enumerate(zip(positions, gdps, colors)):
        ax.add_patch(plt.Circle(pos, r(gdp), color=color, alpha=0.8))


    # Arrows
    for i in range(len(positions)):
        total_trade_i = np.sum(flows[i, :])
        for j in range(i + 1, len(positions)):
            lw_ij = flows[i, j] * line_width
            alpha_ij = min(1.0, 0.0005 + lw_ij / 4)

            # Draw Arrow
            ax.annotate('', xy=positions[j], xytext=positions[i],
                        arrowprops=dict(arrowstyle='<->', lw=lw_ij, color='black', alpha=alpha_ij))


    #############
    ## Labels
    ############

    # Circle Labels
    circle_labels = ['o', 'g', 'a', 'b', 'c', 'd']
    for i, (pos, label) in enumerate(zip(positions, circle_labels)):
      ax.text(pos[0], pos[1], label, ha='center', va='center', fontsize=12, color='black')

    # GDP and Multilateral resistance labels
    if show_mr or show_gdp_labels:
      for i, pos in enumerate(positions):
          label_parts = []
          if show_mr:
              label_parts.append(f"$\Pi_{circle_labels[i]}$: {Pi[i]:.2f}, " + f"$P_{circle_labels[i]}$: {P[i]:.2f}")
          if show_gdp_labels:
              label_parts.append(f"$Y_{circle_labels[i]}$: {gdps[i]:.0f}")
          label_text = ", ".join(label_parts)
          ax.text(pos[0], pos[1] + r(gdps[i]) + 0.05, label_text, ha='center', fontsize=10)

    # # Arrows
    # for i in range(len(positions)):
    #   for j in range(i + 1, len(positions)):

    #       # Label showing exports and imports
    #       midpoint = (positions[i] + positions[j]) / 2
    #       tau_ij = tau[i, j]
    #       tau_ji = tau[j, i]
    #       sub_ij = circle_labels[i]+circle_labels[j];
    #       sub_ji = circle_labels[j]+circle_labels[i];
    #       label_text = f"$x_{{{sub_ij}}} = {tau_ij:.1f}$\n$x_{{{sub_ji}}} =  {tau_ji:.1f}$"

    #       ax.text(midpoint[0], midpoint[1], label_text, fontsize=8, ha='center', va='center', color='black')



  # Gravity Equation
    orange = 0
    green = 1
    gravity_val = (gdps[orange] * gdps[green]) / Y_w * (tau[orange, green] / (P[orange] * P[green]))**(1 - sigma)
    ax.text(
        0.05, 6.2,
        fr"$X_{{og}} = \left(\frac{{Y_o Y_g}}{{Y_W}}\right) \cdot \left(\frac{{\tau_{{ij}}}}{{\Pi_i P_j}}\right)^{{1-\sigma}} =  \left(\frac{{{gdps[0]:.2f} \times {gdps[1]:.2f}}}{{{Y_w:.2f}}} \right) \cdot \left( \frac{{{tau[0,1]:.2f}}}{{{Pi[0]:.2f} \times {P[1]:.2f}}} \right)^{{{1 - sigma}}} = {gravity_val:.2f}$",
        fontsize=15, ha='left'
    )


    ##############################
    ## Tables
    ##############################

    # Table size and dimensions
    if table_type != 'none':
        ax_table = fig.add_axes([0.1, 0, 0.8, 0.25])
        ax_table.axis('off')
        tbl = Table(ax_table, bbox=[0, 0, 1, 1])
        n = len(circle_labels)

        for j in range(n):
            tbl.add_cell(0, j+1, 1/(n+1), 0.2, text=circle_labels[j], loc='center', facecolor='lightgray')
        tbl.add_cell(0, n+1, 1/(n+1), 0.2,
                    text="Total" if table_type == 'flows' else r"$\Pi$",
                    loc='center', facecolor='lightgray')
        for i in range(n):
            tbl.add_cell(i+1, 0, 0.1, 0.2, text=circle_labels[i], loc='center', facecolor='lightgray')
            row_sum = 0
            for j in range(n):
                val = flows[i,j] if table_type == 'flows' else tau[i,j]
                row_sum += val
                tbl.add_cell(i+1, j+1, 1/(n+1), 0.2, text=f"{val:.1f}", loc='center',
                             facecolor='lightyellow' if i==j else 'white')
            if table_type == 'flows':
                right_text = f"{row_sum:.1f}"
            else:
                right_text = f"{Pi[i]:.2f}"

            tbl.add_cell(i+1, n+1, 1/(n+1), 0.2, text=right_text, loc='center', facecolor='lightyellow')
        # Bottom row: Total (flows) or P (tau)
        for j in range(n):
            if table_type == 'flows':
                col_val = sum(flows[i, j] for i in range(n))
                cell_text = f"{col_val:.1f}"
            else:
                cell_text = f"{P[j]:.2f}"
            tbl.add_cell(n+1, j+1, 1/(n+1), 0.2, text=cell_text, loc='center', facecolor='lightyellow')

        # Bottom-right cell: total sum for flows, blank for tau
        tbl.add_cell(n+1, n+1, 1/(n+1), 0.2,
                     text=f"{flows.sum():.1f}" if table_type == 'flows' else "",
                     loc='center', facecolor='khaki')
        tbl.add_cell(n+1, 0, 0.1, 0.2,
                            text="Total" if table_type == 'flows' else "$P$",
                            loc='center', facecolor='lightgray')
        ax_table.add_table(tbl)

        # Add axis labels "Origin" and "Destination"
        ax_table.text(-0.01, 0.5, "Origin", ha='center', va='center', rotation='vertical',
                      transform=ax_table.transAxes, fontsize=10)

        ax_table.text(0.55, 1.08, "Destination", ha='center', va='center',
                      transform=ax_table.transAxes, fontsize=10)

    plt.show()

# Widget configuration
ui_xpos = VBox([
    FloatSlider(min=2.0, max=10.0, step=0.1, value=2.4, description="Orange X"),
    FloatSlider(min=2.0, max=10.0, step=0.1, value=8.0, description="Green X"),
    FloatSlider(min=2.0, max=10.0, step=0.1, value=8.0, description="ROW X")
])

ui_gdp = VBox([
    FloatSlider(min=1, max=100, step=1, value=30, description="GDP Orange"),
    FloatSlider(min=1, max=100, step=1, value=30, description="GDP Green"),
    FloatSlider(min=10, max=300, step=1, value=90, description="GDP ROW")
])

ui_misc = VBox([
    IntSlider(min=0, max=100, step=1, value=60, description="Seed"),
    IntSlider(min=2, max=10, step=1, value=5, description="Sigma"),
])

ui_tau = VBox([
    FloatSlider(min=0, max=2.5, step=0.1, value=1, description="Tau Multiplier"),
    FloatSlider(min=-1, max=1, step=0.1, value=0, description="Tau Asymmetry"),
])

ui_display = VBox([
    FloatSlider(min=0.25, max=4, step=0.25, value=1, description="Line Width"),
    FloatSlider(min=0.5, max=2.5, step=0.1, value=1.5, description="Dispersion"),
])

ui_labels = VBox([
    Checkbox(value=False, description="Show MR"),
    Checkbox(value=False, description="Show GDP"),
    Dropdown(options=['none', 'flows', 'tau'], value='none', description="Table Type")
    ])

controls = {
    'Y_orange': ui_gdp.children[0],
    'Y_green': ui_gdp.children[1],
    'Y_row': ui_gdp.children[2],
    'orange_x': ui_xpos.children[0],
    'green_x': ui_xpos.children[1],
    'row_x': ui_xpos.children[2],
    'seed': ui_misc.children[0],
    'sigma': ui_misc.children[1],
    'tau_mult': ui_tau.children[0],
    'asymmetric_trade': ui_tau.children[1],
    'line_width': ui_display.children[0],
    'dispersion': ui_display.children[1],
    'show_mr': ui_labels.children[0],
    'show_gdp_labels': ui_labels.children[1],
    'table_type': ui_labels.children[2],
}

out = interactive_output(plot_trade_potential, controls)
labeled_ui = VBox([
    HBox([
      VBox([Label("Location Controls"), *ui_xpos.children]),
      VBox([Label("GDP Parameters"), *ui_gdp.children]),
      VBox([Label("Model Settings"), *ui_misc.children]),
    ]),
    HBox([
     VBox([Label("Tau Settings"), *ui_tau.children]),
     VBox([Label("Display Settings"), *ui_display.children]),
     VBox([Label("Labels"), *ui_labels.children])
    ])
])

display(VBox([labeled_ui, out]))




VBox(children=(VBox(children=(HBox(children=(VBox(children=(Label(value='Location Controls'), FloatSlider(valu…