# **CNN with GradCAM Interpretability**
1D CNN architecture adapted from paper's design for label-encoded SNPs

GradCAM highlights which SNP regions contribute to resistance predictions

In [1]:
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import roc_auc_score, classification_report
import matplotlib.pyplot as plt
from datetime import datetime

## **LOAD DATA**

In [2]:
data = pd.read_csv("/content/drive/MyDrive/ML-iAMR_Recreation/01_data/raw/giessen/cip_ctx_ctz_gen_multi_data.csv")
pheno = pd.read_csv("/content/drive/MyDrive/ML-iAMR_Recreation/01_data/raw/giessen/cip_ctx_ctz_gen_pheno.csv", index_col=0)
X = data.drop('prename', axis=1).values

In [None]:
EXPERIMENT_ID = f"EXP-010-{datetime.now().strftime('%Y%m%d_%H%M%S')}"
print(f"Experiment: {EXPERIMENT_ID}")

Experiment: EXP-010-20251107_092352


## **CNN ARCHITECTURE**

In [3]:
def build_1d_cnn(input_shape, ab_name):
    """1D CNN - Sequential model"""
    model = models.Sequential(name=f"CNN_{ab_name}")

    model.add(layers.Input(shape=input_shape))    #input layer
    model.add(layers.Reshape((input_shape[0], 1)))    #add channel dimension

    #convolutional Block 1
    model.add(layers.Conv1D(8, 3, activation='relu', padding='same', name='conv1'))
    model.add(layers.BatchNormalization(name='bn1'))
    model.add(layers.MaxPooling1D(2, name='pool1'))

    #convolutional Block 2
    model.add(layers.Conv1D(8, 3, activation='relu', padding='same', name='conv2'))
    model.add(layers.BatchNormalization(name='bn2'))
    model.add(layers.MaxPooling1D(2, name='pool2'))

    #convolutional Block 3
    model.add(layers.Conv1D(16, 3, activation='relu', padding='same', name='conv3'))
    model.add(layers.MaxPooling1D(2, name='pool3'))

    #convolutional Block 4 (target for GradCAM)
    model.add(layers.Conv1D(16, 3, activation='relu', padding='same', name='conv4_gradcam'))
    model.add(layers.MaxPooling1D(2, name='pool4'))

    #dense layers
    model.add(layers.Flatten(name='flatten'))
    model.add(layers.Dropout(0.5, name='dropout'))
    model.add(layers.Dense(64, activation='relu', name='dense'))
    model.add(layers.Dense(2, activation='softmax', name='output'))

    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

## **GRADCAM IMPLEMENTATION**

In [4]:
def make_gradcam_heatmap(model, sample, conv_layer_name='conv4_gradcam', pred_index=None):
    """
    Generate GradCAM heatmap for a given sample

    Args:
        model: Trained Keras model
        sample: Input sample (1D array)
        conv_layer_name: Name of convolutional layer to visualize
        pred_index: Class index (1 for resistant)

    Returns:
        heatmap: 1D heatmap of feature importance
    """
    #ensure model is built
    if not model.built:
        #build by calling with a dummy input
        dummy_input = tf.zeros((1, sample.shape[0]))
        _ = model(dummy_input)

    #prepare input
    sample_input = tf.convert_to_tensor(sample.reshape(1, -1), dtype=tf.float32)

    #get the target layer
    target_layer = model.get_layer(conv_layer_name)

    #create a sub-model from input to target layer output
    grad_model = keras.Model(
        inputs=model.input,
        outputs=[target_layer.output, model.output]
    )

    #record operations for automatic differentiation
    with tf.GradientTape() as tape:
        conv_output, predictions = grad_model(sample_input, training=False)

        if pred_index is None:
            pred_index = tf.argmax(predictions[0])

        loss = predictions[:, pred_index]

    #compute gradients
    grads = tape.gradient(loss, conv_output)

    if grads is None:
        print(f"  WARNING: Gradients are None, returning zeros")
        return np.zeros(conv_output.shape[1])

    #global average pooling
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1))

    #weight the conv output by pooled gradients
    conv_output = conv_output[0]
    heatmap = conv_output @ pooled_grads[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)

    #reLU and normalize
    heatmap = tf.maximum(heatmap, 0)
    max_val = tf.reduce_max(heatmap)
    if max_val > 0:
        heatmap = heatmap / max_val

    return heatmap.numpy()

## **TRAIN CNN FOR EACH ANTIBIOTIC**

In [7]:
all_results = []
gradcam_examples = {}

for ab in ['CIP', 'CTX', 'CTZ', 'GEN']:
    print(f"\n{'='*70}")
    print(f"TRAINING CNN FOR {ab}")
    print(f"{'='*70}")

    y = pheno[ab].values

     #Normalize features (important for CNNs)
    X_normalized = (X - X.mean(axis=0)) / (X.std(axis=0) + 1e-8)

    X_train, X_test, y_train, y_test = train_test_split(
        X_normalized, y, test_size=0.2, stratify=y, random_state=42
    )

    print(f"Training samples: {len(X_train)}, Test samples: {len(X_test)}")
    print(f"Class distribution (train): {np.bincount(y_train)}")

    #build model
    input_shape = (X_train.shape[1],)
    model = build_1d_cnn(input_shape, ab)

    print("\nModel Architecture:")
    model.summary()

    #class weights
    class_weight = {0: 1.0, 1: len(y_train) / (2 * np.sum(y_train))}

    #train
    print(f"\nTraining CNN...")
    history = model.fit(
        X_train, y_train,
        epochs=20,
        batch_size=32,
        validation_split=0.2,
        class_weight=class_weight,
        verbose=1
    )

    #evaluate
    y_pred_proba = model.predict(X_test, verbose=0)
    y_pred_class = np.argmax(y_pred_proba, axis=1)

    auc = roc_auc_score(y_test, y_pred_proba[:, 1])

    print(f"\nTest AUC: {auc:.4f}")
    print("\nClassification Report:")
    print(classification_report(y_test, y_pred_class, target_names=['Susceptible', 'Resistant']))

    all_results.append({
        'Experiment_ID': EXPERIMENT_ID,
        'Antibiotic': ab,
        'AUC': round(auc, 4),
        'Train_Samples': len(X_train),
        'Test_Samples': len(X_test),
    })

    #GRADCAM
    print(f"\n--- GradCAM Analysis for {ab} ---")

    #select examples: 1 resistant, 1 susceptible (correctly classified)
    resistant_indices = np.where((y_test == 1) & (y_pred_class == 1))[0]
    susceptible_indices = np.where((y_test == 0) & (y_pred_class == 0))[0]

    if len(resistant_indices) > 0 and len(susceptible_indices) > 0:
        resistant_idx = resistant_indices[0]
        susceptible_idx = susceptible_indices[0]

        try:
            print(f"  Generating GradCAM for resistant sample...")
            heatmap_resistant = make_gradcam_heatmap(model, X_test[resistant_idx], pred_index=1)

            print(f"  Generating GradCAM for susceptible sample...")
            heatmap_susceptible = make_gradcam_heatmap(model, X_test[susceptible_idx], pred_index=0)

            gradcam_examples[ab] = {
                'resistant': heatmap_resistant,
                'susceptible': heatmap_susceptible,
                'resistant_sample': X_test[resistant_idx],
                'susceptible_sample': X_test[susceptible_idx]
            }

            print(f"GradCAM heatmaps generated successfully!")
            print(f"  Resistant: max={heatmap_resistant.max():.4f}, mean={heatmap_resistant.mean():.4f}")
            print(f"  Susceptible: max={heatmap_susceptible.max():.4f}, mean={heatmap_susceptible.mean():.4f}")

        except Exception as e:
            print(f"GradCAM failed: {str(e)}")
            import traceback
            traceback.print_exc()
    else:
        print(f"Not enough correctly classified samples")


TRAINING CNN FOR CIP
Training samples: 647, Test samples: 162
Class distribution (train): [354 293]

Model Architecture:



Training CNN...
Epoch 1/20
[1m17/17[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 488ms/step - accuracy: 0.6576 - loss: 6.7090 - val_accuracy: 0.7385 - val_loss: 0.5702
Epoch 2/20
[1m17/17[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 243ms/step - accuracy: 0.8210 - loss: 0.6730 - val_accuracy: 0.8538 - val_loss: 0.5036
Epoch 3/20
[1m17/17[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 251ms/step - accuracy: 0.8743 - loss: 0.4697 - val_accuracy: 0.8385 - val_loss: 0.5594
Epoch 4/20
[1m17/17[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 243ms/step - accuracy: 0.8617 - loss: 0.4705 - val_accuracy: 0.7692 - val_loss: 0.5313
Epoch 5/20
[1m17/17[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 243ms/step - accuracy: 0.8971 - loss: 0.4645 - val_accuracy: 0.8538 - val_loss: 0.4684
Epoch 6/20
[1m17/17[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 247ms/step - accuracy: 0.9389 - loss: 0.2317 - val_accuracy: 0.7692 - val_loss: 0.6288
Epoch 7/20

Traceback (most recent call last):
  File "/tmp/ipython-input-2783864653.py", line 73, in <cell line: 0>
    heatmap_resistant = make_gradcam_heatmap(model, X_test[resistant_idx], pred_index=1)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipython-input-3126995105.py", line 28, in make_gradcam_heatmap
    inputs=model.input,
           ^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/keras/src/ops/operation.py", line 276, in input
    return self._get_node_attribute_at_index(0, "input_tensors", "input")
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/keras/src/ops/operation.py", line 307, in _get_node_attribute_at_index
    raise AttributeError(
AttributeError: The layer CNN_CIP has never been called and thus has no defined input.. Did you mean: 'inputs'?


Training samples: 647, Test samples: 162
Class distribution (train): [361 286]

Model Architecture:



Training CNN...
Epoch 1/20
[1m17/17[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 528ms/step - accuracy: 0.5326 - loss: 5.6239 - val_accuracy: 0.5923 - val_loss: 0.6809
Epoch 2/20
[1m17/17[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 245ms/step - accuracy: 0.6160 - loss: 0.6848 - val_accuracy: 0.6154 - val_loss: 0.6900
Epoch 3/20
[1m17/17[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 244ms/step - accuracy: 0.6816 - loss: 0.6026 - val_accuracy: 0.5923 - val_loss: 0.6848
Epoch 4/20
[1m17/17[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 245ms/step - accuracy: 0.7200 - loss: 0.5651 - val_accuracy: 0.7077 - val_loss: 0.6432
Epoch 5/20
[1m17/17[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 245ms/step - accuracy: 0.7927 - loss: 0.4904 - val_accuracy: 0.5308 - val_loss: 0.6700
Epoch 6/20
[1m17/17[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 244ms/step - accuracy: 0.7890 - loss: 0.4849 - val_accuracy: 0.6000 - val_loss: 0.6558
Epoch 7/20




Test AUC: 0.8427

Classification Report:
              precision    recall  f1-score   support

 Susceptible       0.85      0.71      0.78        90
   Resistant       0.70      0.85      0.77        72

    accuracy                           0.77       162
   macro avg       0.78      0.78      0.77       162
weighted avg       0.79      0.77      0.77       162


--- GradCAM Analysis for CTX ---
  Generating GradCAM for resistant sample...
GradCAM failed: The layer CNN_CTX has never been called and thus has no defined input.

TRAINING CNN FOR CTZ


Traceback (most recent call last):
  File "/tmp/ipython-input-2783864653.py", line 73, in <cell line: 0>
    heatmap_resistant = make_gradcam_heatmap(model, X_test[resistant_idx], pred_index=1)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipython-input-3126995105.py", line 28, in make_gradcam_heatmap
    inputs=model.input,
           ^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/keras/src/ops/operation.py", line 276, in input
    return self._get_node_attribute_at_index(0, "input_tensors", "input")
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/keras/src/ops/operation.py", line 307, in _get_node_attribute_at_index
    raise AttributeError(
AttributeError: The layer CNN_CTX has never been called and thus has no defined input.


Training samples: 647, Test samples: 162
Class distribution (train): [426 221]

Model Architecture:



Training CNN...
Epoch 1/20
[1m17/17[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 474ms/step - accuracy: 0.4719 - loss: 15.0283 - val_accuracy: 0.3231 - val_loss: 0.6940
Epoch 2/20
[1m17/17[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 244ms/step - accuracy: 0.3602 - loss: 0.8042 - val_accuracy: 0.6769 - val_loss: 0.6927
Epoch 3/20
[1m17/17[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 245ms/step - accuracy: 0.6757 - loss: 0.7969 - val_accuracy: 0.6769 - val_loss: 0.6903
Epoch 4/20
[1m17/17[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 242ms/step - accuracy: 0.6277 - loss: 0.8120 - val_accuracy: 0.6769 - val_loss: 0.6884
Epoch 5/20
[1m17/17[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 243ms/step - accuracy: 0.6509 - loss: 0.8033 - val_accuracy: 0.6769 - val_loss: 0.6859
Epoch 6/20
[1m17/17[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 244ms/step - accuracy: 0.6627 - loss: 0.7981 - val_accuracy: 0.6769 - val_loss: 0.6836
Epoch 7/2




Test AUC: 0.5000

Classification Report:
              precision    recall  f1-score   support

 Susceptible       0.66      1.00      0.80       107
   Resistant       0.00      0.00      0.00        55

    accuracy                           0.66       162
   macro avg       0.33      0.50      0.40       162
weighted avg       0.44      0.66      0.53       162


--- GradCAM Analysis for CTZ ---
Not enough correctly classified samples

TRAINING CNN FOR GEN


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Training samples: 647, Test samples: 162
Class distribution (train): [497 150]

Model Architecture:



Training CNN...
Epoch 1/20
[1m17/17[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 510ms/step - accuracy: 0.5467 - loss: 4.5754 - val_accuracy: 0.7462 - val_loss: 0.6835
Epoch 2/20
[1m17/17[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 244ms/step - accuracy: 0.4336 - loss: 0.8643 - val_accuracy: 0.2538 - val_loss: 0.6937
Epoch 3/20
[1m17/17[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 243ms/step - accuracy: 0.4196 - loss: 0.8636 - val_accuracy: 0.7462 - val_loss: 0.6909
Epoch 4/20
[1m17/17[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 245ms/step - accuracy: 0.7745 - loss: 0.8723 - val_accuracy: 0.7462 - val_loss: 0.6880
Epoch 5/20
[1m17/17[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 244ms/step - accuracy: 0.7926 - loss: 0.8552 - val_accuracy: 0.7462 - val_loss: 0.6837
Epoch 6/20
[1m17/17[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 244ms/step - accuracy: 0.7670 - loss: 0.8746 - val_accuracy: 0.7462 - val_loss: 0.6805
Epoch 7/20

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


## **SAVE RESULTS**

In [9]:
results_df = pd.DataFrame(all_results)

print("CNN TRAINING SUMMARY")
print(results_df[['Antibiotic', 'AUC', 'Train_Samples', 'Test_Samples']].to_string(index=False))

results_df.to_csv(f"/content/drive/MyDrive/ML-iAMR_Recreation/05_evaluation/results/{EXPERIMENT_ID}_cnn_results.csv", index=False)
print(f"\nResults saved to results/{EXPERIMENT_ID}_cnn_results.csv")

CNN TRAINING SUMMARY
Antibiotic    AUC  Train_Samples  Test_Samples
       CIP 0.9314            647           162

Results saved to results/EXP-010-20260118_211421_cnn_results.csv


## **VISUALIZE GRADCAM**

In [10]:
if len(gradcam_examples) > 0:
    print("\n--- Generating GradCAM Visualizations ---")

    n_abs = len(gradcam_examples)
    fig, axes = plt.subplots(n_abs, 2, figsize=(16, 4*n_abs))

    if n_abs == 1:
        axes = axes.reshape(1, -1)

    for idx, (ab, example) in enumerate(gradcam_examples.items()):
        #get heatmaps
        heatmap_res = example['resistant']
        heatmap_sus = example['susceptible']

        #upsample to match input size
        input_len = len(example['resistant_sample'])
        repeat_factor = max(1, input_len // len(heatmap_res))

        heatmap_res_up = np.repeat(heatmap_res, repeat_factor)
        heatmap_sus_up = np.repeat(heatmap_sus, repeat_factor)

        #adjust length
        if len(heatmap_res_up) > input_len:
            heatmap_res_up = heatmap_res_up[:input_len]
            heatmap_sus_up = heatmap_sus_up[:input_len]
        elif len(heatmap_res_up) < input_len:
            pad_len = input_len - len(heatmap_res_up)
            heatmap_res_up = np.pad(heatmap_res_up, (0, pad_len), mode='edge')
            heatmap_sus_up = np.pad(heatmap_sus_up, (0, pad_len), mode='edge')

        #plot (show first 1000 SNPs for clarity)
        display_len = min(1000, input_len)

        ax1 = axes[idx, 0]
        im1 = ax1.imshow(heatmap_res_up[np.newaxis, :display_len],
                         cmap='hot', aspect='auto', interpolation='nearest')
        ax1.set_title(f'{ab} - Resistant Sample', fontsize=13, fontweight='bold')
        ax1.set_xlabel(f'SNP Position (first {display_len})', fontsize=11)
        ax1.set_yticks([])
        plt.colorbar(im1, ax=ax1, fraction=0.046, label='GradCAM Activation')

        ax2 = axes[idx, 1]
        im2 = ax2.imshow(heatmap_sus_up[np.newaxis, :display_len],
                         cmap='hot', aspect='auto', interpolation='nearest')
        ax2.set_title(f'{ab} - Susceptible Sample', fontsize=13, fontweight='bold')
        ax2.set_xlabel(f'SNP Position (first {display_len})', fontsize=11)
        ax2.set_yticks([])
        plt.colorbar(im2, ax=ax2, fraction=0.046, label='GradCAM Activation')

    plt.tight_layout()
    plt.savefig(f'{EXPERIMENT_ID}_gradcam.png', dpi=300, bbox_inches='tight')
    plt.show()
    print(f"GradCAM visualization saved: {EXPERIMENT_ID}_gradcam.png")

## **IDENTIFY KEY SNPS FROM GRADCAM**

In [11]:
snp_cols = [col for col in data.columns if col.startswith('X')]

for ab, example in gradcam_examples.items():
    print(f"\n{ab}:")
    heatmap = example['resistant']

    top_k = min(10, len(heatmap))
    top_indices = np.argsort(heatmap)[-top_k:][::-1]

    repeat_factor = max(1, len(snp_cols) // len(heatmap))

    for rank, idx in enumerate(top_indices, 1):
        snp_idx = min(idx * repeat_factor, len(snp_cols) - 1)
        snp_name = snp_cols[snp_idx]
        snp_pos = snp_name.replace('X', '')
        activation = heatmap[idx]
        print(f"  {rank:2d}. Position {snp_pos:>7s}: Activation={activation:.4f}")

In [12]:
print("TOP SNPs IDENTIFIED BY GRADCAM (per antibiotic)")

snp_positions = [col for col in data.columns if col.startswith('X')]

for ab in ['CIP', 'CTX', 'CTZ', 'GEN']:
    if ab not in gradcam_examples:
        continue

    print(f"\n{ab}:")

    heatmap = gradcam_examples[ab]['resistant']

    #map heatmap back to SNP positions
    repeat_factor = len(snp_positions) // len(heatmap)

    #find top activated regions
    top_k = 10
    top_indices = np.argsort(heatmap)[-top_k:][::-1]

    for rank, idx in enumerate(top_indices, 1):
        #map back to original SNP index
        snp_idx = idx * repeat_factor
        if snp_idx < len(snp_positions):
            snp_name = snp_positions[snp_idx]
            snp_pos = snp_name.replace('X', '')
            activation = heatmap[idx]
            print(f"  {rank:2d}. SNP Position {snp_pos:>7s}: Activation = {activation:.4f}")

print("- Red regions (high activation): SNPs strongly contributing to prediction")
print("- Blue regions (low activation): SNPs with minimal contribution")
print("- Compare resistant vs susceptible samples to identify discriminative SNPs")

TOP SNPs IDENTIFIED BY GRADCAM (per antibiotic)
- Red regions (high activation): SNPs strongly contributing to prediction
- Blue regions (low activation): SNPs with minimal contribution
- Compare resistant vs susceptible samples to identify discriminative SNPs
