<a href="https://colab.research.google.com/github/BCB4PM/GL4SDA/blob/main/GL4SDA_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Train and validate the GL4SDA model

## Add this in a Google Colab cell to install the correct version of Pytorch Geometric.



```
import torch

def format_pytorch_version(version):
  return version.split('+')[0]

TORCH_version = torch.__version__
TORCH = format_pytorch_version(TORCH_version)

def format_cuda_version(version):
  return 'cu' + version.replace('.', '')

CUDA_version = torch.version.cuda
CUDA = format_cuda_version(CUDA_version)

!pip install torch-scatter -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-cluster -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-spline-conv -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-geometric
```




In [None]:
import torch

def format_pytorch_version(version):
  return version.split('+')[0]

TORCH_version = torch.__version__
TORCH = format_pytorch_version(TORCH_version)

def format_cuda_version(version):
  return 'cu' + version.replace('.', '')

CUDA_version = torch.version.cuda
CUDA = format_cuda_version(CUDA_version)

!pip install torch-scatter -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-cluster -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-spline-conv -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-geometric

Looking in links: https://data.pyg.org/whl/torch-2.5.1+cu124.html
Looking in links: https://data.pyg.org/whl/torch-2.5.1+cu124.html
Looking in links: https://data.pyg.org/whl/torch-2.5.1+cu124.html
Looking in links: https://data.pyg.org/whl/torch-2.5.1+cu124.html


## Import python libraries.

In [None]:

import numpy as np
import torch
from torch import Tensor
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, GATv2Conv, GraphConv, HeteroConv
import torch_geometric.transforms as T
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric import seed_everything
import random
from sklearn.metrics import accuracy_score, precision_score, recall_score, matthews_corrcoef, f1_score, roc_auc_score

## Set train folder and seed parameters. Also, import the training graph.

In [1]:
traindir = "https://github.com/BCB4PM/GL4SDA/tree/main/data/graph_data/train/" # directory with the training graph
input_graph = "hetero_graph_2-4_train.pkl" # training graph input file

seed_model = 35 #seed for model training
seed = 41 # seed for train/validation splitting


## Define the final classifier.
Our final classifier applies the dot-product between source and destination
node embeddings to derive edge-level predictions:

In [None]:
class Classifier(torch.nn.Module):
    def forward(self, x_user: Tensor, x_movie: Tensor, edge_label_index: Tensor) -> Tensor:
        # Convert node embeddings to edge-level representations:
        edge_feat_user = x_user[edge_label_index[0]]
        edge_feat_movie = x_movie[edge_label_index[1]]
        # Apply dot-product to get a prediction per supervision edge:
        return (edge_feat_user * edge_feat_movie).sum(dim=-1)

## Define the Bipartite GNN model.
Class that implements the GNN with GraphConv and GATv2Conv operators, without edge weights

In [None]:
class HeteroGraphGNN(torch.nn.Module):
    def __init__(self, hidden_channels, dropout_rate=0.4, num_heads=8):
        super().__init__()

        self.convs = torch.nn.ModuleList()

        conv1 = HeteroConv({
            ('snorna', 'to', 'disease'): GraphConv((-1, -1), hidden_channels[0]),
            ('disease', 'rev_to', 'snorna'): GraphConv((-1, -1), hidden_channels[0]),
        }, aggr='sum')
        self.convs.append(conv1)

        conv2 = HeteroConv({
              ('snorna', 'to', 'disease'): GATv2Conv((-1, -1), hidden_channels[1], add_self_loops=False, heads=num_heads, concat=False),
              ('disease', 'rev_to', 'snorna'): GATv2Conv((-1, -1), hidden_channels[1], add_self_loops=False, heads=num_heads, concat=False),
        }, aggr='sum')
        self.convs.append(conv2)


        conv3 = HeteroConv({
            ('snorna', 'to', 'disease'): GraphConv((-1, -1), hidden_channels[2]),
            ('disease', 'rev_to', 'snorna'): GraphConv((-1, -1), hidden_channels[2]),
        }, aggr='sum')
        self.convs.append(conv3)


        self.p = dropout_rate

        self.classifier = Classifier()


    def forward(self, data, x_dict, edge_index):


        for cc, conv in enumerate(self.convs):

            x_dict = conv(x_dict, edge_index)

            if cc != (len(self.convs)-1):
                x_dict = {key: z.relu() for key, z in x_dict.items()}
                x_dict = {key: F.dropout(z, p=self.p, training=self.training) for key, z in x_dict.items()}

        pred = self.classifier(
            x_dict['snorna'],
            x_dict['disease'],
            data['snorna','to','disease'].edge_label_index
        )

        return pred



## Define the model hyperparameters and load the train GNN.

In [None]:
n_epochs = 500

hidden_1 = 128
hidden_2 = 128
hidden_3 = 64
hidden_channels = [hidden_1,hidden_2,hidden_3]

lr = 0.0001
criterion = torch.nn.BCEWithLogitsLoss() # Choose loss function. We are working directly with logits, i.e. it makes the sigmoid first.

# load train graph
graph = torch.load(traindir+input_graph)

#set the seed for reproducibility
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)            # if you are using multi-GPU.
np.random.seed(seed)                        # Numpy module.
random.seed(seed)                           # Python random module.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

  graph = torch.load(traindir+input_graph)


## Split train and validation set and add negative samples to the validation set

In [None]:
transform = T.RandomLinkSplit(
    num_val=0.1,
    num_test=0.0,
    is_undirected = True,
    disjoint_train_ratio=0.2, # ratio between supervision and message passing edges
    neg_sampling_ratio=1.0,
    add_negative_train_samples=False,
    edge_types=("snorna", "to", "disease"),
    rev_edge_types=("disease", "rev_to", "snorna"),
    )

train_data, val_data, empty_data = transform(graph)


# Define train seed edges and add negative samples to the training set:
edge_label_index = train_data["snorna", "to", "disease"].edge_label_index
edge_label = train_data["snorna", "to", "disease"].edge_label
train_loader = LinkNeighborLoader(
    data=train_data,
    num_neighbors=[20, 15],
    neg_sampling_ratio=2.0,
    edge_label_index=(("snorna", "to", "disease"), edge_label_index),
    edge_label=edge_label,
    batch_size=128,
    shuffle=False,
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# set the seed for the GNN model
seed_everything(seed_model)

## Instantiate the model and start training.

In [None]:
model = HeteroGraphGNN(hidden_channels=hidden_channels)
model = model.double()
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr) # Choose optimizer.

# start training
for epoch in range(n_epochs):
    model.train()
    total_loss = total_examples = 0
    for sampled_data in train_loader:
        optimizer.zero_grad()
        sampled_data.to(device)
        pred = model(sampled_data,sampled_data.x_dict,sampled_data.edge_index_dict)
        ground_truth = sampled_data["snorna", "to", "disease"].edge_label
        loss = criterion(pred, ground_truth)
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * pred.numel()
        total_examples += pred.numel()
    print(f"Epoch: {epoch:03d}, Loss: {total_loss / total_examples:.4f}")


Epoch: 000, Loss: 0.6954
Epoch: 001, Loss: 0.6855
Epoch: 002, Loss: 0.6889
Epoch: 003, Loss: 0.6751
Epoch: 004, Loss: 0.6642
Epoch: 005, Loss: 0.6765
Epoch: 006, Loss: 0.6683
Epoch: 007, Loss: 0.6672
Epoch: 008, Loss: 0.6578
Epoch: 009, Loss: 0.6673
Epoch: 010, Loss: 0.6602
Epoch: 011, Loss: 0.6489
Epoch: 012, Loss: 0.6465
Epoch: 013, Loss: 0.6466
Epoch: 014, Loss: 0.6521
Epoch: 015, Loss: 0.6380
Epoch: 016, Loss: 0.6523
Epoch: 017, Loss: 0.6656
Epoch: 018, Loss: 0.6260
Epoch: 019, Loss: 0.6231
Epoch: 020, Loss: 0.6344
Epoch: 021, Loss: 0.6347
Epoch: 022, Loss: 0.6387
Epoch: 023, Loss: 0.6440
Epoch: 024, Loss: 0.6233
Epoch: 025, Loss: 0.6313
Epoch: 026, Loss: 0.6403
Epoch: 027, Loss: 0.6039
Epoch: 028, Loss: 0.6362
Epoch: 029, Loss: 0.6245
Epoch: 030, Loss: 0.6063
Epoch: 031, Loss: 0.5953
Epoch: 032, Loss: 0.5876
Epoch: 033, Loss: 0.5726
Epoch: 034, Loss: 0.5946
Epoch: 035, Loss: 0.5679
Epoch: 036, Loss: 0.5919
Epoch: 037, Loss: 0.5775
Epoch: 038, Loss: 0.5631
Epoch: 039, Loss: 0.5755


## Validate GL4SDA model.
We do not need to add negative samples for the validation set. It already has them.

In [None]:
# Define validation seed edges:
edge_label_index = val_data["snorna", "to", "disease"].edge_label_index
edge_label = val_data["snorna", "to", "disease"].edge_label
val_loader = LinkNeighborLoader(
    data=val_data,
    num_neighbors=[20, 15],
    edge_label_index=(("snorna", "to", "disease"), edge_label_index),
    edge_label=edge_label,
    batch_size=128,
    shuffle=False,
)

preds = []
ground_truths = []
with torch.no_grad():
    model.eval()
    for sampled_val_data in val_loader:
        sampled_val_data.to(device)
        preds.append(model(sampled_val_data,sampled_val_data.x_dict,sampled_val_data.edge_index_dict))
        ground_truths.append(sampled_val_data["snorna", "to", "disease"].edge_label)
pred_row = torch.cat(preds, dim=0).cpu()
pred = torch.cat(preds, dim=0).cpu().numpy()
ground_row = torch.cat(ground_truths, dim=0).cpu()
ground_truth = torch.cat(ground_truths, dim=0).cpu().numpy()
auc_val = roc_auc_score(ground_truth, pred)
acc_val = accuracy_score( ground_truth, pred_row.sigmoid().round() )
val_pre = precision_score(ground_truth,pred_row.sigmoid().round())
val_rec = recall_score(ground_truth,pred_row.sigmoid().round())
val_mcc = matthews_corrcoef(ground_truth,pred_row.sigmoid().round())
val_f1 = f1_score(ground_truth,pred_row.sigmoid().round())

In [None]:
# print validated output
print()
print(f"Validation AUC: {auc_val:.4f}")
print(f"Validation ACC: {acc_val:.4f}")
print(f"Validation Prec: {val_pre:.4f}")
print(f"Validation Recall: {val_rec:.4f}")
print(f"Validation MCC: {val_mcc:.4f}")
print(f"Validation F-1: {val_f1:.4f}")


Validation AUC: 0.9158
Validation ACC: 0.8049
Validation Prec: 0.9167
Validation Recall: 0.6707
Validation MCC: 0.6330
Validation F-1: 0.7746
