python == 3.7.4
torch == 1.6.0
torch-cluster == 1.5.7
torch_scatter == 2.0.6
torch-sparse == 0.6.9
torch-geometric == 2.0.0
pandas == 1.2.5
numpy == 1.17.2
conda create -n DAGC python==3.7.4
conda activate DAGC
pip install torch==1.6.0 -f https://download.pytorch.org/whl/cu101/torch-1.6.0%2Bcu101-cp37-cp37m-linux_x86_64.whl
pip install torch-cluster==1.5.7 -f https://data.pyg.org/whl/torch-1.6.0%2Bcu101.html
pip install torch_scatter==2.0.6 -f https://data.pyg.org/whl/torch-1.6.0%2Bcu101.html
pip install torch-sparse==0.6.9 -f https://data.pyg.org/whl/torch-1.6.0%2Bcu101.html
pip install torch-geometric==2.0.0
pip install pandas==1.2.5
pip install numpy==1.17.2
A quick start example is given by:
$ python auto_test.py --data_name MUTAG --gpu 0
An example of auto search is as follows:
$ python auto_main.py --data_name MUTAG --gpu 0
or
$ python auto_main.py --data_name COX2 --gpu 0
Name | Default value | Description |
---|---|---|
data_name | MUTAG | the name of dataset for graph classification |
gpu | 0 | gpu device id |
epochs | 100 | the num of training epochs of each GNN architecture to obtain the reward value of each GNN architecture during auto-search |
epochs_test | 100 | the num of training epochs during testing the optimal GNN architecture, which will be trained from scratch with 20-fold cross-validation |
controller_train_epoch | 200 | the train epoch of learnable agent based on reinforcement learning |
search_scale | 100 | the number of good GNN architecture predicted by the trained learnable agent |
DAGC is very friendly for users to implement customization, users can freely define their own functional components as long as they follow the custom specification. Users can know the custom specification of each functional component in the following list, which is very simple. The list of definable components is as follows:
First, user can define the input graph data if user use non benchmark datasets. The data format is like torch-geometric:
from torch_geometric.data import InMemoryDataset
from torch_geometric.data import Data
import torch
class PyGDataset(InMemoryDataset):
"""
data: list
every element is a sample, in which contain feature, torch_edge_index, label
dataset_save_path: str
the saving path for dataset
"""
def __init__(self, data, dataset_save_path):
self.dataset = self.convert_pyg_data(data)
self.dataset_save_path = dataset_save_path
if not os.path.exists(self.dataset_save_path):
os.mkdir(self.dataset_save_path)
if os.path.exists(self.dataset_save_path + 'processed/' + self.processed_file_names[0]):
os.remove(self.dataset_save_path + 'processed/' + self.processed_file_names[0])
super(PyGDataset, self).__init__(root=self.dataset_save_path)
self.data, self.slices = torch.load(self.processed_paths[0])
def convert_pyg_data(self, data):
converted_data = []
for d in data:
feature = d[0]
torch_edge_index = d[1]
label = d[2]
c_d = Data(x=feature, edge_index=torch_edge_index, y=label)
converted_data.append(c_d)
return converted_data
@property
def processed_file_names(self):
return ['data.pt']
def process(self):
data_list = self.dataset
data, slices = self.collate(data_list)
print("Saving processed files...")
torch.save((data, slices), self.processed_paths[0])
print('Saving complete!')
Then, user can modify the search space according to the need. The default setting of search space is as follows:
Search Component | Candidate Operations |
---|---|
GNN Depth | 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 |
Aggregation Function | GCNConv, GATConv, SAGEConv, GINConv, GraphConv, GeneralConv |
Local Pooling | TopKPool, SAGPool, ASAPool, PANPool, HopPool, GCPool, GAPPool, None |
Global Pooling | GlobalMaxPool, GlobalMeanPool, GlobalSumPool |
modify the 'gnn_layers' key in the search_space_config.py
modify the 'conv' key in the search_space_config.py
and, define the added aggregation function like convolution.py
last, register to the 'conv_map' function in utils.py
modify the 'local_pooling' key in the search_space_config.py
and, define the added local pooling function like local_pooling.py
last, register to the 'local_pooling_map' function in utils.py
modify the 'global_pooling' key in the search_space_config.py
and, define the added global pooling function like torch_geometric.nn.pool.glob
last, register to the 'global_pooling_map' function in utils.py