In [1]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import csv

import stan
import arviz as az

import nest_asyncio
nest_asyncio.apply()

In [17]:
file = open('insurance.csv', 'r')
reader = csv.reader(file)
data = []

for row in reader:
	data.append(row)
print(data.pop(0))
data = np.array(data)

data[:, 1][data[:, 1] == 'male'] = 0
data[:, 1][data[:, 1] == 'female'] = 1
data[:, 4][data[:, 4] == 'no'] = 0
data[:, 4][data[:, 4] == 'yes'] = 1
data[:, 5][data[:, 5] == 'northeast'] = 0
data[:, 5][data[:, 5] == 'northwest'] = 1
data[:, 5][data[:, 5] == 'southeast'] = 2
data[:, 5][data[:, 5] == 'southwest'] = 3
data = data.astype(np.float64)

#normalize data
for i in range(len(data[0])):
	std = np.std(data[:,i])
	data[:,i] = data[:,i] - np.mean(data[:,i])
	data[:,i] = data[:,i]/std

x_data = data[:, :6].reshape((len(data), 6, 1))
y_data = data[:, 6]
print(x_data.shape)
print(y_data.shape)

['age', 'sex', 'bmi', 'children', 'smoker', 'region', 'charges']
(1338, 6, 1)
(1338,)


In [None]:
N = len(data)
alpha = 2.3
# sigma = 2.
# slope = 4.
# x = np.random.normal (size=N)
# y = alpha + slope * x + sigma * np.random.normal(size=N)

data = {
    'x':x_data,
    'N':N,
    'y':y_data
}

In [None]:
program_code = """
data {
  int<lower=1> N;           // Number of observations
  vector[N] x;              // Covariate
  vector[N] y;              // Outcome
}

parameters {
  real alpha;               // Intercept
  real beta;                // Slope
  real<lower=0> sigma;      // Noise
}

model {
  // Priors
  sigma ~ inv_gamma(1, 1);     //tau0 = 1, tau1 = 1
  alpha ~ normal(0, 10);       //sigma_alhpa = 10
  beta ~ multi_normal(0, 10);  //sigma_beta = 10

  // Likelihood
  for (n in 2:N)
    y[n] ~ normal(alpha + beta * x[n], sigma);
}
"""

In [None]:
model = stan.build(program_code,data)
fit = model.sample(num_chains=3,num_warmup=1000,num_samples=2500)

In [None]:
df = fit.to_frame()
df.head()

In [None]:
az.summary(fit)

In [None]:
f = az.plot_trace(fit, compact=False, legend=True)
plt.tight_layout()