In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import logging
import os
import custom_graphgym  # noqa, register custom modules
import torch
from pytorch_lightning import seed_everything
from torch_geometric.graphgym.cmd_args import parse_args
from torch_geometric.graphgym.config import (
    cfg,
    dump_cfg,
    load_cfg,
    set_out_dir,
    set_run_dir,
)
from torch_geometric.graphgym.logger import set_printing
from torch_geometric.graphgym.model_builder import create_model
from torch_geometric.graphgym.train import GraphGymDataModule, train
from torch_geometric.graphgym.utils.agg_runs import agg_runs
from torch_geometric.graphgym.utils.comp_budget import params_count
from torch_geometric.graphgym.utils.device import auto_select_device
import pandas as pd
import numpy as np
from torch_geometric.graphgym import register

The following notebook covers dataset creation, config load and training procedures. 

In [4]:
# Parse the input arguments for the script

import argparse

class NotebookArgParser:
    def __init__(self, args_str):
        self.args = self.parse_args(args_str)

    def parse_args(self, args_str):
        parser = argparse.ArgumentParser(description='GraphGym')

        # Add command-line arguments
        parser.add_argument('--cfg',
                            dest='cfg_file',
                            type=str,
                            required=True,
                            help='The configuration file path.')
        parser.add_argument('--repeat',
                            type=int,
                            default=1,
                            help='The number of repeated jobs.')
        parser.add_argument('--mark_done',
                            action='store_true',
                            help='Mark yaml as done after a job has finished.')
        parser.add_argument('opts',
                            default=None,
                            nargs=argparse.REMAINDER,
                            help='See graphgym/config.py for remaining options.')

        # Parse the command-line arguments
        args = parser.parse_args(args_str)

        return args

In [5]:
config_name = 'yeast_static'

In [6]:
# Emulate command-line arguments using input cells
command = f"python main_pyg.py --cfg ./configs/pyg/{config_name}.yaml --repeat 1"
args_str = command.split()[2:]
# Create a NotebookArgParser instance
notebook_parser = NotebookArgParser(args_str)

# Access parsed arguments
args = notebook_parser.args
print("Parsed Arguments:")
print(f"Configuration File: {args.cfg_file}")
print(f"Repeat: {args.repeat}")
print(f"Mark Done: {args.mark_done}")
print(f"Remaining Options: {args.opts}")

Parsed Arguments:
Configuration File: ./configs/pyg/yeast_static.yaml
Repeat: 1
Mark Done: False
Remaining Options: []


In [7]:
notebook_parser = NotebookArgParser(args_str)

# Access parsed arguments
args = notebook_parser.args

In [None]:
# Load config file
load_cfg(cfg, args)
set_out_dir(cfg.out_dir, args.cfg_file)
# Set Pytorch environment
torch.set_num_threads(cfg.num_threads)
dump_cfg(cfg)
# Repeat for different random seeds
for i in range(args.repeat):
    set_run_dir(cfg.out_dir, i)
    set_printing()
    # Set configurations for each run
    cfg.seed = cfg.seed + 1
    seed_everything(cfg.seed, workers=True)
    auto_select_device() # if not set in the yaml config, set to cuda accelerator if available and single device
    # Set machine learning pipeline
    if cfg.dataset.name in ['yeast-ppi']:
            print('Loading BIOGRID')
            datamodule = register.train_dict["BioGridGraphGymDataModule"](split_type = cfg.dataset.split_type)    
    else:
        print('Loading grn-ecoli')
        datamodule = register.train_dict["CustomGraphGymDataModule"](split_type = cfg.dataset.split_type)
    cfg.share.dim_out = 1 
    cfg.share.num_splits = 3
    model = create_model()
    # Print model info
    logging.info(model)
    logging.info(cfg)
    cfg.params = params_count(model)
    logging.info('Num parameters: %s', cfg.params)
    # Call the custom training function
    register.train_dict["train_pl"](model, datamodule, logger=True)