In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from data import Dataset
from torch_geometric.data import DataLoader

from evaluate import evaluate_node, evaluate_clustering, run_similarity_search

## Evaluation of embeddings
Due to memory limitation of uploading files, we attached only two embeddings from our model.  
They are embeddings from node classification task for **Amazon computers** dataset and **Coauthor CS** dataset.

In [2]:
# Computers dataset
computers_dataset = Dataset(root="data", dataset="computers")[0]
embedding_path = "./embeddings/embeddings_computers_node.pt"
embeddings = torch.load(embedding_path)

evaluate_node(embeddings, computers_dataset, "computers")

Evaluate node classification results
** Val: 90.0800 (0.7563) | Test: 89.8100 (0.3403) **


In [3]:
# CS dataset
cs_dataset = Dataset(root="data", dataset="cs")[0]
embedding_path = "./embeddings/embeddings_cs_node.pt"
embeddings = torch.load(embedding_path)

evaluate_node(embeddings, cs_dataset, "cs")

Evaluate node classification results
** Val: 93.1506 (0.4517) | Test: 93.3022 (0.2037) **


### Evaluation for Node Classification
This is for reproducing **Table 2** in the paper

In [4]:
# WikiCS dataset
wikics_dataset = Dataset(root="data", dataset="wikics")[0]
embedding_path = "./embeddings/embeddings_wikics_node.pt"
embeddings = torch.load(embedding_path)

evaluate_node(embeddings, wikics_dataset, "wikics")

Evaluate node classification results
** Val: 78.3041 (0.9529) | Test: 77.5517 (0.5158) **


In [5]:
# Computers dataset
computers_dataset = Dataset(root="data", dataset="computers")[0]
embedding_path = "./embeddings/embeddings_computers_node.pt"
embeddings = torch.load(embedding_path)

evaluate_node(embeddings, computers_dataset, "computers")

Evaluate node classification results
** Val: 90.1127 (0.7972) | Test: 89.7959 (0.3182) **


In [6]:
# Photo dataset
photo_dataset = Dataset(root="data", dataset="photo")[0]
embedding_path = "./embeddings/embeddings_photo_node.pt"
embeddings = torch.load(embedding_path)

evaluate_node(embeddings, photo_dataset, "photo")

Evaluate node classification results
** Val: 93.0980 (0.9691) | Test: 93.1634 (0.3363) **


In [7]:
# CS dataset
cs_dataset = Dataset(root="data", dataset="cs")[0]
embedding_path = "./embeddings/embeddings_cs_node.pt"
embeddings = torch.load(embedding_path)

evaluate_node(embeddings, cs_dataset, "cs")

Evaluate node classification results
** Val: 93.1833 (0.4807) | Test: 93.3073 (0.2143) **


In [8]:
# Physics dataset
physics_dataset = Dataset(root="data", dataset="physics")[0]
embedding_path = "./embeddings/embeddings_physics_node.pt"
embeddings = torch.load(embedding_path)

evaluate_node(embeddings, physics_dataset, "physics")

Evaluate node classification results
** Val: 95.5697 (0.3342) | Test: 95.6067 (0.0860) **


### Clustering
This is for reproducing **Table 3** in the paper

In [9]:
# WikiCS dataset
wikics_dataset = Dataset(root="data", dataset="wikics")[0]
embedding_path = "./embeddings/embeddings_wikics_clustering.pt"
embeddings = torch.load(embedding_path)

evaluate_clustering(embeddings, wikics_dataset)

Evaluate clustering results
** Clustering NMI: 0.4115 | homogeneity score: 0.4288 **


In [10]:
# Computers dataset
computers_dataset = Dataset(root="data", dataset="computers")[0]
embedding_path = "./embeddings/embeddings_computers_clustering.pt"
embeddings = torch.load(embedding_path)

evaluate_clustering(embeddings, computers_dataset)

Evaluate clustering results
** Clustering NMI: 0.5442 | homogeneity score: 0.5936 **


In [11]:
# Photo dataset
photo_dataset = Dataset(root="data", dataset="photo")[0]
embedding_path = "./embeddings/embeddings_photo_clustering.pt"
embeddings = torch.load(embedding_path)

evaluate_clustering(embeddings, photo_dataset)

Evaluate clustering results
** Clustering NMI: 0.6647 | homogeneity score: 0.6807 **


In [12]:
# CS dataset
cs_dataset = Dataset(root="data", dataset="cs")[0]
embedding_path = "./embeddings/embeddings_cs_clustering.pt"
embeddings = torch.load(embedding_path)

evaluate_clustering(embeddings, cs_dataset)

Evaluate clustering results
** Clustering NMI: 0.7778 | homogeneity score: 0.8088 **


In [13]:
# Physics dataset
physics_dataset = Dataset(root="data", dataset="physics")[0]
embedding_path = "./embeddings/embeddings_physics_clustering.pt"
embeddings = torch.load(embedding_path)

evaluate_clustering(embeddings, physics_dataset)

Evaluate clustering results
** Clustering NMI: 0.7288 | homogeneity score: 0.7353 **


### Similarity Search
This is for reproducing **Table 4** in the paper

In [14]:
# WikiCS dataset
wikics_dataset = Dataset(root="data", dataset="wikics")[0]
embedding_path = "./embeddings/embeddings_wikics_similarity.pt"
embeddings = torch.load(embedding_path)

run_similarity_search(embeddings, wikics_dataset)

Evaluate similarity search results
** sim@5 : 0.781 | sim@10 : 0.7657 **


In [15]:
# Computers dataset
computers_dataset = Dataset(root="data", dataset="computers")[0]
embedding_path = "./embeddings/embeddings_computers_similarity.pt"
embeddings = torch.load(embedding_path)

run_similarity_search(embeddings, computers_dataset)

Evaluate similarity search results
** sim@5 : 0.8944 | sim@10 : 0.8867 **


In [16]:
# Photo dataset
photo_dataset = Dataset(root="data", dataset="photo")[0]
embedding_path = "./embeddings/embeddings_photo_similarity.pt"
embeddings = torch.load(embedding_path)

run_similarity_search(embeddings, photo_dataset)

Evaluate similarity search results
** sim@5 : 0.9226 | sim@10 : 0.9158 **


In [17]:
# CS dataset
cs_dataset = Dataset(root="data", dataset="cs")[0]
embedding_path = "./embeddings/embeddings_cs_similarity.pt"
embeddings = torch.load(embedding_path)

run_similarity_search(embeddings, cs_dataset)

Evaluate similarity search results
** sim@5 : 0.918 | sim@10 : 0.9141 **


In [18]:
# Physics dataset
physics_dataset = Dataset(root="data", dataset="physics")[0]
embedding_path = "./embeddings/embeddings_physics_similarity.pt"
embeddings = torch.load(embedding_path)

run_similarity_search(embeddings, physics_dataset)

Evaluate similarity search results
** sim@5 : 0.9529 | sim@10 : 0.9491 **
