In [1]:
%%capture
!pip install pandas scipy==1.5.2

In [2]:
import pandas as pd
from scipy import stats

In [3]:
data = pd.read_csv("data/medical_data.csv")
display(data)

Unnamed: 0,death,amiodarone,loop_diuretics,ivabradine,ARB,digoxin,MRA,heart_failure,AOS,SBP,DBP,PLT,LDL,HDL,LVEF,Na+,K+,MPV
0,0,0,0,1,1,1,0,2,0,62.0,126.0,196.0,50.0,42.0,28.0,136.0,4.30,9.9
1,0,1,1,0,1,0,1,0,1,72.0,108.0,245.0,59.0,85.0,25.0,147.0,4.58,13.3
2,0,0,0,0,0,1,1,2,2,73.0,109.0,219.0,79.0,61.0,14.0,133.0,4.05,1.5
3,0,0,0,1,1,1,0,2,1,55.0,114.0,294.0,97.0,55.0,8.0,150.0,5.34,7.8
4,0,0,1,0,1,1,1,2,2,70.0,95.0,293.0,96.0,30.0,28.0,151.0,5.25,6.8
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
575,1,1,1,0,0,1,0,2,0,82.0,129.0,186.0,77.0,27.0,61.0,131.0,4.23,0.8
576,1,0,1,0,1,0,1,1,2,86.0,122.0,281.0,111.0,25.0,83.0,142.0,5.04,1.7
577,1,1,1,1,0,0,1,0,2,80.0,121.0,210.0,69.0,85.0,61.0,132.0,3.64,4.5
578,1,1,1,1,0,1,0,1,1,73.0,114.0,231.0,44.0,102.0,77.0,154.0,4.16,8.1


In [4]:
test_data = data[["death", "Na+", "DBP", "PLT", "ivabradine", "MRA"]]
display(test_data.head())

Unnamed: 0,death,Na+,DBP,PLT,ivabradine,MRA
0,0,136.0,126.0,196.0,1,0
1,0,147.0,108.0,245.0,0,1
2,0,133.0,109.0,219.0,0,1
3,0,150.0,114.0,294.0,1,0
4,0,151.0,95.0,293.0,0,1


In [70]:
for col in data.columns:
    if data[col].dtype == 'int64':
        print(f"Col {col} has {data[col].nunique()} unique values")

Col death has 2 unique values
Col amiodarone has 2 unique values
Col loop_diuretics has 2 unique values
Col ivabradine has 2 unique values
Col ARB has 2 unique values
Col digoxin has 2 unique values
Col MRA has 2 unique values
Col heart_failure has 3 unique values
Col AOS has 3 unique values


In [71]:
pd.crosstab(data["death"], data["AOS"])

AOS,0,1,2
death,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,165,169,166
1,27,26,27


In [64]:
def chi_squared_test(data: pd.DataFrame, var:str) -> float:
    contingency_table = pd.crosstab(data["death"], data[var])
    _, p_value, _, _ = stats.chi2_contingency(contingency_table)

    return round(p_value, 4)


def shapiro_wilk_tests(group_0: pd.Series, group_1: pd.Series) -> tuple:
    _, p_value_0 = stats.shapiro(group_0)
    _, p_value_1 = stats.shapiro(group_1)

    return round(p_value_0, 4), round(p_value_1, 4)


def unpaired_ttest_test(group_0: pd.Series, group_1: pd.Series) -> float:
    _, p_value = stats.ttest_ind(group_0, group_1, equal_var=False)

    return round(p_value, 4)


def mann_whitney_u_test(group_0: pd.Series, group_1: pd.Series) -> float:
    _, p_value = stats.mannwhitneyu(group_0, group_1)

    return round(p_value, 4)


def perform_tests(data: pd.DataFrame) -> dict:
    mann_whitney_list = []
    ttest_list = []
    chi_square_list = []
    shapiro_wilk_list = []

    death_0 = data.loc[data["death"] == 0]
    death_1 = data.loc[data["death"] == 1]

    for var in [col for col in data.columns if col != "death"]:
        if data[var].dtype == "int64":
            chi2_pval = chi_squared_test(data, var)
            chi_square_list.append((var, chi2_pval))

        else:
            shap_wilk_pval_0, shap_wilk_pval_1 = shapiro_wilk_tests(death_0[var], death_1[var])
            shapiro_wilk_list.append((var, (shap_wilk_pval_0, shap_wilk_pval_1)))

            if shap_wilk_pval_0 > 0.05 and shap_wilk_pval_1 > 0.05:
                ttest_pval = unpaired_ttest_test(death_0[var], death_1[var])
                ttest_list.append((var, ttest_pval))
            
            else:
                mann_whit_pval = mann_whitney_u_test(death_0[var], death_1[var])
                mann_whitney_list.append((var, mann_whit_pval))

    return {
        'mann_whitney': mann_whitney_list,
        'ttest': ttest_list,
        'chi_square': chi_square_list,
        'shapiro_wilk': shapiro_wilk_list
    }

In [67]:
output = perform_tests(data)
output

{'mann_whitney': [('HDL', 0.0),
  ('LVEF', 0.0),
  ('Na+', 0.2143),
  ('K+', 0.4284),
  ('MPV', 0.4331)],
 'ttest': [('SBP', 0.0), ('DBP', 0.0), ('PLT', 0.4739), ('LDL', 0.8392)],
 'chi_square': [('amiodarone', 0.0),
  ('loop_diuretics', 0.0),
  ('ivabradine', 0.0144),
  ('ARB', 0.2338),
  ('digoxin', 0.5185),
  ('MRA', 0.2884),
  ('heart_failure', 0.0),
  ('AOS', 0.974)],
 'shapiro_wilk': [('SBP', (0.2581, 0.5881)),
  ('DBP', (0.5272, 0.3715)),
  ('PLT', (0.2361, 0.6935)),
  ('LDL', (0.9367, 0.95)),
  ('HDL', (0.4388, 0.0006)),
  ('LVEF', (0.0, 0.0)),
  ('Na+', (0.0, 0.0071)),
  ('K+', (0.0, 0.1771)),
  ('MPV', (0.0, 0.0003))]}

In [53]:
target_output = {'mann_whitney': [('Na+', 0.2143)],
                 'ttest': [('DBP', 0.0), ('PLT', 0.4739)],
                 'chi_square': [('ivabradine', 0.0144), ('MRA', 0.2884)],
                 'shapiro_wilk': [('Na+', (0.0, 0.0071)), ('PLT', (0.2361, 0.6935)), ('DBP', (0.5272, 0.3715))]
}

In [66]:
output == target_output

True