In [1]:
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, classification_report, jaccard_score, roc_auc_score
from sklearn.multioutput import MultiOutputClassifier
from sklearn.ensemble import RandomForestClassifier
import xgboost as xgb

In [2]:
import xgboost as xgb
print(xgb.__version__)

2.1.3


In [3]:
import os

os.chdir(r"C:\Users\LENOVO\Desktop\Intern_data\Project_Result\Matrix")
print("Current Working Directory:", os.getcwd())

Current Working Directory: C:\Users\LENOVO\Desktop\Intern_data\Project_Result\Matrix


In [4]:
gene_matrix = pd.read_csv("abricate_resfinder_gene_matrix.csv", index_col=0)
resistance_matrix = pd.read_csv("abricate_resfinder_resistance_matrix.csv", index_col=0)

In [5]:
gene_jaccard = pd.read_csv("gene_gene_jaccard_matrix.csv", index_col=0)
res_jaccard = pd.read_csv("res_res_jaccard_matrix.csv", index_col=0)
gene_antibiotic_matrix = pd.read_csv("gene_antibiotic_matrix.csv", index_col=0)


In [6]:
metadata = pd.read_csv("strain_metadata.csv", index_col=0)
metadata.head(5)

Unnamed: 0_level_0,species,total_genes_detected,total_resistance_classes,MDR_status,network_degree,network_betweenness,macrolide_gene_count,tetracycline_gene_count,fluoroquinolone_gene_count,aminoglycoside_gene_count
strain,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
Streptococcus_australis_GCA_900476055.1_53750_F01,,2,,0,,,0,2,0,0
Streptococcus_australis_GCA_900636505.1_42650_H02,,2,,0,,,0,2,0,0
Streptococcus_cristatus_GCA_000385925.1_ASM38592v1,,2,,0,,,1,1,0,0
Streptococcus_cristatus_GCA_900475445.1_42727_F01,,1,,0,,,0,1,0,0
Streptococcus_cristatus_GCA_900478185.1_51342_H01,,2,,0,,,1,1,0,0


In [7]:
print("Shapes:")
print("Gene Matrix:", gene_matrix.shape)
print("Resistance Matrix:", resistance_matrix.shape)

Shapes:
Gene Matrix: (56, 25)
Resistance Matrix: (56, 21)


In [8]:
gene_matrix = gene_matrix.loc[:, gene_matrix.sum(axis=0) > 0]

In [9]:
resistance_matrix = resistance_matrix.loc[:, resistance_matrix.sum(axis=0) > 0]


In [10]:
print("\nAfter Cleaning:")
print("Gene Matrix:", gene_matrix.shape)
print("Resistance Matrix:", resistance_matrix.shape)



After Cleaning:
Gene Matrix: (56, 25)
Resistance Matrix: (56, 21)


In [11]:
gene_matrix.index = gene_matrix.index.str.replace("_genomic", "", regex=False)
resistance_matrix.index = resistance_matrix.index.str.replace("_genomic", "", regex=False)


In [12]:
common_strains = gene_matrix.index.intersection(resistance_matrix.index)
print("Common strains found:", len(common_strains))

Common strains found: 56


In [13]:
gene_matrix = gene_matrix.loc[common_strains]
resistance_matrix = resistance_matrix.loc[common_strains]


In [14]:
gene_matrix = gene_matrix.loc[resistance_matrix.index]


In [33]:
def get_pca_features(matrix, name_prefix, n_components=5):
    scaler = StandardScaler()
    scaled = scaler.fit_transform(matrix)

    pca = PCA(n_components=n_components)
    pca_features = pca.fit_transform(scaled)

    cols = [f"{name_prefix}_PC{i+1}" for i in range(n_components)]
    pca_df = pd.DataFrame(pca_features, index=matrix.index, columns=cols)

    return pca_df, scaler, pca



In [34]:
gene_pca, gene_scaler, gene_pca_model = get_pca_features(gene_jaccard, "GENE")
res_pca, res_scaler, res_pca_model = get_pca_features(res_jaccard, "RES")


In [35]:
X = pd.concat([gene_matrix, gene_pca, res_pca], axis=1)


In [36]:
if metadata is not None:
    X = pd.concat([X, metadata], axis=1)

In [37]:
y_genes = gene_matrix.copy()  # Multi-label gene prediction target
y_resistance = resistance_matrix.copy()  # Antibiotic resistance prediction target

print("\nFinal Feature Matrix Shape:", X.shape)


Final Feature Matrix Shape: (96, 45)


In [38]:
common = X.index.intersection(y_genes.index).intersection(y_resistance.index)

print("Common rows found:", len(common))


Common rows found: 56


In [39]:
X = X.loc[common]
y_genes = y_genes.loc[common]
y_resistance = y_resistance.loc[common]

In [40]:
low_support = y_resistance.sum(axis=0)[y_resistance.sum(axis=0) < 2].index.tolist()

print("Removing rare classes:", low_support)

# Drop them
y_resistance_filtered = y_resistance.drop(columns=low_support)

Removing rare classes: ['Amikacin', 'Gentamicin', 'Tobramycin', 'Streptomycin']


In [41]:
X_train, X_test, y_genes_train, y_genes_test = train_test_split(
    X, y_genes, test_size=0.25, random_state=42
)

_, _, y_res_train, y_res_test = train_test_split(
    X, y_resistance, test_size=0.25, random_state=42
)


In [42]:
from sklearn.ensemble import RandomForestClassifier

rf_balanced = RandomForestClassifier(
    n_estimators=500,
    class_weight="balanced",
    max_depth=None,
    random_state=42
)


In [43]:
rf_genes = RandomForestClassifier(n_estimators=300, random_state=42)
multi_label_model = MultiOutputClassifier(rf_genes)

print("\nTraining Gene Prediction Model...")
multi_label_model.fit(X_train, y_genes_train)


Training Gene Prediction Model...


In [44]:
print("Predicting...")
y_pred_genes = multi_label_model.predict(X_test)

Predicting...


In [45]:
print("\nGENE PREDICTION RESULTS:")
print("Micro F1:", f1_score(y_genes_test, y_pred_genes, average='micro'))
print("Macro F1:", f1_score(y_genes_test, y_pred_genes, average='macro'))
print("Jaccard:", jaccard_score(y_genes_test, y_pred_genes, average='micro'))



GENE PREDICTION RESULTS:
Micro F1: 0.8928571428571429
Macro F1: 0.30666666666666664
Jaccard: 0.8064516129032258


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [46]:
rf_models = {}
res_predictions = pd.DataFrame(index=y_res_test.index)

print("\nTraining Antibiotic Models (RandomForest)...")

for antibiotic in y_resistance.columns:
    print(f" → Training for {antibiotic}...")

    clf = RandomForestClassifier(
        n_estimators=400,
        max_depth=None,
        random_state=42
    )

    clf.fit(X_train, y_res_train[antibiotic])
    rf_models[antibiotic] = clf

    preds = clf.predict(X_test)
    res_predictions[antibiotic] = preds

print("\nANTIBIOTIC RESISTANCE PREDICTION SUMMARY:")


Training Antibiotic Models (RandomForest)...
 → Training for Doxycycline...
 → Training for Tetracycline...
 → Training for Minocycline...
 → Training for Erythromycin...
 → Training for Lincomycin...
 → Training for Clindamycin...
 → Training for Quinupristin...
 → Training for Pristinamycin_IA...
 → Training for Virginiamycin_S...
 → Training for Azithromycin...
 → Training for Telithromycin...
 → Training for Dalfopristin...
 → Training for Pristinamycin_IIA...
 → Training for Virginiamycin_M...
 → Training for Tiamulin...
 → Training for Tigecycline...
 → Training for Chloramphenicol...
 → Training for Amikacin...
 → Training for Gentamicin...
 → Training for Tobramycin...
 → Training for Streptomycin...

ANTIBIOTIC RESISTANCE PREDICTION SUMMARY:


In [47]:
print(classification_report(y_res_test, res_predictions))

              precision    recall  f1-score   support

           0       1.00      1.00      1.00         8
           1       1.00      1.00      1.00         8
           2       1.00      1.00      1.00         6
           3       1.00      1.00      1.00        10
           4       1.00      0.67      0.80         3
           5       1.00      1.00      1.00         2
           6       1.00      1.00      1.00        10
           7       1.00      1.00      1.00        10
           8       1.00      1.00      1.00        10
           9       1.00      1.00      1.00         8
          10       1.00      1.00      1.00         8
          11       0.00      0.00      0.00         1
          12       0.00      0.00      0.00         1
          13       0.00      0.00      0.00         1
          14       0.00      0.00      0.00         1
          15       1.00      1.00      1.00         2
          16       0.00      0.00      0.00         1
          17       0.00    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [48]:
def predict_amr_for_new_strain(new_gene_vector):
    new_df = pd.DataFrame([new_gene_vector])

    for col in X.columns:
        if col not in new_df.columns:
            new_df[col] = 0

    gene_preds = multi_label_model.predict(new_df[X.columns])[0]

    antibiotic_results = {
        ab: model.predict(new_df[X.columns])[0]
        for ab, model in rf_models.items()
    }

    return gene_preds, antibiotic_results

print("\nPipeline Ready — RandomForest Version!")


Pipeline Ready — RandomForest Version!


In [49]:
def recommend_antibiotics(strain_features, rf_models, threshold=0.5):
    """
    Takes a feature vector for a strain (same as X),
    returns recommended antibiotics ranked by effectiveness.
    """

    results = []

    for antibiotic, model in rf_models.items():

        # Probability of resistance (1)
        prob_resistant = model.predict_proba([strain_features])[0][1]

        # Convert probability → predicted status
        predicted_label = int(prob_resistant >= threshold)

        # For recommendation, we want probability of SENSITIVITY
        prob_sensitive = 1 - prob_resistant

        results.append({
            "antibiotic": antibiotic,
            "prob_resistant": prob_resistant,
            "prob_sensitive": prob_sensitive,
            "predicted_resistance": predicted_label
        })

    df = pd.DataFrame(results)

    # Rank antibiotics: highest sensitivity first
    df = df.sort_values(by="prob_sensitive", ascending=False)

    return df


In [50]:
# Select any strain from your test set
strain_id = X_test.index[2]

# Extract features for this strain
strain_features = X_test.loc[strain_id].values

recommendations = recommend_antibiotics(strain_features, rf_models)
print(recommendations)




           antibiotic  prob_resistant  prob_sensitive  predicted_resistance
0         Doxycycline             0.0             1.0                     0
11       Dalfopristin             0.0             1.0                     0
19         Tobramycin             0.0             1.0                     0
18         Gentamicin             0.0             1.0                     0
17           Amikacin             0.0             1.0                     0
16    Chloramphenicol             0.0             1.0                     0
15        Tigecycline             0.0             1.0                     0
14           Tiamulin             0.0             1.0                     0
13    Virginiamycin_M             0.0             1.0                     0
12  Pristinamycin_IIA             0.0             1.0                     0
20       Streptomycin             0.0             1.0                     0
1        Tetracycline             0.0             1.0                     0
5         Cl



In [51]:
def print_recommendation_report(strain_id, recommendations):
    print(f"\n=== Antibiotic Recommendation Report for Strain {strain_id} ===\n")

    print("Recommended Antibiotics (high confidence):")
    for idx, row in recommendations.iterrows():
        if row["predicted_resistance"] == 0 and row["prob_sensitive"] >= 0.70:
            print(f"✔ {row['antibiotic']}  (sensitivity: {row['prob_sensitive']:.2f})")

    print("\nAvoid These Antibiotics:")
    for idx, row in recommendations.iterrows():
        if row["predicted_resistance"] == 1:
            print(f"✘ {row['antibiotic']}  (resistance: {row['prob_resistant']:.2f})")

print_recommendation_report(strain_id, recommendations)




=== Antibiotic Recommendation Report for Strain Streptococcus_oralis_GCA_016028175.1_ASM1602817v1 ===

Recommended Antibiotics (high confidence):
✔ Doxycycline  (sensitivity: 1.00)
✔ Dalfopristin  (sensitivity: 1.00)
✔ Tobramycin  (sensitivity: 1.00)
✔ Gentamicin  (sensitivity: 1.00)
✔ Amikacin  (sensitivity: 1.00)
✔ Chloramphenicol  (sensitivity: 1.00)
✔ Tigecycline  (sensitivity: 1.00)
✔ Tiamulin  (sensitivity: 1.00)
✔ Virginiamycin_M  (sensitivity: 1.00)
✔ Pristinamycin_IIA  (sensitivity: 1.00)
✔ Streptomycin  (sensitivity: 1.00)
✔ Tetracycline  (sensitivity: 1.00)
✔ Clindamycin  (sensitivity: 1.00)
✔ Lincomycin  (sensitivity: 1.00)
✔ Minocycline  (sensitivity: 1.00)

Avoid These Antibiotics:
✘ Azithromycin  (resistance: 1.00)
✘ Virginiamycin_S  (resistance: 1.00)
✘ Pristinamycin_IA  (resistance: 1.00)
✘ Quinupristin  (resistance: 1.00)
✘ Erythromycin  (resistance: 1.00)
✘ Telithromycin  (resistance: 1.00)


In [52]:
import pickle

# Save multi-output gene prediction model
pickle.dump(multi_label_model, open("multi_label_gene_model.pkl", "wb"))
print("Saved: multi_label_gene_model.pkl")


Saved: multi_label_gene_model.pkl


In [53]:
# Save antibiotic ML models dictionary
pickle.dump(rf_models, open("rf_antibiotic_models.pkl", "wb"))
print("Saved: rf_antibiotic_models.pkl")


Saved: rf_antibiotic_models.pkl


In [54]:
X.to_csv("X_features.csv")
print("Saved: X_features.csv")


Saved: X_features.csv


In [55]:
with open("gene_matrix_columns.txt", "w") as f:
    for col in gene_matrix.columns:
        f.write(col + "\n")

print("Saved: gene_matrix_columns.txt")


Saved: gene_matrix_columns.txt


In [56]:
if metadata is not None:
    with open("metadata_columns.txt", "w") as f:
        for col in metadata.columns:
            f.write(col + "\n")
    print("Saved: metadata_columns.txt")
else:
    print("Metadata not used — skipping metadata_columns.txt")


Saved: metadata_columns.txt


In [57]:
pca_objects = {
    "gene_pca_scaler": gene_scaler,
    "gene_pca": gene_pca_model,
    "res_pca_scaler": res_scaler,
    "res_pca": res_pca_model
}

pickle.dump(pca_objects, open("pca_info.pkl", "wb"))
print("Saved: pca_info.pkl")


Saved: pca_info.pkl
