In [53]:
import os
os.chdir('/Users/cgu3/Documents/Grape-Pi/graphgym')

In [54]:
import logging

import graphgym.custom_graphgym # noqa, register custom modules
import torch
from torch_geometric import seed_everything
import argparse
from torch_geometric.graphgym.config import (
    cfg,
    dump_cfg,
    load_cfg,
    set_out_dir,
    set_run_dir,
)
from torch_geometric.graphgym.model_builder import create_model
from torch_geometric.graphgym.register import train_dict
from torch_geometric.graphgym.train import GraphGymDataModule, train
from torch_geometric.graphgym.utils.agg_runs import agg_runs
from torch_geometric.graphgym.utils.comp_budget import params_count
from torch_geometric.graphgym.utils.device import auto_select_device

from graphgym import logger
import shlex

parser = argparse.ArgumentParser(description='GraphGym')
parser.add_argument('--cfg', dest='cfg_file', type=str, required=True,
                    help='The configuration file path.')
parser.add_argument('--repeat', type=int, default=1,
                    help='The number of repeated jobs.')
parser.add_argument('--mark_done', action='store_true',
                    help='Mark yaml as done after a job has finished.')
parser.add_argument('opts', default=None, nargs=argparse.REMAINDER,
                    help='See graphgym/config.py for remaining options.')



# Load cmd line args
args = parser.parse_args(shlex.split('--cfg configs/protein/gastric-graphsage-soft-label.yaml'))
# Load config file
load_cfg(cfg, args)
# Set Pytorch environment
torch.set_num_threads(cfg.num_threads)
# Repeat for different random seeds
logger.set_printing()

seed_everything(cfg.seed)
auto_select_device()
# Set machine learning pipeline

model, datamodule = None, None
# use the right customized datamodule and graphgymmodule
if cfg.train.grape_pi == 'graphsage':
    datamodule = train_dict['graphsage_graphgym_datamodule']()
    model = train_dict['graphsage_create_model']()
    # train = train_dict['graphsage_train']
elif cfg.train.grape_pi == 'gcnconv':
    datamodule = GraphGymDataModule()
    model = train_dict['gcnconv_create_model']()
    # train = train_dict['gcnconv_train']

data_batch = next(iter(datamodule.loaders[0]))
model(data_batch)

# Print model info
logging.info(model)
logging.info(cfg)
cfg.params = params_count(model)
logging.info('Num parameters: %s', cfg.params)

model.load_state_dict(torch.load('/Users/cgu3/Documents/Grape-Pi/saved_results/gastric-graphsage-soft-label/0/ckpt/epoch=99-step=9600.ckpt')['state_dict'])


GraphsageGraphGymModule(
  (model): GNN(
    (encoder): FeatureEncoder()
    (pre_mp): GeneralMultiLayer(
      (Layer_0): GeneralLayer(
        (layer): Linear(
          (model): Linear(1, 10, bias=True)
        )
        (post_layer): Sequential(
          (0): ReLU()
        )
      )
    )
    (mp): GNNStackStage(
      (layer0): GeneralLayer(
        (layer): SAGEConv(
          (model): SAGEConv(10, 10, aggr=mean)
        )
        (post_layer): Sequential(
          (0): ReLU()
        )
      )
    )
    (post_mp): ExampleNodeHead(
      (layer_post_mp): MLP(
        (model): Sequential(
          (0): Linear(
            (model): Linear(10, 1, bias=True)
          )
        )
      )
    )
  )
)
accelerator: cpu
benchmark: False
bn:
  eps: 1e-05
  mom: 0.1
cfg_dest: config.yaml
custom_metrics: []
dataset:
  cache_load: False
  cache_save: False
  dir: ../data/single-soft-label
  edge_dim: 128
  edge_encoder: False
  edge_encoder_bn: True
  edge_encoder_name: Bond
  edge_messa

<All keys matched successfully>

In [58]:
# check why it is full-size batch and need to have original ID in the dataset
# how to retrieve original ID
model.eval()
with torch.no_grad():
    for batch in datamodule.loaders[1]:
        # batch = batch.to(cfg.accelerator)
        logits, true = model(batch)
        batch_mask = torch.cat([torch.ones(batch.batch_size), torch.zeros(len(batch.y) - batch.batch_size)], dim=0)
        batch_mask = batch_mask.bool()
        # for each batch, only use test nodes in the original mini-batch nodes
        logits, true = logits[batch_mask & batch.test_mask], true[batch_mask & batch.test_mask]
        logits.squeeze_(-1)
        pre_prob = torch.nn.functional.sigmoid(logits)
        print(pre_prob)
        print(true)

tensor([0.9846, 0.7246, 0.7594, 0.8309, 0.9855, 0.7146, 0.7903, 0.9855, 0.8620,
        0.7229, 0.7216, 0.8104, 0.6917, 0.9387, 0.8324, 0.7320, 0.8754, 0.7130,
        0.7604, 0.7912, 0.8094, 0.8835])
tensor([0.9792, 0.7840, 0.8552, 0.9796, 0.9999, 0.6696, 0.6229, 0.9999, 0.9726,
        0.8648, 0.8807, 0.9713, 0.7343, 0.9677, 0.9408, 0.2259, 0.7554, 0.5107,
        0.6694, 0.5175, 0.7110, 0.9594])
tensor([0.9361, 0.9841, 0.8637, 0.7099, 0.8323, 0.8945, 0.8030, 0.8248, 0.9641,
        0.8639, 0.6874, 0.9839, 0.6978, 0.9855, 0.7829, 0.7458, 0.8173, 0.9855,
        0.6926, 0.8519, 0.6878, 0.7048])
tensor([0.8564, 0.9929, 0.6873, 0.9894, 0.6458, 1.0000, 0.4027, 0.9435, 0.9604,
        0.9215, 0.6193, 0.9929, 0.8005, 0.9986, 0.3540, 0.4232, 0.3837, 1.0000,
        0.2079, 0.4347, 0.5594, 0.8891])
tensor([0.7912, 0.7401, 0.9855, 0.7932, 0.7093, 0.7916, 0.8327, 0.9855, 0.8459,
        0.7236, 0.7218, 0.8254, 0.9843, 0.9855, 0.7089, 0.8336, 0.7201, 0.7917,
        0.8266, 0.8233, 0.8841, 0.75