Much like the LightGCN model that we ran that contained only user-song edges, for this particular model, we will be utilising the LightGCN convolutional layer (LGConv) to run on a graph comprising both user-song and song-song edges.

Importing relevant libraries:

In [76]:
# General libraries
import json
from pathlib import Path as Data_Path
import os
from os.path import isfile, join
import pickle
import random

import numpy as np
import networkx as nx
import pandas as pd
from sklearn.metrics import f1_score, roc_auc_score
import matplotlib.pyplot as plt
%matplotlib inline

from tqdm.notebook import tqdm

# Import relevant ML libraries
from typing import Optional, Union

import torch
from torch import Tensor
import torch.nn as nn
from torch.nn import Embedding, ModuleList, Linear
import torch.nn.functional as F

import torch_geometric
import torch_geometric.nn as pyg_nn
from torch_geometric.data import Data
from torch_geometric.transforms import RandomLinkSplit
from torch.nn.modules.loss import _Loss

from torch_geometric.nn.conv import LGConv
from torch_geometric.typing import Adj, OptTensor, SparseTensor

print(f"Torch version: {torch.__version__}; Torch-cuda version: {torch.version.cuda}; Torch Geometric version: {torch_geometric.__version__}.")

Torch version: 2.2.1+cpu; Torch-cuda version: None; Torch Geometric version: 2.5.2.


Reading data:

`user_song_data` is the dataset that contains information on users and their song listening histories

`song_song_data` is the dataset that contains information on songs and their top 5 most similar songs (similarity was calculated via cosine similarity)

In [2]:
user_song_data = pd.read_csv('../Data/user_songs_filtered.csv')
song_song_data = pd.read_csv('../Data/songs_with_similarities_final.csv')

In [3]:
# set the seed for reproducibility
seed = 224
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

In [26]:
user_song_data = pd.read_csv('../Data/user_songs_filtered.csv')
user_song_data = user_song_data.sort_values(by = 'Username')
user_song_data = user_song_data.reset_index(drop=True)

Since there are 2 types of nodes, we create 2 different classes, `Track` and `User` in order to represent the 2 different types of nodes.

In [None]:
class Track:
  def __init__(self, track_name, artist_name, similar_tracks):
    self.name = track_name
    self.artist_name = artist_name
    self.similar_tracks = similar_tracks


  def __str__(self):
    return f"Track called {self.name} from artist {self.artist_name}."

  def __repr__(self):
    return f"Track {self.name}"

user_song_data[['track_name', 'artist_name']].drop_duplicates().to_numpy()

tracks = []
unique_tracks = user_song_data[['track_name', 'artist_name']].drop_duplicates().to_numpy()

for track, artist_name in unique_tracks:
    row_similar_songs = song_song_data[(song_song_data['track_name'] == track) & (song_song_data['artist_name'] == artist_name)].iloc[0]
    similar_tracks = [
        (row_similar_songs['Track_Name_1'], row_similar_songs['Similarity_1'], row_similar_songs['Artist_Name_1']),
        (row_similar_songs['Track_Name_2'], row_similar_songs['Similarity_2'], row_similar_songs['Artist_Name_2']),
        (row_similar_songs['Track_Name_3'], row_similar_songs['Similarity_3'], row_similar_songs['Artist_Name_3']),
        (row_similar_songs['Track_Name_4'], row_similar_songs['Similarity_4'], row_similar_songs['Artist_Name_4']),
        (row_similar_songs['Track_Name_5'], row_similar_songs['Similarity_5'], row_similar_songs['Artist_Name_5'])
    ]

    for similar_track in similar_tracks:
      tracks += [Track(similar_track[0], similar_track[2], [])]


    tracks += [Track(track, artist_name, similar_tracks)]


class User:
  def __init__(self, user_data, top_songs):

    self.name = user_data['Username']
    self.country = user_data['country']
    self.track_count = int(user_data['track_count'])
    self.total_playcount = 0
    self.top_songs = {}
    self.artists = []
    for index, row in top_songs.iterrows():
      rank = row['rank']
      track_name = row['track_name']
      artist_name = row['artist_name']
      track = [obj for obj in tracks if (obj.name == track_name)&(obj.artist_name == artist_name)][0]
      playcount = row['playcount']
      self.top_songs[rank] = (track, playcount)
      self.total_playcount += playcount
      self.artists += [artist_name]

    self.top_songs =  {k: self.top_songs[k] for k in sorted(self.top_songs)}
  def __str__(self):
    return f"User {self.name} with {len(self.top_songs)} top tracks loaded, total listen count is {self.total_playcount}."

  def __repr__(self):
    return f"User {self.name}"
  def __lt__(self, other):
    return (self.name < other.name) and (self.total_playcount < other.total_playcount)

  def __gt__(self, other):
    return (self.name > other.name) and (self.total_playcount > other.total_playcount)

unique_users = user_song_data.Username.unique()
users = []
for user in unique_users:
    user_data = user_song_data.loc[user_song_data['Username'] == user].iloc[0]
    user_data = user_data[['Username', 'country', 'track_count']].to_dict()
    top_songs = user_song_data.loc[user_song_data['Username'] == user]
    top_songs = top_songs[['rank', 'track_name', 'artist_name', 'playcount']]
    users +=[User(user_data, top_songs)]



Subsequently, using `user_song_data` and `song_song_data`, we create the weighted edges to connect each of the nodes.

In [46]:
# adding nodes
G = nx.Graph()
G.add_nodes_from([
    (p, {'name':p, "node_type" : "user"}) for p in users
])
for track in tracks:
    G.add_node(track, node_type="track")
    if (len(track.similar_tracks) == 5):
        for similar_track in track.similar_tracks:
            G.add_node(similar_track, node_type="track")


n_nodes, n_edges = G.number_of_nodes(), G.number_of_edges()

# by sorting them we get an ordering playlist1, ..., playlistN, track1, ..., trackN
sorted_nodes = list(G.nodes())

# create dictionaries to index to 0 to n_nodes, will be necessary for when we are using tensors
node2id = dict(zip(sorted_nodes, np.arange(n_nodes)))
id2node = dict(zip(np.arange(n_nodes), sorted_nodes))


# Add edges

for user in users:
  top_songs = user.top_songs
  user_total_listening = user.total_playcount
  for song, count in top_songs.values():
      G.add_edge(node2id[user], node2id[song], weight= (count/user_total_listening))

for track in tracks:
    for similar_track in track.similar_tracks:
        G.add_edge(node2id[track], node2id[similar_track], weight = similar_track[1])

G = nx.relabel_nodes(G, node2id)

# also keep track of how many users, tracks we have
users_idx = [i for i, v in enumerate(node2id.keys()) if isinstance(v, User)] 
tracks_idx = [i for i, v in enumerate(node2id.keys()) if isinstance(v, Track)]

n_users = np.max(users_idx) + 1
n_tracks = n_nodes - n_users

# Relabel nodes to have consecutive integer indices starting from 0
node_mapping = {node: idx for idx, node in enumerate(G.nodes())}

# Relabel the nodes in the graph using the mapping
G = nx.relabel_nodes(G, node_mapping)

# Get the edge indices from the relabeled graph
edge_idx = torch.tensor(list(G.edges)).t()


# Get the edge weights from the NetworkX graph
edge_weights = [G[u][v]['weight'] for u, v in G.edges()]
edge_weights = torch.tensor(edge_weights, dtype=torch.float)

# Create the PyTorch Geometric Data object
num_nodes = len(G.nodes())
graph_data = Data(edge_index=edge_idx, edge_weight=edge_weights, num_nodes=num_nodes)

# Get the edge indices from the relabeled graph
edge_idx = torch.tensor(list(G.edges)).t()

# Get the edge weights from the NetworkX graph
edge_weights = [G[u][v]['weight'] for u, v in G.edges()]
edge_weights = torch.tensor(edge_weights, dtype=torch.float)

# Create the PyTorch Geometric Data object
num_nodes = len(G.nodes())
graph_data = Data(edge_index=edge_idx, edge_weight=edge_weights, num_nodes=num_nodes)


The dataset is then split using `RandomLinkSplit` into a ratio of 0.7 train, 0.15 test, 0.15 validation.

In [77]:
# convert to train/val/test splits
transform = RandomLinkSplit(
    is_undirected=True,
    add_negative_train_samples=False,
    neg_sampling_ratio=0,
    num_val=0.15, num_test=0.15
)

train_split, val_split, test_split = transform(graph_data)

for split_data in [train_split, val_split, test_split]:
    edge_weights_split = []
    for u, v in split_data.edge_index.T:
        edge_weights_split.append(G[node2id[id2node[u.item()]]][node2id[id2node[v.item()]]]['weight'])
    split_data.edge_weight = torch.tensor(edge_weights_split, dtype=torch.float)


# Edge index: message passing edges
train_split.edge_index = train_split.edge_index.type(torch.int64)
val_split.edge_index = val_split.edge_index.type(torch.int64)
test_split.edge_index = test_split.edge_index.type(torch.int64)
# Edge label index: supervision edges
train_split.edge_label_index = train_split.edge_label_index.type(torch.int64)
val_split.edge_label_index = val_split.edge_label_index.type(torch.int64)
test_split.edge_label_index = test_split.edge_label_index.type(torch.int64)

print(f"Train set has {train_split.edge_label_index.shape[1]} positives supervision edges")
print(f"Validation set has {val_split.edge_label_index.shape[1]} positive supervision edges")
print(f"Test set has {test_split.edge_label_index.shape[1]} positive supervision edges")

print(f"Train set has {train_split.edge_index.shape[1]} message passing edges")
print(f"Validation set has {val_split.edge_index.shape[1]} message passing edges")
print(f"Test set has {test_split.edge_index.shape[1]} message passing edges")

Train set has 471498 positives supervision edges
Validation set has 101034 positive supervision edges
Test set has 101034 positive supervision edges
Train set has 942996 message passing edges
Validation set has 942996 message passing edges
Test set has 1145064 message passing edges


We then create the GCN model class.

In [78]:
class GCN(torch.nn.Module):

    def __init__(
        self,
        num_users: int,
        num_items: int,
        embedding_dim: int,
        num_layers: int,
        alpha: Optional[Union[float, Tensor]] = None,
        alpha_learnable = False,
        name = None,
        **kwargs,
    ):
        super().__init__()
        alpha_string = "alpha" if alpha_learnable else ""
        self.name = f"SimpleConv_{num_layers}_e{embedding_dim}_users{num_users}_items{num_items}_{alpha_string}"
        self.num_users = num_users
        self.num_items = num_items
        self.embedding_dim = embedding_dim
        self.num_layers = num_layers

        if alpha_learnable == True:
          alpha_vals = torch.rand(num_layers+1)
          alpha = nn.Parameter(alpha_vals/torch.sum(alpha_vals))
          print(f"Alpha learnable, initialized to: {alpha.softmax(dim=-1)}")
        else:
          if alpha is None:
              alpha = 1. / (num_layers + 1)

          if isinstance(alpha, Tensor):
              assert alpha.size(0) == num_layers + 1
          else:
              alpha = torch.tensor([alpha] * (num_layers + 1))

        self.register_buffer('alpha', alpha)

        self.user_embedding = Embedding(num_users, embedding_dim)
        self.item_embedding = Embedding(num_items, embedding_dim)

        # initialize convolutional layers
        self.convs = ModuleList([LGConv(aggr='mean') for _ in range(num_layers)])


        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.user_embedding.weight)
        torch.nn.init.xavier_uniform_(self.item_embedding.weight)
        for conv in self.convs:
            conv.reset_parameters()

    def get_embedding(self, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor:
        user_emb = self.user_embedding.weight
        item_emb = self.item_embedding.weight

        weights = self.alpha.softmax(dim=-1)
        user_out = user_emb * weights[0]
        item_out = item_emb * weights[0]

        for i in range(self.num_layers):
            x = torch.cat([user_emb, item_emb], dim=0)
            out = self.convs[i](x, edge_index, edge_weight=edge_weight)
            user_emb, item_emb = out.split([user_emb.size(0), item_emb.size(0)], dim=0)
            user_out = user_out + user_emb * weights[i + 1]
            item_out = item_out + item_emb * weights[i + 1]

        return user_out, item_out

    def forward(self, edge_index: Adj, edge_weight: OptTensor = None, edge_label_index: OptTensor = None) -> Tensor:
        if edge_label_index is None:
            raise ValueError("Edge label index must be provided for link prediction.")

        user_emb, item_emb = self.get_embedding(edge_index, edge_weight)
        return self.predict_link_embedding(user_emb, item_emb, edge_label_index)

    def predict_link_embedding(self, user_emb: Tensor, item_emb: Tensor, edge_label_index: Tensor) -> Tensor:
        user_indices, item_indices = edge_label_index
        combined_emb = torch.cat([user_emb, item_emb], dim=0)
        user_embed_src = combined_emb[user_indices]
        item_embed_dst = combined_emb[item_indices]
        return (user_embed_src * item_embed_dst).sum(dim=-1)
    
    def recommend(self, edge_index: Adj, edge_weight: OptTensor = None, src_index: OptTensor = None,
                  dst_index: OptTensor = None, k: int = 1, sorted: bool = True) -> Tensor:
        out_src = out_dst = self.get_embedding(edge_index, edge_weight)

        if src_index is not None:
            out_src = out_src[src_index]

        if dst_index is not None:
            out_dst = out_dst[dst_index]

        pred = out_src @ out_dst.t()
        top_index = pred.topk(k, dim=-1, sorted=sorted).indices

        if dst_index is not None:  # Map local top-indices to original indices.
            top_index = dst_index[top_index.view(-1)].view(*top_index.size())

        return top_index


    def link_pred_loss(self, pred: Tensor, edge_label: Tensor,
                       **kwargs) -> Tensor:
        loss_fn = torch.nn.BCEWithLogitsLoss(**kwargs)
        return loss_fn(pred, edge_label.to(pred.dtype))


    def recommendation_loss(self, pos_edge_rank: Tensor, neg_edge_rank: Tensor,
                            node_id: Optional[Tensor] = None, lambda_reg: float = 1e-4, **kwargs) -> Tensor:
        r"""Computes the model loss for a ranking objective via the Bayesian
        Personalized Ranking (BPR) loss."""
        loss_fn = BPRLoss(lambda_reg, **kwargs)
        if node_id is None:
            emb = torch.cat([self.user_embedding.weight, self.item_embedding.weight], dim=0)
        else:
            emb = torch.cat([self.user_embedding.weight, self.item_embedding.weight], dim=0)[node_id]
        return loss_fn(pos_edge_rank, neg_edge_rank, emb)

    def bpr_loss(self, pos_scores, neg_scores):
        epsilon = 1e-12  # To prevent log(0) which results in -inf
        diff = pos_scores - neg_scores
        stable_diff = torch.clamp(diff, min=-10, max=10)  # Clamping to avoid overflow
        return -torch.log(torch.sigmoid(stable_diff) + epsilon).mean()

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.num_nodes}, '
                f'{self.embedding_dim}, num_layers={self.num_layers})')



class BPRLoss(_Loss):
    r"""The Bayesian Personalized Ranking (BPR) loss.

    The BPR loss is a pairwise loss that encourages the prediction of an
    observed entry to be higher than its unobserved counterparts
    (see `here <https://arxiv.org/abs/2002.02126>`__).

    .. math::
        L_{\text{BPR}} = - \sum_{u=1}^{M} \sum_{i \in \mathcal{N}_u}
        \sum_{j \not\in \mathcal{N}_u} \ln \sigma(\hat{y}_{ui} - \hat{y}_{uj})
        + \lambda \vert\vert \textbf{x}^{(0)} \vert\vert^2

    where :math:`lambda` controls the :math:`L_2` regularization strength.
    We compute the mean BPR loss for simplicity.

    Args:
        lambda_reg (float, optional): The :math:`L_2` regularization strength
            (default: 0).
        **kwargs (optional): Additional arguments of the underlying
            :class:`torch.nn.modules.loss._Loss` class.
    """
    __constants__ = ['lambda_reg']
    lambda_reg: float

    def __init__(self, lambda_reg: float = 0, **kwargs):
        super().__init__(None, None, "sum", **kwargs)
        self.lambda_reg = lambda_reg

    def forward(self, positives: Tensor, negatives: Tensor,
                parameters: Tensor = None) -> Tensor:
        r"""Compute the mean Bayesian Personalized Ranking (BPR) loss.

        .. note::

            The i-th entry in the :obj:`positives` vector and i-th entry
            in the :obj:`negatives` entry should correspond to the same
            entity (*.e.g*, user), as the BPR is a personalized ranking loss.

        Args:
            positives (Tensor): The vector of positive-pair rankings.
            negatives (Tensor): The vector of negative-pair rankings.
            parameters (Tensor, optional): The tensor of parameters which
                should be used for :math:`L_2` regularization
                (default: :obj:`None`).
        """
        diff = torch.clamp(positives - negatives, min=-10, max=10)
        log_prob = F.logsigmoid(diff).sum()
        n_pairs = positives.size(0)
        regularization = 0

        if self.lambda_reg != 0:
            regularization = self.lambda_reg * parameters.norm(p=2).pow(2)

        return (-log_prob + regularization) / n_pairs

Our main specifications will use a Bayesian Personalized Ranking, which is calculated as

\begin{equation*}
    \text{BPR Loss}(i) = \frac{1}{|\mathcal{E}(i)|} \underset{{(i, j_{+}) \in \mathcal{E}(i)}}{\sum} \log \sigma \left( \text{score}(i, j_+) - \text{score}(i, j_-) \right)
\end{equation*}

for a pair of positive edge $(i, j_{+})$ and negative edge $(i, j_{-})$.

Since our model focuses on link prediction between a pair of user and track node, a negative edge means that there is no link between such two nodes.

Important to any link prediction task is negative sampling. In the graph, we observe positive edges, which allows us to capture which nodes should be most similar to one another. Adding negative edges allows the model to explicitly capture that nodes that don't share an edge should have different embeddings. Without negative edges,  a valid loss minimization strategy would be to simply assign all nodes the same embedding, which is obviously not meaningful or desirable.

In [79]:
def sample_negative_edges_nocheck(data, num_users, num_tracks, device = None):
  # note computationally inefficient to check that these are indeed negative edges
    users = data.edge_label_index[0, :]
    tracks = torch.randint(num_users, num_users + num_tracks - 1, size = data.edge_label_index[1, :].size())
    if users.get_device() != -1: # on gpu
        tracks = tracks.to(device)

    neg_edge_index = torch.stack((users, tracks), dim = 0)
    neg_edge_label = torch.zeros(neg_edge_index.shape[1])

    if neg_edge_index.get_device() != -1: # on gpu
        neg_edge_label = neg_edge_label.to(device)
    
    return neg_edge_index, neg_edge_label

def sample_negative_edges(data, num_users, num_tracks, device=None):
  
    # positive_users, positive_tracks = data.edge_label_index
    positive_edges = data.edge_label_index[:, data.edge_label_index[0] < n_users]  # Filter edges that start with a user node
    positive_users = positive_edges[0]
    positive_tracks = positive_edges[1] - n_users  # Adjust track indices to start from 0

    mask = torch.zeros(num_users, num_tracks, device=device, dtype=torch.bool)

    mask[positive_users, positive_tracks] = True

    # Flatten the mask tensor and get the indices of the negative edges
    flat_mask = mask.flatten()
    negative_indices = torch.where(~flat_mask)[0]

    # Sample negative edges from the negative_indices tensor
    sampled_negative_indices = negative_indices[
        torch.randint(0, negative_indices.size(0), size=(positive_users.size(0),), device=device)
    ]

    # Convert the indices back to playlists and tracks tensors
    users = torch.floor_divide(sampled_negative_indices, num_tracks)
    tracks = torch.remainder(sampled_negative_indices, num_tracks)
    tracks = tracks + num_users

    neg_edge_index = torch.stack((users, tracks), dim=0)
    neg_edge_label = torch.zeros(neg_edge_index.shape[1], device=device)
    return neg_edge_index, neg_edge_label

def sample_hard_negative_edges(data, model, num_users, num_items, device=None, batch_size=500, frac_sample=1):
    with torch.no_grad():
        user_embeddings, item_embeddings = model.get_embedding(data.edge_index, data.edge_weight)
        user_embeddings = user_embeddings.to(device)
        item_embeddings = item_embeddings.to(device)

    positive_users, positive_items = data.edge_label_index
    num_edges = positive_users.size(0)

    # Create a boolean mask for all the positive edges
    positive_mask = torch.zeros(num_users, num_items, device=device, dtype=torch.bool)
    positive_mask[positive_users, positive_items] = True

    neg_edges_list = []
    neg_edge_label_list = []

    for batch_start in range(0, num_edges, batch_size):
        batch_end = min(batch_start + batch_size, num_edges)

        batch_scores = torch.matmul(
            user_embeddings[positive_users[batch_start:batch_end]], item_embeddings.t()
        )

        # Set the scores of the positive edges to negative infinity
        batch_scores[positive_mask[positive_users[batch_start:batch_end]]] = -float("inf")

        # Select the top k highest scoring negative edges for each user in the current batch
        # do 0.99 to filter out all pos edges which will be at the end
        _, top_indices = torch.topk(batch_scores, int(frac_sample * 0.99 * num_items), dim=1)
        selected_indices = torch.randint(0, int(frac_sample * 0.99 * num_items), size=(batch_end - batch_start,))
        top_indices_selected = top_indices[torch.arange(batch_end - batch_start), selected_indices]

        # Create the negative edges tensor for the current batch
        neg_edges_batch = torch.stack(
            (positive_users[batch_start:batch_end], top_indices_selected), dim=0
        )
        neg_edge_label_batch = torch.zeros(neg_edges_batch.shape[1], device=device)

        neg_edges_list.append(neg_edges_batch)
        neg_edge_label_list.append(neg_edge_label_batch)

    # Concatenate the batch tensors
    neg_edges = torch.cat(neg_edges_list, dim=1)
    neg_edge_label = torch.cat(neg_edge_label_list)

    return neg_edges, neg_edge_label

The evaluation metrics on top of loss calculation: recall at K.For a playlist $i$, $P^k_i$ represents the set of the top $k$ predicted tracks for $i$ and $R_i$ the ground truth of connected tracks to user $i$, then we calculate
$$
\text{recall}^k_i = \frac{| P^k_i \cap R_i | }{|R_i|}.
$$
If $R_i = 0$, then we assign this value to 1. Note, if $R_i \subset P_i^k$, then the recall is equal to 1. Hence, our choice of $k$ matters a lot.

Note: when evaluating this metric on our validation or test set, we need to make sure to filter the message passing edges from consideration, as the model can directly observe these.

We choose a value of $k = 30$ in this case, as each user has 50 songs in their top songs 

In [80]:
def recall_at_k(data, model, k = 30, batch_size = 64, device = None):
    with torch.no_grad():
        user_embeddings, tracks_embeddings = model.get_embedding(data.edge_index, data.edge_weight)

    hits_list = []
    relevant_counts_list = []

    for batch_start in range(0, n_users, batch_size):
        batch_end = min(batch_start + batch_size, n_users)
        batch_users_embeddings = user_embeddings[batch_start:batch_end]

        # Calculate scores for all possible item pairs
        scores = torch.matmul(batch_users_embeddings, tracks_embeddings.t())

        # Set the scores of message passing edges to negative infinity
        mp_indices = ((data.edge_index[0] >= batch_start) & (data.edge_index[0] < batch_end)).nonzero(as_tuple=True)[0]
        scores[data.edge_index[0, mp_indices] - batch_start, data.edge_index[1, mp_indices] - n_users] = -float("inf")

        # Find the top k highest scoring items for each playlist in the batch
        _, top_k_indices = torch.topk(scores, k, dim=1)

        # Ground truth supervision edges
        ground_truth_edges = data.edge_label_index

        # Create a mask to indicate if the top k items are in the ground truth supervision edges
        mask = torch.zeros(scores.shape, device=device, dtype=torch.bool)
        gt_indices = ((ground_truth_edges[0] >= batch_start) & (ground_truth_edges[0] < batch_end)).nonzero(as_tuple=True)[0]
        mask[ground_truth_edges[0, gt_indices] - batch_start, ground_truth_edges[1, gt_indices] - n_users] = True

        # Check how many of the top k items are in the ground truth supervision edges
        hits = mask.gather(1, top_k_indices).sum(dim=1)
        hits_list.append(hits)

        # Calculate the total number of relevant items for each playlist in the batch
        relevant_counts = torch.bincount(ground_truth_edges[0, gt_indices] - batch_start, minlength=batch_end - batch_start)
        relevant_counts_list.append(relevant_counts)

    # Compute recall@k
    hits_tensor = torch.cat(hits_list, dim=0)
    relevant_counts_tensor = torch.cat(relevant_counts_list, dim=0)
    # Handle division by zero case
    recall_at_k = torch.where(
        relevant_counts_tensor != 0,
        hits_tensor.true_divide(relevant_counts_tensor),
        torch.ones_like(hits_tensor)
    )
    # take average
    recall_at_k = torch.mean(recall_at_k)

    if recall_at_k.numel() == 1:
        return recall_at_k.item()
    else:
        raise ValueError("recall_at_k contains more than one item.")

In [81]:
def metrics(labels, preds):
  roc = roc_auc_score(labels.flatten().cpu().numpy(), preds.flatten().data.cpu().numpy())
  return roc
# Train
def train(datasets, model, optimizer, loss_fn, args, K = 30, neg_samp = "random"):

  train_data = datasets["train"]
  val_data = datasets["val"]

  stats = {
      'train': {
        'loss': [],
        'roc' : []
      },
      'val': {
        'loss': [],
        'recall': [],
        'roc' : []
      }

  }
  val_neg_edge, val_neg_label = None, None
  for epoch in range(args["epochs"]): # loop over each epoch
    model.train()
    optimizer.zero_grad()

    # obtain negative sample
    if neg_samp == "random":
      neg_edge_index, neg_edge_label = sample_negative_edges(train_data, n_users, n_tracks, args["device"])
    elif neg_samp == "hard":
      if epoch % 5 == 0:
        neg_edge_index, neg_edge_label = sample_hard_negative_edges(
            train_data, model, n_users, n_tracks, args["device"], batch_size = 500,
            frac_sample = 1 - (0.5 * epoch / args["epochs"])
        )
    # calculate embedding
    user_emb, item_emb = model.get_embedding(train_data.edge_index, train_data.edge_weight)

    # calculate pos, negative scores using embedding
    pos_scores = model.predict_link_embedding(user_emb, item_emb, train_data.edge_label_index)
    neg_scores = model.predict_link_embedding(user_emb, item_emb, neg_edge_index)
    if pos_scores.size(0) > neg_scores.size(0):
      indices = torch.randperm(pos_scores.size(0))[:neg_scores.size(0)]
      pos_scores = pos_scores[indices]
    balanced_pos_labels = torch.ones(pos_scores.size(0))
    balanced_neg_labels = torch.zeros(neg_scores.size(0))
    

    # concatenate pos, neg scores together and evaluate loss
    scores = torch.cat((pos_scores, neg_scores), dim = 0)
    labels = torch.cat((balanced_pos_labels, balanced_neg_labels), dim=0)
    loss = model.recommendation_loss(pos_scores, neg_scores, lambda_reg = 1)

    train_roc = metrics(labels, scores)

    loss.backward()
    optimizer.step()

    val_loss, val_roc, val_neg_edge, val_neg_label = test(
        model, val_data, loss_fn, neg_samp, epoch, val_neg_edge, val_neg_label
    )

    stats['train']['loss'].append(loss)
    stats['train']['roc'].append(train_roc)
    stats['val']['loss'].append(val_loss)
    stats['val']['roc'].append(val_roc)

    print(f"Epoch {epoch}; Train loss {loss}; Val loss {val_loss}; Train ROC {train_roc}; Val ROC {val_roc}")

    if epoch % 10 == 0:
      # calculate recall @ K
      val_recall = recall_at_k(val_data, model, k = K, device = args["device"])
      print(f"Val recall {val_recall}")
      stats['val']['recall'].append(val_recall)

    if epoch % 20 == 0:

      # save embeddings for future visualization
      path = os.path.join("model_embeddings", model.name)
      if not os.path.exists(path):
        os.makedirs(path)
      embeddings = {
        'user_embedding': model.user_embedding.weight,
        'item_embedding': model.item_embedding.weight
      }
      torch.save(embeddings, os.path.join("model_embeddings", model.name, f"{model.name}_embeddings_{loss_fn}_{neg_samp}_{epoch}.pt"))

  pickle.dump(stats, open(f"model_stats/{model.name}_{loss_fn}_{neg_samp}.pkl", "wb"))
  return stats

def test(model, data, loss_fn, neg_samp, epoch = 0, neg_edge_index = None, neg_edge_label = None):

  model.eval()
  with torch.no_grad(): # want to save RAM

    # conduct negative sampling
    if neg_samp == "random":
      neg_edge_index, neg_edge_label = sample_negative_edges(data, n_users, n_tracks, args["device"])
    # obtain model embedding
    # embed = model.get_embedding(data.edge_index, data.edge_weight)
    user_emb, item_emb = model.get_embedding(data.edge_index, data.edge_weight)
    # calculate pos, neg scores using embedding
    pos_scores = model.predict_link_embedding(user_emb, item_emb, data.edge_label_index)
    neg_scores = model.predict_link_embedding(user_emb, item_emb, neg_edge_index)
    if pos_scores.size(0) > neg_scores.size(0):
      indices = torch.randperm(pos_scores.size(0))[:neg_scores.size(0)]
      pos_scores = pos_scores[indices]
    # concatenate pos, neg scores together and evaluate loss
    scores = torch.cat((pos_scores, neg_scores), dim = 0)
    balanced_pos_labels = torch.ones(pos_scores.size(0))
    balanced_neg_labels = torch.zeros(neg_scores.size(0))
    labels = torch.cat((balanced_pos_labels, balanced_neg_labels), dim=0)

    # calculate loss
    loss = model.recommendation_loss(pos_scores, neg_scores, lambda_reg = 1)

    roc = metrics(labels, scores)

  return loss, roc, neg_edge_index, neg_edge_label

def weighted_binary_cross_entropy(output, target, edge_weight):
    # Apply sigmoid activation to ensure output is in [0, 1]
    output = torch.sigmoid(output)

    # Only apply edge weights to the positive scores
    pos_output = output[:len(edge_weight)]
    neg_output = output[len(edge_weight):]
    
    # Adjust the scaling factor and exponent as needed
    scaling_factor = 10
    exponent = -0.5
    weight = scaling_factor * torch.pow(edge_weight[:len(pos_output)] + 1e-8, exponent)
    
    # Compute loss separately for positive and negative scores
    pos_loss = F.binary_cross_entropy(pos_output, target[:len(edge_weight)], weight=weight, reduction='none')
    neg_loss = F.binary_cross_entropy(neg_output, target[len(edge_weight):], reduction='none')
    
    # Combine the losses
    loss = torch.cat([pos_loss, neg_loss], dim=0).mean()
    return loss

In [82]:
# create a dictionary of the dataset splits
datasets = {
    'train':train_split,
    'val':val_split,
    'test': test_split
}

# initialize our arguments
args = {
    'device' : 'cuda' if torch.cuda.is_available() else 'cpu',
    'num_layers' : 3,
    'emb_size' : 64,
    'weight_decay': 1e-5,
    'lr': 0.01,
    'epochs': 301
}

# initialize model and and optimizer
num_nodes = n_users + n_tracks
model = GCN(
    num_users = n_users, num_items= n_tracks, num_layers = args['num_layers'],
    embedding_dim = args["emb_size"]
)
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])

# send data, model to GPU if available
users_idx = torch.Tensor(users_idx).type(torch.int64).to(args["device"])
tracks_idx = torch.Tensor(tracks_idx).type(torch.int64).to(args["device"])
datasets['train'].to(args['device'])
datasets['val'].to(args['device'])
datasets['test'].to(args['device'])
model.to(args["device"])

# create directory to save model_stats
MODEL_STATS_DIR = "model_stats"
if not os.path.exists(MODEL_STATS_DIR):
  os.makedirs(MODEL_STATS_DIR)

train(datasets, model, optimizer, "BPR", args, neg_samp = "random")

test(model, datasets['test'], "BPR", neg_samp = "random")

Epoch 0; Train loss 0.694583535194397; Val loss 0.7402910590171814; Train ROC 0.7381654584340545; Val ROC 0.5017101352848458
Val recall 0.02304498665034771
Epoch 1; Train loss 0.7032086849212646; Val loss 0.697307825088501; Train ROC 0.8689706803386928; Val ROC 0.5003601581994733
Epoch 2; Train loss 0.694034993648529; Val loss 0.7132035493850708; Train ROC 0.8491368602349593; Val ROC 0.4977801333483976
Epoch 3; Train loss 0.6974281668663025; Val loss 0.718108057975769; Train ROC 0.863461740058883; Val ROC 0.4983316696772814
Epoch 4; Train loss 0.6984744668006897; Val loss 0.7020402550697327; Train ROC 0.8754052055463788; Val ROC 0.5008194112778334
Epoch 5; Train loss 0.6950436234474182; Val loss 0.6966994404792786; Train ROC 0.8705915994369144; Val ROC 0.5036311867553779
Epoch 6; Train loss 0.6939041614532471; Val loss 0.7043846249580383; Train ROC 0.8363537709412338; Val ROC 0.5001754375220441
Epoch 7; Train loss 0.6955457925796509; Val loss 0.7072674036026001; Train ROC 0.85580441376

(tensor(0.6931),
 0.4990553388842113,
 tensor([[  3369,   6292,      8,  ...,   6086,    492,   1023],
         [226937, 421341, 671250,  ...,  33975,  51297, 898459]]),
 tensor([0., 0., 0.,  ..., 0., 0., 0.]))