# **SHAP Feature Importance Analysis**
Goal: Identify SNP positions most associated with AMR for each antibiotic and Compare to paper's EFS-identified genes

In [None]:
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
import shap
import matplotlib.pyplot as plt
from datetime import datetime

## **LOAD DATA**

In [None]:
data = pd.read_csv("/content/drive/MyDrive/ML-iAMR_Recreation/01_data/raw/giessen/cip_ctx_ctz_gen_multi_data.csv")
pheno = pd.read_csv("/content/drive/MyDrive/ML-iAMR_Recreation/01_data/raw/giessen/cip_ctx_ctz_gen_pheno.csv", index_col=0)

#get SNP position names (columns starting with 'X')
snp_positions = [col for col in data.columns if col.startswith('X')]
X = data[snp_positions].values

In [None]:
EXPERIMENT_ID = f"EXP-007-{datetime.now().strftime('%Y%m%d_%H%M%S')}"
print(f"Experiment: {EXPERIMENT_ID}")

Experiment: EXP-007-20251106_212753


In [None]:
#paper's EFS-identified SNP positions (from Table 8)
PAPER_SNPS = {
    'CIP': [2655873, 4627668, 2017588, 4172893, 4605418, 4453756, 4477553, 4101302, 4466572, 4441487],
    'CTX': [2008324, 2655873, 4453756, 4172893, 18169, 4441487, 4605418, 4477553, 2017588, 4466572],
    'CTZ': [2008324, 3099618, 18169, 4441487, 4453756, 4477553, 4101302, 4605418, 2017588, 4466572],
    'GEN': [4466572, 3644715, 4172893, 2017588, 18169, 4441487, 4605418, 4230581, 4127700, 2655873]
}

In [None]:
all_results = []

for ab in ['CIP', 'CTX', 'CTZ', 'GEN']:
    print(f"\n{'='*60}")
    print(f"Analyzing {ab}")
    print(f"{'='*60}")

    y = pheno[ab].values

    #split data
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, stratify=y, random_state=42
    )

    #train RF
    print(f"Training Random Forest...")
    rf = RandomForestClassifier(n_estimators=200, random_state=42, n_jobs=-1)
    rf.fit(X_train, y_train)

    #compute SHAP values (using TreeExplainer for speed)
    print(f"Computing SHAP values (this may take a few minutes)...")
    explainer = shap.TreeExplainer(rf, X_train)

    #use subset for speed (sample 100 instances from test set)
    sample_size = min(100, len(X_test))
    X_test_sample = X_test[:sample_size]

    shap_values = explainer.shap_values(X_test_sample)

    #get SHAP values for resistant class (class 1)
    if isinstance(shap_values, list):
        shap_values_resistant = shap_values[1] #this is a (num_samples, num_features) array
    else:
        # If shap_values is a 3D numpy array (num_samples, num_features, num_classes), extract the resistant class (class 1) SHAP values.
        #following assumes the last dimension is for classes.
        shap_values_resistant = shap_values[:, :, 1] #this will also be a (num_samples, num_features) array

    #calculate mean absolute SHAP value for each feature
    mean_shap = np.abs(shap_values_resistant).mean(axis=0)

    #get top 20 features
    top_20_indices = np.argsort(mean_shap)[-20:][::-1]
    top_20_snps = [snp_positions[i] for i in top_20_indices]
    top_20_shap = mean_shap[top_20_indices]

    #extract position numbers from column names ('X2655873' -> 2655873)
    top_20_positions = [int(snp.replace('X', '').split('.')[0]) for snp in top_20_snps]

    print(f"\nTop 20 AMR-associated SNPs for {ab}:")
    for i, (pos, shap_val) in enumerate(zip(top_20_positions, top_20_shap), 1):
        in_paper = '(Paper)' if pos in PAPER_SNPS[ab] else ''
        print(f"  {i:2d}. Position {pos:7d}: SHAP = {shap_val:.4f} {in_paper}")

    #check overlap with paper's EFS results
    overlap = set(top_20_positions[:10]).intersection(set(PAPER_SNPS[ab]))
    overlap_pct = len(overlap) / 10 * 100

    print(f"\nOverlap with paper's top 10 EFS SNPs: {len(overlap)}/10 ({overlap_pct:.0f}%)")

    #save results
    for i, (pos, shap_val) in enumerate(zip(top_20_positions, top_20_shap), 1):
        all_results.append({
            'Experiment_ID': EXPERIMENT_ID,
            'Antibiotic': ab,
            'Rank': i,
            'SNP_Position': pos,
            'SHAP_Value': round(shap_val, 6),
            'In_Paper_EFS': pos in PAPER_SNPS[ab],
            'Overlap_Rate': overlap_pct if i == 1 else None
        })

    #plot SHAP summary
    plt.figure(figsize=(10, 6))
    plt.barh(range(10), top_20_shap[:10][::-1])
    plt.yticks(range(10), [f"Pos {p}" for p in top_20_positions[:10][::-1]])
    plt.xlabel('Mean |SHAP Value|')
    plt.title(f'{ab}: Top 10 AMR-Associated SNPs')
    plt.tight_layout()
    plt.savefig(f'/content/drive/MyDrive/ML-iAMR_Recreation/05_evaluation/results/{EXPERIMENT_ID}_{ab}_shap_top10.png', dpi=300, bbox_inches='tight')
    plt.close()
    print(f"SHAP plot saved: {EXPERIMENT_ID}_{ab}_shap_top10.png")


Analyzing CIP
Training Random Forest...
Computing SHAP values (this may take a few minutes)...





Top 20 AMR-associated SNPs for CIP:
   1. Position 4441487: SHAP = 0.0042 (Paper)
   2. Position 4428463: SHAP = 0.0037 
   3. Position 4614902: SHAP = 0.0035 
   4. Position 4441501: SHAP = 0.0033 
   5. Position 4488025: SHAP = 0.0031 
   6. Position 4441487: SHAP = 0.0029 (Paper)
   7. Position 2015702: SHAP = 0.0027 
   8. Position 4428463: SHAP = 0.0027 
   9. Position 4458732: SHAP = 0.0024 
  10. Position 4172893: SHAP = 0.0023 (Paper)
  11. Position 4230581: SHAP = 0.0023 
  12. Position 4125149: SHAP = 0.0021 
  13. Position 2591849: SHAP = 0.0021 
  14. Position 2912911: SHAP = 0.0020 
  15. Position 4271406: SHAP = 0.0020 
  16. Position 3538947: SHAP = 0.0019 
  17. Position 4230581: SHAP = 0.0018 
  18. Position 4428463: SHAP = 0.0017 
  19. Position 4428463: SHAP = 0.0017 
  20. Position 2274094: SHAP = 0.0017 

Overlap with paper's top 10 EFS SNPs: 2/10 (20%)
SHAP plot saved: EXP-007-20251106_212753_CIP_shap_top10.png

Analyzing CTX
Training Random Forest...
Computing S




Top 20 AMR-associated SNPs for CTX:
   1. Position 3538947: SHAP = 0.0019 
   2. Position 2732353: SHAP = 0.0018 
   3. Position 3304819: SHAP = 0.0018 
   4. Position 4466572: SHAP = 0.0018 (Paper)
   5. Position 2222412: SHAP = 0.0016 
   6. Position 4477553: SHAP = 0.0015 (Paper)
   7. Position 4632238: SHAP = 0.0015 
   8. Position 4406510: SHAP = 0.0015 
   9. Position 4441487: SHAP = 0.0015 (Paper)
  10. Position 4469454: SHAP = 0.0014 
  11. Position 4428463: SHAP = 0.0014 
  12. Position 3642836: SHAP = 0.0014 
  13. Position 2831042: SHAP = 0.0014 
  14. Position 4125149: SHAP = 0.0014 
  15. Position 4469454: SHAP = 0.0013 
  16. Position 4441501: SHAP = 0.0013 
  17. Position 3644715: SHAP = 0.0013 
  18. Position   36991: SHAP = 0.0013 
  19. Position 2077992: SHAP = 0.0013 
  20. Position 2911855: SHAP = 0.0013 

Overlap with paper's top 10 EFS SNPs: 3/10 (30%)
SHAP plot saved: EXP-007-20251106_212753_CTX_shap_top10.png

Analyzing CTZ
Training Random Forest...
Computing S




Top 20 AMR-associated SNPs for CTZ:
   1. Position 2792739: SHAP = 0.0023 
   2. Position 4466572: SHAP = 0.0022 (Paper)
   3. Position 4114164: SHAP = 0.0018 
   4. Position 4114164: SHAP = 0.0017 
   5. Position 4466572: SHAP = 0.0016 (Paper)
   6. Position 4443644: SHAP = 0.0016 
   7. Position   36991: SHAP = 0.0015 
   8. Position 4619917: SHAP = 0.0015 
   9. Position 4469454: SHAP = 0.0014 
  10. Position 4614839: SHAP = 0.0013 
  11. Position 3538947: SHAP = 0.0013 
  12. Position  907779: SHAP = 0.0013 
  13. Position 4639512: SHAP = 0.0013 
  14. Position 2274094: SHAP = 0.0013 
  15. Position 4469454: SHAP = 0.0013 
  16. Position 4428463: SHAP = 0.0012 
  17. Position 4441501: SHAP = 0.0012 
  18. Position 1754779: SHAP = 0.0012 
  19. Position 4421044: SHAP = 0.0012 
  20. Position 2076861: SHAP = 0.0011 

Overlap with paper's top 10 EFS SNPs: 1/10 (10%)
SHAP plot saved: EXP-007-20251106_212753_CTZ_shap_top10.png

Analyzing GEN
Training Random Forest...
Computing SHAP val




Top 20 AMR-associated SNPs for GEN:
   1. Position 4114164: SHAP = 0.0020 
   2. Position 3538947: SHAP = 0.0019 
   3. Position  898919: SHAP = 0.0016 
   4. Position 3986701: SHAP = 0.0014 
   5. Position 2445120: SHAP = 0.0014 
   6. Position 4614902: SHAP = 0.0014 
   7. Position 1956366: SHAP = 0.0014 
   8. Position 4469454: SHAP = 0.0013 
   9. Position 4125149: SHAP = 0.0013 
  10. Position 1474029: SHAP = 0.0013 
  11. Position 1882400: SHAP = 0.0012 
  12. Position 4271406: SHAP = 0.0012 
  13. Position 3986701: SHAP = 0.0012 
  14. Position 4428463: SHAP = 0.0012 
  15. Position 4230581: SHAP = 0.0011 (Paper)
  16. Position 3986701: SHAP = 0.0011 
  17. Position 3642836: SHAP = 0.0011 
  18. Position 2251940: SHAP = 0.0011 
  19. Position 2445120: SHAP = 0.0010 
  20. Position 4466572: SHAP = 0.0010 (Paper)

Overlap with paper's top 10 EFS SNPs: 0/10 (0%)
SHAP plot saved: EXP-007-20251106_212753_GEN_shap_top10.png


## **SAVE RESULTS**

In [None]:
results_df = pd.DataFrame(all_results)
results_df.to_csv(f"/content/drive/MyDrive/ML-iAMR_Recreation/05_evaluation/results/{EXPERIMENT_ID}_shap_analysis.csv", index=False)

print("SHAP ANALYSIS SUMMARY")


summary = results_df.groupby('Antibiotic').agg({
    'Overlap_Rate': 'first',
    'In_Paper_EFS': 'sum'
}).reset_index()
summary.columns = ['Antibiotic', 'Top10_Overlap_%', 'Total_Matches_in_Top20']

print(summary.to_string(index=False))
print(f"\nFull results saved to results/{EXPERIMENT_ID}_shap_analysis.csv")

SHAP ANALYSIS SUMMARY
Antibiotic  Top10_Overlap_%  Total_Matches_in_Top20
       CIP             20.0                       3
       CTX             30.0                       3
       CTZ             10.0                       2
       GEN              0.0                       2

Full results saved to results/EXP-007-20251106_212753_shap_analysis.csv
