In [1]:
# import pandas as pd
import torch 
import pickle
from torch_geometric.data import Data
from torch_geometric.nn import SAGEConv
# from torch_geometric.utils import train_test_split_edges
from torch_geometric.nn import GAE, VGAE
from torch_geometric.transforms import RandomLinkSplit
# from torch_geometric.utils import negative_sampling

In [2]:
# node2vec embeddings for all plant proteins 
with open('node2vec_embedding_dict.pkl', 'rb') as f:
    node2vec_embeddings = pickle.load(f)
f.close()

In [3]:
## importing interaccting proteins as 2 lists . first list interact with second list 
with open('first_prot.pkl', 'rb') as f:
    first_prot = pickle.load(f)
f.close()

with open('second_prot.pkl', 'rb') as f:
    second_prot = pickle.load(f)
f.close()

In [None]:
## total number of interacting proteins 
unique_list = list(dict.fromkeys(first_prot + second_prot))
len(unique_list)

In [8]:
## assigning index for each  protein to encode 
ind2node = {index: item for index, item in enumerate(unique_list)} # index to protein mapping 
node2ind = {v: k for k, v in ind2node.items()}  # protein to index mapping 
ind2node = None

In [None]:
# encoded version of both protein lists 
first_prot = [node2ind[protein] for protein in first_prot] 
second_prot = [node2ind[protein] for protein in second_prot]

In [None]:
# Step 1: Convert the interaction data into COO format (edge_index)
# COO format requires edge_index, a 2xN matrix where each column represents an edge (interaction)
edge_index = torch.tensor([first_prot, second_prot], dtype=torch.long)

first_prot = None
second_prot = None

In [None]:
# initial node features
init_node_features = [node2vec_embeddings[protein] for protein in node2ind.keys()]

In [None]:
init_node_features = torch.tensor(init_node_features)

In [None]:
# Step 2: Create a PyTorch Geometric Data object
data = Data(edge_index=edge_index, x= init_node_features)
data

In [12]:
edge_index = torch.randint(0, 1000, (2, 50000)) 
init_node_features = torch.rand(1000, 35)
data = Data(edge_index=edge_index, x= init_node_features)

In [13]:
transform = RandomLinkSplit(is_undirected=True, num_val=0, num_test=0.1, split_labels=True, add_negative_train_samples= False)
train_data, val_data, test_data = transform(data)

In [14]:
train_data

Data(x=[1000, 35], edge_index=[2, 44990], pos_edge_label=[22495], pos_edge_label_index=[2, 22495])

In [15]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [16]:
train_data.to(device)
test_data.to(device)

Data(x=[1000, 35], edge_index=[2, 44990], pos_edge_label=[2499], pos_edge_label_index=[2, 2499], neg_edge_label=[2499], neg_edge_label_index=[2, 2499])

In [6]:
# # train_neg_edge_index = train_data.neg_edge_label_index
# # val_neg_edge_index = val_data.neg_edge_label_index.to(device)
# test_neg_edge_index = test_data.neg_edge_label_index

In [2]:
## graph sage model class
class PPIEncdoer(torch.nn.Module):
    def __init__(self, in_channels, mid_channel, out_channels):
        super(PPIEncdoer, self).__init__()
        self.conv1 = SAGEConv(in_channels, mid_channel) 
        self.conv2 = SAGEConv(mid_channel, out_channels)
        self.conv_logstd = SAGEConv(mid_channel, out_channels) ##

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        # x =  self.conv2(x, edge_index)
        # return x
        return self.conv2(x, edge_index), self.conv_logstd(x, edge_index)

In [None]:
# parameters
out_channels = 25 
num_features = 35
mid_channels = 30
epochs = 100

# model
model = VGAE(PPIEncdoer(num_features, mid_channels,out_channels))

# move to GPU 
model = model.to(device)
x = data.x.to(device)

# inizialize the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [13]:
# Manually print out the model summary
def print_model_summary(model):
    print("Model Summary:")
    print(model)
    print("\nModel Parameters:")
    total_params = 0
    for name, param in model.named_parameters():
        print(f"{name}: {param.numel()} parameters")
        total_params += param.numel()
    print(f"\nTotal Parameters: {total_params}")

print_model_summary(model)

Model Summary:
GAE(
  (encoder): PPIEncdoer(
    (conv1): SAGEConv(35, 30, aggr=mean)
    (conv2): SAGEConv(30, 25, aggr=mean)
  )
  (decoder): InnerProductDecoder()
)

Model Parameters:
encoder.conv1.lin_l.weight: 1050 parameters
encoder.conv1.lin_l.bias: 30 parameters
encoder.conv1.lin_r.weight: 1050 parameters
encoder.conv2.lin_l.weight: 750 parameters
encoder.conv2.lin_l.bias: 25 parameters
encoder.conv2.lin_r.weight: 750 parameters

Total Parameters: 3655


In [19]:
def train():
    model.train()
    optimizer.zero_grad()
    z = model.encode(x, train_data.edge_index)
    loss = model.recon_loss(z, train_data.pos_edge_label_index)
    loss = loss + (1 / data.num_nodes) * model.kl_loss()
    loss.backward()
    optimizer.step()
    return float(loss)


def test(pos_edge_index, neg_edge_index):
    model.eval()
    with torch.no_grad():
        z = model.encode(x, test_data.edge_index)
    return model.test(z, pos_edge_index, neg_edge_index)

In [None]:
for epoch in range(1, epochs + 1):
    loss = train()

    auc, ap = test(test_data.pos_edge_label_index, test_data.neg_edge_label_index)
    print('Epoch: {:03d}, train loss: {:.3f}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, loss, auc, ap))

In [None]:
torch.save(model.state_dict(), 'VGAE_25_SAGE.pth')