In [5]:
# General libraries
import json
from pathlib import Path as Data_Path
import os
from os.path import isfile, join
import pickle
import random

import numpy as np
import networkx as nx
import pandas as pd
from sklearn.metrics import f1_score, roc_auc_score
import matplotlib.pyplot as plt
%matplotlib inline

from tqdm.notebook import tqdm

# Import relevant ML libraries
from typing import Optional, Union

import torch
from torch import Tensor
import torch.nn as nn
from torch.nn import Embedding, ModuleList, Linear
import torch.nn.functional as F

import torch_geometric
import torch_geometric.nn as pyg_nn
from torch_geometric.data import Data
from torch_geometric.transforms import RandomLinkSplit
from torch.nn.modules.loss import _Loss

from torch_geometric.nn.conv import LGConv, GATConv, SAGEConv
from torch_geometric.typing import Adj, OptTensor, SparseTensor
from tqdm.notebook import tqdm
print(f"Torch version: {torch.__version__}; Torch-cuda version: {torch.version.cuda}; Torch Geometric version: {torch_geometric.__version__}.")

Torch version: 2.2.1; Torch-cuda version: None; Torch Geometric version: 2.5.1.


In [6]:
# set the seed for reproducibility
seed = 224
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

In [7]:
data = pd.read_csv('../Data/user_songs_filtered.csv')
data = data.sort_values(by = 'Username')
data = data.reset_index(drop=True)

In [None]:
class Track:
  def __init__(self, track_dict):
    self.name = track_dict["track_name"]
    self.artist_name = track_dict["artist_name"]
    self.listeners = track_dict["listeners"]
    self.total_playcount = track_dict["total_playcount"]
    self.emotion1 = track_dict["emotion1"]
    self.emotion1_score = track_dict["emotion1_score"]
    self.emotion2 = track_dict["emotion2"]
    self.emotion2_score = track_dict["emotion2_score"]
    self.rms = track_dict["rms"]
    self.spectral_centroid	= track_dict['spectral_centroid']
    self.tempo = track_dict['tempo']

  def __str__(self):
    return f"Track called {self.name} by ({self.artist_name}) has emotions {self.emotion1} and {self.emotion2}."

  def __repr__(self):
    return f"Track {self.name}"

  def __lt__(self, other):
    return (self.name < other.name) and (self.total_playcount < other.total_playcount)

  def __gt__(self, other):
    return (self.name > other.name) and (self.total_playcount > other.total_playcount)
  
data[['track_name', 'artist_name']].drop_duplicates().to_numpy()

tracks = []
unique_tracks = data[['track_name', 'artist_name']].drop_duplicates().to_numpy()
for track, artist in tqdm(unique_tracks):
    row_data = data.loc[(data['track_name'] == track) & (data['artist_name'] == artist)].iloc[0]
    track_data = row_data[['track_name', 'artist_name', 'listeners', 'total_playcount', 'emotion1', 'emotion1_score',
                           'emotion2', 'emotion2_score', 'rms', 'spectral_centroid', 'tempo']].to_dict()
    tracks += [Track(track_data)]


In [10]:

class User:
  def __init__(self, user_data, top_songs):

    self.name = user_data['Username']
    self.country = user_data['country']
    self.track_count = int(user_data['track_count'])
    self.total_playcount = 0
    self.top_songs = {}
    self.artists = []
    for index, row in top_songs.iterrows():
      rank = row['rank']
      track_name = row['track_name']
      artist_name = row['artist_name']
      track = [obj for obj in tracks if (obj.name == track_name)&(obj.artist_name == artist_name)][0]
      playcount = row['playcount']
      self.top_songs[rank] = (track, playcount)
      self.total_playcount += playcount
      # or should i use a dict for artists
      self.artists +=[artist_name]

    self.top_songs =  {k: self.top_songs[k] for k in sorted(self.top_songs)}
  def __str__(self):
    return f"User {self.name} with {len(self.top_songs)} top tracks loaded, total listen count is {self.total_playcount}."

  def __repr__(self):
    return f"User {self.name}"
  def __lt__(self, other):
    return (self.name < other.name) and (self.total_playcount < other.total_playcount)

  def __gt__(self, other):
    return (self.name > other.name) and (self.total_playcount > other.total_playcount)


In [11]:

unique_users = data.Username.unique()
users = []
for user in tqdm(unique_users):
    user_data = data.loc[data['Username'] == user].iloc[0]
    user_data = user_data[['Username', 'country', 'track_count']].to_dict()
    top_songs = data.loc[data['Username'] == user]
    top_songs = top_songs[['rank', 'track_name', 'artist_name', 'playcount']]
    users +=[User(user_data, top_songs)]

  0%|          | 0/9483 [00:00<?, ?it/s]

In [12]:
# adding nodes
G = nx.Graph()
G.add_nodes_from([
    (p, {'name':p, "node_type" : "user"}) for p in users
])
G.add_nodes_from([
    (t, {'name':t, "node_type" : "track"}) for t in tracks
])

# adding edges
for user in users:
  top_songs = user.top_songs
  user_total_listening = user.total_playcount
  for song, count in top_songs.values():
    G.add_edge(user, song, weight=1 / (count/user_total_listening))


# Make a large subgraph
random.seed(225)
rand_nodes_lg = random.sample(list(G.nodes()), 3000)
sub_G_lg = G.subgraph(rand_nodes_lg)
largest_cc_lg = max(nx.connected_components(sub_G_lg.to_undirected()), key=len)
sub_G_lg = nx.Graph(sub_G_lg.subgraph(largest_cc_lg))
print('Large subgraph Num nodes:', sub_G_lg.number_of_nodes(),
      '. Num edges:', sub_G_lg.number_of_edges())


n_nodes, n_edges = G.number_of_nodes(), G.number_of_edges()

# by sorting them we get an ordering playlist1, ..., playlistN, track1, ..., trackN
sorted_nodes = list(G.nodes())

# create dictionaries to index to 0 to n_nodes, will be necessary for when we are using tensors
node2id = dict(zip(sorted_nodes, np.arange(n_nodes)))
id2node = dict(zip(np.arange(n_nodes), sorted_nodes))

G = nx.relabel_nodes(G, node2id)

# also keep track of how many users, tracks we have
users_idx = [i for i, v in enumerate(node2id.keys()) if isinstance(v, User)] 
tracks_idx = [i for i, v in enumerate(node2id.keys()) if isinstance(v, Track)]
n_users = np.max(users_idx) + 1
n_tracks = n_nodes - n_users

n_users, n_tracks

# turn the graph into a torch_geometric Data object
num_nodes = G.number_of_nodes()
edge_idx = torch.Tensor(np.array(G.edges()).T)
# Get the edge weights from the NetworkX graph
edge_weights = []
for u, v in G.edges():
    edge_weights.append(G[u][v]['weight'])
edge_weights = torch.tensor(edge_weights, dtype=torch.float)
print(edge_weights)
graph_data = Data(edge_index = edge_idx, edge_weight = edge_weights, num_nodes = num_nodes)

# convert to train/val/test splits
transform = RandomLinkSplit(
    is_undirected=True,
    add_negative_train_samples=False,
    neg_sampling_ratio=0,
    num_val=0.15, num_test=0.15
)
train_split, val_split, test_split = transform(graph_data)
for split_data in [train_split, val_split, test_split]:
    edge_weights_split = []
    for u, v in split_data.edge_index.T:
        edge_weights_split.append(G[node2id[id2node[u.item()]]][node2id[id2node[v.item()]]]['weight'])
    split_data.edge_weight = torch.tensor(edge_weights_split, dtype=torch.float)


# note these are stored as float32, we need them to be int64 for future training

# Edge index: message passing edges
train_split.edge_index = train_split.edge_index.type(torch.int64)
val_split.edge_index = val_split.edge_index.type(torch.int64)
test_split.edge_index = test_split.edge_index.type(torch.int64)
# Edge label index: supervision edges
train_split.edge_label_index = train_split.edge_label_index.type(torch.int64)
val_split.edge_label_index = val_split.edge_label_index.type(torch.int64)
test_split.edge_label_index = test_split.edge_label_index.type(torch.int64)

print(f"Train set has {train_split.edge_label_index.shape[1]} positives supervision edges")
print(f"Validation set has {val_split.edge_label_index.shape[1]} positive supervision edges")
print(f"Test set has {test_split.edge_label_index.shape[1]} positive supervision edges")

print(f"Train set has {train_split.edge_index.shape[1]} message passing edges")
print(f"Validation set has {val_split.edge_index.shape[1]} message passing edges")
print(f"Test set has {test_split.edge_index.shape[1]} message passing edges")


Large subgraph Num nodes: 10 . Num edges: 9
tensor([20.5141, 25.7788, 26.7248,  ..., 52.7391, 52.7391, 52.7391])
Train set has 275184 positives supervision edges
Validation set has 58968 positive supervision edges
Test set has 58968 positive supervision edges
Train set has 550368 message passing edges
Validation set has 550368 message passing edges
Test set has 668304 message passing edges


In [13]:
class GCN(torch.nn.Module):
    """
      Here we adapt the LightGCN model from Torch Geometric for our purposes. We allow
      for customizable convolutional layers, custom embeddings. In addition, we deifne some
      additional custom functions.

    """

    def __init__(
        self,
        num_nodes: int,
        embedding_dim: int,
        num_layers: int,
        alpha: Optional[Union[float, Tensor]] = None,
        alpha_learnable = False,
        conv_layer = "LGC",
        name = None,
        **kwargs,
    ):
        super().__init__()
        alpha_string = "alpha" if alpha_learnable else ""
        self.name = f"LGCN_{conv_layer}_{num_layers}_e{embedding_dim}_nodes{num_nodes}_{alpha_string}"
        self.num_nodes = num_nodes
        self.embedding_dim = embedding_dim
        self.num_layers = num_layers

        if alpha_learnable == True:
          alpha_vals = torch.rand(num_layers+1)
          alpha = nn.Parameter(alpha_vals/torch.sum(alpha_vals))
          print(f"Alpha learnable, initialized to: {alpha.softmax(dim=-1)}")
        else:
          if alpha is None:
              alpha = 1. / (num_layers + 1)

          if isinstance(alpha, Tensor):
              assert alpha.size(0) == num_layers + 1
          else:
              alpha = torch.tensor([alpha] * (num_layers + 1))

        self.register_buffer('alpha', alpha)

        self.embedding = Embedding(num_nodes, embedding_dim)

        # initialize convolutional layers
        self.conv_layer = conv_layer
        if conv_layer == "LGC":
          self.convs = ModuleList([LGConv(**kwargs) for _ in range(num_layers)])
        elif conv_layer == "GAT":
          # initialize Graph Attention layer with multiple heads
          # initialize linear layers to aggregate heads
          n_heads = 5
          self.convs = ModuleList(
              [GATConv(in_channels = embedding_dim, out_channels = embedding_dim, heads = n_heads, dropout = 0.5, **kwargs) for _ in range(num_layers)]
          )
          self.linears = ModuleList([Linear(n_heads * embedding_dim, embedding_dim) for _ in range(num_layers)])

        elif conv_layer == "SAGE":
          #  initialize GraphSAGE conv
          self.convs = ModuleList(
              [SAGEConv(in_channels = embedding_dim, out_channels = embedding_dim, **kwargs) for _ in range(num_layers)]
          )

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.embedding.weight)
        for conv in self.convs:
            conv.reset_parameters()

    def get_embedding(self, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor:
        x = self.embedding.weight

        weights = self.alpha.softmax(dim=-1)
        out = x * weights[0]

        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index, edge_weight)
            if self.conv_layer == "GAT":
              x = self.linears[i](x)
            out = out + x * weights[i + 1]

        return out

    def initialize_embeddings(self, data):
      # initialize with the data node features
        self.embedding.weight.data.copy_(data.node_feature)


    def forward(self, edge_index: Adj,
                edge_label_index: OptTensor = None, edge_weight: OptTensor = None) -> Tensor:
        if edge_label_index is None:
            if isinstance(edge_index, SparseTensor):
                edge_label_index = torch.stack(edge_index.coo()[:2], dim=0)
            else:
                edge_label_index = edge_index

        out = self.get_embedding(edge_index, edge_weight)

        return self.predict_link_embedding(out, edge_label_index)

    def predict_link(self, edge_index: Adj, edge_label_index: OptTensor = None,
                     edge_weight: OptTensor = None,
                     prob: bool = False) -> Tensor:

        pred = self(edge_index, edge_label_index, edge_weight).sigmoid()
        return pred if prob else pred.round()

    def predict_link_embedding(self, embed: Adj, edge_label_index: Adj) -> Tensor:

        embed_src = embed[edge_label_index[0]]
        embed_dst = embed[edge_label_index[1]]
        return (embed_src * embed_dst).sum(dim=-1)


    def recommend(self, edge_index: Adj, edge_weight: OptTensor = None, src_index: OptTensor = None,
                  dst_index: OptTensor = None, k: int = 1, sorted: bool = True) -> Tensor:
        out_src = out_dst = self.get_embedding(edge_index, edge_weight)

        if src_index is not None:
            out_src = out_src[src_index]

        if dst_index is not None:
            out_dst = out_dst[dst_index]

        pred = out_src @ out_dst.t()
        top_index = pred.topk(k, dim=-1, sorted=sorted).indices

        if dst_index is not None:  # Map local top-indices to original indices.
            top_index = dst_index[top_index.view(-1)].view(*top_index.size())

        return top_index


    def link_pred_loss(self, pred: Tensor, edge_label: Tensor,
                       **kwargs) -> Tensor:
        loss_fn = torch.nn.BCEWithLogitsLoss(**kwargs)
        return loss_fn(pred, edge_label.to(pred.dtype))


    def recommendation_loss(self, pos_edge_rank: Tensor, neg_edge_rank: Tensor,
                            node_id: Optional[Tensor] = None, lambda_reg: float = 1e-4, **kwargs) -> Tensor:
        r"""Computes the model loss for a ranking objective via the Bayesian
        Personalized Ranking (BPR) loss."""
        loss_fn = BPRLoss(lambda_reg, **kwargs)
        emb = self.embedding.weight
        emb = emb if node_id is None else emb[node_id]
        return loss_fn(pos_edge_rank, neg_edge_rank, emb)

    def bpr_loss(self, pos_scores, neg_scores):
      return - torch.log(torch.sigmoid(pos_scores - neg_scores)).mean()

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.num_nodes}, '
                f'{self.embedding_dim}, num_layers={self.num_layers})')



class BPRLoss(_Loss):
    r"""The Bayesian Personalized Ranking (BPR) loss.

    The BPR loss is a pairwise loss that encourages the prediction of an
    observed entry to be higher than its unobserved counterparts
    (see `here <https://arxiv.org/abs/2002.02126>`__).

    .. math::
        L_{\text{BPR}} = - \sum_{u=1}^{M} \sum_{i \in \mathcal{N}_u}
        \sum_{j \not\in \mathcal{N}_u} \ln \sigma(\hat{y}_{ui} - \hat{y}_{uj})
        + \lambda \vert\vert \textbf{x}^{(0)} \vert\vert^2

    where :math:`lambda` controls the :math:`L_2` regularization strength.
    We compute the mean BPR loss for simplicity.

    Args:
        lambda_reg (float, optional): The :math:`L_2` regularization strength
            (default: 0).
        **kwargs (optional): Additional arguments of the underlying
            :class:`torch.nn.modules.loss._Loss` class.
    """
    __constants__ = ['lambda_reg']
    lambda_reg: float

    def __init__(self, lambda_reg: float = 0, **kwargs):
        super().__init__(None, None, "sum", **kwargs)
        self.lambda_reg = lambda_reg

    def forward(self, positives: Tensor, negatives: Tensor,
                parameters: Tensor = None) -> Tensor:
        r"""Compute the mean Bayesian Personalized Ranking (BPR) loss.

        .. note::

            The i-th entry in the :obj:`positives` vector and i-th entry
            in the :obj:`negatives` entry should correspond to the same
            entity (*.e.g*, user), as the BPR is a personalized ranking loss.

        Args:
            positives (Tensor): The vector of positive-pair rankings.
            negatives (Tensor): The vector of negative-pair rankings.
            parameters (Tensor, optional): The tensor of parameters which
                should be used for :math:`L_2` regularization
                (default: :obj:`None`).
        """
        n_pairs = positives.size(0)
        log_prob = F.logsigmoid(positives - negatives).sum()
        regularization = 0

        if self.lambda_reg != 0:
            regularization = self.lambda_reg * parameters.norm(p=2).pow(2)

        return (-log_prob + regularization) / n_pairs

In [14]:
def sample_negative_edges_nocheck(data, num_users, num_tracks, device = None):
  # note computationally inefficient to check that these are indeed negative edges
  users = data.edge_label_index[0, :]
  tracks = torch.randint(num_users, num_users + num_tracks - 1, size = data.edge_label_index[1, :].size())

  if users.get_device() != -1: # on gpu
    tracks = tracks.to(device)

  neg_edge_index = torch.stack((users, tracks), dim = 0)
  neg_edge_label = torch.zeros(neg_edge_index.shape[1])

  if neg_edge_index.get_device() != -1: # on gpu
    neg_edge_label = neg_edge_label.to(device)

  return neg_edge_index, neg_edge_label

def sample_negative_edges(data, num_users, num_tracks, device=None):
    positive_users, positive_tracks = data.edge_label_index

    # Create a mask tensor with the shape (num_playlists, num_tracks)
    mask = torch.zeros(num_users, num_tracks, device=device, dtype=torch.bool)
    mask[positive_users, positive_tracks - num_users] = True

    # Flatten the mask tensor and get the indices of the negative edges
    flat_mask = mask.flatten()
    negative_indices = torch.where(~flat_mask)[0]

    # Sample negative edges from the negative_indices tensor
    sampled_negative_indices = negative_indices[
        torch.randint(0, negative_indices.size(0), size=(positive_users.size(0),), device=device)
    ]

    # Convert the indices back to playlists and tracks tensors
    playlists = torch.floor_divide(sampled_negative_indices, num_tracks)
    tracks = torch.remainder(sampled_negative_indices, num_tracks)
    tracks = tracks + num_users

    neg_edge_index = torch.stack((playlists, tracks), dim=0)
    neg_edge_label = torch.zeros(neg_edge_index.shape[1], device=device)

    return neg_edge_index, neg_edge_label

def sample_hard_negative_edges(data, model, num_users, num_tracks, device=None, batch_size=500, frac_sample = 1):
    with torch.no_grad():
        embeddings = model.get_embedding(data.edge_index, data.edge_weight)
        user_embeddings = embeddings[:num_users].to(device)
        tracks_embeddings = embeddings[num_users:].to(device)

    positive_users, positive_tracks = data.edge_label_index
    num_edges = positive_users.size(0)

    # Create a boolean mask for all the positive edges
    positive_mask = torch.zeros(num_users, num_tracks, device=device, dtype=torch.bool)
    positive_mask[positive_users, positive_tracks - num_users] = True

    neg_edges_list = []
    neg_edge_label_list = []

    for batch_start in range(0, num_edges, batch_size):
        batch_end = min(batch_start + batch_size, num_edges)

        batch_scores = torch.matmul(
            user_embeddings[positive_users[batch_start:batch_end]], tracks_embeddings.t()
        )

        # Set the scores of the positive edges to negative infinity
        batch_scores[positive_mask[positive_users[batch_start:batch_end]]] = -float("inf")

        # Select the top k highest scoring negative edges for each playlist in the current batch
        # do 0.99 to filter out all pos edges which will be at the end
        _, top_indices = torch.topk(batch_scores, int(frac_sample * 0.99 * num_tracks), dim=1)
        selected_indices = torch.randint(0, int(frac_sample * 0.99 *num_tracks), size = (batch_end - batch_start, ))
        top_indices_selected = top_indices[torch.arange(batch_end - batch_start), selected_indices] + n_users

        # Create the negative edges tensor for the current batch
        neg_edges_batch = torch.stack(
            (positive_users[batch_start:batch_end], top_indices_selected), dim=0
        )
        neg_edge_label_batch = torch.zeros(neg_edges_batch.shape[1], device=device)

        neg_edges_list.append(neg_edges_batch)
        neg_edge_label_list.append(neg_edge_label_batch)

    # Concatenate the batch tensors
    neg_edges = torch.cat(neg_edges_list, dim=1)
    neg_edge_label = torch.cat(neg_edge_label_list)

    return neg_edges, neg_edge_label

In [15]:
def recall_at_k(data, model, k = 30, batch_size = 64, device = None):
    with torch.no_grad():
        embeddings = model.get_embedding(data.edge_index, data.edge_weight)
        user_embeddings = embeddings[:n_users]
        tracks_embeddings = embeddings[n_users:]

    hits_list = []
    relevant_counts_list = []

    for batch_start in range(0, n_users, batch_size):
        batch_end = min(batch_start + batch_size, n_users)
        batch_users_embeddings = user_embeddings[batch_start:batch_end]

        # Calculate scores for all possible item pairs
        scores = torch.matmul(batch_users_embeddings, tracks_embeddings.t())

        # Set the scores of message passing edges to negative infinity
        mp_indices = ((data.edge_index[0] >= batch_start) & (data.edge_index[0] < batch_end)).nonzero(as_tuple=True)[0]
        scores[data.edge_index[0, mp_indices] - batch_start, data.edge_index[1, mp_indices] - n_users] = -float("inf")

        # Find the top k highest scoring items for each playlist in the batch
        _, top_k_indices = torch.topk(scores, k, dim=1)

        # Ground truth supervision edges
        ground_truth_edges = data.edge_label_index

        # Create a mask to indicate if the top k items are in the ground truth supervision edges
        mask = torch.zeros(scores.shape, device=device, dtype=torch.bool)
        gt_indices = ((ground_truth_edges[0] >= batch_start) & (ground_truth_edges[0] < batch_end)).nonzero(as_tuple=True)[0]
        mask[ground_truth_edges[0, gt_indices] - batch_start, ground_truth_edges[1, gt_indices] - n_users] = True

        # Check how many of the top k items are in the ground truth supervision edges
        hits = mask.gather(1, top_k_indices).sum(dim=1)
        hits_list.append(hits)

        # Calculate the total number of relevant items for each playlist in the batch
        relevant_counts = torch.bincount(ground_truth_edges[0, gt_indices] - batch_start, minlength=batch_end - batch_start)
        relevant_counts_list.append(relevant_counts)

    # Compute recall@k
    hits_tensor = torch.cat(hits_list, dim=0)
    relevant_counts_tensor = torch.cat(relevant_counts_list, dim=0)
    # Handle division by zero case
    recall_at_k = torch.where(
        relevant_counts_tensor != 0,
        hits_tensor.true_divide(relevant_counts_tensor),
        torch.ones_like(hits_tensor)
    )
    # take average
    recall_at_k = torch.mean(recall_at_k)

    if recall_at_k.numel() == 1:
        return recall_at_k.item()
    else:
        raise ValueError("recall_at_k contains more than one item.")

In [16]:
def metrics(labels, preds):
  roc = roc_auc_score(labels.flatten().cpu().numpy(), preds.flatten().data.cpu().numpy())
  return roc
# Train
def train(datasets, model, optimizer, loss_fn, args, K = 30, neg_samp = "random"):
  # print(f"Beginning training for {model.name}")

  train_data = datasets["train"]
  val_data = datasets["val"]

  stats = {
      'train': {
        'loss': [],
        'roc' : []
      },
      'val': {
        'loss': [],
        'recall': [],
        'roc' : []
      }

  }
  val_neg_edge, val_neg_label = None, None
  for epoch in range(args["epochs"]): # loop over each epoch
    model.train()
    optimizer.zero_grad()

    # obtain negative sample
    if neg_samp == "random":
      neg_edge_index, neg_edge_label = sample_negative_edges(train_data, n_users, n_tracks, args["device"])
    elif neg_samp == "hard":
      if epoch % 5 == 0:
        neg_edge_index, neg_edge_label = sample_hard_negative_edges(
            train_data, model, n_users, n_tracks, args["device"], batch_size = 500,
            frac_sample = 1 - (0.5 * epoch / args["epochs"])
        )
    # calculate embedding
    embed = model.get_embedding(train_data.edge_index, train_data.edge_weight)
    # calculate pos, negative scores using embedding
    pos_scores = model.predict_link_embedding(embed, train_data.edge_label_index)
    neg_scores = model.predict_link_embedding(embed, neg_edge_index)

    # concatenate pos, neg scores together and evaluate loss
    scores = torch.cat((pos_scores, neg_scores), dim = 0)
    labels = torch.cat((train_data.edge_label, neg_edge_label), dim = 0)

    # calculate loss function
    if loss_fn == "BCE":
      # edge_weight = train_data.edge_weight[train_data.edge_label_index].view(-1)
      loss = weighted_binary_cross_entropy(scores[:len(pos_scores)], labels[:len(pos_scores)], train_data.edge_weight)
      loss = model.link_pred_loss(scores, labels)
    elif loss_fn == "BPR":
      loss = model.recommendation_loss(pos_scores, neg_scores, lambda_reg = 0)

    train_roc = metrics(labels, scores)

    loss.backward()
    optimizer.step()

    val_loss, val_roc, val_neg_edge, val_neg_label = test(
        model, val_data, loss_fn, neg_samp, epoch, val_neg_edge, val_neg_label
    )

    stats['train']['loss'].append(loss)
    stats['train']['roc'].append(train_roc)
    stats['val']['loss'].append(val_loss)
    stats['val']['roc'].append(val_roc)

    print(f"Epoch {epoch}; Train loss {loss}; Val loss {val_loss}; Train ROC {train_roc}; Val ROC {val_roc}")

    if epoch % 10 == 0:
      # calculate recall @ K
      val_recall = recall_at_k(val_data, model, k = K, device = args["device"])
      print(f"Val recall {val_recall}")
      stats['val']['recall'].append(val_recall)

    if epoch % 20 == 0:

      # save embeddings for future visualization
      path = os.path.join("model_embeddings", model.name)
      if not os.path.exists(path):
        os.makedirs(path)
      torch.save(model.embedding.weight, os.path.join("model_embeddings", model.name, f"{model.name}_{loss_fn}_{neg_samp}_{epoch}.pt"))

  pickle.dump(stats, open(f"model_stats/{model.name}_{loss_fn}_{neg_samp}.pkl", "wb"))
  return stats

def test(model, data, loss_fn, neg_samp, epoch = 0, neg_edge_index = None, neg_edge_label = None):

  model.eval()
  with torch.no_grad(): # want to save RAM

    # conduct negative sampling
    if neg_samp == "random":
      neg_edge_index, neg_edge_label = sample_negative_edges(data, n_users, n_tracks, args["device"])
    elif neg_samp == "hard":
      if epoch % 5 == 0 or neg_edge_index is None:
        neg_edge_index, neg_edge_label = sample_hard_negative_edges(
            data, model, n_users, n_tracks, args["device"], batch_size = 500,
            frac_sample = 1 - (0.5 * epoch / args["epochs"])
        )
    # obtain model embedding
    embed = model.get_embedding(data.edge_index, data.edge_weight)
    # calculate pos, neg scores using embedding
    pos_scores = model.predict_link_embedding(embed, data.edge_label_index)
    neg_scores = model.predict_link_embedding(embed, neg_edge_index)
    # concatenate pos, neg scores together and evaluate loss
    scores = torch.cat((pos_scores, neg_scores), dim = 0)
    labels = torch.cat((data.edge_label, neg_edge_label), dim = 0)
    # calculate loss
    if loss_fn == "BCE":
      loss = model.link_pred_loss(scores, labels)
    elif loss_fn == "BPR":
      loss = model.recommendation_loss(pos_scores, neg_scores, lambda_reg = 0)

    roc = metrics(labels, scores)

  return loss, roc, neg_edge_index, neg_edge_label

def weighted_binary_cross_entropy(output, target, edge_weight):
    # Apply sigmoid activation to ensure output is in [0, 1]
    output = torch.sigmoid(output)

    # Only apply edge weights to the positive scores
    pos_output = output[:len(edge_weight)]
    neg_output = output[len(edge_weight):]
    
    # Adjust the scaling factor and exponent as needed
    scaling_factor = 10
    exponent = -0.5
    weight = scaling_factor * torch.pow(edge_weight[:len(pos_output)] + 1e-8, exponent)
    
    # Compute loss separately for positive and negative scores
    pos_loss = F.binary_cross_entropy(pos_output, target[:len(edge_weight)], weight=weight, reduction='none')
    neg_loss = F.binary_cross_entropy(neg_output, target[len(edge_weight):], reduction='none')
    
    # Combine the losses
    loss = torch.cat([pos_loss, neg_loss], dim=0).mean()
    return loss

In [17]:
# create a dictionary of the dataset splits
datasets = {
    'train':train_split,
    'val':val_split,
    'test': test_split
}

# initialize our arguments
args = {
    'device' : 'cuda' if torch.cuda.is_available() else 'cpu',
    'num_layers' : 3,
    'emb_size' : 64,
    'weight_decay': 1e-5,
    'lr': 0.01,
    'epochs': 301
}

# initialize model and and optimizer
num_nodes = n_users + n_tracks
model = GCN(
    num_nodes = num_nodes, num_layers = args['num_layers'],
    embedding_dim = args["emb_size"]
)
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])

# send data, model to GPU if available
users_idx = torch.Tensor(users_idx).type(torch.int64).to(args["device"])
tracks_idx = torch.Tensor(tracks_idx).type(torch.int64).to(args["device"])
datasets['train'].to(args['device'])
datasets['val'].to(args['device'])
datasets['test'].to(args['device'])
model.to(args["device"])

# create directory to save model_stats
MODEL_STATS_DIR = "model_stats"
if not os.path.exists(MODEL_STATS_DIR):
  os.makedirs(MODEL_STATS_DIR)


# can set BCE -> BPR also, see which is better
train(datasets, model, optimizer, "BCE", args, neg_samp = "random")

test(model, datasets['test'], "BCE", neg_samp = "random")

Epoch 0; Train loss 0.6931408047676086; Val loss 0.6931471228599548; Train ROC 0.8400099044006523; Val ROC 0.5015478149854704
Val recall 0.0036700174678117037
Epoch 1; Train loss 0.693142831325531; Val loss 0.6931471824645996; Train ROC 0.7945596209883696; Val ROC 0.50675568041696
Epoch 2; Train loss 0.6931430697441101; Val loss 0.6931470036506653; Train ROC 0.8288427506952658; Val ROC 0.5365929304608617
Epoch 3; Train loss 0.6931446194648743; Val loss 0.6931468844413757; Train ROC 0.8732002962439784; Val ROC 0.5765423522098753
Epoch 4; Train loss 0.6931448578834534; Val loss 0.6931461095809937; Train ROC 0.8752637891650333; Val ROC 0.6157800659207495
Epoch 5; Train loss 0.6931443810462952; Val loss 0.6931446194648743; Train ROC 0.8897223476955162; Val ROC 0.6568835467175103
Epoch 6; Train loss 0.6931427121162415; Val loss 0.6931407451629639; Train ROC 0.9129377227575194; Val ROC 0.6769313169995179
Epoch 7; Train loss 0.6931385397911072; Val loss 0.6931316256523132; Train ROC 0.9040831

(tensor(0.6631),
 0.6955295515266601,
 tensor([[   167,   5295,   1149,  ...,   2280,   8616,   7558],
         [ 76574,  25634,  34069,  ..., 122822,  31866,  87169]]),
 tensor([0., 0., 0.,  ..., 0., 0., 0.]))

In [1]:
def init_model(conv_layer, args, alpha = False):
  num_nodes = n_users + n_tracks
  model = GCN(
      num_nodes = num_nodes, num_layers = args['num_layers'],
      embedding_dim = args["emb_size"], conv_layer = conv_layer,
      alpha_learnable = alpha
  )
  model.to(args["device"])
  optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])
  return model, optimizer

## For example:

# using BPR loss
loss_fn = "BPR"

# using hard sampling
neg_samp = "hard"

# for LGConv:
args['epochs'] = 301
args['num_layers'] = 3
model, optimizer = init_model("LGC", args)
lgc_stats_hard = train(datasets, model, optimizer, loss_fn, args, K = 30, neg_samp = neg_samp)
torch.save(model.state_dict(), f"model_stats/{model.name}_{loss_fn}_{neg_samp}.pt")


the following models are only ran in smaller samples:



In [None]:
# for GATConv:
model, optimizer = init_model("GAT", args)
gat_stats_hard = train(datasets, model, optimizer, loss_fn, args,  K = 30, neg_samp = neg_samp)
torch.save(model.state_dict(), f"model_stats/{model.name}_{loss_fn}_{neg_samp}.pt")

# for SAGEConv:
args['epochs'] = 301
args['num_layers'] = 3
model, optimizer = init_model("SAGE", args)
sage_stats_hard = train(datasets, model, optimizer, loss_fn, args,  K = 30, neg_samp = neg_samp)
torch.save(model.state_dict(), f"model_stats/{model.name}_{loss_fn}_{neg_samp}.pt")

# # using random sampling
neg_samp = "random"

# for LGConv:
args['epochs'] = 301
args['num_layers'] = 3
model, optimizer = init_model("LGC", args)
lgc_stats = train(datasets, model, optimizer, loss_fn, args,  K = 30, neg_samp = neg_samp)
torch.save(model.state_dict(), f"model_stats/{model.name}_{loss_fn}_{neg_samp}.pt")

# for GATConv:
args['epochs'] = 301
args['num_layers'] = 3
model, optimizer = init_model("GAT", args)
gat_stats = train(datasets, model, optimizer, loss_fn, args,  K = 30, neg_samp = neg_samp)
torch.save(model.state_dict(), f"model_stats/{model.name}_{loss_fn}_{neg_samp}.pt")

# for SAGEConv:
args['epochs'] = 301
args['num_layers'] = 3
model, optimizer = init_model("SAGE", args)
sage_stats = train(datasets, model, optimizer, loss_fn, args,  K = 30, neg_samp = neg_samp)
torch.save(model.state_dict(), f"model_stats/{model.name}_{loss_fn}_{neg_samp}.pt")

In [22]:
def init_model(conv_layer, args, alpha = False):
  num_nodes = n_users + n_tracks
  model = GCN(
      num_nodes = num_nodes, num_layers = args['num_layers'],
      embedding_dim = args["emb_size"], conv_layer = conv_layer,
      alpha_learnable = alpha
  )
  model.to(args["device"])
  optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])
  return model, optimizer

## For example:

# using BPR loss
loss_fn = "BPR"

# using hard sampling
neg_samp = "hard"

# for LGConv:
args['epochs'] = 301
args['num_layers'] = 3
lgc_stats_hard = model, optimizer = init_model("LGC", args)
train(datasets, model, optimizer, loss_fn, args, K = 30, neg_samp = neg_samp)
torch.save(model.state_dict(), f"model_stats/{model.name}_{loss_fn}_{neg_samp}.pt")


Epoch 0; Train loss 0.6931347250938416; Val loss 0.6931470036506653; Train ROC 0.8390378215074243; Val ROC 0.5124777715564951
Val recall 0.0041725956834852695
Epoch 1; Train loss 0.6931416988372803; Val loss 0.6931442618370056; Train ROC 0.7080801761310822; Val ROC 0.6023386720858938
Epoch 2; Train loss 0.693138599395752; Val loss 0.6931243538856506; Train ROC 0.8598907551618671; Val ROC 0.697656541424913
Epoch 3; Train loss 0.6931167244911194; Val loss 0.6930139660835266; Train ROC 0.9114290648303702; Val ROC 0.7027129257194717
Epoch 4; Train loss 0.692988395690918; Val loss 0.6926167607307434; Train ROC 0.8476822162061597; Val ROC 0.704997132075166
Epoch 5; Train loss 0.6925309896469116; Val loss 0.6916404366493225; Train ROC 0.818687308712473; Val ROC 0.702659423573733
Epoch 6; Train loss 0.69135981798172; Val loss 0.6897045969963074; Train ROC 0.8206334415187109; Val ROC 0.7051177561351866
Epoch 7; Train loss 0.6890003085136414; Val loss 0.6864423751831055; Train ROC 0.823857821473