In [1]:
# Set up
import os
import numpy as np
import pandas as pd

%load_ext autoreload
%autoreload 2  

from equilibrium import *

from cp_functions import calc_real_I, calc_real_w



# Import real data and check

In [29]:
wd = os.path.expanduser("~/Dropbox/Tariff_Project")

os.chdir(wd)
print(f"Current working directory: {os.getcwd()}")

Current working directory: /Users/yaolangzhong/Dropbox/Tariff_Project


In [30]:
real_data = np.load('model_data_2017.npz')
#real_data = np.load('R_model_data_2017.npz')

for name in real_data.files:
    arr = real_data[name]
    print(f"{name}: shape={arr.shape}, dtype={arr.dtype}")



N: shape=(), dtype=int64
J: shape=(), dtype=int64
country_list: shape=(37,), dtype=<U3
sector_list: shape=(23,), dtype=<U36
alpha: shape=(37, 23), dtype=float64
beta: shape=(37, 23), dtype=float64
gamma: shape=(37, 23, 23), dtype=float64
theta: shape=(23,), dtype=float64
pif: shape=(37, 37, 23), dtype=float64
pim: shape=(37, 37, 23), dtype=float64
pit: shape=(37, 37, 23), dtype=float64
Xf: shape=(37, 23), dtype=float64
Xm: shape=(37, 23), dtype=float64
X: shape=(37, 23), dtype=float64
tilde_tau: shape=(37, 37, 23), dtype=float64
D: shape=(37,), dtype=float64
VA: shape=(37,), dtype=float64


In [31]:
N = int(real_data['N'])  # Number of countries
J = int(real_data['J'])  # Number of sectors
country_list = [c.decode("utf-8") if isinstance(c, bytes) else c
                for c in real_data["country_list"].tolist()]

sector_list  = [s.decode("utf-8") if isinstance(s, bytes) else s
                for s in real_data["sector_list"].tolist()]
theta = real_data['theta']

X = real_data['X']  # X[n,j]: country n's total output in sector j

tradable_sector_list = sector_list[0:10]

In [32]:
# Check \alpha
alpha = real_data['alpha']
# Check 1: For each n, the sum over j equals 1
sum_by_n = np.sum(alpha, axis=1)
check_sum = np.allclose(sum_by_n, 1, atol=1e-6)  # Allowing a small numerical tolerance
if check_sum:
    print("For each country, the sum over secotrs equals 1 ✅")
else:
    print("There are some country where the sum over sectors is not 1 ❌")
    print("The issue occurs at the following n indices:", np.where(np.abs(sum_by_n - 1) > 1e-6)[0])
    print("The corresponding sums are:", sum_by_n[np.where(np.abs(sum_by_n - 1) > 1e-6)[0]])
# Check 2: Every value in alpha is between 0 and 1
check_range = np.all((alpha >= 0) & (alpha <= 1))
if check_range:
    print("Every value in alpha is between 0 and 1 ✅")
else:
    print("There are values in alpha that are not between 0 and 1 ❌")
    print("These values are at positions:", np.where((alpha < 0) | (alpha > 1)))
    print("The values are:", alpha[(alpha < 0) | (alpha > 1)])


# Check \beta
beta = real_data['beta']
# Value-added shares (for clarity, I use \beta instead of gamma in this code)
# beta[n,j]: country n's value-added share in sector j
# Check: Every value in beta is between 0 and 1
check_range = np.all((beta >= 0) & (beta <= 1))
if check_range:
    print("Every value in beta is between 0 and 1 ✅")
else:
    print("There are values in beta that are not between 0 and 1 ❌")
    print("These values are at positions:", np.where((beta < 0) | (beta > 1)))
    print("The values are:", beta[(beta < 0) | (beta > 1)])


# Check \gamma
#gamma_temp = real_data['gamma']
#gamma = np.swapaxes(gamma_temp, 1, 2) # to be consistent with the notation in cp(2015) (第二个维度是被用掉的)
gamma= real_data['gamma']
# Check 
# Check 1: Every value in gamma is between 0 and 1
invalid_values = (gamma < 0) | (gamma > 1)

if np.any(invalid_values):
    print("There are values in gamma that are not between 0 and 1 ❌")
    print("These values are at positions:", np.where(invalid_values))
    print("The values are:", gamma[invalid_values])
else:
    print("Every value in gamma is between 0 and 1 ✅")

# Check 2: sum(k) gamma[n,k,j] + beta[n,j] = 1
temp = np.sum(gamma, axis = 1) + beta
is_valid = np.allclose(temp, 1, atol = 1e-5)

if is_valid:
    print("Condition satisfied: sum(k) gamma[n, k, j] + beta[n, j] = 1 ✅")
else:
    print("Condition not satisfied: sum(k) gamma[n, k, j] + beta[n, j] ≠ 1 ❌")
    print("Positions where the condition fails:", np.where(~np.isclose(temp, 1, atol=1e-5)))
    print("Values that do not satisfy the condition:", temp[~np.isclose(temp, 1, atol=1e-5)])


# check \pi
pi = real_data['pit']

# Check
temp = np.sum(pi, axis=1)  
is_valid = np.allclose(temp, 1, atol=1e-5)  

if is_valid:
    print("Condition satisfied: sum(i) pi[n, i, j] = 1 ✅")
else:
    print("Condition not satisfied: sum(i) pi[n, i, j] ≠ 1 ❌")
    print("Positions where the condition fails:", np.where(~np.isclose(temp, 1, atol=1e-5)))
    print("Values that do not satisfy the condition:", temp[~np.isclose(temp, 1, atol=1e-5)])

# real_data_v2 = np.load('model_data_2017.npz')
# D = -real_data_v2['D']  # D[n]: country n's trade deficit





For each country, the sum over secotrs equals 1 ✅
Every value in alpha is between 0 and 1 ✅
Every value in beta is between 0 and 1 ✅
Every value in gamma is between 0 and 1 ✅
Condition not satisfied: sum(k) gamma[n, k, j] + beta[n, j] ≠ 1 ❌
Positions where the condition fails: (array([ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
        0,  0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,
        2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
        2,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,
        3,  3,  3,  3,  3,  3,  3,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,
        4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  5,  5,  5,  5,
        5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,
        5,  5,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,
        6,  6,  6,  6,  6,  6,  6,  6,  7,  7,

In [27]:
# Try the New Tariff Data

N              = len(country_list)              
sector_mapping = {
    "Agriculture"          : "Agriculture",
    "Fishing"              : "Agriculture",
    "Other Manufacturing"  : "Other Manufacturing & Recycling",
    "Recycling"            : "Other Manufacturing & Recycling",
}

# ------------- 1. Import Tariff that has been cleaned -------------
# however, we need to combine some sectors
df = (
    pd.read_csv(os.path.join(wd, "3_Result/parameters/All_Tariff_2017.csv"))   
    .loc[:, ["Importer", "Exporter", "Sector", "tariff"]]
    .rename(columns={"tariff": "Tariff"})
)


# ------------- 2. Combine sector（mapping → taking average） -------------
df["Sector"] = df["Sector"].replace(sector_mapping)

df = (
    df.groupby(["Importer", "Exporter", "Sector"], as_index=False, sort=False)["Tariff"]
    .mean()                                          
)


# ------------- 3. Change to numpy -------------
df["Exporter_Sector"] = df["Exporter"] + "_" + df["Sector"]


df["Importer"]        = pd.Categorical(df["Importer"], categories=country_list, ordered=True)
df["Exporter_Sector"] = pd.Categorical(
    df["Exporter_Sector"],
    categories=[
        f"{exp}_{sec}"
        for exp in country_list
        for sec in df["Sector"].unique()               
    ],
    ordered=True
)

pivot = df.pivot_table(index="Importer",
                    columns="Exporter_Sector",
                    values="Tariff",
                    aggfunc="first",
                    observed=False)

J = df["Sector"].nunique()                              
tariff_np   = pivot.to_numpy()                         
tariff_base = tariff_np.reshape(N, N, J)               


tilde_tau = tariff_base + 1

In [18]:
# From X_nj Calculate VA
# Reshape beta to (N, 1, J) to allow broadcasting with pi_prime and tilde_tau_prime
VA = np.zeros(N)

for n in range(N):  
    for j in range(J):  
        inner_sum = 0
        for i in range(N):  
            inner_sum += X[i,j] * (pi[i,n,j] / tilde_tau[i,n,j])
        VA[n] += beta[n,j] * inner_sum

# Check if every value in VA is greater than 0
check_positive = np.all(VA > 0)
if check_positive:
    print("Every country's value added is greater than 0 ✅")
else:
    print("There are values in VA that are less than or equal to 0 ❌")
    print("These values are at positions:", np.where(VA <= 0))
    print("The values are:", VA[VA <= 0])

D = np.zeros(N)

for n in range(N):
    for j in range(J):
        for i in range (N):
            IM = X[n,j] * pi[n,i,j] / tilde_tau[n,i,j]
            EX = X[i,j] * pi[i,n,j] / tilde_tau[i,n,j]

            D[n] += EX - IM



Every country's value added is greater than 0 ✅


# Base year equilibrium

In [19]:
kappa_base = np.ones((N,N,J))

w_hat_initial = np.ones(N)
P_hat_initial = np.ones((N, J)) 

w_hat = w_hat_initial
P_hat = P_hat_initial
kappa_hat = kappa_base
tilde_tau_prime = tilde_tau
X_initial = X

# gamma_reverse = np.swapaxes(gamma, 1, 2)

In [20]:
P_hat_new, c_hat = solve_price_and_cost(w_hat, P_hat, pi, gamma, beta, theta, kappa_hat, N, J)
pi_prime = solve_piprime(c_hat, P_hat_new, pi, theta, kappa_hat, N, J)
X_prime = solve_X_prime(w_hat, alpha, gamma, pi_prime, VA, tilde_tau_prime, D, N, J, X_initial)

In [21]:
kappa_base = np.ones((N,N,J))

w_hat_initial = np.ones(N)
P_hat_initial = np.ones((N, J)) 


w_base, P_base, X_base, pi_base = equilibrium(gamma, beta, theta, tilde_tau, kappa_base, pi, alpha, VA, D, N, J, X, w_hat_initial, P_hat_initial)


Round 1: w_hat_min = 0.5976833482610692, w_hat_max = 1.4135727629919639, min_X_prime = -178259572.83962208, max_X_prime = 8038268412.226635, wfmax = 0.4135727629919639, Pfmax = 2.220446049250313e-16 
Round 2: w_hat_min = 0.29384623140235744, w_hat_max = 1.6931240298450803, min_X_prime = -324549361.7016442, max_X_prime = 8090507083.253502, wfmax = 0.30383711685871173, Pfmax = 0.40064764457369795 
Round 3: w_hat_min = 0.3189068762899382, w_hat_max = 1.8236815041697574, min_X_prime = -435350838.30775523, max_X_prime = 8111247623.237915, wfmax = 1.4794002945128915, Pfmax = 0.5175029076413666 
Round 4: w_hat_min = 0.18982730129881434, w_hat_max = 1.8224533424386735, min_X_prime = -327864503.9764037, max_X_prime = 8131533148.317243, wfmax = 0.30533091262215284, Pfmax = 1.1932967615092631 
Round 5: w_hat_min = 0.344297927880141, w_hat_max = 1.8330386558397902, min_X_prime = -407839548.75847816, max_X_prime = 8063051511.805533, wfmax = 1.4485954662113087, Pfmax = 0.510875541319789 
Round 6: w_

  log_w_hat = np.log(w_hat)  # shape: (N,)


In [11]:
VA_base = np.zeros(N)

for n in range(N):  
    for j in range(J):  
        inner_sum = 0
        for i in range(N):  
            inner_sum += X_base[i,j] * (pi_base[i,n,j] / tilde_tau[i,n,j])
        VA_base[n] += beta[n,j] * inner_sum

w_try, P_try, X_try, pi_try = equilibrium(gamma, beta, theta, tilde_tau, kappa_base, pi_base, alpha, VA_base, D, N, J, X_base, w_hat_initial, P_hat_initial)


Round 1: w_hat_min = 0.9999990197165288, w_hat_max = 1.0000001673213967, min_X_prime = 34152.490709737234, max_X_prime = 9054952410.621124, wfmax = 9.802834711747366e-07, Pfmax = 4.311377921339954e-10 


In [12]:
real_I_base = calc_real_I(pi_base, tilde_tau, X_base, D, P_base, alpha, beta)

# df = pd.DataFrame({
#     'Country': country_list,
#     'Real_w_hat': real_I_base
# })

# df

In [12]:
w_base

array([0.9766961 , 1.00619565, 1.00756329, 0.97709982, 1.01719607,
       1.00159375, 0.98938641, 1.01914484, 1.01630796, 1.01268122,
       1.03322999, 1.01477737, 1.00805579, 1.00463619, 1.01392204,
       0.98169497, 1.01828613, 1.00427082, 1.02055935, 1.03260867,
       1.00371769, 1.00877911, 0.99782968, 1.00225483, 1.02599355,
       1.00324412, 0.96357212, 1.00891034, 1.02536923, 1.01286933,
       1.01007897, 1.06106607, 1.01344023, 1.0074154 , 1.01549527,
       1.00728104, 0.97500542])

# Counterfactual 1
Additional 10% tariff by the US on all imports from any country in the world

In [11]:
# Counterfactual Tariff: U.S.
# U.S.: impose uniform 25% tariff on all imported aluminum products 
target_importers = ['USA']  # one or multiple
target_importers_index = [country_list.index(importer) for importer in target_importers]  

tilde_tau_1 = tilde_tau.copy()

for importer_index in target_importers_index:
    for exporter_index in range(len(country_list)):
        for sector_index in range(len(tradable_sector_list)):
            if importer_index != exporter_index:
                tilde_tau_1[importer_index, exporter_index, sector_index] += 0.1


kappa_hat_1 = tilde_tau_1 / tilde_tau

In [12]:
for importer_index in target_importers_index:
    for exporter_index in range(len(country_list)):
        for sector_index in range(len(tradable_sector_list)):
            if importer_index != exporter_index:
                print(f"Importer: {country_list[importer_index]}, Exporter: {country_list[exporter_index]}, Sector: {tradable_sector_list[sector_index]}, Tariff: {tilde_tau[importer_index, exporter_index, sector_index]}")  # Print the tariff value)

Importer: USA, Exporter: AUS, Sector: Agriculture, Tariff: 1.1952048791749998
Importer: USA, Exporter: AUS, Sector: Mining and Quarrying, Tariff: 1.007837838
Importer: USA, Exporter: AUS, Sector: Food, Tariff: 1.586006451
Importer: USA, Exporter: AUS, Sector: Textiles, Tariff: 1.3192676933333334
Importer: USA, Exporter: AUS, Sector: Wood and Paper, Tariff: 1.0330707643333332
Importer: USA, Exporter: AUS, Sector: Petroleum, Tariff: 1.025495314
Importer: USA, Exporter: AUS, Sector: Metal Products, Tariff: 1.226551668
Importer: USA, Exporter: AUS, Sector: Electrica and Machinery, Tariff: 1.0431085428
Importer: USA, Exporter: AUS, Sector: Transport Equipment, Tariff: 1.0
Importer: USA, Exporter: AUS, Sector: Other Manufacturing and Recycling, Tariff: 1.0166666675
Importer: USA, Exporter: AUT, Sector: Agriculture, Tariff: 2.192080055
Importer: USA, Exporter: AUT, Sector: Mining and Quarrying, Tariff: 1.204188552
Importer: USA, Exporter: AUT, Sector: Food, Tariff: 92.5345393
Importer: USA, E

In [13]:

w_base = np.ones(N)
P_base = np.ones((N, J))

w_hat_1, P_hat_1, X_prime_1, pi_prime_1 = equilibrium(gamma, beta, theta, tilde_tau_1, kappa_hat_1, pi_base, alpha, VA_base, D, N, J, X_base, w_base, P_base)



Round 1: w_hat_min = 0.9855710657263415, w_hat_max = 1.003896557531984, min_X_prime = 34133.43311883327, max_X_prime = 9159197256.31095, wfmax = 0.014428934273658545, Pfmax = 0.04486849102963997 
Round 2: w_hat_min = 0.9740388879330709, w_hat_max = 1.0071121145524289, min_X_prime = 34128.11829955087, max_X_prime = 9181112617.11289, wfmax = 0.011532177793270537, Pfmax = 0.014898054077991829 
Round 3: w_hat_min = 0.9649408730625892, w_hat_max = 1.0098831033621998, min_X_prime = 34122.82217745871, max_X_prime = 9200440265.985008, wfmax = 0.009098014870481674, Pfmax = 0.011483043052995634 
Round 4: w_hat_min = 0.9580113452031055, w_hat_max = 1.0122996112537452, min_X_prime = 34116.40152280263, max_X_prime = 9217231081.867065, wfmax = 0.006929527859483753, Pfmax = 0.00907877323914219 
Round 5: w_hat_min = 0.9529233679962605, w_hat_max = 1.014417895186874, min_X_prime = 34109.04649203188, max_X_prime = 9231975360.854366, wfmax = 0.005087977206844951, Pfmax = 0.006990534233731083 
Round 6: w_

In [14]:
real_w_1 = calc_real_w(w_hat_1, P_hat_1, alpha)

In [15]:
real_I_1 = calc_real_I(pi_prime_1, tilde_tau_1, X_prime_1, D, P_hat_1, alpha, beta)

real_I_1_hat = real_I_1 / real_I_base



# Counterfactual 2

Additional reciprocal tariffs by the US as listed in the table, except for China; the US imposes 125% additional tariffs against all imports from China.

In [16]:
annex_tariff = {
    # -------- EU Members (22)--------
    'AUT': 0.20, 'BEL': 0.20, 'BGR': 0.20, 'CZE': 0.20, 'DNK': 0.20,
    'EST': 0.20, 'FIN': 0.20, 'FRA': 0.20, 'DEU': 0.20, 'GRC': 0.20,
    'HUN': 0.20, 'IRL': 0.20, 'ITA': 0.20, 'LTU': 0.20, 'NLD': 0.20,
    'POL': 0.20, 'PRT': 0.20, 'ROU': 0.20, 'SVK': 0.20, 'SVN': 0.20,
    'ESP': 0.20, 'SWE': 0.20,

    # --------- Non-EU Members (6) ---------
    'CHN': 1.25, 'IND': 0.26, 'JPN': 0.24, 'KOR': 0.25,'TWN': 0.32, 'VNM': 0.46,

    # ---------Not mentioned (10%?) (8)-----------
    # Australia, Brazil, Canada, Mexico, Turkey, Russia
    'AUS':0.1, 'BRA':0.1, 'CAN': 0.1, 'MEX':0.1, 'TUR':0.1, 'GBR':0.1, 'RUS':0.1, 'ROW':0.1,

}

us_index = country_list.index('USA')

exporter_indices = {
    iso: country_list.index(iso)
    for iso in annex_tariff.keys()
    if iso in country_list and iso != 'USA'   
}


tilde_tau_2 = tilde_tau.copy()

for iso, exp_idx in exporter_indices.items():
    tilde_tau_2[us_index, exp_idx, :] += annex_tariff[iso]

tilde_tau_2[us_index,6,1]

kappa_hat_2 = tilde_tau_2 / tilde_tau

In [17]:
w_hat_2, P_hat_2, X_prime_2, pi_prime_2 = equilibrium(gamma, beta, theta, tilde_tau_2, kappa_hat_2, pi_base, alpha, VA_base, D, N, J, X_base,w_base,P_base)


Round 1: w_hat_min = 0.9874252033052447, w_hat_max = 1.004949292644205, min_X_prime = 34119.83694604867, max_X_prime = 9155678040.650473, wfmax = 0.012574796694755253, Pfmax = 0.06637201988338814 
Round 2: w_hat_min = 0.9787035041690316, w_hat_max = 1.0089608832181356, min_X_prime = 34111.109505721746, max_X_prime = 9181717604.335512, wfmax = 0.008721699136213124, Pfmax = 0.016206148895439254 
Round 3: w_hat_min = 0.9724911816825804, w_hat_max = 1.0124412585620421, min_X_prime = 34104.51819779832, max_X_prime = 9206592335.66214, wfmax = 0.006212322486451205, Pfmax = 0.008686720748291243 
Round 4: w_hat_min = 0.9682631015306201, w_hat_max = 1.0155323056676882, min_X_prime = 34098.024542904954, max_X_prime = 9229252474.312548, wfmax = 0.004228080151960301, Pfmax = 0.006201820422691462 
Round 5: w_hat_min = 0.9655999185258664, w_hat_max = 1.0183094918736022, min_X_prime = 34091.637861323514, max_X_prime = 9249963471.212341, wfmax = 0.0027771862059140506, Pfmax = 0.004668179330141853 
Roun

In [18]:
real_w_2 = calc_real_w(w_hat_2, P_hat_2, alpha)



In [19]:
real_I_2 = calc_real_I(pi_prime_2, tilde_tau_2, X_prime_2, D, P_hat_2, alpha, beta)

real_I_2_hat = real_I_2 / real_I_base


# Counterfactual 3

In addition to the second scenario, China imposes 84% on all goods imported from the US. (It is not the additional rate, but the tariff rate applied to all imports by China from the US.)


In [20]:
us_index = country_list.index('USA')
chn_index = country_list.index('CHN')

tilde_tau_3 = tilde_tau_2.copy()


for sector_index in range(len(tradable_sector_list)):
    tilde_tau_3[chn_index, us_index, sector_index] += 0.84

kappa_hat_3 = tilde_tau_3 / tilde_tau

In [21]:
w_hat_3, P_hat_3, X_prime_3, pi_prime_3 = equilibrium(gamma, beta, theta, tilde_tau_3, kappa_hat_3, pi_base, alpha, VA_base, D, N, J, X_base,w_base,P_base)


Round 1: w_hat_min = 0.9874141637529726, w_hat_max = 1.0048843825416451, min_X_prime = 34119.86706709987, max_X_prime = 9153922844.476051, wfmax = 0.01258583624702736, Pfmax = 0.06637201988338814 
Round 2: w_hat_min = 0.9786761081407819, w_hat_max = 1.008838996413049, min_X_prime = 34110.43090853216, max_X_prime = 9179571904.865028, wfmax = 0.008738055612190765, Pfmax = 0.01616990602966739 
Round 3: w_hat_min = 0.9724396425825835, w_hat_max = 1.012268364974588, min_X_prime = 34103.4526474534, max_X_prime = 9204069370.69005, wfmax = 0.006236465558198345, Pfmax = 0.008702926737697014 
Round 4: w_hat_min = 0.9681815455989333, w_hat_max = 1.0153133717280218, min_X_prime = 34096.76420577045, max_X_prime = 9226378803.139599, wfmax = 0.004258096983650206, Pfmax = 0.006225794889870406 
Round 5: w_hat_min = 0.9654849208107378, w_hat_max = 1.0180487772622906, min_X_prime = 34090.29110249906, max_X_prime = 9246766075.702835, wfmax = 0.0027796030910740566, Pfmax = 0.0046835900069780445 
Round 6: w

In [22]:
real_w_3 = calc_real_w(w_hat_3, P_hat_3, alpha)


In [23]:
real_I_3 = calc_real_I(pi_prime_3, tilde_tau_3, X_prime_3, D, P_hat_3, alpha, beta)

real_I_3_hat = real_I_3 / real_I_base

df = pd.DataFrame({
    'Country': country_list,
    'Real_I_hat': real_I_3_hat
})

df

Unnamed: 0,Country,Real_I_hat
0,AUS,0.975341
1,AUT,1.004614
2,BEL,1.005902
3,BRA,0.977569
4,BGR,1.015072
5,CAN,1.003593
6,CHN,0.962647
7,CZE,1.015343
8,DNK,1.01509
9,EST,1.012768


In [28]:
real_hat = pd.DataFrame({
    'Country': country_list,
    'Scenario1_real_w_hat': real_w_1,
    'Scenario2_real_w_hat': real_w_2,
    'Scenario3_real_w_hat': real_w_3,
    'Scenario1_real_I_hat': real_I_1_hat,
    'Scenario2_real_I_hat': real_I_2_hat,
    'Scenario3_real_I_hat': real_I_3_hat
})
real_hat.to_csv(os.path.join(wd, "real_hat.csv"), index=False)    

In [29]:

GO_base = np.einsum('nij,nij,nj->ij', pi_base, 1 / tilde_tau, X_base)  # shape: (N,J)

GO_counter_1 = np.einsum('nij,nij,nj-> ij', pi_prime_1, 1 / tilde_tau_1, X_prime_1)  # shape: (N,J)
GO_counter_2 = np.einsum('nij,nij,nj-> ij', pi_prime_2, 1 / tilde_tau_2, X_prime_2)  # shape: (N,J)
GO_counter_3 = np.einsum('nij,nij,nj-> ij', pi_prime_3, 1 / tilde_tau_3, X_prime_3)  # shape: (N,J)


pct1 = (GO_counter_1 - GO_base) / GO_base          # (N, J)
pct2 = (GO_counter_2 - GO_base) / GO_base          # (N, J)
pct3 = (GO_counter_3 - GO_base) / GO_base          # (N, J)



target_countries = ['CHN', 'JPN', 'USA']
target_idx       = [country_list.index(c) for c in target_countries]

records = []
for c, idx in zip(target_countries, target_idx):
    for j, sector in enumerate(sector_list):
        records.append({
            'Country'   : c,
            'Sector'    : sector,
            'Scenario1' : pct1[idx, j],
            'Scenario2' : pct2[idx, j],
            'Scenario3' : pct3[idx, j],
        })

df_pct = pd.DataFrame(records)

In [30]:
df_pct.to_csv(os.path.join(wd, "pct.csv"), index=False)

In [31]:
gamma = np.swapaxes(gamma, 1, 2) # to be consistent with the notation in cp(2015) (第二个维度是被用掉的)
D = -D

np.savez('CP_data_2017.npz',
        N = N,
        J = J,
        country_list = country_list, 
        sector_list = sector_list, 
        alpha = alpha, 
        beta = beta, 
        gamma = gamma, 
        theta = theta, 
        pi = pi,
        X = X,
        tilde_tau = tilde_tau,
        D = D, 
        VA = VA)