# Antibiotic resistance prediction

## Project Introduction
This project aims to predict antibiotic resistance using structured electronic health record (EHR) data from the Antibiotic Resistance Microbiology Dataset (ARMD). The goal is to classify whether a bacterial isolate is susceptible (S) or resistant (R) to a given antibiotic, based on clinical, demographic, microbiological, and treatment-related features. This binary classification model supports empirical antibiotic selection and contributes to combating antimicrobial resistance in clinical settings.

In [4]:
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.impute import SimpleImputer

final_armd_ds = 'ARMD_Dataset/selected_features_output.parquet'

df = pd.read_parquet(final_armd_ds)
print(df.shape)


(2184195, 27)


### Separate target + binary encoding

In [7]:
target_col = 'susceptibility_label'
df[target_col] = df[target_col].map({'S': 0, 'R': 1})  
y = df[target_col]
X = df.drop(columns=[target_col])
print('y: ',y.shape)
print('X: ',X.shape)

y:  (2184195,)
X:  (2184195, 26)


### Identify column types

In [9]:
X.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2184195 entries, 0 to 2184194
Data columns (total 26 columns):
 #   Column                          Dtype  
---  ------                          -----  
 0   organism_x                      string 
 1   antibiotic_x                    string 
 2   resistant_time_to_culturetime   float64
 3   age                             string 
 4   gender                          string 
 5   adi_score                       string 
 6   adi_state_rank                  string 
 7   median_wbc                      string 
 8   median_neutrophils              string 
 9   median_lymphocytes              string 
 10  median_hgb                      string 
 11  median_plt                      string 
 12  median_na                       string 
 13  median_hco3                     string 
 14  median_bun                      string 
 15  median_cr                       string 
 16  median_lactate                  string 
 17  median_procalcitonin       

In [12]:
categorical_cols = X.select_dtypes(include=['string']).columns
numeric_cols = X.select_dtypes(include=['number']).columns
print(f'categorical_cols: {categorical_cols}')
print('--------------------------------------------')
print(f'numeric_cols: {numeric_cols}')

categorical_cols: Index(['organism_x', 'antibiotic_x', 'age', 'gender', 'adi_score',
       'adi_state_rank', 'median_wbc', 'median_neutrophils',
       'median_lymphocytes', 'median_hgb', 'median_plt', 'median_na',
       'median_hco3', 'median_bun', 'median_cr', 'median_lactate',
       'median_procalcitonin', 'medication_category'],
      dtype='object')
--------------------------------------------
numeric_cols: Index(['resistant_time_to_culturetime', 'median_heartrate', 'median_resprate',
       'median_temp', 'median_sysbp', 'median_diasbp',
       'medication_time_to_culturetime', 'nursing_home_visit_culture'],
      dtype='object')


## Frequency counts

In [18]:
for cat_col in categorical_cols:
    print(X[cat_col].value_counts())

organism_x
PSEUDOMONAS AERUGINOSA                                             2017924
KLEBSIELLA PNEUMONIAE                                               103259
ESCHERICHIA COLI                                                     32539
ENTEROCOCCUS FAECALIS                                                 9373
STAPHYLOCOCCUS AUREUS                                                 3459
STAPH AUREUS {MRSA}                                                   3175
ENTEROCOCCUS SPECIES                                                  2852
COAG NEGATIVE STAPHYLOCOCCUS                                          2212
ZZZPROTEUS VULGARIS                                                   2048
MUCOID PSEUDOMONAS AERUGINOSA                                         1709
PROTEUS MIRABILIS                                                     1144
ENTEROBACTER CLOACAE COMPLEX                                           734
STREPTOCOCCUS AGALACTIAE (GROUP B)                                     500
KLEBSIELLA OXY

### Clean, Preprocess + Apply One-Hot Encoding to categorical columns (with frequency filtering)

#### Convert "Null" Strings to np.nan
- To Ensures proper missing value handling; avoids treating "Null" as a real category

In [50]:
import pandas as pd
import numpy as np

# Also handle pd.NA, None
X = X.replace(["Null", pd.NA, None], np.nan)


### Identify Categorical vs. Numerical Columns
- **Categorical columns:** `organism_x`, `antibiotic_x`, `gender`, `age`, `medication_category`, `adi_state_rank`, `adi_score`
- **Numerical columns:** everything starting with `median_` (like `median_wbc`, `median_cr`, etc.), and maybe adi_score I will treat it as continuous (for now).

In [57]:
# 2. Categorical: convert to object and force np.nan presence
cat_cols = ['organism_x', 'antibiotic_x', 'gender', 'age', 'medication_category', 'adi_state_rank', 'adi_score']
X[cat_cols] = X[cat_cols].astype(str).replace("nan", np.nan)

# 3. (Optional) force numeric columns to float for mean imputation
num_cols = [col for col in X.columns if col.startswith('median_')]
X[num_cols] = X[num_cols].astype(float)

### Impute Missing Values
I will ues different strategies for categorical vs. numerical:

In [59]:
from sklearn.impute import SimpleImputer

# Categorical imputation
cat_imputer = SimpleImputer(strategy='most_frequent')
X[cat_cols] = cat_imputer.fit_transform(X[cat_cols])

# Numerical imputation
num_cols = [col for col in X.columns if col.startswith('median_')]
num_imputer = SimpleImputer(strategy='mean')
X[num_cols] = num_imputer.fit_transform(X[num_cols])


### Verify All Columns Are Numeric

In [61]:
non_numeric_cols = X.select_dtypes(exclude=["number"]).columns
print("Non-numeric columns:", non_numeric_cols.tolist())


Non-numeric columns: ['organism_x', 'antibiotic_x', 'age', 'gender', 'adi_score', 'adi_state_rank', 'medication_category']


### Fill missing values (numeric only here, since all is numeric now)

In [14]:
imputer = SimpleImputer(strategy='median')
X_imputed = pd.DataFrame(imputer.fit_transform(X), columns=X.columns)


### Combine imputed features with target

In [16]:
df_numeric = X_imputed.copy()
df_numeric['susceptibility_label'] = y.values


### Save as Parquet

In [18]:
df_numeric.to_parquet("numeric_dataset.parquet", index=False)
print("✅ All features numeric and saved to 'numeric_dataset.parquet'")

✅ All features numeric and saved to 'numeric_dataset.parquet'


## Load the Data

In [31]:
import pandas as pd

df = pd.read_parquet("numeric_dataset.parquet")
X = df.drop(columns=["susceptibility_label"])
y = df["susceptibility_label"]


## Train/Test Split + Scale

In [42]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)


## Try Baseline Models

In [44]:
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, roc_auc_score

models = {
    "LogisticRegression": LogisticRegression(max_iter=2000),
    "RandomForest": RandomForestClassifier(n_estimators=100)
}

for name, model in models.items():
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    y_prob = model.predict_proba(X_test)[:, 1]
    
    print(f"\n{name}")
    print(classification_report(y_test, y_pred))
    print(f"ROC AUC: {roc_auc_score(y_test, y_prob):.3f}")


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(



LogisticRegression
              precision    recall  f1-score   support

           0       0.59      0.99      0.74    249846
           1       0.85      0.08      0.15    186993

    accuracy                           0.60    436839
   macro avg       0.72      0.54      0.45    436839
weighted avg       0.70      0.60      0.49    436839

ROC AUC: 0.561

RandomForest
              precision    recall  f1-score   support

           0       0.59      0.99      0.74    249846
           1       0.85      0.09      0.16    186993

    accuracy                           0.60    436839
   macro avg       0.72      0.54      0.45    436839
weighted avg       0.70      0.60      0.49    436839

ROC AUC: 0.548
