# **TIER 3 (T2+T1B+NOVEL) MODEL SETUP**

**$\text{Tier 2} + \text{Tier 1B} + \mathbf{Novel\ Genes}$**

In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split, StratifiedKFold, cross_val_score
from xgboost import XGBClassifier
from sklearn.metrics import roc_auc_score, average_precision_score, precision_score, recall_score, f1_score, classification_report
import numpy as np
import shap
import matplotlib.pyplot as plt
from pathlib import Path
from scipy.stats import pearsonr

In [None]:
def standardize_sample_id(sample_id):
    """Standardize sample IDs"""
    sample_id = str(sample_id).strip()
    if '#' in sample_id:
        return sample_id
    parts = sample_id.rsplit('_', 1)
    return f"{parts[0]}#{parts[1]}" if len(parts) == 2 else sample_id

def fix_sample_ids(df):
    """Apply standardization to DataFrame"""
    df.index = df.index.map(standardize_sample_id)
    return df

In [None]:
def remove_lineage_markers(roary_df, tier2_df, correlation_threshold=0.7):
    """
    Remove genes highly correlated with known lineage markers
    to reduce population structure confounding
    """
    #known lineage markers (from mutation analysis)
    potential_markers = ['rz', 'yedI', 'nmpC', 'gatD', 'betU', 'yeeO']

    genes_to_keep = []
    removed_genes = []

    for gene in roary_df.columns:
        if gene in potential_markers:
            removed_genes.append(gene)
            continue

        #check correlation with Tier 2 features (chromosomal mutations)
        max_corr = 0
        for tier2_feature in tier2_df.columns:
            try:
                corr, _ = pearsonr(roary_df[gene], tier2_df[tier2_feature])
                max_corr = max(max_corr, abs(corr))
            except:
                continue

        if max_corr < correlation_threshold:
            genes_to_keep.append(gene)
        else:
            removed_genes.append(gene)

    print(f"  Correlation filter: kept {len(genes_to_keep)}, removed {len(removed_genes)}")
    return roary_df[genes_to_keep], removed_genes

In [None]:
BASE_DIR = Path('/content/drive/MyDrive')
DATA_DIR = BASE_DIR / 'amr_features'
ROARY_DIR = BASE_DIR / 'pangenome_features'
MODEL_DIR = BASE_DIR / 'models'
RESULTS_DIR = BASE_DIR / 'results'

MODEL_DIR.mkdir(exist_ok=True)
RESULTS_DIR.mkdir(exist_ok=True)

In [None]:
#1. Load Tier 2 (genes + mutations combined)
tier2_df = pd.read_csv(DATA_DIR / 'tier2_amr_genes_plus_mutations.csv', index_col=0)
tier2_df = fix_sample_ids(tier2_df)

#2. Load Tier 1B Stress/Metal Genes
tier1b_stress = pd.read_csv('/content/drive/MyDrive/amr_features/tier1b_stress_genes.csv', index_col='ISOLATE_ID')
tier1b_stress = fix_sample_ids(tier1b_stress)

#3. Concatenate Tier 2 and Tier 1B to create the expanded TIER 2 baseline
tier2_df = pd.concat([tier2_df, tier1b_stress], axis=1, join='outer')
tier2_df.fillna(0, inplace=True)
print(f"  Expanded Tier 2 (T2 + T1B) feature count: {tier2_df.shape[1]}")

  Expanded Tier 2 (T2 + T1B) feature count: 1286


In [None]:
#4.Load phenotypes and find common samples
phenotypes = pd.read_csv(BASE_DIR / 'data' / 'E.coli' / 'phenotypic.csv')
if 'Isolate' in phenotypes.columns:
    phenotypes.set_index('Isolate', inplace=True)
elif 'Lane.accession' in phenotypes.columns:
    phenotypes.set_index('Lane.accession', inplace=True)

phenotypes = fix_sample_ids(phenotypes)
common_samples = tier2_df.index.intersection(phenotypes.index)

#align Tier 2 to common samples
tier2_df = tier2_df.loc[common_samples]
print(f"  Final Aligned Tier 2 shape: {tier2_df.shape}")

  Final Aligned Tier 2 shape: (1651, 1286)


## **TIER 3 (T2 + T1B + NOVEL) MODEL TRAINING LOOP**

In [None]:
results = []
novel_gene_discoveries = {}

for drug in ['AMX', 'AMC', 'CIP']:
    print(f"\n{'='*80}")
    print(f"TIER 3 MODEL: {drug} (T2 + T1B + Novel)")
    print(f"{'='*80}")

    #load Roary filtered genes for this drug (Novel Genes)
    roary_file = ROARY_DIR / f'roary_filtered_{drug}_top500_decorrelated_v2.csv'

    if not roary_file.exists():
        print(f"Roary file not found: {roary_file}. Skipping {drug}")
        continue

    print(f"Loading Roary genes from: {roary_file.name}")
    roary_df = pd.read_csv(roary_file, index_col=0)
    roary_df = fix_sample_ids(roary_df)
    print(f"  Roary shape: {roary_df.shape}")

    #remove lineage markers (now using the expanded tier2_df for filtering)
    print(f"Applying lineage marker filter...")
    roary_filtered, removed = remove_lineage_markers(
        roary_df, tier2_df, correlation_threshold=0.7
    )
    #the 'removed' variable from the function should contain the count of removed genes
    print(f"  Correlation filter: kept {roary_filtered.shape[1]}, removed {removed}")
    print(f"  Filtered Roary shape: {roary_filtered.shape}")

    #combine Expanded Tier 2 + Filtered Roary (Novel Genes)
    tier3_combined = pd.concat([tier2_df, roary_filtered], axis=1, join='inner')
    print(f"\nTier 3 combined shape: {tier3_combined.shape}")
    print(f"  Tier 2 + Tier 1B features: {tier2_df.shape[1]}")
    print(f"  Novel genes:     {roary_filtered.shape[1]}")
    print(f"  Total features:  {tier3_combined.shape[1]}")

    #PREPARE LABELS AND ALIGN FEATURES

    #use the index of the final combined feature set to filter the phenotypes
    y_raw = phenotypes.loc[tier3_combined.index, drug].map({'R': 1, 'S': 0, 'I': 0})

    #drop any samples where the phenotype was missing (NaN), which might occur for samples in tier3_combined that had an 'I' or NaN in the raw phenotype file.
    y = y_raw.dropna()

    # Align the feature matrix X_drug to the cleaned labels y. This step is now safe because y.index is a subset of tier3_combined.index
    X_drug = tier3_combined.loc[y.index]

    print(f"\nData prepared:")
    print(f"  Samples: {len(X_drug)}")
    print(f"  Resistant:   {(y==1).sum()} ({(y==1).sum()/len(y)*100:.1f}%)")
    print(f"  Susceptible: {(y==0).sum()} ({(y==0).sum()/len(y)*100:.1f}%)")

    # # Prepare labels
    # y = phenotypes.loc[common_samples, drug].map({'R': 1, 'S': 0, 'I': 0})
    # y = y.dropna()
    # X_drug = tier3_combined.loc[y.index]

    print(f"\nData prepared:")
    print(f"  Samples: {len(X_drug)}")
    print(f"  Resistant:   {(y==1).sum()} ({(y==1).sum()/len(y)*100:.1f}%)")
    print(f"  Susceptible: {(y==0).sum()} ({(y==0).sum()/len(y)*100:.1f}%)")

    #train/test split
    X_train, X_test, y_train, y_test = train_test_split(
        X_drug, y, test_size=0.2, random_state=42, stratify=y
    )

    # Class weights
    n_resistant = (y_train == 1).sum()
    n_susceptible = (y_train == 0).sum()
    scale_pos_weight = n_susceptible / n_resistant if n_resistant > 0 else 1.0

    print(f"\nTraining configuration:")
    print(f"  Train: {len(X_train)} (R={n_resistant}, S={n_susceptible})")
    print(f"  Test:  {len(X_test)}")
    print(f"  scale_pos_weight: {scale_pos_weight:.2f}")

    # Train with regularization
    model = XGBClassifier(
        max_depth=5,
        n_estimators=100,
        learning_rate=0.1,
        scale_pos_weight=scale_pos_weight,
        min_child_weight=3,
        gamma=0.1,
        subsample=0.8,
        colsample_bytree=0.8,
        reg_alpha=0.1,
        reg_lambda=1.0,
        random_state=42,
        eval_metric='auc',
        verbosity=0
    )

    model.fit(X_train, y_train)

    # Evaluate
    y_pred_proba = model.predict_proba(X_test)[:, 1]
    y_pred = model.predict(X_test)

    auroc = roc_auc_score(y_test, y_pred_proba)
    auprc = average_precision_score(y_test, y_pred_proba)
    precision = precision_score(y_test, y_pred, zero_division=0)
    recall = recall_score(y_test, y_pred, zero_division=0)
    f1 = f1_score(y_test, y_pred, zero_division=0)

    print(f"\n{'='*80}")
    print(f"PERFORMANCE")
    print(f"{'='*80}")
    print(f"AUROC:     {auroc:.4f}")
    print(f"AUPRC:     {auprc:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall:    {recall:.4f}")
    print(f"F1 Score:  {f1:.4f}")

    print(f"\nClassification Report:")
    print(classification_report(y_test, y_pred, target_names=['S', 'R'],
                                zero_division=0))

    # Cross-validation
    cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    cv_scores = cross_val_score(model, X_drug, y, cv=cv,
                                 scoring='roc_auc', n_jobs=-1)
    print(f"\n5-Fold CV: {cv_scores.mean():.4f} ± {cv_scores.std():.4f}")

    # SHAP analysis
    print(f"\nComputing SHAP values...")
    explainer = shap.TreeExplainer(model)
    shap_values = explainer.shap_values(X_test)

    # Feature importance
    feature_importance = pd.DataFrame({
        'feature': X_drug.columns,
        'shap_importance': np.abs(shap_values).mean(axis=0)
    }).sort_values('shap_importance', ascending=False)

    # Separate novel genes from Tier 2 (which now includes T1B)
    feature_importance['feature_type'] = feature_importance['feature'].apply(
        lambda x: 'Tier2+Tier1B' if x in tier2_df.columns else 'Novel'
    )

    print(f"\n{'='*80}")
    print(f"TOP 20 FEATURES (ALL)")
    print(f"{'='*80}")
    print(feature_importance.head(20).to_string(index=False))

    # Novel genes only
    novel_features = feature_importance[
        feature_importance['feature_type'] == 'Novel'
    ].head(20)

    print(f"\n{'='*80}")
    print(f"TOP 20 NOVEL GENES")
    print(f"{'='*80}")
    print(novel_features[['feature', 'shap_importance']].to_string(index=False))

    # Store for validation pipeline
    novel_gene_discoveries[drug] = novel_features

    # Feature type contribution
    tier2_importance = feature_importance[
        feature_importance['feature_type'] == 'Tier2+Tier1B'
    ]['shap_importance'].sum()

    novel_importance = feature_importance[
        feature_importance['feature_type'] == 'Novel'
    ]['shap_importance'].sum()

    total_importance = tier2_importance + novel_importance

    print(f"\n{'='*80}")
    print(f"FEATURE CONTRIBUTION")
    print(f"{'='*80}")
    print(f"Tier 2 + Tier 1B: {tier2_importance:.2f} "
          f"({tier2_importance/total_importance*100:.1f}%)")
    print(f"Novel genes:              {novel_importance:.2f} "
          f"({novel_importance/total_importance*100:.1f}%)")

    # Save results
    results.append({
        'drug': drug,
        'model_type': 'Tier3_T2+T1B_Novel', # New model type name
        'n_features': len(X_drug.columns),
        'n_tier2_plus_t1b': tier2_df.shape[1],
        'n_novel': roary_filtered.shape[1],
        'n_samples': len(X_drug),
        'auroc': auroc,
        'auprc': auprc,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'cv_mean': cv_scores.mean(),
        'cv_std': cv_scores.std(),
        'tier2_plus_t1b_contribution': tier2_importance/total_importance,
        'novel_contribution': novel_importance/total_importance
    })

    # Save model
    model.save_model(MODEL_DIR / f'tier3_t2_t1b_novel_model_{drug}.json')

    # Save feature importance
    feature_importance.to_csv(
        RESULTS_DIR / f'tier3_t2_t1b_novel_feature_importance_{drug}.csv',
        index=False
    )

    # Save novel genes separately
    novel_features.to_csv(
        RESULTS_DIR / f'tier3_t2_t1b_novel_genes_{drug}.csv',
        index=False
    )

    # SHAP plot
    plt.figure(figsize=(10, 8))
    shap.summary_plot(shap_values, X_test, show=False, max_display=20)
    plt.title(f'{drug} - Tier 3 (T2+T1B+Novel) SHAP Summary')
    plt.tight_layout()
    plt.savefig(RESULTS_DIR / f'tier3_t2_t1b_novel_shap_{drug}.png',
                dpi=300, bbox_inches='tight')
    plt.close()


TIER 3 MODEL: AMX (T2 + T1B + Novel)
Loading Roary genes from: roary_filtered_AMX_top500_decorrelated_v2.csv
  Roary shape: (1089, 489)
Applying lineage marker filter...
  Correlation filter: kept 487, removed 2
  Correlation filter: kept 487, removed ['betU', 'rz']
  Filtered Roary shape: (1089, 487)

Tier 3 combined shape: (1089, 1773)
  Tier 2 + Tier 1B features: 1286
  Novel genes:     487
  Total features:  1773

Data prepared:
  Samples: 1089
  Resistant:   659 (60.5%)
  Susceptible: 430 (39.5%)

Data prepared:
  Samples: 1089
  Resistant:   659 (60.5%)
  Susceptible: 430 (39.5%)

Training configuration:
  Train: 871 (R=527, S=344)
  Test:  218
  scale_pos_weight: 0.65

PERFORMANCE
AUROC:     0.9265
AUPRC:     0.9634
Precision: 0.9492
Recall:    0.8485
F1 Score:  0.8960

Classification Report:
              precision    recall  f1-score   support

           S       0.80      0.93      0.86        86
           R       0.95      0.85      0.90       132

    accuracy            

### **FINAL SUMMARY**


In [None]:
results_df = pd.DataFrame(results)
print(results_df.to_string(index=False))

results_df.to_csv(RESULTS_DIR / 'tier3_t2_t1b_novel_results_summary.csv', index=False)

print(f"\nTier 3 (T2+T1B+Novel) analysis ready to run!")

drug         model_type  n_features  n_tier2_plus_t1b  n_novel  n_samples    auroc    auprc  precision   recall       f1  cv_mean   cv_std  tier2_plus_t1b_contribution  novel_contribution
 AMX Tier3_T2+T1B_Novel        1773              1286      487       1089 0.926533 0.963450   0.949153 0.848485 0.896000 0.944742 0.006056                     0.613682            0.386318
 AMC Tier3_T2+T1B_Novel        1773              1286      487       1089 0.883057 0.753697   0.631579 0.738462 0.680851 0.853462 0.019449                     0.586841            0.413159
 CIP Tier3_T2+T1B_Novel        1775              1286      489       1089 0.962454 0.952710   0.970588 0.916667 0.942857 0.978913 0.003845                     0.725260            0.274740

Tier 3 (T2+T1B+Novel) analysis ready to run!


### **Extracting Novel Genes**

In [None]:
pan = pd.read_csv('/content/drive/MyDrive/data/E.coli/gene_presence_absence.csv', low_memory=False)

In [None]:
amx_novel = ['ybeT','group_3820','yhjK_2','group_11074','group_24688','traL','hsdR','group_7497','group_5999','group_16945',
             'fucA_3','group_7221','group_16687','group_18988','group_11744','group_9728','sopB','group_3326','mtnA','group_16811']

amc_novel = ['group_3326','yfeA','group_24688','group_7896','wcaM','group_16955','insA_1','group_16990','group_17656',
             'group_312','group_16885','sopB','merA_2','group_13334','vapB_2','yhdJ_2','group_20927','group_6211','group_14353','fI']

cip_novel = ['group_16337','group_13365','group_8907','group_9126','ybcQ','group_16890','group_24688','chpB','chpS',
             'group_3308','group_20717','group_11114','sopA','yggM','group_12558','sopB','group_5974','yiaM_1','group_7650','ygjO']

# Unique combined list
novel_genes = sorted(set(amx_novel + amc_novel + cip_novel))

print(f"Unique novel genes: {len(novel_genes)}")
print(novel_genes)

# Filter only gene_annotation columns (first ~16 cols)
meta_cols = list(pan.columns[:16])

# Extract rows for novel genes
novel_info = pan[pan['Gene'].isin(novel_genes)][meta_cols]

# Reset index for clean output
novel_info = novel_info.reset_index(drop=True)

novel_info

Unique novel genes: 55
['chpB', 'chpS', 'fI', 'fucA_3', 'group_11074', 'group_11114', 'group_11744', 'group_12558', 'group_13334', 'group_13365', 'group_14353', 'group_16337', 'group_16687', 'group_16811', 'group_16885', 'group_16890', 'group_16945', 'group_16955', 'group_16990', 'group_17656', 'group_18988', 'group_20717', 'group_20927', 'group_24688', 'group_312', 'group_3308', 'group_3326', 'group_3820', 'group_5974', 'group_5999', 'group_6211', 'group_7221', 'group_7497', 'group_7650', 'group_7896', 'group_8907', 'group_9126', 'group_9728', 'hsdR', 'insA_1', 'merA_2', 'mtnA', 'sopA', 'sopB', 'traL', 'vapB_2', 'wcaM', 'ybcQ', 'ybeT', 'yfeA', 'yggM', 'ygjO', 'yhdJ_2', 'yhjK_2', 'yiaM_1']


Unnamed: 0,Gene,Non-unique Gene name,Annotation,No. isolates,No. sequences,Avg sequences per isolate,Genome Fragment,Order within Fragment,Accessory Fragment,Accessory Order with Fragment,QC,Min group size nuc,Max group size nuc,Avg group size nuc,11657_5#25,11657_5#26
0,ybeT,,outer membrane protein,797,4955,6.22,1,41088,,,,125,3299,1071,11657_5#25_01418\t11657_5#25_02329,
1,wcaM,,colanic acid biosynthesis protein,783,783,1.0,1,19199,1.0,28018.0,,131,1394,1391,11657_5#25_01555,11657_5#26_04551
2,ygjO,,Putative enzyme,747,747,1.0,1,2542,1.0,17085.0,,476,1136,1131,11657_5#25_02390,11657_5#26_01938
3,ybcQ,,lambdoid prophage DLP12 antitermination protein Q,738,738,1.0,1,10583,1.0,12578.0,,215,383,378,11657_5#25_04703,11657_5#26_04568
4,yggM,,putative alpha helix chain,733,733,1.0,1,2790,1.0,1658.0,,197,1007,1002,11657_5#25_03202,11657_5#26_01707
5,traL,,conjugative transfer protein,721,721,1.0,1,42837,1.0,36341.0,,311,314,311,,
6,yfeA,,putative diguanylate cyclase,680,680,1.0,1,17190,1.0,25881.0,,278,2195,2167,11657_5#25_03319,11657_5#26_00353
7,group_9728,nepI,ribonucleoside transporter,547,547,1.0,1,27238,1.0,24380.0,,455,1238,1225,11657_5#25_03652,11657_5#26_02965
8,hsdR,,GntR family transcriptional regulator,542,557,1.03,1,22417,1.0,3851.0,,854,3512,1588,,11657_5#26_02466
9,group_16337,yedI,putative methyl-independent mismatch repair pr...,506,506,1.0,1,33533,1.0,6381.0,,200,917,914,11657_5#25_01412,
