# Evaluation of a LightGCN-based Model on the Author-Paper Heterogeneous Graph

## Step 1: Import Libraries and Load Data

In [26]:
from modeling.metrics import calculate_metrics
import torch_geometric.transforms as T
from modeling.models.lightGCN2 import LightGCN
from modeling.models.simple_V2 import Model
import torch
from torch import optim
from modeling.utils import add_coauthor_edges

torch.manual_seed(42)
torch.cuda.manual_seed_all(42)


# Lets start by loading the data
data = torch.load("data/hetero_data_filtered_3_2.pt", weights_only=False)
data = T.AddSelfLoops()(data)
data = T.NormalizeFeatures()(data)

# Splitting the data
train_data, val_data, test_data = T.RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    neg_sampling_ratio=0.0,
    disjoint_train_ratio=0.0,
    add_negative_train_samples=False,
    is_undirected=True,
    edge_types=[("author", "writes", "paper")],
    rev_edge_types=[("paper", "rev_writes", "author")],
)(data)


In [27]:
data

HeteroData(
  author={
    node_id=[20950],
    x=[20950, 256],
  },
  paper={
    node_id=[39802],
    x=[39802, 256],
  },
  (author, writes, paper)={ edge_index=[2, 185956] },
  (paper, rev_writes, author)={ edge_index=[2, 185956] }
)

## Step 2: Load models

In [28]:

ckpt_paths = {
    "LightGCN": "checkpoints/model_LightGCN_15000.pth",
    "GNN": "checkpoints/model_GNN_15000.pth",
}

models = {
    "LightGCN": LightGCN(
        num_authors=train_data["author"].num_nodes,
        num_papers=train_data["paper"].num_nodes,
        embedding_dim=256,
        K=6,
    ),
    "GNN": Model(
        data=train_data,
        embedding_dim=256,
        num_layers=5,
    )
}

In [29]:
Ks = [5, 10, 20]
TEST_EDGE_TYPE = ("author", "writes", "paper")

train_message_passing_edge_index = train_data[TEST_EDGE_TYPE].edge_index
train_supervision_edge_index = train_data[TEST_EDGE_TYPE].edge_label_index
train_edge_index = train_data[TEST_EDGE_TYPE].edge_index
val_edge_index = val_data[TEST_EDGE_TYPE].edge_label_index
test_edge_index = test_data[TEST_EDGE_TYPE].edge_label_index

results = {}

for model_name in models:
    print(f"Evaluating {model_name}...")
    model = models[model_name]
    model.load_state_dict(torch.load(ckpt_paths[model_name], map_location=torch.device("cpu")))
    model.eval()
    results[model_name] = {}
    with torch.no_grad():
        embeddings = model.forward(train_data)
        author_embeddings = embeddings["author"]
        paper_embeddings = embeddings["paper"]
    for k in Ks:
        recall, precision = calculate_metrics(
            author_embeddings,
            paper_embeddings,
            test_edge_index,
            [train_edge_index],
            k=k
        )
        results[model_name][f"recall@{k}"] = recall
        results[model_name][f"precision@{k}"] = precision
        results[model_name][f"F1@{k}"] = 2 * (precision * recall) / (precision + recall + 1e-9)


Evaluating LightGCN...


  model.load_state_dict(torch.load(ckpt_paths[model_name], map_location=torch.device("cpu")))


Evaluating GNN...


In [30]:
# Now print the results in a pretty table
from tabulate import tabulate
table = []
headers = ["Model", "Recall@5", "Precision@5", "F1@5",
           "Recall@10", "Precision@10", "F1@10",
           "Recall@20", "Precision@20", "F1@20"]
for model_name in results:
    row = [model_name]
    for k in Ks:
        row.append(f"{results[model_name][f'recall@{k}']:.4f}")
        row.append(f"{results[model_name][f'precision@{k}']:.4f}")
        row.append(f"{results[model_name][f'F1@{k}']:.4f}")
    table.append(row)
print(tabulate(table, headers=headers, tablefmt="grid"))

+----------+------------+---------------+--------+-------------+----------------+---------+-------------+----------------+---------+
| Model    |   Recall@5 |   Precision@5 |   F1@5 |   Recall@10 |   Precision@10 |   F1@10 |   Recall@20 |   Precision@20 |   F1@20 |
| LightGCN |     0.2957 |        0.0977 | 0.1469 |      0.3823 |         0.0646 |  0.1105 |      0.4676 |         0.0401 |  0.0739 |
+----------+------------+---------------+--------+-------------+----------------+---------+-------------+----------------+---------+
| GNN      |     0.0458 |        0.0139 | 0.0213 |      0.0698 |         0.0107 |  0.0185 |      0.1032 |         0.0082 |  0.0152 |
+----------+------------+---------------+--------+-------------+----------------+---------+-------------+----------------+---------+
