In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, average_precision_score
from torch_geometric.data import Data
from torch_geometric.nn import RGCNConv
from torch_geometric.utils import negative_sampling
from transformers import BertModel, BertTokenizer
from tqdm import tqdm
import pickle, os

In [2]:
# Load dataset
# -----------------------------
df = pd.read_csv("df_actions_27k.csv")

In [3]:
df.head()

Unnamed: 0,sequence_a,sequence_b,item_id_a,item_id_b,mode,is_directional,a_is_acting,score
0,MGLTVSALFSRIFGKKQMRILMVGLDAAGKTTILYKLKLGEIVTTI...,MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGER...,9606.ENSP00000000233,9606.ENSP00000250971,reaction,t,t,900
1,MGLTVSALFSRIFGKKQMRILMVGLDAAGKTTILYKLKLGEIVTTI...,MALWMRLLPLLALLALWGPDPAAAFVNQHLCGSHLVEALYLVCGER...,9606.ENSP00000000233,9606.ENSP00000250971,reaction,t,f,900
2,MGLTVSALFSRIFGKKQMRILMVGLDAAGKTTILYKLKLGEIVTTI...,MTECFLPPTSSPSEHRRVEHGSGLTRTPSSEEISPTKFPGLYRTGE...,9606.ENSP00000000233,9606.ENSP00000019317,activation,f,f,175
3,MGLTVSALFSRIFGKKQMRILMVGLDAAGKTTILYKLKLGEIVTTI...,MQQAPQPYEFFSEENSPKWRGLLVSALRKVQEQVHPTLSANEESLY...,9606.ENSP00000000233,9606.ENSP00000216373,reaction,f,f,161
4,MGLTVSALFSRIFGKKQMRILMVGLDAAGKTTILYKLKLGEIVTTI...,MAMAEGERTECAEPPRDEPPADGALKRAEELKTQANDYFKAKDYEN...,9606.ENSP00000000233,9606.ENSP00000012443,catalysis,t,f,155


In [4]:
# Map interaction mode to integer (used as edge type)
df['edge_type'] = pd.factorize(df['mode'])[0]
mode_to_int = dict(zip(df['mode'], df['edge_type']))

In [5]:
# Create protein-to-sequence map
# -----------------------------
proteins_a = df[["item_id_a", "sequence_a"]].rename(columns={"item_id_a": "item_id", "sequence_a": "sequence"})
proteins_b = df[["item_id_b", "sequence_b"]].rename(columns={"item_id_b": "item_id", "sequence_b": "sequence"})
all_proteins = pd.concat([proteins_a, proteins_b]).drop_duplicates("item_id").set_index("item_id")
protein_to_idx = {pid: i for i, pid in enumerate(all_proteins.index)}

In [6]:
# ProtBERT Embedding
# -----------------------------
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert", do_lower_case=False)
model = BertModel.from_pretrained("Rostlab/prot_bert").eval()

def embed_sequence(seq):
    seq = seq.replace(" ", "").upper()
    seq = " ".join(list(seq))
    tokens = tokenizer(seq, return_tensors="pt", truncation=True, padding=True, max_length=1024)
    with torch.no_grad():
        output = model(**tokens)
    return output.last_hidden_state.mean(dim=1).squeeze().numpy()

  return self.fget.__get__(instance, owner)()


In [7]:
# Embed protein sequences (cached)
# -----------------------------
cache_path = "protbert_embeddings_linkprediction_148k.pkl"
if os.path.exists(cache_path):
    with open(cache_path, "rb") as f:
        protein_embeddings = pickle.load(f)
else:
    protein_embeddings = {}
    for pid in tqdm(all_proteins.index, desc="Embedding proteins"):
        try:
            protein_embeddings[pid] = embed_sequence(all_proteins.loc[pid, "sequence"])
        except Exception as e:
            print(f"Error embedding {pid}: {e}")
    with open(cache_path, "wb") as f:
        pickle.dump(protein_embeddings, f)

Embedding proteins: 100%|████████████████████████████████████████████████████████| 1690/1690 [1:55:44<00:00,  4.11s/it]


In [8]:
# Create feature matrix
# -----------------------------
embedding_dim = len(next(iter(protein_embeddings.values())))
x = np.zeros((len(protein_to_idx), embedding_dim), dtype=np.float32)
for pid, idx in protein_to_idx.items():
    x[idx] = protein_embeddings[pid]
x = torch.tensor(x, dtype=torch.float)

In [9]:
# Build edges and edge types
# -----------------------------
# Direction-aware edge construction
src_nodes = []
dst_nodes = []
edge_types = []

relation_to_id = {}
rel_id_counter = 0

for _, row in df.iterrows():
    a = row["item_id_a"]
    b = row["item_id_b"]
    mode = row["mode"]
    is_dir = row["is_directional"]
    a_acts = row["a_is_acting"]

    if is_dir == "t":
        if a_acts == "t":
            src = protein_to_idx[a]
            dst = protein_to_idx[b]
            rel = f"{mode}_forward"
        else:
            src = protein_to_idx[b]
            dst = protein_to_idx[a]
            rel = f"{mode}_reverse"
    else:
        # undirected → add both edges
        rel = f"{mode}_bidirectional"
        src_nodes.append(protein_to_idx[a])
        dst_nodes.append(protein_to_idx[b])
        if rel not in relation_to_id:
            relation_to_id[rel] = rel_id_counter
            rel_id_counter += 1
        edge_types.append(relation_to_id[rel])

        src_nodes.append(protein_to_idx[b])
        dst_nodes.append(protein_to_idx[a])
        edge_types.append(relation_to_id[rel])
        continue  # skip to next row

    if rel not in relation_to_id:
        relation_to_id[rel] = rel_id_counter
        rel_id_counter += 1

    src_nodes.append(src)
    dst_nodes.append(dst)
    edge_types.append(relation_to_id[rel])

edge_index = torch.tensor([src_nodes, dst_nodes], dtype=torch.long)
edge_type = torch.tensor(edge_types, dtype=torch.long)

print(f"✅ Total unique edge types: {len(relation_to_id)}")

✅ Total unique edge types: 20


In [10]:
# Build PyG graph object
# -----------------------------
data = Data(x=x, edge_index=edge_index, edge_type=edge_type)

# -----------------------------
# Split for link prediction
# -----------------------------
ei = data.edge_index.numpy()
et = data.edge_type.numpy()

ei_train, ei_test, et_train, et_test = train_test_split(ei.T, et, test_size=0.2, random_state=42)
ei_train, ei_val, et_train, et_val = train_test_split(ei_train, et_train, test_size=0.1, random_state=42)

ei_train = torch.tensor(ei_train, dtype=torch.long).t()
ei_val = torch.tensor(ei_val, dtype=torch.long).t()
ei_test = torch.tensor(ei_test, dtype=torch.long).t()

et_train = torch.tensor(et_train, dtype=torch.long)
et_val = torch.tensor(et_val, dtype=torch.long)
et_test = torch.tensor(et_test, dtype=torch.long)

# Generate negative samples
neg_train = negative_sampling(ei_train, num_nodes=data.num_nodes)
neg_val = negative_sampling(ei_val, num_nodes=data.num_nodes)
neg_test = negative_sampling(ei_test, num_nodes=data.num_nodes)

In [11]:
# RGCN Model
# -----------------------------
class RGCNLinkPredictor(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, num_relations):
        super().__init__()
        self.conv1 = RGCNConv(in_dim, hidden_dim, num_relations)
        self.conv2 = RGCNConv(hidden_dim, out_dim, num_relations)

    def encode(self, x, edge_index, edge_type):
        x = F.relu(self.conv1(x, edge_index, edge_type))
        return self.conv2(x, edge_index, edge_type)

    def decode(self, z, edge_index):
        return (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)

model = RGCNLinkPredictor(data.num_features, hidden_dim=64, out_dim=32, num_relations=len(relation_to_id))
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

In [12]:
# Training function
# -----------------------------
def train():
    model.train()
    optimizer.zero_grad()
    z = model.encode(data.x, ei_train, et_train)
    pos_score = model.decode(z, ei_train)
    neg_score = model.decode(z, neg_train)

    pos_labels = torch.ones_like(pos_score)
    neg_labels = torch.zeros_like(neg_score)

    loss = F.binary_cross_entropy_with_logits(pos_score, pos_labels) + \
           F.binary_cross_entropy_with_logits(neg_score, neg_labels)

    loss.backward()
    optimizer.step()
    return loss.item()

In [13]:
# Evaluation function
# -----------------------------
def evaluate(ei_pos, et_pos, ei_neg):
    model.eval()
    with torch.no_grad():
        z = model.encode(data.x, ei_pos, et_pos)
        pos_score = model.decode(z, ei_pos).sigmoid().cpu().numpy()
        neg_score = model.decode(z, ei_neg).sigmoid().cpu().numpy()

        y_true = np.concatenate([np.ones_like(pos_score), np.zeros_like(neg_score)])
        y_score = np.concatenate([pos_score, neg_score])

        roc = roc_auc_score(y_true, y_score)
        pr = average_precision_score(y_true, y_score)

        return roc, pr

In [14]:
# Training loop
# -----------------------------
for epoch in range(1, 201):
    loss = train()
    if epoch % 10 == 0:
        roc, pr = evaluate(ei_val, et_val, neg_val)
        print(f"Epoch {epoch:03d} | Loss: {loss:.4f} | Val ROC-AUC: {roc:.4f} | PR-AUC: {pr:.4f}")

# Final test evaluation
roc, pr = evaluate(ei_test, et_test, neg_test)
print(f"\n🧪 Test ROC-AUC: {roc:.4f} | Test PR-AUC: {pr:.4f}")

Epoch 010 | Loss: 1.3706 | Val ROC-AUC: 0.3899 | PR-AUC: 0.4716
Epoch 020 | Loss: 1.3162 | Val ROC-AUC: 0.5627 | PR-AUC: 0.6363
Epoch 030 | Loss: 1.2565 | Val ROC-AUC: 0.6912 | PR-AUC: 0.7718
Epoch 040 | Loss: 1.2081 | Val ROC-AUC: 0.7733 | PR-AUC: 0.8301
Epoch 050 | Loss: 1.1452 | Val ROC-AUC: 0.8014 | PR-AUC: 0.8543
Epoch 060 | Loss: 1.0669 | Val ROC-AUC: 0.7763 | PR-AUC: 0.8381
Epoch 070 | Loss: 1.0382 | Val ROC-AUC: 0.7549 | PR-AUC: 0.8243
Epoch 080 | Loss: 1.0076 | Val ROC-AUC: 0.7369 | PR-AUC: 0.8148
Epoch 090 | Loss: 0.9768 | Val ROC-AUC: 0.7449 | PR-AUC: 0.8212
Epoch 100 | Loss: 0.9526 | Val ROC-AUC: 0.7481 | PR-AUC: 0.8224
Epoch 110 | Loss: 0.9301 | Val ROC-AUC: 0.7553 | PR-AUC: 0.8255
Epoch 120 | Loss: 0.9160 | Val ROC-AUC: 0.7473 | PR-AUC: 0.8185
Epoch 130 | Loss: 0.9056 | Val ROC-AUC: 0.7423 | PR-AUC: 0.8151
Epoch 140 | Loss: 0.8958 | Val ROC-AUC: 0.7446 | PR-AUC: 0.8164
Epoch 150 | Loss: 0.8860 | Val ROC-AUC: 0.7464 | PR-AUC: 0.8174
Epoch 160 | Loss: 0.8744 | Val ROC-AUC: 

In [15]:
# Save model
torch.save(model.state_dict(), "protbert_rgcn_link_predictor_modes_148k.pth")