# Antibiotic resistance prediction

## Project Introduction
This project aims to predict antibiotic resistance using structured electronic health record (EHR) data from the Antibiotic Resistance Microbiology Dataset (ARMD). The goal is to classify whether a bacterial isolate is susceptible (S) or resistant (R) to a given antibiotic, based on clinical, demographic, microbiological, and treatment-related features. This binary classification model supports empirical antibiotic selection and contributes to combating antimicrobial resistance in clinical settings.

In [3]:
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.impute import SimpleImputer

final_armd_ds = 'ARMD_Dataset/selected_features_output.parquet'

df = pd.read_parquet(final_armd_ds)
print(df.shape)



(2184195, 27)


### Separate target + binary encoding

In [5]:
target_col = 'susceptibility_label'
df[target_col] = df[target_col].map({'S': 0, 'R': 1})  
y = df[target_col]
X = df.drop(columns=[target_col])
print('y: ',y.shape)
print('X: ',X.shape)

y:  (2184195,)
X:  (2184195, 26)


### Identify column types

In [7]:
X.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2184195 entries, 0 to 2184194
Data columns (total 26 columns):
 #   Column                          Dtype  
---  ------                          -----  
 0   organism_x                      string 
 1   antibiotic_x                    string 
 2   resistant_time_to_culturetime   float64
 3   age                             string 
 4   gender                          string 
 5   adi_score                       string 
 6   adi_state_rank                  string 
 7   median_wbc                      string 
 8   median_neutrophils              string 
 9   median_lymphocytes              string 
 10  median_hgb                      string 
 11  median_plt                      string 
 12  median_na                       string 
 13  median_hco3                     string 
 14  median_bun                      string 
 15  median_cr                       string 
 16  median_lactate                  string 
 17  median_procalcitonin       

In [8]:
categorical_cols = X.select_dtypes(include=['string']).columns
numeric_cols = X.select_dtypes(include=['number']).columns
print(f'categorical_cols: {categorical_cols}')
print('--------------------------------------------')
print(f'numeric_cols: {numeric_cols}')

categorical_cols: Index(['organism_x', 'antibiotic_x', 'age', 'gender', 'adi_score',
       'adi_state_rank', 'median_wbc', 'median_neutrophils',
       'median_lymphocytes', 'median_hgb', 'median_plt', 'median_na',
       'median_hco3', 'median_bun', 'median_cr', 'median_lactate',
       'median_procalcitonin', 'medication_category'],
      dtype='object')
--------------------------------------------
numeric_cols: Index(['resistant_time_to_culturetime', 'median_heartrate', 'median_resprate',
       'median_temp', 'median_sysbp', 'median_diasbp',
       'medication_time_to_culturetime', 'nursing_home_visit_culture'],
      dtype='object')


## Frequency counts

In [10]:
for cat_col in categorical_cols:
    print(X[cat_col].value_counts())

organism_x
PSEUDOMONAS AERUGINOSA                                             2017924
KLEBSIELLA PNEUMONIAE                                               103259
ESCHERICHIA COLI                                                     32539
ENTEROCOCCUS FAECALIS                                                 9373
STAPHYLOCOCCUS AUREUS                                                 3459
STAPH AUREUS {MRSA}                                                   3175
ENTEROCOCCUS SPECIES                                                  2852
COAG NEGATIVE STAPHYLOCOCCUS                                          2212
ZZZPROTEUS VULGARIS                                                   2048
MUCOID PSEUDOMONAS AERUGINOSA                                         1709
PROTEUS MIRABILIS                                                     1144
ENTEROBACTER CLOACAE COMPLEX                                           734
STREPTOCOCCUS AGALACTIAE (GROUP B)                                     500
KLEBSIELLA OXY

### Clean, Preprocess + Apply One-Hot Encoding to categorical columns (with frequency filtering)

#### Replace string placeholders for missing values
- Ensures consistent handling of missing values using np.nan, which is understood by sklearn and pandas tools

In [13]:
import pandas as pd
import numpy as np

# Replace known missing indicators
X = X.replace("Null", np.nan)
X = X.astype(object).where(pd.notna(X), np.nan)

### Identify Categorical vs. Numerical Columns
Enables applying appropriate preprocessing strategies for each data type:
- **Categorical columns:** `organism_x`, `antibiotic_x`, `gender`, `age`, `medication_category`, `adi_state_rank`, `adi_score`
- **Numerical columns:** everything starting with `median_` (like `median_wbc`, `median_cr`, etc.), and maybe adi_score I will treat it as continuous (for now).

In [15]:
# Define columns
cat_cols = ['organism_x', 'antibiotic_x', 'gender', 'age', 'medication_category', 'adi_state_rank', 'adi_score']
num_cols = [col for col in X.columns if col.startswith('median_')]


### Impute Missing Values
Fill missing values in **categorical columns** with the most frequent value.
- Prevents models from breaking due to missing categories, and avoids introducing bias from arbitrary imputation

Fill missing values in **numerical columns** with the mean.

- Ensures numerical columns are fully numeric and avoids dropping rows or using biased constants

In [17]:
# Impute missing values
X[cat_cols] = SimpleImputer(strategy='most_frequent').fit_transform(X[cat_cols])
X[num_cols] = SimpleImputer(strategy='mean').fit_transform(X[num_cols])


## Frequency filtering
Group rare categories into a common value (`__OTHER__`)
- Reduces dimensionality and sparsity of one-hot vectors; improves model generalization and memory usage.

In [19]:
# Frequency filtering
MIN_FREQ = 100
for col in cat_cols:
    value_counts = pd.Series(X[col]).value_counts()
    to_keep = value_counts[value_counts >= MIN_FREQ].index
    X[col] = pd.Series(X[col]).where(pd.Series(X[col]).isin(to_keep), other='__OTHER__')


### One-hot encode categorical columns
Converts categorical data to numeric while avoiding multicollinearity (`drop_first=True`)

In [21]:
# Now ready for one-hot encoding
X = pd.get_dummies(X, columns=cat_cols, drop_first=True)

### Checklist to verify everything is correctly processed before modeling

#### Chck If Missing Values Are Handled

In [24]:
X.head()

Unnamed: 0,resistant_time_to_culturetime,median_wbc,median_neutrophils,median_lymphocytes,median_hgb,median_plt,median_na,median_hco3,median_bun,median_cr,...,adi_score_5,adi_score_50,adi_score_51,adi_score_6,adi_score_68,adi_score_7,adi_score_71,adi_score_8,adi_score_9,adi_score___OTHER__
0,,6.2,74.790045,1.892448,12000.0,149.0,141.0,29.0,22.0,0.9,...,False,False,False,False,False,False,False,False,False,True
1,,6.2,74.790045,1.892448,12000.0,149.0,141.0,29.0,22.0,0.9,...,False,False,False,False,False,False,False,False,False,True
2,,6.2,74.790045,1.892448,12000.0,149.0,141.0,29.0,22.0,0.9,...,False,False,False,False,False,False,False,False,False,True
3,,6.2,74.790045,1.892448,12000.0,149.0,141.0,29.0,22.0,0.9,...,False,False,False,False,False,False,False,False,False,True
4,,6.2,74.790045,1.892448,12000.0,149.0,141.0,29.0,22.0,0.9,...,False,False,False,False,False,False,False,False,False,True


In [25]:
print(X.isnull().sum().sum())  # Should return 0


2249379


OMG i have huge red flag: i still have over 2.2 million missing values in my dataset after running all that code. 
😭 i have to debug this step-by-step to find why missing values remain. 

In [26]:
print(X.dtypes)

resistant_time_to_culturetime     object
median_wbc                       float64
median_neutrophils               float64
median_lymphocytes               float64
median_hgb                       float64
                                  ...   
adi_score_7                         bool
adi_score_71                        bool
adi_score_8                         bool
adi_score_9                         bool
adi_score___OTHER__                 bool
Length: 134, dtype: object
