<a href="https://colab.research.google.com/github/Ishita7078/test_tpu/blob/main/GPUvsTPU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Why TPUs ?

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


TPUs are tensor processing units developed by Google to  accelerate operations on a Tensorflow Graph. Each TPU packs up to 180 teraflops of floating-point performance and 64 GB of high-bandwidth memory onto a single board. Here is a comparions between TPUs and Nvidia GPUs. The y axis represents # images per seconds and the x axis is different models.

<img src="https://cdn-images-1.medium.com/max/800/1*tVHGjJHJrhKaKECT3Z4CIw.png" alt="Drawing" style="width: 150px;"/>

# Experiement

TPUs were only available on Google cloud but now they are available for free in Colab. We will be comparing TPU vs GPU here on colab using mnist dataset. We will compare the time of each step and epoch against different batch sizes.

# Downoad MNIST

In [None]:
import tensorflow as tf
import os
import numpy as np
from tensorflow.keras.utils import to_categorical

def get_data():

  #Load mnist data set
  (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

  x_train = x_train.astype('float32') / 255
  x_test = x_test.astype('float32') / 255

  x_train = np.expand_dims(x_train, 3)
  x_test = np.expand_dims(x_test, 3)

  y_train = to_categorical(y_train)
  y_test  = to_categorical(y_test)

  return x_train, y_train, x_test, y_test

# Basic CNN

Note that since we need to run the code on TPU we need to do more work. We need to specify the address of the TPU and tell tensorflow to run the model on the TPU cluster

In [None]:
from tensorflow.contrib.tpu.python.tpu import keras_support

def get_model(tpu = False):
  model = tf.keras.Sequential()

  #add layers to the model
  model.add(tf.keras.layers.Conv2D(filters=64, kernel_size=2, padding='same', activation='relu', input_shape=(28,28,1)))
  model.add(tf.keras.layers.MaxPooling2D(pool_size=2))
  model.add(tf.keras.layers.Dropout(0.3))

  model.add(tf.keras.layers.Conv2D(filters=32, kernel_size=2, padding='same', activation='relu'))
  model.add(tf.keras.layers.MaxPooling2D(pool_size=2))
  model.add(tf.keras.layers.Dropout(0.3))

  model.add(tf.keras.layers.Flatten())
  model.add(tf.keras.layers.Dense(256, activation='relu'))
  model.add(tf.keras.layers.Dropout(0.5))
  model.add(tf.keras.layers.Dense(10, activation='softmax'))

  #compile the model
  model.compile(loss='categorical_crossentropy',
               optimizer='adam',
               metrics=['accuracy'])

  #flag to run on tpu
  if tpu:
    tpu_grpc_url = "grpc://"+os.environ["COLAB_TPU_ADDR"]

    #connect the TPU cluster using the address
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu_grpc_url)

    #run the model on different clusters
    strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver)

    #convert the model to run on tpu
    model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy)
  return model

#GPU vs TPU


In [None]:
x_train, y_train, x_test, y_test = get_data()

Each time you want to run the model on TPU make sure to set the tpu flag and change the enviornment runtime via  Edit> Notebook Setting > Hardware Accelerator > TPU and then click save.

In [None]:
#set tpu = True if you want to run the model on TPU
model = get_model(tpu = False)

In [None]:
model.fit(x_train,
         y_train,
         batch_size=1024,
         epochs=10,
         validation_data=(x_test, y_test))

Train on 60000 samples, validate on 10000 samples
Epoch 1/3
Epoch 2/3
Epoch 3/3


# Benchmarks

Note that TPU setup takes some time when compiling the model and distributing the data in the clusters, so the first epoch will take alonger time. I only reported the time for the later epochs. I calculated the average time accross different epochs.

### Epoch Time ($s$)

$$\left[\begin{array}{c|c|c}  
 \textbf{Batch Size} & \textbf{GPU} & \textbf{TPU} \\
 256 & 6s & 6s\\  
 512 & 5s & 3s\\
 1024 & 4s & 2s\\
\end{array}\right]$$

### Step Time ($\mu s$)

$$\left[\begin{array}{c|c|c}  
 \textbf{Batch Size} & \textbf{GPU} & \textbf{TPU} \\
 256 & 94 \mu s & 97 \mu s\\  
 512 & 82 \mu  s& 58 \mu s \\
 1024 & 79 \mu s & 37 \mu s\\
\end{array}\right]$$

# References



*   https://qiita.com/koshian2/items/25a6341c035e8a260a01
*   https://medium.com/tensorflow/hello-deep-learning-fashion-mnist-with-keras-50fcff8cd74a
*   https://blog.riseml.com/benchmarking-googles-new-tpuv2-121c03b71384
*   https://cloudplatform.googleblog.com/2018/02/Cloud-TPU-machine-learning-accelerators-now-available-in-beta.html



In [19]:
import os

current_directory = os.getcwd()
print(current_directory)

import os, time
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

assert any(d.platform == "tpu" for d in jax.devices()), "Not using TPU"
print(jax.devices())

/content




[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]


In [20]:
@jax.jit
def _mvm_full(A, x):
    return A @ x

class TPUDenseMat:
    """
    Wraps a real dense matrix A so `A @ x` runs on TPU via JAX.
    Returns a numpy array, so the rest of the numpy model code remains unchanged.
    """
    def __init__(self, A_np: np.ndarray, jax_dtype=jnp.float32):
        A_np = np.asarray(A_np)
        if A_np.ndim != 2 or A_np.shape[0] != A_np.shape[1]:
            raise ValueError(f"Expected square 2D matrix, got shape {A_np.shape}")
        self.N = int(A_np.shape[0])
        self.jax_dtype = jax_dtype

        self.A_dev = jax.device_put(jnp.asarray(A_np, dtype=jax_dtype))

        x0 = jnp.zeros((self.N,), dtype=jax_dtype)
        _ = _mvm_full(self.A_dev, x0).block_until_ready()

    @property
    def shape(self):
        return (self.N, self.N)

    def __matmul__(self, x):
        x_np = np.asarray(x)
        if x_np.ndim != 1 or x_np.shape[0] != self.N:
            raise ValueError(f"Expected vector shape ({self.N},), got {x_np.shape}")

        x_dev = jax.device_put(jnp.asarray(x_np, dtype=self.jax_dtype))
        y = _mvm_full(self.A_dev, x_dev)
        y.block_until_ready()
        return np.asarray(y)

In [17]:
import numpy as np
import scipy.io as sio

def _load_mat_any(path):
    try:
        return sio.loadmat(path)
    except NotImplementedError:
        import h5py

        out = {}
        with h5py.File(path, 'r') as f:
            for k in f.keys():
                out[k] = np.array(f[k])
        return out


def _load_mat_var(path, var_name=None):
    d = _load_mat_any(path)
    if var_name is not None:
        if var_name not in d:
            raise KeyError(f"Variable '{var_name}' not found in {path}. Keys: {list(d.keys())}")
        return d[var_name]

    keys = [k for k in d.keys() if not k.startswith('__')]
    if len(keys) != 1:
        raise KeyError(f"Expected exactly one variable in {path}; found keys: {keys}")
    return d[keys[0]]


def _load_txt_flat(path):
    arr = np.loadtxt(path, delimiter=',')
    return np.asarray(arr).reshape(-1)


def initialize(load_trmult):
    global H0, a, a_norm, m2, C_vect, tau0, pop0, pop5, pop5_fertadj, popminus5, popminus10, ubar
    global trmult_reduced, n, earth_indices, indicator_sea, subs, beta, tail_bands, ind_islands
    global alpha, theta, Omega, vect_omega

    try:
        H0 = _load_mat_var('H0.mat', 'H0')
    except FileNotFoundError:
        H0 = _load_mat_var('H0.mat', 'H0')

    a = _load_mat_var('a_H0.mat', 'a_H0')
    tau0 = _load_mat_var('tau_H0.mat', 'tau_H0')
    m2 = _load_mat_var('m2.mat', 'm2')

    a = np.asarray(a).reshape(-1)
    tau0 = np.asarray(tau0).reshape(-1)
    m2 = np.asarray(m2).reshape(-1)

    a_norm = None

    pop0 = _load_txt_flat('l.csv')
    pop5 = _load_txt_flat('pop5.csv')
    popminus5 = _load_txt_flat('popminus5.csv')
    popminus10 = _load_txt_flat('popminus10.csv')
    pop5_fertadj = _load_txt_flat('pop5_fertadj.csv')

    H0_arr = np.asarray(H0)
    earth_indices = np.flatnonzero(H0_arr.reshape(-1) > 0)
    n = int(earth_indices.size)
    indicator_sea = (H0_arr == 0)

    ubar = _load_txt_flat('ubar.csv')
    ubar[np.isnan(ubar)] = 0
    ubar[np.isinf(ubar)] = 0

    if load_trmult == 1:
        trmult_reduced = _load_mat_var('drive/MyDrive/trmult_reduced.mat', 'trmult_reduced')
        trmult_reduced = np.asarray(trmult_reduced)
        trmult_reduced[trmult_reduced < 1e-12] = 0
    else:
        trmult_reduced = None

    C = _load_txt_flat('C.csv')
    C_stock = C[earth_indices]
    indices = np.unique(C_stock)
    C_stock_2 = C_stock.copy()
    for i, idx in enumerate(indices, start=1):
        C_stock_2[C_stock == idx] = i
    C[earth_indices] = C_stock_2

    subs = C.reshape(-1) + 1
    C_vect = C[earth_indices]

    beta = 0.965
    tail_bands = 0.2
    alpha = 0.06
    theta = 6.5
    Omega = 0.5

    results = [
        H0, a, a_norm, m2, C_vect, tau0, pop0, pop5, pop5_fertadj, popminus5, popminus10, ubar,
        trmult_reduced, n, earth_indices, indicator_sea, subs, None, beta, tail_bands, None, alpha, theta, Omega
    ]
    return results


In [4]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from pathlib import Path

def maps(series1, series2, series3, series4, t, earth_indices=None):
    """
    Creates 4 maps at time t.
    The series must be in order:
    - series1: l(t)
    - series2: u(t)
    - series3: prod(t)
    - series4: realgdp(t)
    """
    if earth_indices is None:
        from init import earth_indices as _earth_indices
        earth_indices = _earth_indices

    # Take logs of variables
    series1 = np.log(series1)
    series2 = np.log(series2)
    series3 = np.log(series3)
    series4 = np.log(series4)

    # Define titles for the plots
    titles = [
        'Log population density, time {}'.format(t),
        'Log utility, time {}'.format(t),
        'Log productivity, time {}'.format(t),
        'Log real GDP per capita, time {}'.format(t)
    ]

    # Define title names for saving files
    title_names = ['PD', 'U', 'PR', 'RO']

    # Plot each figure
    for i, (series, title, title_name) in enumerate(zip([series1, series2, series3, series4], titles, title_names), start=1):
        plt.figure()

        # Create the map array
        varm = np.full((180, 360), -np.inf)
        varm.flat[earth_indices] = series

        # Set color limits based on the series and time
        if i == 1:
            vmin, vmax = -10, 21
        elif i == 3:
            if t == 1:
                vmin, vmax = -3, 7
            elif t == 600:
                vmin, vmax = 11, 21
            else:
                vmin, vmax = None, None
        elif i == 4:
            if t == 1:
                vmin, vmax = -4, 3
            else:
                vmin, vmax = None, None
        else:
            vmin, vmax = None, None

        # Plot the map
        plt.imshow(varm, cmap='jet', vmin=vmin, vmax=vmax)
        plt.colorbar(label='Value', orientation='vertical')
        plt.title(title)

        # Save the output to disk
        Path('Maps').mkdir(parents=True, exist_ok=True)
        filename = 'Maps/{}_NF_{}_1000.png'.format(title_name, t)
        plt.savefig(filename)
        plt.close()  # Close the figure to avoid memory issues


In [5]:
import numpy as np
from scipy.special import gamma

def model(H, T, vars):
    # Global variables (make sure these are initialized or imported correctly)
    global a, a_norm, m2, ubar, C_vect, tau0, pop0, trmult_reduced, earth_indices, H0, n, alpha, theta, Omega
    H0, a, a_norm, m2, C_vect, tau0, pop0, _, _, _, _, ubar, trmult_reduced, n, earth_indices, _, _, _, _, _, _, alpha, theta, Omega = vars

    # Initialize parameters and output
    # Normalize population to population density
    H0_arr = np.asarray(H0).reshape(-1)
    popdens = pop0.copy()
    popdens[earth_indices] = popdens[earth_indices] / H0_arr[earth_indices]
    popdens[np.isinf(popdens)] = 0
    popdens[np.isnan(popdens)] = 0

    # Parameter values
    lbar = 5.9174e+09  # Total population
    lambda_ = 0.32     # Congestion externalities
    gamma1 = 0.319    # Elasticity of tomorrow's productivity w.r.t. today's innovation
    gamma2 = 0.99246  # Elasticity of tomorrow's productivity w.r.t. today's productivity
    eta = 1           # Parameter driving scale of technology diffusion
    mu = 0.8          # Labor share in production
    nu = 0.15         # Intercept parameter in innovation cost function
    ksi = 125         # Elasticity of innovation costs w.r.t. innovation
    sigma = 4         # Elasticity of substitution
    rad = 6371        # Radius of Earth
    psi = 1.8         # Subjective wellbeing parameter
    khi = lambda_ - (alpha - 1 + (lambda_ + gamma1 / ksi - (1 - mu)) * theta) / (1 + 2 * theta)
    kappa1 = ((mu * ksi + gamma1) / ksi) ** (-(mu + gamma1 / ksi) * theta) * mu ** (mu * theta) * \
             (ksi * nu / gamma1) ** (-gamma1 / ksi * theta) * gamma(1 - (sigma - 1) / theta) ** (theta / (sigma - 1))

    # Calculate utility from subjective wellbeing
    u0 = np.exp(psi * ubar[earth_indices])

    # Back out amenities
    a_norm = a * u0

    # Initialize output variables
    l = np.zeros((n, T))
    w = np.zeros((n, T))
    phi = np.zeros((n, T))
    realgdp = np.zeros((n, T))
    tau = np.zeros((n, T))
    u = np.zeros((n, T))
    uhat = np.zeros((n, T))

    # 2. Simulate the model

    # Update productivity from period 0 to period 1 levels according to equations (8), (12), and (13)
    avgprod = np.sum(tau0) / n
    tau[:, 0] = eta * tau0 ** gamma2 * avgprod ** (1 - gamma2) * \
                (gamma1 / (nu * (gamma1 + mu * ksi)) * popdens[earth_indices]) ** (gamma1 * theta / ksi)

    # Initial guess for uhat
    uhat_loop = np.ones(n) / n

    # Calculate equilibrium distribution for each period
    for t in range(T):
        print(f't={t + 1}')

        # Solve for uhat using equation (51)
        error = 1e+10

        # Pre-computed quantities used in the while loop
        aa = a_norm ** (theta ** 2 / (1 + 2 * theta))
        aa2 = a_norm ** ((1 + theta) / ((khi / Omega + (1 + theta) / (1 + 2 * theta)) * (1 + 2 * theta)))
        exponent_l = (1 - lambda_ * theta + (1 + theta) / (1 + 2 * theta) * (alpha - 1 + (lambda_ + gamma1 / ksi - (1 - mu)) * theta))
        input_integral_outer = aa * (H ** (theta / (1 + 2 * theta) - 1 + lambda_ * theta - (1 + theta) / (1 + 2 * theta) * (alpha - 1 + (lambda_ + gamma1 / ksi - (1 - mu)) * theta))) * \
                               tau[:, t] ** ((1 + theta) / (1 + 2 * theta)) * m2 ** (-exponent_l / Omega)

        input_integral_outer[np.isnan(input_integral_outer)] = 0
        input_uhat_inner = H ** ((lambda_ - (alpha + (lambda_ + gamma1 / ksi - (1 - mu)) * theta) / (1 + 2 * theta)) / (khi / Omega + (1 + theta) / (1 + 2 * theta))) * \
                            tau[:, t] ** (1 / ((khi / Omega + (1 + theta) / (1 + 2 * theta)) * (1 + 2 * theta))) * \
                            m2 ** (khi / (Omega * (khi / Omega + (1 + theta) / (1 + 2 * theta))))
        input_uhat_inner[H == 0] = 0

        # Inner loop
        while error >= 1e-2:
            uhat_old = uhat_loop.copy()
            input_integral_inner = input_integral_outer * uhat_loop ** (exponent_l / Omega - theta ** 2 / (1 + 2 * theta))
            input_integral_inner[uhat_loop == 0] = 0

            # Matrix product
            rhs = np.dot(trmult_reduced, input_integral_inner)
            eps_val = 1e-12
            rhs = np.maximum(rhs, eps_val)

            uhat_loop = aa2 * input_uhat_inner * rhs ** (1 / (khi * theta / Omega + theta * (1 + theta) / (1 + 2 * theta)))
            error = np.sum((uhat_loop - uhat_old) ** 2)

        uhat[:, t] = uhat_loop

        # Solve for u using equation (53)
        u[:, t] = uhat[:, t] / (lbar / np.sum(uhat[:, t] ** (1 / Omega) * m2 ** (-1 / Omega))) ** \
                  (Omega * (((1 / Omega) * (((lambda_ + (1 - mu) - gamma1 / ksi) * theta) - alpha) + theta) / theta - 1))

        # Solve for population using equation (7)
        l[:, t] = H ** -1 * u[:, t] ** (1 / Omega) * m2 ** (-1 / Omega)

        # Rescale L so that H * L sums to lbar
        l[:, t] = l[:, t] / np.sum(H * l[:, t]) * lbar

        # Calculate other quantities
        phi[:, t] = (gamma1 / (nu * (gamma1 + mu * ksi))) ** (1 / ksi) * l[:, t] ** (1 / ksi)
        w[:, t] = a_norm ** (-theta / (1 + 2 * theta)) * u[:, t] ** (theta / (1 + 2 * theta)) * H ** (-1 / (1 + 2 * theta)) * \
                  tau[:, t] ** (1 / (1 + 2 * theta)) * l[:, t] ** ((alpha - 1 + (lambda_ + gamma1 / ksi - (1 - mu)) * theta) / (1 + 2 * theta))

        # Normalize wages relative to Princeton, NJ
        w[:, t] = w[:, t] / w[3198, t]  # 3198 is the Python index for Princeton, NJ

        # Calculate real GDP per capita using equation (22)
        realgdp[:, t] = u[:, t] / a_norm * l[:, t] ** lambda_

        # Calculate trade to GDP ratio in periods 1 and T
        if t == 0 or t == T - 1:
            print('TOTAL IMPORTS TO WORLD GDP')
            trsharesum = np.dot(trmult_reduced, (tau[:, t] * l[:, t] ** (alpha - (1 - mu - gamma1 / ksi) * theta) * w[:, t] ** (-theta)))
            eps_val = 1e-12
            trsharesum = np.maximum(trsharesum, eps_val)
            domtrade = 0
            for i in range(n):
                for j in range(n):
                    if C_vect[i] == C_vect[j]:
                        domtrade += (tau[j, t] * l[j, t] ** (alpha - (1 - mu - gamma1 / ksi) * theta) * w[j, t] ** (-theta) *
                                     trmult_reduced[i, j] / trsharesum[i] * w[i, t] * H[i] * l[i, t])
            print(1 - domtrade / np.sum(w[:, t] * H * l[:, t]))

        # Update productivity according to equation (8)
        if t < T - 1:
            avgprod = np.sum(tau[:, t]) / n
            tau[:, t + 1] = eta * tau[:, t] ** gamma2 * avgprod ** (1 - gamma2) * phi[:, t] ** (gamma1 * theta)

    # Handle NaN values
    realgdp[np.isnan(realgdp)] = 0
    tau[np.isnan(tau)] = 0
    phi[np.isnan(phi)] = 0
    w[np.isnan(w)] = 0
    l[np.isnan(l)] = 0

    return l, w, u, tau, phi, realgdp


In [6]:
import numpy as np
from math import gamma

def backward(H, T, vars):
    # Ensure global variables are available
    global a_norm, m2, tau0, pop0, trmult_reduced, earth_indices, H0, n, alpha, theta, Omega
    H0, a, a_norm, m2, _, tau0, pop0, _, _, _, _, ubar, trmult_reduced, n, earth_indices, _, _, _, _, _, _, alpha, theta, Omega = vars

    # Initialize parameters and output
    # Normalize population to population density
    H0_arr = np.asarray(H0).reshape(-1)
    popdens = np.copy(pop0)
    popdens[earth_indices] = popdens[earth_indices] / H0_arr[earth_indices]
    popdens[np.isinf(popdens)] = 0
    popdens[np.isnan(popdens)] = 0

    if a_norm is None:
        psi = 1.8
        u0 = np.exp(psi * ubar[earth_indices])
        a_norm = np.asarray(a).reshape(-1) * u0

    # Parameter values
    lbar = 5.9174e+09
    lambda_ = 0.32
    gamma1 = 0.319
    gamma2 = 0.99246
    mu = 0.8
    nu = 0.15
    ksi = 125
    sigma = 4
    rad = 6371
    khi = lambda_ - (alpha - 1 + (lambda_ + gamma1 / ksi - (1 - mu)) * theta) / (1 + 2 * theta)
    kappa1 = ((mu * ksi + gamma1) / ksi) ** (-(mu + gamma1 / ksi) * theta) * \
             mu ** (mu * theta) * (ksi * nu / gamma1) ** (-gamma1 / ksi * theta) * \
             gamma(1 - (sigma - 1) / theta) ** (theta / (sigma - 1))

    # Initialize output variables
    l = np.zeros((n, T))
    u = np.zeros((n, T))
    w = np.zeros((n, T))
    phi = np.zeros((n, T))
    tau = np.zeros((n, T))
    realgdp = np.zeros((n, T))

    # 2. Simulate the model backwards

    # Initial guess for Lhat
    l_loop = np.copy(popdens[earth_indices])

    # Outer loop
    for t in range(T):
        print(f't={-t - 1}')

        # Next period's productivity
        if t > 0:
            taunext = tau[:, t - 1]
        else:
            taunext = tau0

        eps_val = 1e-12
        eps_pos = 1e-300
        taunext = np.asarray(taunext)
        taunext = np.maximum(taunext, eps_pos)
        taunext = np.minimum(taunext, 1e300)

        # Solve for Lhat
        error = 1e+10

        # Pre-computed quantities used in the while loop
        aa = a_norm ** (theta ** 2 / (1 + 2 * theta))
        aa2 = a_norm ** ((1 + theta) / ((khi + Omega * (1 + theta) / (1 + 2 * theta) + theta / (1 + 2 * theta) * gamma1 / (ksi * gamma2)) * (1 + 2 * theta)))
        exponent_l = (1 - lambda_ * theta + (1 + theta) / (1 + 2 * theta) * (alpha - 1 + (lambda_ + gamma1 / ksi - (1 - mu)) * theta))
        input_integral_outer = aa * H ** ((theta - theta ** 2 * Omega) / (1 + 2 * theta)) * \
                               taunext ** ((1 + theta) / (gamma2 * (1 + 2 * theta))) * \
                               m2 ** (-theta ** 2 / (1 + 2 * theta))
        input_integral_outer[~np.isfinite(input_integral_outer)] = 0
        denom_inner = (khi + Omega * (1 + theta) / (1 + 2 * theta) + theta / (1 + 2 * theta) * gamma1 / (ksi * gamma2))
        input_l_inner = H ** (-(1 + Omega * (1 + theta)) / (denom_inner * (1 + 2 * theta))) * \
                        taunext ** (1 / (denom_inner * gamma2 * (1 + 2 * theta))) * \
                        m2 ** (-(1 + theta) / (denom_inner * (1 + 2 * theta)))
        input_l_inner[H == 0] = 0
        input_l_inner[~np.isfinite(input_l_inner)] = 0
        input_l_inner = np.maximum(input_l_inner, 0)
        input_l_inner = np.minimum(input_l_inner, 1e300)

        # Inner loop - solve for l using equation (40)
        it = 0
        max_it = 2000
        while error >= 1:
            l_old = np.copy(l_loop)

            l_loop = np.maximum(l_loop, eps_pos)
            l_loop = np.minimum(l_loop, 1e300)

            input_integral_inner = input_integral_outer * \
                                   l_loop ** (exponent_l - Omega * theta ** 2 / (1 + 2 * theta) - theta * (1 + theta) / (1 + 2 * theta) * gamma1 / (ksi * gamma2))
            input_integral_inner[l_loop == 0] = 0
            input_integral_inner[~np.isfinite(input_integral_inner)] = 0

            # Matrix product
            rhs = np.dot(trmult_reduced, input_integral_inner)
            rhs = np.maximum(rhs, eps_val)

            l_loop = aa2 * input_l_inner * rhs ** (1 / ((khi + Omega * (1 + theta) / (1 + 2 * theta) + theta / (1 + 2 * theta) * gamma1 / (ksi * gamma2)) * theta))
            l_loop[~np.isfinite(l_loop)] = eps_pos
            l_loop = np.minimum(l_loop, 1e300)
            error = np.sum((l_loop - l_old) ** 2)
            if not np.isfinite(error):
                error = 0
            it += 1
            if it >= max_it:
                error = 0

        # Rescale L so that H * L sum to lbar
        denom = np.sum(H * l_loop)
        if (not np.isfinite(denom)) or denom <= eps_pos:
            denom = eps_pos
        l[:, t] = l_loop / denom * lbar

        # Back out productivity using equation (39)
        tau[:, t] = ((mu + gamma1 / ksi) / (gamma1 / ksi) * nu) ** (theta * gamma1 / (ksi * gamma2)) * \
                    taunext ** (1 / gamma2) * l[:, t] ** (-theta * gamma1 / (ksi * gamma2))
        avgprodtogamma2 = np.sum(tau[:, t]) / n
        tau[:, t] = avgprodtogamma2 ** (gamma2 - 1) * tau[:, t]
        tau[:, t] = np.maximum(tau[:, t], eps_pos)
        tau[:, t] = np.minimum(tau[:, t], 1e300)

        # Calculate utility
        u[:, t] = m2 * l[:, t] ** Omega * (kappa1 ** (1 / Omega) * \
                  ((mu + gamma1 / ksi) / (gamma1 / ksi) * nu) ** (gamma1 / (ksi * gamma2)) * \
                  (np.sum(tau[:, t]) / n) ** (1 / theta * (1 - 1 / gamma2)) * \
                  (lbar / denom) ** (1 / theta - 2 * lambda_ + (alpha - 1 + (lambda_ + gamma1 / ksi - (1 - mu)) * theta) / theta - Omega - gamma1 / (ksi * gamma2)))
        u[:, t][~np.isfinite(u[:, t])] = 0

        # Calculate real GDP per capita using equation (22)
        realgdp[:, t] = u[:, t] / a_norm * l[:, t] ** lambda_

        # Calculate innovation using equation (12) and (13)
        phi[:, t] = (gamma1 / (nu * (gamma1 + mu * ksi))) ** (1 / ksi) * l[:, t] ** (1 / ksi)

        # Calculate wage using equation (23)
        w[:, t] = a_norm ** (-theta / (1 + 2 * theta)) * u[:, t] ** (theta / (1 + 2 * theta)) * H ** (-1 / (1 + 2 * theta)) * \
                  tau[:, t] ** (1 / (1 + 2 * theta)) * l[:, t] ** ((alpha - 1 + (lambda_ + gamma1 / ksi - (1 - mu)) * theta) / (1 + 2 * theta))

        # Normalize wages relative to Princeton, NJ (Python index adjustment)
        w[:, t] = w[:, t] / w[3198, t]  # Adjust index as necessary for your data

    # Handle NaN values
    realgdp[np.isnan(realgdp)] = 0
    tau[np.isnan(tau)] = 0
    phi[np.isnan(phi)] = 0
    w[np.isnan(w)] = 0
    u[np.isnan(u)] = 0
    l[np.isnan(l)] = 0

    return l, u, w, tau, phi, realgdp


In [7]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from pathlib import Path

def maps(series1, series2, series3, series4, t, earth_indices=None):
    """
    Creates 4 maps at time t.
    The series must be in order:
    - series1: l(t)
    - series2: u(t)
    - series3: prod(t)
    - series4: realgdp(t)
    """
    if earth_indices is None:
        from init import earth_indices as _earth_indices
        earth_indices = _earth_indices

    # Take logs of variables
    series1 = np.log(series1)
    series2 = np.log(series2)
    series3 = np.log(series3)
    series4 = np.log(series4)

    # Define titles for the plots
    titles = [
        'Log population density, time {}'.format(t),
        'Log utility, time {}'.format(t),
        'Log productivity, time {}'.format(t),
        'Log real GDP per capita, time {}'.format(t)
    ]

    # Define title names for saving files
    title_names = ['PD', 'U', 'PR', 'RO']

    # Plot each figure
    for i, (series, title, title_name) in enumerate(zip([series1, series2, series3, series4], titles, title_names), start=1):
        plt.figure()

        # Create the map array
        varm = np.full((180, 360), -np.inf)
        varm.flat[earth_indices] = series

        # Set color limits based on the series and time
        if i == 1:
            vmin, vmax = -10, 21
        elif i == 3:
            if t == 1:
                vmin, vmax = -3, 7
            elif t == 600:
                vmin, vmax = 11, 21
            else:
                vmin, vmax = None, None
        elif i == 4:
            if t == 1:
                vmin, vmax = -4, 3
            else:
                vmin, vmax = None, None
        else:
            vmin, vmax = None, None

        # Plot the map
        plt.imshow(varm, cmap='jet', vmin=vmin, vmax=vmax)
        plt.colorbar(label='Value', orientation='vertical')
        plt.title(title)

        # Save the output to disk
        Path('Maps').mkdir(parents=True, exist_ok=True)
        filename = 'Maps/{}_NF_{}_1000.png'.format(title_name, t)
        plt.savefig(filename)
        plt.close()  # Close the figure to avoid memory issues


In [8]:
import numpy as np
import matplotlib.pyplot as plt
# from maps import maps
# import init
from pathlib import Path

def plots(H, realgdp_w, u_w, u2_w, prod_w, l, u, tau, realgdp):
    global m2, earth_indices, tail_bands, alpha, theta, Omega
    alpha = init.alpha
    theta = init.theta

    T = int(len(realgdp_w))

    # Calculate world productivity and real GDP, correlations
    prworld = prod_w
    rgdpworld = realgdp_w
    uworld = u_w
    u2world = u2_w
    prgrowth = np.zeros(T)
    rgdpgrowth = np.zeros(T)
    ugrowth = np.zeros(T)
    u2growth = np.zeros(T)
    corr_rgdppop = np.zeros((T, 2, 2))
    corr_prpop = np.zeros((T, 2, 2))
    corr_prrgdp = np.zeros((T, 2, 2))

    for t in range(T):
        if t > 0:
            prgrowth[t] = prworld[t] / prworld[t - 1]
            rgdpgrowth[t] = rgdpworld[t] / rgdpworld[t - 1]
            ugrowth[t] = uworld[t] / uworld[t - 1]
            u2growth[t] = u2world[t] / u2world[t - 1]

        rgdp_vector = realgdp[:, t]
        pop_vector = l[:, t]
        pr_vector = (tau[:, t] * l[:, t] ** alpha) ** (1 / theta)

        rgdp_vector = rgdp_vector[H > 0]
        pop_vector = pop_vector[H > 0]
        pr_vector = pr_vector[H > 0]

        corr_rgdppop[t, :, :] = np.corrcoef(np.log(rgdp_vector), np.log(pop_vector))
        corr_prpop[t, :, :] = np.corrcoef(np.log(pr_vector), np.log(pop_vector))
        corr_prrgdp[t, :, :] = np.corrcoef(np.log(pr_vector), np.log(rgdp_vector))

    # Time series plots
    fig, axs = plt.subplots(2, 3, figsize=(15, 10))

    axs[0, 0].plot(range(1, T), prgrowth[1:T])
    axs[0, 0].set_title('Growth rate of productivity')
    axs[0, 0].set_xlabel('Time')

    axs[0, 1].plot(range(1, T), rgdpgrowth[1:T])
    axs[0, 1].set_title('Growth rate of real GDP')
    axs[0, 1].set_xlabel('Time')

    axs[0, 2].plot(range(1, T), ugrowth[1:T])
    axs[0, 2].set_title('Growth rate of utility (u)')
    axs[0, 2].set_xlabel('Time')

    axs[1, 0].plot(range(T), np.log(prworld[:T]))
    axs[1, 0].set_title('Ln world average productivity')
    axs[1, 0].set_xlabel('Time')

    axs[1, 1].plot(range(T), np.log(rgdpworld[:T]))
    axs[1, 1].set_title('Ln world average real GDP')
    axs[1, 1].set_xlabel('Time')

    axs[1, 2].plot(range(T), np.log(uworld[:T]))
    axs[1, 2].set_title('Ln world utility (u)')
    axs[1, 2].set_xlabel('Time')

    plt.tight_layout()
    Path('Output').mkdir(parents=True, exist_ok=True)
    plt.savefig('Output/world_aggregates.png')
    plt.close(fig)

    # Additional Time series plots
    fig, axs = plt.subplots(2, 2, figsize=(12, 8))

    axs[0, 0].plot(range(1, T), ugrowth[1:T])
    axs[0, 0].set_title('Growth rate of utility (u)')
    axs[0, 0].set_xlabel('Time')

    axs[0, 1].plot(range(2, T), u2growth[2:T])
    axs[0, 1].set_title('Growth rate of E(u*epsilon)')
    axs[0, 1].set_xlabel('Time')

    axs[1, 0].plot(range(T), np.log(uworld[:T]))
    axs[1, 0].set_title('Ln world utility (u)')
    axs[1, 0].set_xlabel('Time')

    axs[1, 1].plot(range(T), np.log(u2world[:T]))
    axs[1, 1].set_title('Ln E(u*epsilon)')
    axs[1, 1].set_xlabel('Time')

    plt.tight_layout()
    Path('Output').mkdir(parents=True, exist_ok=True)
    plt.savefig('Output/world_utility.png')
    plt.close(fig)

    # Correlation plots
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))

    axs[0].plot(corr_rgdppop[:, 0, 1])
    axs[0].set_title('Corr (log real GDP per capita, log population density)')
    axs[0].set_xlabel('Time')

    axs[1].plot(corr_prpop[:, 0, 1])
    axs[1].set_title('Corr (log productivity, log population density)')
    axs[1].set_xlabel('Time')

    axs[2].plot(corr_prrgdp[:, 0, 1])
    axs[2].set_title('Corr (log productivity, log real GDP per capita)')
    axs[2].set_xlabel('Time')

    plt.tight_layout()
    Path('Output').mkdir(parents=True, exist_ok=True)
    plt.savefig('Output/correlations.png')
    plt.close(fig)

    # Cell-level maps
    if T >= 1:
        maps(l[:, 0], u[:, 0], (tau[:, 0] * l[:, 0] ** alpha) ** (1 / theta), realgdp[:, 0], 1)
    if T >= 200:
        maps(l[:, 199], u[:, 199], (tau[:, 199] * l[:, 199] ** alpha) ** (1 / theta), realgdp[:, 199], 200)
    if T >= 600:
        maps(l[:, 599], u[:, 599], (tau[:, 599] * l[:, 599] ** alpha) ** (1 / theta), realgdp[:, 599], 600)


In [9]:
import numpy as np
import scipy.io as sio
from scipy.stats import pearsonr
# from model import model

def accumarray(indices, values, size):
    # A replacement for MATLAB's accumarray
    result = np.zeros(size)
    for idx, value in zip(indices, values):
        j = int(idx) - 1
        if 0 <= j < size:
            result[j] += value
    return result

def results(H, T, vars):
    # Global variables
    global a, a_norm, m2, C_vect, pop0, pop5, pop5_fertadj, beta, n, H0, earth_indices, alpha, theta, Omega
    H0, a, a_norm, m2, C_vect, _, pop0, pop5, pop5_fertadj, _, _, _, _, n, earth_indices, _, _, _, beta, _, _, alpha, theta, Omega = vars

    # Initialize output arrays
    realgdp_w = np.zeros(T)
    u_w = np.zeros(T)
    u2_w = np.zeros(T)
    prod_w = np.zeros(T)
    phi_w = np.zeros(T)
    PDV_u_w = 0
    PDV_u2_w = 0
    PDV_realgdp_w = 0

    # Simulate the model
    l, w, u, tau, phi, realgdp = model(H, T, vars)

    # Calculate correlations - Cell Level
    print('CORRELATIONS - CELL LEVEL')
    corr_pop5 = pearsonr(pop5[earth_indices], H * l[:, 4])
    corr_log_pop5 = pearsonr(np.log(pop5[earth_indices]), np.log(H * l[:, 4]))
    corr_pop5_diff = pearsonr(pop5[earth_indices] - pop0[earth_indices], H * l[:, 4] - pop0[earth_indices])
    corr_log_pop5_diff = pearsonr(np.log(pop5[earth_indices]) - np.log(pop0[earth_indices]), np.log(H * l[:, 4]) - np.log(pop0[earth_indices]))

    print('CORRELATIONS - COUNTRY LEVEL')
    pop5_ctry_d = accumarray(C_vect, pop5[earth_indices], 168)
    pop5_ctry_m = accumarray(C_vect, H * l[:, 4], 168)
    pop0_ctry = accumarray(C_vect, pop0[earth_indices], 168)
    corr_pop5_ctry_d = pearsonr(pop5_ctry_d, pop5_ctry_m)
    corr_log_pop5_ctry_d = pearsonr(np.log(pop5_ctry_d), np.log(pop5_ctry_m))
    corr_pop5_ctry_diff = pearsonr(pop5_ctry_d - pop0_ctry, pop5_ctry_m - pop0_ctry)
    corr_log_pop5_ctry_diff = pearsonr(np.log(pop5_ctry_d) - np.log(pop0_ctry), np.log(pop5_ctry_m) - np.log(pop0_ctry))

    # Fertility-Adjusted Correlations - Cell Level
    print('CORRELATIONS (FERTILITY-ADJUSTED) - CELL LEVEL')
    corr_pop5_fertadj = pearsonr(pop5_fertadj[earth_indices], H * l[:, 4])
    corr_log_pop5_fertadj = pearsonr(np.log(pop5_fertadj[earth_indices]), np.log(H * l[:, 4]))
    corr_pop5_fertadj_diff = pearsonr(pop5_fertadj[earth_indices] - pop0[earth_indices], H * l[:, 4] - pop0[earth_indices])
    corr_log_pop5_fertadj_diff = pearsonr(np.log(pop5_fertadj[earth_indices]) - np.log(pop0[earth_indices]), np.log(H * l[:, 4]) - np.log(pop0[earth_indices]))

    print('CORRELATIONS (FERTILITY-ADJUSTED) - COUNTRY LEVEL')
    pop5_fertadj_ctry = accumarray(C_vect, pop5_fertadj[earth_indices], 168)
    corr_pop5_fertadj_ctry = pearsonr(pop5_fertadj_ctry, pop5_ctry_m)
    corr_log_pop5_fertadj_ctry = pearsonr(np.log(pop5_fertadj_ctry), np.log(pop5_ctry_m))
    corr_pop5_fertadj_ctry_diff = pearsonr(pop5_fertadj_ctry - pop0_ctry, pop5_ctry_m - pop0_ctry)
    corr_log_pop5_fertadj_ctry_diff = pearsonr(np.log(pop5_fertadj_ctry) - np.log(pop0_ctry), np.log(pop5_ctry_m) - np.log(pop0_ctry))

    # Compute world aggregates
    u2 = np.zeros((n, T))
    m1 = np.power(m2, -1)
    for t in range(T):
        u2[:, t] = np.sum(np.power(u[:, t], 1 / Omega) * np.power(m2, -1 / Omega)) ** Omega * m2
        u_w[t] = np.sum(u[:, t] * H * l[:, t])
        u2_w[t] = np.sum(u2[:, t] * H * l[:, t])
        realgdp_w[t] = np.sum(realgdp[:, t] * H * l[:, t])
        prod_w[t] = np.sum(np.power(tau[:, t] * H * np.power(l[:, t], 1 + alpha), 1 / theta))
        phi_w[t] = np.sum(phi[:, t] * H * l[:, t])
        PDV_u_w += beta ** t * u_w[t]
        PDV_u2_w += beta ** t * u2_w[t]
        PDV_realgdp_w += beta ** t * realgdp_w[t]

    if beta * u_w[-1] / u_w[-2] < 1:
        PDV_u_w += (beta ** T * u_w[-1] ** 2 / u_w[-2]) / (1 - beta * u_w[-1] / u_w[-2])
    else:
        PDV_u_w = np.nan

    if beta * u2_w[-1] / u2_w[-2] < 1:
        PDV_u2_w += (beta ** T * u2_w[-1] ** 2 / u2_w[-2]) / (1 - beta * u2_w[-1] / u2_w[-2])
    else:
        PDV_u2_w = np.nan

    if beta * realgdp_w[-1] / realgdp_w[-2] < 1:
        PDV_realgdp_w += (beta ** T * realgdp_w[-1] ** 2 / realgdp_w[-2]) / (1 - beta * realgdp_w[-1] / realgdp_w[-2])
    else:
        PDV_realgdp_w = np.nan

    # Share of migrants - Cell Level
    migr_cell = np.zeros(T)
    for t in range(T):
        summ = 0
        for j in range(n):
            if t == 0:
                if H[j] * l[j, 0] > pop0[earth_indices[j]]:
                    summ += H[j] * l[j, 0] - pop0[earth_indices[j]]
            else:
                if H[j] * l[j, t] > H[j] * l[j, t - 1]:
                    summ += H[j] * l[j, t] - H[j] * l[j, t - 1]
        migr_cell[t] = summ / np.sum(pop0)

    migr_ctry = np.zeros(T)
    pop_ctry_m = np.zeros((168, T))
    for t in range(T):
        pop_ctry_m[:, t] = accumarray(C_vect, H * l[:, t], 168)
        summ = 0
        for i in range(168):
            if t == 0:
                if pop_ctry_m[i, 0] > pop0_ctry[i]:
                    summ += pop_ctry_m[i, 0] - pop0_ctry[i]
            else:
                if pop_ctry_m[i, t] > pop_ctry_m[i, t - 1]:
                    summ += pop_ctry_m[i, t] - pop_ctry_m[i, t - 1]
        migr_ctry[t] = summ / np.sum(pop0_ctry)

    return realgdp_w, u_w, u2_w, prod_w, phi_w, PDV_u_w, PDV_u2_w, PDV_realgdp_w, migr_cell, migr_ctry, l, u, u2, tau, realgdp


In [None]:
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
from pathlib import Path

# from init import initialize
# from results import results
# from backward import backward
# from plots import plots
import time
_program_start_time = time.perf_counter()

global H0, a, a_norm, m2, C_vect, tau0, pop0, pop5, pop5_fertadj, popminus5, popminus10, ubar, trmult_reduced, n, earth_indices, indicator_sea, subs, subs_vect, beta, tail_bands, ind_islands, alpha, theta, Omega

# Initialize model
vars = initialize(1)
H0, a, a_norm, m2, C_vect, tau0, pop0, pop5, pop5_fertadj, popminus5, popminus10, ubar, trmult_reduced, n, earth_indices, indicator_sea, subs, subs_vect, beta, tail_bands, ind_islands, alpha, theta, Omega = vars
JAX_DTYPE = jnp.float32  # or jnp.bfloat16

_trmult_np = np.asarray(trmult_reduced)

trmult_reduced = TPUDenseMat(_trmult_np, jax_dtype=JAX_DTYPE)

print("Wrapped init.trmult_reduced for TPU matvec:","shape=", trmult_reduced.shape, "dtype=", JAX_DTYPE)

# Distribution of land for simulation
H0_arr = np.asarray(H0).reshape(-1)
H = H0_arr[earth_indices]

# Number of periods
nb_per = 600

# Run the model and obtain summary statistics
results_data = results(H, nb_per, vars)
realgdp_w, u_w, u2_w, prod_w, phi_w, PDV_u_w, PDV_u2_w, PDV_realgdp_w, migr_cell, migr_ctry, l, u, u2, tau, realgdp = results_data

# Plot time series and maps, and save them
plots(H, realgdp_w, u_w, u2_w, prod_w, l, u, tau, realgdp)

# Number of periods for backward simulation
nb_back = 180

# Run model backwards
l_b, u_b, w_b, tau_b, phi_b, realgdp_b = backward(H, nb_back, vars)

# Calculate correlations
def calculate_correlation(x, y):
    return np.corrcoef(x, y)[0, 1]

print('CORRELATIONS WITH 1995 DATA - CELL LEVEL')
print(calculate_correlation(popminus5[earth_indices], H0_arr[earth_indices] * l_b[:, 4]))
print(calculate_correlation(np.log(popminus5[earth_indices]), np.log(H0_arr[earth_indices] * l_b[:, 4])))
print(calculate_correlation(pop0[earth_indices] - popminus5[earth_indices], pop0[earth_indices] - H0_arr[earth_indices] * l_b[:, 4]))
print(calculate_correlation(np.log(pop0[earth_indices]) - np.log(popminus5[earth_indices]), np.log(pop0[earth_indices]) - np.log(H0_arr[earth_indices] * l_b[:, 4])))

print('CORRELATIONS WITH 1990 DATA - CELL LEVEL')
print(calculate_correlation(popminus10[earth_indices], H0_arr[earth_indices] * l_b[:, 9]))
print(calculate_correlation(np.log(popminus10[earth_indices]), np.log(H0_arr[earth_indices] * l_b[:, 9])))
print(calculate_correlation(pop0[earth_indices] - popminus10[earth_indices], pop0[earth_indices] - H0_arr[earth_indices] * l_b[:, 9]))
print(calculate_correlation(np.log(pop0[earth_indices]) - np.log(popminus10[earth_indices]), np.log(pop0[earth_indices]) - np.log(H0_arr[earth_indices] * l_b[:, 9])))

print('CORRELATIONS WITH 1995 DATA - COUNTRY LEVEL')
ctry_idx = C_vect.astype(int) - 1
popminus5_ctry_d = np.bincount(ctry_idx, weights=popminus5[earth_indices])
popminus5_ctry_m = np.bincount(ctry_idx, weights=H0_arr[earth_indices] * l_b[:, 4])
pop0_ctry = np.bincount(ctry_idx, weights=pop0[earth_indices])
print(calculate_correlation(popminus5_ctry_d, popminus5_ctry_m))
print(calculate_correlation(np.log(popminus5_ctry_d), np.log(popminus5_ctry_m)))
print(calculate_correlation(pop0_ctry - popminus5_ctry_d, pop0_ctry - popminus5_ctry_m))
print(calculate_correlation(np.log(pop0_ctry) - np.log(popminus5_ctry_d), np.log(pop0_ctry) - np.log(popminus5_ctry_m)))

print('CORRELATIONS WITH 1990 DATA - COUNTRY LEVEL')
popminus10_ctry_d = np.bincount(ctry_idx, weights=popminus10[earth_indices])
popminus10_ctry_m = np.bincount(ctry_idx, weights=H0_arr[earth_indices] * l_b[:, 9])
print(calculate_correlation(popminus10_ctry_d, popminus10_ctry_m))
print(calculate_correlation(np.log(popminus10_ctry_d), np.log(popminus10_ctry_m)))
print(calculate_correlation(pop0_ctry - popminus10_ctry_d, pop0_ctry - popminus10_ctry_m))
print(calculate_correlation(np.log(pop0_ctry) - np.log(popminus10_ctry_d), np.log(pop0_ctry) - np.log(popminus10_ctry_m)))

# Save all the output to disk
Path('Output').mkdir(parents=True, exist_ok=True)
sio.savemat('Output/realgdp_w.mat', {'realgdp_w': realgdp_w})
sio.savemat('Output/u_w.mat', {'u_w': u_w})
sio.savemat('Output/u2_w.mat', {'u2_w': u2_w})
sio.savemat('Output/prod_w.mat', {'prod_w': prod_w})
sio.savemat('Output/phi_w.mat', {'phi_w': phi_w})
sio.savemat('Output/PDV_u_w.mat', {'PDV_u_w': PDV_u_w})
sio.savemat('Output/PDV_u2_w.mat', {'PDV_u2_w': PDV_u2_w})
sio.savemat('Output/PDV_realgdp_w.mat', {'PDV_realgdp_w': PDV_realgdp_w})
sio.savemat('Output/migr_cell.mat', {'migr_cell': migr_cell})
sio.savemat('Output/migr_ctry.mat', {'migr_ctry': migr_ctry})
sio.savemat('Output/l.mat', {'l': l})
sio.savemat('Output/u.mat', {'u': u})
sio.savemat('Output/realgdp.mat', {'realgdp': realgdp})
sio.savemat('Output/tau.mat', {'tau': tau})
sio.savemat('Output/l_b.mat', {'l_b': l_b})
_program_total_time = (time.perf_counter() - _program_start_time) * 1e3
print(f"\nTOTAL RUNTIME: { _program_total_time:.9f} ms ({_program_total_time/1000:.9f} s)")


Wrapped init.trmult_reduced for TPU matvec: shape= (17048, 17048) dtype= <class 'jax.numpy.float32'>
t=1
TOTAL IMPORTS TO WORLD GDP
4.0523494250654934e-05
t=2
t=3
t=4
t=5
t=6
t=7
t=8
t=9
t=10
t=11
t=12
t=13
t=14
t=15
t=16
t=17
t=18
t=19
t=20
t=21
t=22
t=23
t=24
t=25
t=26
t=27
t=28
t=29
t=30
t=31
t=32
t=33
t=34
t=35
t=36
t=37
t=38
t=39
t=40
t=41
t=42
t=43
t=44
t=45
t=46
t=47
t=48
t=49
t=50
t=51
t=52
t=53
t=54
t=55
t=56
t=57
t=58
t=59
t=60
t=61
t=62
t=63
t=64
t=65
t=66
t=67
t=68
t=69
t=70
t=71
t=72
t=73
t=74
t=75
t=76
t=77
t=78
t=79
t=80
t=81
t=82
t=83
t=84
t=85
t=86
t=87
t=88
t=89
t=90
t=91
t=92
t=93
t=94
t=95
t=96
t=97
t=98
t=99
t=100
t=101
t=102
t=103
t=104
t=105
t=106
t=107
t=108
t=109
t=110
t=111
t=112
t=113
t=114
t=115
t=116
t=117
t=118
t=119
t=120
t=121
t=122
t=123
t=124
t=125
t=126
t=127
t=128
t=129
t=130
t=131
t=132
t=133
t=134
t=135
t=136
t=137
t=138
t=139
t=140
t=141
t=142
t=143
t=144
t=145
t=146
t=147
t=148
t=149
t=150
t=151
t=152
t=153
t=154
t=155
t=156
t=157
t=158
t=159
t=1