## The purpose of this notebook is to be the final version of EnhancerDetector for publication

### This notebook will take an input fasta file of 400 base pair length and output their probability of being a enhancer
### If using fly then the max length can be 500 base pair
### Optional, output will also have a Class Activation map of the sequences.

### @author: Luis Solis, Bioinformatics Toolsmith Laboratory, Texas A&M University-Kingsville
### @author: Dr. Hani Z. Girgis, Bioinformatics Toolsmith Laboratory, Texas A&M University-Kingsville

#### Date Created: 05-27-2025

In [None]:
import tensorflow as tf
from tensorflow.keras.models import load_model, Model

from Nets import CustomConvLayer
from Metrics import weighted_f1_score, specificity
from OneNucleotideIndexer import OneNucleotideIndexer

from Bio import SeqIO
import pickle
import numpy as np
import matplotlib.pyplot as plt
import sys

In [None]:
output_cam_pdf = True

In [None]:
'''
Input will include all sequences that will be tested to see if they are an enhancer or not, this must be in fasta format
Output will include the sequence id and their probability of being an enhancer.
'''

similar_sequences_file = f'Test_Input/input_human.fasta'

model_folder = 'Models/'
network = f'{model_folder}/Human/Single_Classifier_64_3_20.keras'
indexer_dir      = f'{model_folder}/Human/indexer.pkl'

output_dir             = f'Output'

max_len = 400

use_fly = False

### Fly uses a ensemble of three networks, if needed then we load all networks

In [None]:
if use_fly:
    max_len = 500
    
    fly_network1         = f'{model_folder}/Fly/Single_Classifier_40_3_20.keras'
    fly_network2         = f'{model_folder}/Fly/Single_Classifier_32_3_20.keras'
    fly_network_finetune = f'{model_folder}/Fly/FineTune_Classifier_64_3_20_With_No_Convolution_0.h5'
    
    human_indexer_dir = f'{model_folder}/Fly/indexer_human.pkl'
    fly_indexer_dir   = f'{model_folder}/Fly/indexer_fly.pkl'
    
    fly_model1 = load_model(fly_network1, custom_objects={'CustomConvLayer': CustomConvLayer, 'specificity': specificity, 'weighted_f1_score': weighted_f1_score})
    fly_model2 = load_model(fly_network2, custom_objects={'CustomConvLayer': CustomConvLayer, 'specificity': specificity, 'weighted_f1_score': weighted_f1_score})
    fly_model_finetune = load_model(fly_network_finetune, custom_objects={'CustomConvLayer': CustomConvLayer, 'specificity': specificity, 'weighted_f1_score': weighted_f1_score})

    with open(human_indexer_dir, 'rb') as f:
        human_indexer = pickle.load(f)
    human_indexer = OneNucleotideIndexer(max_len, human_indexer)
    with open(fly_indexer_dir, 'rb') as f:
        fly_indexer = pickle.load(f)

### Load model used for EnhancerDetector
### Load indexer used for encoding the sequences to numerical format the model understands 

In [None]:
if not use_fly:
    model = load_model(network, custom_objects={'CustomConvLayer': CustomConvLayer, 'specificity': specificity, 'weighted_f1_score': weighted_f1_score})
    
    with open(indexer, 'rb') as f:
        indexer = pickle.load(f)

### Parse input files and grab their names for output and CAM
### Encode the input sequences

In [None]:
similar_seq_list = list(SeqIO.parse(similar_sequences_file, "fasta"))

In [None]:
similar_name_list = []

for seq in similar_seq_list:
    similar_name_list.append(seq.id)

In [None]:
if not use_fly:
    matrix  = indexer.encode_list(similar_seq_list)
else:
    matrix_fly   = fly_indexer.encode_list(similar_seq_list)
    matrix_human = human_indexer.encode_list(similar_seq_list)

### Create a zero tensor with shape of input for the model
### Fill in the tensor with data from input files

In [None]:
if use_fly:
    batch_size   = matrix_fly.shape[0]
    row_size     = 1
    column_size  = max_len
    channel_size = 1
    
    tensor  = np.zeros((batch_size, row_size, column_size, channel_size), dtype=np.int8)
    tensor_fly  = np.zeros((batch_size, row_size, column_size, channel_size), dtype=np.int8)
else:
    batch_size   = matrix.shape[0]
    row_size     = 1
    column_size  = max_len
    channel_size = 1
    
    tensor  = np.zeros((batch_size, row_size, column_size, channel_size), dtype=np.int8)

In [None]:
tensor.shape

In [None]:
if not use_fly: 
    for i in range(batch_size):
        tensor[i, 0, :, 0] = matrix[i]
else:
    for i in range(batch_size):
        tensor_fly[i, 0, :, 0] = matrix_fly[i]
        tensor[i, 0, :, 0] = matrix_human[i]

### Predict the tensor and write results to output file

In [None]:
if not use_fly:
    output_prediction = model.predict(tensor)
else:
    pred1 = fly_model1.predict(tensor_fly)
    pred2 = fly_model2.predict(tensor_fly)
    pred3 = fly_model_finetune.predict(tensor)

    output_prediction = np.mean([pred1, pred2, pred3], axis=0)

In [None]:
formatted_output = [f"{value[0]:.2f}" for value in output_prediction]

In [None]:
with open(f'{output_dir}/Model_Output.txt', 'w') as file:
    for name, percentage in zip(similar_name_list, formatted_output):
        file.write(f"{name} {percentage}\n")

### Below is code for making the CAM model
### Cam model was based on code from Deep Learning with Python by Francois Chollet

In [None]:
if not use_fly:
    input_tensor = model.input
    
    last_conv_layer = model.get_layer('custom_conv_layer_3')  
    
    cam_model = Model(inputs=input_tensor, outputs=last_conv_layer.output)
else:
    input_tensor = fly_model1.input
    
    last_conv_layer = fly_model1.get_layer('custom_conv_layer_3')  
    
    cam_model = Model(inputs=input_tensor, outputs=last_conv_layer.output)

In [None]:
if not use_fly:
    first_dense_layer = model.get_layer('fc_layer_1')
    
    class_model = Model(inputs=first_dense_layer.input, outputs=model.output)
else:
    first_dense_layer = fly_model1.get_layer('fc_layer_1')
    
    class_model = Model(inputs=first_dense_layer.input, outputs=fly_model1.output)

In [None]:
def plot_CAM_map(heatmap_interpolated_list, output_dir, name_list, save_pdf):
    """
    Calculates a Class Activation Map (CAM) for a single sequence input.

    Inputs:
    - x_batch_sample (numpy array): A single input tensor of shape (1, 1, 400, 1)
    
    Returns:
    - heatmap (numpy array): A 1D array representing the importance of each region
      in the sequence, normalized between 0 and 1.
    """
    
    num_sequences = len(heatmap_interpolated_list)  # Get the actual number of heatmaps
    if num_sequences == 0:
        print("No heatmaps to plot.")
        return
    
    # Create subplots based on the number of sequences
    fig, axs = plt.subplots(num_sequences, 1, figsize=(8.5, 2 * num_sequences))  # Dynamic figure height
    if num_sequences == 1:
        axs = [axs]  # Make it a list for consistency in the loop
    
    for i, heatmap_interpolated in enumerate(heatmap_interpolated_list):
        # Reshape the heatmap for visualization
        image = axs[i].matshow(heatmap_interpolated.reshape(1, -1), cmap='jet', aspect='auto', vmin=0, vmax=1)
        
        # Customize the plot appearance
        axs[i].set_yticks([])
        axs[i].xaxis.set_ticks_position('bottom') 
        axs[i].set_xlim(-0.5, len(heatmap_interpolated))
        
        # Add title dynamically (use name if available)
        title = f'Seq: {name_list}'
        axs[i].set_title(title, fontsize=10)
        
        # Add x-axis label to each plot
        axs[i].set_xlabel('Nucleotide position')
        
        # Hide x-ticks for all except the last plot
        if i != num_sequences - 1:
            axs[i].set_xticks([])
        
        # Hide box around the heatmap
        for spine in axs[i].spines.values():
            spine.set_visible(False)
    
    # Add a single color bar if there are multiple plots
    fig.colorbar(image, ax=axs, orientation='vertical', fraction=0.025, pad=0.02)
    
    # Adjust layout
    #plt.tight_layout()
    
    # Save as PDF if specified
    if save_pdf:
        plt.savefig(f'{output_dir}.pdf') 

    # Show the plot
    plt.close(fig)

In [None]:
def calculate_cam(x_batch_sample):
    """
    Plots one or more CAM heatmaps for enhancer sequences.

    Inputs:
    - heatmap_interpolated_list (list): List of 1D numpy arrays of equal length,
      typically interpolated to 400 positions.
    - output_dir (str): Directory where PDF heatmaps will be saved.
    - name_list (list): List of sequence names for labeling each heatmap.
    - save_pdf (bool): If True, saves a PDF file per input set.
    """
    
    with tf.GradientTape() as tape:
        # Pass input through the cam model and get the output from the last conv layer
        cam_output = cam_model(x_batch_sample, training=False)  # Shape: (1, 1, 22, 512)
        
        # Flatten the cam output to match the input shape expected by the class model
        cam_output_flattened = tf.reshape(cam_output, (1, -1))
        
        # Watch the cam_output tensor for gradient calculation
        tape.watch(cam_output)
        
        # Pass the flattened output through the classification model
        preds = class_model(cam_output_flattened, training=False)
        
        # Choose the target prediction (for positive/negative class)
        target_class_pred = preds[0]  # Assuming binary classification, adjust as needed
    
    # Calculate gradients with respect to cam_output
    grads = tape.gradient(target_class_pred, cam_output)  # Shape: (1, 1, 22, 512)

    # Pool the gradients across the spatial dimensions (1, 22) and reduce to get channel-wise weights
    pooled_grads = tf.reduce_mean(grads, axis=(1, 2))  # Shape: (1, 512)

    # Multiply each channel by its corresponding gradient weight
    cam_output = cam_output[0]  # Remove batch dimension (22, 512)
    heatmap = cam_output * pooled_grads[0]  # Shape: (22, 512)

    # Aggregate across channels to get the heatmap
    heatmap = tf.reduce_mean(heatmap, axis=-1)  # Shape: (22,)

    # Apply ReLU to ensure only positive contributions are kept
    heatmap = tf.nn.relu(heatmap)

    # Normalize heatmap to range [0, 1] for better visualization
    heatmap = heatmap / tf.reduce_max(heatmap)
    
    #print("Heatmap:", heatmap.numpy())
    
    return heatmap.numpy()

In [None]:
def get_sequence(idx):
    """
    Retrieves the name of a sequence from the preloaded name list.

    Inputs:
    - idx (int): Index of the sequence in the input FASTA list.

    Returns:
    - seq_name (str): Identifier of the sequence for labeling outputs.
    """
    seq_name = similar_name_list[idx]

    
    return seq_name

### The CAM only gets generated if output_cam_pdf is True
### The code will go through each sequence in input and calculate a cam and plot the heatmap
### The heatmap will then get outputed to the output file as a pdf

In [None]:
if output_cam_pdf:
    for i in range(tensor.shape[0]):
        if use_fly:
            x_input = tensor_fly[i:i+1]
        else:
            x_input = tensor[i:i+1]

        heatmap = calculate_cam(x_input)

        # Interpolate CAM to 400 bp
        heatmap = heatmap.flatten()
        old_indices = np.linspace(0, heatmap.shape[0] - 1, num=heatmap.shape[0])
        new_indices = np.linspace(0, heatmap.shape[0] - 1, num=max_len)
        heatmap_interpolated = np.interp(new_indices, old_indices, heatmap)

        name = get_sequence(i)
        plot_CAM_map([heatmap_interpolated], f'{output_dir}/{name}_CAM', [name], save_pdf=True)