In [1]:
import torch
import torch_geometric

In [2]:
print(torch.__version__)
!nvcc --version

2.8.0+cu128
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2025 NVIDIA Corporation
Built on Fri_Feb_21_20:23:50_PST_2025
Cuda compilation tools, release 12.8, V12.8.93
Build cuda_12.8.r12.8/compiler.35583870_0


In [3]:
# General libraries
import json
from pathlib import Path as Data_Path
import os
from os.path import isfile, join
import pickle
import random

import numpy as np
import networkx as nx
import pandas as pd
from sklearn.metrics import f1_score, roc_auc_score
import matplotlib.pyplot as plt
%matplotlib inline

from tqdm.notebook import tqdm

In [4]:
# Import relevant ML libraries
from typing import Optional, Union

import torch
from torch import Tensor
import torch.nn as nn
from torch.nn import Embedding, ModuleList, Linear
import torch.nn.functional as F

import torch_geometric
import torch_geometric.nn as pyg_nn
from torch_geometric.data import Data
from torch_geometric.transforms import RandomLinkSplit
from torch.nn.modules.loss import _Loss

from torch_geometric.nn.conv import LGConv, GATConv, SAGEConv
from torch_geometric.typing import Adj, OptTensor, SparseTensor

print(f"Torch version: {torch.__version__}; Torch-cuda version: {torch.version.cuda}; Torch Geometric version: {torch_geometric.__version__}.")

Torch version: 2.8.0+cu128; Torch-cuda version: 12.8; Torch Geometric version: 2.7.0.


In [5]:
# set the seed for reproducibility
seed = 224
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

In [6]:
MAIN_DIR = "/home/DSE411/Documents/Tanishq/mlg-music-recommendation"
DATA_DIR = Data_Path('spotify_million_playlist_dataset/data')
os.chdir(MAIN_DIR)

In [7]:
with open(f"{DATA_DIR}/{os.listdir(DATA_DIR)[0]}") as jf:
  example_file = json.load(jf)

print(example_file['playlists'][0])

{'name': 'OVO', 'collaborative': 'false', 'pid': 241000, 'modified_at': 1493596800, 'num_tracks': 45, 'num_albums': 10, 'num_followers': 1, 'tracks': [{'pos': 0, 'artist_name': 'Drake', 'track_uri': 'spotify:track:7jslhIiELQkgW9IHeYNOWE', 'artist_uri': 'spotify:artist:3TVXtAsR1Inumwj472S9r4', 'track_name': 'Big Rings', 'album_uri': 'spotify:album:1ozpmkWcCHwsQ4QTnxOOdT', 'duration_ms': 217706, 'album_name': 'What A Time To Be Alive'}, {'pos': 1, 'artist_name': 'Drake', 'track_uri': 'spotify:track:2AGottAzfC8bHzF7kEJ3Wa', 'artist_uri': 'spotify:artist:3TVXtAsR1Inumwj472S9r4', 'track_name': 'Diamonds Dancing', 'album_uri': 'spotify:album:1ozpmkWcCHwsQ4QTnxOOdT', 'duration_ms': 314631, 'album_name': 'What A Time To Be Alive'}, {'pos': 2, 'artist_name': 'Drake', 'track_uri': 'spotify:track:27GmP9AWRs744SzKcpJsTZ', 'artist_uri': 'spotify:artist:3TVXtAsR1Inumwj472S9r4', 'track_name': 'Jumpman', 'album_uri': 'spotify:album:1ozpmkWcCHwsQ4QTnxOOdT', 'duration_ms': 205879, 'album_name': 'What A 

In [8]:
"""
Here we define classes for the data that we are going to load. The data is stored in JSON files, each
which contain playlists, which themselves contain tracks. Thus, we define three classes:
  Track       --> contains information for a specific track (its id, name, etc.)
  Playlist    --> contains information for a specific playlist (its id, name, etc. as well as a list of Tracks)
  JSONFile    --> contains the loaded json file and stores a dictionary of all of the Playlists

Note: if we were to use the artist information, we could make an Artist class
"""

class Track:
  """
  Simple class for a track, containing its attributes:
    1. URI (a unique id)
    2. Name
    3. Artist info (URI and name)
    4. Parent playlist
  """

  def __init__(self, track_dict, playlist):
    self.uri = track_dict["track_uri"]
    self.name = track_dict["track_name"]
    self.artist_uri = track_dict["artist_uri"]
    self.artist_name = track_dict["artist_name"]
    self.album_uri = track_dict["album_uri"]
    self.album_name = track_dict["album_name"]
    self.playlist = playlist

  def __str__(self):
    return f"Track {self.uri} called {self.name} by {self.artist_uri} ({self.artist_name}) and in {self.album_uri} | {self.album_name} in playlist {self.playlist}."

  def __repr__(self):
    return f"Track {self.uri}"

class Playlist:
  """
  Simple class for a playlist, containing its attributes:
    1. Name (playlist and its associated index)
    2. Title (playlist title in the Spotify dataset)
    3. Loaded dictionary from the raw json for the playlist
    4. Dictionary of tracks (track_uri : Track), populated by .load_tracks()
    5. List of artists uris
  """

  def __init__(self, json_data, index):

    self.name = f"playlist_{index}"
    self.title = json_data["name"]
    self.data = json_data

    self.tracks = {}

  def load_tracks(self):
    """ Call this function to load all of the tracks in the json data for the playlist."""

    tracks_list = self.data["tracks"]
    self.tracks = {x["track_uri"] : Track(x, self.name) for x in tracks_list}


  def __str__(self):
    return f"Playlist {self.name} with {len(self.tracks)} tracks loaded."

  def __repr__(self):
    return f"Playlist {self.name}"

class JSONFile:
  """
  Simple class for a JSON file, containing its attributes:
    1. File Name
    2. Index to begin numbering playlists at
    3. Loaded dictionary from the raw json for the full file
    4. Dictionary of playlists (name : Playlist), populated by .process_file()
  """

  def __init__(self, data_path, file_name, start_index):

    self.file_name = file_name
    self.start_index = start_index

    with open(join(data_path, file_name)) as json_file:
      json_data = json.load(json_file)
    self.data = json_data

    self.playlists = {}

  def process_file(self):
    """ Call this function to load all of the playlists in the json data."""

    for i, playlist_json in enumerate(self.data["playlists"]):
      playlist = Playlist(playlist_json, self.start_index + i)
      playlist.load_tracks()
      self.playlists[playlist.name] = playlist

  def __str__(self):
    return f"JSON {self.file_name} has {len(self.playlists)} playlists loaded."

  def __repr__(self):
    return self.file_name


In [9]:
DATA_PATH = Data_Path('spotify_million_playlist_dataset/data')
N_FILES_TO_USE = 50

file_names = sorted(os.listdir(DATA_PATH))
file_names_to_use = file_names[:N_FILES_TO_USE]

n_playlists = 0

# load each json file, and store it in a list of files
JSONs = []
for file_name in tqdm(file_names_to_use, desc='Files processed: ', unit='files', total=len(file_names_to_use)):
  json_file = JSONFile(DATA_PATH, file_name, n_playlists)
  json_file.process_file()
  n_playlists += len(json_file.playlists)
  JSONs.append(json_file)

Files processed:   0%|          | 0/50 [00:00<?, ?files/s]

In [10]:
playlist_data = {}
track_data = []
playlists = []
tracks = []

# build list of all unique playlists, tracks
for json_file in tqdm(JSONs):
  playlists += [p.name for p in json_file.playlists.values()]
  tracks += [track.uri for playlist in json_file.playlists.values() for track in list(playlist.tracks.values())]
  playlist_data = playlist_data | json_file.playlists

  0%|          | 0/50 [00:00<?, ?it/s]

In [11]:
len(playlists), len(set(tracks)), len(tracks)

(50000, 457016, 3303932)

In [12]:
## create graph from these lists

# adding nodes
G = nx.Graph()
G.add_nodes_from([
    (p, {'name':p, "node_type" : "playlist"}) for p in playlists
])
G.add_nodes_from([
    (t, {'name':t, "node_type" : "track"}) for t in tracks
])
# add node types of track to album and artist
G.add_nodes_from([
    (track.album_uri, {'name':track.album_uri, "node_type" : "album"}) for p_name, playlist in playlist_data.items() for track in playlist.tracks.values()
])
G.add_nodes_from([
    (track.artist_uri, {'name':track.artist_uri, "node_type" : "artist"}) for p_name, playlist in playlist_data.items() for track in playlist.tracks.values()   
])

# adding edges
track_edge_list = []
album_edge_list = []
artist_edge_list = []
for p_name, playlist in playlist_data.items():
  for track in playlist.tracks.values():
    track_edge_list.append((p_name, track.uri))
    album_edge_list.append((track.uri, track.album_uri))
    artist_edge_list.append((track.uri, track.artist_uri))

G.add_edges_from(track_edge_list, edge_types="track_in_playlist")
G.add_edges_from(album_edge_list, edge_types="track_in_album")
G.add_edges_from(artist_edge_list, edge_types="track_by_artist")

print('Num nodes:', G.number_of_nodes(), '. Num edges:', G.number_of_edges())

Num nodes: 779017 . Num edges: 4217964


In [13]:
from collections import Counter

cnt = Counter([d["node_type"] for (_, d) in G.nodes(data=True)])
print(cnt)


Counter({'track': 457016, 'album': 192812, 'artist': 79189, 'playlist': 50000})


In [14]:
cmap_theme = "Set1"

# Graph for Visualization (N=20)

In [23]:
import random
random.seed(seed)

In [25]:

# -------- CONFIG --------
NUM_PLAYLISTS_TO_SAMPLE = 20

# -------- STEP 1: sample playlists --------
playlist_nodes = [
    n for n, d in G.nodes(data=True) if d["node_type"] == "playlist"
]

sampled_playlists = random.sample(playlist_nodes, NUM_PLAYLISTS_TO_SAMPLE)

sub_nodes = set(sampled_playlists)

# -------- STEP 2: playlist → track --------
for p in sampled_playlists:
    for neigh in G.neighbors(p):
        if G.nodes[neigh]["node_type"] == "track":
            sub_nodes.add(neigh)

# -------- STEP 3: track → album & artist ONLY --------
for t in list(sub_nodes):
    if G.nodes[t]["node_type"] == "track":
        for neigh in G.neighbors(t):
            if G.nodes[neigh]["node_type"] in ["album", "artist"]:
                sub_nodes.add(neigh)

# -------- STEP 4: build subgraph --------
sub_G_vis = G.subgraph(sub_nodes).copy()

print("Nodes:", sub_G_vis.number_of_nodes())
print("Edges:", sub_G_vis.number_of_edges())


Nodes: 2721
Edges: 3746


In [None]:
from pyvis.network import Network

net = Network(height="800px", width="100%", notebook=True)
net.force_atlas_2based()   # fast, stable layout

color_map = {
    "playlist": "#ff3333",
    "track": "#ff9933",
    "album": "#9933ff",
    "artist": "#999999"
}

for node, data in sub_G_vis.nodes(data=True):
    net.add_node(
        node,
        label=data["node_type"],        # node type displayed
        title=node,                     # hover shows node ID
        color=color_map[data["node_type"]],
        size=8
    )

for u, v in sub_G_vis.edges():
    net.add_edge(u, v, color="#cccccc")

net.show("20_subgraph.html")


100_subgraph.html


# Actual Graph to use for training N=10k

In [26]:
import random
# -------- CONFIG --------
NUM_PLAYLISTS_TO_SAMPLE = 10000

# -------- STEP 1: sample playlists --------
playlist_nodes = [
    n for n, d in G.nodes(data=True) if d["node_type"] == "playlist"
]

sampled_playlists = random.sample(playlist_nodes, NUM_PLAYLISTS_TO_SAMPLE)

sub_nodes = set(sampled_playlists)

# -------- STEP 2: playlist → track --------
for p in sampled_playlists:
    for neigh in G.neighbors(p):
        if G.nodes[neigh]["node_type"] == "track":
            sub_nodes.add(neigh)

# -------- STEP 3: track → album & artist ONLY --------
for t in list(sub_nodes):
    if G.nodes[t]["node_type"] == "track":
        for neigh in G.neighbors(t):
            if G.nodes[neigh]["node_type"] in ["album", "artist"]:
                sub_nodes.add(neigh)

# -------- STEP 4: build subgraph --------
sub_G = G.subgraph(sub_nodes).copy()

print("Nodes:", sub_G.number_of_nodes())
print("Edges:", sub_G.number_of_edges())


Nodes: 299372
Edges: 1001704


In [38]:
from collections import Counter

cnt = Counter([d["node_type"] for (_, d) in sub_G.nodes(data=True)])
print(cnt)


Counter({'track': 171855, 'album': 81720, 'artist': 35797, 'playlist': 10000})


# Save Graph for regular use

In [33]:
with open("10K_playlist_graph.pkl", "wb") as f:
    pickle.dump(sub_G, f)

# Load Graph 

In [36]:
# Note if you've already generated the graph above, you can skip those steps, and simply run set reload to True!
reload = False
if reload:
  sub_G = pickle.load(open("10K_playlist_graph.pkl", "rb"))
print('Num nodes:', sub_G.number_of_nodes(), '. Num edges:', sub_G.number_of_edges())

Num nodes: 299372 . Num edges: 1001704


In [47]:
import torch
from torch_geometric.data import HeteroData

def nx_to_heterodata(G):
    data = HeteroData()

    # ---- 1. Map nodes per type ----
    node_maps = {}
    for node, attr in G.nodes(data=True):
        ntype = attr["node_type"]
        if ntype not in node_maps:
            node_maps[ntype] = {}
        node_maps[ntype][node] = len(node_maps[ntype])

    # ---- 2. Set num_nodes per node type ----
    for ntype, idmap in node_maps.items():
        data[ntype].num_nodes = len(idmap)

    # ---- 3. Collect edges grouped by (src_type, rel, dst_type) ----
    edge_groups = {}

    for u, v, attr in G.edges(data=True):
        rel = attr["edge_types"]               # your key
        src_t = G.nodes[u]["node_type"]
        dst_t = G.nodes[v]["node_type"]

        src_id = node_maps[src_t][u]
        dst_id = node_maps[dst_t][v]

        edge_type = (src_t, rel, dst_t)

        if edge_type not in edge_groups:
            edge_groups[edge_type] = []

        edge_groups[edge_type].append([src_id, dst_id])

    # ---- 4. Write to PyG ----
    for etype, edges in edge_groups.items():
        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
        data[etype].edge_index = edge_index

    return data


In [48]:
hetero_data = nx_to_heterodata(sub_G)
print(hetero_data)


HeteroData(
  track={ num_nodes=171855 },
  album={ num_nodes=81720 },
  playlist={ num_nodes=10000 },
  artist={ num_nodes=35797 },
  (track, track_in_playlist, playlist)={ edge_index=[2, 329839] },
  (track, track_in_album, album)={ edge_index=[2, 85611] },
  (track, track_by_artist, artist)={ edge_index=[2, 87106] },
  (album, track_in_album, track)={ edge_index=[2, 86244] },
  (playlist, track_in_playlist, track)={ edge_index=[2, 328155] },
  (artist, track_by_artist, track)={ edge_index=[2, 84749] }
)


In [65]:
edge_types = hetero_data.edge_types
edge_types

[('track', 'track_in_playlist', 'playlist'),
 ('track', 'track_in_album', 'album'),
 ('track', 'track_by_artist', 'artist'),
 ('album', 'track_in_album', 'track'),
 ('playlist', 'track_in_playlist', 'track'),
 ('artist', 'track_by_artist', 'track')]

In [66]:


# convert to train/val/test splits
transform = RandomLinkSplit(
    is_undirected=True,
    add_negative_train_samples=False,
    neg_sampling_ratio=0,
    num_val=0.15, num_test=0.15,
    edge_types=edge_types
)
train_split, val_split, test_split = transform(hetero_data)

In [69]:
edge_type = hetero_data.edge_types[0]
hetero_data[edge_type].edge_index

tensor([[     0,      0,      0,  ..., 171333, 171468, 171623],
        [  4555,   1326,   6228,  ...,   9984,   9993,   9994]])

In [64]:
edge_type, hetero_data.edge_types

(('track', 'track_in_playlist', 'playlist'),
 [('track', 'track_in_playlist', 'playlist'),
  ('track', 'track_in_album', 'album'),
  ('track', 'track_by_artist', 'artist'),
  ('album', 'track_in_album', 'track'),
  ('playlist', 'track_in_playlist', 'track'),
  ('artist', 'track_by_artist', 'track')])

In [78]:
# note these are stored as float32, we need them to be int64 for future training

# Edge index: message passing edges
for edge_type in train_split.edge_types:
    # Convert message-passing edges
    train_split[edge_type].edge_index = train_split[edge_type].edge_index.long()
    val_split[edge_type].edge_index   = val_split[edge_type].edge_index.long()
    test_split[edge_type].edge_index  = test_split[edge_type].edge_index.long()

    # Convert label edges (if present)
    if "edge_label_index" in train_split[edge_type]:
        train_split[edge_type].edge_label_index = train_split[edge_type].edge_label_index.long()
    if "edge_label_index" in val_split[edge_type]:
        val_split[edge_type].edge_label_index   = val_split[edge_type].edge_label_index.long()
    if "edge_label_index" in test_split[edge_type]:
        test_split[edge_type].edge_label_index  = test_split[edge_type].edge_label_index.long()


    print(f"Train set has {train_split[edge_type].edge_label_index.shape[1]} positives upervision edges")
    print(f"Validation set has {val_split[edge_type].edge_label_index.shape[1]} positive supervision edges")
    print(f"Test set has {test_split[edge_type].edge_label_index.shape[1]} positive supervision edges")

    print(f"Train set has {train_split[edge_type].edge_index.shape[1]} message passing edges")
    print(f"Validation set has {val_split[edge_type].edge_index.shape[1]} message passing edges")
    print(f"Test set has {test_split[edge_type].edge_index.shape[1]} message passing edges")

Train set has 230889 positives upervision edges
Validation set has 49475 positive supervision edges
Test set has 49475 positive supervision edges
Train set has 230889 message passing edges
Validation set has 230889 message passing edges
Test set has 280364 message passing edges
Train set has 59929 positives upervision edges
Validation set has 12841 positive supervision edges
Test set has 12841 positive supervision edges
Train set has 59929 message passing edges
Validation set has 59929 message passing edges
Test set has 72770 message passing edges
Train set has 60976 positives upervision edges
Validation set has 13065 positive supervision edges
Test set has 13065 positive supervision edges
Train set has 60976 message passing edges
Validation set has 60976 message passing edges
Test set has 74041 message passing edges
Train set has 60372 positives upervision edges
Validation set has 12936 positive supervision edges
Test set has 12936 positive supervision edges
Train set has 60372 messag

In [None]:
def RGCN(torch.nn.Module):
    

In [28]:
class GCN(torch.nn.Module):
    """
      Here we adapt the LightGCN model from Torch Geometric for our purposes. We allow
      for customizable convolutional layers, custom embeddings. In addition, we deifne some
      additional custom functions.

    """

    def __init__(
        self,
        num_nodes: int,
        embedding_dim: int,
        num_layers: int,
        alpha: Optional[Union[float, Tensor]] = None,
        alpha_learnable = False,
        conv_layer = "LGC",
        name = None,
        **kwargs,
    ):
        super().__init__()
        alpha_string = "alpha" if alpha_learnable else ""
        self.name = f"LGCN_{conv_layer}_{num_layers}_e{embedding_dim}_nodes{num_nodes}_{alpha_string}"
        self.num_nodes = num_nodes
        self.embedding_dim = embedding_dim
        self.num_layers = num_layers

        if alpha_learnable == True:
          alpha_vals = torch.rand(num_layers+1)
          alpha = nn.Parameter(alpha_vals/torch.sum(alpha_vals))
          print(f"Alpha learnable, initialized to: {alpha.softmax(dim=-1)}")
        else:
          if alpha is None:
              alpha = 1. / (num_layers + 1)

          if isinstance(alpha, Tensor):
              assert alpha.size(0) == num_layers + 1
          else:
              alpha = torch.tensor([alpha] * (num_layers + 1))

        self.register_buffer('alpha', alpha)

        self.embedding = Embedding(num_nodes, embedding_dim)

        # initialize convolutional layers
        self.conv_layer = conv_layer
        if conv_layer == "LGC":
          self.convs = ModuleList([LGConv(**kwargs) for _ in range(num_layers)])
        elif conv_layer == "GAT":
          # initialize Graph Attention layer with multiple heads
          # initialize linear layers to aggregate heads
          n_heads = 5
          self.convs = ModuleList(
              [GATConv(in_channels = embedding_dim, out_channels = embedding_dim, heads = n_heads, dropout = 0.5, **kwargs) for _ in range(num_layers)]
          )
          self.linears = ModuleList([Linear(n_heads * embedding_dim, embedding_dim) for _ in range(num_layers)])

        elif conv_layer == "SAGE":
          #  initialize GraphSAGE conv
          self.convs = ModuleList(
              [SAGEConv(in_channels = embedding_dim, out_channels = embedding_dim, **kwargs) for _ in range(num_layers)]
          )

        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.embedding.weight)
        for conv in self.convs:
            conv.reset_parameters()

    def get_embedding(self, edge_index: Adj) -> Tensor:
        x = self.embedding.weight

        weights = self.alpha.softmax(dim=-1)
        out = x * weights[0]

        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index)
            if self.conv_layer == "GAT":
              x = self.linears[i](x)
            out = out + x * weights[i + 1]

        return out

    def initialize_embeddings(self, data):
      # initialize with the data node features
        self.embedding.weight.data.copy_(data.node_feature)


    def forward(self, edge_index: Adj,
                edge_label_index: OptTensor = None) -> Tensor:
        if edge_label_index is None:
            if isinstance(edge_index, SparseTensor):
                edge_label_index = torch.stack(edge_index.coo()[:2], dim=0)
            else:
                edge_label_index = edge_index

        out = self.get_embedding(edge_index)

        return self.predict_link_embedding(out, edge_label_index)

    def predict_link(self, edge_index: Adj, edge_label_index: OptTensor = None,
                     prob: bool = False) -> Tensor:

        pred = self(edge_index, edge_label_index).sigmoid()
        return pred if prob else pred.round()

    def predict_link_embedding(self, embed: Adj, edge_label_index: Adj) -> Tensor:

        embed_src = embed[edge_label_index[0]]
        embed_dst = embed[edge_label_index[1]]
        return (embed_src * embed_dst).sum(dim=-1)


    def recommend(self, edge_index: Adj, src_index: OptTensor = None,
                  dst_index: OptTensor = None, k: int = 1) -> Tensor:
        out_src = out_dst = self.get_embedding(edge_index)

        if src_index is not None:
            out_src = out_src[src_index]

        if dst_index is not None:
            out_dst = out_dst[dst_index]

        pred = out_src @ out_dst.t()
        top_index = pred.topk(k, dim=-1).indices

        if dst_index is not None:  # Map local top-indices to original indices.
            top_index = dst_index[top_index.view(-1)].view(*top_index.size())

        return top_index


    def link_pred_loss(self, pred: Tensor, edge_label: Tensor,
                       **kwargs) -> Tensor:
        loss_fn = torch.nn.BCEWithLogitsLoss(**kwargs)
        return loss_fn(pred, edge_label.to(pred.dtype))


    def recommendation_loss(self, pos_edge_rank: Tensor, neg_edge_rank: Tensor,
                            lambda_reg: float = 1e-4, **kwargs) -> Tensor:
        r"""Computes the model loss for a ranking objective via the Bayesian
        Personalized Ranking (BPR) loss."""
        loss_fn = BPRLoss(lambda_reg, **kwargs)
        return loss_fn(pos_edge_rank, neg_edge_rank, self.embedding.weight)

    def bpr_loss(self, pos_scores, neg_scores):
      return - torch.log(torch.sigmoid(pos_scores - neg_scores)).mean()

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.num_nodes}, '
                f'{self.embedding_dim}, num_layers={self.num_layers})')


In [29]:
class BPRLoss(_Loss):
    r"""The Bayesian Personalized Ranking (BPR) loss.

    The BPR loss is a pairwise loss that encourages the prediction of an
    observed entry to be higher than its unobserved counterparts
    (see `here <https://arxiv.org/abs/2002.02126>`__).

    .. math::
        L_{\text{BPR}} = - \sum_{u=1}^{M} \sum_{i \in \mathcal{N}_u}
        \sum_{j \not\in \mathcal{N}_u} \ln \sigma(\hat{y}_{ui} - \hat{y}_{uj})
        + \lambda \vert\vert \textbf{x}^{(0)} \vert\vert^2

    where :math:`lambda` controls the :math:`L_2` regularization strength.
    We compute the mean BPR loss for simplicity.

    Args:
        lambda_reg (float, optional): The :math:`L_2` regularization strength
            (default: 0).
        **kwargs (optional): Additional arguments of the underlying
            :class:`torch.nn.modules.loss._Loss` class.
    """
    __constants__ = ['lambda_reg']
    lambda_reg: float

    def __init__(self, lambda_reg: float = 0, **kwargs):
        super().__init__(None, None, "sum", **kwargs)
        self.lambda_reg = lambda_reg

    def forward(self, positives: Tensor, negatives: Tensor,
                parameters: Tensor = None) -> Tensor:
        r"""Compute the mean Bayesian Personalized Ranking (BPR) loss.

        .. note::

            The i-th entry in the :obj:`positives` vector and i-th entry
            in the :obj:`negatives` entry should correspond to the same
            entity (*.e.g*, user), as the BPR is a personalized ranking loss.

        Args:
            positives (Tensor): The vector of positive-pair rankings.
            negatives (Tensor): The vector of negative-pair rankings.
            parameters (Tensor, optional): The tensor of parameters which
                should be used for :math:`L_2` regularization
                (default: :obj:`None`).
        """
        n_pairs = positives.size(0)
        log_prob = F.logsigmoid(positives - negatives).sum()
        regularization = 0

        if self.lambda_reg != 0:
            regularization = self.lambda_reg * parameters.norm(p=2).pow(2)

        return (-log_prob + regularization) / n_pairs

In [30]:
def sample_negative_edges_nocheck(data, num_playlists, num_tracks, device = None):
  # note computationally inefficient to check that these are indeed negative edges
  playlists = data.edge_label_index[0, :]
  tracks = torch.randint(num_playlists, num_playlists + num_tracks - 1, size = data.edge_label_index[1, :].size())

  if playlists.get_device() != -1: # on gpu
    tracks = tracks.to(device)

  neg_edge_index = torch.stack((playlists, tracks), dim = 0)
  neg_edge_label = torch.zeros(neg_edge_index.shape[1])

  if neg_edge_index.get_device() != -1: # on gpu
    neg_edge_label = neg_edge_label.to(device)

  return neg_edge_index, neg_edge_label

def sample_negative_edges(data, num_playlists, num_tracks, device=None):
    positive_playlists, positive_tracks = data.edge_label_index

    # Create a mask tensor with the shape (num_playlists, num_tracks)
    mask = torch.zeros(num_playlists, num_tracks, device=device, dtype=torch.bool)
    mask[positive_playlists, positive_tracks - num_playlists] = True

    # Flatten the mask tensor and get the indices of the negative edges
    flat_mask = mask.flatten()
    negative_indices = torch.where(~flat_mask)[0]

    # Sample negative edges from the negative_indices tensor
    sampled_negative_indices = negative_indices[
        torch.randint(0, negative_indices.size(0), size=(positive_playlists.size(0),), device=device)
    ]

    # Convert the indices back to playlists and tracks tensors
    playlists = torch.floor_divide(sampled_negative_indices, num_tracks)
    tracks = torch.remainder(sampled_negative_indices, num_tracks)
    tracks = tracks + num_playlists

    neg_edge_index = torch.stack((playlists, tracks), dim=0)
    neg_edge_label = torch.zeros(neg_edge_index.shape[1], device=device)

    return neg_edge_index, neg_edge_label

def sample_hard_negative_edges(data, model, num_playlists, num_tracks, device=None, batch_size=500, frac_sample = 1):
    with torch.no_grad():
        embeddings = model.get_embedding(data.edge_index)
        playlists_embeddings = embeddings[:num_playlists].to(device)
        tracks_embeddings = embeddings[num_playlists:].to(device)

    positive_playlists, positive_tracks = data.edge_label_index
    num_edges = positive_playlists.size(0)

    # Create a boolean mask for all the positive edges
    positive_mask = torch.zeros(num_playlists, num_tracks, device=device, dtype=torch.bool)
    positive_mask[positive_playlists, positive_tracks - num_playlists] = True

    neg_edges_list = []
    neg_edge_label_list = []

    for batch_start in range(0, num_edges, batch_size):
        batch_end = min(batch_start + batch_size, num_edges)

        batch_scores = torch.matmul(
            playlists_embeddings[positive_playlists[batch_start:batch_end]], tracks_embeddings.t()
        )

        # Set the scores of the positive edges to negative infinity
        batch_scores[positive_mask[positive_playlists[batch_start:batch_end]]] = -float("inf")

        # Select the top k highest scoring negative edges for each playlist in the current batch
        # do 0.99 to filter out all pos edges which will be at the end
        _, top_indices = torch.topk(batch_scores, int(frac_sample * 0.99 * num_tracks), dim=1)
        selected_indices = torch.randint(0, int(frac_sample * 0.99 *num_tracks), size = (batch_end - batch_start, ))
        top_indices_selected = top_indices[torch.arange(batch_end - batch_start), selected_indices] + n_playlists

        # Create the negative edges tensor for the current batch
        neg_edges_batch = torch.stack(
            (positive_playlists[batch_start:batch_end], top_indices_selected), dim=0
        )
        neg_edge_label_batch = torch.zeros(neg_edges_batch.shape[1], device=device)

        neg_edges_list.append(neg_edges_batch)
        neg_edge_label_list.append(neg_edge_label_batch)

    # Concatenate the batch tensors
    neg_edges = torch.cat(neg_edges_list, dim=1)
    neg_edge_label = torch.cat(neg_edge_label_list)

    return neg_edges, neg_edge_label

In [31]:
def recall_at_k(data, model, k = 300, batch_size = 64, device = None):
    with torch.no_grad():
        embeddings = model.get_embedding(data.edge_index)
        playlists_embeddings = embeddings[:n_playlists]
        tracks_embeddings = embeddings[n_playlists:]

    hits_list = []
    relevant_counts_list = []

    for batch_start in range(0, n_playlists, batch_size):
        batch_end = min(batch_start + batch_size, n_playlists)
        batch_playlists_embeddings = playlists_embeddings[batch_start:batch_end]

        # Calculate scores for all possible item pairs
        scores = torch.matmul(batch_playlists_embeddings, tracks_embeddings.t())

        # Set the scores of message passing edges to negative infinity
        mp_indices = ((data.edge_index[0] >= batch_start) & (data.edge_index[0] < batch_end)).nonzero(as_tuple=True)[0]
        scores[data.edge_index[0, mp_indices] - batch_start, data.edge_index[1, mp_indices] - n_playlists] = -float("inf")

        # Find the top k highest scoring items for each playlist in the batch
        _, top_k_indices = torch.topk(scores, k, dim=1)

        # Ground truth supervision edges
        ground_truth_edges = data.edge_label_index

        # Create a mask to indicate if the top k items are in the ground truth supervision edges
        mask = torch.zeros(scores.shape, device=device, dtype=torch.bool)
        gt_indices = ((ground_truth_edges[0] >= batch_start) & (ground_truth_edges[0] < batch_end)).nonzero(as_tuple=True)[0]
        mask[ground_truth_edges[0, gt_indices] - batch_start, ground_truth_edges[1, gt_indices] - n_playlists] = True

        # Check how many of the top k items are in the ground truth supervision edges
        hits = mask.gather(1, top_k_indices).sum(dim=1)
        hits_list.append(hits)

        # Calculate the total number of relevant items for each playlist in the batch
        relevant_counts = torch.bincount(ground_truth_edges[0, gt_indices] - batch_start, minlength=batch_end - batch_start)
        relevant_counts_list.append(relevant_counts)

    # Compute recall@k
    hits_tensor = torch.cat(hits_list, dim=0)
    relevant_counts_tensor = torch.cat(relevant_counts_list, dim=0)
    # Handle division by zero case
    recall_at_k = torch.where(
        relevant_counts_tensor != 0,
        hits_tensor.true_divide(relevant_counts_tensor),
        torch.ones_like(hits_tensor)
    )
    # take average
    recall_at_k = torch.mean(recall_at_k)

    if recall_at_k.numel() == 1:
        return recall_at_k.item()
    else:
        raise ValueError("recall_at_k contains more than one item.")

In [32]:
def metrics(labels, preds):
  roc = roc_auc_score(labels.flatten().cpu().numpy(), preds.flatten().data.cpu().numpy())
  return roc

In [33]:
# Train
def train(datasets, model, optimizer, loss_fn, args, neg_samp = "random"):
  print(f"Beginning training for {model.name}")

  train_data = datasets["train"]
  val_data = datasets["val"]

  stats = {
      'train': {
        'loss': [],
        'roc' : []
      },
      'val': {
        'loss': [],
        'recall': [],
        'roc' : []
      }

  }
  val_neg_edge, val_neg_label = None, None
  for epoch in range(args["epochs"]): # loop over each epoch
    model.train()
    optimizer.zero_grad()

    # obtain negative sample
    if neg_samp == "random":
      neg_edge_index, neg_edge_label = sample_negative_edges(train_data, n_playlists, n_tracks, args["device"])
    elif neg_samp == "hard":
      if epoch % 5 == 0:
        neg_edge_index, neg_edge_label = sample_hard_negative_edges(
            train_data, model, n_playlists, n_tracks, args["device"], batch_size = 500,
            frac_sample = 1 - (0.5 * epoch / args["epochs"])
        )
    # calculate embedding
    embed = model.get_embedding(train_data.edge_index)
    # calculate pos, negative scores using embedding
    pos_scores = model.predict_link_embedding(embed, train_data.edge_label_index)
    neg_scores = model.predict_link_embedding(embed, neg_edge_index)

    # concatenate pos, neg scores together and evaluate loss
    scores = torch.cat((pos_scores, neg_scores), dim = 0)
    labels = torch.cat((train_data.edge_label, neg_edge_label), dim = 0)

    # calculate loss function
    if loss_fn == "BCE":
      loss = model.link_pred_loss(scores, labels)
    elif loss_fn == "BPR":
      loss = model.recommendation_loss(pos_scores, neg_scores, lambda_reg = 0)

    train_roc = metrics(labels, scores)

    loss.backward()
    optimizer.step()

    val_loss, val_roc, val_neg_edge, val_neg_label = test(
        model, val_data, loss_fn, neg_samp, epoch, val_neg_edge, val_neg_label
    )

    stats['train']['loss'].append(loss)
    stats['train']['roc'].append(train_roc)
    stats['val']['loss'].append(val_loss)
    stats['val']['roc'].append(val_roc)

    print(f"Epoch {epoch}; Train loss {loss}; Val loss {val_loss}; Train ROC {train_roc}; Val ROC {val_roc}")

    if epoch % 10 == 0:
      # calculate recall @ K
      val_recall = recall_at_k(val_data, model, k = 300, device = args["device"])
      print(f"Val recall {val_recall}")
      stats['val']['recall'].append(val_recall)

    if epoch % 20 == 0:

      # save embeddings for future visualization
      path = os.path.join("model_embeddings", model.name)
      if not os.path.exists(path):
        os.makedirs(path)
      torch.save(model.embedding.weight, os.path.join("model_embeddings", model.name, f"{model.name}_{loss_fn}_{neg_samp}_{epoch}.pt"))

  pickle.dump(stats, open(f"model_stats/{model.name}_{loss_fn}_{neg_samp}.pkl", "wb"))
  return stats

def test(model, data, loss_fn, neg_samp, epoch = 0, neg_edge_index = None, neg_edge_label = None):

  model.eval()
  with torch.no_grad(): # want to save RAM

    # conduct negative sampling
    if neg_samp == "random":
      neg_edge_index, neg_edge_label = sample_negative_edges(data, n_playlists, n_tracks, args["device"])
    elif neg_samp == "hard":
      if epoch % 5 == 0 or neg_edge_index is None:
        neg_edge_index, neg_edge_label = sample_hard_negative_edges(
            data, model, n_playlists, n_tracks, args["device"], batch_size = 500,
            frac_sample = 1 - (0.5 * epoch / args["epochs"])
        )
    # obtain model embedding
    embed = model.get_embedding(data.edge_index)
    # calculate pos, neg scores using embedding
    pos_scores = model.predict_link_embedding(embed, data.edge_label_index)
    neg_scores = model.predict_link_embedding(embed, neg_edge_index)
    # concatenate pos, neg scores together and evaluate loss
    scores = torch.cat((pos_scores, neg_scores), dim = 0)
    labels = torch.cat((data.edge_label, neg_edge_label), dim = 0)
    # calculate loss
    if loss_fn == "BCE":
      loss = model.link_pred_loss(scores, labels)
    elif loss_fn == "BPR":
      loss = model.recommendation_loss(pos_scores, neg_scores, lambda_reg = 0)

    roc = metrics(labels, scores)

  return loss, roc, neg_edge_index, neg_edge_label

In [34]:
# create a dictionary of the dataset splits
datasets = {
    'train':train_split,
    'val':val_split,
    'test': test_split
}

In [35]:
# initialize our arguments
args = {
    'device' : 'cuda' if torch.cuda.is_available() else 'cpu',
    'num_layers' :  3,
    'emb_size' : 64,
    'weight_decay': 1e-5,
    'lr': 0.01,
    'epochs': 301
}

In [36]:
# initialize model and and optimizer
num_nodes = n_playlists + n_tracks
model = GCN(
    num_nodes = num_nodes, num_layers = args['num_layers'],
    embedding_dim = args["emb_size"], conv_layer = "SAGE"
)
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])

In [37]:
# send data, model to GPU if available
playlists_idx = torch.Tensor(playlists_idx).type(torch.int64).to(args["device"])
tracks_idx =torch.Tensor(tracks_idx).type(torch.int64).to(args["device"])
datasets['train'].to(args['device'])
datasets['val'].to(args['device'])
datasets['test'].to(args['device'])
model.to(args["device"])

GCN(35300, 64, num_layers=3)

In [38]:
# create directory to save model_stats
MODEL_STATS_DIR = "model_stats"
if not os.path.exists(MODEL_STATS_DIR):
  os.makedirs(MODEL_STATS_DIR)

In [39]:
train(datasets, model, optimizer, "BPR", args, neg_samp = "random")

Beginning training for LGCN_SAGE_3_e64_nodes35300_
Epoch 0; Train loss 0.6927112936973572; Val loss 0.6699389219284058; Train ROC 0.5236944881664941; Val ROC 0.8081507123303305
Val recall 0.22803421318531036
Epoch 0; Train loss 0.6927112936973572; Val loss 0.6699389219284058; Train ROC 0.5236944881664941; Val ROC 0.8081507123303305
Val recall 0.22803421318531036
Epoch 1; Train loss 0.669262170791626; Val loss 0.5582274794578552; Train ROC 0.8147294370524052; Val ROC 0.7799351087426378
Epoch 1; Train loss 0.669262170791626; Val loss 0.5582274794578552; Train ROC 0.8147294370524052; Val ROC 0.7799351087426378
Epoch 2; Train loss 0.5527150630950928; Val loss 0.5579432845115662; Train ROC 0.7861143212979961; Val ROC 0.7705621662799603
Epoch 2; Train loss 0.5527150630950928; Val loss 0.5579432845115662; Train ROC 0.7861143212979961; Val ROC 0.7705621662799603
Epoch 3; Train loss 0.5422505736351013; Val loss 0.4192062318325043; Train ROC 0.775803129346007; Val ROC 0.8102440714623504
Epoch 4;

{'train': {'loss': [tensor(0.6927, device='cuda:0', grad_fn=<DivBackward0>),
   tensor(0.6693, device='cuda:0', grad_fn=<DivBackward0>),
   tensor(0.5527, device='cuda:0', grad_fn=<DivBackward0>),
   tensor(0.5423, device='cuda:0', grad_fn=<DivBackward0>),
   tensor(0.4046, device='cuda:0', grad_fn=<DivBackward0>),
   tensor(0.3855, device='cuda:0', grad_fn=<DivBackward0>),
   tensor(0.3664, device='cuda:0', grad_fn=<DivBackward0>),
   tensor(0.3786, device='cuda:0', grad_fn=<DivBackward0>),
   tensor(0.3814, device='cuda:0', grad_fn=<DivBackward0>),
   tensor(0.3602, device='cuda:0', grad_fn=<DivBackward0>),
   tensor(0.3473, device='cuda:0', grad_fn=<DivBackward0>),
   tensor(0.3488, device='cuda:0', grad_fn=<DivBackward0>),
   tensor(0.3505, device='cuda:0', grad_fn=<DivBackward0>),
   tensor(0.3465, device='cuda:0', grad_fn=<DivBackward0>),
   tensor(0.3423, device='cuda:0', grad_fn=<DivBackward0>),
   tensor(0.3409, device='cuda:0', grad_fn=<DivBackward0>),
   tensor(0.3399, devic

In [40]:
test(model, datasets['test'], "BPR", neg_samp = "random")

(tensor(0.1410, device='cuda:0'),
 0.9499165680823833,
 tensor([[12321, 14144,  3385,  ...,  5691, 21204,  1460],
         [29629, 32538, 35098,  ..., 35196, 30857, 25586]], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'))

In [41]:
def init_model(conv_layer, args, alpha = False):
  num_nodes = n_playlists + n_tracks
  model = GCN(
      num_nodes = num_nodes, num_layers = args['num_layers'],
      embedding_dim = args["emb_size"], conv_layer = conv_layer,
      alpha_learnable = alpha
  )
  model.to(args["device"])
  optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])
  return model, optimizer

In [42]:
## For example:

# using BPR loss
loss_fn = "BPR"

# using hard sampling
neg_samp = "hard"

# for LGConv:
args['epochs'] = 301
args['num_layers'] = 4
model, optimizer = init_model("LGC", args)
lgc_stats_hard = train(datasets, model, optimizer, loss_fn, args, neg_samp = neg_samp)
torch.save(model.state_dict(), f"model_stats/{model.name}_{loss_fn}_{neg_samp}.pt")

Beginning training for LGCN_LGC_4_e64_nodes35300_
Epoch 0; Train loss 0.6931434273719788; Val loss 0.6931453943252563; Train ROC 0.6034073015044848; Val ROC 0.6759018905552148
Val recall 0.2589329481124878
Epoch 0; Train loss 0.6931434273719788; Val loss 0.6931453943252563; Train ROC 0.6034073015044848; Val ROC 0.6759018905552148
Val recall 0.2589329481124878
Epoch 1; Train loss 0.693144679069519; Val loss 0.6931245923042297; Train ROC 0.7589306002473327; Val ROC 0.7777647373441927
Epoch 2; Train loss 0.6931230425834656; Val loss 0.6929528117179871; Train ROC 0.8139802534992705; Val ROC 0.8013493128548925
Epoch 1; Train loss 0.693144679069519; Val loss 0.6931245923042297; Train ROC 0.7589306002473327; Val ROC 0.7777647373441927
Epoch 2; Train loss 0.6931230425834656; Val loss 0.6929528117179871; Train ROC 0.8139802534992705; Val ROC 0.8013493128548925
Epoch 3; Train loss 0.6929451823234558; Val loss 0.6922587156295776; Train ROC 0.8156420458589091; Val ROC 0.7996414969775105
Epoch 4; T

In [None]:
# for GATConv:
args['epochs'] = 301
args['num_layers'] = 3
model, optimizer = init_model("GAT", args)
gat_stats_hard = train(datasets, model, optimizer, loss_fn, args, neg_samp = neg_samp)
torch.save(model.state_dict(), f"model_stats/{model.name}_{loss_fn}_{neg_samp}.pt")

In [None]:
# for SAGEConv:
args['epochs'] = 301
args['num_layers'] = 3
model, optimizer = init_model("SAGE", args)
sage_stats_hard = train(datasets, model, optimizer, loss_fn, args, neg_samp = neg_samp)
torch.save(model.state_dict(), f"model_stats/{model.name}_{loss_fn}_{neg_samp}.pt")

In [None]:
# using random sampling
neg_samp = "random"

# for LGConv:
args['epochs'] = 301
args['num_layers'] = 4
model, optimizer = init_model("LGC", args)
lgc_stats = train(datasets, model, optimizer, loss_fn, args, neg_samp = neg_samp)
torch.save(model.state_dict(), f"model_stats/{model.name}_{loss_fn}_{neg_samp}.pt")

In [None]:
# for GATConv:
args['epochs'] = 301
args['num_layers'] = 3
model, optimizer = init_model("GAT", args)
gat_stats = train(datasets, model, optimizer, loss_fn, args, neg_samp = neg_samp)
torch.save(model.state_dict(), f"model_stats/{model.name}_{loss_fn}_{neg_samp}.pt")

In [None]:
# for SAGEConv:
args['epochs'] = 301
args['num_layers'] = 3
model, optimizer = init_model("SAGE", args)
sage_stats = train(datasets, model, optimizer, loss_fn, args, neg_samp = neg_samp)
torch.save(model.state_dict(), f"model_stats/{model.name}_{loss_fn}_{neg_samp}.pt")

In [None]:
def detach_loss(stats):
  return [loss.detach().cpu().numpy().item() for loss in stats]

def plot_train_val_loss(stats_dict):
  fig, ax = plt.subplots(1,1, figsize = (6, 4))
  train_loss = detach_loss(stats_dict["train"]["loss"])
  val_loss = detach_loss(stats_dict["val"]["loss"])
  idx = np.arange(0, len(train_loss), 1)
  ax.plot(idx, train_loss, label = "train")
  ax.plot(idx, val_loss, label = "val")
  ax.legend()
  plt.show()

In [None]:
# If you had to stop for whatever reason, you can always reload the stats here! (just uncomment and change to correct paths)
# lgc_stats = pickle.load(open(f"{MODEL_STATS_DIR}/LGCN_LGC_4_e64_nodes34810__BPR_random.pkl", "rb"))
# gat_stats = pickle.load(open(f"{MODEL_STATS_DIR}/LGCN_GAT_3_e64_nodes34810__BPR_random.pkl", "rb"))
# sage_stats = pickle.load(open(f"{MODEL_STATS_DIR}/LGCN_SAGE_3_e64_nodes34810__BPR_random.pkl", "rb"))
# lgc_stats_hard = pickle.load(open(f"{MODEL_STATS_DIR}/LGCN_LGC_4_e64_nodes34810__BPR_hard.pkl", "rb"))
# gat_stats_hard = pickle.load(open(f"{MODEL_STATS_DIR}/LGCN_GAT_3_e64_nodes34810__BPR_hard.pkl", "rb"))
# sage_stats_hard = pickle.load(open(f"{MODEL_STATS_DIR}/LGCN_SAGE_3_e64_nodes34810__BPR_hard.pkl", "rb"))

In [None]:
plot_train_val_loss(lgc_stats)
plot_train_val_loss(lgc_stats_hard)

In [None]:
plot_train_val_loss(sage_stats)
plot_train_val_loss(sage_stats_hard)

In [None]:
plot_train_val_loss(gat_stats)
plot_train_val_loss(gat_stats_hard)

In [None]:
fig, ax = plt.subplots(1,1, figsize = (8, 6))
key = "loss"
lgc_loss = pd.Series(detach_loss(lgc_stats["val"][key])).rolling(3).mean()
gat_loss = pd.Series(detach_loss(gat_stats["val"][key])).rolling(3).mean()
sage_loss = pd.Series(detach_loss(sage_stats["val"][key])).rolling(3).mean()
idx = np.arange(0, len(lgc_loss), 1)

colors = ["#1f77b4", "#ff7f0e", "#2ca02c"]
ax.plot(idx, lgc_loss, color = colors[0], linestyle = 'dashed', label = "LGC - random")
ax.plot(idx, gat_loss, color = colors[1], linestyle = 'dashed', label = "GAT - random")
ax.plot(idx, sage_loss, color = colors[2], linestyle = 'dashed', label = "SAGE - random")
ax.legend(loc = 'lower right')

ax.set_xlabel("Epochs")
ax.set_ylabel("BPR Loss")
ax.set_title("Model BPR Loss, by convolution type and negative sampling")
ax.set_ylim(0, 0.7)
plt.show()

In [None]:
fig, ax = plt.subplots(1,1, figsize = (8, 6))
key = "loss"
lgc_loss = pd.Series(detach_loss(lgc_stats["val"][key])).rolling(3).mean()
gat_loss = pd.Series(detach_loss(gat_stats["val"][key])).rolling(3).mean()
sage_loss = pd.Series(detach_loss(sage_stats["val"][key])).rolling(3).mean()
lgc_hard_loss = pd.Series(detach_loss(lgc_stats_hard["val"][key])).rolling(3).mean()
gat_hard_loss = pd.Series(detach_loss(gat_stats_hard["val"][key])).rolling(3).mean()
sage_hard_loss = pd.Series(detach_loss(sage_stats_hard["val"][key])).rolling(3).mean()
idx = np.arange(0, len(lgc_loss), 1)

colors = ["#1f77b4", "#ff7f0e", "#2ca02c"]
ax.plot(idx, lgc_loss, color = colors[0], linestyle = 'dashed', label = "LGC - random")
ax.plot(idx, lgc_hard_loss, color = colors[0], label = "LGC - hard")
ax.plot(idx, gat_loss, color = colors[1], linestyle = 'dashed', label = "GAT - random")
ax.plot(idx, gat_hard_loss, color = colors[1], label = "GAT - hard")
ax.plot(idx, sage_loss, color = colors[2], linestyle = 'dashed', label = "SAGE - random")
ax.plot(idx, sage_hard_loss, color = colors[2], label = "SAGE - hard")
ax.legend(loc = 'lower left')

ax.set_xlabel("Epochs")
ax.set_ylabel("BPR Loss")
ax.set_title("Model BPR Loss, by convolution type and negative sampling")
ax.set_ylim(0, 0.7)
plt.show()

In [None]:
fig, ax = plt.subplots(1,1, figsize = (8, 6))
key = "recall"
lgc_recall = lgc_stats["val"][key]
gat_recall = gat_stats["val"][key]
sage_recall = sage_stats["val"][key]
lgc_hard_recall = lgc_stats_hard["val"][key]
gat_hard_recall = gat_stats_hard["val"][key]
sage_hard_recall = sage_stats_hard["val"][key]
# increment by 10
idx = np.arange(0, 10 * len(lgc_recall), 10)

colors = ["#1f77b4", "#ff7f0e", "#2ca02c"]
ax.plot(idx, lgc_recall, color = colors[0], linestyle = 'dashed', label = "LGC - random")
ax.plot(idx, lgc_hard_recall, color = colors[0], label = "LGC - hard")
ax.plot(idx, gat_recall, color = colors[1], linestyle = 'dashed', label = "GAT - random")
ax.plot(idx, gat_hard_recall, color = colors[1], label = "GAT - hard")
ax.plot(idx, sage_recall, color = colors[2], linestyle = 'dashed', label = "SAGE - random")
ax.plot(idx, sage_hard_recall, color = colors[2], label = "SAGE - hard")
ax.legend(loc = 'lower right')

ax.set_xlabel("Epochs")
ax.set_ylabel("Recall@300")
ax.set_title("Model Recall@300, by convolution type and negative sampling")
ax.set_ylim(0, 0.7)
plt.show()