In [1]:
%run data_setup

In [2]:
import numpyro
from numpyro.infer import SVI, Trace_ELBO, MCMC, NUTS, Predictive
import numpyro.distributions as dist
from sklearn.preprocessing import LabelEncoder
from numpyro.infer.reparam import NeuTraReparam
from numpyro.infer.autoguide import AutoDiagonalNormal
from numpyro.handlers import reparam
from numpyro.optim import Adam
import jax.numpy as jnp
from jax import random
import statsmodels.api as sm
from numpyro.infer.reparam import LocScaleReparam

numpyro.enable_x64()
prng_seed = random.PRNGKey(0)
assert numpyro.__version__.startswith("0.13.2")

# Hierarchical Models

In [3]:
# global control variables
learning_rate = 1e-2
vi_iters = 10000

num_warmup = 250
num_samples = 1000
num_chains = 1
thinning = 1
target_accept_prob = 0.99
max_tree_depth = 12

num_resamples = 5

In [4]:
def hg_intercept_partial(pool_code, logX_ijt, hyperparameters, logM_ijt=None):
    σ_α = numpyro.sample("σ_α", dist.HalfNormal(10.0))

    μ_α = numpyro.sample("μ_α", dist.Normal(hyperparameters['Β_param'][0], np.sqrt(np.diag(hyperparameters['cov_Β'])[0])))
    β_1 = numpyro.sample("β_1", dist.Normal(hyperparameters['Β_param'][1], np.sqrt(np.diag(hyperparameters['cov_Β'])[1])))
    β_2 = numpyro.sample("β_2", dist.Normal(hyperparameters['Β_param'][2], np.sqrt(np.diag(hyperparameters['cov_Β'])[2])))
    β_3 = numpyro.sample("β_3", dist.Normal(hyperparameters['Β_param'][3], np.sqrt(np.diag(hyperparameters['cov_Β'])[3])))
    μ_σ = numpyro.sample("μ_σ", dist.HalfNormal(np.sqrt(np.pi/2) * hyperparameters['σ_param']))

    n_pairs = len(np.unique(pool_code))
    with numpyro.plate("levels", n_pairs):
        α = numpyro.sample("α", dist.Normal(μ_α, σ_α))
        σ = numpyro.sample("σ", dist.HalfNormal(np.sqrt(np.pi/2) * μ_σ))

    μ_ijt = α[pool_code] + β_1 * logX_ijt[:,0] + β_2 * logX_ijt[:,1] + β_3 * logX_ijt[:,2]
    σ_ij = σ[pool_code]

    with numpyro.plate("data", len(pool_code)):
        numpyro.sample("logM_ijt", dist.TruncatedNormal(μ_ijt, σ_ij, low=0), obs=logM_ijt)

In [5]:
def hg_fully_partial(pool_code, logX_ijt, hyperparameters, logM_ijt=None):
    σ_α = numpyro.sample("σ_α", dist.HalfNormal(10))
    σ_β_1 = numpyro.sample("σ_β_1", dist.HalfNormal(10))
    σ_β_2 = numpyro.sample("σ_β_2", dist.HalfNormal(10))

    μ_α = numpyro.sample("μ_α", dist.Normal(hyperparameters['Β_param'][0], np.sqrt(np.diag(hyperparameters['cov_Β'])[0])))
    μ_β_1 = numpyro.sample("μ_β_1", dist.Normal(hyperparameters['Β_param'][1], np.sqrt(np.diag(hyperparameters['cov_Β'])[1])))
    μ_β_2 = numpyro.sample("μ_β_2", dist.Normal(hyperparameters['Β_param'][2], np.sqrt(np.diag(hyperparameters['cov_Β'])[2])))
    β_3 = numpyro.sample("β_3", dist.Normal(hyperparameters['Β_param'][3], np.sqrt(np.diag(hyperparameters['cov_Β'])[3])))
    μ_σ = numpyro.sample("μ_σ", dist.HalfNormal(np.sqrt(np.pi/2) * hyperparameters['σ_param']))

    n_pairs = len(np.unique(pool_code))
    with numpyro.plate("levels", n_pairs):
        α = numpyro.sample("α", dist.Normal(μ_α, σ_α))
        β_1 = numpyro.sample("β_1", dist.Normal(μ_β_1, σ_β_1))
        β_2 = numpyro.sample("β_2", dist.Normal(μ_β_2, σ_β_2))
        σ = numpyro.sample("σ", dist.HalfNormal(np.sqrt(np.pi/2) * μ_σ))

    μ_ijt = α[pool_code] + β_1[pool_code] * logX_ijt[:,0] + β_2[pool_code] * logX_ijt[:,1] + β_3 * logX_ijt[:,2]
    σ_ij = σ[pool_code]
    
    with numpyro.plate("data", len(pool_code)):
        numpyro.sample("logM_ijt", dist.TruncatedNormal(μ_ijt, σ_ij, low=0), obs=logM_ijt)

In [6]:
def hr_intercept_partial(pool_code, logX_ijt, hyperparameters, logM_ijt=None):
    σ_α = numpyro.sample("σ_α", dist.HalfNormal(10))

    μ_α = numpyro.sample("μ_α", dist.Normal(hyperparameters['Β_param'][0], np.sqrt(np.diag(hyperparameters['cov_Β'])[0])))
    β_1 = numpyro.sample("β_1", dist.Normal(hyperparameters['Β_param'][1], np.sqrt(np.diag(hyperparameters['cov_Β'])[1])))
    β_2 = numpyro.sample("β_2", dist.Normal(hyperparameters['Β_param'][2], np.sqrt(np.diag(hyperparameters['cov_Β'])[2])))
    β_3 = numpyro.sample("β_3", dist.Normal(hyperparameters['Β_param'][3], np.sqrt(np.diag(hyperparameters['cov_Β'])[3])))
    β_4 = numpyro.sample("β_4", dist.Normal(hyperparameters['Β_param'][4], np.sqrt(np.diag(hyperparameters['cov_Β'])[4])))
    μ_σ = numpyro.sample("μ_σ", dist.HalfNormal(np.sqrt(np.pi/2) * hyperparameters['σ_param']))

    n_pairs = len(np.unique(pool_code))
    with numpyro.plate("levels", n_pairs):
        α = numpyro.sample("α", dist.Normal(μ_α, σ_α))
        σ = numpyro.sample("σ", dist.HalfNormal(np.sqrt(np.pi/2) * μ_σ))

    μ_ijt = α[pool_code] + β_1 * logX_ijt[:,0] + β_2 * logX_ijt[:,1] + β_3 * logX_ijt[:,2] + β_4 * logX_ijt[:,3]
    σ_ij = σ[pool_code]

    with numpyro.plate("data", len(pool_code)):
        numpyro.sample("logM_ijt", dist.TruncatedNormal(μ_ijt, σ_ij, low=0), obs=logM_ijt)

In [7]:
def hr_fully_partial(pool_code, logX_ijt, hyperparameters, logM_ijt=None):
    σ_α = numpyro.sample("σ_α", dist.HalfNormal(10))
    σ_β_1 = numpyro.sample("σ_β_1", dist.HalfNormal(10))
    σ_β_2 = numpyro.sample("σ_β_2", dist.HalfNormal(10))
    σ_β_3 = numpyro.sample("σ_β_3", dist.HalfNormal(10))
    σ_β_4 = numpyro.sample("σ_β_4", dist.HalfNormal(10))

    μ_α = numpyro.sample("μ_α", dist.Normal(hyperparameters['Β_param'][0], np.sqrt(np.diag(hyperparameters['cov_Β'])[0])))
    μ_β_1 = numpyro.sample("μ_β_1", dist.Normal(hyperparameters['Β_param'][1], np.sqrt(np.diag(hyperparameters['cov_Β'])[1])))
    μ_β_2 = numpyro.sample("μ_β_2", dist.Normal(hyperparameters['Β_param'][2], np.sqrt(np.diag(hyperparameters['cov_Β'])[2])))
    μ_β_3 = numpyro.sample("μ_β_3", dist.Normal(hyperparameters['Β_param'][3], np.sqrt(np.diag(hyperparameters['cov_Β'])[3])))
    μ_β_4 = numpyro.sample("μ_β_4", dist.Normal(hyperparameters['Β_param'][4], np.sqrt(np.diag(hyperparameters['cov_Β'])[4])))
    μ_σ = numpyro.sample("μ_σ", dist.HalfNormal(np.sqrt(np.pi/2) * hyperparameters['σ_param']))

    n_pairs = len(np.unique(pool_code))
    with numpyro.plate("levels", n_pairs):
        α = numpyro.sample("α", dist.Normal(μ_α, σ_α))
        β_1 = numpyro.sample("β_1", dist.Normal(μ_β_1, σ_β_1))
        β_2 = numpyro.sample("β_2", dist.Normal(μ_β_2, σ_β_2))
        β_3 = numpyro.sample("β_3", dist.Normal(μ_β_3, σ_β_3))
        β_4 = numpyro.sample("β_4", dist.Normal(μ_β_4, σ_β_4))
        σ = numpyro.sample("σ", dist.HalfNormal(np.sqrt(np.pi/2) * μ_σ))

    μ_ijt = α[pool_code] + β_1[pool_code] * logX_ijt[:,0] + β_2[pool_code] * logX_ijt[:,1] + β_3[pool_code] * logX_ijt[:,2] + β_4[pool_code] * logX_ijt[:,3]
    σ_ij = σ[pool_code]
    
    with numpyro.plate("data", len(pool_code)):
        numpyro.sample("logM_ijt", dist.TruncatedNormal(μ_ijt, σ_ij, low=0), obs=logM_ijt)

In [8]:
def run_experiment(model, prng_seed=prng_seed, print_progress=False, path_count=path_count, vi=True):

    mae_model, r_squared_model, cpc_model, cpcd_model = np.empty(path_count), np.empty(path_count), np.empty(path_count), np.empty(path_count)
    
    observations = np.random.normal(np.repeat(df_train.query("path_ind == 0").M_ij_mean, num_resamples)  ,  
                np.repeat(df_train.query("path_ind == 0").M_ij_sd, num_resamples)).astype('int')

    tmp = pd.DataFrame(np.repeat(df_train.query("path_ind == 0").values, num_resamples, axis=0), columns=df_train.columns)
    tmp.loc[:,'M_ij'] = observations

    df_train_filtered = tmp.query("M_ij > 0")
    if model in (hg_intercept_partial, hg_fully_partial):
        X_train = df_train_filtered.set_index('State_pair')[['P_i','P_j','D_ij']]
        pair_le = LabelEncoder().fit(X_train.index)
        pair_code = pair_le.transform(X_train.index)
    elif model in (hr_intercept_partial, hr_fully_partial):
        X_train = df_train_filtered.set_index('State_pair')[['P_i','P_j','SP_ij']]
        X_train['P_i + SP_ij'] = X_train.P_i + X_train.SP_ij
        X_train['P_i + P_j + SP_ij'] = X_train.P_i + X_train.P_j + X_train.SP_ij
        pair_le = LabelEncoder().fit(X_train.index)
        pair_code = pair_le.transform(X_train.index)
        X_train.drop('SP_ij', axis=1, inplace=True)
        
    X_train = np.log(X_train.astype('float').values)
    y_train = np.log(df_train_filtered.M_ij.astype('float').values)

    Β_param = np.linalg.inv(X_train.T @ X_train) @ X_train.T @ y_train
    α_param = y_train.mean() - Β_param @ X_train.mean(axis=0)
    ε = y_train - α_param - Β_param @ X_train.T
    σ_param = np.sqrt(ε @ ε / ( X_train.shape[0] - X_train.shape[1] - 1 ))
    cov_Β = σ_param * np.linalg.inv(sm.add_constant(X_train).T @ sm.add_constant(X_train))
    hyperparameters = {'Β_param': np.concatenate([[α_param],Β_param]),
                    'σ_param': σ_param,
                    'cov_Β': cov_Β}
    
    if model in (hg_intercept_partial, hr_intercept_partial):
        reparam_config = {"α": LocScaleReparam(0)}
    elif model == hg_fully_partial:
        reparam_config = {"α": LocScaleReparam(0), "β_1": LocScaleReparam(0), "β_2": LocScaleReparam(0)}
    elif model == hr_fully_partial:
        reparam_config = {"α": LocScaleReparam(0), "β_1": LocScaleReparam(0), "β_2": LocScaleReparam(0), "β_3": LocScaleReparam(0), "β_4": LocScaleReparam(0)}

    reparam_model = reparam(model, config=reparam_config)
    
    if vi:
        guide = AutoDiagonalNormal(reparam_model, init_loc_fn=numpyro.infer.init_to_feasible)
        svi = SVI(reparam_model, guide, Adam(learning_rate), Trace_ELBO())
        svi_result = svi.run(prng_seed, vi_iters, pair_code, X_train, hyperparameters, y_train, progress_bar=print_progress)
        neutra = NeuTraReparam(guide, svi_result.params)
        reparam_model = neutra.reparam(reparam_model)

    nuts_kernel = NUTS(reparam_model, init_strategy=numpyro.infer.init_to_feasible, target_accept_prob=target_accept_prob, max_tree_depth=max_tree_depth)

    mcmc = MCMC(nuts_kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=print_progress)
    
    mcmc.warmup(prng_seed, pair_code, X_train, hyperparameters, y_train, collect_warmup=True, extra_fields=["potential_energy"])
    poten_model = mcmc.get_extra_fields()["potential_energy"]
    divergences_model = mcmc.get_extra_fields()["diverging"]
    mcmc.run(prng_seed, pair_code, X_train, hyperparameters, y_train)
    
    if vi:
        posteriors_model = neutra.transform_sample(mcmc.get_samples()['auto_shared_latent'])
    else:
        posteriors_model = mcmc.get_samples()
    
    for path in df.path_ind.unique()[:path_count]:

        df_test_filtered = df_test.query("path_ind == @path & M_ij != 0 & State_pair in {}".format(list(df_train_filtered.State_pair.unique())))

        if model in (hg_intercept_partial, hg_fully_partial):
            X_test = df_test_filtered.set_index('State_pair')[['P_i','P_j','D_ij']]
            pair_code = pair_le.transform(X_test.index)
        elif model in (hr_intercept_partial, hr_fully_partial):
            X_test = df_test_filtered.set_index('State_pair')[['P_i','P_j','SP_ij']]
            X_test['P_i + SP_ij'] = X_test.P_i + X_test.SP_ij
            X_test['P_i + P_j + SP_ij'] = X_test.P_i + X_test.P_j + X_test.SP_ij
            pair_code = pair_le.transform(X_test.index)
            X_test.drop('SP_ij', axis=1, inplace=True)

        X_test = np.log(X_test.astype('float').values)
        y_test = df_test_filtered.M_ij

        predictive = Predictive(reparam_model, posteriors_model, return_sites=["logM_ijt"])

        samples_predictive = predictive(prng_seed, pair_code, X_test, hyperparameters)

        pred = np.exp( samples_predictive["logM_ijt"].mean(axis=0) )

        mae_model[path] = mae(y_test, pred )
        r_squared_model[path] = r_squared(y_test, pred)
        cpc_model[path] = cpc(y_test, pred)
        cpcd_model[path] = cpcd(y_test, pred, df_test_filtered.D_ij)

        print("Path {}/{} delivers MAE {:0.3f}".format(path+1,path_count, mae_model[path]))
    
    model_results = {'mae_model': mae_model,
                     'r_squared_model': r_squared_model,
                     'cpc_model': cpc_model,
                     'cpcd_model': cpcd_model,
                     'poten_model': poten_model,
                     'posteriors_model': posteriors_model,
                     'divergences': divergences_model
    }
    
    return model_results

In [9]:
def print_results(model_results):
    print("----------------------------------------------")
    for metric in list(model_results.keys())[:4]:
        print("{}: {:0.3f}, +/- {:0.3f}".format(metric,
                                    model_results[metric].mean(),
                                    norm.ppf(.975) * model_results[metric].std(ddof=1) / np.sqrt(path_count)
            )
        )
    print("----------------------------------------------")
    print("Percent of warm-up transitions that are divergent: {:0.1f}%".format(100*np.array(model_results['divergences']).mean()))

In [10]:
def save_results(model_results):
    with open('../results/{}.pkl'.format(f'{model_results=}'.split('=')[0]), 'wb') as f:
        pickle.dump(model_results, f, protocol=pickle.HIGHEST_PROTOCOL)

## Gravity

### Varying intercept

In [11]:
hg_intercept_partial_results_upsampled = run_experiment(hg_intercept_partial, print_progress=False, path_count=5, vi=True)

print_results(hg_intercept_partial_results_upsampled)
# save_results(hg_intercept_partial_results_upsampled)

----------------------------------------------
mae_model: 0.000, +/- 0.000
r_squared_model: 0.000, +/- 0.000
cpc_model: 0.000, +/- 0.000
cpcd_model: 0.000, +/- 0.000
----------------------------------------------
Percent of warm-up transitions that are divergent: 1.6%


### Varying intercept & coefficients

In [12]:
hg_fully_partial_results_upsampled = run_experiment(hg_fully_partial, path_count=5, vi=True)

print_results(hg_fully_partial_results_upsampled)
# save_results(hg_fully_partial_results_upsampled)

----------------------------------------------
mae_model: 0.000, +/- 0.000
r_squared_model: 0.000, +/- 0.000
cpc_model: 0.000, +/- 0.000
cpcd_model: 0.000, +/- 0.000
----------------------------------------------
Percent of warm-up transitions that are divergent: 2.4%


## Radiation

### Varying intercept

In [13]:
hr_intercept_partial_results_upsampled = run_experiment(hr_intercept_partial, path_count=5, vi=True)

print_results(hr_intercept_partial_results_upsampled)
# save_results(hr_intercept_partial_results_upsampled)

Path 1/5 delivers MAE 1068.069
Path 2/5 delivers MAE 1079.170
Path 3/5 delivers MAE 1043.206
Path 4/5 delivers MAE 1086.266
Path 5/5 delivers MAE 1074.553
----------------------------------------------
mae_model: 1070.253, +/- 14.474
r_squared_model: 0.830, +/- 0.006
cpc_model: 0.827, +/- 0.002
cpcd_model: 0.960, +/- 0.001
----------------------------------------------
Percent of warm-up transitions that are divergent: 2.4%


### Varying intercept & coefficients

In [14]:
hr_fully_partial_results_upsampled = run_experiment(hr_fully_partial, path_count=5, vi=False)

print_results(hr_fully_partial_results_upsampled)
save_results(hr_fully_partial_results_upsampled)

Path 1/5 delivers MAE 1066.853
Path 2/5 delivers MAE 1080.202
Path 3/5 delivers MAE 1043.461
Path 4/5 delivers MAE 1084.733
Path 5/5 delivers MAE 1073.525
----------------------------------------------
mae_model: 1069.755, +/- 14.184
r_squared_model: 0.831, +/- 0.006
cpc_model: 0.827, +/- 0.002
cpcd_model: 0.961, +/- 0.001
----------------------------------------------
Percent of warm-up transitions that are divergent: 2.4%
