In [1]:
%load_ext autoreload
%autoreload 2
import os
import numpy as np
import timeit
from models import *
from equations import *
from solvers import *

# 1. Benchmark Model

## 1.1 baseline

In [15]:
data_file_name = "data/data_2017.npz"
params = ModelParams.load_from_npz(data_file_name)
params.gamma = np.swapaxes(params.gamma, 1, 2)
model0 = Model(params)
shocks = model.shocks

In [16]:
solver = ModelSolver(model0)
solver.solve()

Iter 1: w_min=9.848e-01, w_max=1.007e+00, X_min=1.438e+04, X_max=9.825e+09, Δw=1.524e-02, ΔP=2.220e-16
Iter 2: w_min=9.725e-01, w_max=1.013e+00, X_min=1.435e+04, X_max=9.872e+09, Δw=1.228e-02, ΔP=1.217e-02
Iter 3: w_min=9.641e-01, w_max=1.019e+00, X_min=1.433e+04, X_max=9.906e+09, Δw=8.414e-03, ΔP=1.005e-02
Iter 4: w_min=9.591e-01, w_max=1.023e+00, X_min=1.432e+04, X_max=9.930e+09, Δw=5.057e-03, ΔP=8.079e-03
Iter 5: w_min=9.568e-01, w_max=1.026e+00, X_min=1.430e+04, X_max=9.949e+09, Δw=3.252e-03, ΔP=6.150e-03
Iter 6: w_min=9.560e-01, w_max=1.029e+00, X_min=1.429e+04, X_max=9.963e+09, Δw=2.525e-03, ΔP=4.335e-03
Iter 7: w_min=9.560e-01, w_max=1.031e+00, X_min=1.429e+04, X_max=9.975e+09, Δw=1.956e-03, ΔP=2.975e-03
Iter 8: w_min=9.562e-01, w_max=1.032e+00, X_min=1.428e+04, X_max=9.983e+09, Δw=1.519e-03, ΔP=2.063e-03
Iter 9: w_min=9.564e-01, w_max=1.033e+00, X_min=1.428e+04, X_max=9.990e+09, Δw=1.186e-03, ΔP=1.641e-03
Iter 10: w_min=9.565e-01, w_max=1.034e+00, X_min=1.429e+04, X_max=9.996e+

## 1.2 counterfactual: US imposes 20% tariff rate to ALL countries ALL sectors

In [17]:
target_importers = ['USA']  # one or multiple
country_list = params.country_list.tolist()
sector_list = params.sector_list.tolist()
tilde_tau = params.tilde_tau.copy()
target_importers_index = [country_list.index(importer) for importer in target_importers]  

tilde_tau_1 = tilde_tau.copy()

for importer_index in target_importers_index:
    for exporter_index in range(len(country_list)):
        for sector_index in range(len(sector_list)):
            if importer_index != exporter_index:
                tilde_tau_1[importer_index, exporter_index, sector_index] += 0.2

data_file_name = "data/data_2017.npz"
params = ModelParams.load_from_npz(data_file_name)
params.gamma = np.swapaxes(params.gamma, 1, 2)
params.tilde_tau = tilde_tau_1
model1 = Model(params)

In [18]:
solver = ModelSolver(model1)
solver.solve()

Iter 1: w_min=9.835e-01, w_max=1.009e+00, X_min=1.435e+04, X_max=9.938e+09, Δw=1.654e-02, ΔP=2.220e-16
Iter 2: w_min=9.713e-01, w_max=1.018e+00, X_min=1.432e+04, X_max=9.999e+09, Δw=1.268e-02, ΔP=1.317e-02
Iter 3: w_min=9.629e-01, w_max=1.024e+00, X_min=1.429e+04, X_max=1.004e+10, Δw=8.885e-03, ΔP=1.074e-02
Iter 4: w_min=9.575e-01, w_max=1.030e+00, X_min=1.427e+04, X_max=1.008e+10, Δw=5.433e-03, ΔP=8.418e-03
Iter 5: w_min=9.547e-01, w_max=1.034e+00, X_min=1.425e+04, X_max=1.010e+10, Δw=4.224e-03, ΔP=6.541e-03
Iter 6: w_min=9.535e-01, w_max=1.037e+00, X_min=1.424e+04, X_max=1.012e+10, Δw=3.325e-03, ΔP=4.700e-03
Iter 7: w_min=9.532e-01, w_max=1.040e+00, X_min=1.424e+04, X_max=1.014e+10, Δw=2.624e-03, ΔP=3.453e-03
Iter 8: w_min=9.532e-01, w_max=1.042e+00, X_min=1.424e+04, X_max=1.015e+10, Δw=2.081e-03, ΔP=2.718e-03
Iter 9: w_min=9.533e-01, w_max=1.044e+00, X_min=1.424e+04, X_max=1.016e+10, Δw=1.660e-03, ΔP=2.146e-03
Iter 10: w_min=9.532e-01, w_max=1.045e+00, X_min=1.424e+04, X_max=1.017e+

## 1.3 counterfactual: US imposes 20% tariff rate to ALL countries steel sector

In [20]:
target_importers = ['USA']  # one or multiple
target_sectors = ["Metal Products"]
country_list = params.country_list.tolist()
sector_list = params.sector_list.tolist()
tilde_tau = params.tilde_tau.copy()
target_importers_index = [country_list.index(importer) for importer in target_importers]  

tilde_tau_1 = tilde_tau.copy()

for importer_index in target_importers_index:
    for exporter_index in range(len(country_list)):
        for sector_index in range(len(target_sectors)):
            if importer_index != exporter_index:
                tilde_tau_1[importer_index, exporter_index, sector_index] += 0.2

data_file_name = "data/data_2017.npz"
params = ModelParams.load_from_npz(data_file_name)
params.gamma = np.swapaxes(params.gamma, 1, 2)
params.tilde_tau = tilde_tau_1
model2 = Model(params)

In [21]:
solver = ModelSolver(model2)
solver.solve()

Iter 1: w_min=9.834e-01, w_max=1.009e+00, X_min=1.435e+04, X_max=9.940e+09, Δw=1.661e-02, ΔP=2.220e-16
Iter 2: w_min=9.712e-01, w_max=1.018e+00, X_min=1.432e+04, X_max=1.000e+10, Δw=1.269e-02, ΔP=1.323e-02
Iter 3: w_min=9.628e-01, w_max=1.024e+00, X_min=1.429e+04, X_max=1.004e+10, Δw=8.897e-03, ΔP=1.079e-02
Iter 4: w_min=9.574e-01, w_max=1.030e+00, X_min=1.427e+04, X_max=1.008e+10, Δw=5.438e-03, ΔP=8.426e-03
Iter 5: w_min=9.546e-01, w_max=1.034e+00, X_min=1.425e+04, X_max=1.010e+10, Δw=4.236e-03, ΔP=6.550e-03
Iter 6: w_min=9.535e-01, w_max=1.037e+00, X_min=1.424e+04, X_max=1.012e+10, Δw=3.335e-03, ΔP=4.708e-03
Iter 7: w_min=9.532e-01, w_max=1.040e+00, X_min=1.424e+04, X_max=1.014e+10, Δw=2.632e-03, ΔP=3.463e-03
Iter 8: w_min=9.532e-01, w_max=1.042e+00, X_min=1.424e+04, X_max=1.015e+10, Δw=2.087e-03, ΔP=2.726e-03
Iter 9: w_min=9.532e-01, w_max=1.044e+00, X_min=1.424e+04, X_max=1.016e+10, Δw=1.665e-03, ΔP=2.153e-03
Iter 10: w_min=9.532e-01, w_max=1.045e+00, X_min=1.424e+04, X_max=1.017e+

# 2. Model with half trade elasticity

## 2.1 baseline

In [22]:
data_file_name = "data/data_2017.npz"
params = ModelParams.load_from_npz(data_file_name)
params.gamma = np.swapaxes(params.gamma, 1, 2)
params.theta /= 2
model3 = Model(params)
shocks = model2.shocks

In [23]:
solver = ModelSolver(model3)
solver.solve()

Iter 1: w_min=9.848e-01, w_max=1.007e+00, X_min=1.438e+04, X_max=9.825e+09, Δw=1.524e-02, ΔP=3.331e-16
Iter 2: w_min=9.714e-01, w_max=1.014e+00, X_min=1.435e+04, X_max=9.880e+09, Δw=1.337e-02, ΔP=1.217e-02
Iter 3: w_min=9.604e-01, w_max=1.020e+00, X_min=1.433e+04, X_max=9.928e+09, Δw=1.101e-02, ΔP=1.092e-02
Iter 4: w_min=9.517e-01, w_max=1.025e+00, X_min=1.431e+04, X_max=9.970e+09, Δw=8.697e-03, ΔP=9.352e-03
Iter 5: w_min=9.450e-01, w_max=1.030e+00, X_min=1.429e+04, X_max=1.001e+10, Δw=6.674e-03, ΔP=8.370e-03
Iter 6: w_min=9.400e-01, w_max=1.034e+00, X_min=1.427e+04, X_max=1.004e+10, Δw=5.262e-03, ΔP=7.041e-03
Iter 7: w_min=9.362e-01, w_max=1.038e+00, X_min=1.426e+04, X_max=1.007e+10, Δw=4.275e-03, ΔP=5.695e-03
Iter 8: w_min=9.335e-01, w_max=1.041e+00, X_min=1.424e+04, X_max=1.009e+10, Δw=3.471e-03, ΔP=4.613e-03
Iter 9: w_min=9.314e-01, w_max=1.044e+00, X_min=1.423e+04, X_max=1.012e+10, Δw=2.931e-03, ΔP=3.831e-03
Iter 10: w_min=9.298e-01, w_max=1.047e+00, X_min=1.422e+04, X_max=1.014e+

## 2.2 counterfactual: US imposes 20% tariff rate to ALL countries ALL sectors

In [24]:
target_importers = ['USA']  # one or multiple
country_list = params.country_list.tolist()
sector_list = params.sector_list.tolist()
tilde_tau = params.tilde_tau.copy()
target_importers_index = [country_list.index(importer) for importer in target_importers]  

tilde_tau_1 = tilde_tau.copy()

for importer_index in target_importers_index:
    for exporter_index in range(len(country_list)):
        for sector_index in range(len(sector_list)):
            if importer_index != exporter_index:
                tilde_tau_1[importer_index, exporter_index, sector_index] += 0.2

data_file_name = "data/data_2017.npz"
params = ModelParams.load_from_npz(data_file_name)
params.gamma = np.swapaxes(params.gamma, 1, 2)
params.theta /= 2
params.tilde_tau = tilde_tau_1
model4 = Model(params)

In [25]:
solver = ModelSolver(model4)
solver.solve()

Iter 1: w_min=9.835e-01, w_max=1.009e+00, X_min=1.435e+04, X_max=9.938e+09, Δw=1.654e-02, ΔP=3.331e-16
Iter 2: w_min=9.696e-01, w_max=1.018e+00, X_min=1.432e+04, X_max=1.001e+10, Δw=1.388e-02, ΔP=1.304e-02
Iter 3: w_min=9.587e-01, w_max=1.026e+00, X_min=1.428e+04, X_max=1.007e+10, Δw=1.142e-02, ΔP=1.204e-02
Iter 4: w_min=9.502e-01, w_max=1.033e+00, X_min=1.426e+04, X_max=1.013e+10, Δw=9.131e-03, ΔP=1.009e-02
Iter 5: w_min=9.430e-01, w_max=1.039e+00, X_min=1.423e+04, X_max=1.018e+10, Δw=7.113e-03, ΔP=8.723e-03
Iter 6: w_min=9.376e-01, w_max=1.045e+00, X_min=1.421e+04, X_max=1.022e+10, Δw=5.609e-03, ΔP=7.424e-03
Iter 7: w_min=9.334e-01, w_max=1.049e+00, X_min=1.419e+04, X_max=1.026e+10, Δw=4.870e-03, ΔP=6.088e-03
Iter 8: w_min=9.303e-01, w_max=1.054e+00, X_min=1.418e+04, X_max=1.029e+10, Δw=4.327e-03, ΔP=4.939e-03
Iter 9: w_min=9.279e-01, w_max=1.058e+00, X_min=1.416e+04, X_max=1.032e+10, Δw=3.850e-03, ΔP=4.251e-03
Iter 10: w_min=9.260e-01, w_max=1.061e+00, X_min=1.415e+04, X_max=1.035e+

## 2.3 counterfactual: US imposes 20% tariff rate to ALL countries steel sector

In [26]:
target_importers = ['USA']  # one or multiple
target_sectors = ["Metal Products"]
country_list = params.country_list.tolist()
sector_list = params.sector_list.tolist()
tilde_tau = params.tilde_tau.copy()
target_importers_index = [country_list.index(importer) for importer in target_importers]  

tilde_tau_1 = tilde_tau.copy()

for importer_index in target_importers_index:
    for exporter_index in range(len(country_list)):
        for sector_index in range(len(target_sectors)):
            if importer_index != exporter_index:
                tilde_tau_1[importer_index, exporter_index, sector_index] += 0.2

data_file_name = "data/data_2017.npz"
params = ModelParams.load_from_npz(data_file_name)
params.gamma = np.swapaxes(params.gamma, 1, 2)
params.theta /= 2
params.tilde_tau = tilde_tau_1
model5 = Model(params)

In [27]:
solver = ModelSolver(model5)
solver.solve()

Iter 1: w_min=9.834e-01, w_max=1.009e+00, X_min=1.435e+04, X_max=9.940e+09, Δw=1.661e-02, ΔP=3.331e-16
Iter 2: w_min=9.694e-01, w_max=1.018e+00, X_min=1.432e+04, X_max=1.001e+10, Δw=1.394e-02, ΔP=1.310e-02
Iter 3: w_min=9.585e-01, w_max=1.026e+00, X_min=1.428e+04, X_max=1.007e+10, Δw=1.143e-02, ΔP=1.210e-02
Iter 4: w_min=9.501e-01, w_max=1.033e+00, X_min=1.426e+04, X_max=1.013e+10, Δw=9.142e-03, ΔP=1.014e-02
Iter 5: w_min=9.430e-01, w_max=1.039e+00, X_min=1.423e+04, X_max=1.018e+10, Δw=7.123e-03, ΔP=8.732e-03
Iter 6: w_min=9.375e-01, w_max=1.045e+00, X_min=1.421e+04, X_max=1.022e+10, Δw=5.613e-03, ΔP=7.433e-03
Iter 7: w_min=9.334e-01, w_max=1.050e+00, X_min=1.419e+04, X_max=1.026e+10, Δw=4.885e-03, ΔP=6.096e-03
Iter 8: w_min=9.302e-01, w_max=1.054e+00, X_min=1.418e+04, X_max=1.029e+10, Δw=4.340e-03, ΔP=4.943e-03
Iter 9: w_min=9.278e-01, w_max=1.058e+00, X_min=1.416e+04, X_max=1.032e+10, Δw=3.862e-03, ΔP=4.264e-03
Iter 10: w_min=9.259e-01, w_max=1.061e+00, X_min=1.415e+04, X_max=1.035e+

In [32]:
target_importers = ['USA', 'CHN', "JPN"]  # one or multiple
country_list = params.country_list.tolist()
target_importers_index = [country_list.index(importer) for importer in target_importers]  

In [33]:
target_importers_index

[34, 6, 18]

In [34]:
model1.sol.real_w[target_importers_index]

array([1.01251735, 0.99705262, 0.99599141])

In [35]:
model4.sol.real_w[target_importers_index]

array([1.02231848, 0.99513196, 0.99259115])