# Implementing a Recommender System using LightGCN

In [1]:
import matplotlib.pyplot as plt
from modeling.sampling import sample_minibatch_V2
from modeling.metrics import calculate_metrics
from modeling.losses import BPR_loss
import torch_geometric.transforms as T 

import time

import torch
from torch import optim

In [2]:
# Lets start by loading the data
data = torch.load("data/hetero_data_no_coauthor.pt", weights_only=False)
assert data.is_undirected(), "Data should be undirected"

# # add the ones vector to every author node
data["author"].x = torch.ones((data["author"].num_nodes, 256))

print(data)

HeteroData(
  author={
    node_id=[90941],
    x=[90941, 256],
  },
  paper={
    node_id=[63854],
    x=[63854, 256],
  },
  (author, writes, paper)={ edge_index=[2, 320187] },
  (paper, rev_writes, author)={ edge_index=[2, 320187] }
)


In [3]:
# Splitting the data
train_data, val_data, test_data = T.RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    neg_sampling_ratio=0.0,
    disjoint_train_ratio=0.3,
    add_negative_train_samples=False,
    is_undirected=True,
    edge_types=[("author", "writes", "paper")],
    rev_edge_types=[("paper", "rev_writes", "author")],
)(data)

val_data

HeteroData(
  author={
    node_id=[90941],
    x=[90941, 256],
  },
  paper={
    node_id=[63854],
    x=[63854, 256],
  },
  (author, writes, paper)={
    edge_index=[2, 256151],
    edge_label=[32018],
    edge_label_index=[2, 32018],
  },
  (paper, rev_writes, author)={ edge_index=[2, 256151] }
)

In [4]:
from torch_geometric.nn import to_hetero
from torch_geometric.data import HeteroData
import torch.nn.functional as F
from torch_geometric.nn.conv import SAGEConv


class GNN(torch.nn.Module):
    def __init__(
        self,
        embedding_dim: int,
        num_layers: int,
    ):
        super().__init__()

        self.convs = torch.nn.ModuleList(
            [
                SAGEConv(
                    embedding_dim,
                    embedding_dim,
                    aggr="mean",
                    project=False,
                )
                for _ in range(num_layers - 1)
            ]
        )

        self.out_conv = SAGEConv(
            embedding_dim,
            embedding_dim,
            aggr="mean",
            project=False,
        )

    def forward(
        self,
        x: torch.Tensor,
        edge_index: torch.Tensor,
    ) -> torch.Tensor:
        for conv in self.convs:
            x = conv(x, edge_index)
            x = F.relu(x)

        return self.out_conv(x, edge_index)


class Model(torch.nn.Module):
    def __init__(
        self,
        embedding_dim: int = 256,
        num_layers: int = 5,
    ):
        super().__init__()

        self.embedding_dim = embedding_dim

        self.gnn = GNN(embedding_dim, num_layers)
        # Convert GNN model into a heterogeneous variant:
        self.gnn = to_hetero(
            self.gnn,
            metadata=data.metadata(),
            aggr="mean",
        )

    def forward(self, data: HeteroData) -> torch.Tensor:

        x_dict = {
            "author": data["author"].x,
            "paper": data["paper"].x,
        }

        output_dict = self.gnn(x_dict, data.edge_index_dict)

        return output_dict

In [None]:
# define contants
ITERATIONS = 10000
LR = 1e-3

ITERS_PER_EVAL = 1000
K = 20

BATCH_SIZE = 4096
NEG_SAMPLE_RATIO = 10

TEST_EDGE_TYPE = ("author", "writes", "paper")

# setup
model = Model(
    embedding_dim=256,
    num_layers=5,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}.")

train_data = train_data.to(device)
val_data = val_data.to(device)
test_data = test_data.to(device)

model = model.to(device)
model.train()

optimizer = optim.Adam(model.parameters(), lr=LR)

# training loop
train_losses = []
batching_times = []
forward_times = []
loss_times = []
backward_times = []

Using device cuda.


In [24]:
for iter in range(ITERATIONS):
    print(
        f"Iteration {iter + 1}/{ITERATIONS} | Average Loss over last 100 iters: {sum(train_losses[-100:])/len(train_losses[-100:]) if len(train_losses) > 0 else 0:.5f}",
        end="\r",
    )

    # mini batching
    start_time = time.time()
    sampled_author_ids, sampled_pos_paper_ids, sampled_neg_paper_ids = (
        sample_minibatch_V2(
            data=train_data,
            edge_type=TEST_EDGE_TYPE,
            batch_size=BATCH_SIZE,
            neg_sample_ratio=NEG_SAMPLE_RATIO,
        )
    )
    batching_times.append(time.time() - start_time)

    # forward propagation
    start_time = time.time()
    embeddings = model.forward(train_data)
    author_embeddings = embeddings["author"]
    paper_embeddings = embeddings["paper"]
    forward_times.append(time.time() - start_time)

    pos_scores = torch.sum(
        author_embeddings[sampled_author_ids] * paper_embeddings[sampled_pos_paper_ids],
        dim=1,
    )
    neg_scores = torch.sum(
        author_embeddings[sampled_author_ids] * paper_embeddings[sampled_neg_paper_ids],
        dim=1,
    )

    # loss computation
    start_time = time.time()
    train_loss = BPR_loss(pos_scores, neg_scores)
    loss_times.append(time.time() - start_time)

    # backward propagation
    start = time.time()
    optimizer.zero_grad()
    train_loss.backward()
    optimizer.step()
    backward_times.append(time.time() - start)

    train_losses.append(train_loss.item())

    if (iter + 1) % ITERS_PER_EVAL == 0:
        model.eval()

        with torch.no_grad():
            # typically we would use the supervising edges as well here
            # But LightGCN does not have parameters, it only learns from the edges we use during training is is fix after that
            embeddings = model.forward(train_data)
            author_embeddings = embeddings["author"]
            paper_embeddings = embeddings["paper"]

        val_edge_label_index = val_data[TEST_EDGE_TYPE].edge_label_index
        train_edge_index = train_data[TEST_EDGE_TYPE].edge_index
        train_edge_label_index = train_data[TEST_EDGE_TYPE].edge_label_index

        val_recall, val_precision = calculate_metrics(
            author_embeddings,
            paper_embeddings,
            val_edge_label_index,
            [train_edge_index, train_edge_label_index],
            k=K,
        )

        train_recall, train_precision = calculate_metrics(
            author_embeddings,
            paper_embeddings,
            train_edge_label_index,
            [train_edge_index],
            k=K,
        )

        print(
            f"[Iteration {iter + 1}/{ITERATIONS}] train_loss: {train_loss.item():.05f}, val_recall@{K}: {val_recall:.05f}, val_precision@{K}: {val_precision:.05f}, train_recall@{K}: {train_recall:.05f}, train_precision@{K}: {train_precision:.05f}"
        )
        model.train()

# Best layer = 8: [Iteration 10000/10000] train_loss: 0.08947, val_recall@20: 0.02844, val_precision@20: 0.00221, train_recall@20: 0.03590, train_precision@20: 0.00380
#               [Iteration 100000/100000] train_loss: 0.03324, val_recall@20: 0.07191, val_precision@20: 0.00573, train_recall@20: 0.26362, train_precision@20: 0.02820

Iteration 4/10000 | Average Loss over last 100 iters: 0.66681

[Iteration 1000/10000] train_loss: 0.61854, val_recall@20: 0.00043, val_precision@20: 0.00003, train_recall@20: 0.00037, train_precision@20: 0.00004
[Iteration 2000/10000] train_loss: 0.62204, val_recall@20: 0.00051, val_precision@20: 0.00003, train_recall@20: 0.00038, train_precision@20: 0.00004
Iteration 2090/10000 | Average Loss over last 100 iters: 0.61767

KeyboardInterrupt: 

In [7]:
with torch.no_grad():
    # typically we would use the supervising edges as well here
    # But LightGCN does not have parameters, it only learns from the edges we use during training is is fix after that
    embeddings = model.forward(train_data)
    author_embeddings = embeddings["author"]
    paper_embeddings = embeddings["paper"]
    
author_embeddings

tensor([[ 0.7311,  0.3519, -0.0115,  ..., -0.8251, -0.0769, -0.5254],
        [ 0.7340,  0.3517, -0.0138,  ..., -0.8301, -0.0754, -0.5290],
        [ 0.2681, -0.0973, -0.1264,  ..., -0.6281,  0.0778, -0.2580],
        ...,
        [ 0.2681, -0.0973, -0.1264,  ..., -0.6281,  0.0778, -0.2580],
        [ 0.7304,  0.3509, -0.0130,  ..., -0.8254, -0.0760, -0.5265],
        [ 0.2681, -0.0973, -0.1264,  ..., -0.6281,  0.0778, -0.2580]],
       device='cuda:0')

In [21]:
start_time = time.time()
sampled_author_ids, sampled_pos_paper_ids, sampled_neg_paper_ids = (
    sample_minibatch_V2(
        data=train_data,
        edge_type=TEST_EDGE_TYPE,
        batch_size=5,
        neg_sample_ratio=NEG_SAMPLE_RATIO,
    )
)
batching_times.append(time.time() - start_time)

# forward propagation
start_time = time.time()
embeddings = model.forward(train_data)
forward_times.append(time.time() - start_time)

pos_scores = torch.sum(
    embeddings["author"][sampled_author_ids]
    * embeddings["paper"][sampled_pos_paper_ids],
    dim=1,
)
neg_scores = torch.sum(
    embeddings["author"][sampled_author_ids]
    * embeddings["paper"][sampled_neg_paper_ids],
    dim=1,
)

embeddings["paper"]


tensor([[-0.7302, -0.6382, -0.1720,  ..., -0.0179,  0.7548, -0.2720],
        [-0.7279, -0.6374, -0.1700,  ..., -0.0210,  0.7586, -0.2762],
        [-0.7282, -0.6360, -0.1692,  ..., -0.0187,  0.7530, -0.2726],
        ...,
        [-0.7365, -0.6418, -0.1756,  ..., -0.0199,  0.7560, -0.2647],
        [ 0.0610, -0.0696, -0.1163,  ..., -0.8232,  0.4937, -0.4334],
        [-0.7401, -0.6417, -0.1762,  ..., -0.0149,  0.7484, -0.2589]],
       device='cuda:0', grad_fn=<AddBackward0>)

In [None]:
# save the train_loss curve
plt.plot(train_losses, label="train")
plt.xlabel("iteration")
plt.ylabel("loss")
plt.title("training and validation loss curves")
plt.legend()
plt.grid()
plt.yscale("log")
plt.xscale("log")
plt.savefig("training_loss_curve.png")
plt.show()

plt.plot(batching_times[100:], label="batching")
plt.plot(forward_times[100:], label="forwarding")
plt.plot(loss_times[100:], label="loss computation")
plt.plot(backward_times[100:], label="backwarding")
plt.xlabel("iteration")
plt.ylabel("time (s)")
plt.title("time per operation")
plt.legend()
plt.grid()
plt.show()