In [None]:
import numpyro
import numpyro.distributions as dist
import jax.numpy as np
import pandas
import pickle

numpyro.set_platform('cpu')
numpyro.set_host_device_count(10)

def preprocess():
    german_data_raw = pandas.read_csv('german.data',header=None,sep='\s+')

    modified_columns = []
    column_info = {}
    column_index = {}


    column_end = 0
    for i in german_data_raw:
        raw_column = german_data_raw[i]
        print(i, ": ", raw_column.dtype)

        column_start = column_end

        if raw_column.dtype == 'O':
            categories = np.array(raw_column.map(lambda s : int(s.removeprefix("A"+str(i+1)))))
            column_info[i],indices = np.unique(categories, return_inverse=True)
            column_length = column_info[i].shape[0]

            modified_columns.append(np.eye(column_length)[indices])

            column_end = column_start + column_length
        else:
            modified_columns.append(np.expand_dims(np.array(raw_column),1))
            column_end = column_start + 1
        
        column_index[i] = (column_start,column_end)
    return column_info, column_index, np.concatenate(modified_columns,axis=1)




try:
    print("load column info...")
    with open('data/column_info.pkl', 'rb') as f:
        cinfo = pickle.load(f)
    print("load column index...")
    with open('data/column_index.pkl', 'rb') as f:
        cidx = pickle.load(f)
    print("load german credit data...")
    german_data = np.load("data/german.npy")
except Exception as e:
    print(e)
    cinfo, cidx, german_data = preprocess()
    np.save("data/german.npy",german_data)
    with open('data/column_info.pkl', 'wb') as f:
        pickle.dump(cinfo, f)
    
    with open('data/column_index.pkl', 'wb') as f:
        pickle.dump(cidx, f)

import numpy

normalized_features = numpy.array(german_data[:,:-1])

mask = numpy.ones(len(cidx.keys()), dtype=bool)
mask[numpy.fromiter(cinfo.keys(), dtype=int)]=False

numeric_idx = numpy.array([cidx[i][0] for i,c in enumerate(mask[:-1]) if c])

normalized_features[:,numeric_idx] = (normalized_features[:,numeric_idx] - numpy.mean(normalized_features[:,numeric_idx], axis=0))/numpy.std(normalized_features[:,numeric_idx], axis=0)

normalized_features = np.array(normalized_features)
credits = 2-german_data[:,-1]

In [None]:
def german_credit(features, credits):
    num_obs, num_feature = features.shape

    log_tau_beta_global = numpyro.sample('log_tau_beta_global', dist.Normal(0, 10))
    #alpha = numpyro.sample('alpha', dist.Normal(0, 1))

    with numpyro.plate('feature', num_feature):
        log_tau_beta = numpyro.sample('log_tau_beta',dist.Normal(log_tau_beta_global, 1))
        beta = numpyro.sample('beta', dist.Normal(0, np.exp(log_tau_beta)))

    logits = np.dot(features, beta)#+alpha
    with numpyro.plate('observation', num_obs):
        return numpyro.sample('credit', dist.Bernoulli(logits = logits),obs=credits)



from jax import random
from numpyro.infer import MCMC, NUTS 

from numpyro.handlers import reparam
from numpyro.infer.reparam import LocScaleReparam

nuts_kernel = NUTS(german_credit)

mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=5000 ,num_chains=10)

rng_key = random.PRNGKey(0)

mcmc.run(rng_key, features=normalized_features, credits = credits, extra_fields=('num_steps',))

mcmc.print_summary()



reparam_model = reparam(german_credit, config={"log_tau_beta_global": LocScaleReparam(0),"log_tau_beta": LocScaleReparam(0),"beta": LocScaleReparam(0)})

nuts_kernel2 = NUTS(reparam_model)

mcmc2 = MCMC(nuts_kernel2, num_warmup=1000, num_samples=5000 ,num_chains=10)

rng_key2 = random.PRNGKey(0)

mcmc2.run(rng_key, features=normalized_features, credits = credits, extra_fields=('num_steps',))


mcmc2.print_summary()

In [None]:
fields = mcmc.get_extra_fields()
fields2 = mcmc2.get_extra_fields()

print(np.sum(fields['num_steps']), np.sum(fields2['num_steps']))

In [None]:
import matplotlib.pyplot as plt

samples = mcmc.get_samples()
samples2 = mcmc2.get_samples()

plt.scatter(samples2["log_tau_beta_decentered"][:,8],samples2["beta_decentered"][:,8], s = 1, c = "red")
plt.scatter(samples["log_tau_beta"][:,8],samples["beta"][:,8], s = 1)


In [None]:
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal
from numpyro.optim import Adam

optimizer = Adam(step_size = 0.005)

learnable_model = reparam(german_credit, config={"log_tau_beta_global": LocScaleReparam(),"log_tau_beta": LocScaleReparam(),"beta": LocScaleReparam()})

# setup the inference algorithm
svi = SVI(learnable_model, AutoNormal(learnable_model), optimizer, loss=Trace_ELBO())

# do gradient steps
svi_result = svi.run(random.PRNGKey(0), 20000, features=normalized_features, credits = credits)
params = svi_result.params

In [None]:
nuts_kernel3 = NUTS(learnable_model)

mcmc3 = MCMC(nuts_kernel3, num_warmup=1000, num_samples=5000 ,num_chains=10)

rng_key3 = random.PRNGKey(0)

mcmc3.run(rng_key3, features=normalized_features, credits = credits, extra_fields=('num_steps',))

mcmc3.print_summary()

In [None]:
samples3 = mcmc3.get_samples()
plt.scatter(samples3["log_tau_beta"][:,8],samples3["beta"][:,8], s = 1)
plt.scatter(samples3["log_tau_beta_decentered"][:,8],samples3["beta_decentered"][:,8], s = 1, c = 'red')