In [None]:
import os
os.chdir('/Users/cgu3/Documents/Grape-Pi/graphgym')

In [48]:
import logging

import graphgym.custom_graphgym # noqa, register custom modules
import torch
from torch_geometric import seed_everything
import argparse
from torch_geometric.graphgym.config import (
    cfg,
    dump_cfg,
    load_cfg,
    set_out_dir,
    set_run_dir,
)
from torch_geometric.graphgym.model_builder import create_model
from torch_geometric.graphgym.register import train_dict
from torch_geometric.graphgym.train import GraphGymDataModule, train
from torch_geometric.graphgym.utils.agg_runs import agg_runs
from torch_geometric.graphgym.utils.comp_budget import params_count
from torch_geometric.graphgym.utils.device import auto_select_device

from graphgym import logger
import shlex

parser = argparse.ArgumentParser(description='GraphGym')
parser.add_argument('--cfg', dest='cfg_file', type=str, required=True,
                    help='The configuration file path.')
parser.add_argument('--repeat', type=int, default=1,
                    help='The number of repeated jobs.')
parser.add_argument('--mark_done', action='store_true',
                    help='Mark yaml as done after a job has finished.')
parser.add_argument('opts', default=None, nargs=argparse.REMAINDER,
                    help='See graphgym/config.py for remaining options.')



# Load cmd line args
args = parser.parse_args(shlex.split('--cfg configs/protein/gastric-graphsage-soft-label.yaml'))
# Load config file
load_cfg(cfg, args)
# Set Pytorch environment
torch.set_num_threads(cfg.num_threads)
# Repeat for different random seeds
logger.set_printing()

seed_everything(cfg.seed)
auto_select_device()
# Set machine learning pipeline

model, datamodule = None, None
# use the right customized datamodule and graphgymmodule
if cfg.train.grape_pi == 'graphsage':
    datamodule = train_dict['graphsage_graphgym_datamodule']()
    model = train_dict['graphsage_create_model']()
    # train = train_dict['graphsage_train']
elif cfg.train.grape_pi == 'gcnconv':
    datamodule = GraphGymDataModule()
    model = train_dict['gcnconv_create_model']()
    # train = train_dict['gcnconv_train']

data_batch = next(iter(datamodule.loaders[0]))
model(data_batch)

# Print model info
logging.info(model)
logging.info(cfg)
cfg.params = params_count(model)
logging.info('Num parameters: %s', cfg.params)

model.load_state_dict(torch.load('/Users/cgu3/Documents/Grape-Pi/saved_results/gastric-graphsage-soft-label/0/ckpt/epoch=99-step=9600.ckpt')['state_dict'])


GraphsageGraphGymModule(
  (model): GNN(
    (encoder): FeatureEncoder()
    (pre_mp): GeneralMultiLayer(
      (Layer_0): GeneralLayer(
        (layer): Linear(
          (model): Linear(1, 10, bias=True)
        )
        (post_layer): Sequential(
          (0): ReLU()
        )
      )
    )
    (mp): GNNStackStage(
      (layer0): GeneralLayer(
        (layer): SAGEConv(
          (model): SAGEConv(10, 10, aggr=mean)
        )
        (post_layer): Sequential(
          (0): ReLU()
        )
      )
    )
    (post_mp): ExampleNodeHead(
      (layer_post_mp): MLP(
        (model): Sequential(
          (0): Linear(
            (model): Linear(10, 1, bias=True)
          )
        )
      )
    )
  )
)
accelerator: cpu
benchmark: False
bn:
  eps: 1e-05
  mom: 0.1
cfg_dest: config.yaml
custom_metrics: []
dataset:
  cache_load: False
  cache_save: False
  dir: ../data/single-soft-label
  edge_dim: 128
  edge_encoder: False
  edge_encoder_bn: True
  edge_encoder_name: Bond
  edge_messa

<All keys matched successfully>

In [51]:
# check why it is full-size batch and need to have original ID in the dataset

model.eval()
with torch.no_grad():
    for batch in datamodule.loaders[1]:
        # batch = batch.to(cfg.accelerator)
        logits, true = model(batch)
        logits.squeeze_(-1)
        pre_prob = torch.nn.functional.sigmoid(logits)
        print(pre_prob)
        print(batch.y)

tensor([0.9494, 0.8825, 0.8027,  ..., 0.7440, 0.9760, 0.7440])
tensor([0.6953, 0.3215, 0.0851,  ..., 0.1640, 0.3708, 0.2497])
tensor([0.9815, 0.8622, 0.8326,  ..., 0.7974, 0.8462, 0.7715])
tensor([0.9398, 0.5928, 0.5864,  ..., 0.3749, 0.6537, 0.2740])
tensor([0.8184, 0.7451, 0.7808,  ..., 0.9853, 0.9853, 0.7440])
tensor([0.1593, 0.6712, 0.2991,  ..., 1.0000, 0.9998, 0.2935])
tensor([0.9826, 0.6864, 0.6538,  ..., 0.7440, 0.7879, 0.9615])
tensor([0.9972, 0.0936, 0.2152,  ..., 0.1211, 0.4463, 0.8313])
tensor([0.7677, 0.7717, 0.8024,  ..., 0.7440, 0.7440, 0.9853])
tensor([0.5043, 0.6033, 0.5178,  ..., 0.1217, 0.0000, 0.8158])
tensor([0.7543, 0.9233, 0.8382,  ..., 0.8431, 0.7440, 0.8529])
tensor([0.7218, 0.9518, 0.8988,  ..., 0.7838, 0.0000, 0.5119])
tensor([0.8397, 0.8017, 0.8079,  ..., 0.7440, 0.7440, 0.7440])
tensor([0.8337, 0.8726, 0.4443,  ..., 0.1779, 0.2358, 0.0289])
tensor([0.8064, 0.8641, 0.7513,  ..., 0.7440, 0.7924, 0.7440])
tensor([0.4626, 0.5632, 0.8809,  ..., 0.2358, 0.2422, 0