# Tutorial: training a GCL benchmark model

This tutorial uses the `ckg_benchmarks` package to train a GCL model and evaluate it on CompanyKG.

The trained model is not likely to produce good results, since the hyperparameters are set for minimal computation (e.g. only one training epoch). But this code can serve as an example for training better models and a test of the GCL model training code.

We demonstrate how to train with GRACE and MVGRL here. (The only difference is the `method` argument to `train_model`.)

You can apply an almost identical training procedure to other GNN training methods by using their `train_model` functions and adjusting the parameters.

In [1]:
%load_ext autoreload
%autoreload 2

We initialize logging so that we see model training progress.

In [2]:
import logging
logger = logging.getLogger()
handler = logging.StreamHandler()
handler.setFormatter(
    logging.Formatter('%(asctime)s %(name)-12s %(levelname)-8s %(message)s')
)
logger.addHandler(handler)
logger.setLevel(logging.INFO)

In [4]:
from companykg import CompanyKG
from ckg_benchmarks.gcl.train import train_model

## Load data

Prepare the CompanyKG dataset. 

The first time you run this, the data will be downloaded from Zenodo to the `data` subdirectory, which could take some time. After that, it will be quicker to load.

The dataset is then loaded into memory using the mSBERT node feature type.

This step is not strictly necessary, as we don't use the loaded data for training: the training routine takes care of loading it itself. But loading it here causes it to be downloaded if necessary.

In [5]:
ckg = CompanyKG(
    nodes_feature_type="msbert", 
    load_edges_weights=True,
)

2023-06-14 12:39:58,723 companykg.kg INFO     [DONE] Loaded ./data/edges.pt
2023-06-14 12:40:10,900 companykg.kg INFO     [DONE] Loaded ./data/edges_weight.pt
2023-06-14 12:40:31,937 companykg.kg INFO     [DONE] Loaded ./data/nodes_feature_msbert.pt
2023-06-14 12:40:33,453 companykg.kg INFO     [DONE] Loaded ./data/eval_task_sp.parquet.gz
2023-06-14 12:40:33,521 companykg.kg INFO     [DONE] Loaded ./data/eval_task_sr.parquet.gz
2023-06-14 12:40:33,526 companykg.kg INFO     [DONE] Loaded ./data/eval_task_cr.parquet.gz


## GRACE training

### Train model
Now we set a minimal GCL model training using GRACE.

Training uses a GPU if it's available.

To train a better model, adjust the parameters set here, in particular `epochs`.

Calling this training method is equivalent to running the following command:
```
python -m ckg_benchmarks.gcl.train \
    --device -1 \
    --method grace \
    --n-layer 1 \
    --embedding-dim 8 \
    --epochs 1 \
    --sampler-edges 2 \
    --batch-size 128
```

In [7]:
trainer = train_model(
    nodes_feature_type="msbert",
    # Train with GRACE; you can also use 'mvgrl' here
    method="grace",
    # Minimum value we usually consider is 8
    embedding_dim=8,
    # Typically we use 2 or 3
    n_layer=1,
    # We usually sample 5 or 10 edges for training
    sampler_edges=2,
    # For GPU you'll want to set your batch size bigger if you can, as it makes it faster
    batch_size=128,
    # For our experiments we trained for 100 epochs, here just 1 for testing
    epochs=1,
)

2023-06-14 12:41:41,929 ckg_benchmarks.base INFO     Initializing model and trainer
2023-06-14 12:41:42,184 companykg.kg INFO     [DONE] Loaded ./data/edges.pt
2023-06-14 12:41:42,912 companykg.kg INFO     [DONE] Loaded ./data/nodes_feature_msbert.pt
2023-06-14 12:41:42,917 companykg.kg INFO     [DONE] Loaded ./data/eval_task_sp.parquet.gz
2023-06-14 12:41:42,921 companykg.kg INFO     [DONE] Loaded ./data/eval_task_sr.parquet.gz
2023-06-14 12:41:42,925 companykg.kg INFO     [DONE] Loaded ./data/eval_task_cr.parquet.gz


data_root_folder=./data
n_nodes=1169931, n_edges=50815503
nodes_feature_type=msbert
nodes_feature_dimension=512
sp: 3219 samples
sr: 1856 samples
cr: 400 samples


2023-06-14 12:41:43,293 ckg_benchmarks.base INFO     Data(x=[1169931, 512], edge_index=[2, 101631006])
2023-06-14 12:41:43,296 ckg_benchmarks.base INFO     Starting model training
2023-06-14 12:42:02,491 ckg_benchmarks.base INFO     Sending training logs to experiments/grace/msbert_1_1_8_2_128_42.log
2023-06-14 12:42:02,492 ckg_benchmarks.base INFO     Strating model training
2023-06-14 12:42:04,991 ckg_benchmarks.gcl.train INFO     Starting epoch 1


Epoch 1/1:   0%|          | 0/9141 [00:00<?, ?it/s]

2023-06-14 12:43:30,915 ckg_benchmarks.gcl.train INFO     Epoch 1 loss: 27873.784133195877
2023-06-14 12:43:30,916 ckg_benchmarks.base INFO     Model training complete
2023-06-14 12:43:30,917 ckg_benchmarks.base INFO     Projecting full KG using final model
2023-06-14 12:44:06,688 ckg_benchmarks.base INFO     Best embeddings saved to experiments/grace/msbert_1_1_8_2_128_42.pt
2023-06-14 12:44:06,689 ckg_benchmarks.base INFO     Model training complete


### Evaluation

At the end of training, the final trained model was used to project all the companies (nodes) in the graph into the learned vector space. We now feed these into the CompanyKG evaluation method to get final results.

In [8]:
eval_results = trainer.evaluate()

Evaluate Custom Embeddings:
Evaluate SP ...
SP AUC: 0.7642423012861067
Evaluate SR ...
SR Validation ACC: 0.625 SR Test ACC: 0.5866935483870968
Evaluate CR with top-K hit rate (K=[50, 100, 200, 500, 1000, 2000, 5000, 10000]) ...
CR Hit Rates: [0.0, 0.018543043214095844, 0.06450880990354674, 0.1578890407837776, 0.24386012379433433, 0.330349548112706, 0.4522841193893825, 0.5296498822814613]


The evaluation results are all stored in a dict, from which we can pick out the ones we want to report.

In [9]:
print(f"SP: {eval_results['sp_auc']}")

SP: 0.7642423012861067


In [10]:
print(f"SR: {eval_results['sr_test_acc']*100.:.2f}%")

SR: 58.67%


In [13]:
print("CR:")
for k, cr in zip(trainer.comkg.eval_cr_top_ks, eval_results['cr_topk_hit_rate']):
    print(f"R@{k}: {cr*100.:.2f}%")

CR:
R@50: 0.00%
R@100: 1.85%
R@200: 6.45%
R@500: 15.79%
R@1000: 24.39%
R@2000: 33.03%
R@5000: 45.23%
R@10000: 52.96%


## MVGRL training

### Training

Training MVGRL is exactly the same as GRACE.

In [15]:
trainer2 = train_model(
    nodes_feature_type="msbert",
    method="mvgrl",
    embedding_dim=8,
    n_layer=1,
    sampler_edges=2,
    batch_size=128,
    epochs=1,
)

2023-06-14 13:04:58,229 ckg_benchmarks.base INFO     Initializing model and trainer
2023-06-14 13:04:58,482 companykg.kg INFO     [DONE] Loaded ./data/edges.pt
2023-06-14 13:04:59,208 companykg.kg INFO     [DONE] Loaded ./data/nodes_feature_msbert.pt
2023-06-14 13:04:59,214 companykg.kg INFO     [DONE] Loaded ./data/eval_task_sp.parquet.gz
2023-06-14 13:04:59,218 companykg.kg INFO     [DONE] Loaded ./data/eval_task_sr.parquet.gz
2023-06-14 13:04:59,221 companykg.kg INFO     [DONE] Loaded ./data/eval_task_cr.parquet.gz


data_root_folder=./data
n_nodes=1169931, n_edges=50815503
nodes_feature_type=msbert
nodes_feature_dimension=512
sp: 3219 samples
sr: 1856 samples
cr: 400 samples


2023-06-14 13:04:59,557 ckg_benchmarks.base INFO     Data(x=[1169931, 512], edge_index=[2, 101631006])
2023-06-14 13:04:59,560 ckg_benchmarks.base INFO     Starting model training
2023-06-14 13:04:59,562 ckg_benchmarks.base INFO     Sending training logs to experiments/mvgrl/msbert_1_1_8_2_128_42.log
2023-06-14 13:04:59,563 ckg_benchmarks.base INFO     Strating model training
2023-06-14 13:05:01,739 ckg_benchmarks.gcl.train INFO     Starting epoch 1


Epoch 1/1:   0%|          | 0/9141 [00:00<?, ?it/s]

2023-06-14 13:06:34,663 ckg_benchmarks.gcl.train INFO     Epoch 1 loss: -8337.374342504889
2023-06-14 13:06:34,664 ckg_benchmarks.base INFO     Model training complete
2023-06-14 13:06:34,665 ckg_benchmarks.base INFO     Projecting full KG using final model
2023-06-14 13:07:10,336 ckg_benchmarks.base INFO     Best embeddings saved to experiments/mvgrl/msbert_1_1_8_2_128_42.pt
2023-06-14 13:07:10,338 ckg_benchmarks.base INFO     Model training complete


### Evaluation

In [16]:
eval_results2 = trainer2.evaluate()

Evaluate Custom Embeddings:
Evaluate SP ...
SP AUC: 0.5797767103886661
Evaluate SR ...
SR Validation ACC: 0.5543478260869565 SR Test ACC: 0.5477150537634409
Evaluate CR with top-K hit rate (K=[50, 100, 200, 500, 1000, 2000, 5000, 10000]) ...
CR Hit Rates: [0.0018796992481203006, 0.015037593984962405, 0.02106829573934837, 0.04037793347003873, 0.06084054834054833, 0.09382072605756817, 0.15318409660514923, 0.23358538391433128]


In [17]:
print(f"SP: {eval_results2['sp_auc']}")

SP: 0.5797767103886661


In [18]:
print(f"SR: {eval_results2['sr_test_acc']*100.:.2f}%")

SR: 54.77%


In [19]:
print("CR:")
for k, cr in zip(trainer2.comkg.eval_cr_top_ks, eval_results2['cr_topk_hit_rate']):
    print(f"R@{k}: {cr*100.:.2f}%")

CR:
R@50: 0.19%
R@100: 1.50%
R@200: 2.11%
R@500: 4.04%
R@1000: 6.08%
R@2000: 9.38%
R@5000: 15.32%
R@10000: 23.36%
