In [1]:
!pip install pystan



In [3]:
import pystan
import pandas as pd
import numpy as np
from sklearn.datasets import make_classification
from sklearn.metrics import roc_auc_score
from scipy.special import expit

In [5]:
log_ab_normal = """
    /* Stan model code for logistic regression */
    data {                                 
        int<lower=0> N;  // count of observations
        int<lower=0> K;  // count of features
        matrix[N, K] X;  // feature matrix
        int<lower=0,upper=1> y[N];  // target
    }
    parameters {
        real alpha;  // bias
        vector[K] beta;  // feature weights
    }
    model { 
        alpha ~ normal(0,5);  // bias prior
        beta ~ normal(0,1.0);  // featue weights prior
        y ~ bernoulli_logit(alpha + X * beta);  // likelihood
    }
    generated quantities {}
    """


In [11]:
def fit():
    """Fit a Bayesian logistic regression model using MCMC sampling.
   	Args:
		input_fn: lambda function that return training data as two numpy arrays (x, y).
		
    Returns:
      fit: Fitted Stan output.	
    """
    features, labels = make_classification(n_features=5, n_samples=1000)

    stan_datadict = {}
    stan_datadict['N'] = features.shape[0]
    stan_datadict['K'] = features.shape[1]
    stan_datadict['X'] = features
    stan_datadict['y'] = labels

    model = pystan.StanModel(model_code=log_ab_normal)

    fit = model.sampling(
        data=stan_datadict, warmup=250, iter=1000, verbose=True)

    return (features, labels, fit)


def evaluate(features, labels, fit):
    """Evaluate the performance of fitted model on unseen test data.
   	Args:
		input_fn: lambda function that outputs test data as two numpy arrays (x, y).
		
	Returns:
		score: AUC score of fitted model	
	"""

    b = fit.extract(['alpha'])['alpha'].mean()
    w = fit.extract(['beta'])['beta'].mean(axis=0)

    logits = features @ w + b
    preds = expit(logits)

    score = roc_auc_score(labels, preds)
    return score

In [12]:
fit()

INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_3dd09af7b0835515b3175e391df50a26 NOW.


(array([[-1.70544481,  1.05184672, -0.81634385,  0.00663493, -1.76764936],
        [ 0.78569833,  0.90930996, -0.93044512,  1.52942914,  1.22270603],
        [ 0.7774027 , -1.58365708, -0.01941882, -1.2169977 ,  0.48227936],
        ...,
        [-0.91313133, -1.75223818, -0.40237757, -2.54208078, -1.62475287],
        [-1.36950389,  1.27720649, -1.18267143,  0.48088931, -1.29273602],
        [ 1.28207779, -0.54549664,  1.62807193,  0.26463011,  1.4006835 ]]),
 array([1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
        0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0,
        0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
        1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0,
        0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0,
        0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1,
        0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0,
        0, 1, 1, 