In [17]:
# import os
# os.chdir("birthrate_mtgp")
from jax import numpy as jnp
import numpy as np
import numpyro.distributions as dist
import jax.numpy as jnp
import numpyro
from numpyro.handlers import scope

from models.panel_nmf_model import model
from models.utils import missingness_adjustment
from numpyro_to_draws_df_csv import dict_to_tidybayes

import pandas as pd

In [None]:
## Some defaults for testing
dist = "NB"
outcome_type = "births"
cat_name = "total"
rank = 5
sample_disp = False
missingness=True
disp_param = 1e-4
model_treated = True
placebo_time = "2019-03-01"
num_chains = 1
num_samples = 100
num_warmup=100
start_time = '2016-01-01'
end_time = '2024-01-01'
def run_model(dist, outcome_type="births", cat_name="total", rank=5, missingness=True, 
         disp_param=1e-4, sample_disp=False, placebo_state = None, 
         start_time = '2016-01-01', end_time = '2023-12-31',
         placebo_time = None, dobbs_donor_sensitivity=False,
         model_treated=True, results_file_suffix = "",
         num_chains=num_chains, num_warmup=1000, num_samples=1000, thinning=1):
    
    numpyro.set_host_device_count(num_chains)

    # df = pd.read_csv('data/dobbsbimonthlybirths_10_23_24.csv')
    df = pd.read_csv('/Users/shaokangyang/Library/CloudStorage/GoogleDrive-sky.ang510@gmail.com/My Drive/Code/dobbs_fertility/data/fertility_data.csv')
    
    from clean_monthly_birth_data import prep_data, clean_dataframe, create_unit_placebo_dataset, create_time_placebo_dataset
    
    df = clean_dataframe(df, outcome_type, cat_name,  
                         dobbs_donor_sensitivity=dobbs_donor_sensitivity, csv_filename=None)
    df = df[df['time'] <= pd.to_datetime(end_time)]
    df = df.sort_values(by=['state', 'year', 'bmcode']) 
    df = df.drop_duplicates()

    if placebo_state is not None and placebo_state != "Texas":
        df = create_unit_placebo_dataset(df, placebo_state = placebo_state)
    
    if placebo_time is not None:
        df = create_time_placebo_dataset(df, new_treatment_start = placebo_time)
    else:
        # Only use data from 2016 onwards if not using a placebo time
        df = df[df['time'] >= pd.to_datetime(start_time)]  

    data_dict_cat = prep_data(df, outcome_type=outcome_type, group=cat_name)

    data_dict_cat['Y'].shape
    data_dict_cat['denominators'].shape
    data_dict_cat['control_idx_array'].shape
    
    import numpy as np
    from jax import random
    from numpyro.infer import MCMC, NUTS, Predictive

    #from models.monthly_model import monthly_model

    # set the random seed
    rng_key = random.PRNGKey(8675309)
    # split the random key
    rng_key, rng_key_ = random.split(rng_key)
    # Setup the sampler
    kernel = NUTS(model)

    mcmc = MCMC(
        kernel,
        num_warmup=num_warmup,
        num_samples=num_samples,
        num_chains=num_chains,
        progress_bar=True,
        thinning=thinning
    )

    mcmc.run(
        rng_key_,
        y=data_dict_cat['Y'],
        denominators=data_dict_cat['denominators'],
        control_idx_array=data_dict_cat['control_idx_array'],
        missing_idx_array=data_dict_cat['missing_idx_array'],
        rank=rank,
        outcome_dist=dist,
        adjust_for_missingness=missingness,
        nb_disp = disp_param,
        sample_disp = sample_disp,
        model_treated = model_treated
    )

    samples = mcmc.get_samples(group_by_chain=True)
    predictive = Predictive(model, mcmc.get_samples(group_by_chain=False))
    rng_key, rng_key_ = random.split(rng_key)

    predictions = predictive(
        rng_key_,
        denominators=data_dict_cat['denominators'],
        control_idx_array=None, #data_dict_cat['control_idx_array'],
        missing_idx_array=None, #data_dict_cat['missing_idx_array'],
        rank=rank,
        outcome_dist=dist,
        nb_disp = disp_param,
        sample_disp = sample_disp,
        model_treated = False
    )['y_obs']
    K, D, N = data_dict_cat['denominators'].shape
    pred_mat = predictions.reshape(mcmc.num_chains, int(mcmc.num_samples / mcmc.thinning), K, D, N)
   
    ## Take Python output and convert to draws matrix form
    params = dict_to_tidybayes({'mu': samples['mu_ctrl'], 'te': samples['te'], 'disp' : samples['disp']})
    preds = dict_to_tidybayes({"ypred" : pred_mat})

    preds[".chain"] = params[".chain"]
    preds[".draw"] = params[".draw"]

    all_samples = params.merge(preds, left_on = ['.draw', '.chain'], right_on = ['.draw', '.chain'])
    results_df = pd.DataFrame(all_samples)

    ## save input df
    df.to_csv('/Users/shaokangyang/Library/CloudStorage/GoogleDrive-sky.ang510@gmail.com/My Drive/Code/fertility_results/df_{}.csv'.format(results_file_suffix))
    ## save posterior samples
    results_df.to_csv(
        '/Users/shaokangyang/Library/CloudStorage/GoogleDrive-sky.ang510@gmail.com/My Drive/Code/fertility_results/{}_{}_{}_{}_{}.csv'.format(dist, "births", cat_name, rank, results_file_suffix)
    )

    
if __name__ == '__main__':
    from clean_monthly_birth_data import subgroup_definitions
    from joblib import Parallel, delayed

    # Define the inputs for the function
    inputs = [6, 7, 8, 9, 10, 11, 12]
    outcome_type = "births" 
    cats = list(subgroup_definitions[outcome_type].keys())
    dists = ['NB'] # Poisson or NB
    missing_flags = [True]
    # disp_params = [1e-4, 1e-3]
    disp_params = [1e-4]
    ## placebo_times = ["2020-05-01"]
    placebo_times = [None]
    placebo_states = [None]
    sample_disp = False
    dobbs_donor_sensitivity = False

    args = [(dist, cat, rank, m, disp, p, tm) for dist in dists for rank in inputs for cat in cats 
            for m in missing_flags for disp in disp_params for p in placebo_states 
            for tm in placebo_times]
    # Run the function in parallel
    results = Parallel(n_jobs=8)(delayed(run_model)(dist=i[0], outcome_type=outcome_type, cat_name=i[1], rank=i[2], missingness=i[3], 
                                                disp_param=i[4],
                                                sample_disp=sample_disp, placebo_state=i[5], placebo_time = i[6], 
                                                dobbs_donor_sensitivity=dobbs_donor_sensitivity, 
                                                results_file_suffix="through_june", num_chains=4, num_samples=2500, num_warmup=1000, thinning=10) for i in args)

In [18]:
from run_model import run_model
from joblib import Parallel, delayed

# Only rerun subgroups that may be missing ypred
rerun_cats = ['race', 'edu']

inputs = [6, 7, 8, 9, 10, 11, 12]  # latent ranks
outcome_type = "births"
dists = ['NB']
missing_flags = [True]
disp_params = [1e-4]
placebo_times = [None]
placebo_states = [None]
sample_disp = False
results_file_suffix = "rerun_ypred_fix"  # Distinct output

# Generate parameter combinations
args = [(dist, cat, rank, m, disp, p, tm)
        for dist in dists
        for rank in inputs
        for cat in rerun_cats
        for m in missing_flags
        for disp in disp_params
        for p in placebo_states
        for tm in placebo_times]

# Run the models in parallel
results = Parallel(n_jobs=8)(delayed(run_model)(
    dist=i[0], outcome_type=outcome_type, cat_name=i[1], rank=i[2],
    missingness=i[3], disp_param=i[4], sample_disp=sample_disp,
    placebo_state=i[5], placebo_time=i[6],
    results_file_suffix=results_file_suffix,
    num_chains=2, num_samples=1000, num_warmup=500, thinning=5,
    dobbs_donor_sensitivity=True  # or False if not filtering
) for i in args)


['Alabama']
['Alabama']
['Alaska']
['Alaska']
['Arizona']
['Alabama']
['Alabama']
['Arizona']
['Alabama']
['Arkansas']
['Arkansas']
['Alaska']
['Alaska']
['California']
['Alaska']
['California']
['Arizona']
['Arizona']
['Colorado']
['Colorado']
['Arkansas']
['Connecticut']
['Arizona']
['Arkansas']
['Alabama']
['Connecticut']
['California']
['Delaware']
['California']
['Alabama']
['Arkansas']
['Delaware']
['Colorado']
['District of Columbia']
['Alaska']
['Colorado']
['Alaska']
['District of Columbia']
['California']
['Connecticut']
['Florida']
['Arizona']
['Connecticut']
['Delaware']
['Florida']
['Colorado']
['Arizona']
['Arkansas']
['Georgia']
['Delaware']
['District of Columbia']
['Connecticut']
['Arkansas']
['Georgia']
['District of Columbia']
['California']
['Hawaii']
['Florida']
['California']
['Idaho']
['Georgia']
['Florida']
['Delaware']
['Colorado']
['Colorado']
['Hawaii']
['Georgia']
['District of Columbia']
['Connecticut']
['Hawaii']
['Illinois']
['Hawaii']
['Florida']
['Delaw

Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s]
  0%|          | 0/1500 [00:00<?, ?it/s][A
Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s][A
  0%|          | 0/1500 [00:00<?, ?it/s][A
Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s][A
  0%|          | 0/1500 [00:00<?, ?it/s][A
Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s][A
  0%|          | 0/1500 [00:00<?, ?it/s][A
Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s][A
  0%|          | 0/1500 [00:00<?, ?it/s][A
Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s][A
  0%|          | 0/1500 [00:00<?, ?it/s][A
Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s][A
  0%|          | 0/1500 [00:00<?, ?it/s][A
Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s][A

(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)


Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s]
  0%|          | 0/1500 [00:00<?, ?it/s][A
Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s][A

(4, 51)
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)


Running chain 0:   0%|          | 0/1500 [00:13<?, ?it/s]
Running chain 0:   0%|          | 0/1500 [00:14<?, ?it/s][A
Running chain 0:   0%|          | 0/1500 [00:14<?, ?it/s][A
Running chain 0:   0%|          | 0/1500 [00:14<?, ?it/s][A
Running chain 1:   0%|          | 0/1500 [00:14<?, ?it/s][A
Running chain 0:   0%|          | 0/1500 [00:14<?, ?it/s][A
Running chain 0:   0%|          | 0/1500 [00:14<?, ?it/s][A
Running chain 0:   0%|          | 0/1500 [00:15<?, ?it/s][A
Running chain 1:   0%|          | 0/1500 [00:15<?, ?it/s][A
Running chain 0:   5%|▌         | 75/1500 [01:44<28:47,  1.21s/it][A
Running chain 1:   5%|▌         | 75/1500 [01:53<31:28,  1.33s/it][A
Running chain 0:   5%|▌         | 75/1500 [02:04<34:38,  1.46s/it][A
Running chain 0:   5%|▌         | 75/1500 [02:04<34:39,  1.46s/it][A
Running chain 0:   5%|▌         | 75/1500 [02:10<36:41,  1.55s/it][A
Running chain 1:   5%|▌         | 75/1500 [02:12<37:25,  1.58s/it][A
Running chain 0:   5%|▌         | 

['mu', 'te', 'disp']



Running chain 1:  95%|█████████▌| 1425/1500 [49:35<02:41,  2.15s/it][A

['ypred']


Running chain 0:  90%|█████████ | 1350/1500 [49:46<05:39,  2.26s/it]
Running chain 1:  90%|█████████ | 1350/1500 [49:53<05:37,  2.25s/it][A

['Alabama']
['Alaska']
['Arizona']
['Arkansas']
['California']
['Colorado']
['Connecticut']
['Delaware']
['District of Columbia']
['Florida']
['Georgia']
['Hawaii']
['Idaho']
['Illinois']
['Indiana']
['Iowa']
['Kansas']
['Kentucky']
['Louisiana']
['Maine']
['Maryland']
['Massachusetts']
['Michigan']
['Minnesota']
['Mississippi']
['Missouri']
['Montana']
['Nebraska']
['Nevada']
['New Hampshire']
['New Jersey']
['New Mexico']
['New York']
['North Carolina']
['North Dakota']
['Ohio']
['Oklahoma']
['Oregon']
['Pennsylvania']
['Rhode Island']
['South Carolina']
['South Dakota']
['Tennessee']
['Texas']
['Utah']
['Vermont']
['Virginia']
['Washington']
['West Virginia']
['Wisconsin']
['Wyoming']
['Alabama', 'Arkansas', 'Georgia', 'Idaho', 'Kentucky', 'Louisiana', 'Mississippi', 'Missouri', 'Oklahoma', 'South Dakota', 'Tennessee', 'Texas', 'West Virginia', 'Wisconsin']
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)


Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s]
  0%|          | 0/1500 [00:00<?, ?it/s][A
Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s][A

(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)


Running chain 0:   0%|          | 0/1500 [00:17<?, ?it/s]
Running chain 0:  80%|████████  | 1200/1500 [51:03<13:03,  2.61s/it]
Running chain 1:  80%|████████  | 1200/1500 [51:39<13:07,  2.62s/it][A
Running chain 0:  85%|████████▌ | 1275/1500 [51:57<09:18,  2.48s/it][A
Running chain 1:  85%|████████▌ | 1275/1500 [52:02<09:23,  2.50s/it][A
Running chain 0:  90%|█████████ | 1350/1500 [52:10<05:54,  2.37s/it][A
Running chain 1: 100%|██████████| 1500/1500 [52:11<00:00,  2.09s/it][A

Running chain 0: 100%|██████████| 1500/1500 [52:19<00:00,  2.09s/it][A


['mu', 'te', 'disp']


Running chain 0:   5%|▌         | 75/1500 [02:17<37:54,  1.60s/it]t]
Running chain 1:   5%|▌         | 75/1500 [02:18<38:17,  1.61s/it][A

['ypred']


Running chain 0:  95%|█████████▌| 1425/1500 [52:33<02:48,  2.25s/it]
Running chain 1:  95%|█████████▌| 1425/1500 [52:38<02:47,  2.24s/it][A

['Alabama']
['Alaska']
['Arizona']
['Arkansas']
['California']
['Colorado']
['Connecticut']
['Delaware']
['District of Columbia']
['Florida']
['Georgia']
['Hawaii']
['Idaho']
['Illinois']
['Indiana']
['Iowa']
['Kansas']
['Kentucky']
['Louisiana']
['Maine']
['Maryland']
['Massachusetts']
['Michigan']
['Minnesota']
['Mississippi']
['Missouri']
['Montana']
['Nebraska']
['Nevada']
['New Hampshire']
['New Jersey']
['New Mexico']
['New York']
['North Carolina']
['North Dakota']
['Ohio']
['Oklahoma']
['Oregon']
['Pennsylvania']
['Rhode Island']
['South Carolina']
['South Dakota']
['Tennessee']
['Texas']
['Utah']
['Vermont']
['Virginia']
['Washington']
['West Virginia']
['Wisconsin']
['Wyoming']
['Alabama', 'Arkansas', 'Georgia', 'Idaho', 'Kentucky', 'Louisiana', 'Mississippi', 'Missouri', 'Oklahoma', 'South Dakota', 'Tennessee', 'Texas', 'West Virginia', 'Wisconsin']
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)


Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s]
  0%|          | 0/1500 [00:00<?, ?it/s][A
Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s][A

(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)


Running chain 0:   0%|          | 0/1500 [00:18<?, ?it/s]
Running chain 0:  85%|████████▌ | 1275/1500 [54:17<09:45,  2.60s/it]
Running chain 1:  90%|█████████ | 1350/1500 [54:46<06:12,  2.48s/it][A
Running chain 0:  90%|█████████ | 1350/1500 [55:00<06:10,  2.47s/it][A
Running chain 0:  95%|█████████▌| 1425/1500 [55:05<02:56,  2.36s/it][A
Running chain 1:  95%|█████████▌| 1425/1500 [55:07<02:56,  2.35s/it][A
Running chain 0: 100%|██████████| 1500/1500 [55:21<00:00,  2.21s/it][A

Running chain 1: 100%|██████████| 1500/1500 [55:25<00:00,  2.22s/it][A
Running chain 0:   5%|▌         | 75/1500 [02:23<39:49,  1.68s/it]
Running chain 1:   5%|▌         | 75/1500 [02:27<41:01,  1.73s/it][A

['mu', 'te', 'disp']
['ypred']


Running chain 0:  10%|█         | 150/1500 [05:45<51:26,  2.29s/it]
Running chain 1:  10%|█         | 150/1500 [05:45<51:24,  2.28s/it][A

['Alabama']
['Alaska']
['Arizona']
['Arkansas']
['California']
['Colorado']
['Connecticut']
['Delaware']
['District of Columbia']
['Florida']
['Georgia']
['Hawaii']
['Idaho']
['Illinois']
['Indiana']
['Iowa']
['Kansas']
['Kentucky']
['Louisiana']
['Maine']
['Maryland']
['Massachusetts']
['Michigan']
['Minnesota']
['Mississippi']
['Missouri']
['Montana']
['Nebraska']
['Nevada']
['New Hampshire']
['New Jersey']
['New Mexico']
['New York']
['North Carolina']
['North Dakota']
['Ohio']
['Oklahoma']
['Oregon']
['Pennsylvania']
['Rhode Island']
['South Carolina']
['South Dakota']
['Tennessee']
['Texas']
['Utah']
['Vermont']
['Virginia']
['Washington']
['West Virginia']
['Wisconsin']
['Wyoming']
['Alabama', 'Arkansas', 'Georgia', 'Idaho', 'Kentucky', 'Louisiana', 'Mississippi', 'Missouri', 'Oklahoma', 'South Dakota', 'Tennessee', 'Texas', 'West Virginia', 'Wisconsin']
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)


Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s]
  0%|          | 0/1500 [00:00<?, ?it/s][A
Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s][A

(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)


Running chain 0:   0%|          | 0/1500 [00:19<?, ?it/s]
Running chain 0:  90%|█████████ | 1350/1500 [57:31<06:30,  2.60s/it]
Running chain 1:  95%|█████████▌| 1425/1500 [57:52<03:06,  2.48s/it][A
Running chain 1: 100%|██████████| 1500/1500 [57:58<00:00,  2.32s/it][A
Running chain 0: 100%|██████████| 1500/1500 [58:00<00:00,  2.32s/it]

Running chain 1: 100%|██████████| 1500/1500 [58:03<00:00,  2.32s/it][A
Running chain 0:  95%|█████████▌| 1425/1500 [58:04<03:06,  2.48s/it]

['mu', 'te', 'disp']



Running chain 1:  90%|█████████ | 1350/1500 [58:08<06:31,  2.61s/it][A
Running chain 0: 100%|██████████| 1500/1500 [58:14<00:00,  2.33s/it][A


['ypred']
['mu', 'te', 'disp']
['ypred']


Running chain 0:   5%|▌         | 75/1500 [02:27<40:23,  1.70s/it]

['Alabama']
['Alaska']
['Arizona']
['Arkansas']
['California']
['Colorado']
['Connecticut']
['Delaware']
['District of Columbia']
['Florida']
['Georgia']
['Hawaii']
['Idaho']
['Illinois']
['Indiana']
['Iowa']
['Kansas']
['Kentucky']
['Louisiana']
['Maine']
['Maryland']
['Massachusetts']
['Michigan']
['Minnesota']
['Mississippi']
['Missouri']
['Montana']
['Nebraska']
['Nevada']
['New Hampshire']
['New Jersey']
['New Mexico']
['New York']
['North Carolina']
['North Dakota']
['Ohio']
['Oklahoma']
['Oregon']
['Pennsylvania']
['Rhode Island']
['South Carolina']
['South Dakota']
['Tennessee']
['Texas']
['Utah']
['Vermont']
['Virginia']
['Washington']
['West Virginia']
['Wisconsin']
['Wyoming']
['Alabama', 'Arkansas', 'Georgia', 'Idaho', 'Kentucky', 'Louisiana', 'Mississippi', 'Missouri', 'Oklahoma', 'South Dakota', 'Tennessee', 'Texas', 'West Virginia', 'Wisconsin']
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)


Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s]
  0%|          | 0/1500 [00:00<?, ?it/s][A
Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s][A

(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)



Running chain 1:   5%|▌         | 75/1500 [02:36<43:23,  1.83s/it][A

['Alabama']
['Alaska']
['Arizona']
['Arkansas']
['California']
['Colorado']
['Connecticut']
['Delaware']
['District of Columbia']
['Florida']
['Georgia']
['Hawaii']
['Idaho']
['Illinois']
['Indiana']
['Iowa']
['Kansas']
['Kentucky']
['Louisiana']
['Maine']
['Maryland']
['Massachusetts']
['Michigan']
['Minnesota']
['Mississippi']
['Missouri']
['Montana']
['Nebraska']
['Nevada']
['New Hampshire']
['New Jersey']
['New Mexico']
['New York']
['North Carolina']
['North Dakota']
['Ohio']
['Oklahoma']
['Oregon']
['Pennsylvania']
['Rhode Island']
['South Carolina']
['South Dakota']
['Tennessee']
['Texas']
['Utah']
['Vermont']
['Virginia']
['Washington']
['West Virginia']
['Wisconsin']
['Wyoming']
['Alabama', 'Arkansas', 'Georgia', 'Idaho', 'Kentucky', 'Louisiana', 'Mississippi', 'Missouri', 'Oklahoma', 'South Dakota', 'Tennessee', 'Texas', 'West Virginia', 'Wisconsin']
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)


Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s]
  0%|          | 0/1500 [00:00<?, ?it/s][A
Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s][A

(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)


Running chain 0:   0%|          | 0/1500 [00:16<?, ?it/s] 2.35s/it]
Running chain 1:   0%|          | 0/1500 [00:16<?, ?it/s][A
Running chain 0:   0%|          | 0/1500 [00:20<?, ?it/s] 2.37s/it][A
Running chain 1:   0%|          | 0/1500 [00:20<?, ?it/s][A
Running chain 0:  95%|█████████▌| 1425/1500 [1:00:44<03:14,  2.59s/it]
Running chain 1: 100%|██████████| 1500/1500 [1:00:53<00:00,  2.44s/it][A
Running chain 0: 100%|██████████| 1500/1500 [1:01:04<00:00,  2.44s/it]
Running chain 0: 100%|██████████| 1500/1500 [1:01:07<00:00,  2.44s/it]


['mu', 'te', 'disp']



Running chain 1: 100%|██████████| 1500/1500 [1:01:13<00:00,  2.45s/it][A
Running chain 0:   5%|▌         | 75/1500 [02:28<42:02,  1.77s/it]
Running chain 1:  95%|█████████▌| 1425/1500 [1:01:16<03:13,  2.58s/it][A

['ypred']



Running chain 1:   5%|▌         | 75/1500 [02:32<43:07,  1.82s/it][A

['mu', 'te', 'disp']
['ypred']



Running chain 0:   5%|▌         | 75/1500 [02:39<44:04,  1.86s/it][A

['Alabama']
['Alaska']
['Arizona']
['Arkansas']
['California']
['Colorado']
['Connecticut']
['Delaware']
['District of Columbia']
['Florida']
['Georgia']
['Hawaii']
['Idaho']
['Illinois']
['Indiana']
['Iowa']
['Kansas']
['Kentucky']
['Louisiana']
['Maine']
['Maryland']
['Massachusetts']
['Michigan']
['Minnesota']
['Mississippi']
['Missouri']
['Montana']
['Nebraska']
['Nevada']
['New Hampshire']
['New Jersey']
['New Mexico']
['New York']
['North Carolina']
['North Dakota']
['Ohio']
['Oklahoma']
['Oregon']
['Pennsylvania']
['Rhode Island']
['South Carolina']
['South Dakota']
['Tennessee']
['Texas']
['Utah']
['Vermont']
['Virginia']
['Washington']
['West Virginia']
['Wisconsin']
['Wyoming']
['Alabama', 'Arkansas', 'Georgia', 'Idaho', 'Kentucky', 'Louisiana', 'Mississippi', 'Missouri', 'Oklahoma', 'South Dakota', 'Tennessee', 'Texas', 'West Virginia', 'Wisconsin']
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)


Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s]
  0%|          | 0/1500 [00:00<?, ?it/s][A
Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s][A

(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)


Running chain 0:   0%|          | 0/1500 [00:16<?, ?it/s]
Running chain 0:  10%|█         | 150/1500 [05:52<52:00,  2.31s/it]
Running chain 0:  15%|█▌        | 225/1500 [09:22<53:58,  2.54s/it][A
Running chain 1:  15%|█▌        | 225/1500 [09:29<54:34,  2.57s/it][A
Running chain 0: 100%|██████████| 1500/1500 [1:03:40<00:00,  2.55s/it]

Running chain 1: 100%|██████████| 1500/1500 [1:04:07<00:00,  2.56s/it][A
Running chain 0:   5%|▌         | 75/1500 [02:21<39:27,  1.66s/it]
Running chain 1:   5%|▌         | 75/1500 [02:21<39:29,  1.66s/it][A

['mu', 'te', 'disp']
['ypred']


Running chain 0:  10%|█         | 150/1500 [05:45<51:03,  2.27s/it]
Running chain 1:  10%|█         | 150/1500 [05:47<51:14,  2.28s/it][A
Running chain 0:  15%|█▌        | 225/1500 [08:56<50:26,  2.37s/it][A
Running chain 0:  20%|██        | 300/1500 [12:24<49:52,  2.49s/it][A
Running chain 1:  25%|██▌       | 375/1500 [15:17<46:02,  2.46s/it][A
Running chain 0:  10%|█         | 150/1500 [05:28<48:22,  2.15s/it][A
Running chain 0:  15%|█▌        | 225/1500 [08:38<48:34,  2.29s/it][A
Running chain 0:  15%|█▌        | 225/1500 [08:51<49:03,  2.31s/it][A
Running chain 0:  20%|██        | 300/1500 [11:39<45:54,  2.30s/it][A
Running chain 1:  15%|█▌        | 225/1500 [08:58<50:03,  2.36s/it][A
Running chain 0:  30%|███       | 450/1500 [18:04<41:24,  2.37s/it][A
Running chain 0:  20%|██        | 300/1500 [11:26<45:26,  2.27s/it][A
Running chain 1:  20%|██        | 300/1500 [11:28<45:25,  2.27s/it][A
Running chain 0:  15%|█▌        | 225/1500 [08:30<48:25,  2.28s/it][A
Running c

['mu', 'te', 'disp']



Running chain 0:  80%|████████  | 1200/1500 [47:17<11:52,  2.37s/it][A

['ypred']



Running chain 0:  90%|█████████ | 1350/1500 [50:22<05:30,  2.21s/it][A
Running chain 0:  75%|███████▌  | 1125/1500 [45:14<15:11,  2.43s/it][A
Running chain 0:  95%|█████████▌| 1425/1500 [54:26<02:45,  2.21s/it][A
Running chain 0:  85%|████████▌ | 1275/1500 [48:47<08:27,  2.26s/it][A
Running chain 1:  95%|█████████▌| 1425/1500 [52:28<02:38,  2.11s/it][A
Running chain 0:  95%|█████████▌| 1425/1500 [52:44<02:38,  2.11s/it][A
Running chain 0: 100%|██████████| 1500/1500 [56:50<00:00,  2.27s/it][A

Running chain 1: 100%|██████████| 1500/1500 [56:51<00:00,  2.27s/it][A

Running chain 1:  90%|█████████ | 1350/1500 [51:10<05:26,  2.17s/it][A

['mu', 'te', 'disp']


Running chain 0:  90%|█████████ | 1350/1500 [51:13<05:24,  2.16s/it]

['ypred']



Running chain 1: 100%|██████████| 1500/1500 [54:41<00:00,  2.19s/it][A

Running chain 0: 100%|██████████| 1500/1500 [54:53<00:00,  2.20s/it][A


['mu', 'te', 'disp']
['ypred']



Running chain 0:  85%|████████▌ | 1275/1500 [50:02<08:04,  2.15s/it][A
Running chain 0:  95%|█████████▌| 1425/1500 [53:12<02:29,  1.99s/it][A
Running chain 0:  95%|█████████▌| 1425/1500 [54:03<02:26,  1.96s/it][A
Running chain 1: 100%|██████████| 1500/1500 [54:57<00:00,  2.20s/it][A

Running chain 0: 100%|██████████| 1500/1500 [55:00<00:00,  2.20s/it][A


['mu', 'te', 'disp']
['ypred']



Running chain 1: 100%|██████████| 1500/1500 [55:49<00:00,  2.23s/it][A
Running chain 0: 100%|██████████| 1500/1500 [55:51<00:00,  2.23s/it]


['mu', 'te', 'disp']
['ypred']



Running chain 0:  95%|█████████▌| 1425/1500 [53:45<02:15,  1.81s/it][A
Running chain 1: 100%|██████████| 1500/1500 [55:28<00:00,  2.22s/it][A
Running chain 0: 100%|██████████| 1500/1500 [55:29<00:00,  2.22s/it]


['mu', 'te', 'disp']
['ypred']
