# Implement Grad-CAM

In [20]:
import h5py
import os
import pickle
from tqdm import tqdm
from time import gmtime, strftime
import numpy as np
import math
from sklearn import metrics
from sklearn.metrics import roc_curve, confusion_matrix, roc_auc_score
import tensorflow as tf
from tensorflow.keras import layers,Model
from sklearn.model_selection import KFold
import gc
import time
from sklearn.model_selection import KFold
import import_test as data_load

In [35]:
MAX_SEQ_LENGTH= 1100
NUM_FEATURE = 1024 # esm1 & 2 a
NUM_FILTER = 64
NUM_HIDDEN = 512#100
BATCH_SIZE  = 16
WINDOW_SIZES = [4, 8, 16]
NUM_CLASSES = 2
CLASS_NAMES = ['1','0']
EPOCHS      =15
VALIDATION_MODE="independent" # cross or independent
class_names = ["Sodium", "Membrane"]

In [36]:
import tensorflow as tf

# Check if TensorFlow can access the GPU
physical_devices = tf.config.list_physical_devices('GPU')
print("GPUs available: ", len(physical_devices))

GPUs available:  1


In [37]:
class DeepScan(Model):
    def __init__(self,
                input_shape=(1, MAX_SEQ_LENGTH, NUM_FEATURE),
                window_sizes=WINDOW_SIZES,
                num_filters=NUM_FILTER,
                num_hidden=NUM_HIDDEN):
        super(DeepScan, self).__init__()
        # Add input layer
        self.input_layer = tf.keras.Input(shape=input_shape, name='input_layer')
        self.window_sizes = window_sizes
        self.conv2d = []
        self.maxpool = []
        self.flatten = []
        
        # Create named layers for each window size
        for i, window_size in enumerate(self.window_sizes):
            self.conv2d.append(
                layers.Conv2D(filters=num_filters,
                            kernel_size=(1, window_size),
                            activation=tf.nn.relu,
                            padding='valid',
                            bias_initializer=tf.constant_initializer(0.1),
                            kernel_initializer=tf.keras.initializers.GlorotUniform(),
                            name=f'conv2d_{i}'))
            
            self.maxpool.append(
                layers.MaxPooling2D(pool_size=(1, MAX_SEQ_LENGTH - window_size + 1),
                                  strides=(1, MAX_SEQ_LENGTH),
                                  padding='valid',
                                  name=f'maxpool_{i}'))
            
            self.flatten.append(layers.Flatten(name=f'flatten_{i}'))
            
        self.dropout = layers.Dropout(rate=0.7, name='dropout')
        self.fc1 = layers.Dense(
            num_hidden,
            activation=tf.nn.relu,
            bias_initializer=tf.constant_initializer(0.1),
            kernel_initializer=tf.keras.initializers.GlorotUniform(),
            name='dense_1')
        
        self.fc2 = layers.Dense(NUM_CLASSES,
                               activation='softmax',
                               kernel_regularizer=tf.keras.regularizers.l2(1e-3),
                               name='dense_2')
        
        # Build the model by calling it once
        self.out = self.call(self.input_layer)

    def call(self, x, training=False):
        _x = []
        for i in range(len(self.window_sizes)):
            x_conv = self.conv2d[i](x)
            x_maxp = self.maxpool[i](x_conv)
            x_flat = self.flatten[i](x_maxp)
            _x.append(x_flat)
        x = tf.concat(_x, 1)
        x = self.dropout(x, training=training)
        x = self.fc1(x)
        x = self.fc2(x)
        return x

    def get_conv_output(self, x, conv_idx):
        """Helper method to get convolutional layer output"""
        return self.conv2d[conv_idx](x)

In [43]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

# [Previous initialize_model and LastLayerGradCAM class ]implementations remain the same
class LastLayerGradCAM:
    def __init__(self, model):
        print("Initializing LastLayerGradCAM...")
        self.model = model
        self.last_conv_idx = len(model.window_sizes) - 1
        print(f"Last convolutional layer index: {self.last_conv_idx}")
    
    def compute_heatmap(self, input_data, target_class_idx):
        print("\nComputing heatmap...")
        print(f"Input shape: {input_data.shape}")
        
        try:
            with tf.GradientTape() as tape:
                print("Setting up gradient tape...")
                inputs = tf.convert_to_tensor(input_data)
                tape.watch(inputs)
                
                print("Starting forward pass...")
                _x = []
                last_conv_output = None
                
                # Process through conv layers
                for i in range(len(self.model.window_sizes)):
                    print(f"Processing conv layer {i}")
                    x_conv = self.model.conv2d[i](inputs)
                    if i == self.last_conv_idx:
                        last_conv_output = x_conv
                        print(f"Captured last conv output, shape: {last_conv_output.shape}")
                    x_maxp = self.model.maxpool[i](x_conv)
                    x_flat = self.model.flatten[i](x_maxp)
                    _x.append(x_flat)
                
                print("Concatenating features...")
                x = tf.concat(_x, 1)
                print(f"Concatenated shape: {x.shape}")
                
                print("Applying final layers...")
                x = self.model.dropout(x, training=False)
                x = self.model.fc1(x)
                predictions = self.model.fc2(x)
                print(f"Predictions shape: {predictions.shape}")
                
                target_class_score = predictions[:, target_class_idx]
                print(f"Target class score shape: {target_class_score.shape}")
            
            print("Computing gradients...")
            grads = tape.gradient(target_class_score, last_conv_output)
            
            if grads is None:
                raise ValueError("Gradients are None. The computation graph might be disconnected.")
            
            print(f"Gradients shape: {grads.shape}")
            
            # Complete the Grad-CAM computation
            pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
            last_conv_output = last_conv_output[0]
            weighted_conv_output = tf.multiply(last_conv_output, pooled_grads)
            heatmap = tf.reduce_sum(weighted_conv_output, axis=-1)
            
            # Normalize
            heatmap = tf.maximum(heatmap, 0)
            max_val = tf.reduce_max(heatmap)
            if max_val != 0:
                heatmap = heatmap / max_val
            
            print("Heatmap computation successful")
            return heatmap.numpy()
            
        except Exception as e:
            print(f"Error in compute_heatmap: {str(e)}")
            raise
def save_heatmap(heatmap, save_dir, sequence_idx, label):
    """
    Save heatmap visualization with label-specific directory organization
    """
    # Determine subdirectory based on label
    label_dir = 'positive' if label == 1 else 'negative'
    label_save_dir = os.path.join(save_dir, label_dir)
    os.makedirs(label_save_dir, exist_ok=True)
    
    # Create full paths
    save_path = os.path.join(label_save_dir, f'attention_map_sequence_{sequence_idx}.png')
    np_save_path = os.path.join(label_save_dir, f'attention_map_sequence_{sequence_idx}.npy')
    
    try:
        # Create figure
        plt.figure(figsize=(15, 3))
        
        # Plot heatmap
        reshaped_heatmap = heatmap.reshape(1, -1)
        plt.imshow(reshaped_heatmap, aspect='auto', cmap='hot')
        plt.colorbar(label='Attention Score')
        plt.title(f'Last Layer Attention Map ({"Positive" if label == 1 else "Negative"} Sample)')
        plt.xlabel('Sequence Position')
        plt.ylabel('Attention Weight')
        plt.tight_layout()
        
        # Save visualization and raw data
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        np.save(np_save_path, heatmap)
        plt.close()
        
        return True
    except Exception as e:
        print(f"Error saving heatmap: {str(e)}")
        return False

def process_samples(test_data, test_labels, model, save_dir, samples_per_class=40):
    """
    Process samples and generate heatmaps for both positive and negative classes
    """
    gradcam = LastLayerGradCAM(model)
    
    # Create indices for positive and negative samples
    positive_indices = np.where(test_labels == 1)[0]
    negative_indices = np.where(test_labels == 0)[0]
    
    print(f"Found {len(positive_indices)} positive samples and {len(negative_indices)} negative samples")
    
    # Ensure we have enough samples
    if len(positive_indices) < samples_per_class or len(negative_indices) < samples_per_class:
        print("Warning: Not enough samples in one or both classes!")
        samples_per_class = min(len(positive_indices), len(negative_indices))
        print(f"Adjusting to {samples_per_class} samples per class")
    
    # Select random samples from each class
    np.random.seed(42)  # For reproducibility
    selected_positive = np.random.choice(positive_indices, samples_per_class, replace=False)
    selected_negative = np.random.choice(negative_indices, samples_per_class, replace=False)
    
    # Process positive samples
    print("\nProcessing positive samples...")
    for i, idx in enumerate(selected_positive):
        print(f"\nProcessing positive sample {i+1}/{samples_per_class}")
        sequence = test_data[idx].reshape(1, 1, test_data.shape[1], test_data.shape[2])
        heatmap = gradcam.compute_heatmap(sequence, target_class_idx=1)  # Use class 1 for positive
        success = save_heatmap(heatmap, save_dir, idx, label=1)
        if not success:
            print(f"Failed to save positive sample {i+1}")
    
    # Process negative samples
    print("\nProcessing negative samples...")
    for i, idx in enumerate(selected_negative):
        print(f"\nProcessing negative sample {i+1}/{samples_per_class}")
        sequence = test_data[idx].reshape(1, 1, test_data.shape[1], test_data.shape[2])
        heatmap = gradcam.compute_heatmap(sequence, target_class_idx=0)  # Use class 0 for negative
        success = save_heatmap(heatmap, save_dir, idx, label=0)
        if not success:
            print(f"Failed to save negative sample {i+1}")

def main():
    try:
        # Define constants
        print("Setting up constants...")
        MAX_SEQ_LENGTH = 1100
        NUM_FEATURE = 1024
        SAMPLES_PER_CLASS = 40
        
        print("\nLoading data...")
        try:
            test_data = np.load("C:/jupyter/Malik/SodiumTransporters/ProtTrans/All_Test_data.npy")
            test_labels = np.load("C:/jupyter/Malik/SodiumTransporters/ProtTrans/All_Test_labels.npy")
            print(f"Data loaded successfully. Shape: {test_data.shape}")
        except Exception as e:
            print(f"Error loading data: {str(e)}")
            return
        
        # Initialize model
        model = initialize_model(MAX_SEQ_LENGTH, NUM_FEATURE)
        
        # Load the pre-trained weights (if available)
        weight_file = "path_to_your_model_weights/model_weights.h5"
        if os.path.exists(weight_file):
            print(f"Loading model weights from {weight_file}...")
            model.load_weights(weight_file)
            print("Model weights loaded successfully.")
        else:
            print(f"Model weights file {weight_file} not found!")
        
        # Create save directory with correct path
        save_dir = os.path.abspath("C:/jupyter/Malik/SodiumTransporters/Code/AttentionMaps")
        os.makedirs(save_dir, exist_ok=True)
        print(f"\nSave directory created/verified: {save_dir}")
        
        # Process samples for both classes
        process_samples(test_data, test_labels, model, save_dir, SAMPLES_PER_CLASS)
        
        print("\nProcess completed!")
        print(f"Please check the directory: {save_dir}")
        print("Positive samples are in the 'positive' subdirectory")
        print("Negative samples are in the 'negative' subdirectory")
        
    except Exception as e:
        print(f"Error in main function: {str(e)}")
        print("Stack trace:")
        import traceback
        traceback.print_exc()

if __name__ == "__main__":
    print("Starting program...")
    main()

Starting program...
Setting up constants...

Loading data...
Data loaded successfully. Shape: (1161, 1100, 1024)
Initializing model...
Building model with dummy data...
Loading weights...
Model initialized successfully!

Save directory created/verified: C:\jupyter\Malik\SodiumTransporters\Code\AttentionMaps
Initializing LastLayerGradCAM...
Last convolutional layer index: 2
Found 85 positive samples and 1076 negative samples

Processing positive samples...

Processing positive sample 1/40

Computing heatmap...
Input shape: (1, 1, 1100, 1024)
Setting up gradient tape...
Starting forward pass...
Processing conv layer 0
Processing conv layer 1
Processing conv layer 2
Captured last conv output, shape: (1, 1, 1085, 64)
Concatenating features...
Concatenated shape: (1, 192)
Applying final layers...
Predictions shape: (1, 2)
Target class score shape: (1,)
Computing gradients...
Gradients shape: (1, 1, 1085, 64)
Heatmap computation successful

Processing positive sample 2/40

Computing heatmap.

## Select few image to display

In [47]:
import random

# Define the directories where the heatmaps are saved
heatmap_save_dir = "C:/jupyter/Malik/SodiumTransporters/Code/AttentionMaps"
positives_dir = os.path.join(heatmap_save_dir, "positive")
negatives_dir = os.path.join(heatmap_save_dir, "negative")

# Function to randomly select and load heatmaps from the directory
def load_random_heatmaps(directory, num_samples=3):
    """
    Randomly select a number of heatmap files from the given directory.
    
    Args:
        directory (str): The directory containing the heatmap files.
        num_samples (int): The number of heatmap files to select.
        
    Returns:
        list: A list of loaded heatmaps.
    """
    # Get all the heatmap file names in the directory
    heatmap_files = [f for f in os.listdir(directory) if f.endswith('.png')]
    
    # Randomly select 'num_samples' heatmaps
    selected_files = random.sample(heatmap_files, num_samples)
    
    heatmaps = []
    for file in selected_files:
        # Load the heatmap image and append it to the list
        heatmap_path = os.path.join(directory, file)
        heatmap = plt.imread(heatmap_path)  # Read the image file
        heatmaps.append(heatmap)
    
    return heatmaps

# Function to plot and save the heatmaps in a single figure
def plot_and_save_heatmaps(positives, negatives, save_dir):
    """
    Plot and save a figure containing positive and negative heatmaps.
    
    Args:
        positives (list): List of positive heatmaps to plot.
        negatives (list): List of negative heatmaps to plot.
        save_dir (str): Directory to save the figure.
    """
    # Create a 2x3 grid (2 rows, 3 columns)
    fig, axes = plt.subplots(2, 3, figsize=(12, 8))  # Adjust size as needed
    
    # Plot heading for positives (Sodium Transporters)
    axes[0, 0].text(0.5, 1.05, "Sodium Transporters", ha='center', va='bottom', fontsize=14, fontweight='bold')
    for idx, sample in enumerate(positives):
        ax = axes[0, idx]
        ax.imshow(sample, cmap='jet', interpolation='nearest')
        ax.axis('off')  # Hide axes for clarity

    # Plot heading for negatives (Membrane Proteins)
    axes[1, 0].text(0.5, 1.05, "Membrane Proteins", ha='center', va='bottom', fontsize=14, fontweight='bold')
    for idx, sample in enumerate(negatives):
        ax = axes[1, idx]
        ax.imshow(sample, cmap='jet', interpolation='nearest')
        ax.axis('off')  # Hide axes for clarity

    # Save the figure at high resolution (300 DPI)
    save_path = os.path.join(save_dir, "combined_heatmaps_New.png")
    plt.tight_layout()  # Adjust layout to avoid overlap
    plt.savefig(save_path, dpi=300)  # Save with high resolution
    plt.close()  # Close the plot to free up memory

    print(f"Heatmap figure saved at {save_path}")

# Load three random positive and three random negative heatmaps
positives = load_random_heatmaps(positives_dir, num_samples=3)
negatives = load_random_heatmaps(negatives_dir, num_samples=3)

# Plot and save the figure
plot_and_save_heatmaps(positives, negatives, heatmap_save_dir)

print("Heatmap generation complete!")

Heatmap figure saved at C:/jupyter/Malik/SodiumTransporters/Code/AttentionMaps\combined_heatmaps_New.png
Heatmap generation complete!
