<a href="https://colab.research.google.com/github/AntoineChapel/Trade_HW_1/blob/main/TTrade_hw1_AC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, grad, vmap
import time

### Solver: Armington model

In [None]:
def Armington_solver(A: np.ndarray,
                     L: np.ndarray,
                     tau: np.ndarray,
                     sigma: float,
                     damp: float = 0.9,
                     verbose = True,
                     normalize = True) -> np.ndarray:
  """
  Solves the Armington model for N countries
  Takes as input:
    * A (np.ndarray) : a vector of productivities of dimension N x 1
    * L (np.ndarray) : a vector of labor endowment of dimension N x 1
    * tau (np.ndarray) : a matrix of distances of dimension N x N, such that the
           entry on row i, column j, tau_{ij} is the bilateral trade cost
           between countries i and j
    * sigma (float) : elasticity parameter (scalar)
    * damp (float in [0, 1) ) : parametrizes the convex combination step of the
                                iterative procedure. With damp=0 (not
                                recommended), full updating.
  """

  tic = time.time()

  A = A.reshape(-1, 1)
  L = L.reshape(-1, 1)

  assert A.shape == L.shape, "The two arrays A and L should have the same size"

  N = A.shape[0]
  assert tau.shape[0] == tau.shape[1], "The tau matrix should be square"
  assert tau.shape[0] == N, "The tau matrix should have dimension N x N"
  assert sigma > 0, "The elasticity parameter sigma should be strictly positive"
  assert damp >= 0 and damp < 1, "The damping parameter should be in [0, 1)"

  epsilon = sigma - 1

  A_jnp = (jnp.array(A)).reshape(-1, 1)
  L_jnp = (jnp.array(L)).reshape(-1, 1)
  tau_jnp = jnp.array(tau)
  T_jnp = A_jnp**epsilon

  tol = 1e-3
  max_iter = 1e5
  norm = 1e6
  iter_count = 0

  w = jnp.ones((N, 1))

  def D_j(T, w, tau, epsilon, j):
    return T.T@((w*(tau[:, j]).reshape(-1, 1))**(-epsilon))
  D_j_vectorized = vmap(D_j, in_axes=(None, None, None, None, 0))

  while norm > tol and iter_count < max_iter:
    Y = w * L_jnp
    lambda_D = (D_j_vectorized(T_jnp, w, tau_jnp, epsilon, jnp.arange(N))).reshape(1, -1)
    lambda_mat = (T_jnp.T*(w*tau_jnp)**(-epsilon)) / lambda_D
    Y_prime = lambda_mat@Y

    norm = jnp.max(jnp.abs(Y_prime - Y))
    Y = (damp*Y + (1-damp)*Y_prime).reshape(-1, 1)

    w = (Y/L_jnp).reshape(-1, 1)
    if normalize:
      w = w.at[0].set(1) #normalization enforcement: w_1 = 1
    iter_count += 1

    if verbose==True and iter_count%20==0:
      print(f"Iteration {iter_count}, norm: {norm}")

  tac = time.time()

  if iter_count == max_iter:
    print("Maximum number of iterations reached")
  else:
    print(f"Convergence reached in {iter_count} iterations and {tac - tic} seconds.")

  return np.array(w)

In [None]:
### Example:

np.random.seed(123)
n_countries = 30
A = np.random.normal(1, 0.1, size=(n_countries, 1))
L = np.random.normal(5, 0.1, size=(n_countries, 1))
tau = np.random.normal(5, 0.4, size=(n_countries, n_countries))
sigma = 5

w = Armington_solver(A, L, tau, sigma, verbose=True)

Iteration 20, norm: 0.06522607803344727
Iteration 40, norm: 0.053789615631103516
Iteration 60, norm: 0.044370174407958984
Iteration 80, norm: 0.036608219146728516
Iteration 100, norm: 0.030209064483642578
Iteration 120, norm: 0.024932384490966797
Iteration 140, norm: 0.020577430725097656
Iteration 160, norm: 0.016986370086669922
Iteration 180, norm: 0.014017581939697266
Iteration 200, norm: 0.011566638946533203
Iteration 220, norm: 0.00954437255859375
Iteration 240, norm: 0.00787353515625
Iteration 260, norm: 0.006495952606201172
Iteration 280, norm: 0.005356788635253906
Iteration 300, norm: 0.004414081573486328
Iteration 320, norm: 0.003635883331298828
Iteration 340, norm: 0.0029935836791992188
Iteration 360, norm: 0.002464771270751953
Iteration 380, norm: 0.002025604248046875
Iteration 400, norm: 0.0016608238220214844
Iteration 420, norm: 0.0013623237609863281
Iteration 440, norm: 0.0011148452758789062
Convergence reached in 451 iterations and 3.0502233505249023 seconds.


In [None]:
w.T

array([[1.        , 1.0147899 , 0.98749226, 0.99407387, 1.0079896 ,
        1.0054353 , 0.98809654, 1.021088  , 1.0073427 , 1.0014774 ,
        1.0048668 , 1.0116875 , 0.9936751 , 0.98970354, 0.9882356 ,
        0.9853505 , 0.98072785, 0.9725412 , 0.9800803 , 1.009552  ,
        0.98060745, 0.9859122 , 0.97389203, 1.0425824 , 1.0244395 ,
        0.997455  , 1.0073949 , 0.9813022 , 0.9871493 , 1.0052215 ]],
      dtype=float32)

### Unit Test: $\tau_{ij}=1$ $\forall i, j$

In [None]:
np.random.seed(123)
n_countries = 100
A_test = np.ones((n_countries, 1))
L_test = np.random.normal(1, 0.1, size=(n_countries, 1))**2
L_test[0] = 1 #normalization

tau_test = np.ones((n_countries, n_countries))
sigma_test = 5

In [None]:
def unit_test(A, L, sigma, tol=1e-2):
  n_countries = A.shape[0]
  closed_form = L**(-1/sigma)

  tau = np.ones((n_countries, n_countries))
  equilibrium_w = Armington_solver(A, L, tau, sigma, verbose=True, damp=0.9)

  return jnp.max(jnp.abs(equilibrium_w - closed_form)) < tol

In [None]:
unit_test(A_test, L_test, sigma_test)

Iteration 20, norm: 0.03693592548370361
Iteration 40, norm: 0.03335893154144287
Iteration 60, norm: 0.0301363468170166
Iteration 80, norm: 0.027231335639953613
Iteration 100, norm: 0.024611234664916992
Iteration 120, norm: 0.022247314453125
Iteration 140, norm: 0.020113468170166016
Iteration 160, norm: 0.01818668842315674
Iteration 180, norm: 0.01644730567932129
Iteration 200, norm: 0.014876246452331543
Iteration 220, norm: 0.01345670223236084
Iteration 240, norm: 0.012173295021057129
Iteration 260, norm: 0.011012911796569824
Iteration 280, norm: 0.009963512420654297
Iteration 300, norm: 0.00901496410369873
Iteration 320, norm: 0.00815737247467041
Iteration 340, norm: 0.007381796836853027
Iteration 360, norm: 0.006679892539978027
Iteration 380, norm: 0.006044745445251465
Iteration 400, norm: 0.0054700374603271484
Iteration 420, norm: 0.0049495697021484375
Iteration 440, norm: 0.004479169845581055
Iteration 460, norm: 0.0040531158447265625
Iteration 480, norm: 0.0036679506301879883
Iter

Array(True, dtype=bool)

### Gains from Trade

In [None]:
np.random.seed(123)
n_countries = 10
A_test2 = np.ones((n_countries, 1))
L_test2 = np.random.normal(2, 0.1, size=(n_countries, 1))**2
L_test2[0] = 1 #normalization

tau_test2 = np.ones((n_countries, n_countries))
sigma_test2 = 10

In [None]:
def GFT(A, L, tau, sigma, damp = 0.9):
  A = A.reshape(-1, 1)
  L = L.reshape(-1, 1)

  w_FT = Armington_solver(A, L, tau, sigma, verbose=False, normalize=True, damp=damp).reshape(-1, 1)

  P_FT = (jnp.sum(((w_FT*tau/A)**(1-sigma)), axis=0)**(1/(1-sigma))).reshape(-1, 1)

  real_income_FT = (w_FT*L)/P_FT
  real_income_autarky = (A*L)

  return np.array(real_income_FT)/np.array(real_income_autarky)

In [None]:
GFT(A_test2, L_test2, tau_test2, sigma_test2)

Convergence reached in 212 iterations and 1.3409764766693115 seconds.


array([[1.46922958],
       [1.2667754 ],
       [1.27557532],
       [1.29935464],
       [1.28669717],
       [1.25902727],
       [1.31268858],
       [1.28472283],
       [1.26355909],
       [1.29054959]])