# Antibiotic resistance prediction

## Project Introduction
This project aims to predict antibiotic resistance using structured electronic health record (EHR) data from the Antibiotic Resistance Microbiology Dataset (ARMD). The goal is to classify whether a bacterial isolate is susceptible (S) or resistant (R) to a given antibiotic, based on clinical, demographic, microbiological, and treatment-related features. This binary classification model supports empirical antibiotic selection and contributes to combating antimicrobial resistance in clinical settings.

In [3]:
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.impute import SimpleImputer

final_armd_ds = 'ARMD_Dataset/selected_features_with_deltas.parquet'

df = pd.read_parquet(final_armd_ds)
print(df.shape)



(2184195, 38)


### Separate target + binary encoding

In [5]:
target_col = 'susceptibility_label'
df[target_col] = df[target_col].map({'S': 0, 'R': 1})  
y = df[target_col]
X = df.drop(columns=[target_col])
print('y: ',y.shape)
print('X: ',X.shape)

y:  (2184195,)
X:  (2184195, 37)


### Missing values handle 

In [7]:
X.isna().sum()

organism_x                              0
antibiotic_x                            0
resistant_time_to_culturetime       29367
age                                     0
gender                                  0
adi_score                               0
adi_state_rank                          0
median_wbc                          11684
median_neutrophils                  11684
median_lymphocytes                  11684
median_hgb                          11684
median_plt                          11684
median_na                           11684
median_hco3                         11684
median_bun                          11684
median_cr                           11684
median_lactate                      11684
median_procalcitonin                11684
median_heartrate                    29737
median_resprate                     34359
median_temp                         30766
median_sysbp                        30071
median_diasbp                       30071
medication_category               

In [8]:
X.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2184195 entries, 0 to 2184194
Data columns (total 37 columns):
 #   Column                          Dtype  
---  ------                          -----  
 0   organism_x                      string 
 1   antibiotic_x                    string 
 2   resistant_time_to_culturetime   float64
 3   age                             string 
 4   gender                          string 
 5   adi_score                       string 
 6   adi_state_rank                  string 
 7   median_wbc                      string 
 8   median_neutrophils              string 
 9   median_lymphocytes              string 
 10  median_hgb                      string 
 11  median_plt                      string 
 12  median_na                       string 
 13  median_hco3                     string 
 14  median_bun                      string 
 15  median_cr                       string 
 16  median_lactate                  string 
 17  median_procalcitonin       

In [9]:
categorical_cols = X.select_dtypes(include=['string']).columns
numeric_cols = X.select_dtypes(include=['number']).columns
print(f'categorical_cols: {categorical_cols}')
print('--------------------------------------------')
print(f'numeric_cols: {numeric_cols}')

categorical_cols: Index(['organism_x', 'antibiotic_x', 'age', 'gender', 'adi_score',
       'adi_state_rank', 'median_wbc', 'median_neutrophils',
       'median_lymphocytes', 'median_hgb', 'median_plt', 'median_na',
       'median_hco3', 'median_bun', 'median_cr', 'median_lactate',
       'median_procalcitonin', 'medication_category'],
      dtype='object')
--------------------------------------------
numeric_cols: Index(['resistant_time_to_culturetime', 'median_heartrate', 'median_resprate',
       'median_temp', 'median_sysbp', 'median_diasbp',
       'medication_time_to_culturetime', 'first_wbc', 'last_wbc', 'first_cr',
       'last_cr', 'first_lactate', 'last_lactate', 'first_procalcitonin',
       'last_procalcitonin', 'delta_wbc', 'delta_cr', 'delta_lactate',
       'delta_procalcitonin'],
      dtype='object')


### Encode categorical columns

In [15]:
label_encoders = {}
for col in categorical_cols:
    le = LabelEncoder()
    X[col] = le.fit_transform(X[col].astype(str))
    label_encoders[col] = le  # Save encoder if needed later


### Verify All Columns Are Numeric

In [18]:
non_numeric_cols = X.select_dtypes(exclude=["number"]).columns
print("Non-numeric columns:", non_numeric_cols.tolist())


Non-numeric columns: []


### Fill missing values (numeric only here, since all is numeric now)

In [22]:
# Drop columns with all missing values
X_non_empty = X.dropna(axis=1, how='all')

# Apply imputer only to columns with at least some non-missing data
imputer = SimpleImputer(strategy='median')
X_imputed_array = imputer.fit_transform(X_non_empty)
X_imputed = pd.DataFrame(X_imputed_array, columns=X_non_empty.columns)


In [26]:
print(X_imputed_array.shape)
print(X_imputed.shape)

(2184195, 36)
(2184195, 36)


### Combine imputed features with target

In [29]:
df_numeric = X_imputed.copy()
df_numeric['susceptibility_label'] = y.values


### Save as Parquet

In [34]:
df_numeric.to_parquet("numeric_dataset_labelEncoder.parquet", index=False)
print("✅ All features numeric and saved to 'numeric_dataset.parquet'")

✅ All features numeric and saved to 'numeric_dataset.parquet'


## Load the Data

In [36]:
import pandas as pd

df = pd.read_parquet("numeric_dataset_labelEncoder.parquet")
X = df.drop(columns=["susceptibility_label"])
y = df["susceptibility_label"]
print('y: ',y.shape)
print('X: ',X.shape)

y:  (2184195,)
X:  (2184195, 36)


## Correlation Features


In [40]:
correlation = df.corr()['susceptibility_label'].sort_values(ascending=False)
print(correlation)

susceptibility_label              1.000000
median_heartrate                  0.162111
median_diasbp                     0.106129
median_plt                        0.103091
median_hco3                       0.096921
organism_x                        0.087757
median_sysbp                      0.066204
medication_category               0.062786
median_lactate                    0.062596
median_wbc                        0.059287
median_hgb                        0.008144
delta_wbc                         0.007449
last_wbc                          0.006466
delta_lactate                     0.001549
last_lactate                      0.000324
last_procalcitonin                0.000243
delta_cr                         -0.000105
first_lactate                    -0.001241
last_cr                          -0.006196
first_wbc                        -0.013252
median_neutrophils               -0.015547
resistant_time_to_culturetime    -0.016621
median_lymphocytes               -0.033181
median_na  

## Train/Test Split + Scale

In [45]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)


## Models implementation

In [None]:
from imblearn.under_sampling import RandomUnderSampler
rus = RandomUnderSampler(random_state=42)
X_res, y_res = rus.fit_resample(X_train, y_train)
from sklearn.model_selection import RandomizedSearchCV

param_grid = {
    'randomforestclassifier__n_estimators': [100, 200],
    'randomforestclassifier__max_depth': [10, 20, None],
    'randomforestclassifier__min_samples_split': [2, 10],
    'randomforestclassifier__min_samples_leaf': [1, 5, 10],
    'randomforestclassifier__class_weight': ['balanced', {0: 1, 1: 2}]
}
search = RandomizedSearchCV(model, param_grid, cv=3, scoring='roc_auc', n_iter=10, n_jobs=-1)
search.fit(X_train, y_train)

importances = model.named_steps['randomforestclassifier'].feature_importances_
feature_importance_df = pd.DataFrame({
    'feature': X_train.columns,
    'importance': importances
}).sort_values(by='importance', ascending=False)

plt.figure(figsize=(6,6))
plt.scatter(y_prob, y_test, alpha=0.01)
plt.xlabel('Predicted Probability')
plt.ylabel('True Label')
plt.title('Predicted Probability vs. True Label')
plt.show()


## Best Models and Techniques for Overlapping Classes

**Gradient Boosting Trees (e.g., XGBoost, LightGBM, CatBoost)**
- They handle non-linear boundaries well and are more robust to overlapping regions than Random Forests.
- LightGBM is very fast on large datasets.
- Use `scale_pos_weight` to handle class imbalance, and tune `max_depth`, `min_child_weight`, and `learning_rate`.

## LightGBM Modeling Pipeline

In [None]:
from lightgbm import LGBMClassifier
from sklearn.metrics import classification_report, roc_auc_score, confusion_matrix
import matplotlib.pyplot as plt

# 1. Estimate imbalance ratio for scale_pos_weight
neg, pos = (y_train == 0).sum(), (y_train == 1).sum()
imbalance_ratio = neg / pos

# 2. Define the LightGBM model
lgbm = LGBMClassifier(
    n_estimators=200,
    learning_rate=0.05,
    max_depth=10,
    min_child_weight=30,
    subsample=0.8,
    colsample_bytree=0.8,
    scale_pos_weight=imbalance_ratio,
    random_state=42,
    n_jobs=-1
)

# 3. Train the model
lgbm.fit(X_train, y_train)

# 4. Predict probabilities and adjust threshold
y_prob = lgbm.predict_proba(X_test)[:, 1]
y_pred_adj = (y_prob > 0.4).astype(int)

# 5. Evaluate performance
print("LightGBM Performance:")
print(classification_report(y_test, y_pred_adj))
print(f"\nROC AUC Score: {roc_auc_score(y_test, y_prob):.3f}")
print("\nConfusion Matrix:")
print(confusion_matrix(y_test, y_pred_adj))

# 6. Plot histogram of predicted probabilities
plt.figure(figsize=(10, 4))
plt.hist(y_prob[y_test==0], bins=50, alpha=0.5, label='Class 0')
plt.hist(y_prob[y_test==1], bins=50, alpha=0.5, label='Class 1')
plt.legend()
plt.title("LightGBM: Probability Distributions by True Class")
plt.xlabel("Predicted Probability")
plt.ylabel("Count")
plt.show()
