In [1]:
import torch
from torch_geometric.data import HeteroData
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.nn import SAGEConv, to_hetero
import tqdm
import torch.nn.functional as F
import torch_geometric.transforms as T

In [2]:
# Lets start by loading the data

data = torch.load("data/hetero_data_no_coauthor.pt", weights_only=False)
data

HeteroData(
  author={ node_id=[90941] },
  paper={
    node_id=[63854],
    x=[63854, 256],
  },
  (author, writes, paper)={ edge_index=[2, 320187] },
  (paper, rev_writes, author)={ edge_index=[2, 320187] }
)

In [3]:
# Do the Train, Val, Test Split
# training (80%), validation (10%), and testing edges (10%).
# Across the training edges, we use 70% of edges for message passing,
# and 30% of edges for supervision. (This is from a tutorial by PyG, we can change this later)
# We further want to generate fixed negative edges for evaluation with a ratio of 2:1. (Again a Hyperparameter we can tune later)
# Negative edges during training will be generated on-the-fly (How?, again this is from the tutorial, need to check later)
transform = T.RandomLinkSplit(
    num_val=0.1, # Validation set percentage
    num_test=0.1, # Test set percentage
    disjoint_train_ratio=0.3, # Percentage of training edges used for supervision, these will not be used for message passing
    neg_sampling_ratio=2.0, # Ratio of negative to positive edges for validation and testing, dont know how this is related to `add_negative_train_samples`, need to check later
    add_negative_train_samples=False, # AYYY NO idea, why this set to False, but somehow it works worse with True ???, Need it investigate later, Prolly because we do LinkNeighborLoader which samples neg edges for us?
    edge_types=("author", "writes", "paper"), # Any ways, these are the edge types we want to predict
    rev_edge_types=("paper", "rev_writes", "author"), # Reverse edge types, so we dont accidentally bleed information into validation/test set
)

train_data, val_data, test_data = transform(data)
train_data

HeteroData(
  author={ node_id=[90941] },
  paper={
    node_id=[63854],
    x=[63854, 256],
  },
  (author, writes, paper)={
    edge_index=[2, 179306],
    edge_label=[76845],
    edge_label_index=[2, 76845],
  },
  (paper, rev_writes, author)={ edge_index=[2, 179306] }
)

In [4]:
# In the first hop, we sample at most 20 neighbors.
# In the second hop, we sample at most 10 neighbors.
# In addition, during training, we want to sample negative edges on-the-fly with
# a ratio of 2:1.
# We can make use of the `loader.LinkNeighborLoader` from PyG:

# This loader is actually SAMPLING the full graph, by first sampling 64 random nodes then 32 neighbors of each node previously sampled node to create a sparse subgraph etc...
# We should be able to load the graph fully into memory, but how would one train that?
# We could probably use the previous random link split to do full batch training, but somehow we would not sample random negative edges then?
# Need to check different loaders which sample the full graph and then do negative sampling on-the-fly
edge_label_index = train_data["author", "writes", "paper"].edge_label_index
edge_label = train_data["author", "writes", "paper"].edge_label

train_loader = LinkNeighborLoader(
    data=train_data,
    num_neighbors=[64, 32, 16],
    neg_sampling_ratio=2.0,
    edge_label_index=(("author", "writes", "paper"), edge_label_index),
    edge_label=edge_label,
    batch_size=128,
    shuffle=True,
)

In [None]:

# Simple 3 hop GNN
class GNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.conv1 = SAGEConv(
            hidden_channels,
            hidden_channels,
            aggr="mean",
            project=False,
        )
        self.conv2 = SAGEConv(
            hidden_channels,
            hidden_channels,
            aggr="mean",
            project=False,
        )
        self.conv3 = SAGEConv(
            hidden_channels,
            hidden_channels,
            aggr="mean",
            project=False,
        )

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)
        return x


# Our final classifier applies the dot-product between source and destination
# node embeddings to derive edge-level predictions:
class Classifier(torch.nn.Module):
    def forward(
        self,
        x_user: torch.Tensor,
        x_movie: torch.Tensor,
        edge_label_index: torch.Tensor,
    ) -> torch.Tensor:
        # Convert node embeddings to edge-level representations:
        edge_feat_user = x_user[edge_label_index[0]]
        edge_feat_movie = x_movie[edge_label_index[1]]
        return (edge_feat_user * edge_feat_movie).sum(dim=-1)


class Model(torch.nn.Module):
    def __init__(self, hidden_channels: int, data: HeteroData):
        super().__init__()

        self.hidden_channels = hidden_channels

        # Instantiate homogeneous GNN:
        self.gnn = GNN(hidden_channels)

        # Convert GNN model into a heterogeneous variant:
        self.gnn = to_hetero(self.gnn, metadata=data.metadata())
        
        # Instantiate link classifier:
        self.classifier = Classifier()

    def forward(self, data: HeteroData) -> torch.Tensor:

        # Set the initial user embeddings to all ones for all authors
        # This makes sure the graph can generalize to unseen authors during inference
        author_embedding = torch.ones(
            (data["author"].num_nodes, self.hidden_channels),
            device=data["paper"].x.device,
        )

        # Extract paper embeddings from the data object
        paper_embedding = data["paper"].x

        # Noew we can create the x_dict required for the GNN
        x_dict = {
            "author": author_embedding,
            "paper": paper_embedding,
        }

        # "x_dict" now holds feature matrices of all node types
        # "edge_index_dict" holds all edge indices, i.e. the connections between users and movies
        # The GNN will predict new embeddings for all node types, we can even check how the user embeddings change
        gnn_pred = self.gnn(x_dict, data.edge_index_dict)

        # Finally we can use the classifier to get the final link predictions
        # This can be done either with the dot product of the updated embeddings
        # or more involved with a linear projection head or smth similar
        cls_pred = self.classifier(
            gnn_pred["author"],
            gnn_pred["paper"],
            data["author", "writes", "paper"].edge_label_index,
        )

        return cls_pred

In [8]:
LR = 0.001
EPOCHS = 20
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Model(hidden_channels=256, data=data)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

model = model.to(device)

model.train()
for epoch in range(EPOCHS):
    total_loss = 0
    total_examples = 0
    for sampled_data in tqdm.tqdm(train_loader):
        
        optimizer.zero_grad()
        sampled_data.to(device)
        
        y_pred = model(sampled_data)
        y_true = sampled_data["author", "writes", "paper"].edge_label
        
        loss = F.binary_cross_entropy_with_logits(y_pred, y_true)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item() * y_pred.numel()
        total_examples += y_pred.numel()
        
    print(f"Epoch: {epoch:03d}, Loss: {total_loss / total_examples:.4f}")

  1%|▏         | 8/601 [00:00<00:07, 78.63it/s]

100%|██████████| 601/601 [00:07<00:00, 83.46it/s]


Epoch: 000, Loss: 0.5890


100%|██████████| 601/601 [00:07<00:00, 82.47it/s]


Epoch: 001, Loss: 0.5046


100%|██████████| 601/601 [00:07<00:00, 83.40it/s]


Epoch: 002, Loss: 0.4771


100%|██████████| 601/601 [00:07<00:00, 82.92it/s]


Epoch: 003, Loss: 0.4601


100%|██████████| 601/601 [00:07<00:00, 82.77it/s]


Epoch: 004, Loss: 0.4500


100%|██████████| 601/601 [00:07<00:00, 84.77it/s]


Epoch: 005, Loss: 0.4413


100%|██████████| 601/601 [00:06<00:00, 86.44it/s]


Epoch: 006, Loss: 0.4336


100%|██████████| 601/601 [00:07<00:00, 83.63it/s]


Epoch: 007, Loss: 0.4259


100%|██████████| 601/601 [00:07<00:00, 84.14it/s]


Epoch: 008, Loss: 0.4246


100%|██████████| 601/601 [00:07<00:00, 84.89it/s]


Epoch: 009, Loss: 0.4185


100%|██████████| 601/601 [00:07<00:00, 84.60it/s]


Epoch: 010, Loss: 0.4135


100%|██████████| 601/601 [00:07<00:00, 83.07it/s]


Epoch: 011, Loss: 0.4095


100%|██████████| 601/601 [00:07<00:00, 83.69it/s]


Epoch: 012, Loss: 0.4069


100%|██████████| 601/601 [00:07<00:00, 83.91it/s]


Epoch: 013, Loss: 0.4038


100%|██████████| 601/601 [00:07<00:00, 83.57it/s]


Epoch: 014, Loss: 0.3983


100%|██████████| 601/601 [00:07<00:00, 84.18it/s]


Epoch: 015, Loss: 0.3987


100%|██████████| 601/601 [00:07<00:00, 84.83it/s]


Epoch: 016, Loss: 0.3966


100%|██████████| 601/601 [00:07<00:00, 83.88it/s]


Epoch: 017, Loss: 0.3903


100%|██████████| 601/601 [00:07<00:00, 84.10it/s]


Epoch: 018, Loss: 0.3902


100%|██████████| 601/601 [00:07<00:00, 83.86it/s]

Epoch: 019, Loss: 0.3913





In [9]:
def evaluate_model(model, data):
    model.eval()
    with torch.no_grad():
        y_pred = model(data)

    y_pred = y_pred.cpu().numpy()
    y_true = data["author", "writes", "paper"].edge_label.cpu().numpy()

    # binary thresholding at 0.5
    y_pred = (y_pred >= 0.5)
            
    FP = ((y_true == 0) & (y_pred == 1)).sum().item()
    TP = ((y_true == 1) & (y_pred == 1)).sum().item()
    FN = ((y_true == 1) & (y_pred == 0)).sum().item()
    TN = ((y_true == 0) & (y_pred == 0)).sum().item()

    precision = TP / (TP + FP + 1e-8)
    recall = TP / (TP + FN + 1e-8)
    f1_score = 2 * (precision * recall) / (precision + recall + 1e-8)
    accuracy = (TP + TN) / (TP + TN + FP + FN + 1e-8)

    return precision, recall, f1_score, accuracy


test_data.to(device)
precision, recall, f1_score, accuracy = evaluate_model(model, test_data)
print("Evaluating on Test set...")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1_score:.4f}")
print(f"Accuracy: {accuracy:.4f}")
print("--------------------------------------------------")
val_data.to(device)
precision, recall, f1_score, accuracy = evaluate_model(model, val_data)
print("Evaluating on validation set...")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1_score:.4f}")
print(f"Accuracy: {accuracy:.4f}")

Evaluating on Test set...
Precision: 0.8437
Recall: 0.6208
F1 Score: 0.7153
Accuracy: 0.8353
--------------------------------------------------
Evaluating on validation set...
Precision: 0.8489
Recall: 0.5969
F1 Score: 0.7010
Accuracy: 0.8302


In [10]:
from sklearn.metrics import roc_auc_score

with torch.no_grad():
    y_pred = model(val_data)

y_pred = y_pred.cpu().numpy()
y_true = val_data["author", "writes", "paper"].edge_label.cpu().numpy()

auc = roc_auc_score(y_true, y_pred)
print(f"Validation AUC: {auc:.4f}")

Validation AUC: 0.9098
