<a href="https://colab.research.google.com/github/BCB4PM/GL4SDA/blob/main/GL4SDA_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Train and validate the GL4SDA model

## Add this in a Google Colab cell to install the correct version of Pytorch Geometric.



```
import torch

def format_pytorch_version(version):
  return version.split('+')[0]

TORCH_version = torch.__version__
TORCH = format_pytorch_version(TORCH_version)

def format_cuda_version(version):
  return 'cu' + version.replace('.', '')

CUDA_version = torch.version.cuda
CUDA = format_cuda_version(CUDA_version)

!pip install torch-scatter -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-cluster -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-spline-conv -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-geometric
```




In [1]:
import torch

def format_pytorch_version(version):
  return version.split('+')[0]

TORCH_version = torch.__version__
TORCH = format_pytorch_version(TORCH_version)

def format_cuda_version(version):
  return 'cu' + version.replace('.', '')

CUDA_version = torch.version.cuda
CUDA = format_cuda_version(CUDA_version)

!pip install torch-scatter -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-cluster -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-spline-conv -f https://data.pyg.org/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-geometric

Looking in links: https://data.pyg.org/whl/torch-2.5.1+cu124.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-2.5.0%2Bcu124/torch_scatter-2.1.2%2Bpt25cu124-cp311-cp311-linux_x86_64.whl (10.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.8/10.8 MB[0m [31m30.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-scatter
Successfully installed torch-scatter-2.1.2+pt25cu124
Looking in links: https://data.pyg.org/whl/torch-2.5.1+cu124.html
Collecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-2.5.0%2Bcu124/torch_sparse-0.6.18%2Bpt25cu124-cp311-cp311-linux_x86_64.whl (5.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.2/5.2 MB[0m [31m12.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch-sparse
Successfully installed torch-sparse-0.6.18+pt25cu124
Looking in links: https://data.pyg.org/whl/torch-2.5.1+cu124.html
Collecting torch-cluster
  Downloading https://data.p

## Import python libraries.

In [2]:
import numpy as np
import torch
from torch import Tensor
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv, GATv2Conv, GraphConv, HeteroConv
import torch_geometric.transforms as T
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric import seed_everything
import random
from sklearn.metrics import accuracy_score, precision_score, recall_score, matthews_corrcoef, f1_score, roc_auc_score

## Get sample data and set seed parameters.

In [3]:
file_download_link = "https://docs.google.com/uc?export=download&id=1SMmwi1WU1WQqTc_PpCyiYGzeVqZrKbeB"

!wget -O hetero_graph_2-4_train.pkl --no-check-certificate "$file_download_link"

input_graph = "hetero_graph_2-4_train.pkl" # training graph input file

seed_model = 35 #seed for model training
seed = 41 # seed for train/validation splitting

--2025-02-26 08:34:21--  https://docs.google.com/uc?export=download&id=1SMmwi1WU1WQqTc_PpCyiYGzeVqZrKbeB
Resolving docs.google.com (docs.google.com)... 142.251.170.139, 142.251.170.102, 142.251.170.113, ...
Connecting to docs.google.com (docs.google.com)|142.251.170.139|:443... connected.
HTTP request sent, awaiting response... 303 See Other
Location: https://drive.usercontent.google.com/download?id=1SMmwi1WU1WQqTc_PpCyiYGzeVqZrKbeB&export=download [following]
--2025-02-26 08:34:22--  https://drive.usercontent.google.com/download?id=1SMmwi1WU1WQqTc_PpCyiYGzeVqZrKbeB&export=download
Resolving drive.usercontent.google.com (drive.usercontent.google.com)... 64.233.188.132, 2404:6800:4008:c02::84
Connecting to drive.usercontent.google.com (drive.usercontent.google.com)|64.233.188.132|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2758761 (2.6M) [application/octet-stream]
Saving to: ‘hetero_graph_2-4_train.pkl’


2025-02-26 08:34:25 (56.7 MB/s) - ‘hetero_graph_2-4_

## Define the final classifier.
Our final classifier applies the dot-product between source and destination
node embeddings to derive edge-level predictions:

In [4]:
class Classifier(torch.nn.Module):
    def forward(self, x_user: Tensor, x_movie: Tensor, edge_label_index: Tensor) -> Tensor:
        # Convert node embeddings to edge-level representations:
        edge_feat_user = x_user[edge_label_index[0]]
        edge_feat_movie = x_movie[edge_label_index[1]]
        # Apply dot-product to get a prediction per supervision edge:
        return (edge_feat_user * edge_feat_movie).sum(dim=-1)

## Define the Bipartite GNN model.
Class that implements the GNN with GraphConv and GATv2Conv operators, without edge weights

In [5]:
class HeteroGraphGNN(torch.nn.Module):
    def __init__(self, hidden_channels, dropout_rate=0.4, num_heads=8):
        super().__init__()

        self.convs = torch.nn.ModuleList()

        conv1 = HeteroConv({
            ('snorna', 'to', 'disease'): GraphConv((-1, -1), hidden_channels[0]),
            ('disease', 'rev_to', 'snorna'): GraphConv((-1, -1), hidden_channels[0]),
        }, aggr='sum')
        self.convs.append(conv1)

        conv2 = HeteroConv({
              ('snorna', 'to', 'disease'): GATv2Conv((-1, -1), hidden_channels[1], add_self_loops=False, heads=num_heads, concat=False),
              ('disease', 'rev_to', 'snorna'): GATv2Conv((-1, -1), hidden_channels[1], add_self_loops=False, heads=num_heads, concat=False),
        }, aggr='sum')
        self.convs.append(conv2)


        conv3 = HeteroConv({
            ('snorna', 'to', 'disease'): GraphConv((-1, -1), hidden_channels[2]),
            ('disease', 'rev_to', 'snorna'): GraphConv((-1, -1), hidden_channels[2]),
        }, aggr='sum')
        self.convs.append(conv3)


        self.p = dropout_rate

        self.classifier = Classifier()


    def forward(self, data, x_dict, edge_index):


        for cc, conv in enumerate(self.convs):

            x_dict = conv(x_dict, edge_index)

            if cc != (len(self.convs)-1):
                x_dict = {key: z.relu() for key, z in x_dict.items()}
                x_dict = {key: F.dropout(z, p=self.p, training=self.training) for key, z in x_dict.items()}

        pred = self.classifier(
            x_dict['snorna'],
            x_dict['disease'],
            data['snorna','to','disease'].edge_label_index
        )

        return pred

## Define the model hyperparameters and load the train GNN.

In [6]:
n_epochs = 500

hidden_1 = 128
hidden_2 = 128
hidden_3 = 64
hidden_channels = [hidden_1,hidden_2,hidden_3]

lr = 0.0001
criterion = torch.nn.BCEWithLogitsLoss() # Choose loss function. We are working directly with logits, i.e. it makes the sigmoid first.

# load train graph
#!wget "https://github.com/BCB4PM/GL4SDA/tree/main/data/graph_data/train/hetero_graph_2-4_train.pkl"
#graph = input_graph

#graph = torch.load(traindir+input_graph)
graph = torch.load(input_graph)


#set the seed for reproducibility
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)            # if you are using multi-GPU.
np.random.seed(seed)                        # Numpy module.
random.seed(seed)                           # Python random module.
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

  graph = torch.load(input_graph)


## Split train and validation set and add negative samples to the validation set

In [7]:
transform = T.RandomLinkSplit(
    num_val=0.1,
    num_test=0.0,
    is_undirected = True,
    disjoint_train_ratio=0.2, # ratio between supervision and message passing edges
    neg_sampling_ratio=1.0,
    add_negative_train_samples=False,
    edge_types=("snorna", "to", "disease"),
    rev_edge_types=("disease", "rev_to", "snorna"),
    )

train_data, val_data, empty_data = transform(graph)


# Define train seed edges and add negative samples to the training set:
edge_label_index = train_data["snorna", "to", "disease"].edge_label_index
edge_label = train_data["snorna", "to", "disease"].edge_label
train_loader = LinkNeighborLoader(
    data=train_data,
    num_neighbors=[20, 15],
    neg_sampling_ratio=2.0,
    edge_label_index=(("snorna", "to", "disease"), edge_label_index),
    edge_label=edge_label,
    batch_size=128,
    shuffle=False,
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# set the seed for the GNN model
seed_everything(seed_model)



## Instantiate the model and start training.

In [8]:
model = HeteroGraphGNN(hidden_channels=hidden_channels)
model = model.double()
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr) # Choose optimizer.

# start training
for epoch in range(n_epochs):
    model.train()
    total_loss = total_examples = 0
    for sampled_data in train_loader:
        optimizer.zero_grad()
        sampled_data.to(device)
        pred = model(sampled_data,sampled_data.x_dict,sampled_data.edge_index_dict)
        ground_truth = sampled_data["snorna", "to", "disease"].edge_label
        loss = criterion(pred, ground_truth)
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * pred.numel()
        total_examples += pred.numel()
    print(f"Epoch: {epoch:03d}, Loss: {total_loss / total_examples:.4f}")


Epoch: 000, Loss: 0.6899
Epoch: 001, Loss: 0.6819
Epoch: 002, Loss: 0.7001
Epoch: 003, Loss: 0.6940
Epoch: 004, Loss: 0.6691
Epoch: 005, Loss: 0.6781
Epoch: 006, Loss: 0.6753
Epoch: 007, Loss: 0.6657
Epoch: 008, Loss: 0.6795
Epoch: 009, Loss: 0.6754
Epoch: 010, Loss: 0.6685
Epoch: 011, Loss: 0.6568
Epoch: 012, Loss: 0.6581
Epoch: 013, Loss: 0.6626
Epoch: 014, Loss: 0.6534
Epoch: 015, Loss: 0.6469
Epoch: 016, Loss: 0.6406
Epoch: 017, Loss: 0.6701
Epoch: 018, Loss: 0.6474
Epoch: 019, Loss: 0.6344
Epoch: 020, Loss: 0.6169
Epoch: 021, Loss: 0.6194
Epoch: 022, Loss: 0.5958
Epoch: 023, Loss: 0.6110
Epoch: 024, Loss: 0.6108
Epoch: 025, Loss: 0.5977
Epoch: 026, Loss: 0.5760
Epoch: 027, Loss: 0.5477
Epoch: 028, Loss: 0.5479
Epoch: 029, Loss: 0.5939
Epoch: 030, Loss: 0.5921
Epoch: 031, Loss: 0.5537
Epoch: 032, Loss: 0.5641
Epoch: 033, Loss: 0.5454
Epoch: 034, Loss: 0.5830
Epoch: 035, Loss: 0.5773
Epoch: 036, Loss: 0.5524
Epoch: 037, Loss: 0.5507
Epoch: 038, Loss: 0.5492
Epoch: 039, Loss: 0.5769


## Validate GL4SDA model.
We do not need to add negative samples for the validation set. It already has them.

In [9]:
# Define validation seed edges:
edge_label_index = val_data["snorna", "to", "disease"].edge_label_index
edge_label = val_data["snorna", "to", "disease"].edge_label
val_loader = LinkNeighborLoader(
    data=val_data,
    num_neighbors=[20, 15],
    edge_label_index=(("snorna", "to", "disease"), edge_label_index),
    edge_label=edge_label,
    batch_size=128,
    shuffle=False,
)

preds = []
ground_truths = []
with torch.no_grad():
    model.eval()
    for sampled_val_data in val_loader:
        sampled_val_data.to(device)
        preds.append(model(sampled_val_data,sampled_val_data.x_dict,sampled_val_data.edge_index_dict))
        ground_truths.append(sampled_val_data["snorna", "to", "disease"].edge_label)
pred_row = torch.cat(preds, dim=0).cpu()
pred = torch.cat(preds, dim=0).cpu().numpy()
ground_row = torch.cat(ground_truths, dim=0).cpu()
ground_truth = torch.cat(ground_truths, dim=0).cpu().numpy()
auc_val = roc_auc_score(ground_truth, pred)
acc_val = accuracy_score( ground_truth, pred_row.sigmoid().round() )
val_pre = precision_score(ground_truth,pred_row.sigmoid().round())
val_rec = recall_score(ground_truth,pred_row.sigmoid().round())
val_mcc = matthews_corrcoef(ground_truth,pred_row.sigmoid().round())
val_f1 = f1_score(ground_truth,pred_row.sigmoid().round())

In [10]:
# print validated output
print()
print(f"Validation AUC: {auc_val:.4f}")
print(f"Validation ACC: {acc_val:.4f}")
print(f"Validation Prec: {val_pre:.4f}")
print(f"Validation Recall: {val_rec:.4f}")
print(f"Validation MCC: {val_mcc:.4f}")
print(f"Validation F-1: {val_f1:.4f}")


Validation AUC: 0.9126
Validation ACC: 0.8354
Validation Prec: 0.9365
Validation Recall: 0.7195
Validation MCC: 0.6895
Validation F-1: 0.8138
