In [1]:
import XGCN
from XGCN.utils import io
from XGCN.utils.utils import print_dict, ensure_dir, set_random_seed

In [2]:
import os.path as osp

In [3]:
all_data_root = "/media/xreco/DEV/xiran/code/XGCN_data"

In [4]:
dataset_name = 'facebook'
model_name = 'GraphSAGE'

In [5]:
seed = 0
dataset_root = osp.join(all_data_root, 'dataset/instance_' + dataset_name)
file_val_set = osp.join(dataset_root, 'val_set.pkl')
file_test_set = osp.join(dataset_root, 'test_set.pkl')

results_dir = '[seed{}][cluster8-2]'.format(seed)
results_root = osp.join(
    all_data_root, 'model_output', dataset_name, model_name, results_dir
)

In [6]:
config = {
    "seed": seed,
    "model": model_name,

    "data_root": dataset_root,
    "results_root": results_root,
    
    "partition_cache_filepath": osp.join(results_root, 'partition_result.pkl'),
    "num_parts": 8,
    "group_size": 2,
    
    # "Dataset_type": "BlockDataset",
    "num_workers": 0,
    "num_gcn_layers": 2,
    "train_num_layer_sample": "[]",

    # "NodeListDataset_type": "LinkDataset",
    # "pos_sampler": "ObservedEdges_Sampler",
    # "neg_sampler": "RandomNeg_Sampler",
    # "num_neg": 1,

    # "BatchSampleIndicesGenerator_type": "SampleIndicesWithReplacement",
    "train_batch_size": 2048,
    # "train_edge_sample_ratio": 0.1,

    "val_evaluator": "WholeGraph_MultiPos_Evaluator",
    "val_batch_size": 256,
    "file_val_set": file_val_set,

    "test_evaluator": "WholeGraph_MultiPos_Evaluator",
    "test_batch_size": 256,
    "file_test_set": file_test_set,

    "epochs": 200,
    "val_freq": 1,
    "key_score_metric": "r100",
    "convergence_threshold": 20,
    
    "forward_mode": "sub_graph",
    "graph_device": "cpu",
    "subgraph_device": "cuda",
    "emb_table_device": "cuda",
    "gnn_device": "cuda",
    "out_emb_table_device": "cuda",

    "from_pretrained": 0,
    "file_pretrained_emb": "",

    "freeze_emb": 0,
    "use_sparse": 0,

    "emb_dim": 64,
    "emb_init_std": 0.1,
    "emb_lr": 0.01,

    "gnn_arch": '[{"in_feats": 64, "out_feats": 64, "aggregator_type": "pool", "activation": torch.tanh}, {"in_feats": 64, "out_feats": 64, "aggregator_type": "pool"}]',
    "gnn_lr": 0.01,

    "loss_type": "bpr",
    "L2_reg_weight": 0.0,

    "infer_num_layer_sample": "[]"
}

In [7]:
set_random_seed(config['seed'])

results_root = config['results_root']
ensure_dir(results_root)
io.save_yaml(osp.join(results_root, 'config.yaml'), config)

In [8]:
data = {}
model = XGCN.build_Model(config, data)

train_dl = XGCN.build_DataLoader(config, data)

val_evaluator = XGCN.build_val_Evaluator(config, data, model)
test_evaluator = XGCN.build_test_Evaluator(config, data, model)

trainer = XGCN.build_Trainer(config, data, model, train_dl, 
                                val_evaluator, test_evaluator)

In [9]:
trainer.train_and_test()

  d[key] = value
val: 100%|██████████| 8/8 [00:02<00:00,  2.81it/s]


val: {'r20': 0.023, 'r50': 0.0495, 'r100': 0.0725, 'r300': 0.1765, 'n20': 0.006376032903790475, 'n50': 0.011577807568013668, 'n100': 0.015310970462858679, 'n300': 0.02919703843072057}
>> new best score - r100 : 0.0725
train epoch 1


74it [00:00, 94.67it/s]              
val: 100%|██████████| 8/8 [00:00<00:00, 11.35it/s]


val: {'r20': 0.2125, 'r50': 0.3865, 'r100': 0.545, 'r300': 0.773, 'n20': 0.08200331737846137, 'n50': 0.11659545716643334, 'n100': 0.14217576391994954, 'n300': 0.17286243205890062}
>> new best score - r100 : 0.545
train epoch 2


75it [00:00, 82.70it/s]              
val: 100%|██████████| 8/8 [00:00<00:00, 11.72it/s]


val: {'r20': 0.21750000000000003, 'r50': 0.3765, 'r100': 0.548, 'r300': 0.7735000000000001, 'n20': 0.08010385914891958, 'n50': 0.11140814793854951, 'n100': 0.1391450522840023, 'n300': 0.16971439379826186}
>> new best score - r100 : 0.548
train epoch 3


75it [00:00, 93.78it/s]              
val: 100%|██████████| 8/8 [00:00<00:00, 11.78it/s]


val: {'r20': 0.254, 'r50': 0.4565, 'r100': 0.627, 'r300': 0.8305, 'n20': 0.08618625354766846, 'n50': 0.12625139208137992, 'n100': 0.15390588760375978, 'n300': 0.18158616364747288}
>> new best score - r100 : 0.627
train epoch 4


75it [00:00, 85.09it/s]              
val: 100%|██████████| 8/8 [00:00<00:00, 11.79it/s]


val: {'r20': 0.1355, 'r50': 0.28850000000000003, 'r100': 0.444, 'r300': 0.752, 'n20': 0.04838614316284656, 'n50': 0.07866126888245344, 'n100': 0.10384456320852042, 'n300': 0.14514034586027263}
>> distance_between_best_epoch: 1 threshold: 20
train epoch 5


75it [00:00, 92.38it/s]              
val: 100%|██████████| 8/8 [00:00<00:00, 11.63it/s]


val: {'r20': 0.185, 'r50': 0.373, 'r100': 0.545, 'r300': 0.794, 'n20': 0.06761031030863524, 'n50': 0.10477740429341793, 'n100': 0.1327407001629472, 'n300': 0.16652721863240005}
>> distance_between_best_epoch: 2 threshold: 20
train epoch 6


75it [00:00, 93.85it/s]              
val: 100%|██████████| 8/8 [00:00<00:00, 11.88it/s]


val: {'r20': 0.195, 'r50': 0.3605, 'r100': 0.5215000000000001, 'r300': 0.7885, 'n20': 0.06947568245232105, 'n50': 0.10223319512605666, 'n100': 0.12844292387366293, 'n300': 0.16451227578520777}
>> distance_between_best_epoch: 3 threshold: 20
train epoch 7


75it [00:00, 83.24it/s]              
val: 100%|██████████| 8/8 [00:00<00:00, 11.84it/s]


val: {'r20': 0.184, 'r50': 0.356, 'r100': 0.5155000000000001, 'r300': 0.76, 'n20': 0.06242803066968918, 'n50': 0.0966092528924346, 'n100': 0.12235618705302478, 'n300': 0.1553657201975584}
>> distance_between_best_epoch: 4 threshold: 20
train epoch 8


76it [00:00, 94.45it/s]              
val: 100%|██████████| 8/8 [00:00<00:00, 11.54it/s]


val: {'r20': 0.1405, 'r50': 0.28, 'r100': 0.44499999999999995, 'r300': 0.727, 'n20': 0.049304743938148016, 'n50': 0.07677892279624939, 'n100': 0.10350973045825959, 'n300': 0.1412840737812221}
>> distance_between_best_epoch: 5 threshold: 20
train epoch 9


74it [00:00, 92.85it/s]              
val: 100%|██████████| 8/8 [00:00<00:00, 11.76it/s]


val: {'r20': 0.2275, 'r50': 0.3915, 'r100': 0.5375000000000001, 'r300': 0.788, 'n20': 0.07687943713366985, 'n50': 0.10919061303138733, 'n100': 0.13297558528929948, 'n300': 0.1667641249448061}
>> distance_between_best_epoch: 6 threshold: 20
train epoch 10


74it [00:00, 93.94it/s]              
val: 100%|██████████| 8/8 [00:00<00:00, 11.65it/s]


val: {'r20': 0.243, 'r50': 0.4605, 'r100': 0.6365, 'r300': 0.8695, 'n20': 0.08747658411413431, 'n50': 0.13026535496115682, 'n100': 0.1587206047922373, 'n300': 0.19043571224808692}
>> new best score - r100 : 0.6365
train epoch 11


75it [00:00, 94.27it/s]              
val: 100%|██████████| 8/8 [00:00<00:00, 12.00it/s]


val: {'r20': 0.177, 'r50': 0.33749999999999997, 'r100': 0.522, 'r300': 0.821, 'n20': 0.06016192609071731, 'n50': 0.09178349446505309, 'n100': 0.12178446788340808, 'n300': 0.16236013805121183}
>> distance_between_best_epoch: 1 threshold: 20
train epoch 12


76it [00:00, 84.83it/s]              
val: 100%|██████████| 8/8 [00:00<00:00, 11.41it/s]


val: {'r20': 0.23349999999999999, 'r50': 0.43, 'r100': 0.6065, 'r300': 0.851, 'n20': 0.08639779559522867, 'n50': 0.12508303473889826, 'n100': 0.15367623800784352, 'n300': 0.18678112192451954}
>> distance_between_best_epoch: 2 threshold: 20
train epoch 13


75it [00:00, 92.89it/s]              
val: 100%|██████████| 8/8 [00:00<00:00, 11.92it/s]


val: {'r20': 0.2265, 'r50': 0.41600000000000004, 'r100': 0.625, 'r300': 0.892, 'n20': 0.07799983543157578, 'n50': 0.11533568876236677, 'n100': 0.14917067636549472, 'n300': 0.18545357714220884}
>> distance_between_best_epoch: 3 threshold: 20
train epoch 14


23it [00:00, 89.72it/s]              
test: 100%|██████████| 11/11 [00:00<00:00, 11.41it/s]

test: {'r20': 0.2268947520408134, 'r50': 0.4298230023020434, 'r100': 0.639047283841216, 'r300': 0.8895521520803592, 'n20': 0.10778731456078214, 'n50': 0.16311242970945322, 'n100': 0.209644807172051, 'n300': 0.25663913208898353, 'formatted': 'r20:0.2269 || r50:0.4298 || r100:0.6390 || r300:0.8896 || n20:0.1078 || n50:0.1631 || n100:0.2096 || n300:0.2566 || '}



