Training Playground for ModSpy with Ray + PyTorch Lightning ⚡

In [1]:
import GPUtil

# Get the first available GPU
gpu = GPUtil.getGPUs()[0]

print(f"Total GPU memory: {gpu.memoryTotal / 1024} GB")
print(f"Free GPU memory: {gpu.memoryFree / 1024} GB")
print(f"Used GPU memory: {gpu.memoryUsed / 1024} GB")

Total GPU memory: 12.0 GB
Free GPU memory: 11.9072265625 GB
Used GPU memory: 0.0 GB


In [2]:
import sys
sys.path.insert(0, '/home/rahit/projects/def-mtarailo/rahit/from_scratch/modspy-data/src/modspy_data')

In [3]:
from comet_ml import Experiment


import os
import time

from typing import List, Optional, Tuple, Union, cast
from traitlets import default
from rich import print

import numpy as np
import pandas as pd
import torch
import torchmetrics
from torch_geometric.transforms import AddSelfLoops
from torch_geometric.typing import EdgeType
from torch_geometric.data import HeteroData

import optuna
from pytorch_lightning.loggers import CometLogger

# import wandb
import faiss
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
from torch_geometric.nn import MetaPath2Vec
import torch_geometric.utils as utils

from torch.nn.functional import cosine_similarity
from sklearn.metrics import accuracy_score

from ray import tune
from ray.tune.schedulers import ASHAScheduler


import ray.train.lightning
from ray.air.integrations.comet import CometLoggerCallback
from ray.tune.integration.pytorch_lightning import TuneReportCheckpointCallback

import networkx as nx
import plotly.express as px

from models.embed import MetaPath2VecLightningModule



In [4]:
device = 'cpu'

In [5]:
# Top models
# Consolidated meptapaths: /home/rahit/projects/def-mtarailo/rahit/from_scratch/modspy-data/data/06_models/modspy-experiments/ray-reults/train_m2vec_2024-04-18_17-24-55/train_m2vec_19f34_00000_0_batch_size=512,embedding_dim=128,lr=0.0029,num_negative_samples=5,walk_length=21,walks_per_node=12_2024-04-18_17-25-05/checkpoint_000000
# entire elk: /home/rahit/projects/def-mtarailo/rahit/from_scratch/modspy-data/data/06_models/modspy-experiments/ray-reults/train_m2vec_2024-04-13_02-43-09/train_m2vec_178a4_00000_0_batch_size=256,embedding_dim=128,lr=0.0248,metapath=biolink_Gene_biolink_interacts_with_biolink_Gene_bio_2024-04-13_02-43-23/checkpoint_000000
# "/home/rahit/projects/def-mtarailo/rahit/from_scratch/modspy-data/data/06_models/modspy-experiments/ray-reults/train_m2vec_2024-04-18_17-24-55/train_m2vec_19f34_00000_0_batch_size=512,embedding_dim=128,lr=0.0029,num_negative_samples=5,walk_length=21,walks_per_node=12_2024-04-18_17-25-05/checkpoint_000000/checkpoint"
# precious_carp_5539: /home/rahit/projects/def-mtarailo/rahit/from_scratch/modspy-data/data/06_models/modspy-experiments/ray-reults/train_m2vec_2024-04-03_01-44-54/train_m2vec_4bf4d_00001_1_batch_size=128,embedding_dim=128,lr=0.0050,walk_length=12_2024-04-03_01-44-55/checkpoint_000000/checkpoint

chkpt_filepath = "/home/rahit/projects/def-mtarailo/rahit/from_scratch/modspy-data/data/06_models/modspy-experiments/ray-reults/train_m2vec_2024-04-14_17-21-49/train_m2vec_016cc_00000_0_batch_size=32,embedding_dim=128,lr=0.0104,metapath=biolink_Gene_biolink_interacts_with_biolink_Gene_biol_2024-04-14_17-22-01/checkpoint_000000/checkpoint"
# chkpt_filepath = "/home/rahit/projects/def-mtarailo/rahit/from_scratch/modspy-data/data/06_models/modspy-experiments/ray-reults/train_m2vec_2024-04-03_01-44-54/train_m2vec_4bf4d_00001_1_batch_size=128,embedding_dim=128,lr=0.0050,walk_length=12_2024-04-03_01-44-55/checkpoint_000000/checkpoint"

In [6]:
model = MetaPath2VecLightningModule.load_from_checkpoint(chkpt_filepath).to(device)

Computing on cpu
True
1
Tesla P100-PCIE-12GB
 
        Metapath: [('biolink:Gene', 'biolink:interacts_with', 'biolink:Gene'), ('biolink:Gene', 'biolink:orthologous_to', 'biolink:Gene'), ('biolink:Gene', 'biolink:interacts_with', 'biolink:Gene')]
        Total nodes: 862115
        Total node types: 88

        Total edges: 11412471
        Total edge types: 289                
        


In [7]:
model.device

device(type='cpu')

In [8]:
model.eval()

MetaPath2VecLightningModule(
  (model): MetaPath2Vec(559272, 128)
  (val_precision): BinaryPrecision()
  (val_recall): BinaryRecall()
)

In [32]:
graph = torch.load(model.hparams['network_filepath']).to(device)

In [33]:
df = pd.read_csv('./data/02_intermediate/2024-03-31-merged-dataset.tsv', sep="\t")

In [34]:
dataset = torch.load('./data/05_model_input/2024-03-31-merged-dataset.pt', map_location=device).detach().numpy()

In [35]:
metapath = model.hparams.metapath

In [36]:
if metapath is None:
    metapath = model.metapath
    print(metapath)

In [37]:
types = set([x[0] for x in metapath]) | set([x[-1] for x in metapath])
types = sorted(list(types))
print(types)

In [38]:
torch.tensor(dataset[:,:2]).shape

torch.Size([7958, 2])

In [39]:
dataset_gene_idx = np.unique(dataset.flatten())

In [40]:
graph[('biolink:Gene', 'IS_MODIFIER', 'biolink:Gene')].edge_index = torch.tensor(dataset[:,:2]).t().contiguous()

# rev_edge_index = data.edge_index_dict[edge_type].flip([0])
# graph[(edge_type[2], edge_type[1], edge_type[0])].edge_index = rev_edge_index

In [50]:
# We can choose to take dataset-only subgraph for better resolution in training.
# However, we are trying all gene subgraph first
subgraph = graph.node_type_subgraph('biolink:Gene')

In [51]:
assert subgraph['biolink:Gene'].num_nodes ==  subgraph['biolink:Gene'].x.shape[0]

In [52]:
subgraph[("biolink:Gene", 'IS_MODIFIER', "biolink:Gene")].y = torch.tensor(dataset[:,2]).t().contiguous().to(torch.long)

In [53]:
display(subgraph[("biolink:Gene", 'IS_MODIFIER', "biolink:Gene")])
print(subgraph[("biolink:Gene", 'IS_MODIFIER', "biolink:Gene")].num_edges)

{'edge_index': tensor([[516790, 516790, 516790,  ...,  86692,  86692,  86692],
        [518908, 526845, 529553,  ...,  93262,  86304,  89182]],
       dtype=torch.int32), 'y': tensor([1, 1, 1,  ..., 0, 0, 0])}

In [54]:
subgraph.validate()

True

In [55]:
subgraph.edge_types

[('biolink:Gene', 'biolink:orthologous_to', 'biolink:Gene'),
 ('biolink:Gene', 'biolink:interacts_with', 'biolink:Gene'),
 ('biolink:Gene', 'IS_MODIFIER', 'biolink:Gene')]

In [56]:
# Assuming `data` is your HeteroData object
# node_types_to_remove = [node_type for node_type, node_store in subgraph.node_stores.items() if node_store.num_nodes == 0]

for idx in range(len(subgraph.node_stores)):
    if 'num_nodes' not in subgraph.node_stores[idx]:
        print(subgraph.node_stores[idx])
        del subgraph.node_stores[idx]

In [57]:
subgraph.node_stores

[{'num_nodes': 559272, 'x': tensor([[1.],
         [1.],
         [1.],
         ...,
         [1.],
         [1.],
         [1.]])}]

In [58]:
# For this, we first split the set of edges into
# training (80%), validation (10%), and testing edges (10%).
# Across the training edges, we use 70% of edges for message passing,
# and 30% of edges for supervision.
# We further want to generate fixed negative edges for evaluation with a ratio of 2:1.
# Negative edges during training will be generated on-the-fly, so we don't want to
# add them to the graph right away.
# Overall, we can leverage the `RandomLinkSplit()` transform for this from PyG:
from torch_geometric.transforms import RandomLinkSplit

transform = RandomLinkSplit(
    num_val=0.20,  # TODO
    num_test=0.05,  # TODO
    # disjoint_train_ratio=...,  # TODO
    # neg_sampling_ratio=...,  # TODO
    # add_negative_train_samples=...,  # TODO
    key='y',
    edge_types=(("biolink:Gene", 'IS_MODIFIER', "biolink:Gene")),
    # rev_edge_types=("movie", "rev_rates", "user"),
)

train_data, val_data, test_data = transform(subgraph)
print("==============")
print("Before Split:")
print("==============")
print(subgraph)
print("==============")
print("Training data:")
print("==============")
print(train_data)
print("================")
print("Validation data:")
print("================")
print(val_data)
print("================")
print("Test data:")
print("================")
print(test_data)

# assert train_data["user", "rates", "movie"].num_edges == 56469
# assert train_data["user", "rates", "movie"].edge_label_index.size(1) == 24201
# assert train_data["movie", "rev_rates", "user"].num_edges == 56469
# # No negative edges added:
# assert train_data["user", "rates", "movie"].edge_label.min() == 1
# assert train_data["user", "rates", "movie"].edge_label.max() == 1

# assert val_data["user", "rates", "movie"].num_edges == 80670
# assert val_data["user", "rates", "movie"].edge_label_index.size(1) == 30249
# assert val_data["movie", "rev_rates", "user"].num_edges == 80670
# # Negative edges with ratio 2:1:
# assert val_data["user", "rates", "movie"].edge_label.long().bincount().tolist() == [20166, 10083]