# Antibiotic resistance prediction

## Project Introduction
This project aims to predict antibiotic resistance using structured electronic health record (EHR) data from the Antibiotic Resistance Microbiology Dataset (ARMD). The goal is to classify whether a bacterial isolate is susceptible (S) or resistant (R) to a given antibiotic, based on clinical, demographic, microbiological, and treatment-related features. This binary classification model supports empirical antibiotic selection and contributes to combating antimicrobial resistance in clinical settings.

In [3]:
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.impute import SimpleImputer

final_armd_ds = 'ARMD_Dataset/selected_features_output.parquet'

df = pd.read_parquet(final_armd_ds)
print(df.shape)



(2184195, 27)


### Separate target + binary encoding

In [5]:
target_col = 'susceptibility_label'
df[target_col] = df[target_col].map({'S': 0, 'R': 1})  
y = df[target_col]
X = df.drop(columns=[target_col])
print('y: ',y.shape)
print('X: ',X.shape)

y:  (2184195,)
X:  (2184195, 26)


### Identify column types

In [7]:
X.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2184195 entries, 0 to 2184194
Data columns (total 26 columns):
 #   Column                          Dtype  
---  ------                          -----  
 0   organism_x                      string 
 1   antibiotic_x                    string 
 2   resistant_time_to_culturetime   float64
 3   age                             string 
 4   gender                          string 
 5   adi_score                       string 
 6   adi_state_rank                  string 
 7   median_wbc                      string 
 8   median_neutrophils              string 
 9   median_lymphocytes              string 
 10  median_hgb                      string 
 11  median_plt                      string 
 12  median_na                       string 
 13  median_hco3                     string 
 14  median_bun                      string 
 15  median_cr                       string 
 16  median_lactate                  string 
 17  median_procalcitonin       

## Column Categorization
- First, properly define all column categories
- Combine numerical features

In [9]:
true_categorical_cols = ['organism_x', 'antibiotic_x', 'gender', 'medication_category']
numeric_cols = ['resistant_time_to_culturetime', 'median_heartrate', 'median_resprate',
               'median_temp', 'median_sysbp', 'median_diasbp',
               'medication_time_to_culturetime', 'nursing_home_visit_culture']
numerical_med_cols = ['median_wbc', 'median_neutrophils', 'median_lymphocytes',
                     'median_hgb', 'median_plt', 'median_na', 'median_hco3',
                     'median_bun', 'median_cr', 'median_lactate', 'median_procalcitonin']
ordinal_cols = ['age', 'adi_score', 'adi_state_rank']


all_numerical_cols = numeric_cols + numerical_med_cols


all_columns = true_categorical_cols + all_numerical_cols + ordinal_cols


## Apply CatBoost Encoding

#### Data Splitting

In [12]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

print("\nFinal Training Set Shape:", X_train.shape)
print("Final Test Set Shape:", X_test.shape)


Final Training Set Shape: (1747356, 26)
Final Test Set Shape: (436839, 26)


#### Ensure proper data types

In [14]:
X_train[true_categorical_cols] = X_train[true_categorical_cols].astype(str)
X_test[true_categorical_cols] = X_test[true_categorical_cols].astype(str)


#### Handle missing 
- Replace 'Null' with actual NaN values
- **For numerical columns**: Simple imputation (fill with mean/median/mode)
- **For categorical columns (if any)**: Fill with the most frequent.

In [16]:
import numpy as np
import pandas as pd
from sklearn.impute import SimpleImputer

def safe_impute(train_df, test_df, num_cols):
    """
    Safely impute missing values in numerical columns, handling persistent NA types
    """
    # 1. Make copies to avoid SettingWithCopyWarning
    train_df = train_df.copy()
    test_df = test_df.copy()
    
    # 2. Force convert all numerical columns to float, coercing errors
    for col in num_cols:
        train_df[col] = pd.to_numeric(train_df[col], errors='coerce')
        test_df[col] = pd.to_numeric(test_df[col], errors='coerce')
    
    # 3. Convert all possible NA representations to np.nan
    na_values = [pd.NA, pd.NaT, np.nan, None, 'NA', 'NaN', 'nan', 'null', 'Null']
    train_df[num_cols] = train_df[num_cols].replace(na_values, np.nan)
    test_df[num_cols] = test_df[num_cols].replace(na_values, np.nan)
    
    # 4. Verify conversion
    print("Missing values after conversion:")
    print(train_df[num_cols].isna().sum())
    
    # 5. Impute using median - convert to numpy arrays first
    num_imputer = SimpleImputer(strategy='median')
    
    # Fit on training data
    train_imputed = num_imputer.fit_transform(train_df[num_cols])
    test_imputed = num_imputer.transform(test_df[num_cols])
    
    # Convert back to DataFrame
    train_df[num_cols] = pd.DataFrame(train_imputed, 
                                    columns=num_cols,
                                    index=train_df.index)
    test_df[num_cols] = pd.DataFrame(test_imputed,
                                   columns=num_cols,
                                   index=test_df.index)
    
    # 6. Final verification
    print("\nMissing values after imputation:")
    print(train_df[num_cols].isna().sum())
    
    return train_df, test_df

# Apply to your data
X_train, X_test = safe_impute(X_train, X_test, all_numerical_cols)

Missing values after conversion:
resistant_time_to_culturetime       23366
median_heartrate                    23822
median_resprate                     27518
median_temp                         24637
median_sysbp                        24087
median_diasbp                       24087
medication_time_to_culturetime      29372
nursing_home_visit_culture        1746678
median_wbc                          14553
median_neutrophils                 149408
median_lymphocytes                 148739
median_hgb                          14553
median_plt                          14553
median_na                            9819
median_hco3                          9832
median_bun                           9862
median_cr                            9805
median_lactate                      88682
median_procalcitonin               149347
dtype: int64

Missing values after imputation:
resistant_time_to_culturetime     0
median_heartrate                  0
median_resprate                   0
median_temp   

####  Ordinal Encoding

In [18]:
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OrdinalEncoder
from sklearn.impute import SimpleImputer
import numpy as np
import pandas as pd

# 1. First convert ALL ordinal columns to strings
ordinal_cols = ['age', 'adi_score', 'adi_state_rank']
X_train_ordinal = X_train[ordinal_cols].astype(str)
X_test_ordinal = X_test[ordinal_cols].astype(str)

# 2. Define proper ordering (after seeing your value counts)
age_order = [
    '18-24 years', 
    '25-34 years',
    '35-44 years',
    '45-54 years',
    '55-64 years',
    '65-74 years',
    '75-84 years',
    '85-89 years',
    'above 90'
]

# For numeric ordinals, convert to float first, then sort
adi_order = sorted([float(x) for x in X['adi_score'].unique() if x not in ['Null', 'missing']])
rank_order = sorted([float(x) for x in X['adi_state_rank'].unique() if x not in ['Null', 'missing']])

# 3. Create the pipeline with proper string handling
ordinal_pipeline = Pipeline([
    ('imputer', SimpleImputer(strategy='constant', fill_value='missing')),
    ('encoder', OrdinalEncoder(
        categories=[age_order, adi_order, rank_order],
        handle_unknown='use_encoded_value',
        unknown_value=-1,
        dtype=np.int32
    ))
])

# 4. Fit and transform
ordinal_pipeline.fit(X_train_ordinal)
X_train[ordinal_cols] = ordinal_pipeline.transform(X_train_ordinal)
X_test[ordinal_cols] = ordinal_pipeline.transform(X_test_ordinal)

# 5. Verification
print("Encoded value counts:")
for col in ordinal_cols:
    print(f"\n{col}:")
    print(pd.Series(X_train[col]).value_counts().sort_index())

Encoded value counts:

age:
age
0      92143
1      10277
2       3094
3      69539
4    1551797
5       9668
6       7598
7       1524
8       1716
Name: count, dtype: int64

adi_score:
adi_score
-1    1747356
Name: count, dtype: int64

adi_state_rank:
adi_state_rank
-1    1747356
Name: count, dtype: int64


#### Apply CatBoost Encoding
Initialize CatBoost encoder with optimal settings
- Added noise to prevent overfitting
- Smoothing parameter

Fit and transform - ensuring no data leakage

In [20]:
from category_encoders import CatBoostEncoder
cbe = CatBoostEncoder(
    cols=true_categorical_cols,
    random_state=42,
    sigma=0.1,  # noise
    a=1.0       # Smoothing
)

X_train_encoded = cbe.fit_transform(X_train[true_categorical_cols], y_train)
X_test_encoded = cbe.transform(X_test[true_categorical_cols])

#### Create final feature sets

In [22]:
final_features = true_categorical_cols + all_numerical_cols + ordinal_cols
X_train_final = pd.concat([
    X_train_encoded,
    X_train[all_numerical_cols + ordinal_cols]
], axis=1)[final_features]  # Ensure consistent column order

X_test_final = pd.concat([
    X_test_encoded,
    X_test[all_numerical_cols + ordinal_cols]
], axis=1)[final_features]

#### Final verification

In [24]:
print("\nEncoded values validation:")
for col in true_categorical_cols:
    print(f"\n{col}:")
    print(f"Unique encoded values: {X_train_encoded[col].nunique()}")
    print("Value distribution:")
    print(X_train_encoded[col].describe())

print("\nFinal training set shape:", X_train_final.shape)
print("Final test set shape:", X_test_final.shape)


Encoded values validation:

organism_x:
Unique encoded values: 1747356
Value distribution:
count    1.747356e+06
mean     4.278171e-01
std      7.143352e-02
min      2.213322e-03
25%      4.044311e-01
50%      4.385614e-01
75%      4.702068e-01
max      1.070724e+00
Name: organism_x, dtype: float64

antibiotic_x:
Unique encoded values: 1747356
Value distribution:
count    1.747356e+06
mean     4.279686e-01
std      1.163655e-01
min      6.041039e-03
25%      3.878881e-01
50%      4.214420e-01
75%      4.539043e-01
max      1.249172e+00
Name: antibiotic_x, dtype: float64

gender:
Unique encoded values: 1747356
Value distribution:
count    1.747356e+06
mean     4.278894e-01
std      4.723353e-02
min      1.254375e-01
25%      3.959535e-01
50%      4.259338e-01
75%      4.572970e-01
max      7.193763e-01
Name: gender, dtype: float64

medication_category:
Unique encoded values: 1747356
Value distribution:
count    1.747356e+06
mean     4.278501e-01
std      7.062163e-02
min      2.895311e

## Models implementation 

In [None]:
from sklearn.preprocessing import RobustScaler  # Better for outliers
from sklearn.calibration import CalibratedClassifierCV

# 1. More balanced class weighting
model = make_pipeline(
    RobustScaler(),  # Handles outliers better than StandardScaler
    LogisticRegression(
        max_iter=5000,
        class_weight='balanced',  # Auto-balance (better than manual here)
        random_state=42,
        solver='saga',
        penalty='l2',  # Simpler regularization
        C=0.1  # Stronger regularization
    )
)

# 2. Train with calibration
calibrated_model = CalibratedClassifierCV(model, cv=3)
calibrated_model.fit(X_train_final, y_train)

# 3. Adjust prediction threshold
y_prob = calibrated_model.predict_proba(X_test_final)[:,1]
y_pred_adj = (y_prob > 0.4).astype(int)  # Lower threshold from 0.5 to 0.4

# 4. Enhanced evaluation
print("Balanced Logistic Regression")
print(classification_report(y_test, y_pred_adj))
print(f"ROC AUC: {roc_auc_score(y_test, y_prob):.3f}")
print("\nAdjusted Confusion Matrix:")
print(confusion_matrix(y_test, y_pred_adj))

# 5. Visualize probability distributions
plt.figure(figsize=(10,4))
plt.hist(y_prob[y_test==0], bins=50, alpha=0.5, label='Class 0')
plt.hist(y_prob[y_test==1], bins=50, alpha=0.5, label='Class 1')
plt.legend()
plt.title("Predicted Probability Distributions")
plt.show()

