In [46]:
import numpy as np
from scipy.stats import multivariate_normal, gamma, norm, bernoulli, truncnorm
from tqdm import tqdm
import pandas as pd

# %% set graphs style
import matplotlib.pyplot as plt
import matplotlib.style as style
import seaborn as sns
plt.rcParams['xtick.labelsize'] = 16
plt.rcParams['ytick.labelsize'] = 16
plt.rcParams['lines.linewidth'] = 4
style.use('ggplot')

# %% Import data
df = pd.read_csv('../../data/polls.csv')
df = df.dropna()# drop rows with nan
#drop the first three columns
df = df.drop(columns=['org', 'year', 'survey'])
#get dummies
df = pd.get_dummies(df, prefix='education', prefix_sep='.', 
                            columns=['edu'])
df = pd.get_dummies(df, prefix='age', prefix_sep='.', 
                            columns=['age'])
# %% manipulate data to get design matrices
states = df.groupby('state')
st = list(states.groups.keys())
states_X = []
states_y = []
for state in st:
    df_state = states.get_group(state)
    y = df_state['bush'].to_numpy()
    #remove column name state and y
    df_state = df_state.drop(columns=['bush', 'state'])
    #get intercept and concatenate
    interc = np.ones((df_state.shape[0],1))
    rest = df_state.to_numpy()
    X = np.concatenate([interc, rest], axis=1)   
    #append to list
    states_X.append(X)
    states_y.append(y)
    
#%%

class Hierarchical:
    '''Takes list of design matrices (one per store)
    and list of vectors (one per store)'''
    # initialise the object
    def __init__(self, X, y):
        self.X = X
        self.y = y
        self.S = len(y)
        self.n_par = X[0].shape[1]
        

    def initialise_parameters(self):
        self.gam = np.zeros((self.n_par, self.S))
        self.m = np.zeros(self.n_par)
        self.tau_squared = 1
        #for loop for z
        self.z = []
        for s in range(self.S):
            self.z.append([bernoulli.rvs(p= 0.5,size= 1) for i in range(self.X[s].shape[0])])
        
        
        
    
    def _update_z(self, gam): #super slow loop. maybe map function?or smarter way?
        for s in range(self.S):
            for i in range(self.X[s].shape[0]):
                if self.y[s][i] == 0:
                    self.z[s][i] = truncnorm.rvs(-np.inf,0, 
                                                loc=self.X[s][i] @ gam[:,s], scale=1)
                if self.y[s][i] == 1:
                    self.z[s][i] = truncnorm.rvs(0,np.inf, 
                                                loc =self.X[s][i] @ gam[:,s], scale=1)
         
    def _update_gamma(self, tau_squared):
        for s in range(0, self.S):            
            V = np.linalg.inv(self.X[s].T @ self.X[s] +
                              np.identity(self.n_par) / tau_squared)            
            m = V @ ((self.X[s].T @ np.asarray(self.z[s])).reshape(self.X[s].shape[1],) + self.m/(tau_squared))            
            self.gam[:, s] = multivariate_normal.rvs(mean=m, cov=V)    

    
    def _update_m(self, tau_squared, g):
        var = np.identity(self.n_par) * (tau_squared / self.S)
        #var = (np.eye(12)@g).sum(axis=1) / tau_squared
        mean = var @ (np.sum(np.array(g), axis=1)) / (tau_squared)
        self.m = multivariate_normal.rvs(mean, var+np.diag(np.ones(var.shape[0])*1e-9))

        
    
    def _update_tau_squared(self, gam, m):
        alpha = (self.S + 1)/2
        g_sum = 1
        for s in range(0, self.S):
            g_sum = g_sum + (gam[:, s] - m).T @ (gam[:, s] - m)
        beta = 0.5 * (g_sum)
        self.tau_squared = 1 / gamma.rvs(a=alpha, scale=1 / beta, size=1)
        
    def _update_traces(self, it):
        self.traces['gammas'][it, :, :] = self.gam
        self.traces['m'][it] = self.m
        self.traces['tau_squared'][it] = self.tau_squared
        #self.traces['z'] = self.z not necessary to store

    def fit_GibbsSampler(self, n_iter, burn):
        # initialize parameters:
        self.initialise_parameters()
        # housekeeping setup
        self.n_iter = n_iter
        self.burn = burn
        self.traces = {'tau_squared': np.zeros(self.n_iter),
                       'm': np.zeros((self.n_iter, self.n_par)),
                       'gammas': np.zeros((self.n_iter, self.n_par, self.S)),
                       }
        # do gibbs steps:
        for it in tqdm(range(self.n_iter)):
            self._update_m(self.tau_squared, self.gam)
            self._update_tau_squared(self.gam, self.m)
            self._update_gamma(self.tau_squared)
            self._update_z(self.gam)
            self._update_traces(it)
        # remove burnin:
        for trace in self.traces.keys():
            self.traces[trace] = self.traces[trace][self.burn:]
        
        '''Define plot functions'''    
        def plot_gamma_histograms_1(self):        
            for b in range(self.n_par): 
                traces = self.traces['gammas']
                for group in range(self.S):
                    plt.hist(traces[:, b, group], density=True, alpha=.5, bins=50)
                plt.title(f'gamma {b} posteriors, all stores')
                plt.figure()
        
            
        def plot_m_histogram(self):
            traces = self.traces['m']
            for b in range(self.n_par):
                plt.hist(traces[:, b], density= True, alpha=.5, bins=50)
            plt.title('m posterior')
            plt.figure()
            

        def plot_other_histograms(self, variable):
            trace = self.traces[variable]
            plt.hist(trace, density=True, alpha=.5, bins=50)
            plt.title(f'{variable} posterior')
            plt.figure()

        def plot_all_posteriors(self):
            self.plot_beta_histograms_1()
            self.plot_m_histogram()
            keys = list(self.traces.keys())
            keys.remove('m')
            keys.remove('gammas')
            for trace in keys:
                self.plot_other_histograms(trace)
                
        

#%%
model = Hierarchical(states_X, states_y)

#%%
        
model.fit_GibbsSampler(1500, 500)   

#%%
model.plot_all_posteriors()     
        
        

  out = random_state.multivariate_normal(mean, cov, size)
  1%|█▊                                                                                                                           | 21/1500 [00:03<03:40,  6.69it/s]


LinAlgError: Singular matrix

In [49]:
truncnorm.rvs([-np.inf, -np.inf],[0,0], loc=0, scale=1)

array([-1.2460836, -0.9041836])

In [54]:
lis = np.random.choice([1,0], size=10, replace=True, p=None)

In [55]:
lis

array([1, 1, 0, 1, 0, 0, 1, 1, 1, 1])

In [56]:
lis.apply(lambda x: -np.inf if x == 1 else 0)

AttributeError: 'numpy.ndarray' object has no attribute 'apply'

In [67]:
lower_bound

[0, 0, -inf, 0, -inf, -inf, 0, 0, 0, 0]

In [68]:
upper_bound

[inf, inf, 0, inf, 0, 0, inf, inf, inf, inf]

In [69]:
lis

array([1, 1, 0, 1, 0, 0, 1, 1, 1, 1])