# Evaluation of the models

## Step 1: Import Libraries and Load Data

In [1]:
from modeling.metrics import calculate_metrics
import torch_geometric.transforms as T
from modeling.models.lightGCN import LightGCN
from modeling.models.simpleGNN import SimpleGNN
import torch
from modeling.models.TextDotProduct import TextDotProductModel

torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

# Lets start by loading the data
data = torch.load("data/hetero_data.pt", weights_only=False)
data = T.AddSelfLoops()(data)
data = T.NormalizeFeatures()(data)

# Splitting the data
train_data, val_data, test_data = T.RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    neg_sampling_ratio=0.0,
    disjoint_train_ratio=0.0,
    add_negative_train_samples=False,
    is_undirected=True,
    edge_types=[("author", "writes", "paper")],
    rev_edge_types=[("paper", "rev_writes", "author")],
)(data)
data

HeteroData(
  author={
    node_id=[28772],
    x=[28772, 256],
  },
  paper={
    node_id=[56063],
    x=[56063, 256],
  },
  (author, writes, paper)={ edge_index=[2, 236503] },
  (paper, rev_writes, author)={ edge_index=[2, 236503] }
)

## Step 2: Load models

In [2]:
ckpt_paths = {
    "GraphSAGE": "checkpoints/model_GraphSAGE_50k.pth",
    "LightGCN": "checkpoints/model_LightGCN_50k.pth",
}

models = {
    "GraphSAGE": SimpleGNN(
        data=train_data,
        embedding_dim=256,
        num_layers=5,
    ),
    "LightGCN": LightGCN(
        num_authors=train_data["author"].num_nodes,
        num_papers=train_data["paper"].num_nodes,
        embedding_dim=256,
        K=5,
    ),
    "TextDotProduct": TextDotProductModel(),
}

In [3]:
Ks = [5, 10, 20, 50, 100]
TEST_EDGE_TYPE = ("author", "writes", "paper")

train_edge_index = train_data[TEST_EDGE_TYPE].edge_index
val_edge_index = val_data[TEST_EDGE_TYPE].edge_label_index
test_edge_index = test_data[TEST_EDGE_TYPE].edge_label_index

results = {}

for model_name in models:
    print(f"Evaluating {model_name}...")
    model = models[model_name]
    if model_name in ckpt_paths:  # models with parameters:
        model.load_state_dict(
            torch.load(ckpt_paths[model_name], map_location=torch.device("cpu"))
        )
    model.eval()
    results[model_name] = {}
    with torch.no_grad():
        embeddings = model.forward(train_data)
        author_embeddings = embeddings["author"]
        paper_embeddings = embeddings["paper"]
    for k in Ks:
        recall, precision = calculate_metrics(
            author_embeddings,
            paper_embeddings,
            test_edge_index,
            [train_edge_index, val_edge_index],
            k=k,
        )
        results[model_name][f"recall@{k}"] = recall

Evaluating GraphSAGE...
Evaluating LightGCN...
Evaluating TextDotProduct...


In [4]:

# Now print the results in a pretty table
from tabulate import tabulate

table = []
headers = [
    "Model",
    "Recall@5",
    "Recall@10",
    "Recall@20",
    "Recall@50",
    "Recall@100",
]
for model_name in results:
    row = [model_name]
    for k in Ks:
        row.append(f"{results[model_name][f'recall@{k}']:.4f}")
    table.append(row)
print(tabulate(table, headers=headers, tablefmt="grid"))

+----------------+------------+-------------+-------------+-------------+--------------+
| Model          |   Recall@5 |   Recall@10 |   Recall@20 |   Recall@50 |   Recall@100 |
| GraphSAGE      |     0.1572 |      0.2054 |      0.2582 |      0.3341 |       0.3962 |
+----------------+------------+-------------+-------------+-------------+--------------+
| LightGCN       |     0.176  |      0.25   |      0.3359 |      0.4439 |       0.5132 |
+----------------+------------+-------------+-------------+-------------+--------------+
| TextDotProduct |     0.0506 |      0.0697 |      0.0944 |      0.1353 |       0.1768 |
+----------------+------------+-------------+-------------+-------------+--------------+
