<a href="https://colab.research.google.com/github/Oselin1988/hybrid-dl-ast/blob/main/hybrid-dl-ast.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip -q install pandas numpy scikit-learn tensorflow openpyxl


In [2]:
import pandas as pd
import numpy as np

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report, confusion_matrix

import tensorflow as tf
from tensorflow.keras.layers import (
    Input, Embedding, Dense, Dropout,
    Conv1D, GlobalMaxPooling1D,
    Concatenate, Flatten, Reshape
)
from tensorflow.keras.models import Model


In [3]:
from google.colab import files
files.upload()   # upload: 1000 samples.xlsx


Saving original 1000 samples.xlsx to original 1000 samples.xlsx


{'original 1000 samples.xlsx': b'PK\x03\x04\x14\x00\x06\x00\x08\x00\x00\x00!\x00\xf6\xca@\xb7\x7f\x01\x00\x00\x8a\x05\x00\x00\x13\x00\x08\x02[Content_Types].xml \xa2\x04\x02(\xa0\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x

In [28]:
raw = pd.read_excel("original 1000 samples.xlsx")
df = pd.DataFrame({
    "organism": raw.iloc[:, 1],
    "antibiotic": raw.iloc[:, 3],
    "phenotype": raw.iloc[:, 11], # Corrected to 'AST phenotype' column (index 11)
    "amr_genotype_raw": raw.iloc[:, 10], # Corrected to 'AMR genotype' column (index 10)
    "stress_genotype_raw": raw.iloc[:, 12], # Corrected to 'stress genotype' column (index 12)
})

df.head()

Unnamed: 0,organism,antibiotic,phenotype,amr_genotype_raw,stress_genotype_raw
0,,,,,
1,Neisseria gonorrhoeae,4664,Resistant (2),Complete (5),Complete (1)
2,,BG24-032,ciprofloxacin,farB,mtrF
3,,SRS27057028,tetracycline,mtrA,
4,,,Intermediate (0),mtrC,


In [29]:
import re

def parse_phenotype_label(phenotype_str):
    s = str(phenotype_str).lower().strip()
    if 'susceptible' in s or s == 's':
        return 'susceptible'
    if 'intermediate' in s or s == 'i':
        return 'intermediate'
    if 'resistant' in s or s == 'r':
        return 'resistant'
    return None # Handle unparseable values by returning None

df["phenotype"] = df["phenotype"].apply(parse_phenotype_label)

label_map = {
    "susceptible": 0, "s": 0,
    "intermediate": 1, "i": 1,
    "resistant": 2, "r": 2
}

# Filter out rows where phenotype could not be parsed
df = df[df["phenotype"].notna()]

df["ast_label"] = df["phenotype"].map(label_map)

In [30]:
import re

def parse_genotype(x):
    if pd.isna(x):
        return 0.0
    s = str(x).lower()
    m = re.search(r"\((\d+)\)", s)
    if m:
        return float(m.group(1))
    m2 = re.search(r"(\d+(\.\d+)?)", s)
    if m2:
        return float(m2.group(1))
    return 0.0

df["amr_genotype"] = df["amr_genotype_raw"].apply(parse_genotype)
df["stress_genotype"] = df["stress_genotype_raw"].apply(parse_genotype)

In [31]:
df["organism"] = df["organism"].astype(str).fillna("unknown")
df["antibiotic"] = df["antibiotic"].astype(str).fillna("unknown")

le_org = LabelEncoder()
le_abx = LabelEncoder()

df["organism_id"] = le_org.fit_transform(df["organism"])
df["antibiotic_id"] = le_abx.fit_transform(df["antibiotic"])

In [32]:
X = {
    "organism": df["organism_id"].values,
    "antibiotic": df["antibiotic_id"].values,
    "genomic": df[["amr_genotype", "stress_genotype"]].values.astype("float32")
}

y = df["ast_label"].values.astype("int32")

In [34]:
X_org_tr, X_org_tmp, \
X_abx_tr, X_abx_tmp, \
X_gen_tr, X_gen_tmp, \
y_tr, y_tmp = train_test_split(
    X["organism"], X["antibiotic"], X["genomic"], y,
    test_size=0.30, random_state=42, stratify=y
)

X_org_val, X_org_te, \
X_abx_val, X_abx_te, \
X_gen_val, X_gen_te, \
y_val, y_te = train_test_split(
    X_org_tmp, X_abx_tmp, X_gen_tmp, y_tmp,
    test_size=0.50, random_state=42, stratify=y_tmp
)

X_train = {"organism": X_org_tr, "antibiotic": X_abx_tr, "genomic": X_gen_tr}
X_val   = {"organism": X_org_val, "antibiotic": X_abx_val, "genomic": X_gen_val}
X_test  = {"organism": X_org_te, "antibiotic": X_abx_te, "genomic": X_gen_te}


In [35]:
X_org_tr, X_org_tmp, \
X_abx_tr, X_abx_tmp, \
X_gen_tr, X_gen_tmp, \
y_tr, y_tmp = train_test_split(
    X["organism"], X["antibiotic"], X["genomic"], y,
    test_size=0.30, random_state=42, stratify=y
)

X_org_val, X_org_te, \
X_abx_val, X_abx_te, \
X_gen_val, X_gen_te, \
y_val, y_te = train_test_split(
    X_org_tmp, X_abx_tmp, X_gen_tmp, y_tmp,
    test_size=0.50, random_state=42, stratify=y_tmp
)

X_train = {"organism": X_org_tr, "antibiotic": X_abx_tr, "genomic": X_gen_tr}
X_val   = {"organism": X_org_val, "antibiotic": X_abx_val, "genomic": X_gen_val}
X_test  = {"organism": X_org_te, "antibiotic": X_abx_te, "genomic": X_gen_te}

In [13]:
display(raw.head())

Unnamed: 0,S/N,organism group,strain,isolaote identifier,isolate,location,isolation source,isolation type of pathogen,biosample,assembly,AMR genotype,AST phenotype,stress genotype,bioproject,method,WGS ascession
0,,,,,,,,,,,,,,,,
1,1.0,Neisseria gonorrhoeae,4664.0,4664,PDT002998700.1,Bulgaria: Ruse,first-void urine,clinical,SAMN53242160,GCA_053668095.1,Complete (5),Resistant (2),Complete (1),PRJNA1363965,Autocycler v. 0.5.0,JBSJED000000000.1
2,,,,BG24-032,,,,,,,farB,ciprofloxacin,mtrF,,,
3,,,,SRS27057028,,,,,,,mtrA,tetracycline,,,,
4,,,,,,,,,,,mtrC,Intermediate (0),,,,


In [33]:
X_org_tr, X_org_tmp, \
X_abx_tr, X_abx_tmp, \
X_gen_tr, X_gen_tmp, \
y_tr, y_tmp = train_test_split(
    X["organism"], X["antibiotic"], X["genomic"], y,
    test_size=0.30, random_state=42, stratify=y
)

X_org_val, X_org_te, \
X_abx_val, X_abx_te, \
X_gen_val, X_gen_te, \
y_val, y_te = train_test_split(
    X_org_tmp, X_abx_tmp, X_gen_tmp, y_tmp,
    test_size=0.50, random_state=42, stratify=y_tmp
)

X_train = {"organism": X_org_tr, "antibiotic": X_abx_tr, "genomic": X_gen_tr}
X_val   = {"organism": X_org_val, "antibiotic": X_abx_val, "genomic": X_gen_val}
X_test  = {"organism": X_org_te, "antibiotic": X_abx_te, "genomic": X_gen_te}

In [36]:
class_weights = compute_class_weight(
    class_weight="balanced",
    classes=np.unique(y_tr),
    y=y_tr
)
class_weights = dict(enumerate(class_weights))


In [37]:
org_in = Input(shape=(1,), name="organism")
abx_in = Input(shape=(1,), name="antibiotic")
gen_in = Input(shape=(2,), name="genomic")

org_emb = Embedding(len(le_org.classes_), 16)(org_in)
abx_emb = Embedding(len(le_abx.classes_), 16)(abx_in)

org_emb = Flatten()(org_emb)
abx_emb = Flatten()(abx_emb)

g = Reshape((2,1))(gen_in)
g = Conv1D(32, 2, activation="relu")(g)
g = GlobalMaxPooling1D()(g)

x = Concatenate()([org_emb, abx_emb, g])
x = Dense(64, activation="relu")(x)
x = Dropout(0.5)(x)
x = Dense(32, activation="relu")(x)

out = Dense(3, activation="softmax")(x)

model = Model([org_in, abx_in, gen_in], out)
model.compile(
    optimizer="adam",
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)

model.summary()


In [38]:
history = model.fit(
    X_train, y_tr,
    validation_data=(X_val, y_val),
    epochs=60,
    batch_size=32,
    class_weight=class_weights,
    callbacks=[
        tf.keras.callbacks.EarlyStopping(patience=8, restore_best_weights=True)
    ],
    verbose=1
)


Epoch 1/60
[1m66/66[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 9ms/step - accuracy: 0.4384 - loss: 2.0112 - val_accuracy: 0.6289 - val_loss: 0.9197
Epoch 2/60
[1m66/66[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - accuracy: 0.6396 - loss: 0.9359 - val_accuracy: 0.7111 - val_loss: 0.7077
Epoch 3/60
[1m66/66[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.6844 - loss: 0.6888 - val_accuracy: 0.7356 - val_loss: 0.6030
Epoch 4/60
[1m66/66[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.7327 - loss: 0.5872 - val_accuracy: 0.7511 - val_loss: 0.4647
Epoch 5/60
[1m66/66[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.7646 - loss: 0.4742 - val_accuracy: 0.7467 - val_loss: 0.4556
Epoch 6/60
[1m66/66[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.7835 - loss: 0.5035 - val_accuracy: 0.7511 - val_loss: 0.6684
Epoch 7/60
[1m66/66[0m [32m━━━━━━━━━━

In [39]:
y_pred = np.argmax(model.predict(X_test), axis=1)

print(classification_report(
    y_te, y_pred,
    target_names=["Susceptible","Intermediate","Resistant"]
))

print(confusion_matrix(y_te, y_pred))


[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 25ms/step
              precision    recall  f1-score   support

 Susceptible       0.63      0.66      0.64       150
Intermediate       0.65      0.61      0.63       150
   Resistant       0.99      1.00      0.99       150

    accuracy                           0.76       450
   macro avg       0.75      0.76      0.75       450
weighted avg       0.75      0.76      0.75       450

[[ 99  49   2]
 [ 59  91   0]
 [  0   0 150]]


In [40]:
model.save("hybrid_ast_model.keras")


In [41]:
model.input


[<KerasTensor shape=(None, 1), dtype=float32, sparse=False, ragged=False, name=organism>,
 <KerasTensor shape=(None, 1), dtype=float32, sparse=False, ragged=False, name=antibiotic>,
 <KerasTensor shape=(None, 2), dtype=float32, sparse=False, ragged=False, name=genomic>]

In [43]:
from tensorflow.keras.models import load_model

model = load_model("/content/hybrid_ast_model.keras")

In [44]:
model.input


[<KerasTensor shape=(None, 1), dtype=float32, sparse=False, ragged=False, name=organism>,
 <KerasTensor shape=(None, 1), dtype=float32, sparse=False, ragged=False, name=antibiotic>,
 <KerasTensor shape=(None, 2), dtype=float32, sparse=False, ragged=False, name=genomic>]

In [46]:
from sklearn.metrics import classification_report, confusion_matrix

y_pred = model.predict(X_test)
y_pred = y_pred.argmax(axis=1)

print(classification_report(y_te, y_pred))
print(confusion_matrix(y_te, y_pred))

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 3ms/step 
              precision    recall  f1-score   support

           0       0.63      0.66      0.64       150
           1       0.65      0.61      0.63       150
           2       0.99      1.00      0.99       150

    accuracy                           0.76       450
   macro avg       0.75      0.76      0.75       450
weighted avg       0.75      0.76      0.75       450

[[ 99  49   2]
 [ 59  91   0]
 [  0   0 150]]


In [47]:
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np


In [48]:
# Predict probabilities
y_pred_prob = model.predict(X_test)

# Convert probabilities to class labels
y_pred = np.argmax(y_pred_prob, axis=1)


[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step 


In [50]:
print("Classification Report:\n")
print(
    classification_report(
        y_te,
        y_pred,
        target_names=["Susceptible", "Intermediate", "Resistant"]
    )
)

print("\nConfusion Matrix:\n")
print(confusion_matrix(y_te, y_pred))

Classification Report:

              precision    recall  f1-score   support

 Susceptible       0.63      0.66      0.64       150
Intermediate       0.65      0.61      0.63       150
   Resistant       0.99      1.00      0.99       150

    accuracy                           0.76       450
   macro avg       0.75      0.76      0.75       450
weighted avg       0.75      0.76      0.75       450


Confusion Matrix:

[[ 99  49   2]
 [ 59  91   0]
 [  0   0 150]]


In [None]:
import pandas as pd

# Save classification report
report_dict = classification_report(
    y_test,
    y_pred,
    target_names=["Susceptible", "Intermediate", "Resistant"],
    output_dict=True
)
report_df = pd.DataFrame(report_dict).transpose()
report_df.to_csv("classification_report_new.csv")

# Save confusion matrix
cm = confusion_matrix(y_test, y_pred)
cm_df = pd.DataFrame(
    cm,
    index=["True_Susceptible", "True_Intermediate", "True_Resistant"],
    columns=["Pred_Susceptible", "Pred_Intermediate", "Pred_Resistant"]
)
cm_df.to_csv("confusion_matrix_new.csv")


In [52]:
import pandas as pd

# Save classification report
report_dict = classification_report(
    y_te,
    y_pred,
    target_names=["Susceptible", "Intermediate", "Resistant"],
    output_dict=True
)
report_df = pd.DataFrame(report_dict).transpose()
report_df.to_csv("classification_report_new.csv")

# Save confusion matrix
cm = confusion_matrix(y_te, y_pred)
cm_df = pd.DataFrame(
    cm,
    index=["True_Susceptible", "True_Intermediate", "True_Resistant"],
    columns=["Pred_Susceptible", "Pred_Intermediate", "Pred_Resistant"]
)
cm_df.to_csv("confusion_matrix_new.csv")

In [53]:
# ------------------------------
# Run this in Google Colab
# ------------------------------
import pandas as pd
from google.colab import files
import os, textwrap

# File names (change if you used other names)
cls_file = "classification_report_new.csv"
cm_file = "confusion_matrix_new.csv"

# Load classification report (expects the sklearn output_dict -> DataFrame saved)
cls = pd.read_csv(cls_file, index_col=0)

# Prepare Table 2 (Classification performance)
# Select per-class rows and macro/weighted averages if present
classes = [c for c in cls.index if c.lower() in ["susceptible","intermediate","resistant","macro avg","weighted avg","accuracy","macro-average","weighted-average"]]
# fall back to first 3 rows if names differ
if len(classes) < 3:
    classes = list(cls.index[:5])

# Build a Word-friendly table string
def make_table2_text(cls_df):
    # Try to format typical fields: precision, recall, f1-score, support
    headers = ["Class", "Precision", "Recall", "F1-score", "Support"]
    lines = []
    lines.append("**Table 2**\n**Classification performance of the proposed hybrid deep learning model on the independent test set**\n")
    # Header row
    header_row = "| " + " | ".join(headers) + " |"
    sep_row = "|-" + "-|-".join(['-'*len(h) for h in headers]) + "-|"
    lines.append(header_row)
    lines.append(sep_row)
    # For each requested class, extract values (round to 2 decimals)
    for idx in cls_df.index:
        row = cls_df.loc[idx]
        # robust access
        prec = row.get("precision", row.get("precision ", None))
        rec = row.get("recall", row.get("recall ", None))
        f1  = row.get("f1-score", row.get("f1_score", None))
        sup = row.get("support", row.get("n", None))
        # format numeric values
        def fmt(v):
            try:
                if pd.isna(v): return ""
                if isinstance(v, (int, float)):
                    # If support is integer
                    if float(v).is_integer():
                        return str(int(v))
                    else:
                        return f"{float(v):.2f}"
                else:
                    return str(v)
            except:
                return str(v)
        prec_s = fmt(prec) if prec!=None else ""
        rec_s  = fmt(rec) if rec!=None else ""
        f1_s   = fmt(f1)  if f1!=None else ""
        sup_s  = fmt(sup) if sup!=None else ""
        lines.append(f"| {idx} | {prec_s} | {rec_s} | {f1_s} | {sup_s} |")
    return "\n".join(lines)

# Prepare Table 3 (Confusion matrix)
def make_table3_text(cm_df):
    lines = []
    lines.append("**Table 3**\n**Confusion matrix of predicted versus true antimicrobial susceptibility phenotypes (test set)**\n")
    # Build header
    cols = list(cm_df.columns)
    header = "| True \\ Predicted | " + " | ".join(cols) + " |"
    sep = "|-" + "-|-".join(['-'*len(c) for c in ["True \\ Predicted"] + cols]) + "-|"
    lines.append(header)
    lines.append(sep)
    for idx in cm_df.index:
        row_vals = " | ".join(str(int(x)) for x in cm_df.loc[idx].values)
        lines.append(f"| **{idx}** | {row_vals} |")
    return "\n".join(lines)

# Read files
cls_df = cls.copy()
# if index has 'macro avg' etc but slightly different names, normalise:
cls_df.index = [i if isinstance(i,str) else str(i) for i in cls_df.index]

cm_df = pd.read_csv(cm_file, index_col=0)

# Ensure cm_df has appropriate row and column labels; if not, add default
if cm_df.shape != (3,3):
    # try to extract numeric matrix if not labelled
    try:
        cm_array = cm_df.values
        if cm_array.shape[0] == 3:
            # assign default names
            cm_df = pd.DataFrame(cm_array, index=["Susceptible","Intermediate","Resistant"],
                                 columns=["Susceptible","Intermediate","Resistant"])
    except:
        pass

# Create text versions
table2_text = make_table2_text(cls_df.loc[["Susceptible","Intermediate","Resistant","macro avg","weighted avg"]].dropna(axis=0, how='all', subset=["precision","recall","f1-score","support"])
                               if set(["Susceptible","Intermediate","Resistant"]).issubset(set(cls_df.index)) else make_table2_text(cls_df))
table3_text = make_table3_text(cm_df)

# Print to console (copy-paste into Word)
print(table2_text)
print("\n\n")
print(table3_text)

# Also save into files for download
with open("Table_2_classification_performance.txt", "w") as f:
    f.write(table2_text)
with open("Table_3_confusion_matrix.txt", "w") as f:
    f.write(table3_text)

# Offer downloads
files.download("Table_2_classification_performance.txt")
files.download("Table_3_confusion_matrix.txt")


**Table 2**
**Classification performance of the proposed hybrid deep learning model on the independent test set**

| Class | Precision | Recall | F1-score | Support |
|-------|-----------|--------|----------|---------|
| Susceptible | 0.63 | 0.66 | 0.64 | 150 |
| Intermediate | 0.65 | 0.61 | 0.63 | 150 |
| Resistant | 0.99 | 1 | 0.99 | 150 |
| macro avg | 0.75 | 0.76 | 0.75 | 450 |
| weighted avg | 0.75 | 0.76 | 0.75 | 450 |



**Table 3**
**Confusion matrix of predicted versus true antimicrobial susceptibility phenotypes (test set)**

| True \ Predicted | Pred_Susceptible | Pred_Intermediate | Pred_Resistant |
|------------------|------------------|-------------------|----------------|
| **True_Susceptible** | 99 | 49 | 2 |
| **True_Intermediate** | 59 | 91 | 0 |
| **True_Resistant** | 0 | 0 | 150 |


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [55]:
# Quick ablation: assumes you have X_test (dict), y_te, model loaded
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix

def eval_preds(y_true, y_pred, title=""):
    print(title)
    print(classification_report(y_true, y_pred, target_names=["Susceptible","Intermediate","Resistant"]))
    print("Confusion matrix:\n", confusion_matrix(y_true, y_pred))
    print("-"*60)

# baseline
y_prob = model.predict(X_test)
y_pred = np.argmax(y_prob, axis=1)
eval_preds(y_te, y_pred, "BASELINE (all inputs)")

# 1A) Ablate genomic features by setting to zero
X_test_zero_gen = {
    "organism": X_test["organism"].copy(),
    "antibiotic": X_test["antibiotic"].copy(),
    "genomic": np.zeros_like(X_test["genomic"])
}
y_prob_gzero = model.predict(X_test_zero_gen)
y_pred_gzero = np.argmax(y_prob_gzero, axis=1)
eval_preds(y_te, y_pred_gzero, "ABLATION: Genomic features set to 0 (no retrain)")

# 1B) Ablate organism by setting organism id to a single constant (e.g., most frequent)
most_freq_org = np.bincount(X_test["organism"].astype(int)).argmax()
X_test_fix_org = {
    "organism": np.full_like(X_test["organism"], fill_value=most_freq_org),
    "antibiotic": X_test["antibiotic"].copy(),
    "genomic": X_test["genomic"].copy()
}
y_prob_orgfix = model.predict(X_test_fix_org)
y_pred_orgfix = np.argmax(y_prob_orgfix, axis=1)
eval_preds(y_te, y_pred_orgfix, f"ABLATION: Organism fixed to id {most_freq_org} (no retrain)")

[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step
BASELINE (all inputs)
              precision    recall  f1-score   support

 Susceptible       0.63      0.66      0.64       150
Intermediate       0.65      0.61      0.63       150
   Resistant       0.99      1.00      0.99       150

    accuracy                           0.76       450
   macro avg       0.75      0.76      0.75       450
weighted avg       0.75      0.76      0.75       450

Confusion matrix:
 [[ 99  49   2]
 [ 59  91   0]
 [  0   0 150]]
------------------------------------------------------------
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step
ABLATION: Genomic features set to 0 (no retrain)
              precision    recall  f1-score   support

 Susceptible       0.62      0.63      0.62       150
Intermediate       0.62      0.62      0.62       150
   Resistant       1.00      1.00      1.00       150

    accuracy                           0.75       450
   m

In [56]:
# --- Retrain ablation cell: organism+antibiotic only, and antibiotic+genomic only ---
# Paste and run this cell in your Colab session. May take some minutes to train.
import os, sys
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.layers import Input, Embedding, Dense, Dropout, Flatten, Concatenate, Reshape, Conv1D, GlobalMaxPooling1D
from tensorflow.keras.models import Model
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from tensorflow.keras.callbacks import EarlyStopping

# --------- Helper: locate data variables (be robust to naming) ----------
# Find X_train/X_val/X_test and y_train/y_val/y_test in globals
g = globals()
# X dictionaries
if 'X_train' in g and 'X_val' in g and 'X_test' in g:
    Xtr, Xval, Xte = X_train, X_val, X_test
elif 'X_train' in g:  # guard but assume present
    Xtr, Xval, Xte = X_train, X_val, X_test
else:
    raise RuntimeError("Could not find X_train/X_val/X_test in the global namespace. Make sure you ran preprocessing and splits.")

# y arrays: accept y_tr / y_val / y_te or y_train / y_val / y_test
if 'y_tr' in g and 'y_val' in g and 'y_te' in g:
    ytr, yval, yte = y_tr, y_val, y_te
elif 'y_train' in g and 'y_val' in g and 'y_test' in g:
    ytr, yval, yte = y_train, y_val, y_test
elif 'y_train' in g and 'y_val' in g and 'y_te' in g:
    ytr, yval, yte = y_train, y_val, y_te
else:
    # try y_train/y_test fallback
    try:
        ytr = g['y_train']
        yval = g['y_val']
        yte = g['y_test']
    except Exception:
        raise RuntimeError("Could not find y training/validation/test arrays (y_tr / y_train etc.).")

# label encoders (for vocab sizes)
if 'le_org' not in g or 'le_abx' not in g:
    print("Warning: le_org or le_abx not found. Attempting to infer vocab sizes from X arrays.")
    vocab_org = int(np.max(Xtr['organism']) + 1)
    vocab_abx = int(np.max(Xtr['antibiotic']) + 1)
else:
    vocab_org = len(le_org.classes_)
    vocab_abx = len(le_abx.classes_)

print("Vocab sizes: organism =", vocab_org, "antibiotic =", vocab_abx)

# class weights: reuse if available, else compute from ytr
if 'class_weights' in g:
    cw = class_weights
else:
    from sklearn.utils.class_weight import compute_class_weight
    cw_vals = compute_class_weight(class_weight='balanced', classes=np.unique(ytr), y=ytr)
    cw = dict(enumerate(cw_vals))

print("Class weights:", cw)

# common training params
EPOCHS = 60
BATCH = 32
callbacks = [EarlyStopping(patience=8, restore_best_weights=True)]

# ----------------- Model 1: No-genomic (organism + antibiotic) -----------------
def build_model_no_gen(vocab_org, vocab_abx, emb_dim=16):
    org_in = Input(shape=(1,), name="organism")
    abx_in = Input(shape=(1,), name="antibiotic")
    org_emb = Embedding(input_dim=vocab_org, output_dim=emb_dim, name="org_emb")(org_in)
    abx_emb = Embedding(input_dim=vocab_abx, output_dim=emb_dim, name="abx_emb")(abx_in)
    org_f = Flatten()(org_emb)
    abx_f = Flatten()(abx_emb)
    x = Concatenate()([org_f, abx_f])
    x = Dense(64, activation="relu")(x)
    x = Dropout(0.5)(x)
    x = Dense(32, activation="relu")(x)
    out = Dense(3, activation="softmax")(x)
    model = Model(inputs=[org_in, abx_in], outputs=out)
    model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
    return model

model_no_gen = build_model_no_gen(vocab_org, vocab_abx)
print("Model (no-genomic) summary")
model_no_gen.summary()

# Prepare inputs for model_no_gen
Xtr_no_gen = {"organism": Xtr["organism"], "antibiotic": Xtr["antibiotic"]}
Xval_no_gen = {"organism": Xval["organism"], "antibiotic": Xval["antibiotic"]}
Xte_no_gen = {"organism": Xte["organism"], "antibiotic": Xte["antibiotic"]}

print("Training model without genomic features...")
hist_no_gen = model_no_gen.fit(
    Xtr_no_gen, ytr,
    validation_data=(Xval_no_gen, yval),
    epochs=EPOCHS, batch_size=BATCH, class_weight=cw, callbacks=callbacks, verbose=1
)

# Evaluate
yprob_no_gen = model_no_gen.predict(Xte_no_gen)
ypred_no_gen = np.argmax(yprob_no_gen, axis=1)
report_no_gen = classification_report(yte, ypred_no_gen, target_names=["Susceptible","Intermediate","Resistant"], output_dict=True)
cm_no_gen = confusion_matrix(yte, ypred_no_gen)
acc_no_gen = accuracy_score(yte, ypred_no_gen)
print("No-genomic model accuracy:", acc_no_gen)

# ----------------- Model 2: No-organism (antibiotic + genomic) -----------------
def build_model_no_org(vocab_abx, emb_dim=16):
    abx_in = Input(shape=(1,), name="antibiotic")
    gen_in = Input(shape=(2,), name="genomic")
    abx_emb = Embedding(input_dim=vocab_abx, output_dim=emb_dim, name="abx_emb")(abx_in)
    abx_f = Flatten()(abx_emb)
    g = Reshape((2,1))(gen_in)
    g = Conv1D(32, 2, activation="relu")(g)
    g = GlobalMaxPooling1D()(g)
    x = Concatenate()([abx_f, g])
    x = Dense(64, activation="relu")(x)
    x = Dropout(0.5)(x)
    x = Dense(32, activation="relu")(x)
    out = Dense(3, activation="softmax")(x)
    model = Model(inputs=[abx_in, gen_in], outputs=out)
    model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
    return model

model_no_org = build_model_no_org(vocab_abx)
print("Model (no-organism) summary")
model_no_org.summary()

# Prepare inputs for model_no_org
Xtr_no_org = {"antibiotic": Xtr["antibiotic"], "genomic": Xtr["genomic"]}
Xval_no_org = {"antibiotic": Xval["antibiotic"], "genomic": Xval["genomic"]}
Xte_no_org = {"antibiotic": Xte["antibiotic"], "genomic": Xte["genomic"]}

print("Training model without organism feature...")
hist_no_org = model_no_org.fit(
    Xtr_no_org, ytr,
    validation_data=(Xval_no_org, yval),
    epochs=EPOCHS, batch_size=BATCH, class_weight=cw, callbacks=callbacks, verbose=1
)

# Evaluate
yprob_no_org = model_no_org.predict(Xte_no_org)
ypred_no_org = np.argmax(yprob_no_org, axis=1)
report_no_org = classification_report(yte, ypred_no_org, target_names=["Susceptible","Intermediate","Resistant"], output_dict=True)
cm_no_org = confusion_matrix(yte, ypred_no_org)
acc_no_org = accuracy_score(yte, ypred_no_org)
print("No-organism model accuracy:", acc_no_org)

# ----------------- Baseline (if available) - evaluate saved baseline model if in memory -------------
# Baseline evaluation: try to use 'model' variable (your original model)
if 'model' in g:
    yprob_base = model.predict(Xte)
    ypred_base = np.argmax(yprob_base, axis=1)
    report_base = classification_report(yte, ypred_base, target_names=["Susceptible","Intermediate","Resistant"], output_dict=True)
    cm_base = confusion_matrix(yte, ypred_base)
    acc_base = accuracy_score(yte, ypred_base)
    print("Baseline model accuracy:", acc_base)
else:
    report_base, cm_base, acc_base = None, None, None
    print("Baseline model not found in globals; skipping baseline eval.")

# ----------------- Save reports and confusion matrices -----------------
out_dir = "/content/ablation_results/"
os.makedirs(out_dir, exist_ok=True)

# helper to write report
def save_report(report_dict, prefix):
    df = pd.DataFrame(report_dict).T
    df.to_csv(os.path.join(out_dir, f"{prefix}_classification_report.csv"))
    return df

df_no_gen = save_report(report_no_gen, "no_genomic")
df_no_org = save_report(report_no_org, "no_organism")
if report_base is not None:
    df_base = save_report(report_base, "baseline")

# confusion matrices
pd.DataFrame(cm_no_gen, index=["True_Susc","True_Int","True_Res"], columns=["Pred_Susc","Pred_Int","Pred_Res"]).to_csv(os.path.join(out_dir,"cm_no_genomic.csv"))
pd.DataFrame(cm_no_org, index=["True_Susc","True_Int","True_Res"], columns=["Pred_Susc","Pred_Int","Pred_Res"]).to_csv(os.path.join(out_dir,"cm_no_organism.csv"))
if report_base is not None:
    pd.DataFrame(cm_base, index=["True_Susc","True_Int","True_Res"], columns=["Pred_Susc","Pred_Int","Pred_Res"]).to_csv(os.path.join(out_dir,"cm_baseline.csv"))

# Summary table: compute accuracy and macro F1
def macro_f1_from_report_df(df):
    # macro avg row might be "macro avg" or "macro_avg"
    for candidate in ["macro avg","macro_avg","macro-average","macro"]:
        if candidate in df.index:
            return df.loc[candidate]["f1-score"]
    # else compute ourselves
    if set(["Susceptible","Intermediate","Resistant"]).issubset(set(df.index)):
        return np.mean([df.loc[c]["f1-score"] for c in ["Susceptible","Intermediate","Resistant"]])
    return None

summary_rows = []
if report_base is not None:
    summary_rows.append({"Condition":"Baseline","Accuracy":round(float(acc_base),3),"Macro_F1":round(float(macro_f1_from_report_df(df_base)),3)})
summary_rows.append({"Condition":"No genomic","Accuracy":round(float(acc_no_gen),3),"Macro_F1":round(float(macro_f1_from_report_df(df_no_gen)),3)})
summary_rows.append({"Condition":"No organism","Accuracy":round(float(acc_no_org),3),"Macro_F1":round(float(macro_f1_from_report_df(df_no_org)),3)})

summary_df = pd.DataFrame(summary_rows)
summary_df.to_csv(os.path.join(out_dir,"ablation_summary.csv"), index=False)

# Print Word-ready table
print("\n=== Ablation summary (Word-ready) ===\n")
print(summary_df.to_markdown(index=False))

print("\nSaved detailed reports and confusion matrices to:", out_dir)


Vocab sizes: organism = 17 antibiotic = 2011
Class weights: {0: np.float64(1.0), 1: np.float64(1.0), 2: np.float64(1.0)}
Model (no-genomic) summary


Training model without genomic features...
Epoch 1/60
[1m66/66[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 18ms/step - accuracy: 0.5908 - loss: 1.0242 - val_accuracy: 0.7533 - val_loss: 0.5517
Epoch 2/60
[1m66/66[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 9ms/step - accuracy: 0.6910 - loss: 0.5149 - val_accuracy: 0.7533 - val_loss: 0.4512
Epoch 3/60
[1m66/66[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 7ms/step - accuracy: 0.7509 - loss: 0.4529 - val_accuracy: 0.7533 - val_loss: 0.4434
Epoch 4/60
[1m66/66[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.7508 - loss: 0.4353 - val_accuracy: 0.7533 - val_loss: 0.4442
Epoch 5/60
[1m66/66[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 7ms/step - accuracy: 0.7745 - loss: 0.4089 - val_accuracy: 0.7556 - val_loss: 0.4536
Epoch 6/60
[1m66/66[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.8543 - loss: 0.3433 - val_accuracy: 0.7533 - val_loss: 0.4

Training model without organism feature...
Epoch 1/60
[1m66/66[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 10ms/step - accuracy: 0.3637 - loss: 1.9980 - val_accuracy: 0.6289 - val_loss: 1.0404
Epoch 2/60
[1m66/66[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 7ms/step - accuracy: 0.5798 - loss: 1.0705 - val_accuracy: 0.7111 - val_loss: 0.8421
Epoch 3/60
[1m66/66[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 7ms/step - accuracy: 0.6819 - loss: 0.8739 - val_accuracy: 0.7133 - val_loss: 0.6945
Epoch 4/60
[1m66/66[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 7ms/step - accuracy: 0.7250 - loss: 0.8087 - val_accuracy: 0.7133 - val_loss: 0.6760
Epoch 5/60
[1m66/66[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 8ms/step - accuracy: 0.7290 - loss: 0.6817 - val_accuracy: 0.7067 - val_loss: 0.7124
Epoch 6/60
[1m66/66[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 4ms/step - accuracy: 0.7479 - loss: 0.6698 - val_accuracy: 0.6911 - val_loss: 0.5

In [57]:
# --- Per-organism performance table ---
import numpy as np, pandas as pd
from sklearn.metrics import accuracy_score, f1_score

# Ensure X_test and yte are available
g = globals()
if 'X_test' in g and ('yte' in g or 'y_test' in g):
    if 'yte' in g:
        ytest = yte
    else:
        ytest = y_test
    Xtest = X_test
else:
    # try fallback names
    if 'Xte' in g and 'yte' in g:
        Xtest = Xte
        ytest = yte
    else:
        raise RuntimeError("X_test and y_test (or Xte/yte) not found. Run preprocessing/split cells first.")

# require le_org for mapping ids to names (if not available, we output ids)
use_names = 'le_org' in g
min_samples = 20  # adjust if you want
org_ids = Xtest['organism'].astype(int)
unique_ids, counts = np.unique(org_ids, return_counts=True)

rows = []
for uid, cnt in zip(unique_ids, counts):
    if cnt < min_samples:
        continue
    idx = (org_ids == uid)
    y_true = ytest[idx]
    # build sub-X for predictions for the final full model if available, else skip
    if 'model' in g:
        subX = {"organism": Xtest['organism'][idx], "antibiotic": Xtest['antibiotic'][idx], "genomic": Xtest['genomic'][idx]}
        y_pred = np.argmax(model.predict(subX), axis=1)
    else:
        # if final model not available, try no_org or no_gen etc.
        raise RuntimeError("Full model not found in globals as 'model'. Load or instantiate it before running per-organism evaluation.")
    acc = accuracy_score(y_true, y_pred)
    f1m = f1_score(y_true, y_pred, average='macro')
    name = le_org.inverse_transform([uid])[0] if use_names else str(uid)
    rows.append({"organism_id": int(uid),"organism_name": name, "n": int(cnt), "accuracy": round(acc,3), "f1_macro": round(f1m,3)})

per_org_df = pd.DataFrame(rows).sort_values(by='n', ascending=False)
per_org_df.to_csv("/content/ablation_results/per_organism_performance.csv", index=False)

# Print Word-ready table: top 20
print("\nPer-organism performance (organisms with >= {} samples):\n".format(min_samples))
print(per_org_df.head(20).to_markdown(index=False))
print("\nSaved to: /content/ablation_results/per_organism_performance.csv")


[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 141ms/step
[1m10/10[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step 

Per-organism performance (organisms with >= 20 samples):

|   organism_id | organism_name         |   n |   accuracy |   f1_macro |
|--------------:|:----------------------|----:|-----------:|-----------:|
|            16 | nan                   | 300 |      0.633 |      0.423 |
|             2 | E.coli and Shigella   |  66 |      1     |      1     |
|            15 | Staphylococcus aureus |  28 |      1     |      1     |

Saved to: /content/ablation_results/per_organism_performance.csv


In [58]:
# Zip & download the ablation_results folder (Google Colab)
import os, shutil
from google.colab import files

SRC_FOLDER = "/content/ablation_results"   # change this if your folder is elsewhere
ZIP_BASE  = "/content/ablation_results"    # zip will become /content/ablation_results.zip
ZIP_PATH  = ZIP_BASE + ".zip"

def zip_and_download_folder(src_folder=SRC_FOLDER, zip_base=ZIP_BASE):
    if not os.path.exists(src_folder):
        print(f"Folder not found: {src_folder}\nMake sure the retrain/per-organism cells have been run and wrote files to this path.")
        return

    # Remove old zip if exists
    if os.path.exists(zip_base + ".zip"):
        try:
            os.remove(zip_base + ".zip")
        except Exception as e:
            print("Could not remove existing zip:", e)

    print("Creating zip archive... (this may take a few seconds)")
    shutil.make_archive(zip_base, 'zip', src_folder)
    print("Created:", zip_base + ".zip")

    # Download the zip
    try:
        print("Starting download...")
        files.download(zip_base + ".zip")
        print("Download initiated. Check your browser downloads.")
    except Exception as e:
        print("Automatic download failed (browser security). The zip file is at:", zip_base + ".zip")
        print("You can download it manually from the Colab file browser (left pane) or copy it to Drive.")

def list_and_download_individual_csvs(src_folder=SRC_FOLDER):
    if not os.path.exists(src_folder):
        print(f"Folder not found: {src_folder}")
        return
    files_list = [f for f in os.listdir(src_folder) if f.lower().endswith(".csv")]
    if not files_list:
        print("No CSV files found in", src_folder)
        return
    print("CSV files found:")
    for f in files_list:
        print(" -", f)
    print("\nDownloading CSVs one-by-one...")
    for f in files_list:
        path = os.path.join(src_folder, f)
        try:
            files.download(path)
        except Exception as e:
            print("Could not download", f, ":", e)
            print("You can find the file in the Colab Files pane (left side) and download manually.")

# Run the zip+download. Uncomment the alternative if you want individual CSVs instead.
zip_and_download_folder()

# If you prefer to download CSVs individually, uncomment:
# list_and_download_individual_csvs()


Creating zip archive... (this may take a few seconds)
Created: /content/ablation_results.zip
Starting download...


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Download initiated. Check your browser downloads.


# Task
The task is complete.

## Final Task

### Subtask:
The user has asked for an explanation of the 'amr_genotype_raw' and 'stress_genotype_raw' columns. I have provided a detailed explanation based on the dataframe content and how these columns are processed in the notebook.


## Summary:

### Data Analysis Key Findings
*   A detailed explanation was provided for the 'amr_genotype_raw' and 'stress_genotype_raw' columns. This explanation covered the content of these columns and how they are processed within the notebook.

### Insights or Next Steps
*   The explanation of the columns was formulated based on their content within the dataframe and their processing steps in the notebook.
