In [6]:
import pandas as pd
import numpy as np
import lightgbm as lgb
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder, MultiLabelBinarizer
from sklearn.multioutput import MultiOutputClassifier
from sklearn.metrics import f1_score, jaccard_score, hamming_loss

df = pd.read_csv("/home/chanbo.s/personalized_ecoli/new_code/merged_microbiology_admissions_final.csv.csv")


bins = [0, 20, 40, 60, 80, 100]
labels = ['0-20', '21-40', '41-60', '61-80', '80+']
df['age_group'] = pd.cut(df['anchor_age'], bins=bins, labels=labels)
df = df[['gender', 'age_group', 'effective_antibiotics']]

df['effective_antibiotics'] = df['effective_antibiotics'].apply(
    lambda x: [ant.strip() for ant in x.strip("[]").replace("'", "").split(",") if ant.strip()]
)

encoder = OneHotEncoder(sparse_output=False, drop='first')
encoded_features = encoder.fit_transform(df[['gender', 'age_group']])
encoded_df = pd.DataFrame(encoded_features, columns=encoder.get_feature_names_out(['gender', 'age_group']))
df = pd.concat([df, encoded_df], axis=1)
df.drop(columns=['gender', 'age_group'], inplace=True)


mlb = MultiLabelBinarizer()
y = mlb.fit_transform(df['effective_antibiotics'])
y = np.array(y, dtype=np.int32)

cols_to_remove = np.where((y.sum(axis=0) == 0) | (y.sum(axis=0) == len(y)))[0]
if len(cols_to_remove) > 0:
    y = np.delete(y, cols_to_remove, axis=1)
    mlb.classes_ = np.delete(mlb.classes_, cols_to_remove)

X = df.drop(columns=['effective_antibiotics'])

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, shuffle=True
)

lgb_base = lgb.LGBMClassifier(objective="binary", random_state=42)
multi_lgb = MultiOutputClassifier(lgb_base)
multi_lgb.fit(X_train, y_train)

y_pred = multi_lgb.predict(X_test)

f1 = f1_score(y_test, y_pred, average='macro')
jaccard = jaccard_score(y_test, y_pred, average='samples')
hamming = hamming_loss(y_test, y_pred)

print(f"F1 Score: {f1:.2f}")
print(f"Jaccard Similarity: {jaccard:.2f}")
print(f"Hamming Loss: {hamming:.2f}")

def predict_antibiotics(gender, age, threshold=0.4):
    age_group = pd.cut([age], bins=[0,20,40,60,80,100], labels=labels)[0]
    input_df = pd.DataFrame({'gender': [gender], 'age_group': [age_group]})
    input_encoded = encoder.transform(input_df)
    input_encoded_df = pd.DataFrame(input_encoded, columns=encoder.get_feature_names_out(['gender', 'age_group']))
    input_final = pd.DataFrame(np.zeros((1, X.shape[1])), columns=X.columns)
    input_final.update(input_encoded_df)
    y_pred_proba = multi_lgb.predict_proba(input_final)
    y_pred = []
    for proba in y_pred_proba:
        pred = proba[0, 1] >= threshold if proba.shape[1]==2 else False
        y_pred.append(pred)
    y_pred = np.array(y_pred).reshape(1, -1)
    recommended_antibiotics = mlb.inverse_transform(y_pred)
    return recommended_antibiotics

gender, age = 'F', 30
recommended = predict_antibiotics(gender, age)
print(f"Recommended Antibiotics for {gender}, Age {age}: {recommended}")

[LightGBM] [Info] Number of positive: 1732, number of negative: 17667
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.001800 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 10
[LightGBM] [Info] Number of data points in the train set: 19399, number of used features: 5
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.089283 -> initscore=-2.322422
[LightGBM] [Info] Start training from score -2.322422
[LightGBM] [Info] Number of positive: 11165, number of negative: 8234
[LightGBM] [Info] Auto-choosing row-wise multi-threading, the overhead of testing was 0.000390 seconds.
You can set `force_row_wise=true` to remove the overhead.
And if memory is not enough, you can set `force_col_wise=true`.
[LightGBM] [Info] Total Bins 10
[LightGBM] [Info] Number of data points in the train set: 19399, number of used features: 5
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.575545 -> initscore=0.304512
[LightGBM] 

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
