In [437]:
import os
import numpy as np
import time
import cProfile, pstats, io
from models import ModelParams, ModelShocks, ModelSol
from optimization import objective_w_hat_reduced
from scipy.optimize import minimize

In [431]:
out_dir = "output"
os.makedirs(out_dir, exist_ok=True)

# data = np.load("real_data.npz")
data = np.load("real_data_2017.npz")
N, J = data["N"], data["J"]
mp = ModelParams(
    N=N,
    J=J,
    alpha=data["alpha"],
    beta=data["beta"],
    gamma=data["gamma"],
    theta=data["theta"],
    pif=data["pi_f"],
    pim=data["pi_m"],
    tilde_tau=data["tilde_tau"],
    Xf=np.ones((N, J)),
    Xm=np.ones((N, J)),
    w0=data["VA"],
    L0=np.ones_like(data["VA"]),
    td=data["D"],
)

bench_shocks = ModelShocks.load_from_npz(os.path.join(out_dir, "benchmark/shocks.npz"), mp)


GLOBAL_mp = None
GLOBAL_bench_sol = None
GLOBAL_numeraire_index = None


class EarlyStopException(Exception):
    """Optimization early stop signal."""

    pass


def init_worker(mp, bench_sol, numeraire_index):
    """Initialize worker processes."""
    global GLOBAL_mp
    global GLOBAL_bench_sol
    global GLOBAL_numeraire_index
    GLOBAL_mp = mp
    GLOBAL_bench_sol = bench_sol
    GLOBAL_numeraire_index = numeraire_index

In [434]:
dim_reduced = N - 1
# Initial guess for the reduced problem
x0_guess = np.ones(dim_reduced)

# (B) Example of early stop with callback
best_x = [None]
res = None

iter_count = [0]

n = len(x0_guess)
eps = 1e-12  # 0に限りなく近い正の値を設定
bnds = [(eps, None)] * n  # 下限：eps, 上限：制限なし

numeraire_index = 0

Xf_init = mp.Xf.copy()
Xm_init = mp.Xm.copy()


def callback_func(xk):
    """Callback function to check the objective value and stop the optimization."""
    iter_count[0] += 1  # Increment the iteration counter
    val = objective_w_hat_reduced(
        xk, mp, bench_shocks, Xf_init, Xm_init, numeraire_index
    )
    # Print the current loss value for each iteration
    print(f"Iteration {iter_count[0]}: loss = {val}")
    threshold = 1e-6
    if val < threshold:
        best_x[0] = xk.copy()
        raise EarlyStopException(
            f"Residual {val} < threshold {threshold}. Early stopping."
        )

# try:
#     # (C) Optimize w_hat by using Nelder-Mead method
#     res = minimize(
#         objective_w_hat_reduced,
#         x0_guess,
#         args=(mp, bench_shocks, Xf_init, Xm_init, numeraire_index),
#         # method="Nelder-Mead",
#         method="L-BFGS-B",
#         bounds=bnds,
#         callback=callback_func,
#         options={"maxiter": 10000, "disp": True},
#     )
# except EarlyStopException as e:
#     print("Early stop triggered:", e)


In [435]:
# Create a profiler object
profiler = cProfile.Profile()
profiler.enable()  # start profiling

# Run the optimization and capture the return value in 'result'
result = minimize(
    objective_w_hat_reduced,
    x0_guess,
    args=(mp, bench_shocks, Xf_init, Xm_init, numeraire_index),
    method="L-BFGS-B",
    bounds=bnds,
    callback=callback_func,
    options={"maxiter": 10000, "disp": True},
)

profiler.disable()  # stop profiling


Iteration 1: loss = 0.012346351852317361
Iteration 2: loss = 0.006001047895900565
Iteration 3: loss = 0.005843174136111443
Iteration 4: loss = 0.004451119477670087
Iteration 5: loss = 0.0035338343563096897
Iteration 6: loss = 0.0032484028623443895
Iteration 7: loss = 0.002954258071530411
Iteration 8: loss = 0.002945698302513719
Iteration 9: loss = 0.0029453131398309166
Iteration 10: loss = 0.002945293121120793
Iteration 11: loss = 0.00294529229266479


In [438]:
s = io.StringIO()
stats = pstats.Stats(profiler, stream=s).sort_stats('cumtime')
stats.print_stats(20)
print(s.getvalue())

         23047238 function calls (23047067 primitive calls) in 364.111 seconds

   Ordered by: cumulative time
   List reduced from 378 to 20 due to restriction <20>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
      3/2    0.000    0.000  364.111  182.055 /opt/miniconda3/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3541(run_code)
      3/2    0.000    0.000  364.111  182.055 {built-in method builtins.exec}
        1    0.000    0.000  364.110  364.110 /var/folders/q7/lwn2vx9138g3s5z2vnwb01z80000gn/T/ipykernel_31362/3296041888.py:1(<module>)
        1    0.000    0.000  364.110  364.110 /opt/miniconda3/lib/python3.12/site-packages/scipy/optimize/_minimize.py:53(minimize)
     2376    0.597    0.000  362.056    0.152 /opt/miniconda3/lib/python3.12/site-packages/scipy/optimize/_differentiable_functions.py:16(wrapped)
       71    0.000    0.000  361.161    5.087 /opt/miniconda3/lib/python3.12/site-packages/scipy/optimize/_differentiable_functi

## Check the calc_X function

In [443]:
import numpy as np
import time
from scipy.linalg import block_diag

# ---------------------------------------------------------------------------
# Define dummy ModelParams and ModelShocks classes for testing.
# In your actual code, you likely already have these defined.
# ---------------------------------------------------------------------------
class ModelParams:
    def __init__(self, N, J):
        self.N = N
        self.J = J
        self.alpha = np.random.rand(N, J)          # (N, J)
        self.beta = np.random.rand(N, J)           # (N, J)
        self.gamma = np.random.rand(N, J, J)         # (N, J, J)
        self.theta = np.random.rand(J) + 1.0         # (J,)
        self.pif = np.random.rand(N, N, J)           # (N, N, J)
        self.pim = np.random.rand(N, N, J)           # (N, N, J)
        self.tilde_tau = np.ones((N, N, J))          # (N, N, J)
        self.Xf = np.random.rand(N, J)               # (N, J)
        self.Xm = np.random.rand(N, J)               # (N, J)
        self.w0 = np.random.rand(N) + 1.0            # (N,)
        self.L0 = np.random.rand(N) + 1.0            # (N,)
        self.td = np.random.rand(N)                  # (N,)

class ModelShocks:
    def __init__(self, mp):
        self.params = mp
        self.lambda_hat = np.exp(np.random.normal(0.0, 0.2, size=(mp.N, mp.J)))  # (N, J)
        self.df_hat = np.ones((mp.N, mp.N, mp.J))  # (N, N, J)
        self.dm_hat = np.ones((mp.N, mp.N, mp.J))  # (N, N, J)
        self.tilde_tau_prime = np.ones((mp.N, mp.N, mp.J))  # (N, N, J)

# Define dimensions for testing
N, J = 30, 20
mp = ModelParams(N, J)
shocks = ModelShocks(mp)

# Define additional inputs
w_hat = np.random.rand(N) + 1.0   # (N,)
td_prime = np.random.rand(N)      # (N,)

# ---------------------------------------------------------------------------
# Functions for calculating A and B (unchanged from your specification)
# ---------------------------------------------------------------------------
def calc_A(w_hat, td_prime, mp):
    # Compute Af = alpha * (w_hat * w0 * L0 + td_prime) with shape (N, J)
    Af = mp.alpha * (w_hat * mp.w0 * mp.L0 + td_prime)[:, np.newaxis]
    Am = np.zeros_like(Af)  # (N, J)
    # Vectorize using default C-order; we later re-vectorize in Fortran order
    Af_vec = Af.reshape(-1)
    Am_vec = Am.reshape(-1)
    A = np.concatenate([Af_vec, Am_vec])  # (2 * N * J,)
    return A

def calc_B(pif_hat, pim_hat, mp, shocks):
    N, J = mp.alpha.shape
    # --- Compute Bff ---
    factorff = (shocks.tilde_tau_prime - 1) / shocks.tilde_tau_prime  # shape: (N, N, J)
    pif_prime = pif_hat * mp.pif  # shape: (N, N, J)
    # Sum over the importer dimension (axis=1) for each sector:
    U = np.sum(factorff * pif_prime, axis=1)  # shape: (N, J)
    V = mp.alpha  # shape: (N, J)
    u, v = U.reshape(-1), V.reshape(-1)  # each of shape (N*J,)
    Du = np.diag(u)
    # For C-order, rows (countries) are contiguous.
    # R = kron(I_N, ones(1, J)) sums over columns for each country.
    R = np.kron(np.eye(N), np.ones((1, J)))  # shape: (N, N*J)
    # P = kron(I_N, ones(J, 1)) replicates each country sum J times.
    P = np.kron(np.eye(N), np.ones((J, 1)))   # shape: (N*J, N)
    Dv = np.diag(v)
    Bff = Dv @ P @ R @ Du  # shape: (N*J, N*J)

    # --- Compute Bfm ---
    pim_prime = pim_hat * mp.pim  # shape: (N, N, J)
    U = np.sum(factorff * pim_prime, axis=1)  # shape: (N, J)
    u = U.reshape(-1)
    Du = np.diag(u)
    Bfm = Dv @ P @ R @ Du  # shape: (N*J, N*J)

    # --- Compute Bmf ---
    U_temp = pif_prime / shocks.tilde_tau_prime  # shape: (N, N, J)
    U_trans = U_temp.transpose(1, 0, 2)  # shape: (N, N, J) with indices (exporter, importer, sector)
    V_temp = mp.gamma  # shape: (N, J, J)
    # Compute Bmf_tensor with indices (n, s, i, k) = mp.gamma[n,k,s]*U_trans[i,n,k]
    Bmf_tensor = np.einsum("nks,ink->nsik", V_temp, U_trans)
    Bmf = Bmf_tensor.reshape((N * J, N * J))
    
    # --- Compute Bmm ---
    U_temp = pim_prime / shocks.tilde_tau_prime  # shape: (N, N, J)
    U_trans = U_temp.transpose(1, 0, 2)  # shape: (N, N, J)
    Bmm_tensor = np.einsum("nks,ink->nsik", V_temp, U_trans)
    Bmm = Bmm_tensor.reshape((N * J, N * J))
    
    # --- Assemble full B ---
    B_top = np.hstack((Bff, Bfm))       # shape: (N*J, 2*N*J)
    B_bottom = np.hstack((Bmf, Bmm))      # shape: (N*J, 2*N*J)
    B = np.vstack((B_top, B_bottom))    # shape: (2*N*J, 2*N*J)
    return B

# ---------------------------------------------------------------------------
# Full method: Solve the entire system (I - B)x = A, vectorizing in Fortran order.
# ---------------------------------------------------------------------------
def calc_X_full_F(w_hat, pif_hat, pim_hat, td_prime, mp, shocks):
    N, J = mp.alpha.shape
    I = np.eye(2 * N * J, dtype=np.float64)
    # Re-vectorize A and B in Fortran order.
    A_vec = calc_A(w_hat, td_prime, mp).reshape(-1, order='F')
    B_vec = calc_B(mp.pif, mp.pim, mp, shocks).reshape((2 * N * J, 2 * N * J), order='F')
    X_vec = np.linalg.solve(I - B_vec, A_vec)
    Xf_vec = X_vec[:N * J]
    Xm_vec = X_vec[N * J:]
    Xf = Xf_vec.reshape((N, J), order='F')
    Xm = Xm_vec.reshape((N, J), order='F')
    return Xf, Xm

# ---------------------------------------------------------------------------
# Blockwise method: Solve the system sector-by-sector (in Fortran order)
# ---------------------------------------------------------------------------
def calc_X_block_F(w_hat, pif_hat, pim_hat, td_prime, mp, shocks):
    N, J = mp.alpha.shape
    A_vec = calc_A(w_hat, td_prime, mp).reshape(-1, order='F')
    B_vec = calc_B(mp.pif, mp.pim, mp, shocks).reshape((2 * N * J, 2 * N * J), order='F')
    Xf_solution = np.empty((N, J), dtype=np.float64)
    Xm_solution = np.empty((N, J), dtype=np.float64)
    
    # In Fortran order, for an (N, J) matrix the columns (sectors) are contiguous.
    for j in range(J):
        # The block for sector j in Xf is contiguous in Fortran order:
        start = j * N
        end = (j + 1) * N
        # For Xf (first N*J entries) the indices for sector j:
        indices_f = np.arange(start, end)
        # For Xm (next N*J entries) the indices for sector j:
        indices_m = np.arange(N * J + start, N * J + end)
        sector_indices = np.concatenate([indices_f, indices_m])  # shape: (2N,)
        A_sector = A_vec[sector_indices]
        B_sector = B_vec[np.ix_(sector_indices, sector_indices)]
        I_sector = np.eye(2 * N, dtype=np.float64)
        X_sector = np.linalg.solve(I_sector - B_sector, A_sector)
        Xf_solution[:, j] = X_sector[:N]
        Xm_solution[:, j] = X_sector[N:]
    return Xf_solution, Xm_solution

# ---------------------------------------------------------------------------
# Test and time both methods.
# ---------------------------------------------------------------------------
start_full = time.perf_counter()
Xf_full, Xm_full = calc_X_full_F(w_hat, mp.pif, mp.pim, td_prime, mp, shocks)
end_full = time.perf_counter()
time_full = end_full - start_full

start_block = time.perf_counter()
Xf_block, Xm_block = calc_X_block_F(w_hat, mp.pif, mp.pim, td_prime, mp, shocks)
end_block = time.perf_counter()
time_block = end_block - start_block

print("Time for full method (Fortran order):", time_full, "seconds")
print("Time for block method (Fortran order):", time_block, "seconds")
print("Max difference in Xf:", np.max(np.abs(Xf_full - Xf_block)))
print("Max difference in Xm:", np.max(np.abs(Xm_full - Xm_block)))

Time for full method (Fortran order): 0.02467595797497779 seconds
Time for block method (Fortran order): 0.010439625009894371 seconds
Max difference in Xf: 0.0
Max difference in Xm: 328.13300351406866
