In [77]:
import os

import pandas as pd

os.chdir('/Users/cgu3/Documents/Grape-Pi/graphgym')

In [62]:
import logging

import graphgym.custom_graphgym # noqa, register custom modules
import torch
from torch_geometric import seed_everything
import argparse
from torch_geometric.graphgym.config import (
    cfg,
    dump_cfg,
    load_cfg,
    set_out_dir,
    set_run_dir,
)
from torch_geometric.graphgym.model_builder import create_model
from torch_geometric.graphgym.register import train_dict
from torch_geometric.graphgym.train import GraphGymDataModule, train
from torch_geometric.graphgym.utils.agg_runs import agg_runs
from torch_geometric.graphgym.utils.comp_budget import params_count
from torch_geometric.graphgym.utils.device import auto_select_device

from graphgym import logger
import shlex

parser = argparse.ArgumentParser(description='GraphGym')
parser.add_argument('--cfg', dest='cfg_file', type=str, required=True,
                    help='The configuration file path.')
parser.add_argument('--repeat', type=int, default=1,
                    help='The number of repeated jobs.')
parser.add_argument('--mark_done', action='store_true',
                    help='Mark yaml as done after a job has finished.')
parser.add_argument('opts', default=None, nargs=argparse.REMAINDER,
                    help='See graphgym/config.py for remaining options.')



# Load cmd line args
args = parser.parse_args(shlex.split('--cfg configs/protein/gastric-graphsage-soft-label.yaml'))
# Load config file
load_cfg(cfg, args)
# Set Pytorch environment
torch.set_num_threads(cfg.num_threads)
# Repeat for different random seeds
logger.set_printing()

seed_everything(cfg.seed)
auto_select_device()
# Set machine learning pipeline

model, datamodule = None, None
# use the right customized datamodule and graphgymmodule
if cfg.train.grape_pi == 'graphsage':
    datamodule = train_dict['graphsage_graphgym_datamodule']()
    model = train_dict['graphsage_create_model']()
    # train = train_dict['graphsage_train']
elif cfg.train.grape_pi == 'gcnconv':
    datamodule = GraphGymDataModule()
    model = train_dict['gcnconv_create_model']()
    # train = train_dict['gcnconv_train']

data_batch = next(iter(datamodule.loaders[0]))
mapping = data_batch.mapping
model(data_batch)

# Print model info
logging.info(model)
logging.info(cfg)
cfg.params = params_count(model)
logging.info('Num parameters: %s', cfg.params)

model.load_state_dict(torch.load('/Volumes/cgu3/graph-neural-network/saved_results/gastric-graphsage-soft-label-with-mRNA/0/ckpt/epoch=99-step=9600.ckpt')['state_dict'])


GraphsageGraphGymModule(
  (model): GNN(
    (encoder): FeatureEncoder()
    (pre_mp): GeneralMultiLayer(
      (Layer_0): GeneralLayer(
        (layer): Linear(
          (model): Linear(2, 10, bias=True)
        )
        (post_layer): Sequential(
          (0): ReLU()
        )
      )
    )
    (mp): GNNStackStage(
      (layer0): GeneralLayer(
        (layer): SAGEConv(
          (model): SAGEConv(10, 10, aggr=mean)
        )
        (post_layer): Sequential(
          (0): ReLU()
        )
      )
    )
    (post_mp): ExampleNodeHead(
      (layer_post_mp): MLP(
        (model): Sequential(
          (0): Linear(
            (model): Linear(10, 1, bias=True)
          )
        )
      )
    )
  )
)
accelerator: cpu
benchmark: False
bn:
  eps: 1e-05
  mom: 0.1
cfg_dest: config.yaml
custom_metrics: []
dataset:
  cache_load: False
  cache_save: False
  dir: ../data/single-soft-label
  edge_dim: 128
  edge_encoder: False
  edge_encoder_bn: True
  edge_encoder_name: Bond
  edge_messa

<All keys matched successfully>

In [70]:
# get the dictionary mapping from global node index to original protein accession
mapping = data_batch.mapping
reversed_mapping = {v: k for k, v in mapping.items()}
reversed_mapping

{0: 'A0A024RBG1',
 1: 'A0A075B6H7',
 2: 'A0A075B6H8',
 3: 'A0A075B6L6',
 4: 'A0A075B6N1',
 5: 'A0A075B6N2',
 6: 'A0A075B6N3',
 7: 'A0A075B6R0',
 8: 'A0A075B6S4',
 9: 'A0A075B759',
 10: 'A0A087WSY6',
 11: 'A0A087WUL8',
 12: 'A0A087WUV0',
 13: 'A0A087X0M5',
 14: 'A0A087X1C5',
 15: 'A0A096LP49',
 16: 'A0A096LP55',
 17: 'A0A096LPI5',
 18: 'A0A0A0MS06',
 19: 'A0A0A0MS15',
 20: 'A0A0A0MT36',
 21: 'A0A0A6YYL3',
 22: 'A0A0B4J1V1',
 23: 'A0A0B4J1V2',
 24: 'A0A0B4J1X5',
 25: 'A0A0B4J1X8',
 26: 'A0A0B4J268',
 27: 'A0A0B4J2D5',
 28: 'A0A0C4DH24',
 29: 'A0A0C4DH31',
 30: 'A0A0C4DH32',
 31: 'A0A0C4DH33',
 32: 'A0A0C4DH38',
 33: 'A0A0C4DH42',
 34: 'A0A0C4DH55',
 35: 'A0A0G2JMD5',
 36: 'A0A0G2JMI3',
 37: 'A0A0J9YVY3',
 38: 'A0A0J9YWL9',
 39: 'A0A0J9YX94',
 40: 'A0A0J9YXV3',
 41: 'A0A0J9YXX1',
 42: 'A0A0U1RQS6',
 43: 'A0A0U1RRI6',
 44: 'A0A1B0GTU1',
 45: 'A0A1B0GTY4',
 46: 'A0A1B0GU33',
 47: 'A0A1B0GUA5',
 48: 'A0A1B0GUS4',
 49: 'A0A1B0GUU1',
 50: 'A0A1B0GUV1',
 51: 'A0A1B0GUW6',
 52: 'A0A1B0GV03',
 53

In [65]:
datamodule.loaders[1]

NeighborLoader()

In [100]:
# check why it is full-size batch and need to have original ID in the dataset
# how to retrieve original ID
model.eval()
accession = []
all_pred_prob = []
all_soft_label = []
all_raw_prob = []

with torch.no_grad():
    for batch in datamodule.loaders[1]:
        # batch = batch.to(cfg.accelerator)
        
        # for each batch, only use test nodes in the original mini-batch nodes
        batch_mask = torch.cat([torch.ones(batch.batch_size), torch.zeros(len(batch.y) - batch.batch_size)], dim=0)
        batch_mask = batch_mask.bool()
        mask = batch_mask & batch.test_mask
        raw_prob = batch.x[:, 0][mask]
        
        logits, true = model(batch)

        # for each batch, only use test nodes in the original mini-batch nodes
        global_node_idx = batch.n_id[mask]
        original_id = [reversed_mapping[key] for key in global_node_idx.tolist()]
        
        logits, true = logits[mask], true[mask]
        logits.squeeze_(-1)
        pred_prob = torch.nn.functional.sigmoid(logits)
        
        
        
        accession = accession + original_id
        all_raw_prob += raw_prob.tolist()
        all_pred_prob += pred_prob.tolist()
        all_soft_label += true.tolist()
        # print(original_id)
        # print(pred_prob)
        # print(true)

In [101]:
pd.DataFrame({'accession': accession, 'raw_prob': all_raw_prob, 'pred_prob': all_pred_prob, 'soft_label': all_soft_label})

Unnamed: 0,accession,raw_prob,pred_prob,soft_label
0,A0A087X0M5,0.976280,0.983320,0.979157
1,A0A087X1C5,0.243009,0.675054,0.784043
2,A0A0A6YYL3,0.341797,0.697359,0.855166
3,A0A0B4J1V2,0.572661,0.776625,0.979623
4,A0A0B4J1X8,0.999882,0.983741,0.999882
...,...,...,...,...
4080,Q9BXW3,0.000000,0.729058,0.982459
4081,Q9HAA7,0.000000,0.570404,0.000000
4082,Q9UF83,0.000000,0.570404,0.184923
4083,Q9Y6Z2,0.000000,0.583879,0.000000
