In [1]:
import pandas as pd
import ast
from collections import Counter
from sklearn.preprocessing import LabelEncoder
from xgboost import XGBClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.preprocessing import MultiLabelBinarizer
from tqdm import tqdm

In [2]:
df = pd.read_csv(r'../data/interim/final_data_base.csv')
df.head()

Unnamed: 0,accession,date,organism,geographic_location,isolation_source,genetic_mechanisms,antibiogram
0,SAMN46923997,2025-02,neisseria gonorrhoeae,usa,rectal,none,"{'ciprofloxacin': 'intermediate', 'penicillin'..."
1,SAMN46841726,2025-02,klebsiella pneumoniae,brazil,blood,none,"{'cefotaxime': 'resistant', 'ceftazidime': 're..."
2,SAMN46841725,2025-02,klebsiella pneumoniae,brazil,blood,none,"{'cefotaxime': 'resistant', 'ceftazidime': 're..."
3,SAMN46841724,2025-02,klebsiella pneumoniae,brazil,blood,none,"{'cefotaxime': 'resistant', 'ceftazidime': 're..."
4,SAMN46841723,2025-02,klebsiella pneumoniae,brazil,ascitic fluid,none,"{'cefotaxime': 'resistant', 'ceftazidime': 're..."


In [3]:
# 1. What does the raw column look like?
print(df['antibiogram'].head(5).tolist())
print(type(df['antibiogram'].iloc[0]))

# 2. Try parsing one manually
print(ast.literal_eval(df['antibiogram'].iloc[0]))


["{'ciprofloxacin': 'intermediate', 'penicillin': 'resistant', 'tetracycline': 'resistant', 'cefepime': 'susceptible', 'cefixime': 'susceptible', 'ceftriaxone': 'susceptible', 'cefoxitin': 'susceptible', 'cefotaxime': 'susceptible'}", "{'cefotaxime': 'resistant', 'ceftazidime': 'resistant', 'cefepime': 'resistant', 'ertapenem': 'resistant', 'imipenem': 'resistant', 'meropenem': 'resistant', 'ciprofloxacin': 'resistant', 'gentamicin': 'resistant', 'polymyxin_b': 'resistant'}", "{'cefotaxime': 'resistant', 'ceftazidime': 'resistant', 'cefepime': 'resistant', 'ertapenem': 'resistant', 'imipenem': 'resistant', 'meropenem': 'resistant', 'ciprofloxacin': 'resistant', 'gentamicin': 'resistant', 'polymyxin_b': 'resistant'}", "{'cefotaxime': 'resistant', 'ceftazidime': 'resistant', 'cefepime': 'resistant', 'ertapenem': 'resistant', 'imipenem': 'resistant', 'meropenem': 'resistant', 'ciprofloxacin': 'resistant', 'gentamicin': 'susceptible', 'polymyxin_b': 'resistant'}", "{'cefotaxime': 'resistan

1 . XGBoost approach

In [6]:
tqdm.pandas()  # Enables progress_apply

def preprocess_for_xgboost(df, min_antibiotic_freq=10, min_gene_freq=50):
    def safe_parse_antibiogram(val):
        try:
            parsed = ast.literal_eval(val)
            if isinstance(parsed, dict):
                return parsed
        except Exception:
            pass
        return {}

    def safe_parse_genes(val):
        try:
            parsed = ast.literal_eval(val)
            if isinstance(parsed, list):
                return parsed
        except Exception:
            pass
        return []

    df = df.copy()

    print("Parsing 'antibiogram' column...")
    df['antibiogram'] = df['antibiogram'].progress_apply(safe_parse_antibiogram)

    print("Parsing 'genetic_mechanisms' column...")
    df['genetic_mechanisms'] = df['genetic_mechanisms'].progress_apply(safe_parse_genes)

    # --- Antibiotic Encoding ---
    print("Counting antibiotic frequencies...")
    all_antibiotics = [k for d in tqdm(df['antibiogram'], desc="Collecting antibiotics") for k in d.keys()]
    antibiotic_counts = pd.Series(all_antibiotics).value_counts()
    common_antibiotics = antibiotic_counts[antibiotic_counts >= min_antibiotic_freq].index.tolist()

    status_map = {'resistant': 1, 'intermediate': 0.5, 'susceptible': 0}
    print("Encoding antibiotic resistance levels...")
    ab_data = {
        f'antibiotic_{ab}': df['antibiogram'].progress_apply(lambda d: status_map.get(d.get(ab, 'unknown'), -1))
        for ab in tqdm(common_antibiotics, desc="Encoding antibiotics")
    }
    ab_matrix = pd.DataFrame(ab_data)

    # --- Genetic Marker Encoding ---
    print("Counting genetic marker frequencies...")
    all_genes = [gene for sublist in tqdm(df['genetic_mechanisms'], desc="Collecting genes") for gene in sublist]
    gene_counts = pd.Series(all_genes).value_counts()
    common_genes = gene_counts[gene_counts >= min_gene_freq].index.tolist()

    print("One-hot encoding genetic markers...")
    mlb = MultiLabelBinarizer(classes=common_genes)
    gene_matrix = pd.DataFrame(
        mlb.fit_transform(df['genetic_mechanisms']),
        columns=[f'gene_{g}' for g in mlb.classes_],
        index=df.index
    )

    print("Merging final dataframe...")
    df_processed = pd.concat(
        [df.drop(columns=['antibiogram', 'genetic_mechanisms']), ab_matrix, gene_matrix],
        axis=1
    )

    print("Done.")
    return df_processed, common_antibiotics



def train_models(df, antibiotic_list):
    features = df.drop(columns=[col for col in df.columns if col.endswith('_resistant')])
    models = {}
    reports = {}

    for ab in antibiotic_list:
        target_col = f"{ab}_resistant"
        valid_rows = df[target_col].notna()

        X = features.loc[valid_rows]
        y = df.loc[valid_rows, target_col]

        X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, test_size=0.2, random_state=42)

        model = XGBClassifier(use_label_encoder=False, eval_metric='logloss')
        model.fit(X_train, y_train)

        y_pred = model.predict(X_test)
        report = classification_report(y_test, y_pred, output_dict=True)

        models[ab] = model
        reports[ab] = report

    return models, reports



In [7]:
df_processed, common_antibiotics = preprocess_for_xgboost(df, min_antibiotic_freq=10)

df_processed.head()

Parsing 'antibiogram' column...


100%|██████████| 35253/35253 [00:02<00:00, 13291.03it/s]


Parsing 'genetic_mechanisms' column...


100%|██████████| 35253/35253 [00:20<00:00, 1707.99it/s]


Counting antibiotic frequencies...


Collecting antibiotics: 100%|██████████| 35253/35253 [00:00<00:00, 453576.32it/s]


Encoding antibiotic resistance levels...


100%|██████████| 35253/35253 [00:00<00:00, 756481.12it/s]t/s]
100%|██████████| 35253/35253 [00:00<00:00, 740991.04it/s]
100%|██████████| 35253/35253 [00:00<00:00, 733838.23it/s], 18.70it/s]
100%|██████████| 35253/35253 [00:00<00:00, 719112.71it/s]
100%|██████████| 35253/35253 [00:00<00:00, 707489.65it/s], 18.41it/s]
100%|██████████| 35253/35253 [00:00<00:00, 703223.10it/s]
100%|██████████| 35253/35253 [00:00<00:00, 813795.72it/s], 17.96it/s]
100%|██████████| 35253/35253 [00:00<00:00, 747962.93it/s]
100%|██████████| 35253/35253 [00:00<00:00, 636022.17it/s], 18.71it/s]
100%|██████████| 35253/35253 [00:00<00:00, 740901.93it/s]
100%|██████████| 35253/35253 [00:00<00:00, 726965.13it/s]6, 18.08it/s]
100%|██████████| 35253/35253 [00:00<00:00, 834971.70it/s]
100%|██████████| 35253/35253 [00:00<00:00, 812378.37it/s]5, 18.61it/s]
100%|██████████| 35253/35253 [00:00<00:00, 799062.92it/s]
100%|██████████| 35253/35253 [00:00<00:00, 698763.26it/s]5, 18.89it/s]
100%|██████████| 35253/35253 [00:00<00:

Counting genetic marker frequencies...


Collecting genes: 100%|██████████| 35253/35253 [00:00<00:00, 599489.95it/s]


One-hot encoding genetic markers...


TypeError: unhashable type: 'dict'

In [None]:
common_antibiotics

In [None]:
print(df['antibiogram'].iloc[0])
print(type(df['antibiogram'].iloc[0]))