# **Control Model**
Training a second model for a single drug (e.g., $\text{AMX}$) using $\mathbf{Tier\ 2 + Tier\ 1B}$ features only. Compare the performance to our $\text{Tier 2 + Novel}$ model. If the $\text{Tier 2 + Tier 1B}$ model is significantly worse, it confirms that our Novel genes are providing unique, non-redundant predictive power beyond the stress response.

In [None]:
from xgboost import XGBClassifier
from sklearn.model_selection import train_test_split, StratifiedKFold, cross_val_score
from sklearn.metrics import (roc_auc_score, classification_report,
                              average_precision_score, precision_score,
                              recall_score, f1_score)
import shap
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from pathlib import Path
import pickle
from scipy.stats import pearsonr

## **CONFIGURATION**

In [None]:
BASE_DIR = Path('/content/drive/MyDrive')
DATA_DIR = BASE_DIR / 'amr_features'
ROARY_DIR = BASE_DIR / 'pangenome_features'
MODEL_DIR = BASE_DIR / 'models'
RESULTS_DIR = BASE_DIR / 'results'

MODEL_DIR.mkdir(exist_ok=True)
RESULTS_DIR.mkdir(exist_ok=True)

## **HELPER FUNCTIONS**

In [None]:
def standardize_sample_id(sample_id):
    """Standardize sample IDs"""
    sample_id = str(sample_id).strip()
    if '#' in sample_id:
        return sample_id
    parts = sample_id.rsplit('_', 1)
    return f"{parts[0]}#{parts[1]}" if len(parts) == 2 else sample_id

def fix_sample_ids(df):
    """Apply standardization to DataFrame"""
    df.index = df.index.map(standardize_sample_id)
    return df

### **`Controlling for confounding factors** in genomic epidemiology, specifically **population structure (clonality)`**

The initial preprocessing step (using `filter_by_correlation_v3`) already removes genes that are highly correlated with `Tier 1A` (known AMR genes) and flags correlation with `Tier 1B` (stress/metal genes) using a high threshold ($\rho \ge 0.90$).

The second, simpler filter—`remove_lineage_markers` with a lower threshold ($\rho \ge 0.70$)—serves a **different, complementary purpose** related to **clonality**.


### **`Why the Second Filter is Necessary (Controlling for Clonality)`**

**`1. The Limitation of the Initial Filter ($\rho \ge 0.90$)`**

The initial filter is designed to remove **near-redundant features**—novel genes that are almost perfectly linked ($\rho \ge 0.90$) to a known AMR gene. This prevents you from mistaking a co-occurring gene for a truly novel mechanism.

**`2. The Role of the Second Filter ($\rho \ge 0.70$)`**

The secondary filter, `remove_lineage_markers`, is specifically targeting **Population Structure**.

* **What it targets:** **Lineage markers** are genes or mutations that define specific clonal groups (strains) in our *E. coli* population. These markers are often conserved within a strain and have no direct functional link to antibiotic resistance.
* **The Problem (Confounding):** If a specific **clonal strain (e.g., ST131)** happens to be highly resistant to a drug, and it also happens to carry a novel gene 'X', the model will learn that $\text{Gene X} \rightarrow \text{Resistance}$. This is an **artifact of clonality**, not a functional link. The model has discovered a **lineage marker**, not a novel AMR gene.
* **The Mechanism:** Lineage markers are usually **chromosomal SNPs/genes** (like the ones we list: `rz`, `yedI`, etc.) that are co-inherited with all other features of that lineage. The secondary filter checks for a high correlation ($\rho \ge 0.70$) between your novel genes and **`Tier 2 (Chromosomal SNPs/Mutations)`** features, which are strong indicators of population structure.
* **Lower Threshold:** A correlation of $\rho=0.70$ is used because genes that define a lineage don't need to be perfectly co-inherited, but they do need to be **highly associated** with the clonal backbone features (the Tier 2 mutations). Removing these features prevents your model from being misled by clonal background.

In essence, you are performing a layered cleanup:
* **First Layer ($\rho \ge 0.90$):** Eliminates redundancy with **known AMR genes** ($\text{Tier 1A/B}$).
* **Second Layer ($\rho \ge 0.70$):** Eliminates confounding due to **Population Structure/Lineage** ($\text{Tier 2}$).

**Summary of the Function `remove_lineage_markers`**

The core function of this filter is to **decorrelate novel genes from chromosomal background features (Tier 2)** to ensure that the genes the model finds important are truly associated with resistance and not just strain type.

$$
\text{Focus of } \rho \ge 0.70 \text{ filter} = \text{Novel Gene} \leftrightarrow \text{Tier 2 (Chromosomal Lineage Markers)}
$$

This is a critical step to ensure that the genes nominated for validation are genuine AMR candidates and not merely phylogenetic signals.

In [None]:
def remove_lineage_markers(roary_df, tier2_df, correlation_threshold=0.7):
    """
    Remove genes highly correlated with known lineage markers
    to reduce population structure confounding
    """
    #known lineage markers (from mutation analysis)
    potential_markers = ['rz', 'yedI', 'nmpC', 'gatD', 'betU', 'yeeO']

    genes_to_keep = []
    removed_genes = []

    for gene in roary_df.columns:
        if gene in potential_markers:
            removed_genes.append(gene)
            continue

        #check correlation with Tier 2 features (chromosomal mutations)
        max_corr = 0
        for tier2_feature in tier2_df.columns:
            try:
                corr, _ = pearsonr(roary_df[gene], tier2_df[tier2_feature])
                max_corr = max(max_corr, abs(corr))
            except:
                continue

        if max_corr < correlation_threshold:
            genes_to_keep.append(gene)
        else:
            removed_genes.append(gene)

    print(f"  Correlation filter: kept {len(genes_to_keep)}, removed {len(removed_genes)}")
    return roary_df[genes_to_keep], removed_genes

## **LOAD DATA**

In [None]:
#load Tier 2 (genes + mutations combined)
print("Loading Tier 2 (genes + mutations)...")
tier2_df = pd.read_csv(DATA_DIR / 'tier2_amr_genes_plus_mutations.csv', index_col=0)
tier2_df = fix_sample_ids(tier2_df)
print(f"  Tier 2 shape: {tier2_df.shape}")
tier1b_stress = pd.read_csv('/content/drive/MyDrive/amr_features/tier1b_stress_genes.csv', index_col='ISOLATE_ID')
tier1b_stress = fix_sample_ids(tier1b_stress)
tier2_df = pd.concat([tier2_df, tier1b_stress], axis=1, join='outer')
tier2_df.fillna(0, inplace=True)

Loading Tier 2 (genes + mutations)...
  Tier 2 shape: (1089, 1236)


In [None]:
tier2_df.head()

Unnamed: 0,ampC_T86A,ampC_A356S,gyrA_I198L,gyrA_S111P,gyrA_R91C,robA_L279V,ftsI_H425Q,thyA_Q39E,marR_I137L,thyA_L184F,...,silS,terB,terC,terD,terE,terW,terZ,trxLHR,yfdX1,yfdX2
11657_5#25,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
11657_5#26,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
11657_5#27,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
11657_5#29,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
11657_5#30,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [None]:
#load phenotypes
print("Loading phenotypes...")
phenotypes = pd.read_csv(BASE_DIR / 'data' / 'E.coli' / 'phenotypic.csv')

if 'Isolate' in phenotypes.columns:
    phenotypes.set_index('Isolate', inplace=True)
elif 'Lane.accession' in phenotypes.columns:
    phenotypes.set_index('Lane.accession', inplace=True)

phenotypes = fix_sample_ids(phenotypes)
print(f"  Phenotype shape: {phenotypes.shape}")

Loading phenotypes...
  Phenotype shape: (1936, 16)


In [None]:
#find common samples
common_samples = tier2_df.index.intersection(phenotypes.index)
print(f"Common samples: {len(common_samples)}")

if len(common_samples) < 100:
    raise ValueError(f"Only {len(common_samples)} matching samples! Check IDs.")

Common samples: 1651


## **TRAIN TIER 3 MODELS With `Stress` Genes**

In [None]:
results = []
novel_gene_discoveries = {} #kept for structure, but will remain empty

for drug in ['AMX', 'AMC', 'CIP']:
    print(f"\n{'='*80}")
    print(f"TIER 2 + TIER 1B MODEL: {drug}")
    print(f"{'='*80}")

    #the feature matrix X is the combined Tier 2/Tier 1B dataframe
    X_drug = tier2_df.copy()

    #prepare labels
    y = phenotypes.loc[X_drug.index, drug].map({'R': 1, 'S': 0, 'I': 0})
    y = y.dropna()
    X_drug = X_drug.loc[y.index] #re-align X after dropping NaNs in y

    print(f"\nData prepared:")
    print(f"  Features: {X_drug.shape[1]}")
    print(f"  Samples: {len(X_drug)}")
    print(f"  Resistant:   {(y==1).sum()} ({(y==1).sum()/len(y)*100:.1f}%)")
    print(f"  Susceptible: {(y==0).sum()} ({(y==0).sum()/len(y)*100:.1f}%)")

    #train/test split
    X_train, X_test, y_train, y_test = train_test_split(
        X_drug, y, test_size=0.2, random_state=42, stratify=y
    )

    #class weights
    n_resistant = (y_train == 1).sum()
    n_susceptible = (y_train == 0).sum()
    scale_pos_weight = n_susceptible / n_resistant if n_resistant > 0 else 1.0

    print(f"\nTraining configuration:")
    print(f"  Train: {len(X_train)} (R={n_resistant}, S={n_susceptible})")
    print(f"  Test:  {len(X_test)}")
    print(f"  scale_pos_weight: {scale_pos_weight:.2f}")

    #train with regularization (same parameters as before)
    model = XGBClassifier(
        max_depth=5,
        n_estimators=100,
        learning_rate=0.1,
        scale_pos_weight=scale_pos_weight,
        min_child_weight=3,
        gamma=0.1,
        subsample=0.8,
        colsample_bytree=0.8,
        reg_alpha=0.1,
        reg_lambda=1.0,
        random_state=42,
        eval_metric='auc',
        verbosity=0
    )

    model.fit(X_train, y_train)

    #evaluate
    y_pred_proba = model.predict_proba(X_test)[:, 1]
    y_pred = model.predict(X_test)

    auroc = roc_auc_score(y_test, y_pred_proba)
    auprc = average_precision_score(y_test, y_pred_proba)
    precision = precision_score(y_test, y_pred, zero_division=0)
    recall = recall_score(y_test, y_pred, zero_division=0)
    f1 = f1_score(y_test, y_pred, zero_division=0)

    print(f"\n{'='*80}")
    print(f"PERFORMANCE")
    print(f"{'='*80}")
    print(f"AUROC:     {auroc:.4f}")
    print(f"AUPRC:     {auprc:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall:    {recall:.4f}")
    print(f"F1 Score:  {f1:.4f}")

    print(f"\nClassification Report:")
    print(classification_report(y_test, y_pred, target_names=['S', 'R'],
                                zero_division=0))

    #cross-validation
    cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    cv_scores = cross_val_score(model, X_drug, y, cv=cv,
                                 scoring='roc_auc', n_jobs=-1)
    print(f"\n5-Fold CV: {cv_scores.mean():.4f} ± {cv_scores.std():.4f}")

    #SHAP analysis
    print(f"\nComputing SHAP values...")
    explainer = shap.TreeExplainer(model)
    shap_values = explainer.shap_values(X_test)

    #feature importance (All features are Tier 2/1B now)
    feature_importance = pd.DataFrame({
        'feature': X_drug.columns,
        'shap_importance': np.abs(shap_values).mean(axis=0)
    }).sort_values('shap_importance', ascending=False)

    #all features are of the same type for this run
    feature_importance['feature_type'] = 'Tier2+Tier1B'

    print(f"\n{'='*80}")
    print(f"TOP 20 FEATURES (TIER 2 + TIER 1B)")
    print(f"{'='*80}")
    print(feature_importance.head(20).to_string(index=False))

    #save results
    results.append({
        'drug': drug,
        'model_type': 'Tier2_plus_Tier1B', #changed model type here
        'n_features': len(X_drug.columns),
        'n_tier2': tier2_df.shape[1],
        'n_novel': 0, #explicitly set to zero
        'n_samples': len(X_drug),
        'auroc': auroc,
        'auprc': auprc,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'cv_mean': cv_scores.mean(),
        'cv_std': cv_scores.std(),
        'tier2_contribution': 1.0, #100% contribution from Tier 2 + Tier 1B
        'novel_contribution': 0.0
    })

    #save model and importance (using a different file name)
    model.save_model(MODEL_DIR / f'tier2_plus_tier1b_model_{drug}.json')
    feature_importance.to_csv(
        RESULTS_DIR / f'tier2_plus_tier1b_feature_importance_{drug}.csv',
        index=False
    )

    #SHAP plot (for Tier 2 + Tier 1B)
    plt.figure(figsize=(10, 8))
    shap.summary_plot(shap_values, X_test, show=False, max_display=20)
    plt.title(f'{drug} - Tier 2 + Tier 1B SHAP Summary')
    plt.tight_layout()
    plt.savefig(RESULTS_DIR / f'tier2_plus_tier1b_shap_{drug}.png',
                dpi=300, bbox_inches='tight')
    plt.close()


TIER 2 + TIER 1B MODEL: AMX

Data prepared:
  Features: 1286
  Samples: 1089
  Resistant:   659 (60.5%)
  Susceptible: 430 (39.5%)

Training configuration:
  Train: 871 (R=527, S=344)
  Test:  218
  scale_pos_weight: 0.65

PERFORMANCE
AUROC:     0.9311
AUPRC:     0.9657
Precision: 0.9737
Recall:    0.8409
F1 Score:  0.9024

Classification Report:
              precision    recall  f1-score   support

           S       0.80      0.97      0.87        86
           R       0.97      0.84      0.90       132

    accuracy                           0.89       218
   macro avg       0.89      0.90      0.89       218
weighted avg       0.90      0.89      0.89       218


5-Fold CV: 0.9467 ± 0.0062

Computing SHAP values...

TOP 20 FEATURES (TIER 2 + TIER 1B)
    feature  shap_importance feature_type
      TEM-4         1.534967 Tier2+Tier1B
   blaTEM-1         0.861555 Tier2+Tier1B
       sul1         0.274132 Tier2+Tier1B
 ftsI_L192F         0.172736 Tier2+Tier1B
  gyrA_V85F         0.1

### **TIER 2 + TIER 1B RESULTS SUMMARY**

In [None]:
results_df = pd.DataFrame(results)
print(results_df.to_string(index=False))

results_df.to_csv(RESULTS_DIR / 'tier2_plus_tier1b_results_summary.csv', index=False)

print(f"\nTier 2 + Tier 1B analysis complete!")
print(f"Results saved to: {RESULTS_DIR}")

drug        model_type  n_features  n_tier2  n_novel  n_samples    auroc    auprc  precision   recall       f1  cv_mean   cv_std  tier2_contribution  novel_contribution
 AMX Tier2_plus_Tier1B        1286     1286        0       1089 0.931113 0.965674   0.973684 0.840909 0.902439 0.946707 0.006225                 1.0                 0.0
 AMC Tier2_plus_Tier1B        1286     1286        0       1650 0.796630 0.631623   0.565217 0.650000 0.604651 0.797948 0.019398                 1.0                 0.0
 CIP Tier2_plus_Tier1B        1286     1286        0       1650 0.941914 0.854599   0.770270 0.791667 0.780822 0.949577 0.012414                 1.0                 0.0

Tier 2 + Tier 1B analysis complete!
Results saved to: /content/drive/MyDrive/results
