In [1]:
!pip install torch torch-geometric biopython transformers



Collecting torch-geometric
  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting biopython
  Downloading biopython-1.86-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (13 kB)
Downloading torch_geometric-2.7.0-py3-none-any.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m23.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading biopython-1.86-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (3.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m74.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: biopython, torch-geometric
Successfully installed biopython-1.86 torch-geometric-2.7.0


In [9]:
import torch
from torch_geometric.data import HeteroData
from torch_geometric.nn import SAGEConv
import torch.nn.functional as F
import pandas as pd
import numpy as np
import pickle
import random

# CONFIG SEC
EMBEDDINGS_PATH = "esm2_embeddings_1143_proteins.pkl"
MODEL_PATH = "kyroform_ek.pth"
EDGES_PATH = "training_edges_with_labels.csv"

#Model archi
class HeteroSAGE(torch.nn.Module):
    def __init__(self, input_dim=1280, hidden=256):
        super().__init__()
        # Layer 1: 1280 -> 256
        self.h_conv1 = SAGEConv(input_dim, hidden)
        # Layer 2: 256 -> 256 (Matches your checkpoint)
        self.h_conv2 = SAGEConv(hidden, hidden)

        self.b_conv1 = SAGEConv(input_dim, hidden)
        self.b_conv2 = SAGEConv(hidden, hidden)

    def forward(self, x_dict, edge_index_dict):
        # Extract the interaction edges
        edge_index = edge_index_dict[('human', 'interacts', 'bacterial')]
        rev_edge_index = edge_index.flip(0)

        # Human Path: Sequential processing
        h = F.relu(self.h_conv1(x_dict['human'], rev_edge_index))
        h = F.relu(self.h_conv2(h, rev_edge_index))

        # Bacterial Path: Sequential processing
        b = F.relu(self.b_conv1(x_dict['bacterial'], edge_index))
        b = F.relu(self.b_conv2(b, edge_index))

        return {'human': h, 'bacterial': b}

print("Loading emb")
with open(EMBEDDINGS_PATH, 'rb') as f:
    embeds = pickle.load(f)
print(f"Loaded embeddings for {len(embeds)} proteins")

print("Loading model")
model = HeteroSAGE(input_dim=1280, hidden=256)

# Load state dict with strict=True to ensure every layer matches perfecto-ly
try:
    model.load_state_dict(torch.load(MODEL_PATH, map_location='cpu'))
    model.eval()
    print("Model loaded successfully — architecture matched UwU")
except RuntimeError as e:
    print(f"Error loading model: {e}")

# Load edges for protein lists
df_edges = pd.read_csv(EDGES_PATH)
all_human = [p for p in df_edges['human'].unique() if p in embeds]
all_bact = [p for p in df_edges['bacterial'].unique() if p in embeds]

print(f"Ready: {len(all_human)} human, {len(all_bact)} bacterial proteins available.")

Loading precomputed ESM-2 embeddings...
Loaded embeddings for 1138 proteins
Loading trained model...
Error loading model: Error(s) in loading state_dict for HeteroSAGE:
	size mismatch for h_conv2.lin_l.weight: copying a param with shape torch.Size([256, 1280]) from checkpoint, the shape in current model is torch.Size([256, 256]).
	size mismatch for b_conv2.lin_l.weight: copying a param with shape torch.Size([256, 1280]) from checkpoint, the shape in current model is torch.Size([256, 256]).
Ready: 855 human, 283 bacterial proteins available.


In [10]:
def predict_interaction(human_id, bacterial_id):
    if human_id not in embeds or bacterial_id not in embeds:
        return 0.0

    data = HeteroData()
    data['human'].x = torch.tensor(embeds[human_id]).unsqueeze(0)
    data['bacterial'].x = torch.tensor(embeds[bacterial_id]).unsqueeze(0)

    # Provide a dummy edge for the SAGEConv logic to process could replace it later with SVS Logic?
    data['human', 'interacts', 'bacterial'].edge_index = torch.zeros((2, 1), dtype=torch.long)

    with torch.no_grad():
        z = model(data.x_dict, data.edge_index_dict)
        # Calculate dot-product similarity (or your specific scoring logic)
        score = (z['human'][0] * z['bacterial'][0]).sum().item()
        prob = torch.sigmoid(torch.tensor(score)).item()

    return prob

In [12]:
print("TESTING SAMPLE PREDICTIONS UwU")


random.seed(42)
for _ in range(5):
    h = random.choice(all_human)
    b = random.choice(all_bact)
    prob = predict_interaction(h, b)
    result = "Positive" if prob > 0.5 else "Negative"
    print(f"Pair: {h} + {b} | Score: {prob:.4f} | Prediction: {result}")

TESTING SAMPLE PREDICTIONS UwU
Pair: C9J9G2 + A0A0J6C408 | Score: 0.7088 | Prediction: Positive
Pair: Q5VV89 + A0A0J6C367 | Score: 0.6758 | Prediction: Positive
Pair: H3BLU7 + A0A0J6C5T3 | Score: 0.6108 | Prediction: Positive
Pair: A0A6Q8PFH2 + A0A0J6C5W4 | Score: 0.6473 | Prediction: Positive
Pair: O95670 + A0A0J6C2K2 | Score: 0.8040 | Prediction: Positive
