In [None]:
import XGCN
from XGCN.utils import io
from XGCN.utils.utils import print_dict, ensure_dir, set_random_seed

In [None]:
import os.path as osp

In [None]:
all_data_root = "/media/xreco/DEV/xiran/code/XGCN_data"

In [None]:
dataset_name = 'facebook'
model_name = 'GraphSAGE'

In [None]:
seed = 0
dataset_root = osp.join(all_data_root, 'dataset/instance_' + dataset_name)
file_val_set = osp.join(dataset_root, 'val_set.pkl')
file_test_set = osp.join(dataset_root, 'test_set.pkl')

results_dir = '[seed{}]'.format(seed)
results_root = osp.join(
    all_data_root, 'model_output', dataset_name, model_name, results_dir
)

In [None]:
config = {
    "seed": seed,
    "model": model_name,

    "data_root": dataset_root,
    "results_root": results_root,

    "Dataset_type": "BlockDataset",
    "num_workers": 0,
    "num_gcn_layers": 2,
    "train_num_layer_sample": "[10, 10]",

    "NodeListDataset_type": "LinkDataset",
    "pos_sampler": "ObservedEdges_Sampler",
    "neg_sampler": "RandomNeg_Sampler",
    "num_neg": 1,

    "BatchSampleIndicesGenerator_type": "SampleIndicesWithReplacement",
    "train_batch_size": 2048,
    "epoch_sample_ratio": 0.1,

    "val_evaluator": "WholeGraph_MultiPos_Evaluator",
    "val_batch_size": 256,
    "file_val_set": file_val_set,

    "test_evaluator": "WholeGraph_MultiPos_Evaluator",
    "test_batch_size": 256,
    "file_test_set": file_test_set,

    "epochs": 200,
    "val_freq": 1,
    "key_score_metric": "r100",
    "convergence_threshold": 20,
    
    "forward_mode": "sample",
    "graph_device": "cuda",
    "emb_table_device": "cuda",
    "gnn_device": "cuda",
    "out_emb_table_device": "cuda",

    "from_pretrained": 0,
    "file_pretrained_emb": "",

    "freeze_emb": 0,
    "use_sparse": 0,

    "emb_dim": 64,
    "emb_init_std": 0.1,
    "emb_lr": 0.01,

    "gnn_arch": '[{"in_feats": 64, "out_feats": 64, "aggregator_type": "pool", "activation": torch.tanh}, {"in_feats": 64, "out_feats": 64, "aggregator_type": "pool"}]',
    "gnn_lr": 0.01,

    "loss_type": "bpr",
    "L2_reg_weight": 0.0,

    "infer_num_layer_sample": "[]"
}

In [None]:
set_random_seed(config['seed'])

results_root = config['results_root']
ensure_dir(results_root)
io.save_yaml(osp.join(results_root, 'config.yaml'), config)

In [None]:
data = {}
model = XGCN.build_Model(config, data)

train_dl = XGCN.build_DataLoader(config, data)

val_evaluator = XGCN.build_val_Evaluator(config, data, model)
test_evaluator = XGCN.build_test_Evaluator(config, data, model)

trainer = XGCN.build_Trainer(config, data, model, train_dl, 
                                val_evaluator, test_evaluator)

In [None]:
trainer.train_and_test()