In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Define constants for input sequence lengths (thresholds)
BACTERIUM_THRESHOLD = 7000000  # length for padded bacterium sequence
PHAGE_THRESHOLD = 200000      # length for padded phage sequence

class BacteriaBranch(nn.Module):
    """CNN branch for bacterial DNA sequence."""
    def __init__(self):
        super(BacteriaBranch, self).__init__()
        # Three convolutional layers with specified filters, kernel sizes, and strides
        self.conv1 = nn.Conv1d(in_channels=4, out_channels=64, kernel_size=30, stride=10, bias=True)
        self.pool1 = nn.MaxPool1d(kernel_size=15, stride=5)
        self.conv2 = nn.Conv1d(in_channels=64, out_channels=32, kernel_size=25, stride=10, bias=True)
        self.pool2 = nn.MaxPool1d(kernel_size=10, stride=5)
        self.conv3 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=10, stride=5, bias=True)
        self.pool3 = nn.MaxPool1d(kernel_size=2, stride=2)
    
    def forward(self, x):
        # Expect x shape: (batch, 4, BACTERIUM_THRESHOLD)
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = F.relu(self.conv3(x))
        x = self.pool3(x)
        # Flatten features
        x = x.view(x.size(0), -1)  # flatten to (batch, features)
        return x

class PhageBranch(nn.Module):
    """CNN branch for phage DNA sequence."""
    def __init__(self):
        super(PhageBranch, self).__init__()
        # Two convolutional layers for the phage branch
        self.conv1 = nn.Conv1d(in_channels=4, out_channels=64, kernel_size=30, stride=10, bias=True)
        self.pool1 = nn.MaxPool1d(kernel_size=15, stride=5)
        self.conv2 = nn.Conv1d(in_channels=64, out_channels=32, kernel_size=25, stride=10, bias=True)
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
    
    def forward(self, x):
        # Expect x shape: (batch, 4, PHAGE_THRESHOLD)
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = x.view(x.size(0), -1)
        return x

class PerphectInteractionModel(nn.Module):
    """Dual-input CNN model for phage-bacteria interaction prediction."""
    def __init__(self):
        super(PerphectInteractionModel, self).__init__()
        self.bacteria_branch = BacteriaBranch()
        self.phage_branch = PhageBranch()
        # After flattening, expected concatenated feature length = 15296 (8928 + 6368)
        # This is computed from the convolution/pooling sequence given the input lengths
        self.fc1 = nn.Linear(in_features=15296, out_features=100, bias=True)
        self.dropout = nn.Dropout(p=0.1)
        self.fc2 = nn.Linear(in_features=100, out_features=1, bias=True)
        # Sigmoid will be applied in forward for binary classification output
    
    def forward(self, x_bacteria, x_phage):
        # Permute inputs if they are (batch, length, channels) to (batch, channels, length)
        if x_bacteria.dim() == 3 and x_bacteria.size(1) != 4:
            x_bacteria = x_bacteria.permute(0, 2, 1)
        if x_phage.dim() == 3 and x_phage.size(1) != 4:
            x_phage = x_phage.permute(0, 2, 1)
        # Pass through each branch
        feat_bact = self.bacteria_branch(x_bacteria)
        feat_phage = self.phage_branch(x_phage)
        # Concatenate features from both branches
        combined_feat = torch.cat([feat_bact, feat_phage], dim=1)
        # Fully connected layers for prediction
        x = F.relu(self.fc1(combined_feat))
        x = self.dropout(x)
        out = torch.sigmoid(self.fc2(x))  # sigmoid for binary interaction probability
        return out


In [2]:
import h5py
import numpy as np
import torch

def load_keras_weights_to_pytorch(pytorch_model, keras_h5_path):
    """
    Load weights from a Keras .h5 model file into a PyTorch model.
    Assumes `pytorch_model` has the same architecture as the Keras model.
    """
    # Open the Keras weights file
    with h5py.File(keras_h5_path, 'r') as f:
        # Access the 'model_weights' subgroup
        model_weights = f['model_weights']

        state_dict = {}  # will populate with parameter tensors

        # Helper to load conv weights: transpose kernel to PyTorch format
        def copy_conv(layer_name, pytorch_weight_key, pytorch_bias_key):
            keras_kernel = model_weights[layer_name][layer_name]['kernel:0'][()]
            keras_bias = model_weights[layer_name][layer_name]['bias:0'][()]
            # Convert to PyTorch tensor and permute dimensions for kernel
            state_dict[pytorch_weight_key] = torch.tensor(keras_kernel).permute(2, 1, 0)
            state_dict[pytorch_bias_key] = torch.tensor(keras_bias)

        # Bacterial branch conv layers
        copy_conv('bacterial_conv_1', 'bacteria_branch.conv1.weight', 'bacteria_branch.conv1.bias')
        copy_conv('bacterial_conv_2', 'bacteria_branch.conv2.weight', 'bacteria_branch.conv2.bias')
        copy_conv('bacterial_conv_3', 'bacteria_branch.conv3.weight', 'bacteria_branch.conv3.bias')
        # Phage branch conv layers
        copy_conv('phage_conv_1', 'phage_branch.conv1.weight', 'phage_branch.conv1.bias')
        copy_conv('phage_conv_2', 'phage_branch.conv2.weight', 'phage_branch.conv2.bias')

        # Dense layers (fully connected)
        # Keras layer names are 'dense' for the first Dense(100) and 'dense_1' for the final Dense(1)
        dense_kernel = model_weights['dense']['dense']['kernel:0'][()]   # shape (15296, 100)
        dense_bias   = model_weights['dense']['dense']['bias:0'][()]     # shape (100,)
        dense1_kernel = model_weights['dense_1']['dense_1']['kernel:0'][()]  # shape (100, 1)
        dense1_bias   = model_weights['dense_1']['dense_1']['bias:0'][()]    # shape (1,)
        # Transpose dense weight matrices for PyTorch and copy biases
        state_dict['fc1.weight'] = torch.tensor(dense_kernel).t()
        state_dict['fc1.bias']   = torch.tensor(dense_bias)
        state_dict['fc2.weight'] = torch.tensor(dense1_kernel).t()
        state_dict['fc2.bias']   = torch.tensor(dense1_bias)

    # Load state_dict into the PyTorch model
    pytorch_model.load_state_dict(state_dict)
    return pytorch_model



In [3]:
#Example usage:
model = PerphectInteractionModel()
model = load_keras_weights_to_pytorch(model, "/Users/arthurbabey/Desktop/PerphectPredictor/data/saved_model/model_v1.h5")

In [4]:
model

PerphectInteractionModel(
  (bacteria_branch): BacteriaBranch(
    (conv1): Conv1d(4, 64, kernel_size=(30,), stride=(10,))
    (pool1): MaxPool1d(kernel_size=15, stride=5, padding=0, dilation=1, ceil_mode=False)
    (conv2): Conv1d(64, 32, kernel_size=(25,), stride=(10,))
    (pool2): MaxPool1d(kernel_size=10, stride=5, padding=0, dilation=1, ceil_mode=False)
    (conv3): Conv1d(32, 32, kernel_size=(10,), stride=(5,))
    (pool3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (phage_branch): PhageBranch(
    (conv1): Conv1d(4, 64, kernel_size=(30,), stride=(10,))
    (pool1): MaxPool1d(kernel_size=15, stride=5, padding=0, dilation=1, ceil_mode=False)
    (conv2): Conv1d(64, 32, kernel_size=(25,), stride=(10,))
    (pool2): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc1): Linear(in_features=15296, out_features=100, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (fc2): Linear(in_features=100, out_features=

In [5]:
import pandas as pd

bacteria_df = pd.read_csv("./data/public_data_set/bacteria_df.csv")
couples_df = pd.read_csv("./data/public_data_set/couples_df.csv")
phages_df = pd.read_csv("./data/public_data_set/phages_df.csv")



In [6]:
bacteria_df

Unnamed: 0,bacterium_id,bacterium_sequence,sequence_length
0,5859,TTGACTGCTGACCCCGACCCACCGTTCGTCGCCGTCTGGAACAGCG...,6988208
1,5787,GGCGCGGTCCTGACGGCACCGGCAGCCGACGCCGCTCCGAAGGAAG...,8409936
2,153,ATGACAGTAGACGAAGCCAACCACGCCAATACTGTCGGAAGTTCCT...,4434904
3,1869,CCCCGTGACTTCAGACCGAGACACTTTCGGCGAGGTGTGGCAGCAG...,5701501
4,126,CCGGTCGACGAGCGGGCTTGTCCCCTGCCGGGCTGGTGCTTCTGGT...,8226158
...,...,...,...
89,244,AGAGATTACGTCTGGTTGCAAGAGATCATGACAGGGGGAATTGGTT...,4857450
90,1053,GTGTCAGTGGAACTTTGGCAGCAGTGCGTGGAGCTTTTGCGCGATG...,6722539
91,586,TAATAGATGGATACTATGTTTTTTGGATTACCCACAGAAACCACAG...,1842899
92,799,TTGTTGATATTCTGTTTTTTCTTTTTTAGTTTTCCACATAAAAAAT...,1900521


In [7]:
phages_df

Unnamed: 0,phage_id,phage_sequence,sequence_length
0,5326,TGCGGCCGCCCCATCCTGTACGGGTTTCCAAGTCGATCGGAGGGCA...,53124
1,6247,GGCTTTCGTGTGAGCCGTGATGTTTTCACGAATATGTGCCCCACCT...,74483
2,1976,GTGGGAATTTTTTTTTTGGGTTGCGCGGTGATCGCCGATGACGACG...,50781
3,430,ATGGCTTCGACTCAGACTCCAGCCGTCGGCAAGACCACGGCCATCG...,71565
4,431,TGCGGCTGAGCCATCGTGTACGGGTTTCCAAGTCCATCAGAGCCGG...,53396
...,...,...,...
3454,6189,TGCGGCTGCCAGATCGTGTACGGGTTTGGAAGTCGACGGAGAGAAC...,49487
3455,6190,ATGCTGGTCATGACCACAAACCACCGCCTCATCTACCTCGTCGGCG...,67744
3456,6191,TGCGGCCGAGGCATCGTGTACGGGTTTCCAAGCCCGAACTAACCAC...,52924
3457,6192,TGCAGATTTTGGTCTGTACGGAACCCGGGGGTTTCGCGGTTTCCCC...,56275


In [8]:
couples_df

Unnamed: 0,id,bacterium_id,phage_id,interaction_type
0,728,5859,5924,1
1,730,5859,1976,1
2,731,5859,430,1
3,732,5859,431,1
4,733,5859,433,1
...,...,...,...,...
4197,27255,153,5913,1
4198,27299,61,6142,1
4199,27307,5731,6150,1
4200,27314,5731,6158,1


In [9]:
import numpy as np

onehot_map = {
    "A": torch.tensor([1, 0, 0, 0], dtype=torch.uint8),
    "C": torch.tensor([0, 1, 0, 0], dtype=torch.uint8),
    "G": torch.tensor([0, 0, 1, 0], dtype=torch.uint8),
    "T": torch.tensor([0, 0, 0, 1], dtype=torch.uint8),
    "R": torch.tensor([1, 0, 1, 0], dtype=torch.uint8),
    "Y": torch.tensor([0, 1, 0, 1], dtype=torch.uint8),
    "K": torch.tensor([0, 0, 1, 1], dtype=torch.uint8),
    "M": torch.tensor([1, 1, 0, 0], dtype=torch.uint8),
    "S": torch.tensor([0, 1, 1, 0], dtype=torch.uint8),
    "W": torch.tensor([1, 0, 0, 1], dtype=torch.uint8),
    "B": torch.tensor([0, 1, 1, 1], dtype=torch.uint8),
    "D": torch.tensor([1, 0, 1, 1], dtype=torch.uint8),
    "H": torch.tensor([1, 1, 0, 1], dtype=torch.uint8),
    "V": torch.tensor([1, 1, 1, 0], dtype=torch.uint8),
    "N": torch.tensor([1, 1, 1, 1], dtype=torch.uint8),
    "Z": torch.tensor([0, 0, 0, 0], dtype=torch.uint8)
}

def translate_sequence_onehot(seq: str):
    translated_seq = np.zeros((len(seq), 4), np.uint8)
    for base_idx, base in enumerate(seq.upper()):
        translated_seq[base_idx, :] = onehot_map[base]
    return translated_seq     

def pad_sequence_onehot(seq: np.ndarray, size: int):
    if seq.shape[0] > size:
        padded_seq = seq[:size, :]
    else:
        padded_seq = np.zeros((size, seq.shape[1]), np.uint8)
        padded_seq[:seq.shape[0], :] = seq
    return padded_seq    

In [10]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader

# Fixed lengths for zero padding (as defined in your Keras code)
BACTERIUM_THRESHOLD = 7_000_000
PHAGE_THRESHOLD = 200_000

def encode_pair(row, phages_df, bacteria_df):
    """
    Given a row from couples_df, get the corresponding phage and bacterial sequences,
    translate them to one-hot, pad them to the fixed length, and convert to PyTorch tensors.
    Returns:
        bacterium_tensor: shape (4, BACTERIUM_THRESHOLD)
        phage_tensor: shape (4, PHAGE_THRESHOLD)
    """
    phage_id = row['phage_id']
    bacterium_id = row['bacterium_id']
    
    # Get the raw sequence strings from the DataFrames.
    # It is assumed that phages_df and bacteria_df have 'phage_id' and 'bacterium_id' as columns.
    phage_seq = phages_df.loc[phages_df['phage_id'] == phage_id, 'phage_sequence'].values[0]
    bacterium_seq = bacteria_df.loc[bacteria_df['bacterium_id'] == bacterium_id, 'bacterium_sequence'].values[0]
    
    # Translate sequences to one-hot encoding using your transforms.
    phage_oh = translate_sequence_onehot(phage_seq)       # shape: (seq_length, 4)
    bacterium_oh = translate_sequence_onehot(bacterium_seq) # shape: (seq_length, 4)
    
    # Pad the sequences to fixed length (zero padded).
    phage_padded = pad_sequence_onehot(phage_oh, PHAGE_THRESHOLD)
    bacterium_padded = pad_sequence_onehot(bacterium_oh, BACTERIUM_THRESHOLD)
    
    # PyTorch uses channel-first: convert shape (L, 4) to (4, L) and then to tensor.
    phage_tensor = torch.tensor(phage_padded.T, dtype=torch.float32)
    bacterium_tensor = torch.tensor(bacterium_padded.T, dtype=torch.float32)
    
    return bacterium_tensor, phage_tensor

class InteractionDataset(Dataset):
    """
    A PyTorch Dataset that reads from the couples, phages, and bacteria DataFrames.
    Each sample consists of a bacterial tensor, a phage tensor, and the label.
    """
    def __init__(self, couples_df, phages_df, bacteria_df):
        self.couples_df = couples_df.reset_index(drop=True)
        self.phages_df = phages_df
        self.bacteria_df = bacteria_df

    def __len__(self):
        return len(self.couples_df)

    def __getitem__(self, idx):
        row = self.couples_df.iloc[idx]
        # interaction_type should be 0 or 1
        label = torch.tensor([row['interaction_type']], dtype=torch.float32)
        bacterium_tensor, phage_tensor = encode_pair(row, self.phages_df, self.bacteria_df)
        return bacterium_tensor, phage_tensor, label

bacteria_df = pd.read_csv("./data/public_data_set/bacteria_df.csv")
couples_df = pd.read_csv("./data/public_data_set/couples_df.csv")
phages_df = pd.read_csv("./data/public_data_set/phages_df.csv")


dataset = InteractionDataset(couples_df, phages_df, bacteria_df)
# Using batch_size=1 for simplicity (due to the large sequence sizes)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

In [11]:
dataset[0]

  translated_seq[base_idx, :] = onehot_map[base]


(tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 1.,  ..., 0., 0., 0.],
         [1., 1., 0.,  ..., 0., 0., 0.]]),
 tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 1.,  ..., 0., 0., 0.],
         [0., 1., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.]]),
 tensor([1.]))

In [12]:
model

PerphectInteractionModel(
  (bacteria_branch): BacteriaBranch(
    (conv1): Conv1d(4, 64, kernel_size=(30,), stride=(10,))
    (pool1): MaxPool1d(kernel_size=15, stride=5, padding=0, dilation=1, ceil_mode=False)
    (conv2): Conv1d(64, 32, kernel_size=(25,), stride=(10,))
    (pool2): MaxPool1d(kernel_size=10, stride=5, padding=0, dilation=1, ceil_mode=False)
    (conv3): Conv1d(32, 32, kernel_size=(10,), stride=(5,))
    (pool3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (phage_branch): PhageBranch(
    (conv1): Conv1d(4, 64, kernel_size=(30,), stride=(10,))
    (pool1): MaxPool1d(kernel_size=15, stride=5, padding=0, dilation=1, ceil_mode=False)
    (conv2): Conv1d(64, 32, kernel_size=(25,), stride=(10,))
    (pool2): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc1): Linear(in_features=15296, out_features=100, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (fc2): Linear(in_features=100, out_features=

In [13]:
# move to GPU apple
device = torch.device('cpu')

In [14]:
PYTORCH_ENABLE_MPS_FALLBACK=1

In [15]:
model.eval()
model.to(device)
predictions = []
with torch.no_grad():
    for bact_tensor, phage_tensor, label in dataloader:
        bact_tensor = bact_tensor.to(device)
        phage_tensor = phage_tensor.to(device)
        label = label.to(device)
        # Input tensors should be of shape (batch, 4, length)
        output = model(bact_tensor, phage_tensor)
        predictions.append(output.item())
        print("Predicted interaction probability:", output.item(), "True label:", label.item())

# Optionally, save predictions back to a DataFrame
couples_df['predictions'] = predictions
couples_df.to_csv("couples_predictions.csv", index=False)

  translated_seq[base_idx, :] = onehot_map[base]


Predicted interaction probability: 0.039130888879299164 True label: 1.0
Predicted interaction probability: 0.025687379762530327 True label: 1.0
Predicted interaction probability: 0.012978702783584595 True label: 1.0
Predicted interaction probability: 0.03941984474658966 True label: 1.0
Predicted interaction probability: 0.013013609685003757 True label: 1.0
Predicted interaction probability: 0.04743930324912071 True label: 1.0
Predicted interaction probability: 0.027917126193642616 True label: 1.0
Predicted interaction probability: 0.023473525419831276 True label: 1.0
Predicted interaction probability: 0.01852506771683693 True label: 1.0
Predicted interaction probability: 0.02717985212802887 True label: 1.0
Predicted interaction probability: 0.07235964387655258 True label: 1.0
Predicted interaction probability: 0.03622743859887123 True label: 1.0
Predicted interaction probability: 0.0060615199618041515 True label: 1.0
Predicted interaction probability: 0.010982980020344257 True label: 1

KeyboardInterrupt: 