In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Define constants for input sequence lengths (thresholds)
BACTERIUM_THRESHOLD = 7000000  # length for padded bacterium sequence
PHAGE_THRESHOLD = 200000      # length for padded phage sequence

class BacteriaBranch(nn.Module):
    """CNN branch for bacterial DNA sequence."""
    def __init__(self):
        super(BacteriaBranch, self).__init__()
        # Three convolutional layers with specified filters, kernel sizes, and strides
        self.conv1 = nn.Conv1d(in_channels=4, out_channels=64, kernel_size=30, stride=10, bias=True)
        self.pool1 = nn.MaxPool1d(kernel_size=15, stride=5)
        self.conv2 = nn.Conv1d(in_channels=64, out_channels=32, kernel_size=25, stride=10, bias=True)
        self.pool2 = nn.MaxPool1d(kernel_size=10, stride=5)
        self.conv3 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=10, stride=5, bias=True)
        self.pool3 = nn.MaxPool1d(kernel_size=2, stride=2)
    
    def forward(self, x):
        # Expect x shape: (batch, 4, BACTERIUM_THRESHOLD)
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = F.relu(self.conv3(x))
        x = self.pool3(x)
        # Flatten features
        x = x.view(x.size(0), -1)  # flatten to (batch, features)
        return x

class PhageBranch(nn.Module):
    """CNN branch for phage DNA sequence."""
    def __init__(self):
        super(PhageBranch, self).__init__()
        # Two convolutional layers for the phage branch
        self.conv1 = nn.Conv1d(in_channels=4, out_channels=64, kernel_size=30, stride=10, bias=True)
        self.pool1 = nn.MaxPool1d(kernel_size=15, stride=5)
        self.conv2 = nn.Conv1d(in_channels=64, out_channels=32, kernel_size=25, stride=10, bias=True)
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
    
    def forward(self, x):
        # Expect x shape: (batch, 4, PHAGE_THRESHOLD)
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = x.view(x.size(0), -1)
        return x

class PerphectInteractionModel(nn.Module):
    """Dual-input CNN model for phage-bacteria interaction prediction."""
    def __init__(self):
        super(PerphectInteractionModel, self).__init__()
        self.bacteria_branch = BacteriaBranch()
        self.phage_branch = PhageBranch()
        # After flattening, expected concatenated feature length = 15296 (8928 + 6368)
        # This is computed from the convolution/pooling sequence given the input lengths
        self.fc1 = nn.Linear(in_features=15296, out_features=100, bias=True)
        self.dropout = nn.Dropout(p=0.1)
        self.fc2 = nn.Linear(in_features=100, out_features=1, bias=True)
        # Sigmoid will be applied in forward for binary classification output
    
    def forward(self, x_bacteria, x_phage):
        # Permute inputs if they are (batch, length, channels) to (batch, channels, length)
        if x_bacteria.dim() == 3 and x_bacteria.size(1) != 4:
            x_bacteria = x_bacteria.permute(0, 2, 1)
        if x_phage.dim() == 3 and x_phage.size(1) != 4:
            x_phage = x_phage.permute(0, 2, 1)
        # Pass through each branch
        feat_bact = self.bacteria_branch(x_bacteria)
        feat_phage = self.phage_branch(x_phage)
        # Concatenate features from both branches
        combined_feat = torch.cat([feat_bact, feat_phage], dim=1)
        # Fully connected layers for prediction
        x = F.relu(self.fc1(combined_feat))
        x = self.dropout(x)
        out = torch.sigmoid(self.fc2(x))  # sigmoid for binary interaction probability
        return out


In [2]:
import h5py
import numpy as np
import torch

def load_keras_weights_to_pytorch(pytorch_model, keras_h5_path):
    """
    Load weights from a Keras .h5 model file into a PyTorch model.
    Assumes `pytorch_model` has the same architecture as the Keras model.
    """
    # Open the Keras weights file
    with h5py.File(keras_h5_path, 'r') as f:
        # Access the 'model_weights' subgroup
        model_weights = f['model_weights']

        state_dict = {}  # will populate with parameter tensors

        # Helper to load conv weights: transpose kernel to PyTorch format
        def copy_conv(layer_name, pytorch_weight_key, pytorch_bias_key):
            keras_kernel = model_weights[layer_name][layer_name]['kernel:0'][()]
            keras_bias = model_weights[layer_name][layer_name]['bias:0'][()]
            # Convert to PyTorch tensor and permute dimensions for kernel
            state_dict[pytorch_weight_key] = torch.tensor(keras_kernel).permute(2, 1, 0)
            state_dict[pytorch_bias_key] = torch.tensor(keras_bias)

        # Bacterial branch conv layers
        copy_conv('bacterial_conv_1', 'bacteria_branch.conv1.weight', 'bacteria_branch.conv1.bias')
        copy_conv('bacterial_conv_2', 'bacteria_branch.conv2.weight', 'bacteria_branch.conv2.bias')
        copy_conv('bacterial_conv_3', 'bacteria_branch.conv3.weight', 'bacteria_branch.conv3.bias')
        # Phage branch conv layers
        copy_conv('phage_conv_1', 'phage_branch.conv1.weight', 'phage_branch.conv1.bias')
        copy_conv('phage_conv_2', 'phage_branch.conv2.weight', 'phage_branch.conv2.bias')

        # Dense layers (fully connected)
        # Keras layer names are 'dense' for the first Dense(100) and 'dense_1' for the final Dense(1)
        dense_kernel = model_weights['dense']['dense']['kernel:0'][()]   # shape (15296, 100)
        dense_bias   = model_weights['dense']['dense']['bias:0'][()]     # shape (100,)
        dense1_kernel = model_weights['dense_1']['dense_1']['kernel:0'][()]  # shape (100, 1)
        dense1_bias   = model_weights['dense_1']['dense_1']['bias:0'][()]    # shape (1,)
        # Transpose dense weight matrices for PyTorch and copy biases
        state_dict['fc1.weight'] = torch.tensor(dense_kernel).t()
        state_dict['fc1.bias']   = torch.tensor(dense_bias)
        state_dict['fc2.weight'] = torch.tensor(dense1_kernel).t()
        state_dict['fc2.bias']   = torch.tensor(dense1_bias)

    # Load state_dict into the PyTorch model
    pytorch_model.load_state_dict(state_dict)
    return pytorch_model

# Example usage:
# model = PerphectInteractionModel()
# model = load_keras_weights_to_pytorch(model, "model_v1.h5")


ModuleNotFoundError: No module named 'h5py'

In [3]:
class GradCAM:
    """Grad-CAM for dual-branch 1D CNN model (phage-bacteria interaction)."""
    def __init__(self, model, target_layer_bacteria=None, target_layer_phage=None):
        self.model = model.eval()  # put model in evaluation mode
        # Identify target layers (last conv layers in each branch)
        if target_layer_bacteria is None:
            target_layer_bacteria = model.bacteria_branch.conv3
        if target_layer_phage is None:
            target_layer_phage = model.phage_branch.conv2
        self.target_layer_bacteria = target_layer_bacteria
        self.target_layer_phage = target_layer_phage

        # Placeholders for features and gradients
        self.bact_features = None
        self.phage_features = None

        # Forward hooks to capture feature maps
        target_layer_bacteria.register_forward_hook(self._save_bact_feature)
        target_layer_phage.register_forward_hook(self._save_phage_feature)
        # We will use .retain_grad() on feature maps to get gradients during backward

    def _save_bact_feature(self, module, input, output):
        """Forward hook: store bacterial conv feature map and enable gradient retention."""
        self.bact_features = output  # feature map from conv3
        output.retain_grad()        # keep gradient for this tensor

    def _save_phage_feature(self, module, input, output):
        """Forward hook: store phage conv feature map and enable gradient retention."""
        self.phage_features = output  # feature map from conv2
        output.retain_grad()

    def generate(self, bact_input, phage_input):
        """
        Compute Grad-CAM heatmaps for the given inputs.
        Inputs should be tensors (1 x 4 x length) for each branch.
        Returns:
            cam_bact (numpy 1D array of length ~279) – importance map for bacterium sequence.
            cam_phage (numpy 1D array of length ~199) – importance map for phage sequence.
        """
        # Ensure model is in eval and gradients are zeroed
        self.model.eval()
        self.model.zero_grad()

        # Forward pass through the model
        output = self.model(bact_input, phage_input)
        # We assume a single input (batch size 1). If batch>1, pick the first instance or specify index.
        target_score = output.squeeze()  # scalar prediction for interaction probability
        # Backward pass to compute gradients of target_score w.r.t. feature maps
        target_score.backward()

        # Get the gradients of the conv feature maps
        grad_bact = self.bact_features.grad  # shape: (1, channels, L_bact_feature)
        grad_phage = self.phage_features.grad  # shape: (1, channels, L_phage_feature)

        # Compute channel-wise weights: global average pooling of gradients over the length dimension
        # This yields a weight for each channel (filter) of the conv layer
        weights_bact = grad_bact.mean(dim=2, keepdim=True)  # shape: (1, channels, 1)
        weights_phage = grad_phage.mean(dim=2, keepdim=True)

        # Weight the feature maps by these importance weights and sum over channels
        cam_bact = (weights_bact * self.bact_features).sum(dim=1)  # shape: (1, L_bact_feature)
        cam_phage = (weights_phage * self.phage_features).sum(dim=1)  # shape: (1, L_phage_feature)

        # Apply ReLU to the weighted maps to keep only positive contributions
        cam_bact = F.relu(cam_bact)
        cam_phage = F.relu(cam_phage)

        # Remove batch dimension and normalize the heatmaps to [0, 1]
        cam_bact = cam_bact.detach().cpu().numpy()[0]  # shape (L_bact_feature,)
        cam_phage = cam_phage.detach().cpu().numpy()[0]
        if cam_bact.max() != 0:
            cam_bact = cam_bact / cam_bact.max()
        if cam_phage.max() != 0:
            cam_phage = cam_phage / cam_phage.max()

        return cam_bact, cam_phage

# Example usage:
# gradcam = GradCAM(model)
# cam_bact, cam_phage = gradcam.generate(bact_tensor, phage_tensor)


In [4]:
import numpy as np
import matplotlib.pyplot as plt

def plot_sequence_gradcam(sequence, importance_scores, start=None, end=None, window=100, cmap='coolwarm'):
    """
    Plot a segment of the sequence with Grad-CAM importance scores overlaid as a heatmap.
    - sequence: DNA sequence string (e.g., "ACGT...") 
    - importance_scores: 1D numpy array of Grad-CAM scores (normalized 0 to 1) corresponding to sequence positions.
    - start, end: optional indices to specify the sequence region to plot. If None, will focus on top scoring region.
    - window: if start/end not provided, the number of bases around the top score to display.
    - cmap: colormap for the heatmap (default 'coolwarm').
    """
    seq_len = len(sequence)
    # Determine region to visualize
    if start is None or end is None:
        # Find the position of maximum importance and center the window around it
        max_idx = int(np.argmax(importance_scores))
        half_win = window // 2
        start = max(0, max_idx - half_win)
        end = min(seq_len, max_idx + half_win)
    # Extract the region of interest
    seq_region = sequence[start:end]
    scores_region = importance_scores[start:end]

    # Create heatmap image (1 x region_length) using importance scores
    fig, ax = plt.subplots(figsize=(max(10, 0.2 * len(seq_region)), 2))
    # Reshape scores to (1, L) for imshow
    heatmap = scores_region[np.newaxis, :]  # shape (1, L_region)
    im = ax.imshow(heatmap, aspect='auto', cmap=cmap, vmin=0, vmax=1)
    # Set nucleotide labels on the x-axis
    ax.set_xticks(np.arange(len(seq_region)))
    ax.set_xticklabels(list(seq_region))
    ax.set_yticks([])  # hide y-axis
    ax.set_xlabel(f"Sequence positions {start} to {end-1}")
    # Rotate labels if the region is long for better visibility
    if len(seq_region) > 20:
        plt.setp(ax.get_xticklabels(), rotation=90, fontsize=8)
    # Add a color bar to show the importance scale
    plt.colorbar(im, ax=ax, fraction=0.015, pad=0.1, label='Grad-CAM importance')
    plt.title("Grad-CAM highlight on sequence segment")
    plt.tight_layout()
    plt.show()
    return fig, ax

# Example usage (after obtaining cam_bact and cam_phage from GradCAM):
# original_bact_seq = "ACGTTG... (length 7000000, including 'Z' for padding)" 
# fig, ax = plot_sequence_gradcam(original_bact_seq, cam_bact, window=100)
