# AI-Powered Protein-Protein Interaction Prediction using GNN and STRING Dataset
This notebook demonstrates how to use Graph Neural Networks (GNNs) for predicting protein-protein interactions using the STRING dataset.

In [49]:
import pandas as pd
import torch
import torch.nn as nn
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler


In [50]:
df = pd.read_csv('STRING.csv')
scaler = MinMaxScaler()
df['combined_score'] = scaler.fit_transform(df[['combined_score']])

In [51]:
protein1 = df['protein1'].values
protein2 = df['protein2'].values
interactions = df['combined_score'].values

In [52]:
protein_to_idx = {protein: idx for idx, protein in enumerate(set(protein1).union(set(protein2)))}

In [53]:
protein1_idx = [protein_to_idx[p] for p in protein1]
protein2_idx = [protein_to_idx[p] for p in protein2]

In [54]:
edge_index = torch.tensor([protein1_idx, protein2_idx], dtype=torch.long)

In [55]:
num_proteins = len(protein_to_idx)

In [56]:
x = torch.eye(num_proteins)

In [57]:
y = torch.tensor(interactions, dtype=torch.float).unsqueeze(1)

In [58]:
data = Data(x=x, edge_index=edge_index, y=y)
print(data)

Data(x=[19484, 19484], edge_index=[2, 1048575], y=[1048575, 1])


In [59]:
train_idx, test_idx = train_test_split(range(edge_index.shape[1]), test_size=0.2, random_state=42)

In [60]:
train_edge_index = edge_index[:, train_idx]
test_edge_index = edge_index[:, test_idx]

In [61]:
train_labels = y[train_idx]
test_labels = y[test_idx]

In [62]:
train_data = Data(x=x, edge_index=train_edge_index, y=train_labels)
test_data = Data(x=x, edge_index=test_edge_index, y=test_labels)

In [63]:
train_loader = DataLoader([train_data], batch_size=32, shuffle=True)
test_loader = DataLoader([test_data], batch_size=32, shuffle=False)



In [64]:
class GNN(nn.Module):
    def __init__(self, num_nodes, embedding_dim, hidden_dim, output_dim):
        super(GNN, self).__init__()
        self.embedding = nn.Embedding(num_nodes, embedding_dim)
        self.conv1 = GCNConv(embedding_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)
        self.fc = nn.Linear(output_dim, 1)  # 1 for regression output

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.embedding(torch.arange(x.size(0), device=x.device))
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)
        edge_predictions = self.fc((x[edge_index[0]] + x[edge_index[1]]) / 2)
        return edge_predictions


In [65]:
model = GNN(num_nodes=19484, embedding_dim=128,hidden_dim=64, output_dim=32)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

In [66]:
def train():
    model.train()
    for data in train_loader:
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
    return loss.item()


In [67]:
def evaluate(loader):
    threshold = 0.5
    model.eval()
    all_preds = []
    all_labels = []
    for data in loader:
        with torch.no_grad():
            out = model(data)
            all_preds.append(out)
            all_labels.append(data.y)
    
    preds = torch.cat(all_preds, dim=0).cpu().numpy()
    labels = torch.cat(all_labels, dim=0).cpu().numpy()
    binary_preds = (preds >= threshold).astype(int)
    binary_labels = (labels >= threshold).astype(int)
    accuracy = accuracy_score(binary_labels, binary_preds)
    return accuracy

In [68]:
best_val_acc = 0
for epoch in range(1, 101):
    loss = train()
    train_acc = evaluate(train_loader)
    val_acc = evaluate(test_loader)
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}')
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'best_model.pth')

Epoch: 001, Loss: 0.0388, Train Acc: 0.9397, Val Acc: 0.9400
Epoch: 002, Loss: 0.0898, Train Acc: 0.9397, Val Acc: 0.9395
Epoch: 003, Loss: 0.0433, Train Acc: 0.9397, Val Acc: 0.9348
Epoch: 004, Loss: 0.0349, Train Acc: 0.9397, Val Acc: 0.8883
Epoch: 005, Loss: 0.0426, Train Acc: 0.9397, Val Acc: 0.8687
Epoch: 006, Loss: 0.0432, Train Acc: 0.9397, Val Acc: 0.9118
Epoch: 007, Loss: 0.0385, Train Acc: 0.9397, Val Acc: 0.9316
Epoch: 008, Loss: 0.0353, Train Acc: 0.9397, Val Acc: 0.9368
Epoch: 009, Loss: 0.0345, Train Acc: 0.9397, Val Acc: 0.9384
Epoch: 010, Loss: 0.0351, Train Acc: 0.9397, Val Acc: 0.9388
Epoch: 011, Loss: 0.0359, Train Acc: 0.9397, Val Acc: 0.9390
Epoch: 012, Loss: 0.0365, Train Acc: 0.9397, Val Acc: 0.9393
Epoch: 013, Loss: 0.0367, Train Acc: 0.9397, Val Acc: 0.9396
Epoch: 014, Loss: 0.0366, Train Acc: 0.9397, Val Acc: 0.9396
Epoch: 015, Loss: 0.0363, Train Acc: 0.9397, Val Acc: 0.9396
Epoch: 016, Loss: 0.0359, Train Acc: 0.9397, Val Acc: 0.9395
Epoch: 017, Loss: 0.0355

In [69]:
final_accuracy = evaluate(test_loader)
print(f"Final Model Evaluation - Accuracy: {final_accuracy:.4f}")

Final Model Evaluation - Accuracy: 0.9371


In [70]:
model.load_state_dict(torch.load("best_model.pth"))

  model.load_state_dict(torch.load("best_model.pth"))


<All keys matched successfully>

In [71]:
for epoch in range(1, 101):
    loss = train()
    train_acc = evaluate(train_loader)
    val_acc = evaluate(test_loader)
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}')
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'best_model.pth')

Epoch: 001, Loss: 0.0898, Train Acc: 0.9397, Val Acc: 0.9389
Epoch: 002, Loss: 0.0378, Train Acc: 0.9397, Val Acc: 0.8122
Epoch: 003, Loss: 0.0465, Train Acc: 0.9397, Val Acc: 0.5936
Epoch: 004, Loss: 0.0536, Train Acc: 0.9397, Val Acc: 0.8729
Epoch: 005, Loss: 0.0400, Train Acc: 0.9397, Val Acc: 0.9323
Epoch: 006, Loss: 0.0347, Train Acc: 0.9397, Val Acc: 0.9381
Epoch: 007, Loss: 0.0355, Train Acc: 0.9397, Val Acc: 0.9388
Epoch: 008, Loss: 0.0371, Train Acc: 0.9397, Val Acc: 0.9396
Epoch: 009, Loss: 0.0379, Train Acc: 0.9397, Val Acc: 0.9400
Epoch: 010, Loss: 0.0378, Train Acc: 0.9397, Val Acc: 0.9400
Epoch: 011, Loss: 0.0372, Train Acc: 0.9397, Val Acc: 0.9400
Epoch: 012, Loss: 0.0363, Train Acc: 0.9397, Val Acc: 0.9400
Epoch: 013, Loss: 0.0355, Train Acc: 0.9397, Val Acc: 0.9400
Epoch: 014, Loss: 0.0349, Train Acc: 0.9397, Val Acc: 0.9398
Epoch: 015, Loss: 0.0347, Train Acc: 0.9397, Val Acc: 0.9393
Epoch: 016, Loss: 0.0350, Train Acc: 0.9397, Val Acc: 0.9391
Epoch: 017, Loss: 0.0355

In [72]:
final_accuracy = evaluate(test_loader)
print(f"Final Model Evaluation - Accuracy: {final_accuracy:.4f}")

Final Model Evaluation - Accuracy: 0.9388


In [75]:
import numpy as np

def get_top_interactions(loader, model, top_n=5):
    model.eval()  
    all_preds = []
    all_labels = []
    all_edges = []
    
    for data in loader:
        with torch.no_grad():
            out = model(data)  
            all_preds.append(out)
            all_labels.append(data.y)
            all_edges.append(data.edge_index.T) 

    preds = torch.cat(all_preds, dim=0).cpu().numpy().flatten()
    labels = torch.cat(all_labels, dim=0).cpu().numpy().flatten()
    edges = torch.cat(all_edges, dim=0).cpu().numpy()

    differences = np.abs(preds - labels)
   
    top_indices = np.argsort(differences)[:top_n]  
    
    print("Top Interactions (Protein1, Protein2) with Predicted and Actual Values:")
    for idx in top_indices:
        protein1, protein2 = edges[idx]
        print(f"Proteins: ({protein1}, {protein2}), Predicted: {preds[idx]:.4f}, Actual: {labels[idx]:.4f}")

get_top_interactions(test_loader, model, top_n=5)


Top Interactions (Protein1, Protein2) with Predicted and Actual Values:
Proteins: (12330, 15216), Predicted: 0.0989, Actual: 0.0989
Proteins: (153, 8696), Predicted: 0.1625, Actual: 0.1625
Proteins: (14563, 11703), Predicted: 0.1531, Actual: 0.1531
Proteins: (5998, 9302), Predicted: 0.0836, Actual: 0.0836
Proteins: (9058, 2211), Predicted: 0.0212, Actual: 0.0212
