In [3]:
import arviz as az
import numpy as np
import pymc as pm
from pymc.math import log, dot
import pandas as pd

%load_ext lab_black
%load_ext watermark

# Arrhythmia

A logistic regression example.

Adapted from [unit 7: arrhythmia.odc](https://raw.githubusercontent.com/areding/6420-pymc/main/original_examples/Codes4Unit7/arrhythmia.odc).

Data can be found [here](https://raw.githubusercontent.com/areding/6420-pymc/main/data/arrhythmia.csv).

## Associated lecture video: Unit 7 Lesson 15

In [4]:
%%html
<iframe width="560" height="315" src="https://www.youtube.com/embed?v=xomK4tcePmc&list=PLv0FeK5oXK4l-RdT6DWJj0_upJOG2WKNO&index=77" frameborder="0" allow="autoplay; encrypted-media" allowfullscreen></iframe>

## Problem statement

Patients who undergo Coronary Artery Bypass Graft Surgery (CABG) have an approximate 19-40% chance of developing atrial fibrillation (AF). AF can lead to blood clots forming causing greater in-hospital mortality, strokes, and longer hospital stays. While this can be prevented with drugs, it is very expensive and sometimes dangerous if not warranted. Ideally, several risk factors which would indicate an increased risk of developing AF in this population could save lives and money by indicating which patients need pharmacological intervention. Researchers began collecting data from CABG patients during their hospital stay such as demographics like age and sex, as well as heart rate, cholesterol, operation time, etc.. Then, the researchers recorded which patients developed AF during their hospital stay. Researchers now want to find those pieces of data which indicate high risk of AF. In the past, indicators like age, hypertension, and body surface area (BSA) have been good indicators, though these alone have not produced a satisfactory solution.

Fibrillation occurs when the heart muscle begins a quivering motion instead of a normal, healthy pumping rhythm. Fibrillation can affect the atrium (atrial fibrillation) or the ventricle (ventricular  fibrillation); ventricular fibrillation is imminently life threatening.

Atrial fibrillation is the quivering, chaotic motion in the upper chambers of the heart, known as the atria. Atrial fibrillation is often due to serious underlying medical conditions, and should be evaluated by a physician. It is not typically a medical emergency.

Ventricular fibrillation occurs in the ventricles (lower chambers) of the heart; it is always a medical emergency. If left untreated, ventricular fibrillation (VF, or V-fib) can lead to death within minutes. When a heart goes into V-fib, effective pumping of the blood stops. V-fib is considered a form of cardiac arrest, and an individual suffering from it will not survive unless cardiopulmonary resuscitation (CPR) and defibrillation are provided immediately.

DATA Arrhythmia
- Y = Fibrillation
- X1 = Age
- X2 = Aortic Cross Clamp Time
- X3 = Cardiopulmonary Bypass Time:
    - Bypass of the heart and lungs as, for example, in open heart surgery. Blood returning to the heart is diverted through a heart-lung machine (a pump-oxygenator) before returning it to the arterial circulation. The machine does the work both of the heart (pump blood) and the lungs (supply oxygen to red blood cells).
- X4 = ICU Time	(Intensive Care Unit)
- X5 = Avg Heart Rate	
- X6 = Left Ventricle Ejection Fraction
- X7 = Hypertension
- X8 = Gender [1 -Female; 0-Male]
- X9 = Diabetis
- X10 = Previous MI

In [6]:
data_df = pd.read_csv("../data/arrhythmia.csv")
data_df.info()
X = data_df.iloc[:, 1:].to_numpy()
# add intercept column to X
X_aug = np.concatenate((np.ones((X.shape[0], 1)), X), axis=1)
y = data_df["Fibrillation"].to_numpy()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 81 entries, 0 to 80
Data columns (total 11 columns):
 #   Column                         Non-Null Count  Dtype  
---  ------                         --------------  -----  
 0   Fibrillation                   81 non-null     float64
 1   Age                            81 non-null     float64
 2   AorticCrossClampTime           81 non-null     float64
 3   CardiopulmonaryBypassTime      81 non-null     float64
 4   ICUTime                        81 non-null     float64
 5   AvgHeartRate                   81 non-null     float64
 6   LeftVentricleEjectionFraction  81 non-null     float64
 7   Hypertension                   81 non-null     float64
 8   Gender                         81 non-null     float64
 9   Diabetes                       81 non-null     float64
 10  PreviousMI                     81 non-null     float64
dtypes: float64(11)
memory usage: 7.1 KB


In [7]:
with pm.Model() as m:
    X_data = pm.Data("X_data", X_aug)
    y_data = pm.Data("y_data", y)

    betas = pm.Normal("beta", mu=0, tau=0.001, shape=X.shape[1] + 1)

    p = dot(X_data, betas)

    lik = pm.Bernoulli("y", logit_p=p, observed=y_data)

    trace = pm.sample(
        10000,
        chains=4,
        tune=500,
        cores=4,
        random_seed=1,
    )

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [beta]


  return _boost._beta_ppf(q, a, b)
  return _boost._beta_ppf(q, a, b)
  return _boost._beta_ppf(q, a, b)
  return _boost._beta_ppf(q, a, b)
Sampling 4 chains for 500 tune and 10_000 draw iterations (2_000 + 40_000 draws total) took 39 seconds.


In [8]:
az.summary(trace, hdi_prob=0.95)

Unnamed: 0,mean,sd,hdi_2.5%,hdi_97.5%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat
beta[0],-13.067,4.934,-22.689,-3.345,0.039,0.028,15806.0,21174.0,1.0
beta[1],0.188,0.049,0.097,0.289,0.0,0.0,19858.0,22282.0,1.0
beta[2],0.033,0.026,-0.017,0.084,0.0,0.0,14546.0,19261.0,1.0
beta[3],-0.023,0.016,-0.056,0.007,0.0,0.0,14811.0,19381.0,1.0
beta[4],-0.155,0.096,-0.344,0.035,0.001,0.0,34820.0,23875.0,1.0
beta[5],0.006,0.032,-0.055,0.07,0.0,0.0,22649.0,24365.0,1.0
beta[6],0.025,0.028,-0.03,0.08,0.0,0.0,24866.0,26418.0,1.0
beta[7],-0.65,0.672,-1.974,0.652,0.004,0.003,33866.0,26537.0,1.0
beta[8],-0.327,0.684,-1.7,0.998,0.004,0.003,33485.0,26523.0,1.0
beta[9],1.321,0.692,-0.014,2.714,0.004,0.003,27431.0,25463.0,1.0


In [9]:
%watermark --iversions -v

Python implementation: CPython
Python version       : 3.10.4
IPython version      : 8.3.0

arviz : 0.12.1
pymc  : 4.0.0b5
pandas: 1.4.2
numpy : 1.22.3

