In [1]:
import arviz as az
import numpy as np
import pymc as pm
from pymc.math import log, dot
import pandas as pd

%load_ext lab_black

# 15. GLM Examples*

## Arrhythmia

A logistic regression example.

Adapted from [Unit 7: arrhythmia.odc](https://raw.githubusercontent.com/areding/6420-pymc/main/original_examples/Codes4Unit7/arrhythmia.odc).

Data can be found [here](https://raw.githubusercontent.com/areding/6420-pymc/main/data/arrhythmia.csv).

Patients who undergo Coronary Artery Bypass Graft Surgery (CABG) have an approximate 19-40% chance of developing atrial fibrillation (AF). AF can lead to blood clots forming causing greater in-hospital mortality, strokes, and longer hospital stays. While this can be prevented with drugs, it is very expensive and sometimes dangerous if not warranted. Ideally, several risk factors which would indicate an increased risk of developing AF in this population could save lives and money by indicating which patients need pharmacological intervention. Researchers began collecting data from CABG patients during their hospital stay such as demographics like age and sex, as well as heart rate, cholesterol, operation time, etc.. Then, the researchers recorded which patients developed AF during their hospital stay. Researchers now want to find those pieces of data which indicate high risk of AF. In the past, indicators like age, hypertension, and body surface area (BSA) have been good indicators, though these alone have not produced a satisfactory solution.

Fibrillation occurs when the heart muscle begins a quivering motion instead of a normal, healthy pumping rhythm. Fibrillation can affect the atrium (atrial fibrillation) or the ventricle (ventricular  fibrillation); ventricular fibrillation is imminently life threatening.

Atrial fibrillation is the quivering, chaotic motion in the upper chambers of the heart, known as the atria. Atrial fibrillation is often due to serious underlying medical conditions, and should be evaluated by a physician. It is not typically a medical emergency.

Ventricular fibrillation occurs in the ventricles (lower chambers) of the heart; it is always a medical emergency. If left untreated, ventricular fibrillation (VF, or V-fib) can lead to death within minutes. When a heart goes into V-fib, effective pumping of the blood stops. V-fib is considered a form of cardiac arrest, and an individual suffering from it will not survive unless cardiopulmonary resuscitation (CPR) and defibrillation are provided immediately.

DATA Arrhythmia
- Y = Fibrillation
- X1 = Age
- X2 = Aortic Cross Clamp Time
- X3 = Cardiopulmonary Bypass Time:
    - Bypass of the heart and lungs as, for example, in open heart surgery. Blood returning to the heart is diverted through a heart-lung machine (a pump-oxygenator) before returning it to the arterial circulation. The machine does the work both of the heart (pump blood) and the lungs (supply oxygen to red blood cells).
- X4 = ICU Time	(Intensive Care Unit)
- X5 = Avg Heart Rate	
- X6 = Left Ventricle Ejection Fraction
- X7 = Hypertension
- X8 = Gender [1 -Female; 0-Male]
- X9 = Diabetis
- X10 = Previous MI

In [2]:
data_df = pd.read_csv("../data/arrhythmia.csv")
data_df.info()
X = data_df.iloc[:, 1:].to_numpy()
# add intercept column to X
X_aug = np.concatenate((np.ones((X.shape[0], 1)), X), axis=1)
y = data_df["Fibrillation"].to_numpy()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 81 entries, 0 to 80
Data columns (total 11 columns):
 #   Column                         Non-Null Count  Dtype  
---  ------                         --------------  -----  
 0   Fibrillation                   81 non-null     float64
 1   Age                            81 non-null     float64
 2   AorticCrossClampTime           81 non-null     float64
 3   CardiopulmonaryBypassTime      81 non-null     float64
 4   ICUTime                        81 non-null     float64
 5   AvgHeartRate                   81 non-null     float64
 6   LeftVentricleEjectionFraction  81 non-null     float64
 7   Hypertension                   81 non-null     float64
 8   Gender                         81 non-null     float64
 9   Diabetes                       81 non-null     float64
 10  PreviousMI                     81 non-null     float64
dtypes: float64(11)
memory usage: 7.1 KB


In [3]:
with pm.Model() as m:
    X_data = pm.Data("X_data", X_aug, mutable=True)
    y_data = pm.Data("y_data", y, mutable=False)

    betas = pm.Normal("beta", mu=0, tau=0.001, shape=X.shape[1] + 1)

    p = dot(X_data, betas)

    lik = pm.Bernoulli("y", logit_p=p, observed=y_data)

    trace = pm.sample(5000)

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [beta]


Sampling 4 chains for 1_000 tune and 5_000 draw iterations (4_000 + 20_000 draws total) took 18 seconds.


In [4]:
az.summary(trace, hdi_prob=0.95)

Unnamed: 0,mean,sd,hdi_2.5%,hdi_97.5%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
beta[0],-13.133,4.97,-23.136,-3.628,0.057,0.041,7634.0,9899.0,1.0
beta[1],0.188,0.049,0.097,0.288,0.0,0.0,12150.0,11195.0,1.0
beta[2],0.033,0.026,-0.017,0.083,0.0,0.0,8683.0,11167.0,1.0
beta[3],-0.024,0.016,-0.057,0.006,0.0,0.0,9518.0,11801.0,1.0
beta[4],-0.154,0.096,-0.341,0.033,0.001,0.001,19132.0,14358.0,1.0
beta[5],0.007,0.032,-0.055,0.07,0.0,0.0,12093.0,12380.0,1.0
beta[6],0.025,0.028,-0.029,0.08,0.0,0.0,12060.0,13797.0,1.0
beta[7],-0.655,0.674,-1.997,0.659,0.005,0.004,19381.0,14834.0,1.0
beta[8],-0.332,0.687,-1.641,1.06,0.005,0.005,20494.0,13634.0,1.0
beta[9],1.324,0.692,-0.01,2.687,0.006,0.004,13682.0,13919.0,1.0


## Ants

An example of Poisson regression.

Data can be found [here](https://raw.githubusercontent.com/areding/6420-pymc/main/data/ants.csv).

Adapted from [Unit 7: ants.odc](https://raw.githubusercontent.com/areding/6420-pymc/main/original_examples/Codes4Unit7/ants.odc).

The data discussed in Gotelli and Ellison (2002) provide the ant species richness (number of ant species) found in 64-square-meter sampling grids in 22 forests (coded as 1) and 22 bogs (coded as 2) surrounding the forests in  Connecticut, Massachusetts, and Vermont. The sites span 3Ê• of latitude in New England. There are 44 observations on four variables (columns in data set): 

- Ants: number of species, 
- Habitat: forests (1) and bogs (2), 
- Latitude
- Elevation: in meters above sea level.

(a) Using Poisson regression, model the number of ant species (Ants) with covariates Habitat and Elevation.  
(b) For a sampling grid unit located in a forest at the elevation of 100 m how many species the model from (a) predicts? For the model coefficients and the prediction report 95% credible  sets.

In [6]:
data = pd.read_csv("../data/ants.csv")
data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 44 entries, 0 to 43
Data columns (total 3 columns):
 #   Column     Non-Null Count  Dtype
---  ------     --------------  -----
 0   ants       44 non-null     int64
 1   habitat    44 non-null     int64
 2   elevation  44 non-null     int64
dtypes: int64(3)
memory usage: 1.2 KB


In [15]:
with pm.Model() as m:
    ant_species = pm.Data("ant_species", data["ants"].to_numpy(), mutable=False)
    habitat = pm.Data("habitat", data["habitat"].to_numpy(), mutable=True)
    elevation = pm.Data("elevation", data["elevation"].to_numpy(), mutable=True)

    beta0 = pm.Normal("beta0_intercept", mu=0, tau=0.0001)
    beta1 = pm.Normal("beta1_habitat", mu=0, tau=0.0001)
    beta2 = pm.Normal("beta2_elevation", mu=0, tau=0.0001)

    μ = pm.math.exp(beta0 + beta1 * habitat + beta2 * elevation)

    y = pm.Poisson("y", mu=μ, observed=ant_species)

    trace = pm.sample(5000, tune=2000, init="adapt_diag")

Auto-assigning NUTS sampler...
Initializing NUTS using adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [beta0_intercept, beta1_habitat, beta2_elevation]


Sampling 4 chains for 2_000 tune and 5_000 draw iterations (8_000 + 20_000 draws total) took 5 seconds.


In [16]:
az.summary(trace)

Unnamed: 0,mean,sd,hdi_3%,hdi_97%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
beta0_intercept,3.168,0.186,2.808,3.511,0.002,0.001,7699.0,8786.0,1.0
beta1_habitat,-0.636,0.119,-0.858,-0.411,0.001,0.001,8088.0,8822.0,1.0
beta2_elevation,-0.001,0.0,-0.002,-0.001,0.0,0.0,10210.0,9781.0,1.0


In [None]:
# prediction
with m:
    pm.set_data({"habitat": [1], "elevation": [100]})
    ppc = pm.sample_posterior_predictive(trace, predictions=True)

Sampling: [y]


In [17]:
ppc.predictions

In [18]:
az.summary(ppc.predictions).mean()

mean            10.894955
sd               3.411932
hdi_3%           5.000000
hdi_97%         17.000000
mcse_mean        0.025000
mcse_sd          0.017773
ess_bulk     18685.977273
ess_tail     19295.681818
r_hat            1.000000
dtype: float64

In [19]:
%load_ext watermark
%watermark -n -u -v -iv -p pytensor

The watermark extension is already loaded. To reload it, use:
  %reload_ext watermark
Last updated: Wed Mar 22 2023

Python implementation: CPython
Python version       : 3.11.0
IPython version      : 8.9.0

pytensor: 2.10.1

numpy : 1.24.2
arviz : 0.14.0
pandas: 1.5.3
pymc  : 5.1.2

