In [8]:
import torch
import torch_geometric

In [9]:
print(torch.__version__)
!nvcc --version

2.8.0+cu128
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2025 NVIDIA Corporation
Built on Fri_Feb_21_20:23:50_PST_2025
Cuda compilation tools, release 12.8, V12.8.93
Build cuda_12.8.r12.8/compiler.35583870_0


In [10]:
# General libraries
import json
from pathlib import Path as Data_Path
import os
from os.path import isfile, join
import pickle
import random

import numpy as np
import networkx as nx
import pandas as pd
from sklearn.metrics import f1_score, roc_auc_score
import matplotlib.pyplot as plt
%matplotlib inline

from tqdm.notebook import tqdm

In [11]:
# Import relevant ML libraries
from typing import Optional, Union

import torch
from torch import Tensor
import torch.nn as nn
from torch.nn import Embedding, ModuleList, Linear
import torch.nn.functional as F

import torch_geometric
import torch_geometric.nn as pyg_nn
from torch_geometric.data import Data
from torch_geometric.transforms import RandomLinkSplit
from torch.nn.modules.loss import _Loss

from torch_geometric.nn.conv import LGConv, GATConv, SAGEConv
from torch_geometric.typing import Adj, OptTensor, SparseTensor

print(f"Torch version: {torch.__version__}; Torch-cuda version: {torch.version.cuda}; Torch Geometric version: {torch_geometric.__version__}.")

Torch version: 2.8.0+cu128; Torch-cuda version: 12.8; Torch Geometric version: 2.7.0.


In [12]:
# set the seed for reproducibility
seed = 224
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

# `Loading the Data from JSON Files`

In [6]:
MAIN_DIR = "/home/DSE411/Documents/Tanishq/mlg-music-recommendation"
DATA_DIR = Data_Path('spotify_million_playlist_dataset/data')
os.chdir(MAIN_DIR)

In [7]:
with open(f"{DATA_DIR}/{os.listdir(DATA_DIR)[0]}") as jf:
  example_file = json.load(jf)

print(example_file['playlists'][0])

{'name': 'OVO', 'collaborative': 'false', 'pid': 241000, 'modified_at': 1493596800, 'num_tracks': 45, 'num_albums': 10, 'num_followers': 1, 'tracks': [{'pos': 0, 'artist_name': 'Drake', 'track_uri': 'spotify:track:7jslhIiELQkgW9IHeYNOWE', 'artist_uri': 'spotify:artist:3TVXtAsR1Inumwj472S9r4', 'track_name': 'Big Rings', 'album_uri': 'spotify:album:1ozpmkWcCHwsQ4QTnxOOdT', 'duration_ms': 217706, 'album_name': 'What A Time To Be Alive'}, {'pos': 1, 'artist_name': 'Drake', 'track_uri': 'spotify:track:2AGottAzfC8bHzF7kEJ3Wa', 'artist_uri': 'spotify:artist:3TVXtAsR1Inumwj472S9r4', 'track_name': 'Diamonds Dancing', 'album_uri': 'spotify:album:1ozpmkWcCHwsQ4QTnxOOdT', 'duration_ms': 314631, 'album_name': 'What A Time To Be Alive'}, {'pos': 2, 'artist_name': 'Drake', 'track_uri': 'spotify:track:27GmP9AWRs744SzKcpJsTZ', 'artist_uri': 'spotify:artist:3TVXtAsR1Inumwj472S9r4', 'track_name': 'Jumpman', 'album_uri': 'spotify:album:1ozpmkWcCHwsQ4QTnxOOdT', 'duration_ms': 205879, 'album_name': 'What A 

In [104]:
"""
Here we define classes for the data that we are going to load. The data is stored in JSON files, each
which contain playlists, which themselves contain tracks. Thus, we define three classes:
  Track       --> contains information for a specific track (its id, name, etc.)
  Playlist    --> contains information for a specific playlist (its id, name, etc. as well as a list of Tracks)
  JSONFile    --> contains the loaded json file and stores a dictionary of all of the Playlists

Note: if we were to use the artist information, we could make an Artist class
"""

class Track:
  """
  Simple class for a track, containing its attributes:
    1. URI (a unique id)
    2. Name
    3. Artist info (URI and name)
    4. Parent playlist
  """

  def __init__(self, track_dict, playlist):
    self.uri = track_dict["track_uri"]
    self.name = track_dict["track_name"]
    self.artist_uri = track_dict["artist_uri"]
    self.artist_name = track_dict["artist_name"]
    self.album_uri = track_dict["album_uri"]
    self.album_name = track_dict["album_name"]
    self.playlist = playlist

  def __str__(self):
    return f"Track {self.uri} called {self.name} by {self.artist_uri} ({self.artist_name}) and in {self.album_uri} | {self.album_name} in playlist {self.playlist}."

  def __repr__(self):
    return f"Track {self.uri}"

class Playlist:
  """
  Simple class for a playlist, containing its attributes:
    1. Name (playlist and its associated index)
    2. Title (playlist title in the Spotify dataset)
    3. Loaded dictionary from the raw json for the playlist
    4. Dictionary of tracks (track_uri : Track), populated by .load_tracks()
    5. List of artists uris
  """

  def __init__(self, json_data, index):

    self.name = f"playlist_{index}"
    self.title = json_data["name"]
    self.data = json_data

    self.tracks = {}

  def load_tracks(self):
    """ Call this function to load all of the tracks in the json data for the playlist."""

    tracks_list = self.data["tracks"]
    self.tracks = {x["track_uri"] : Track(x, self.name) for x in tracks_list}


  def __str__(self):
    return f"Playlist {self.name} with {len(self.tracks)} tracks loaded."

  def __repr__(self):
    return f"Playlist {self.name}"

class JSONFile:
  """
  Simple class for a JSON file, containing its attributes:
    1. File Name
    2. Index to begin numbering playlists at
    3. Loaded dictionary from the raw json for the full file
    4. Dictionary of playlists (name : Playlist), populated by .process_file()
  """

  def __init__(self, data_path, file_name, start_index):

    self.file_name = file_name
    self.start_index = start_index

    with open(join(data_path, file_name)) as json_file:
      json_data = json.load(json_file)
    self.data = json_data

    self.playlists = {}

  def process_file(self):
    """ Call this function to load all of the playlists in the json data."""

    for i, playlist_json in enumerate(self.data["playlists"]):
      playlist = Playlist(playlist_json, self.start_index + i)
      playlist.load_tracks()
      self.playlists[playlist.name] = playlist

  def __str__(self):
    return f"JSON {self.file_name} has {len(self.playlists)} playlists loaded."

  def __repr__(self):
    return self.file_name


In [105]:
DATA_PATH = Data_Path('spotify_million_playlist_dataset/data')
N_FILES_TO_USE = 50

file_names = sorted(os.listdir(DATA_PATH))
file_names_to_use = file_names[:N_FILES_TO_USE]

n_playlists = 0

# load each json file, and store it in a list of files
JSONs = []
for file_name in tqdm(file_names_to_use, desc='Files processed: ', unit='files', total=len(file_names_to_use)):
  json_file = JSONFile(DATA_PATH, file_name, n_playlists)
  json_file.process_file()
  n_playlists += len(json_file.playlists)
  JSONs.append(json_file)

Files processed:   0%|          | 0/50 [00:00<?, ?files/s]

In [106]:
playlist_data = {}
track_data = []
playlists = []
tracks = []

# build list of all unique playlists, tracks
for json_file in tqdm(JSONs):
  playlists += [p.name for p in json_file.playlists.values()]
  tracks += [track.uri for playlist in json_file.playlists.values() for track in list(playlist.tracks.values())]
  playlist_data = playlist_data | json_file.playlists

  0%|          | 0/50 [00:00<?, ?it/s]

In [107]:
len(playlists), len(set(tracks)), len(tracks)

(50000, 457016, 3303932)

In [110]:
n_tracks = len(set(tracks))

In [12]:
## create graph from these lists

# adding nodes
G = nx.Graph()
G.add_nodes_from([
    (p, {'name':p, "node_type" : "playlist"}) for p in playlists
])
G.add_nodes_from([
    (t, {'name':t, "node_type" : "track"}) for t in tracks
])
# add node types of track to album and artist
G.add_nodes_from([
    (track.album_uri, {'name':track.album_uri, "node_type" : "album"}) for p_name, playlist in playlist_data.items() for track in playlist.tracks.values()
])
G.add_nodes_from([
    (track.artist_uri, {'name':track.artist_uri, "node_type" : "artist"}) for p_name, playlist in playlist_data.items() for track in playlist.tracks.values()   
])

# adding edges
track_edge_list = []
album_edge_list = []
artist_edge_list = []
for p_name, playlist in playlist_data.items():
  for track in playlist.tracks.values():
    track_edge_list.append((p_name, track.uri))
    album_edge_list.append((track.uri, track.album_uri))
    artist_edge_list.append((track.uri, track.artist_uri))

G.add_edges_from(track_edge_list, edge_types="track_in_playlist")
G.add_edges_from(album_edge_list, edge_types="track_in_album")
G.add_edges_from(artist_edge_list, edge_types="track_by_artist")

print('Num nodes:', G.number_of_nodes(), '. Num edges:', G.number_of_edges())

KeyboardInterrupt: 

In [None]:
from collections import Counter

cnt = Counter([d["node_type"] for (_, d) in G.nodes(data=True)])
print(cnt)


Counter({'track': 457016, 'album': 192812, 'artist': 79189, 'playlist': 50000})


In [None]:
cmap_theme = "Set1"

# `Graph for Visualization (N=20)`

In [23]:
import random
random.seed(seed)

In [25]:

# -------- CONFIG --------
NUM_PLAYLISTS_TO_SAMPLE = 20

# -------- STEP 1: sample playlists --------
playlist_nodes = [
    n for n, d in G.nodes(data=True) if d["node_type"] == "playlist"
]

sampled_playlists = random.sample(playlist_nodes, NUM_PLAYLISTS_TO_SAMPLE)

sub_nodes = set(sampled_playlists)

# -------- STEP 2: playlist → track --------
for p in sampled_playlists:
    for neigh in G.neighbors(p):
        if G.nodes[neigh]["node_type"] == "track":
            sub_nodes.add(neigh)

# -------- STEP 3: track → album & artist ONLY --------
for t in list(sub_nodes):
    if G.nodes[t]["node_type"] == "track":
        for neigh in G.neighbors(t):
            if G.nodes[neigh]["node_type"] in ["album", "artist"]:
                sub_nodes.add(neigh)

# -------- STEP 4: build subgraph --------
sub_G_vis = G.subgraph(sub_nodes).copy()

print("Nodes:", sub_G_vis.number_of_nodes())
print("Edges:", sub_G_vis.number_of_edges())


Nodes: 2721
Edges: 3746


In [None]:
from pyvis.network import Network

net = Network(height="800px", width="100%", notebook=True)
net.force_atlas_2based()   # fast, stable layout

color_map = {
    "playlist": "#ff3333",
    "track": "#ff9933",
    "album": "#9933ff",
    "artist": "#999999"
}

for node, data in sub_G_vis.nodes(data=True):
    net.add_node(
        node,
        label=data["node_type"],        # node type displayed
        title=node,                     # hover shows node ID
        color=color_map[data["node_type"]],
        size=8
    )

for u, v in sub_G_vis.edges():
    net.add_edge(u, v, color="#cccccc")

net.show("20_subgraph.html")


100_subgraph.html


# `Actual Graph to use for training N=10k`

In [None]:
import random
# -------- CONFIG --------
NUM_PLAYLISTS_TO_SAMPLE = 10000

# -------- STEP 1: sample playlists --------
playlist_nodes = [
    n for n, d in G.nodes(data=True) if d["node_type"] == "playlist"
]

sampled_playlists = random.sample(playlist_nodes, NUM_PLAYLISTS_TO_SAMPLE)

sub_nodes = set(sampled_playlists)

# -------- STEP 2: playlist → track --------
for p in sampled_playlists:
    for neigh in G.neighbors(p):
        if G.nodes[neigh]["node_type"] == "track":
            sub_nodes.add(neigh)

# -------- STEP 3: track → album & artist ONLY --------
for t in list(sub_nodes):
    if G.nodes[t]["node_type"] == "track":
        for neigh in G.neighbors(t):
            if G.nodes[neigh]["node_type"] in ["album", "artist"]:
                sub_nodes.add(neigh)

# -------- STEP 4: build subgraph --------
sub_G = G.subgraph(sub_nodes).copy()

print("Nodes:", sub_G.number_of_nodes())
print("Edges:", sub_G.number_of_edges())


KeyboardInterrupt: 

In [None]:
from collections import Counter

cnt = Counter([d["node_type"] for (_, d) in sub_G.nodes(data=True)])
print(cnt)


Counter({'track': 171855, 'album': 81720, 'artist': 35797, 'playlist': 10000})


# `Save Graph for regular use`

In [33]:
with open("10K_playlist_graph.pkl", "wb") as f:
    pickle.dump(sub_G, f)

# `Load Graph `

In [13]:
# Note if you've already generated the graph above, you can skip those steps, and simply run set reload to True!
reload = True
if reload:
  sub_G = pickle.load(open("10K_playlist_graph.pkl", "rb"))
print('Num nodes:', sub_G.number_of_nodes(), '. Num edges:', sub_G.number_of_edges())

Num nodes: 299372 . Num edges: 1001704


# `Convert to Heterogenous Data Loader in PyG`

In [14]:
import torch
from torch_geometric.data import HeteroData

def nx_to_heterodata(G):
    data = HeteroData()

    # ---- 1. Map nodes per type ----
    node_maps = {}
    for node, attr in G.nodes(data=True):
        ntype = attr["node_type"]
        if ntype not in node_maps:
            node_maps[ntype] = {}
        node_maps[ntype][node] = len(node_maps[ntype])

    # ---- 2. Set num_nodes per node type ----
    for ntype, idmap in node_maps.items():
        data[ntype].num_nodes = len(idmap)

    # ---- 3. Collect edges grouped by (src_type, rel, dst_type) ----
    edge_groups = {}

    for u, v, attr in G.edges(data=True):
        rel = attr["edge_types"]               # your key
        src_t = G.nodes[u]["node_type"]
        dst_t = G.nodes[v]["node_type"]

        src_id = node_maps[src_t][u]
        dst_id = node_maps[dst_t][v]

        edge_type = (src_t, rel, dst_t)

        if edge_type not in edge_groups:
            edge_groups[edge_type] = []

        edge_groups[edge_type].append([src_id, dst_id])

    # ---- 4. Write to PyG ----
    for etype, edges in edge_groups.items():
        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
        data[etype].edge_index = edge_index

    return data


In [15]:
hetero_data = nx_to_heterodata(sub_G)
print(hetero_data)


HeteroData(
  track={ num_nodes=171855 },
  album={ num_nodes=81720 },
  playlist={ num_nodes=10000 },
  artist={ num_nodes=35797 },
  (track, track_in_playlist, playlist)={ edge_index=[2, 329839] },
  (track, track_in_album, album)={ edge_index=[2, 85611] },
  (track, track_by_artist, artist)={ edge_index=[2, 87106] },
  (album, track_in_album, track)={ edge_index=[2, 86244] },
  (playlist, track_in_playlist, track)={ edge_index=[2, 328155] },
  (artist, track_by_artist, track)={ edge_index=[2, 84749] }
)


In [16]:
edge_types = hetero_data.edge_types
edge_types

[('track', 'track_in_playlist', 'playlist'),
 ('track', 'track_in_album', 'album'),
 ('track', 'track_by_artist', 'artist'),
 ('album', 'track_in_album', 'track'),
 ('playlist', 'track_in_playlist', 'track'),
 ('artist', 'track_by_artist', 'track')]

# `Train/Val/Test Split via RandomLinkSplit`

In [17]:
edge_types

[('track', 'track_in_playlist', 'playlist'),
 ('track', 'track_in_album', 'album'),
 ('track', 'track_by_artist', 'artist'),
 ('album', 'track_in_album', 'track'),
 ('playlist', 'track_in_playlist', 'track'),
 ('artist', 'track_by_artist', 'track')]

In [18]:

### This is wrong -- need to specify which edge types to split on
# convert to train/val/test splits
# transform = RandomLinkSplit(
#     is_undirected=True,
#     add_negative_train_samples=False,
#     neg_sampling_ratio=0,
#     num_val=0.15, num_test=0.15,
#     edge_types=edge_types
# )

transform = RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    disjoint_train_ratio=0.0,
    split_labels=True,
    is_undirected=True,
    add_negative_train_samples=False,
    edge_types=[('playlist', 'track_in_playlist', 'track')],
    rev_edge_types=[('track', 'track_in_playlist', 'playlist')],
)


train_split, val_split, test_split = transform(hetero_data)

In [19]:
train_split

HeteroData(
  track={ num_nodes=171855 },
  album={ num_nodes=81720 },
  playlist={ num_nodes=10000 },
  artist={ num_nodes=35797 },
  (track, track_in_playlist, playlist)={ edge_index=[2, 262525] },
  (track, track_in_album, album)={ edge_index=[2, 85611] },
  (track, track_by_artist, artist)={ edge_index=[2, 87106] },
  (album, track_in_album, track)={ edge_index=[2, 86244] },
  (playlist, track_in_playlist, track)={
    edge_index=[2, 262525],
    pos_edge_label=[262525],
    pos_edge_label_index=[2, 262525],
  },
  (artist, track_by_artist, track)={ edge_index=[2, 84749] }
)

In [20]:
edge_type = hetero_data.edge_types[0]
hetero_data[edge_type].edge_index

tensor([[     0,      0,      0,  ..., 171333, 171468, 171623],
        [  4555,   1326,   6228,  ...,   9984,   9993,   9994]])

In [21]:
edge_type, hetero_data.edge_types

(('track', 'track_in_playlist', 'playlist'),
 [('track', 'track_in_playlist', 'playlist'),
  ('track', 'track_in_album', 'album'),
  ('track', 'track_by_artist', 'artist'),
  ('album', 'track_in_album', 'track'),
  ('playlist', 'track_in_playlist', 'track'),
  ('artist', 'track_by_artist', 'track')])

In [23]:
# Edge types
fwd_edge = ('playlist', 'track_in_playlist', 'track')
rev_edge = ('track', 'track_in_playlist', 'playlist')

def convert_edge_types(split, edge_type):
    # Convert message-passing edges
    if 'edge_index' in split[edge_type]:
        split[edge_type].edge_index = split[edge_type].edge_index.long()

    # Convert label edges (used for link prediction)
    if "edge_label_index" in split[edge_type]:
        split[edge_type].edge_label_index = split[edge_type].edge_label_index.long()

# Convert FORWARD + REVERSE edges
for split in [train_split, val_split, test_split]:
    convert_edge_types(split, fwd_edge)
    convert_edge_types(split, rev_edge)

# Print stats
print("Forward edge:", fwd_edge)
print("Reverse edge:", rev_edge)

print(f"Train supervision edges: {train_split[fwd_edge].pos_edge_label_index.shape[1]}")
print(f"Val supervision edges:   {val_split[fwd_edge].pos_edge_label_index.shape[1]}")
print(f"Test supervision edges:  {test_split[fwd_edge].pos_edge_label_index.shape[1]}")

print(f"Train MP edges (fwd):    {train_split[fwd_edge].edge_index.shape[1]}")
print(f"Train MP edges (rev):    {train_split[rev_edge].edge_index.shape[1]}")

print(f"Val MP edges (fwd):      {val_split[fwd_edge].edge_index.shape[1]}")
print(f"Val MP edges (rev):      {val_split[rev_edge].edge_index.shape[1]}")

print(f"Test MP edges (fwd):     {test_split[fwd_edge].edge_index.shape[1]}")
print(f"Test MP edges (rev):     {test_split[rev_edge].edge_index.shape[1]}")


Forward edge: ('playlist', 'track_in_playlist', 'track')
Reverse edge: ('track', 'track_in_playlist', 'playlist')
Train supervision edges: 262525
Val supervision edges:   32815
Test supervision edges:  32815
Train MP edges (fwd):    262525
Train MP edges (rev):    262525
Val MP edges (fwd):      262525
Val MP edges (rev):      262525
Test MP edges (fwd):     295340
Test MP edges (rev):     295340


# Model Creation

In [None]:
# class GCN(torch.nn.Module):
#     """
#       Here we adapt the LightGCN model from Torch Geometric for our purposes. We allow
#       for customizable convolutional layers, custom embeddings. In addition, we deifne some
#       additional custom functions.

#     """

#     def __init__(
#         self,
#         num_nodes: int,
#         embedding_dim: int,
#         num_layers: int,
#         alpha: Optional[Union[float, Tensor]] = None,
#         alpha_learnable = False,
#         conv_layer = "LGC",
#         name = None,
#         **kwargs,
#     ):
#         super().__init__()
#         alpha_string = "alpha" if alpha_learnable else ""
#         self.name = f"LGCN_{conv_layer}_{num_layers}_e{embedding_dim}_nodes{num_nodes}_{alpha_string}"
#         self.num_nodes = num_nodes
#         self.embedding_dim = embedding_dim
#         self.num_layers = num_layers

#         if alpha_learnable == True:
#           alpha_vals = torch.rand(num_layers+1)
#           alpha = nn.Parameter(alpha_vals/torch.sum(alpha_vals))
#           print(f"Alpha learnable, initialized to: {alpha.softmax(dim=-1)}")
#         else:
#           if alpha is None:
#               alpha = 1. / (num_layers + 1)

#           if isinstance(alpha, Tensor):
#               assert alpha.size(0) == num_layers + 1
#           else:
#               alpha = torch.tensor([alpha] * (num_layers + 1))

#         self.register_buffer('alpha', alpha)

#         self.embedding = Embedding(num_nodes, embedding_dim)

#         # initialize convolutional layers
#         self.conv_layer = conv_layer
#         if conv_layer == "LGC":
#           self.convs = ModuleList([LGConv(**kwargs) for _ in range(num_layers)])
#         elif conv_layer == "GAT":
#           # initialize Graph Attention layer with multiple heads
#           # initialize linear layers to aggregate heads
#           n_heads = 5
#           self.convs = ModuleList(
#               [GATConv(in_channels = embedding_dim, out_channels = embedding_dim, heads = n_heads, dropout = 0.5, **kwargs) for _ in range(num_layers)]
#           )
#           self.linears = ModuleList([Linear(n_heads * embedding_dim, embedding_dim) for _ in range(num_layers)])

#         elif conv_layer == "SAGE":
#           #  initialize GraphSAGE conv
#           self.convs = ModuleList(
#               [SAGEConv(in_channels = embedding_dim, out_channels = embedding_dim, **kwargs) for _ in range(num_layers)]
#           )

#         self.reset_parameters()

#     def reset_parameters(self):
#         torch.nn.init.xavier_uniform_(self.embedding.weight)
#         for conv in self.convs:
#             conv.reset_parameters()

#     def get_embedding(self, edge_index: Adj) -> Tensor:
#         x = self.embedding.weight

#         weights = self.alpha.softmax(dim=-1)
#         out = x * weights[0]

#         for i in range(self.num_layers):
#             x = self.convs[i](x, edge_index)
#             if self.conv_layer == "GAT":
#               x = self.linears[i](x)
#             out = out + x * weights[i + 1]

#         return out

#     def initialize_embeddings(self, data):
#       # initialize with the data node features
#         self.embedding.weight.data.copy_(data.node_feature)


#     def forward(self, edge_index: Adj,
#                 edge_label_index: OptTensor = None) -> Tensor:
#         if edge_label_index is None:
#             if isinstance(edge_index, SparseTensor):
#                 edge_label_index = torch.stack(edge_index.coo()[:2], dim=0)
#             else:
#                 edge_label_index = edge_index

#         out = self.get_embedding(edge_index)

#         return self.predict_link_embedding(out, edge_label_index)

#     def predict_link(self, edge_index: Adj, edge_label_index: OptTensor = None,
#                      prob: bool = False) -> Tensor:

#         pred = self(edge_index, edge_label_index).sigmoid()
#         return pred if prob else pred.round()

#     def predict_link_embedding(self, embed: Adj, edge_label_index: Adj) -> Tensor:

#         embed_src = embed[edge_label_index[0]]
#         embed_dst = embed[edge_label_index[1]]
#         return (embed_src * embed_dst).sum(dim=-1)


#     def recommend(self, edge_index: Adj, src_index: OptTensor = None,
#                   dst_index: OptTensor = None, k: int = 1) -> Tensor:
#         out_src = out_dst = self.get_embedding(edge_index)

#         if src_index is not None:
#             out_src = out_src[src_index]

#         if dst_index is not None:
#             out_dst = out_dst[dst_index]

#         pred = out_src @ out_dst.t()
#         top_index = pred.topk(k, dim=-1).indices

#         if dst_index is not None:  # Map local top-indices to original indices.
#             top_index = dst_index[top_index.view(-1)].view(*top_index.size())

#         return top_index


#     def link_pred_loss(self, pred: Tensor, edge_label: Tensor,
#                        **kwargs) -> Tensor:
#         loss_fn = torch.nn.BCEWithLogitsLoss(**kwargs)
#         return loss_fn(pred, edge_label.to(pred.dtype))


#     def recommendation_loss(self, pos_edge_rank: Tensor, neg_edge_rank: Tensor,
#                             lambda_reg: float = 1e-4, **kwargs) -> Tensor:
#         r"""Computes the model loss for a ranking objective via the Bayesian
#         Personalized Ranking (BPR) loss."""
#         loss_fn = BPRLoss(lambda_reg, **kwargs)
#         return loss_fn(pos_edge_rank, neg_edge_rank, self.embedding.weight)

#     def bpr_loss(self, pos_scores, neg_scores):
#       return - torch.log(torch.sigmoid(pos_scores - neg_scores)).mean()

#     def __repr__(self) -> str:
#         return (f'{self.__class__.__name__}({self.num_nodes}, '
#                 f'{self.embedding_dim}, num_layers={self.num_layers})')


In [42]:
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import Embedding, ModuleList, Linear
from torch_geometric.nn import RGCNConv, SAGEConv, HGTConv, to_hetero
from torch_geometric.typing import Adj, OptTensor
from typing import Optional, Union, Dict, List


class HeteroGNN(torch.nn.Module):
    """
    Base heterogeneous GNN model for playlist-track recommendation.
    Supports RGCN, GraphSAGE-H, and HGT architectures.
    """

    def __init__(
        self,
        metadata: tuple,  # (node_types, edge_types)
        num_nodes_dict: Dict[str, int],  # {'playlist': 10000, 'track': 171855, ...}
        embedding_dim: int,
        num_layers: int,
        model_type: str = "RGCN",  # "RGCN", "SAGE", "HGT"
        alpha: Optional[Union[float, Tensor]] = None,
        heads: int = 4,  # For HGT
        **kwargs,
    ):
        super().__init__()
        
        self.metadata = metadata
        self.node_types = metadata[0]
        self.edge_types = metadata[1]
        self.num_nodes_dict = num_nodes_dict
        self.embedding_dim = embedding_dim
        self.num_layers = num_layers
        self.model_type = model_type
        
        # Alpha weighting for layer combinations
        if alpha is None:
            alpha = 1. / (num_layers + 1)
        if isinstance(alpha, Tensor):
            assert alpha.size(0) == num_layers + 1
        else:
            alpha = torch.tensor([alpha] * (num_layers + 1))
        self.register_buffer('alpha', alpha)
        
        # Create embeddings for each node type
        self.embeddings = nn.ModuleDict({
            node_type: Embedding(num_nodes, embedding_dim)
            for node_type, num_nodes in num_nodes_dict.items()
        })
        
        # Create convolutional layers based on model type
        self.convs = ModuleList()
        
        if model_type == "RGCN":
            # Relational GCN - handles different edge types with relation-specific weights
            num_relations = len(self.edge_types)
            for _ in range(num_layers):
                self.convs.append(
                    RGCNConv(
                        embedding_dim, 
                        embedding_dim, 
                        num_relations=num_relations,
                        **kwargs
                    )
                )
        
        elif model_type == "SAGE":
            # GraphSAGE for heterogeneous graphs (will be wrapped with to_hetero)
            for _ in range(num_layers):
                self.convs.append(
                    SAGEConv(embedding_dim, embedding_dim, **kwargs)
                )
        
        elif model_type == "HGT":
            # Heterogeneous Graph Transformer
            for _ in range(num_layers):
                self.convs.append(
                    HGTConv(
                        embedding_dim,
                        embedding_dim,
                        metadata,
                        heads=heads,
                        **kwargs
                    )
                )
        
        else:
            raise ValueError(f"Unknown model_type: {model_type}")
        
        self.reset_parameters()

    def reset_parameters(self):
        for emb in self.embeddings.values():
            torch.nn.init.xavier_uniform_(emb.weight)
        for conv in self.convs:
            conv.reset_parameters()

    def get_embedding(self, batch) -> Dict[str, Tensor]:
        """
        Get node embeddings for heterogeneous batch.
        
        Args:
            batch: HeteroData batch
            
        Returns:
            Dictionary of embeddings for each node type
        """
        # Initialize with learned embeddings using node_id (global IDs in batch)
        # This indexes the FULL embedding table with batch's global node IDs
        x_dict = {
            node_type: self.embeddings[node_type](batch[node_type].n_id)
            for node_type in self.node_types
        }
        
        weights = self.alpha.softmax(dim=-1)
        
        # Initialize output with weighted initial embeddings
        out_dict = {
            node_type: x * weights[0]
            for node_type, x in x_dict.items()
        }
        
        # Message passing
        if self.model_type == "RGCN":
            # RGCN uses homogeneous edge_index with edge_type
            edge_index, edge_type = self._to_homogeneous(batch)
            
            # Convert node features to homogeneous
            x = torch.cat([x_dict[nt] for nt in self.node_types], dim=0)
            
            for i, conv in enumerate(self.convs):
                x = conv(x, edge_index, edge_type)
                x = x.relu()
                
                # Add to output with weight
                start_idx = 0
                for node_type in self.node_types:
                    num_nodes = x_dict[node_type].size(0)
                    out_dict[node_type] = out_dict[node_type] + x[start_idx:start_idx+num_nodes] * weights[i + 1]
                    start_idx += num_nodes
        
        elif self.model_type == "SAGE":
            # GraphSAGE with heterogeneous message passing
            edge_index_dict = batch.edge_index_dict
            
            for i, conv in enumerate(self.convs):
                # Apply convolution for each edge type
                x_dict_new = {}
                for node_type in self.node_types:
                    # Aggregate messages from all edge types involving this node
                    msgs = []
                    for edge_type in self.edge_types:
                        src, rel, dst = edge_type
                        if dst == node_type and edge_type in edge_index_dict:
                            edge_index = edge_index_dict[edge_type]
                            msg = conv((x_dict[src], x_dict[dst]), edge_index)
                            msgs.append(msg)
                    
                    if msgs:
                        x_dict_new[node_type] = torch.stack(msgs).mean(dim=0).relu()
                    else:
                        x_dict_new[node_type] = x_dict[node_type]
                    
                    # Add to output with weight
                    out_dict[node_type] = out_dict[node_type] + x_dict_new[node_type] * weights[i + 1]
                
                x_dict = x_dict_new
        
        elif self.model_type == "HGT":
            # Heterogeneous Graph Transformer
            edge_index_dict = batch.edge_index_dict
            
            for i, conv in enumerate(self.convs):
                x_dict = conv(x_dict, edge_index_dict)
                x_dict = {key: x.relu() for key, x in x_dict.items()}
                
                # Add to output with weight
                for node_type in self.node_types:
                    out_dict[node_type] = out_dict[node_type] + x_dict[node_type] * weights[i + 1]
        
        return out_dict

    def _to_homogeneous(self, batch):
        """Convert heterogeneous batch to homogeneous for RGCN."""
        edge_indices = []
        edge_types = []
        
        # Create node offset mapping
        node_offset = {}
        offset = 0
        for node_type in self.node_types:
            node_offset[node_type] = offset
            offset += batch[node_type].num_nodes
        
        # Collect all edges with their types
        for edge_type_idx, edge_type in enumerate(self.edge_types):
            src_type, _, dst_type = edge_type
            if edge_type in batch.edge_index_dict:
                edge_index = batch.edge_index_dict[edge_type]
                # Offset node indices
                edge_index_offset = edge_index.clone()
                edge_index_offset[0] += node_offset[src_type]
                edge_index_offset[1] += node_offset[dst_type]
                
                edge_indices.append(edge_index_offset)
                edge_types.append(torch.full((edge_index.size(1),), edge_type_idx, dtype=torch.long))
        
        edge_index = torch.cat(edge_indices, dim=1)
        edge_type = torch.cat(edge_types, dim=0)
        
        return edge_index.to(batch['playlist'].n_id.device), edge_type.to(batch['playlist'].n_id.device)

    def forward(self, batch, edge_label_index: OptTensor = None) -> Tensor:
        """
        Forward pass for link prediction.
        
        Args:
            batch: HeteroData batch
            edge_label_index: [2, num_edges] edges to predict (playlist, track)
            
        Returns:
            Prediction scores for edges
        """
        if edge_label_index is None:
            edge_label_index = batch['playlist', 'track_in_playlist', 'track'].pos_edge_label_index
        
        out_dict = self.get_embedding(batch)
        
        return self.predict_link_embedding(out_dict, edge_label_index, batch)

    def predict_link_embedding(self, embed_dict: Dict[str, Tensor], edge_label_index: Tensor, 
                              batch=None) -> Tensor:
        """
        Predict link scores using dot product.
        
        Args:
            embed_dict: Dictionary of embeddings per node type (LOCAL batch embeddings)
            edge_label_index: [2, num_edges] (playlist_global_idx, track_global_idx)
            batch: HeteroData batch for global-to-local mapping
            
        Returns:
            scores: [num_edges] prediction scores
        """
        # embed_dict contains embeddings indexed by LOCAL batch indices
        # edge_label_index contains GLOBAL node IDs
        # We need to map global IDs to local batch indices
        
        if batch is None:
            # No batch provided, assume edge_label_index has local indices
            playlist_emb = embed_dict['playlist']
            track_emb = embed_dict['track']
            embed_src = playlist_emb[edge_label_index[0]]
            embed_dst = track_emb[edge_label_index[1]]
        else:
            # Get full embedding tables and index directly by global ID
            playlist_emb_full = self.embeddings['playlist'].weight
            track_emb_full = self.embeddings['track'].weight
            
            # But we need the message-passed embeddings from embed_dict
            # Solution: index the full embeddings, then apply the final layer output
            # Actually, let's just use batch n_id to map properly
            
            playlist_n_id = batch['playlist'].n_id
            track_n_id = batch['track'].n_id
            
            # Create mapping
            # Find local index for each global ID in edge_label_index
            playlist_global_ids = edge_label_index[0]
            track_global_ids = edge_label_index[1]
            
            # Use searchsorted for efficient mapping (requires sorted n_id)
            playlist_sorted, playlist_sort_idx = playlist_n_id.sort()
            track_sorted, track_sort_idx = track_n_id.sort()
            
            playlist_pos = torch.searchsorted(playlist_sorted, playlist_global_ids)
            track_pos = torch.searchsorted(track_sorted, track_global_ids)
            
            # Get actual local indices
            playlist_local = playlist_sort_idx[playlist_pos]
            track_local = track_sort_idx[track_pos]
            
            # Index embeddings
            embed_src = embed_dict['playlist'][playlist_local]
            embed_dst = embed_dict['track'][track_local]
        
        return (embed_src * embed_dst).sum(dim=-1)

    def predict_link(self, batch, edge_label_index: OptTensor = None, prob: bool = False) -> Tensor:
        """Predict links with optional probability output."""
        pred = self(batch, edge_label_index).sigmoid()
        return pred if prob else pred.round()

    def link_pred_loss(self, pred: Tensor, edge_label: Tensor, **kwargs) -> Tensor:
        """Binary cross-entropy loss for link prediction."""
        loss_fn = torch.nn.BCEWithLogitsLoss(**kwargs)
        return loss_fn(pred, edge_label.to(pred.dtype))

    def bpr_loss(self, pos_scores: Tensor, neg_scores: Tensor) -> Tensor:
        """
        Bayesian Personalized Ranking loss.
        Handles multiple negatives per positive.
        """
        num_pos = pos_scores.size(0)
        num_neg = neg_scores.size(0)
        
        if num_pos == num_neg:
            # Equal number of positives and negatives
            return -torch.log(torch.sigmoid(pos_scores - neg_scores)).mean()
        else:
            # Multiple negatives per positive - expand positives
            neg_ratio = num_neg // num_pos
            pos_scores_expanded = pos_scores.repeat_interleave(neg_ratio)
            return -torch.log(torch.sigmoid(pos_scores_expanded - neg_scores)).mean()

    def recommendation_loss(self, pos_edge_rank: Tensor, neg_edge_rank: Tensor,
                           lambda_reg: float = 1e-4) -> Tensor:
        """BPR loss with L2 regularization."""
        bpr = self.bpr_loss(pos_edge_rank, neg_edge_rank)
        
        # L2 regularization on embeddings
        reg_loss = 0
        for emb in self.embeddings.values():
            reg_loss += emb.weight.norm(2).pow(2)
        
        return bpr + lambda_reg * reg_loss

    def __repr__(self) -> str:
        total_nodes = sum(self.num_nodes_dict.values())
        return (f'{self.__class__.__name__}({self.model_type}, '
                f'nodes={total_nodes}, '
                f'emb_dim={self.embedding_dim}, '
                f'layers={self.num_layers})')


# ============================================================================
# Usage Examples
# ============================================================================

def create_model_rgcn(metadata, num_nodes_dict, embedding_dim=128, num_layers=3):
    """Create RGCN model."""
    return HeteroGNN(
        metadata=metadata,
        num_nodes_dict=num_nodes_dict,
        embedding_dim=embedding_dim,
        num_layers=num_layers,
        model_type="RGCN"
    )

def create_model_sage(metadata, num_nodes_dict, embedding_dim=128, num_layers=3):
    """Create GraphSAGE-H model."""
    return HeteroGNN(
        metadata=metadata,
        num_nodes_dict=num_nodes_dict,
        embedding_dim=embedding_dim,
        num_layers=num_layers,
        model_type="SAGE"
    )

def create_model_hgt(metadata, num_nodes_dict, embedding_dim=128, num_layers=3, heads=4):
    """Create HGT model."""
    return HeteroGNN(
        metadata=metadata,
        num_nodes_dict=num_nodes_dict,
        embedding_dim=embedding_dim,
        num_layers=num_layers,
        model_type="HGT",
        heads=heads
    )


# Example usage:
metadata = train_split.metadata()
num_nodes_dict = {
    'playlist': 10000,
    'track': 171855,
    'album': 81720,
    'artist': 35797
}

# Create model
device = "cpu"
model = create_model_rgcn(metadata, num_nodes_dict, embedding_dim=128, num_layers=3)
model = model.to(device)

# Training loop
for batch in train_loader:
    batch = batch.to(device)
    
    # Forward pass
    pos_pred = model(batch)
    
    # Sample negatives
    neg_edge_index, neg_edge_label = sample_negatives(batch)
    neg_pred = model(batch, neg_edge_index)
    
    # Compute loss
    loss = model.bpr_loss(pos_pred, neg_pred)
    
    # Backward
    loss.backward()
    optimizer.step()

KeyboardInterrupt: 

# `v19 works`

In [41]:
metadata = train_split.metadata()
num_nodes_dict = {
    'playlist': 10000,
    'track': 171855,
    'album': 81720,
    'artist': 35797
}

# Create model
model = create_model_rgcn(metadata, num_nodes_dict, embedding_dim=128, num_layers=3)
model = model.to(device)

# Training loop sanity check
for batch in train_loader:
    batch = batch.to(device)
    optimizer.zero_grad()  # ADD THIS! You forgot to zero gradients
    
    # Forward pass
    pos_pred = model(batch)
    
    # Sample negatives
    neg_edge_index, neg_edge_label = sample_negatives(batch)
    neg_pred = model(batch, neg_edge_index)
    
    # Compute loss
    loss = model.bpr_loss(pos_pred, neg_pred)
    
    print(f"Loss: {loss.item():.4f}")
    print(f"Pos pred range: [{pos_pred.min():.4f}, {pos_pred.max():.4f}]")
    print(f"Neg pred range: [{neg_pred.min():.4f}, {neg_pred.max():.4f}]")
    print(f"Pos pred mean: {pos_pred.mean():.4f}, Neg pred mean: {neg_pred.mean():.4f}")
    
    # Backward
    loss.backward()
    
    # Check gradients are flowing
    total_grad_norm = 0
    for name, param in model.named_parameters():
        if param.grad is not None:
            total_grad_norm += param.grad.norm().item()
    print(f"Total gradient norm: {total_grad_norm:.4f}")
    
    optimizer.step()
    break

Loss: 0.6932
Pos pred range: [0.0002, 0.0044]
Neg pred range: [0.0002, 0.0042]
Pos pred mean: 0.0012, Neg pred mean: 0.0013
Total gradient norm: 0.0024


# `Bayesian Personalized Ranking Loss`

In [99]:
class BPRLoss(_Loss):
    r"""The Bayesian Personalized Ranking (BPR) loss.

    The BPR loss is a pairwise loss that encourages the prediction of an
    observed entry to be higher than its unobserved counterparts
    (see `here <https://arxiv.org/abs/2002.02126>`__).

    .. math::
        L_{\text{BPR}} = - \sum_{u=1}^{M} \sum_{i \in \mathcal{N}_u}
        \sum_{j \not\in \mathcal{N}_u} \ln \sigma(\hat{y}_{ui} - \hat{y}_{uj})
        + \lambda \vert\vert \textbf{x}^{(0)} \vert\vert^2

    where :math:`lambda` controls the :math:`L_2` regularization strength.
    We compute the mean BPR loss for simplicity.

    Args:
        lambda_reg (float, optional): The :math:`L_2` regularization strength
            (default: 0).
        **kwargs (optional): Additional arguments of the underlying
            :class:`torch.nn.modules.loss._Loss` class.
    """
    __constants__ = ['lambda_reg']
    lambda_reg: float

    def __init__(self, lambda_reg: float = 0, **kwargs):
        super().__init__(None, None, "sum", **kwargs)
        self.lambda_reg = lambda_reg

    def forward(self, positives: Tensor, negatives: Tensor,
                parameters: Tensor = None) -> Tensor:
        r"""Compute the mean Bayesian Personalized Ranking (BPR) loss.

        .. note::

            The i-th entry in the :obj:`positives` vector and i-th entry
            in the :obj:`negatives` entry should correspond to the same
            entity (*.e.g*, user), as the BPR is a personalized ranking loss.

        Args:
            positives (Tensor): The vector of positive-pair rankings.
            negatives (Tensor): The vector of negative-pair rankings.
            parameters (Tensor, optional): The tensor of parameters which
                should be used for :math:`L_2` regularization
                (default: :obj:`None`).
        """
        n_pairs = positives.size(0)
        log_prob = F.logsigmoid(positives - negatives).sum()
        regularization = 0

        if self.lambda_reg != 0:
            regularization = self.lambda_reg * parameters.norm(p=2).pow(2)

        return (-log_prob + regularization) / n_pairs

In [61]:
train_split[('playlist', 'track_in_playlist', 'track')].pos_edge_label_index

tensor([[  5923,   8685,    641,  ...,    776,   7743,   5947],
        [122716, 166856, 140277,  ...,  54836, 161325, 105642]])

In [64]:
train_split

HeteroData(
  track={ num_nodes=171855 },
  album={ num_nodes=81720 },
  playlist={ num_nodes=10000 },
  artist={ num_nodes=35797 },
  (track, track_in_playlist, playlist)={ edge_index=[2, 262525] },
  (track, track_in_album, album)={ edge_index=[2, 85611] },
  (track, track_by_artist, artist)={ edge_index=[2, 87106] },
  (album, track_in_album, track)={ edge_index=[2, 86244] },
  (playlist, track_in_playlist, track)={
    edge_index=[2, 262525],
    pos_edge_label=[262525],
    pos_edge_label_index=[2, 262525],
  },
  (artist, track_by_artist, track)={ edge_index=[2, 84749] },
  (track, rev_track_in_playlist, playlist)={}
)

In [62]:
def sample_negatives_hetero(batch, num_neg=5):
    playlists = batch['playlist','track_in_playlist','track'].pos_edge_label_index[0]
    pos_tracks = batch['playlist','track_in_playlist','track'].pos_edge_label_index[1]

    all_tracks = batch['track'].node_id  
    # e.g., [103, 33, 5002, 22331, 655] local track indices inside sampled subgraph

    neg_tracks = all_tracks[torch.randint(0, len(all_tracks), (len(playlists), num_neg))]

    neg_edge_index = torch.stack((playlists.repeat_interleave(num_neg),
                                  neg_tracks.reshape(-1)), dim=0)

    return neg_edge_index
neg_edge_index = sample_negatives_hetero(train_split, num_neg=1)
neg_edge_index

AttributeError: 'NodeStorage' object has no attribute 'node_id'

# `Negative Sampling for Heterogeneous Graph`

In [24]:
train_split.metadata()

(['track', 'album', 'playlist', 'artist'],
 [('track', 'track_in_playlist', 'playlist'),
  ('track', 'track_in_album', 'album'),
  ('track', 'track_by_artist', 'artist'),
  ('album', 'track_in_album', 'track'),
  ('playlist', 'track_in_playlist', 'track'),
  ('artist', 'track_by_artist', 'track')])

In [25]:
train_input_nodes = ('playlist', torch.arange(train_split['playlist'].num_nodes))
train_input_nodes


('playlist', tensor([   0,    1,    2,  ..., 9997, 9998, 9999]))

In [26]:
num_neighbors = {
    ('playlist','track_in_playlist','track'): [20, 10],  
    ('track','track_in_playlist','playlist'): [20, 10],  

    ('track','track_in_album','album'): [10, 5],
    ('album','track_in_album','track'): [10, 5],

    ('track','track_by_artist','artist'): [10, 5],
    ('artist','track_by_artist','track'): [10, 5],
}
num_neighbors

{('playlist', 'track_in_playlist', 'track'): [20, 10],
 ('track', 'track_in_playlist', 'playlist'): [20, 10],
 ('track', 'track_in_album', 'album'): [10, 5],
 ('album', 'track_in_album', 'track'): [10, 5],
 ('track', 'track_by_artist', 'artist'): [10, 5],
 ('artist', 'track_by_artist', 'track'): [10, 5]}

In [27]:
from torch_geometric.loader import NeighborLoader

train_loader = NeighborLoader(
    train_split,
    num_neighbors=num_neighbors,
    input_nodes=train_input_nodes,
    batch_size=512,
    shuffle=True
)


In [28]:
for batch in train_loader:
    print(batch)
    break


HeteroData(
  track={
    num_nodes=5447,
    n_id=[5447],
    num_sampled_nodes=[3],
  },
  album={
    num_nodes=2919,
    n_id=[2919],
    num_sampled_nodes=[3],
  },
  playlist={
    num_nodes=6952,
    n_id=[6952],
    num_sampled_nodes=[3],
    input_id=[512],
    batch_size=512,
  },
  artist={
    num_nodes=2045,
    n_id=[2045],
    num_sampled_nodes=[3],
  },
  (track, track_in_playlist, playlist)={
    edge_index=[2, 6829],
    e_id=[6829],
    num_sampled_edges=[2],
  },
  (track, track_in_album, album)={
    edge_index=[2, 0],
    e_id=[0],
    num_sampled_edges=[2],
  },
  (track, track_by_artist, artist)={
    edge_index=[2, 0],
    e_id=[0],
    num_sampled_edges=[2],
  },
  (album, track_in_album, track)={
    edge_index=[2, 3702],
    e_id=[3702],
    num_sampled_edges=[2],
  },
  (playlist, track_in_playlist, track)={
    edge_index=[2, 30067],
    pos_edge_label=[30067],
    pos_edge_label_index=[2, 30067],
    e_id=[30067],
    num_sampled_edges=[2],
  },
  (artist

# `Negative Sampling for Heterogenous Graph`

In [29]:
import torch


def sample_negatives(batch, neg_ratio=3):
    """
    Sample random negative edges for playlist-track link prediction.
    
    Args:
        batch: HeteroData batch from neighbor loader
        neg_ratio: Number of negatives per positive edge (default: 3)
        
    Returns:
        neg_edge_index: [2, num_neg] negative edges (playlist_idx, track_idx)
        neg_edge_label: [num_neg] all zeros
    """
    # Get positive edges and batch nodes
    pos_edge_index = batch['playlist', 'track_in_playlist', 'track'].pos_edge_label_index
    batch_playlists = batch['playlist'].n_id
    batch_tracks = batch['track'].n_id
    
    num_pos = pos_edge_index.size(1)
    num_neg = num_pos * neg_ratio
    
    # Create set of positive edges for fast lookup
    pos_set = set(zip(pos_edge_index[0].tolist(), pos_edge_index[1].tolist()))
    
    # Sample negatives
    neg_edges = []
    while len(neg_edges) < num_neg:
        # Random playlist and track from batch
        p_idx = batch_playlists[torch.randint(0, len(batch_playlists), (1,))].item()
        t_idx = batch_tracks[torch.randint(0, len(batch_tracks), (1,))].item()
        
        # Keep if not a positive edge
        if (p_idx, t_idx) not in pos_set:
            neg_edges.append([p_idx, t_idx])
    
    neg_edge_index = torch.tensor(neg_edges, dtype=torch.long).t().to(pos_edge_index.device)
    neg_edge_label = torch.zeros(num_neg, dtype=torch.float, device=pos_edge_index.device)
    
    return neg_edge_index, neg_edge_label
neg_edge_index, neg_edge_label = sample_negatives(batch, neg_ratio=3)

In [44]:
import torch
import torch.nn.functional as F


def sample_hard_negatives(batch, model, device=None, batch_size=500, frac_sample=1.0):
    """
    Sample hard negative edges based on model embeddings.
    
    Args:
        batch: HeteroData batch from neighbor loader
        model: Your GNN model with get_embedding() method
        device: Device to use (default: same as batch)
        batch_size: Batch size for scoring (default: 500)
        frac_sample: Fraction of tracks to consider for sampling (default: 1.0)
        
    Returns:
        neg_edge_index: [2, num_neg] hard negative edges
        neg_edge_label: [num_neg] all zeros
    """
    if device is None:
        device = batch['playlist'].n_id.device
    
    with torch.no_grad():
        # Get embeddings from model using the batch
        embed_dict = model.get_embedding(batch)
        
        # Extract playlist and track embeddings (LOCAL batch embeddings)
        playlist_emb = embed_dict['playlist'].to(device)
        track_emb = embed_dict['track'].to(device)
        
        # Get positive edges (GLOBAL IDs)
        pos_edge_index = batch['playlist', 'track_in_playlist', 'track'].pos_edge_label_index
        positive_playlists_global = pos_edge_index[0]
        positive_tracks_global = pos_edge_index[1]
        num_edges = positive_playlists_global.size(0)
        
        # Get batch node mappings
        batch_playlists = batch['playlist'].n_id
        batch_tracks = batch['track'].n_id
        num_batch_playlists = len(batch_playlists)
        num_batch_tracks = len(batch_tracks)
        
        # Create reverse mappings: global_id → local_idx
        playlist_map = {global_id.item(): local_idx for local_idx, global_id in enumerate(batch_playlists.cpu())}
        track_map = {global_id.item(): local_idx for local_idx, global_id in enumerate(batch_tracks.cpu())}
        
        # Map positive edges from global to local indices
        pos_playlists_local = torch.tensor(
            [playlist_map[pid.item()] for pid in positive_playlists_global.cpu()],
            dtype=torch.long,
            device=device
        )
        pos_tracks_local = torch.tensor(
            [track_map[tid.item()] for tid in positive_tracks_global.cpu()],
            dtype=torch.long,
            device=device
        )
        
        # Create positive edge mask (LOCAL indices)
        positive_mask = torch.zeros(num_batch_playlists, num_batch_tracks, device=device, dtype=torch.bool)
        positive_mask[pos_playlists_local, pos_tracks_local] = True
        
        neg_edges_list = []
        neg_edge_label_list = []
        
        # Process in batches
        for batch_start in range(0, num_edges, batch_size):
            batch_end = min(batch_start + batch_size, num_edges)
            
            # Get local playlist indices for this batch
            playlists_local_batch = pos_playlists_local[batch_start:batch_end]
            
            # Compute similarity scores (using LOCAL embeddings)
            batch_scores = torch.matmul(
                playlist_emb[playlists_local_batch], 
                track_emb.t()
            )
            
            # Mask out positive edges
            batch_scores[positive_mask[playlists_local_batch]] = -float("inf")
            
            # Select top-k highest scoring negative edges
            k = int(frac_sample * 0.99 * num_batch_tracks)
            k = max(1, k)  # Ensure at least 1
            _, top_indices_local = torch.topk(batch_scores, k, dim=1)
            
            # Randomly select one from top-k for each playlist
            selected_indices = torch.randint(0, k, size=(batch_end - batch_start,), device=device)
            top_tracks_local = top_indices_local[torch.arange(batch_end - batch_start), selected_indices]
            
            # Map local indices back to global for output
            playlists_global_batch = positive_playlists_global[batch_start:batch_end]
            tracks_global_batch = batch_tracks[top_tracks_local]
            
            # Create negative edges (GLOBAL IDs)
            neg_edges_batch = torch.stack(
                (playlists_global_batch, tracks_global_batch), dim=0
            )
            neg_edge_label_batch = torch.zeros(neg_edges_batch.shape[1], device=device)
            
            neg_edges_list.append(neg_edges_batch)
            neg_edge_label_list.append(neg_edge_label_batch)
        
        # Concatenate all batches
        neg_edge_index = torch.cat(neg_edges_list, dim=1)
        neg_edge_label = torch.cat(neg_edge_label_list)
        
        return neg_edge_index, neg_edge_label


# Usage:
neg_edge_index, neg_edge_label = sample_hard_negatives(batch, model)  

In [None]:
"""
SAMPLING NEGATIVES EDGES FOR HOMOGENEOUS GRAPHS - ORIGINAL CODE


# def sample_negative_edges_nocheck(data, num_playlists, num_tracks, device = None):
#   # note computationally inefficient to check that these are indeed negative edges
#   playlists = data.edge_label_index[0, :]
#   tracks = torch.randint(num_playlists, num_playlists + num_tracks - 1, size = data.edge_label_index[1, :].size())

#   if playlists.get_device() != -1: # on gpu
#     tracks = tracks.to(device)

#   neg_edge_index = torch.stack((playlists, tracks), dim = 0)
#   neg_edge_label = torch.zeros(neg_edge_index.shape[1])

#   if neg_edge_index.get_device() != -1: # on gpu
#     neg_edge_label = neg_edge_label.to(device)

#   return neg_edge_index, neg_edge_label

# def sample_negative_edges(data, num_playlists, num_tracks, device=None):
#     positive_playlists, positive_tracks = data.edge_label_index

#     # Create a mask tensor with the shape (num_playlists, num_tracks)
#     mask = torch.zeros(num_playlists, num_tracks, device=device, dtype=torch.bool)
#     mask[positive_playlists, positive_tracks - num_playlists] = True

#     # Flatten the mask tensor and get the indices of the negative edges
#     flat_mask = mask.flatten()
#     negative_indices = torch.where(~flat_mask)[0]

#     # Sample negative edges from the negative_indices tensor
#     sampled_negative_indices = negative_indices[
#         torch.randint(0, negative_indices.size(0), size=(positive_playlists.size(0),), device=device)
#     ]

#     # Convert the indices back to playlists and tracks tensors
#     playlists = torch.floor_divide(sampled_negative_indices, num_tracks)
#     tracks = torch.remainder(sampled_negative_indices, num_tracks)
#     tracks = tracks + num_playlists

#     neg_edge_index = torch.stack((playlists, tracks), dim=0)
#     neg_edge_label = torch.zeros(neg_edge_index.shape[1], device=device)

#     return neg_edge_index, neg_edge_label

# def sample_hard_negative_edges(data, model, num_playlists, num_tracks, device=None, batch_size=500, frac_sample = 1):
#     with torch.no_grad():
#         embeddings = model.get_embedding(data.edge_index)
#         playlists_embeddings = embeddings[:num_playlists].to(device)
#         tracks_embeddings = embeddings[num_playlists:].to(device)

#     positive_playlists, positive_tracks = data.edge_label_index
#     num_edges = positive_playlists.size(0)

#     # Create a boolean mask for all the positive edges
#     positive_mask = torch.zeros(num_playlists, num_tracks, device=device, dtype=torch.bool)
#     positive_mask[positive_playlists, positive_tracks - num_playlists] = True

#     neg_edges_list = []
#     neg_edge_label_list = []

#     for batch_start in range(0, num_edges, batch_size):
#         batch_end = min(batch_start + batch_size, num_edges)

#         batch_scores = torch.matmul(
#             playlists_embeddings[positive_playlists[batch_start:batch_end]], tracks_embeddings.t()
#         )

#         # Set the scores of the positive edges to negative infinity
#         batch_scores[positive_mask[positive_playlists[batch_start:batch_end]]] = -float("inf")

#         # Select the top k highest scoring negative edges for each playlist in the current batch
#         # do 0.99 to filter out all pos edges which will be at the end
#         _, top_indices = torch.topk(batch_scores, int(frac_sample * 0.99 * num_tracks), dim=1)
#         selected_indices = torch.randint(0, int(frac_sample * 0.99 *num_tracks), size = (batch_end - batch_start, ))
#         top_indices_selected = top_indices[torch.arange(batch_end - batch_start), selected_indices] + n_playlists

#         # Create the negative edges tensor for the current batch
#         neg_edges_batch = torch.stack(
#             (positive_playlists[batch_start:batch_end], top_indices_selected), dim=0
#         )
#         neg_edge_label_batch = torch.zeros(neg_edges_batch.shape[1], device=device)

#         neg_edges_list.append(neg_edges_batch)
#         neg_edge_label_list.append(neg_edge_label_batch)

#     # Concatenate the batch tensors
#     neg_edges = torch.cat(neg_edges_list, dim=1)
#     neg_edge_label = torch.cat(neg_edge_label_list)

#     return neg_edges, neg_edge_label
"""

In [None]:
# def recall_at_k(data, model, k = 300, batch_size = 64, device = None):
#     with torch.no_grad():
#         embeddings = model.get_embedding(data.edge_index)
#         playlists_embeddings = embeddings[:n_playlists]
#         tracks_embeddings = embeddings[n_playlists:]

#     hits_list = []
#     relevant_counts_list = []

#     for batch_start in range(0, n_playlists, batch_size):
#         batch_end = min(batch_start + batch_size, n_playlists)
#         batch_playlists_embeddings = playlists_embeddings[batch_start:batch_end]

#         # Calculate scores for all possible item pairs
#         scores = torch.matmul(batch_playlists_embeddings, tracks_embeddings.t())

#         # Set the scores of message passing edges to negative infinity
#         mp_indices = ((data.edge_index[0] >= batch_start) & (data.edge_index[0] < batch_end)).nonzero(as_tuple=True)[0]
#         scores[data.edge_index[0, mp_indices] - batch_start, data.edge_index[1, mp_indices] - n_playlists] = -float("inf")

#         # Find the top k highest scoring items for each playlist in the batch
#         _, top_k_indices = torch.topk(scores, k, dim=1)

#         # Ground truth supervision edges
#         ground_truth_edges = data.edge_label_index

#         # Create a mask to indicate if the top k items are in the ground truth supervision edges
#         mask = torch.zeros(scores.shape, device=device, dtype=torch.bool)
#         gt_indices = ((ground_truth_edges[0] >= batch_start) & (ground_truth_edges[0] < batch_end)).nonzero(as_tuple=True)[0]
#         mask[ground_truth_edges[0, gt_indices] - batch_start, ground_truth_edges[1, gt_indices] - n_playlists] = True

#         # Check how many of the top k items are in the ground truth supervision edges
#         hits = mask.gather(1, top_k_indices).sum(dim=1)
#         hits_list.append(hits)

#         # Calculate the total number of relevant items for each playlist in the batch
#         relevant_counts = torch.bincount(ground_truth_edges[0, gt_indices] - batch_start, minlength=batch_end - batch_start)
#         relevant_counts_list.append(relevant_counts)

#     # Compute recall@k
#     hits_tensor = torch.cat(hits_list, dim=0)
#     relevant_counts_tensor = torch.cat(relevant_counts_list, dim=0)
#     # Handle division by zero case
#     recall_at_k = torch.where(
#         relevant_counts_tensor != 0,
#         hits_tensor.true_divide(relevant_counts_tensor),
#         torch.ones_like(hits_tensor)
#     )
#     # take average
#     recall_at_k = torch.mean(recall_at_k)

#     if recall_at_k.numel() == 1:
#         return recall_at_k.item()
#     else:
#         raise ValueError("recall_at_k contains more than one item.")

In [46]:
import torch


def recall_at_k(batch, model, k=300, score_batch_size=64, device=None):
    """
    Calculate Recall@K for heterogeneous batch.
    
    Args:
        batch: HeteroData batch from neighbor loader
        model: Your GNN model with get_embedding() method
        k: Top-k items to consider (default: 300)
        score_batch_size: Batch size for scoring (default: 64)
        device: Device to use
        
    Returns:
        recall_at_k: Scalar recall value
    """
    if device is None:
        device = batch['playlist'].n_id.device
    
    with torch.no_grad():
        # Get embeddings from model using the batch
        embed_dict = model.get_embedding(batch)
        
        # Extract LOCAL batch embeddings
        playlist_emb = embed_dict['playlist'].to(device)
        track_emb = embed_dict['track'].to(device)
        
        num_batch_playlists = playlist_emb.size(0)
        num_batch_tracks = track_emb.size(0)
        
        # Get edges (GLOBAL IDs)
        mp_edge_index = batch['playlist', 'track_in_playlist', 'track'].edge_index
        gt_edge_index = batch['playlist', 'track_in_playlist', 'track'].pos_edge_label_index
        
        # Get batch node mappings
        batch_tracks = batch['track'].n_id
        batch_playlists = batch['playlist'].n_id
        
        # Create reverse mappings: global_id → local_idx
        track_global_to_local = {t.item(): i for i, t in enumerate(batch_tracks.cpu())}
        playlist_global_to_local = {p.item(): i for i, p in enumerate(batch_playlists.cpu())}
        
        hits_list = []
        relevant_counts_list = []
        
        # Process playlists in batches
        for batch_start in range(0, num_batch_playlists, score_batch_size):
            batch_end = min(batch_start + score_batch_size, num_batch_playlists)
            batch_playlist_emb = playlist_emb[batch_start:batch_end]
            
            # Calculate scores for all tracks (using LOCAL embeddings)
            scores = torch.matmul(batch_playlist_emb, track_emb.t())
            
            # Mask out message passing edges (exclude training edges)
            for i in range(mp_edge_index.size(1)):
                p_global = mp_edge_index[0, i].item()
                t_global = mp_edge_index[1, i].item()
                
                # Check if both nodes are in this batch
                if p_global in playlist_global_to_local and t_global in track_global_to_local:
                    p_local = playlist_global_to_local[p_global]
                    t_local = track_global_to_local[t_global]
                    
                    # Check if playlist is in current scoring batch
                    if batch_start <= p_local < batch_end:
                        scores[p_local - batch_start, t_local] = -float("inf")
            
            # Get top-k predictions
            actual_k = min(k, num_batch_tracks)
            _, top_k_indices = torch.topk(scores, actual_k, dim=1)
            
            # Create ground truth mask (LOCAL indices)
            mask = torch.zeros(scores.shape, device=device, dtype=torch.bool)
            
            for i in range(gt_edge_index.size(1)):
                p_global = gt_edge_index[0, i].item()
                t_global = gt_edge_index[1, i].item()
                
                # Check if both nodes are in this batch
                if p_global in playlist_global_to_local and t_global in track_global_to_local:
                    p_local = playlist_global_to_local[p_global]
                    t_local = track_global_to_local[t_global]
                    
                    # Check if playlist is in current scoring batch
                    if batch_start <= p_local < batch_end:
                        mask[p_local - batch_start, t_local] = True
            
            # Count hits (how many ground truth items are in top-k)
            hits = mask.gather(1, top_k_indices).sum(dim=1)
            hits_list.append(hits)
            
            # Count total relevant items per playlist in this scoring batch
            relevant_counts = torch.zeros(batch_end - batch_start, device=device)
            for i in range(gt_edge_index.size(1)):
                p_global = gt_edge_index[0, i].item()
                
                if p_global in playlist_global_to_local:
                    p_local = playlist_global_to_local[p_global]
                    if batch_start <= p_local < batch_end:
                        relevant_counts[p_local - batch_start] += 1
            
            relevant_counts_list.append(relevant_counts)
        
        # Compute recall@k
        hits_tensor = torch.cat(hits_list, dim=0)
        relevant_counts_tensor = torch.cat(relevant_counts_list, dim=0)
        
        # Handle division by zero (playlists with no ground truth edges)
        recall_at_k = torch.where(
            relevant_counts_tensor != 0,
            hits_tensor.float() / relevant_counts_tensor,
            torch.zeros_like(hits_tensor, dtype=torch.float)  # Changed from ones to zeros
        )
        
        # Average recall across all playlists
        recall_at_k = torch.mean(recall_at_k)
        
        return recall_at_k.item()


# Usage:
recall = recall_at_k(batch, model, k=10)

In [48]:
def metrics(labels, preds):
  roc = roc_auc_score(labels.flatten().cpu().numpy(), preds.flatten().data.cpu().numpy())
  return roc

In [33]:
# Train
def train(datasets, model, optimizer, loss_fn, args, neg_samp = "random"):
  print(f"Beginning training for {model.name}")

  train_data = datasets["train"]
  val_data = datasets["val"]

  stats = {
      'train': {
        'loss': [],
        'roc' : []
      },
      'val': {
        'loss': [],
        'recall': [],
        'roc' : []
      }

  }
  val_neg_edge, val_neg_label = None, None
  for epoch in range(args["epochs"]): # loop over each epoch
    model.train()
    optimizer.zero_grad()

    # obtain negative sample
    if neg_samp == "random":
      neg_edge_index, neg_edge_label = sample_negative_edges(train_data, n_playlists, n_tracks, args["device"])
    elif neg_samp == "hard":
      if epoch % 5 == 0:
        neg_edge_index, neg_edge_label = sample_hard_negative_edges(
            train_data, model, n_playlists, n_tracks, args["device"], batch_size = 500,
            frac_sample = 1 - (0.5 * epoch / args["epochs"])
        )
    # calculate embedding
    embed = model.get_embedding(train_data.edge_index)
    # calculate pos, negative scores using embedding
    pos_scores = model.predict_link_embedding(embed, train_data.edge_label_index)
    neg_scores = model.predict_link_embedding(embed, neg_edge_index)

    # concatenate pos, neg scores together and evaluate loss
    scores = torch.cat((pos_scores, neg_scores), dim = 0)
    labels = torch.cat((train_data.edge_label, neg_edge_label), dim = 0)

    # calculate loss function
    if loss_fn == "BCE":
      loss = model.link_pred_loss(scores, labels)
    elif loss_fn == "BPR":
      loss = model.recommendation_loss(pos_scores, neg_scores, lambda_reg = 0)

    train_roc = metrics(labels, scores)

    loss.backward()
    optimizer.step()

    val_loss, val_roc, val_neg_edge, val_neg_label = test(
        model, val_data, loss_fn, neg_samp, epoch, val_neg_edge, val_neg_label
    )

    stats['train']['loss'].append(loss)
    stats['train']['roc'].append(train_roc)
    stats['val']['loss'].append(val_loss)
    stats['val']['roc'].append(val_roc)

    print(f"Epoch {epoch}; Train loss {loss}; Val loss {val_loss}; Train ROC {train_roc}; Val ROC {val_roc}")

    if epoch % 10 == 0:
      # calculate recall @ K
      val_recall = recall_at_k(val_data, model, k = 300, device = args["device"])
      print(f"Val recall {val_recall}")
      stats['val']['recall'].append(val_recall)

    if epoch % 20 == 0:

      # save embeddings for future visualization
      path = os.path.join("model_embeddings", model.name)
      if not os.path.exists(path):
        os.makedirs(path)
      torch.save(model.embedding.weight, os.path.join("model_embeddings", model.name, f"{model.name}_{loss_fn}_{neg_samp}_{epoch}.pt"))

  pickle.dump(stats, open(f"model_stats/{model.name}_{loss_fn}_{neg_samp}.pkl", "wb"))
  return stats

def test(model, data, loss_fn, neg_samp, epoch = 0, neg_edge_index = None, neg_edge_label = None):

  model.eval()
  with torch.no_grad(): # want to save RAM

    # conduct negative sampling
    if neg_samp == "random":
      neg_edge_index, neg_edge_label = sample_negative_edges(data, n_playlists, n_tracks, args["device"])
    elif neg_samp == "hard":
      if epoch % 5 == 0 or neg_edge_index is None:
        neg_edge_index, neg_edge_label = sample_hard_negative_edges(
            data, model, n_playlists, n_tracks, args["device"], batch_size = 500,
            frac_sample = 1 - (0.5 * epoch / args["epochs"])
        )
    # obtain model embedding
    embed = model.get_embedding(data.edge_index)
    # calculate pos, neg scores using embedding
    pos_scores = model.predict_link_embedding(embed, data.edge_label_index)
    neg_scores = model.predict_link_embedding(embed, neg_edge_index)
    # concatenate pos, neg scores together and evaluate loss
    scores = torch.cat((pos_scores, neg_scores), dim = 0)
    labels = torch.cat((data.edge_label, neg_edge_label), dim = 0)
    # calculate loss
    if loss_fn == "BCE":
      loss = model.link_pred_loss(scores, labels)
    elif loss_fn == "BPR":
      loss = model.recommendation_loss(pos_scores, neg_scores, lambda_reg = 0)

    roc = metrics(labels, scores)

  return loss, roc, neg_edge_index, neg_edge_label

In [34]:
# create a dictionary of the dataset splits
datasets = {
    'train':train_split,
    'val':val_split,
    'test': test_split
}

In [35]:
# initialize our arguments
args = {
    'device' : 'cuda' if torch.cuda.is_available() else 'cpu',
    'num_layers' :  3,
    'emb_size' : 64,
    'weight_decay': 1e-5,
    'lr': 0.01,
    'epochs': 1
}

In [36]:
from torch_geometric.nn import to_hetero

In [38]:
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])


In [37]:
# initialize model and and optimizer
num_nodes = n_playlists + n_tracks
model = GCN(
    num_nodes = num_nodes, num_layers = args['num_layers'],
    embedding_dim = args["emb_size"], conv_layer = "SAGE"
)
model = to_hetero(model, train_split.metadata(), aggr = "sum")
optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])

NameError: name 'n_playlists' is not defined

In [None]:
# Only move model to GPU
model.to(args["device"])

# Convert index tensors (if you're using them separately)
playlists_idx = torch.tensor(playlists_idx, dtype=torch.int64).to(args["device"])
tracks_idx = torch.tensor(tracks_idx, dtype=torch.int64).to(args["device"])


GCN(35300, 64, num_layers=3)

In [112]:
# create directory to save model_stats
MODEL_STATS_DIR = "model_stats"
if not os.path.exists(MODEL_STATS_DIR):
  os.makedirs(MODEL_STATS_DIR)

In [114]:
train(datasets, model, optimizer, "BPR", args, neg_samp = "random")

NameError: name 'train' is not defined

In [40]:
test(model, datasets['test'], "BPR", neg_samp = "random")

(tensor(0.1410, device='cuda:0'),
 0.9499165680823833,
 tensor([[12321, 14144,  3385,  ...,  5691, 21204,  1460],
         [29629, 32538, 35098,  ..., 35196, 30857, 25586]], device='cuda:0'),
 tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0'))

In [41]:
def init_model(conv_layer, args, alpha = False):
  num_nodes = n_playlists + n_tracks
  model = GCN(
      num_nodes = num_nodes, num_layers = args['num_layers'],
      embedding_dim = args["emb_size"], conv_layer = conv_layer,
      alpha_learnable = alpha
  )
  model.to(args["device"])
  optimizer = torch.optim.Adam(model.parameters(), lr=args['lr'], weight_decay=args['weight_decay'])
  return model, optimizer

In [42]:
## For example:

# using BPR loss
loss_fn = "BPR"

# using hard sampling
neg_samp = "hard"

# for LGConv:
args['epochs'] = 301
args['num_layers'] = 4
model, optimizer = init_model("LGC", args)
lgc_stats_hard = train(datasets, model, optimizer, loss_fn, args, neg_samp = neg_samp)
torch.save(model.state_dict(), f"model_stats/{model.name}_{loss_fn}_{neg_samp}.pt")

Beginning training for LGCN_LGC_4_e64_nodes35300_
Epoch 0; Train loss 0.6931434273719788; Val loss 0.6931453943252563; Train ROC 0.6034073015044848; Val ROC 0.6759018905552148
Val recall 0.2589329481124878
Epoch 0; Train loss 0.6931434273719788; Val loss 0.6931453943252563; Train ROC 0.6034073015044848; Val ROC 0.6759018905552148
Val recall 0.2589329481124878
Epoch 1; Train loss 0.693144679069519; Val loss 0.6931245923042297; Train ROC 0.7589306002473327; Val ROC 0.7777647373441927
Epoch 2; Train loss 0.6931230425834656; Val loss 0.6929528117179871; Train ROC 0.8139802534992705; Val ROC 0.8013493128548925
Epoch 1; Train loss 0.693144679069519; Val loss 0.6931245923042297; Train ROC 0.7589306002473327; Val ROC 0.7777647373441927
Epoch 2; Train loss 0.6931230425834656; Val loss 0.6929528117179871; Train ROC 0.8139802534992705; Val ROC 0.8013493128548925
Epoch 3; Train loss 0.6929451823234558; Val loss 0.6922587156295776; Train ROC 0.8156420458589091; Val ROC 0.7996414969775105
Epoch 4; T

In [None]:
# for GATConv:
args['epochs'] = 301
args['num_layers'] = 3
model, optimizer = init_model("GAT", args)
gat_stats_hard = train(datasets, model, optimizer, loss_fn, args, neg_samp = neg_samp)
torch.save(model.state_dict(), f"model_stats/{model.name}_{loss_fn}_{neg_samp}.pt")

In [None]:
# for SAGEConv:
args['epochs'] = 301
args['num_layers'] = 3
model, optimizer = init_model("SAGE", args)
sage_stats_hard = train(datasets, model, optimizer, loss_fn, args, neg_samp = neg_samp)
torch.save(model.state_dict(), f"model_stats/{model.name}_{loss_fn}_{neg_samp}.pt")

In [None]:
# using random sampling
neg_samp = "random"

# for LGConv:
args['epochs'] = 301
args['num_layers'] = 4
model, optimizer = init_model("LGC", args)
lgc_stats = train(datasets, model, optimizer, loss_fn, args, neg_samp = neg_samp)
torch.save(model.state_dict(), f"model_stats/{model.name}_{loss_fn}_{neg_samp}.pt")

In [None]:
# for GATConv:
args['epochs'] = 301
args['num_layers'] = 3
model, optimizer = init_model("GAT", args)
gat_stats = train(datasets, model, optimizer, loss_fn, args, neg_samp = neg_samp)
torch.save(model.state_dict(), f"model_stats/{model.name}_{loss_fn}_{neg_samp}.pt")

In [None]:
# for SAGEConv:
args['epochs'] = 301
args['num_layers'] = 3
model, optimizer = init_model("SAGE", args)
sage_stats = train(datasets, model, optimizer, loss_fn, args, neg_samp = neg_samp)
torch.save(model.state_dict(), f"model_stats/{model.name}_{loss_fn}_{neg_samp}.pt")

In [None]:
def detach_loss(stats):
  return [loss.detach().cpu().numpy().item() for loss in stats]

def plot_train_val_loss(stats_dict):
  fig, ax = plt.subplots(1,1, figsize = (6, 4))
  train_loss = detach_loss(stats_dict["train"]["loss"])
  val_loss = detach_loss(stats_dict["val"]["loss"])
  idx = np.arange(0, len(train_loss), 1)
  ax.plot(idx, train_loss, label = "train")
  ax.plot(idx, val_loss, label = "val")
  ax.legend()
  plt.show()

In [None]:
# If you had to stop for whatever reason, you can always reload the stats here! (just uncomment and change to correct paths)
# lgc_stats = pickle.load(open(f"{MODEL_STATS_DIR}/LGCN_LGC_4_e64_nodes34810__BPR_random.pkl", "rb"))
# gat_stats = pickle.load(open(f"{MODEL_STATS_DIR}/LGCN_GAT_3_e64_nodes34810__BPR_random.pkl", "rb"))
# sage_stats = pickle.load(open(f"{MODEL_STATS_DIR}/LGCN_SAGE_3_e64_nodes34810__BPR_random.pkl", "rb"))
# lgc_stats_hard = pickle.load(open(f"{MODEL_STATS_DIR}/LGCN_LGC_4_e64_nodes34810__BPR_hard.pkl", "rb"))
# gat_stats_hard = pickle.load(open(f"{MODEL_STATS_DIR}/LGCN_GAT_3_e64_nodes34810__BPR_hard.pkl", "rb"))
# sage_stats_hard = pickle.load(open(f"{MODEL_STATS_DIR}/LGCN_SAGE_3_e64_nodes34810__BPR_hard.pkl", "rb"))

In [None]:
plot_train_val_loss(lgc_stats)
plot_train_val_loss(lgc_stats_hard)

In [None]:
plot_train_val_loss(sage_stats)
plot_train_val_loss(sage_stats_hard)

In [None]:
plot_train_val_loss(gat_stats)
plot_train_val_loss(gat_stats_hard)

In [None]:
fig, ax = plt.subplots(1,1, figsize = (8, 6))
key = "loss"
lgc_loss = pd.Series(detach_loss(lgc_stats["val"][key])).rolling(3).mean()
gat_loss = pd.Series(detach_loss(gat_stats["val"][key])).rolling(3).mean()
sage_loss = pd.Series(detach_loss(sage_stats["val"][key])).rolling(3).mean()
idx = np.arange(0, len(lgc_loss), 1)

colors = ["#1f77b4", "#ff7f0e", "#2ca02c"]
ax.plot(idx, lgc_loss, color = colors[0], linestyle = 'dashed', label = "LGC - random")
ax.plot(idx, gat_loss, color = colors[1], linestyle = 'dashed', label = "GAT - random")
ax.plot(idx, sage_loss, color = colors[2], linestyle = 'dashed', label = "SAGE - random")
ax.legend(loc = 'lower right')

ax.set_xlabel("Epochs")
ax.set_ylabel("BPR Loss")
ax.set_title("Model BPR Loss, by convolution type and negative sampling")
ax.set_ylim(0, 0.7)
plt.show()

In [None]:
fig, ax = plt.subplots(1,1, figsize = (8, 6))
key = "loss"
lgc_loss = pd.Series(detach_loss(lgc_stats["val"][key])).rolling(3).mean()
gat_loss = pd.Series(detach_loss(gat_stats["val"][key])).rolling(3).mean()
sage_loss = pd.Series(detach_loss(sage_stats["val"][key])).rolling(3).mean()
lgc_hard_loss = pd.Series(detach_loss(lgc_stats_hard["val"][key])).rolling(3).mean()
gat_hard_loss = pd.Series(detach_loss(gat_stats_hard["val"][key])).rolling(3).mean()
sage_hard_loss = pd.Series(detach_loss(sage_stats_hard["val"][key])).rolling(3).mean()
idx = np.arange(0, len(lgc_loss), 1)

colors = ["#1f77b4", "#ff7f0e", "#2ca02c"]
ax.plot(idx, lgc_loss, color = colors[0], linestyle = 'dashed', label = "LGC - random")
ax.plot(idx, lgc_hard_loss, color = colors[0], label = "LGC - hard")
ax.plot(idx, gat_loss, color = colors[1], linestyle = 'dashed', label = "GAT - random")
ax.plot(idx, gat_hard_loss, color = colors[1], label = "GAT - hard")
ax.plot(idx, sage_loss, color = colors[2], linestyle = 'dashed', label = "SAGE - random")
ax.plot(idx, sage_hard_loss, color = colors[2], label = "SAGE - hard")
ax.legend(loc = 'lower left')

ax.set_xlabel("Epochs")
ax.set_ylabel("BPR Loss")
ax.set_title("Model BPR Loss, by convolution type and negative sampling")
ax.set_ylim(0, 0.7)
plt.show()

In [None]:
fig, ax = plt.subplots(1,1, figsize = (8, 6))
key = "recall"
lgc_recall = lgc_stats["val"][key]
gat_recall = gat_stats["val"][key]
sage_recall = sage_stats["val"][key]
lgc_hard_recall = lgc_stats_hard["val"][key]
gat_hard_recall = gat_stats_hard["val"][key]
sage_hard_recall = sage_stats_hard["val"][key]
# increment by 10
idx = np.arange(0, 10 * len(lgc_recall), 10)

colors = ["#1f77b4", "#ff7f0e", "#2ca02c"]
ax.plot(idx, lgc_recall, color = colors[0], linestyle = 'dashed', label = "LGC - random")
ax.plot(idx, lgc_hard_recall, color = colors[0], label = "LGC - hard")
ax.plot(idx, gat_recall, color = colors[1], linestyle = 'dashed', label = "GAT - random")
ax.plot(idx, gat_hard_recall, color = colors[1], label = "GAT - hard")
ax.plot(idx, sage_recall, color = colors[2], linestyle = 'dashed', label = "SAGE - random")
ax.plot(idx, sage_hard_recall, color = colors[2], label = "SAGE - hard")
ax.legend(loc = 'lower right')

ax.set_xlabel("Epochs")
ax.set_ylabel("Recall@300")
ax.set_title("Model Recall@300, by convolution type and negative sampling")
ax.set_ylim(0, 0.7)
plt.show()