In [7]:
# import os
# os.chdir("birthrate_mtgp")
from jax import numpy as jnp
import numpy as np
import numpyro.distributions as dist
import jax.numpy as jnp
import numpyro
from numpyro.handlers import scope

from models.panel_nmf_model import model
from models.utils import missingness_adjustment
from numpyro_to_draws_df_csv import dict_to_tidybayes

import pandas as pd

In [None]:
## Some defaults for testing
dist = "NB"
outcome_type = "births"
cat_name = "total"
rank = 5
sample_disp = False
missingness=True
disp_param = 1e-4
model_treated = True
placebo_time = "2019-03-01"
num_chains = 1
num_samples = 100
num_warmup=100
start_time = '2016-01-01'
end_time = '2024-01-01'
def run_model(dist, outcome_type="births", cat_name="total", rank=5, missingness=True, 
         disp_param=1e-4, sample_disp=False, placebo_state = None, 
         start_time = '2016-01-01', end_time = '2023-12-31',
         placebo_time = None, dobbs_donor_sensitivity=False,
         model_treated=True, results_file_suffix = "",
         num_chains=num_chains, num_warmup=1000, num_samples=1000, thinning=1):
    
    numpyro.set_host_device_count(num_chains)

    # df = pd.read_csv('data/dobbsbimonthlybirths_10_23_24.csv')
    df = pd.read_csv('/Users/shaokangyang/Library/CloudStorage/GoogleDrive-sky.ang510@gmail.com/My Drive/Code/dobbs_fertility/data/fertility_data.csv')
    
    from clean_monthly_birth_data import prep_data, clean_dataframe, create_unit_placebo_dataset, create_time_placebo_dataset
    
    df = clean_dataframe(df, outcome_type, cat_name,  
                         dobbs_donor_sensitivity=dobbs_donor_sensitivity, csv_filename=None)
    df = df[df['time'] <= pd.to_datetime(end_time)]
    df = df.sort_values(by=['state', 'year', 'bmcode']) 
    df = df.drop_duplicates()

    if placebo_state is not None and placebo_state != "Texas":
        df = create_unit_placebo_dataset(df, placebo_state = placebo_state)
    
    if placebo_time is not None:
        df = create_time_placebo_dataset(df, new_treatment_start = placebo_time)
    else:
        # Only use data from 2016 onwards if not using a placebo time
        df = df[df['time'] >= pd.to_datetime(start_time)]  

    data_dict_cat = prep_data(df, outcome_type=outcome_type, group=cat_name)

    data_dict_cat['Y'].shape
    data_dict_cat['denominators'].shape
    data_dict_cat['control_idx_array'].shape
    
    import numpy as np
    from jax import random
    from numpyro.infer import MCMC, NUTS, Predictive

    #from models.monthly_model import monthly_model

    # set the random seed
    rng_key = random.PRNGKey(8675309)
    # split the random key
    rng_key, rng_key_ = random.split(rng_key)
    # Setup the sampler
    kernel = NUTS(model)

    mcmc = MCMC(
        kernel,
        num_warmup=num_warmup,
        num_samples=num_samples,
        num_chains=num_chains,
        progress_bar=True,
        thinning=thinning
    )

    mcmc.run(
        rng_key_,
        y=data_dict_cat['Y'],
        denominators=data_dict_cat['denominators'],
        control_idx_array=data_dict_cat['control_idx_array'],
        missing_idx_array=data_dict_cat['missing_idx_array'],
        rank=rank,
        outcome_dist=dist,
        adjust_for_missingness=missingness,
        nb_disp = disp_param,
        sample_disp = sample_disp,
        model_treated = model_treated
    )

    samples = mcmc.get_samples(group_by_chain=True)
    predictive = Predictive(model, mcmc.get_samples(group_by_chain=False))
    rng_key, rng_key_ = random.split(rng_key)

    predictions = predictive(
        rng_key_,
        denominators=data_dict_cat['denominators'],
        control_idx_array=None, #data_dict_cat['control_idx_array'],
        missing_idx_array=None, #data_dict_cat['missing_idx_array'],
        rank=rank,
        outcome_dist=dist,
        nb_disp = disp_param,
        sample_disp = sample_disp,
        model_treated = False
    )['y_obs']
    K, D, N = data_dict_cat['denominators'].shape
    pred_mat = predictions.reshape(mcmc.num_chains, int(mcmc.num_samples / mcmc.thinning), K, D, N)
   
    ## Take Python output and convert to draws matrix form
    params = dict_to_tidybayes({'mu': samples['mu_ctrl'], 'te': samples['te'], 'disp' : samples['disp']})
    preds = dict_to_tidybayes({"ypred" : pred_mat})

    preds[".chain"] = params[".chain"]
    preds[".draw"] = params[".draw"]

    all_samples = params.merge(preds, left_on = ['.draw', '.chain'], right_on = ['.draw', '.chain'])
    results_df = pd.DataFrame(all_samples)

    ## save input df
    df.to_csv('/Users/shaokangyang/Library/CloudStorage/GoogleDrive-sky.ang510@gmail.com/My Drive/Code/fertility_results/df_{}.csv'.format(results_file_suffix))
    ## save posterior samples
    results_df.to_csv(
        '/Users/shaokangyang/Library/CloudStorage/GoogleDrive-sky.ang510@gmail.com/My Drive/Code/fertility_results/{}_{}_{}_{}_{}.csv'.format(dist, "births", cat_name, rank, results_file_suffix)
    )

    
if __name__ == '__main__':
    from clean_monthly_birth_data import subgroup_definitions
    from joblib import Parallel, delayed

    # Define the inputs for the function
    inputs = [6, 7, 8, 9, 10, 11, 12]
    outcome_type = "births" 
    cats = list(subgroup_definitions[outcome_type].keys())
    dists = ['NB'] # Poisson or NB
    missing_flags = [True]
    # disp_params = [1e-4, 1e-3]
    disp_params = [1e-4]
    ## placebo_times = ["2020-05-01"]
    placebo_times = [None]
    placebo_states = [None]
    sample_disp = False
    dobbs_donor_sensitivity = False

    args = [(dist, cat, rank, m, disp, p, tm) for dist in dists for rank in inputs for cat in cats 
            for m in missing_flags for disp in disp_params for p in placebo_states 
            for tm in placebo_times]
    # Run the function in parallel
    results = Parallel(n_jobs=8)(delayed(run_model)(dist=i[0], outcome_type=outcome_type, cat_name=i[1], rank=i[2], missingness=i[3], 
                                                disp_param=i[4],
                                                sample_disp=sample_disp, placebo_state=i[5], placebo_time = i[6], 
                                                dobbs_donor_sensitivity=dobbs_donor_sensitivity, 
                                                results_file_suffix="through_june", num_chains=4, num_samples=2500, num_warmup=1000, thinning=10) for i in args)

In [16]:
from run_model import run_model
from joblib import Parallel, delayed

# Only rerun subgroups that may be missing ypred
rerun_cats = ['race', 'edu']

inputs = [6, 7, 8, 9, 10, 11, 12]  # latent ranks
outcome_type = "births"
dists = ['NB']
missing_flags = [True]
disp_params = [1e-4]
placebo_times = [None]
placebo_states = [None]
sample_disp = False
results_file_suffix = "rerun_ypred_fix"  # Distinct output

# Generate parameter combinations
args = [(dist, cat, rank, m, disp, p, tm)
        for dist in dists
        for rank in inputs
        for cat in rerun_cats
        for m in missing_flags
        for disp in disp_params
        for p in placebo_states
        for tm in placebo_times]

# Run the models in parallel
results = Parallel(n_jobs=8)(delayed(run_model)(
    dist=i[0], outcome_type=outcome_type, cat_name=i[1], rank=i[2],
    missingness=i[3], disp_param=i[4], sample_disp=sample_disp,
    placebo_state=i[5], placebo_time=i[6],
    results_file_suffix=results_file_suffix,
    num_chains=2, num_samples=1000, num_warmup=500, thinning=5,
    dobbs_donor_sensitivity=True  # or False if not filtering
) for i in args)


['Alabama']
['Alaska']
['Arizona']
['Arkansas']
['Alabama']
['California']
['Alabama']
['Alaska']
['Colorado']
['Arizona']
['Alaska']
['Alabama']
['Connecticut']
['Arkansas']
['Arizona']
['Alaska']
['Delaware']
['California']
['Arkansas']
['District of Columbia']
['Arizona']
['Colorado']
['Florida']
['California']
['Arkansas']
['Connecticut']
['Georgia']
['Colorado']
['California']
['Delaware']
['Hawaii']
['District of Columbia']
['Colorado']
['Connecticut']
['Alabama']
['Idaho']
['Florida']
['Alabama']
['Connecticut']
['Delaware']
['Alaska']
['Illinois']['Alabama']

['Georgia']
['Delaware']
['Alaska']
['District of Columbia']
['Arizona']
['Indiana']
['Alaska']
['Hawaii']
['Arizona']
['District of Columbia']
['Arkansas']
['Florida']
['Iowa']
['Arizona']
['Idaho']
['Arkansas']
['Alabama']
['Kansas']
['Florida']
['California']
['Arkansas']
['Georgia']
['Illinois']
['California']
['Georgia']
['California']
['Alaska']
['Kentucky']
['Colorado']
['Colorado']
['Hawaii']
['Indiana']
['Colorado

Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s]
  0%|          | 0/1500 [00:00<?, ?it/s][A
Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s][A
  0%|          | 0/1500 [00:00<?, ?it/s][A
Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s][A
  0%|          | 0/1500 [00:00<?, ?it/s][A
Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s][A
  0%|          | 0/1500 [00:00<?, ?it/s][A
Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s][A
  0%|          | 0/1500 [00:00<?, ?it/s][A
Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s][A
  0%|          | 0/1500 [00:00<?, ?it/s][A
Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s][A
  0%|          | 0/1500 [00:00<?, ?it/s][A
Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s][A
  0%|          | 0/1500 [00:00<?, ?it/s][A
Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s][A

(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)


Running chain 0:   0%|          | 0/1500 [00:12<?, ?it/s]
Running chain 0:   0%|          | 0/1500 [00:12<?, ?it/s][A
Running chain 0:   0%|          | 0/1500 [00:12<?, ?it/s][A
Running chain 0:   0%|          | 0/1500 [00:12<?, ?it/s][A
Running chain 0:   0%|          | 0/1500 [00:12<?, ?it/s][A
Running chain 0:   0%|          | 0/1500 [00:13<?, ?it/s][A
Running chain 0:   0%|          | 0/1500 [00:13<?, ?it/s][A
Running chain 0:   0%|          | 0/1500 [00:13<?, ?it/s][A
Running chain 1:   0%|          | 0/1500 [00:13<?, ?it/s][A
Running chain 1:   5%|▌         | 75/1500 [01:28<24:03,  1.01s/it][A
Running chain 0:   5%|▌         | 75/1500 [01:48<29:59,  1.26s/it][A
Running chain 0:   5%|▌         | 75/1500 [01:51<31:25,  1.32s/it][A
Running chain 1:   5%|▌         | 75/1500 [01:55<32:36,  1.37s/it][A
Running chain 0:   5%|▌         | 75/1500 [01:59<34:01,  1.43s/it][A
Running chain 1:   5%|▌         | 75/1500 [02:02<34:32,  1.45s/it][A
Running chain 1:   5%|▌         | 

['mu', 'te', 'disp']


Running chain 0:  95%|█████████▌| 1425/1500 [54:53<02:56,  2.35s/it]

['ypred']


Running chain 0:  90%|█████████ | 1350/1500 [55:18<06:17,  2.52s/it]

['Alabama']
['Alaska']
['Arizona']
['Arkansas']
['California']
['Colorado']
['Connecticut']
['Delaware']
['District of Columbia']
['Florida']
['Georgia']
['Hawaii']
['Idaho']
['Illinois']
['Indiana']
['Iowa']
['Kansas']
['Kentucky']
['Louisiana']
['Maine']
['Maryland']
['Massachusetts']
['Michigan']
['Minnesota']
['Mississippi']
['Missouri']
['Montana']
['Nebraska']
['Nevada']
['New Hampshire']
['New Jersey']
['New Mexico']
['New York']
['North Carolina']
['North Dakota']
['Ohio']
['Oklahoma']
['Oregon']
['Pennsylvania']
['Rhode Island']
['South Carolina']
['South Dakota']
['Tennessee']
['Texas']
['Utah']
['Vermont']
['Virginia']
['Washington']
['West Virginia']
['Wisconsin']
['Wyoming']
['Alabama', 'Arkansas', 'Georgia', 'Idaho', 'Kentucky', 'Louisiana', 'Mississippi', 'Missouri', 'Oklahoma', 'South Dakota', 'Tennessee', 'Texas', 'West Virginia', 'Wisconsin']
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)


Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s]
  0%|          | 0/1500 [00:00<?, ?it/s][A
Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s][A

(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)


Running chain 0:   0%|          | 0/1500 [00:34<?, ?it/s]
Running chain 0:  80%|████████  | 1200/1500 [56:49<14:51,  2.97s/it]
Running chain 0:  85%|████████▌ | 1275/1500 [56:59<10:32,  2.81s/it][A
Running chain 0:  85%|████████▌ | 1275/1500 [57:43<10:35,  2.83s/it][A
Running chain 0:  90%|█████████ | 1350/1500 [57:47<06:41,  2.68s/it][A
Running chain 1: 100%|██████████| 1500/1500 [57:51<00:00,  2.31s/it][A

Running chain 0: 100%|██████████| 1500/1500 [58:01<00:00,  2.32s/it][A

Running chain 1:  90%|█████████ | 1350/1500 [58:03<06:42,  2.69s/it][A

['mu', 'te', 'disp']



Running chain 1:  95%|█████████▌| 1425/1500 [58:09<03:11,  2.56s/it][A

['ypred']


Running chain 0:  95%|█████████▌| 1425/1500 [58:39<03:12,  2.56s/it]
Running chain 0:   5%|▌         | 75/1500 [02:53<43:55,  1.85s/it][A

['Alabama']
['Alaska']
['Arizona']
['Arkansas']
['California']
['Colorado']
['Connecticut']
['Delaware']
['District of Columbia']
['Florida']
['Georgia']
['Hawaii']
['Idaho']
['Illinois']
['Indiana']
['Iowa']
['Kansas']
['Kentucky']
['Louisiana']
['Maine']
['Maryland']
['Massachusetts']
['Michigan']
['Minnesota']
['Mississippi']
['Missouri']
['Montana']
['Nebraska']
['Nevada']
['New Hampshire']
['New Jersey']
['New Mexico']
['New York']
['North Carolina']
['North Dakota']
['Ohio']
['Oklahoma']
['Oregon']
['Pennsylvania']
['Rhode Island']
['South Carolina']
['South Dakota']
['Tennessee']
['Texas']
['Utah']
['Vermont']
['Virginia']
['Washington']
['West Virginia']
['Wisconsin']
['Wyoming']
['Alabama', 'Arkansas', 'Georgia', 'Idaho', 'Kentucky', 'Louisiana', 'Mississippi', 'Missouri', 'Oklahoma', 'South Dakota', 'Tennessee', 'Texas', 'West Virginia', 'Wisconsin']
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)


Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s]
  0%|          | 0/1500 [00:00<?, ?it/s][A
Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s][A

(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)


Running chain 0:   0%|          | 0/1500 [00:29<?, ?it/s]
Running chain 0:  90%|█████████ | 1350/1500 [1:00:43<07:09,  2.86s/it]
Running chain 0:  95%|█████████▌| 1425/1500 [1:01:25<03:25,  2.75s/it][A
Running chain 0:  90%|█████████ | 1350/1500 [1:01:31<07:13,  2.89s/it][A
Running chain 1:  95%|█████████▌| 1425/1500 [1:01:35<03:25,  2.75s/it][A
Running chain 1:  90%|█████████ | 1350/1500 [1:01:38<07:14,  2.90s/it][A
Running chain 1: 100%|██████████| 1500/1500 [1:01:40<00:00,  2.47s/it][A

Running chain 0: 100%|██████████| 1500/1500 [1:01:57<00:00,  2.48s/it][A
Running chain 0:  95%|█████████▌| 1425/1500 [1:01:58<03:26,  2.76s/it]

['mu', 'te', 'disp']
['ypred']



Running chain 0:   5%|▌         | 75/1500 [03:07<49:55,  2.10s/it][A

['Alabama']
['Alaska']
['Arizona']
['Arkansas']
['California']
['Colorado']
['Connecticut']
['Delaware']
['District of Columbia']
['Florida']
['Georgia']
['Hawaii']
['Idaho']
['Illinois']
['Indiana']
['Iowa']
['Kansas']
['Kentucky']
['Louisiana']
['Maine']
['Maryland']
['Massachusetts']
['Michigan']
['Minnesota']
['Mississippi']
['Missouri']
['Montana']
['Nebraska']
['Nevada']
['New Hampshire']
['New Jersey']
['New Mexico']
['New York']
['North Carolina']
['North Dakota']
['Ohio']
['Oklahoma']
['Oregon']
['Pennsylvania']
['Rhode Island']
['South Carolina']
['South Dakota']
['Tennessee']
['Texas']
['Utah']
['Vermont']
['Virginia']
['Washington']
['West Virginia']
['Wisconsin']
['Wyoming']
['Alabama', 'Arkansas', 'Georgia', 'Idaho', 'Kentucky', 'Louisiana', 'Mississippi', 'Missouri', 'Oklahoma', 'South Dakota', 'Tennessee', 'Texas', 'West Virginia', 'Wisconsin']
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)



Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s]:04,  2.67s/it][A
  0%|          | 0/1500 [00:00<?, ?it/s][A
Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s][A

(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)


Running chain 0:   0%|          | 0/1500 [00:24<?, ?it/s],  2.68s/it]
Running chain 0:  90%|█████████ | 1350/1500 [1:04:19<07:26,  2.98s/it]
Running chain 0: 100%|██████████| 1500/1500 [1:04:36<00:00,  2.58s/it][A

Running chain 1: 100%|██████████| 1500/1500 [1:04:40<00:00,  2.59s/it][A

Running chain 1:  95%|█████████▌| 1425/1500 [1:04:46<03:30,  2.80s/it][A

['mu', 'te', 'disp']


Running chain 0:  95%|█████████▌| 1425/1500 [1:04:49<03:31,  2.81s/it]
Running chain 1: 100%|██████████| 1500/1500 [1:04:53<00:00,  2.60s/it][A

Running chain 1:  95%|█████████▌| 1425/1500 [1:04:55<03:31,  2.82s/it][A

['ypred']


Running chain 0: 100%|██████████| 1500/1500 [1:05:02<00:00,  2.60s/it]


['mu', 'te', 'disp']
['ypred']
['Alabama']
['Alaska']
['Arizona']
['Arkansas']
['California']
['Colorado']
['Connecticut']
['Delaware']
['District of Columbia']
['Florida']
['Georgia']
['Hawaii']
['Idaho']
['Illinois']
['Indiana']
['Iowa']
['Kansas']
['Kentucky']
['Louisiana']
['Maine']
['Maryland']
['Massachusetts']
['Michigan']
['Minnesota']
['Mississippi']
['Missouri']
['Montana']
['Nebraska']
['Nevada']
['New Hampshire']
['New Jersey']
['New Mexico']
['New York']
['North Carolina']
['North Dakota']
['Ohio']
['Oklahoma']
['Oregon']
['Pennsylvania']
['Rhode Island']
['South Carolina']
['South Dakota']
['Tennessee']
['Texas']
['Utah']
['Vermont']
['Virginia']
['Washington']
['West Virginia']
['Wisconsin']
['Wyoming']
['Alabama', 'Arkansas', 'Georgia', 'Idaho', 'Kentucky', 'Louisiana', 'Mississippi', 'Missouri', 'Oklahoma', 'South Dakota', 'Tennessee', 'Texas', 'West Virginia', 'Wisconsin']
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)


Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s]
  0%|          | 0/1500 [00:00<?, ?it/s][A
Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s][A

(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)


Running chain 0:   5%|▌         | 75/1500 [02:41<43:12,  1.82s/it]
Running chain 1:   5%|▌         | 75/1500 [02:47<45:14,  1.90s/it][A

['Alabama']
['Alaska']
['Arizona']
['Arkansas']
['California']
['Colorado']
['Connecticut']
['Delaware']
['District of Columbia']
['Florida']
['Georgia']
['Hawaii']
['Idaho']
['Illinois']
['Indiana']
['Iowa']
['Kansas']
['Kentucky']
['Louisiana']
['Maine']
['Maryland']
['Massachusetts']
['Michigan']
['Minnesota']
['Mississippi']
['Missouri']
['Montana']
['Nebraska']
['Nevada']
['New Hampshire']
['New Jersey']
['New Mexico']
['New York']
['North Carolina']
['North Dakota']
['Ohio']
['Oklahoma']
['Oregon']
['Pennsylvania']
['Rhode Island']
['South Carolina']
['South Dakota']
['Tennessee']
['Texas']
['Utah']
['Vermont']
['Virginia']
['Washington']
['West Virginia']
['Wisconsin']
['Wyoming']
['Alabama', 'Arkansas', 'Georgia', 'Idaho', 'Kentucky', 'Louisiana', 'Mississippi', 'Missouri', 'Oklahoma', 'South Dakota', 'Tennessee', 'Texas', 'West Virginia', 'Wisconsin']
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)



Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s]6,  2.58s/it][A
  0%|          | 0/1500 [00:00<?, ?it/s][A
Running chain 0:   0%|          | 0/1500 [00:21<?, ?it/s]
Running chain 1:   0%|          | 0/1500 [00:21<?, ?it/s][A

(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)


Running chain 0:   0%|          | 0/1500 [00:22<?, ?it/s] 2.64s/it]
Running chain 1:   0%|          | 0/1500 [00:22<?, ?it/s][A
Running chain 0: 100%|██████████| 1500/1500 [1:07:22<00:00,  2.70s/it][A
Running chain 0:  95%|█████████▌| 1425/1500 [1:07:49<03:39,  2.92s/it]
Running chain 1: 100%|██████████| 1500/1500 [1:07:52<00:00,  2.72s/it][A

Running chain 1:  95%|█████████▌| 1425/1500 [1:07:54<03:37,  2.89s/it][A

['mu', 'te', 'disp']


Running chain 0: 100%|██████████| 1500/1500 [1:08:10<00:00,  2.73s/it]

Running chain 1: 100%|██████████| 1500/1500 [1:08:15<00:00,  2.73s/it][A


['ypred']



Running chain 1:   5%|▌         | 75/1500 [02:46<45:57,  1.94s/it][A

['mu', 'te', 'disp']


Running chain 0:   5%|▌         | 75/1500 [02:50<47:07,  1.98s/it]

['ypred']



Running chain 0:   5%|▌         | 75/1500 [02:56<48:45,  2.05s/it][A

['Alabama']
['Alaska']
['Arizona']
['Arkansas']
['California']
['Colorado']
['Connecticut']
['Delaware']
['District of Columbia']
['Florida']
['Georgia']
['Hawaii']
['Idaho']
['Illinois']
['Indiana']
['Iowa']
['Kansas']
['Kentucky']
['Louisiana']
['Maine']
['Maryland']
['Massachusetts']
['Michigan']
['Minnesota']
['Mississippi']
['Missouri']
['Montana']
['Nebraska']
['Nevada']
['New Hampshire']
['New Jersey']
['New Mexico']
['New York']
['North Carolina']
['North Dakota']
['Ohio']
['Oklahoma']
['Oregon']
['Pennsylvania']
['Rhode Island']
['South Carolina']
['South Dakota']
['Tennessee']
['Texas']
['Utah']
['Vermont']
['Virginia']
['Washington']
['West Virginia']
['Wisconsin']
['Wyoming']
['Alabama', 'Arkansas', 'Georgia', 'Idaho', 'Kentucky', 'Louisiana', 'Mississippi', 'Missouri', 'Oklahoma', 'South Dakota', 'Tennessee', 'Texas', 'West Virginia', 'Wisconsin']
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)


Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s]
  0%|          | 0/1500 [00:00<?, ?it/s][A
Compiling.. :   0%|          | 0/1500 [00:00<?, ?it/s][A

(4, 51)
(4, 51, 48)
(4, 51, 48)
(4, 51)
(4, 51, 48)
(4, 51, 48)


Running chain 0:   0%|          | 0/1500 [00:25<?, ?it/s] 2.51s/it]
Running chain 1:   0%|          | 0/1500 [00:25<?, ?it/s][A
Running chain 1:  10%|█         | 150/1500 [06:30<57:01,  2.53s/it][A
Running chain 0:  15%|█▌        | 225/1500 [10:41<59:55,  2.82s/it][A
Running chain 0: 100%|██████████| 1500/1500 [1:11:02<00:00,  2.84s/it]

Running chain 1: 100%|██████████| 1500/1500 [1:11:09<00:00,  2.85s/it][A


['mu', 'te', 'disp']
['ypred']


Running chain 0:   5%|▌         | 75/1500 [02:39<42:27,  1.79s/it]
Running chain 1:   5%|▌         | 75/1500 [02:39<42:30,  1.79s/it][A
Running chain 0:  10%|█         | 150/1500 [06:21<55:41,  2.48s/it][A

KeyboardInterrupt: 