In [1]:
from evaluation.ranking_metrics import evaluate_ranking_metrics


import torch
from torch_geometric import seed_everything

seed_everything(42)
from torch_geometric.data import HeteroData
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.nn import SAGEConv, to_hetero
import tqdm
import torch.nn.functional as F
import torch_geometric.transforms as T

# Lets start by loading the data
data = torch.load("data/hetero_data_no_coauthor.pt", weights_only=False)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
!pip install torch_scatter


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [3]:
transform = T.RandomLinkSplit(
    num_val=0.1, # Validation set percentage
    num_test=0.1, #entage
    disjoint_train_ratio=0.3, # Percentage of training edges used for supervision, these will not be used for message passing
    neg_sampling_ratio=2.0, # Ratio of negative to posit Test set percive edges for validation and testing, dont know how this is related to `add_negative_train_samples`, need to check later
    add_negative_train_samples=False, # AYYY NO idea, why this set to False, but somehow it works worse with True ???, Need it investigate later, Prolly because we do LinkNeighborLoader which samples neg edges for us?
    edge_types=("author", "writes", "paper"), # Any ways, these are the edge types we want to predict
    rev_edge_types=("paper", "rev_writes", "author"), # Reverse edge types, so we dont accidentally bleed information into validation/test set
)

train_data, val_data, test_data = transform(data)


In [4]:
# Models to evaluate
from models.GNN import Model, BaselineNoGraphModel
import numpy as np
import torch

# Fix the random seed

# Checkpoints
model_checkpoints = {
    "BaselineGNN": "checkpoints/modelGNN_weights.pkl",
    "Baseline1HopGNN": "checkpoints/baseline_weights.pkl"
}

model_settings = {
    "BaselineGNN": {
        "hidden_channels": 256,
        "data": test_data
    },
    "Baseline1HopGNN": {
        "hidden_channels": 256,
        "data": test_data
    }
}

model_classes = {
    "BaselineGNN": Model,
    "Baseline1HopGNN": BaselineNoGraphModel
}

models = {key: model_classes[key](**model_settings[key]) for key in model_checkpoints.keys()}

# load model weights
for key in model_checkpoints.keys():
    models[key].load_state_dict(torch.load(model_checkpoints[key], map_location=torch.device('cpu')))


  models[key].load_state_dict(torch.load(model_checkpoints[key], map_location=torch.device('cpu')))


In [5]:
def evaluate_model(model, data):
    model.eval()
    with torch.no_grad():
        y_pred = model(data)

    y_pred = y_pred.cpu().numpy()
    y_true = data["author", "writes", "paper"].edge_label.cpu().numpy()

    # binary thresholding at 0.5
    y_pred = (y_pred >= 0.5)

    FP = ((y_true == 0) & (y_pred == 1)).sum().item()
    TP = ((y_true == 1) & (y_pred == 1)).sum().item()
    FN = ((y_true == 1) & (y_pred == 0)).sum().item()
    TN = ((y_true == 0) & (y_pred == 0)).sum().item()

    precision = TP / (TP + FP + 1e-8)
    recall = TP / (TP + FN + 1e-8)
    f1_score = 2 * (precision * recall) / (precision + recall + 1e-8)
    accuracy = (TP + TN) / (TP + TN + FP + FN + 1e-8)

    return precision, recall, f1_score, accuracy

def dump_quick_model_metrics(model):
    # just a safety check to compare with outputs of example_training.ipynb
    precision, recall, f1_score, accuracy = evaluate_model(model, test_data)
    # dump the model weights into a file
    print("Evaluating on Test set...")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1_score:.4f}")
    print(f"Accuracy: {accuracy:.4f}")
    print("--------------------------------------------------")
    precision, recall, f1_score, accuracy = evaluate_model(model, val_data)
    print("Evaluating on validation set...")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1_score:.4f}")
    print(f"Accuracy: {accuracy:.4f}")

In [6]:

for model_name, model in models.items():
    print(f"Evaluating model: {model_name}")
    dump_quick_model_metrics(model)


Evaluating model: BaselineGNN
Evaluating on Test set...
Precision: 0.8224
Recall: 0.6427
F1 Score: 0.7215
Accuracy: 0.8346
--------------------------------------------------
Evaluating on validation set...
Precision: 0.8288
Recall: 0.6206
F1 Score: 0.7097
Accuracy: 0.8308
Evaluating model: Baseline1HopGNN
Evaluating on Test set...
Precision: 0.7469
Recall: 0.4315
F1 Score: 0.5470
Accuracy: 0.7618
--------------------------------------------------
Evaluating on validation set...
Precision: 0.7588
Recall: 0.4299
F1 Score: 0.5489
Accuracy: 0.7644


In [7]:
# ranking metrics
Ks = (4, 12)
metrics = {}
for model_name, model in models.items():
    model.eval()
    metrics[model_name] = evaluate_ranking_metrics(model, test_data, ks=Ks)

In [8]:
metrics.keys()

dict_keys(['BaselineGNN', 'Baseline1HopGNN'])

In [9]:
metrics["BaselineGNN"].keys()

dict_keys(['num_heads', 'MRR', 'MAP', 'Hits@4', 'Precision@4', 'Recall@4', 'F1@4', 'MAP@4', 'NDCG@4', 'Hits@12', 'Precision@12', 'Recall@12', 'F1@12', 'MAP@12', 'NDCG@12'])

In [10]:
def make_latex_table(metrics: dict) -> str:
    lines = []
    lines.append("\\begin{table}[h]")
    lines.append("\\centering")
    lines.append("\\begin{tabular}{c|cc|cc|cc|cc}")
    lines.append("\\toprule")
    lines.append("\\multirow{2}{*}{\\textbf{Model}} &")
    lines.append("\\multicolumn{2}{c|}{\\textbf{MAP}} &")
    lines.append("\\multicolumn{2}{c|}{\\textbf{Precision}} &")
    lines.append("\\multicolumn{2}{c|}{\\textbf{Recall}} &")
    lines.append("\\multicolumn{2}{c}{\\textbf{F1}} \\\\")
    lines.append("& @4 & @12 & @4 & @12 & @4 & @12 & @4 & @12 \\\\")
    lines.append("\\midrule")

    for model, vals in metrics.items():
        row = (
            f"{model} & "
            f"{vals.get('MAP@4', 0):.3f} & {vals.get('MAP@12', 0):.3f} & "
            f"{vals.get('Precision@4', 0):.3f} & {vals.get('Precision@12', 0):.3f} & "
            f"{vals.get('Recall@4', 0):.3f} & {vals.get('Recall@12', 0):.3f} & "
            f"{vals.get('F1@4', 0):.3f} & {vals.get('F1@12', 0):.3f} \\\\"
        )
        lines.append(row)

    lines.append("\\bottomrule")
    lines.append("\\end{tabular}")
    lines.append("\\vspace{0.1in}")
    lines.append("\\caption{The specific metrics were chosen based on the lecture.}")
    lines.append("\\label{tbl:initial_metrics}")
    lines.append("\\end{table}")
    return "\n".join(lines)


In [11]:
print(make_latex_table(metrics))

\begin{table}[h]
\centering
\begin{tabular}{c|cc|cc|cc|cc}
\toprule
\multirow{2}{*}{\textbf{Model}} &
\multicolumn{2}{c|}{\textbf{MAP}} &
\multicolumn{2}{c|}{\textbf{Precision}} &
\multicolumn{2}{c|}{\textbf{Recall}} &
\multicolumn{2}{c}{\textbf{F1}} \\
& @4 & @12 & @4 & @12 & @4 & @12 & @4 & @12 \\
\midrule
BaselineGNN & 0.953 & 0.954 & 0.132 & 0.047 & 0.987 & 1.000 & 0.484 & 0.210 \\
Baseline1HopGNN & 0.947 & 0.949 & 0.132 & 0.047 & 0.986 & 1.000 & 0.484 & 0.210 \\
\bottomrule
\end{tabular}
\vspace{0.1in}
\caption{The specific metrics were chosen based on the lecture.}
\label{tbl:initial_metrics}
\end{table}
