# **CHROMOSOMAL MUTATIONS ONLY ANALYSIS**

**CHROMOSOMAL (Fixed genomic positions) = Inherited vertically (clonally) Models**

**`Tests the hypothesis:`** "`Can mutations alone predict resistance?`"

This is `critical for understanding`:
1. CIP resistance (expected: HIGH performance - mutation-driven)
2. AMX/AMC resistance (expected: LOW performance - gene-driven)

In [None]:
from xgboost import XGBClassifier
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
from sklearn.metrics import (roc_auc_score, classification_report,
                              average_precision_score, precision_score,
                              recall_score, f1_score, roc_curve)
import shap
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from pathlib import Path
import pickle

## **CONFIGURATION**

In [None]:
BASE_DIR = Path('/content/drive/MyDrive')
DATA_DIR = BASE_DIR / 'amr_features'
MODEL_DIR = BASE_DIR / 'models'
RESULTS_DIR = BASE_DIR / 'results'

#create directories if they don't exist
MODEL_DIR.mkdir(exist_ok=True)
RESULTS_DIR.mkdir(exist_ok=True)

## **Standardize sample IDs**

In [None]:
def standardize_sample_id(sample_id):
    """
    Convert between different ID formats:
    - ERR123456_1 ---> ERR123456#1
    - 11657_5_25 ---> 11657_5#25
    - ERR123456 ---> ERR123456
    """
    sample_id = str(sample_id).strip()

    #if already has #, return as-is
    if '#' in sample_id:
        return sample_id

    #replace last underscore with #
    parts = sample_id.rsplit('_', 1)
    if len(parts) == 2:
        return f"{parts[0]}#{parts[1]}"

    return sample_id

def fix_sample_ids(df):
    """Apply standardization to DataFrame index"""
    df.index = df.index.map(standardize_sample_id)
    return df

## **LOAD AND PREPARE DATA**

In [None]:
#load SNP mutations (chromosomal only)
tier2_snps = pd.read_csv(DATA_DIR / 'tier2_snp_mutations_filtered.csv', index_col=0)
print(f"SNP Data Shape: {tier2_snps.shape}")
print(f"Sample ID format: {tier2_snps.index[0]}")

tier2_snps.head()

SNP Data Shape: (1089, 827)
Sample ID format: 11657_5_25


Unnamed: 0,ampC_T86A,ampC_A356S,gyrA_I198L,gyrA_S111P,gyrA_R91C,robA_L279V,ftsI_H425Q,thyA_Q39E,marR_I137L,thyA_L184F,...,gyrA_E505K,ampC_V194F,ampC_I300M,gyrB_V122L,tolC_*377*,tolC_E318D,gyrB_V371M,parC_A533S,phoP_G24R,phoQ_E336K
11657_5_25,1,1,1,1,1,1,1,1,1,1,...,0,0,0,0,0,0,0,0,0,0
11657_5_26,1,1,1,1,1,1,1,1,1,1,...,0,0,0,0,0,0,0,0,0,0
11657_5_27,1,1,1,1,1,1,1,1,1,1,...,0,0,0,0,0,0,0,0,0,0
11657_5_29,1,1,1,1,1,1,1,1,1,1,...,0,0,0,0,0,0,0,0,0,0
11657_5_30,1,1,1,1,1,1,1,1,1,1,...,0,0,0,0,0,0,0,0,0,0


In [None]:
#load phenotypes
phenotypes = pd.read_csv(BASE_DIR / 'data' / 'E.coli' / 'phenotypic.csv')
print(f"\nPhenotype Data Shape: {phenotypes.shape}")
phenotypes.head()


Phenotype Data Shape: (1936, 17)


Unnamed: 0,ENA.Accession.Number,Isolate,Lane.accession,Sequecning Status,Year,CTZ,CTX,AMP,AMX,AMC,TZP,CXM,CET,GEN,TBM,TMP,CIP
0,ERS356929,11657_5#10,ERR434268,Previously sequenced,2010.0,S,S,S,,S,S,S,I,S,S,S,S
1,ERS356935,11657_5#11,ERR434269,Previously sequenced,2010.0,S,S,R,,R,S,S,I,S,S,R,R
2,ERS356938,11657_5#12,ERR434270,Previously sequenced,2010.0,S,S,S,,S,S,S,S,S,S,S,S
3,ERS356941,11657_5#13,ERR434271,Previously sequenced,2010.0,S,S,R,,R,S,S,I,S,S,S,R
4,ERS356967,11657_5#14,ERR434272,Previously sequenced,2010.0,S,S,R,,S,S,S,S,S,S,R,S


In [None]:
#standardize phenotype IDs
if 'Isolate' in phenotypes.columns:
    phenotypes.set_index('Isolate', inplace=True)
elif 'Lane.accession' in phenotypes.columns:
    phenotypes.set_index('Lane.accession', inplace=True)

print(f"Phenotype sample ID format: {phenotypes.index[0]}")

#fix sample IDs in both datasets
tier2_snps = fix_sample_ids(tier2_snps)
phenotypes = fix_sample_ids(phenotypes)

Phenotype sample ID format: 11657_5#10


In [None]:
#find common samples
common_samples = tier2_snps.index.intersection(phenotypes.index)
print(f"{'='*80}")
print(f"SAMPLE MATCHING")
print(f"{'='*80}")
print(f"SNP samples: {len(tier2_snps)}")
print(f"Phenotype samples: {len(phenotypes)}")
print(f"Common samples: {len(common_samples)}")

SAMPLE MATCHING
SNP samples: 1089
Phenotype samples: 1936
Common samples: 1089


In [None]:
if len(common_samples) < 100:
    print("WARNING: Very few matching samples!")
    print("First 5 SNP IDs:", list(tier2_snps.index[:5]))
    print("First 5 Phenotype IDs:", list(phenotypes.index[:5]))

    #try alternative matching
    print("\nTrying alternative ID matching...")

    #extract base IDs (before # or _)
    snp_base_ids = {str(idx).split('#')[0].split('_')[0]: idx
                    for idx in tier2_snps.index}
    pheno_base_ids = {str(idx).split('#')[0].split('_')[0]: idx
                      for idx in phenotypes.index}

    #find matches
    matching_bases = set(snp_base_ids.keys()) & set(pheno_base_ids.keys())

    if len(matching_bases) > len(common_samples):
        print(f"Found {len(matching_bases)} matches using base IDs!")
        #create mapping
        id_mapping = {}
        for base_id in matching_bases:
            snp_id = snp_base_ids[base_id]
            pheno_id = pheno_base_ids[base_id]
            id_mapping[snp_id] = pheno_id

        #rename SNP indices
        tier2_snps.rename(index=id_mapping, inplace=True)
        common_samples = tier2_snps.index.intersection(phenotypes.index)
        print(f"New common samples: {len(common_samples)}")

assert len(common_samples) > 0, "No matching samples found! Check ID formats."

In [None]:
#align datasets
X = tier2_snps.loc[common_samples]
phenotypes = phenotypes.loc[common_samples]

print(f"Final aligned data: {X.shape}")
print(f"Features (mutations): {X.shape[1]}")

Final aligned data: (1089, 827)
Features (mutations): 827


## **TRAIN MODELS FOR EACH DRUG**

In [None]:
results = []
all_predictions = {}

for drug in ['AMX', 'AMC', 'CIP']:
    print(f"\n{'='*80}")
    print(f"CHROMOSOMAL MUTATIONS MODEL: {drug}")
    print(f"{'='*80}")

    #prepare labels
    if drug not in phenotypes.columns:
        print(f"WARNING: {drug} not found in phenotype data!")
        print(f"Available columns: {phenotypes.columns.tolist()}")
        continue

    y = phenotypes[drug].map({'R': 1, 'S': 0, 'I': 0})
    y = y.dropna()
    X_drug = X.loc[y.index]

    print(f"\nData prepared for {drug}")
    print(f"  Total samples: {len(X_drug)}")
    print(f"  Resistance counts:")
    print(f"    Resistant (R=1):   {(y==1).sum()} ({(y==1).sum()/len(y)*100:.1f}%)")
    print(f"    Susceptible (S=0): {(y==0).sum()} ({(y==0).sum()/len(y)*100:.1f}%)")

    if len(X_drug) < 50:
        print(f"WARNING: Only {len(X_drug)} samples for {drug}. Skipping.")
        continue

    #train/test split
    X_train, X_test, y_train, y_test = train_test_split(
        X_drug, y, test_size=0.2, random_state=42, stratify=y
    )

    #calculate class weights
    n_resistant = (y_train == 1).sum()
    n_susceptible = (y_train == 0).sum()
    scale_pos_weight = n_susceptible / n_resistant if n_resistant > 0 else 1.0

    print(f"\nTraining configuration:")
    print(f"  Train samples: {len(X_train)} (R={n_resistant}, S={n_susceptible})")
    print(f"  Test samples:  {len(X_test)}")
    print(f"  scale_pos_weight: {scale_pos_weight:.2f}")

    #train model
    model = XGBClassifier(
        max_depth=6,
        n_estimators=100,
        learning_rate=0.1,
        scale_pos_weight=scale_pos_weight,
        random_state=42,
        eval_metric='auc',
        verbosity=0
    )

    model.fit(X_train, y_train)

    #evaluate
    y_pred_proba = model.predict_proba(X_test)[:, 1]
    y_pred = model.predict(X_test)

    auroc = roc_auc_score(y_test, y_pred_proba)
    auprc = average_precision_score(y_test, y_pred_proba)
    precision = precision_score(y_test, y_pred, zero_division=0)
    recall = recall_score(y_test, y_pred, zero_division=0)
    f1 = f1_score(y_test, y_pred, zero_division=0)

    print(f"\n{'='*80}")
    print(f"PERFORMANCE METRICS")
    print(f"{'='*80}")
    print(f"AUROC:     {auroc:.4f}")
    print(f"AUPRC:     {auprc:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall:    {recall:.4f}")
    print(f"F1 Score:  {f1:.4f}")

    print(f"\nClassification Report:")
    print(classification_report(y_test, y_pred, target_names=['S', 'R'],
                                 zero_division=0))

    #5-fold cross-validation
    print(f"\n5-Fold Cross-Validation:")
    cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    cv_scores = cross_val_score(model, X_drug, y, cv=cv,
                                 scoring='roc_auc', n_jobs=-1)
    print(f"  Mean AUROC: {cv_scores.mean():.4f} ± {cv_scores.std():.4f}")
    print(f"  Scores: {[f'{s:.3f}' for s in cv_scores]}")

    #SHAP analysis
    print(f"\nComputing SHAP values...")
    explainer = shap.TreeExplainer(model)
    shap_values = explainer.shap_values(X_test)

    #top 20 features
    feature_importance = pd.DataFrame({
        'feature': X_drug.columns,
        'shap_importance': np.abs(shap_values).mean(axis=0)
    }).sort_values('shap_importance', ascending=False)

    print(f"\n{'='*80}")
    print(f"TOP 20 MUTATIONS")
    print(f"{'='*80}")
    print(feature_importance.head(20).to_string(index=False))

    #gene-level summary
    print(f"\n{'='*80}")
    print(f"MUTATIONS BY GENE")
    print(f"{'='*80}")

    #extract gene names from mutation features
    feature_importance['gene'] = feature_importance['feature'].str.split('_').str[0]
    gene_importance = feature_importance.groupby('gene')['shap_importance'].agg([
        'sum', 'mean', 'count'
    ]).sort_values('sum', ascending=False)

    print(gene_importance.head(10).to_string())

    #save results
    results.append({
        'drug': drug,
        'model_type': 'Mutations_Only',
        'n_features': len(X_drug.columns),
        'n_samples': len(X_drug),
        'auroc': auroc,
        'auprc': auprc,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'cv_mean': cv_scores.mean(),
        'cv_std': cv_scores.std()
    })

    #save predictions
    all_predictions[drug] = {
        'y_test': y_test,
        'y_pred': y_pred,
        'y_pred_proba': y_pred_proba
    }

    #save model
    model.save_model(MODEL_DIR / f'mutations_only_model_{drug}.json')

    #save feature importance
    feature_importance.to_csv(
        RESULTS_DIR / f'mutations_only_feature_importance_{drug}.csv',
        index=False
    )

    #save SHAP plot
    plt.figure(figsize=(10, 8))
    shap.summary_plot(shap_values, X_test, show=False, max_display=20)
    plt.title(f'{drug} - Chromosomal Mutations SHAP Summary')
    plt.tight_layout()
    plt.savefig(RESULTS_DIR / f'mutations_only_shap_{drug}.png',
                dpi=300, bbox_inches='tight')
    plt.close()


CHROMOSOMAL MUTATIONS MODEL: AMX

Data prepared for AMX
  Total samples: 1089
  Resistance counts:
    Resistant (R=1):   659 (60.5%)
    Susceptible (S=0): 430 (39.5%)

Training configuration:
  Train samples: 871 (R=527, S=344)
  Test samples:  218
  scale_pos_weight: 0.65

PERFORMANCE METRICS
AUROC:     0.7416
AUPRC:     0.8170
Precision: 0.8229
Recall:    0.5985
F1 Score:  0.6930

Classification Report:
              precision    recall  f1-score   support

           S       0.57      0.80      0.66        86
           R       0.82      0.60      0.69       132

    accuracy                           0.68       218
   macro avg       0.69      0.70      0.68       218
weighted avg       0.72      0.68      0.68       218


5-Fold Cross-Validation:
  Mean AUROC: 0.6862 ± 0.0215
  Scores: ['0.657', '0.686', '0.668', '0.709', '0.711']

Computing SHAP values...

TOP 20 MUTATIONS
   feature  shap_importance
 gyrA_S83L         0.359015
 gyrA_V85F         0.183294
parC_L157V         0.

## **COMPARATIVE ANALYSIS: MUTATIONS vs GENES**

In [None]:
results_df = pd.DataFrame(results)
print(results_df.to_string(index=False))

#save summary
results_df.to_csv(RESULTS_DIR / 'mutations_only_results_summary.csv',
                  index=False)

drug     model_type  n_features  n_samples    auroc    auprc  precision   recall       f1  cv_mean   cv_std
 AMX Mutations_Only         827       1089 0.741587 0.816960   0.822917 0.598485 0.692982 0.686216 0.021464
 AMC Mutations_Only         827       1089 0.706536 0.545280   0.455446 0.707692 0.554217 0.710650 0.024922
 CIP Mutations_Only         827       1089 0.972375 0.958387   1.000000 0.916667 0.956522 0.979688 0.005370


## **SAVE PREDICTIONS FOR STATISTICAL TESTING**

In [None]:
with open(RESULTS_DIR / 'mutations_only_predictions.pkl', 'wb') as f:
    pickle.dump(all_predictions, f)

print(f"All results saved to: {RESULTS_DIR}")

All results saved to: /content/drive/MyDrive/results
