In [2]:
import torch
from torch_geometric.data import HeteroData
import os.path as path
import pandas as pd
import numpy as np
from tqdm import tqdm

In [3]:
data_folder = "ds/"

In [4]:
data = HeteroData()

In [5]:
data["artist"].x = torch.load(path.join(data_folder, "artists.pt"), weights_only=True)
print("Artist tensor shape:", data["artist"].x.shape)

data["track"].x = torch.load(path.join(data_folder, "tracks.pt"), weights_only=True)
print("Track tensor shape:", data["track"].x.shape)

data["tag"].x = torch.load(path.join(data_folder, "tags.pt"), weights_only=True)
print("Tag tensor shape:", data["tag"].x.shape)


data["artist", "collab_with", "artist"].edge_index = torch.load(path.join(data_folder, "collab_with.pt"), weights_only=True)
data["artist", "collab_with", "artist"].edge_attr = torch.load(path.join(data_folder, "collab_with_attr.pt"), weights_only=True)
print("collab_with index tensor shape:", data["artist", "collab_with", "artist"].edge_index.shape)
print("collab_with attr tensor shape:", data["artist", "collab_with", "artist"].edge_attr.shape)

data["artist", "has_tag_artists", "tag"].edge_index = torch.load(path.join(data_folder, "has_tag_artists.pt"), weights_only=True)
data["track", "has_tag_tracks", "tag"].edge_index = torch.load(path.join(data_folder, "has_tag_tracks.pt"), weights_only=True)
print("has_tag_artists index tensor shape:", data["artist", "has_tag_artists", "tag"].edge_index.shape)
print("has_tag_tracks index tensor shape:", data["track", "has_tag_tracks", "tag"].edge_index.shape)

data["artist", "last_fm_match", "artist"].edge_index = torch.load(path.join(data_folder, "last_fm_match.pt"), weights_only=True)
data["artist", "last_fm_match", "artist"].edge_attr = torch.load(path.join(data_folder, "last_fm_match_attr.pt"), weights_only=True)
print("last_fm_match index tensor shape:", data["artist", "last_fm_match", "artist"].edge_index.shape)
print("last_fm_match attr tensor shape:", data["artist", "last_fm_match", "artist"].edge_attr.shape)

data["artist", "linked_to", "artist"].edge_index = torch.load(path.join(data_folder, "linked_to.pt"), weights_only=True)
data["artist", "linked_to", "artist"].edge_attr = torch.load(path.join(data_folder, "linked_to_attr.pt"), weights_only=True)
print("linked_to index tensor shape:", data["artist", "linked_to", "artist"].edge_index.shape)
print("linked_to attr tensor shape:", data["artist", "linked_to", "artist"].edge_attr.shape)

data["artist", "musically_related_to", "artist"].edge_index = torch.load(path.join(data_folder, "musically_related_to.pt"), weights_only=True)
data["artist", "musically_related_to", "artist"].edge_attr = torch.load(path.join(data_folder, "musically_related_to_attr.pt"), weights_only=True)
print("musically_related_to index tensor shape:", data["artist", "musically_related_to", "artist"].edge_index.shape)
print("musically_related_to attr tensor shape:", data["artist", "musically_related_to", "artist"].edge_attr.shape)

data["artist", "personally_related_to", "artist"].edge_index = torch.load(path.join(data_folder, "personally_related_to.pt"), weights_only=True)
data["artist", "personally_related_to", "artist"].edge_attr = torch.load(path.join(data_folder, "personally_related_to_attr.pt"), weights_only=True)
print("personally_related_to index tensor shape:", data["artist", "personally_related_to", "artist"].edge_index.shape)
print("personally_related_to attr tensor shape:", data["artist", "personally_related_to", "artist"].edge_attr.shape)

data["tag", "tags_artists", "artist"].edge_index = torch.load(path.join(data_folder, "tags_artists.pt"), weights_only=True)
data["tag", "tags_track", "track"].edge_index = torch.load(path.join(data_folder, "tags_tracks.pt"), weights_only=True)
print("tags_artists index tensor shape:", data["tag", "tags_artists", "artist"].edge_index.shape)
print("tags_tracks index tensor shape:", data["tag", "tags_track", "track"].edge_index.shape)

data["track", "worked_by", "artist"].edge_index = torch.load(path.join(data_folder, "worked_by.pt"), weights_only=True)
data["artist", "worked_in", "track"].edge_index = torch.load(path.join(data_folder, "worked_in.pt"), weights_only=True)
print("worked_by index tensor shape:", data["track", "worked_by", "artist"].edge_index.shape)
print("worked_in index tensor shape:", data["artist", "worked_in", "track"].edge_index.shape)

print()

data.validate()

Artist tensor shape: torch.Size([1489250, 16])
Track tensor shape: torch.Size([24324100, 4])
Tag tensor shape: torch.Size([23, 24])
collab_with index tensor shape: torch.Size([2, 2463052])
collab_with attr tensor shape: torch.Size([2463052, 1])
has_tag_artists index tensor shape: torch.Size([2, 2410207])
has_tag_tracks index tensor shape: torch.Size([2, 4030735])
last_fm_match index tensor shape: torch.Size([2, 154865250])
last_fm_match attr tensor shape: torch.Size([154865250, 1])
linked_to index tensor shape: torch.Size([2, 23128])
linked_to attr tensor shape: torch.Size([23128, 1])
musically_related_to index tensor shape: torch.Size([2, 373262])
musically_related_to attr tensor shape: torch.Size([373262, 1])
personally_related_to index tensor shape: torch.Size([2, 26720])
personally_related_to attr tensor shape: torch.Size([26720, 1])
tags_artists index tensor shape: torch.Size([2, 2410207])
tags_tracks index tensor shape: torch.Size([2, 4030735])
worked_by index tensor shape: torch

True

In [6]:
# OPTIONAL SUBGRAPH

edge_types = [
    ("artist", "collab_with", "artist"),
    ("artist", "has_tag_artists", "tag"),
    ("track", "has_tag_tracks", "tag"),
    ("artist", "last_fm_match", "artist"),
    ("artist", "linked_to", "artist"),
    ("artist", "musically_related_to", "artist"),
    ("artist", "personally_related_to", "artist"),
    ("tag", "tags_artists", "artist"),
    ("tag", "tags_track", "track"),
    ("track", "worked_by", "artist"),
    ("artist", "worked_in", "track")
]

if False:

    # Data
    percentile = 0.85
    artist_popularity = data["artist"].x[:, 8]

    # Threshold obtention
    threshold = torch.quantile(artist_popularity, percentile)
    selected_artists = artist_popularity >= threshold
    selected_artist_ids = torch.nonzero(selected_artists).squeeze()

    # Mapping
    old_to_new_artist_idx = torch.zeros(
        data["artist"].x.shape[0],
        dtype=torch.long
    )

    for i, selected_artist_id in enumerate(selected_artist_ids):
        old_to_new_artist_idx[selected_artist_id] = i

    # Subgraph
    for edge_type in edge_types:
        print(f"edge_type: {edge_type}")
        # Filter edge indices
        edge_index = data[edge_type].edge_index
        mask = torch.ones(edge_index.shape[1], dtype=torch.bool)
        if edge_type[0] == "artist":
            mask &= torch.isin(edge_index[0], selected_artist_ids)
        if edge_type[2] == "artist":
            mask &= torch.isin(edge_index[1], selected_artist_ids)

        filtered_edge_index = edge_index[:, mask]

        # Map the old indices to new ones for 'artist' nodes
        # if edge_type[0] == "artist":  # Reindex source node
        #     filtered_edge_index[0] = torch.tensor(
        #         [old_to_new_artist_idx[idx.item()] for idx in filtered_edge_index[0]],
        #         dtype=torch.long,
        #     )
        # if edge_type[2] == "artist":  # Reindex destination node
        #     filtered_edge_index[1] = torch.tensor(
        #         [old_to_new_artist_idx[idx.item()] for idx in filtered_edge_index[1]],
        #         dtype=torch.long,
        #     )
        if edge_type[0] == "artist":  # Reindex source node
            filtered_edge_index[0] = old_to_new_artist_idx[filtered_edge_index[0]]
        if edge_type[2] == "artist":  # Reindex destination node
            filtered_edge_index[1] = old_to_new_artist_idx[filtered_edge_index[1]]

        # Assign filtered edges to subgraph
        data[edge_type].edge_index = filtered_edge_index

        # Handle edge attributes if they exist
        if hasattr(data[edge_type], "edge_attr"):
            try:
                data[edge_type].edge_attr = data[edge_type].edge_attr[mask]
            except IndexError as e:
                print(f"IndexError for {edge_type}: {e}")
        else:
            print(f"No edge_attr for {edge_type}")

    # Nodes filtering
    data["artist"].x = data["artist"].x[selected_artist_ids]
    # data["track"].x = data["track"].x
    # data["tag"].x = data["tag"].x

# Check the shape of the filtered (or not) nodes and edges
for edge_type in edge_types:
    print(f"Edge type: {edge_type}, edge_index shape: {data[edge_type].edge_index.shape}")

# Check the artist features (should only have the selected artists)
print("Subgraph artist tensor shape:", data["artist"].x.shape)
print("Subgraph track tensor shape:", data["track"].x.shape)
print("Subgraph tag tensor shape:", data["tag"].x.shape)

print("\n")

# Validate the subgraph
try:
    data.validate()
    print("Validation successful.")
except ValueError as e:
    print("Validation failed:", e)

Edge type: ('artist', 'collab_with', 'artist'), edge_index shape: torch.Size([2, 2463052])
Edge type: ('artist', 'has_tag_artists', 'tag'), edge_index shape: torch.Size([2, 2410207])
Edge type: ('track', 'has_tag_tracks', 'tag'), edge_index shape: torch.Size([2, 4030735])
Edge type: ('artist', 'last_fm_match', 'artist'), edge_index shape: torch.Size([2, 154865250])
Edge type: ('artist', 'linked_to', 'artist'), edge_index shape: torch.Size([2, 23128])
Edge type: ('artist', 'musically_related_to', 'artist'), edge_index shape: torch.Size([2, 373262])
Edge type: ('artist', 'personally_related_to', 'artist'), edge_index shape: torch.Size([2, 26720])
Edge type: ('tag', 'tags_artists', 'artist'), edge_index shape: torch.Size([2, 2410207])
Edge type: ('tag', 'tags_track', 'track'), edge_index shape: torch.Size([2, 4030735])
Edge type: ('track', 'worked_by', 'artist'), edge_index shape: torch.Size([2, 27661673])
Edge type: ('artist', 'worked_in', 'track'), edge_index shape: torch.Size([2, 27661

In [7]:
# Track list
cut_year = 2020
cut_month = 3

df = pd.read_csv("../data/year_month_track.csv")
df["track_ids"] = df.track_ids.apply(eval)
df.head()

Unnamed: 0,year,month,track_ids
0,2016,12,"[0, 195, 350, 366, 458, 749, 1014, 1352, 1552,..."
1,2000,6,"[1, 4005, 4028, 4935, 9400, 9504, 9717, 12368,..."
2,2013,4,"[2, 199, 551, 670, 1136, 1300, 1519, 2253, 242..."
3,1997,13,"[3, 42, 68, 136, 418, 438, 541, 543, 619, 662,..."
4,2010,9,"[4, 348, 353, 578, 1532, 2345, 2358, 2479, 252..."


In [8]:
mask = (df["year"] < cut_year) | ((df["year"] == cut_year) & (df["month"] < cut_month))
train_tracks_neo4j = df[mask]["track_ids"].explode().unique().tolist()

In [9]:
import pickle

with open(path.join(data_folder, "track_map.pkl"), "rb") as in_file:
    track_map = pickle.load(in_file)

In [10]:
train_tracks_pyg = [track_map[track_id] for track_id in train_tracks_neo4j]
train_tracks_pyg_t = torch.tensor(train_tracks_pyg)
train_artists_pyg = data["artist", "worked_in", "track"].edge_index[0, :][
    torch.isin(data["artist", "worked_in", "track"].edge_index[1, :], train_tracks_pyg_t)
]
set_tracks_pyg = list(
    set(range(data["track"].x.shape[0])) - set(train_tracks_pyg)
)

In [11]:
train_data = data.subgraph({
    "track": train_tracks_pyg_t,
    "artist": train_artists_pyg
})

In [12]:
collab_with_edge_index = train_data["artist", "collab_with", "artist"].edge_index[:, ::2]
collab_with_edge_attr = train_data["artist", "collab_with", "artist"].edge_attr[::2]
n_collabs = collab_with_edge_attr.shape[0]
worked_in_edge_index = train_data["artist", "worked_in", "track"].edge_index

# Prepare a dictionary to quickly map artists to their tracks
print("Computing unique artists in collab_with")
unique_artists = torch.unique(torch.cat((collab_with_edge_index[0, :], collab_with_edge_index[1, :])))

print("Building dict...")
artist_tracks_dict = {}
# TODO: OPTIMIZE THIS
# FIXME
for artist in tqdm(unique_artists):
    artist_tracks_dict[artist.item()] = worked_in_edge_index[1, worked_in_edge_index[0, :] == artist]

# Collect the new collaboration edges
new_collab_with_edge_index = list()
new_collab_with_edge_attr = list()
print("Building lists...")
for i, (a0, a1) in enumerate(zip(collab_with_edge_index[0, :], collab_with_edge_index[1, :])):
    print(f"{i + 1} out of {n_collabs}")
    a0_item = a0.item()
    a1_item = a1.item()
    intersection_len = len(np.intersect1d(artist_tracks_dict[a0_item], artist_tracks_dict[a1_item]))
    if intersection_len > 0:
        new_collab_with_edge_index.append((a0_item, a1_item))
        new_collab_with_edge_index.append((a1_item, a0_item))
        new_collab_with_edge_attr.extend([intersection_len, intersection_len])

train_data["artist", "collab_with", "artist"].edge_index = new_collab_with_edge_index
train_data["artist", "collab_with", "artist"].edge_attr = new_collab_with_edge_attr

train_data.validate()

Computing unique artists in collab_with
Building dict...


  0%|          | 0/477152 [00:00<?, ?it/s]


AttributeError: 'EdgeStorage' object has no attribute 'get_neighbors'

In [6]:
import torch_geometric.transforms as T

transform = T.RandomLinkSplit(
    num_val=0.15,
    num_test=0.15,
    disjoint_train_ratio=0.3,
    neg_sampling_ratio=1,
    add_negative_train_samples=False,
    edge_types=("artist", "collab_with", "artist")
)

train_data, val_data, test_data = transform(data)

print("Training data:")
print("==============")
print(train_data)
print()
print("Validation data:")
print("================")
print(val_data)
print()
print("Test data:")
print("================")
print(test_data)

print(f"Training edges: {train_data['artist', 'collab_with', 'artist'].edge_index.shape[1]}")
print(f"Validation edges: {val_data['artist', 'collab_with', 'artist'].edge_index.shape[1]}")
print(f"Test edges: {test_data['artist', 'collab_with', 'artist'].edge_index.shape[1]}")

train_edges = set(map(tuple, train_data['artist', 'collab_with', 'artist'].edge_index.T.tolist()))
val_edges = set(map(tuple, val_data['artist', 'collab_with', 'artist'].edge_index.T.tolist()))
test_edges = set(map(tuple, test_data['artist', 'collab_with', 'artist'].edge_index.T.tolist()))

print(f"Train-Val Overlap: {len(train_edges & val_edges)}")  # Should be 0
print(f"Train-Test Overlap: {len(train_edges & test_edges)}")  # Should be 0
print(f"Val-Test Overlap: {len(val_edges & test_edges)}")  # Should be 0

Training data:
HeteroData(
  artist={ x=[223388, 16] },
  track={ x=[24324100, 4] },
  tag={ x=[23, 24] },
  (artist, collab_with, artist)={
    edge_index=[2, 308377],
    edge_attr=[308377, 1],
    edge_label=[132161],
    edge_label_index=[2, 132161],
  },
  (artist, has_tag_artists, tag)={ edge_index=[2, 1042766] },
  (track, has_tag_tracks, tag)={ edge_index=[2, 4030735] },
  (artist, last_fm_match, artist)={
    edge_index=[2, 28357816],
    edge_attr=[28357816, 1],
  },
  (artist, linked_to, artist)={
    edge_index=[2, 1438],
    edge_attr=[1438, 1],
  },
  (artist, musically_related_to, artist)={
    edge_index=[2, 41760],
    edge_attr=[41760, 1],
  },
  (artist, personally_related_to, artist)={
    edge_index=[2, 3334],
    edge_attr=[3334, 1],
  },
  (tag, tags_artists, artist)={ edge_index=[2, 1042766] },
  (tag, tags_track, track)={ edge_index=[2, 4030735] },
  (track, worked_by, artist)={ edge_index=[2, 12509457] },
  (artist, worked_in, track)={ edge_index=[2, 12509457]

In [7]:
# # Get edge index
# edge_index = data['artist', 'collab_with', 'artist'].edge_index

# # Shuffle edges
# num_edges = edge_index.shape[1]
# perm = torch.randperm(num_edges)
# edge_index = edge_index[:, perm]

# # Define sizes
# num_test = int(0.15 * num_edges)
# num_val = int(0.15 * num_edges)
# num_train = num_edges - num_val - num_test

# # Split
# train_edges = edge_index[:, :num_train]
# val_edges = edge_index[:, num_train:num_train + num_val]
# test_edges = edge_index[:, num_train + num_val:]

# # Verify disjoint sets
# train_set = set(map(tuple, train_edges.T.tolist()))
# val_set = set(map(tuple, val_edges.T.tolist()))
# test_set = set(map(tuple, test_edges.T.tolist()))

# print(f"Train-Val Overlap: {len(train_set & val_set)}")  # Should be 0
# print(f"Train-Test Overlap: {len(train_set & test_set)}")  # Should be 0
# print(f"Val-Test Overlap: {len(val_set & test_set)}")  # Should be 0

# # Store the new edge splits in PyG format
# train_data = data.clone()
# train_data['artist', 'collab_with', 'artist'].edge_index = train_edges

# val_data = data.clone()
# val_data['artist', 'collab_with', 'artist'].edge_index = val_edges

# test_data = data.clone()
# test_data['artist', 'collab_with', 'artist'].edge_index = test_edges

# print(f"Training edges: {train_data['artist', 'collab_with', 'artist'].edge_index.shape[1]}")
# print(f"Validation edges: {val_data['artist', 'collab_with', 'artist'].edge_index.shape[1]}")
# print(f"Test edges: {test_data['artist', 'collab_with', 'artist'].edge_index.shape[1]}")

# from torch_geometric.utils import negative_sampling

# def create_edge_labels(data):
#     edge_index = data['artist', 'collab_with', 'artist'].edge_index
#     num_nodes = data['artist'].num_nodes

#     # Generate negative edges
#     neg_edge_index = negative_sampling(
#         edge_index=edge_index,
#         num_nodes=num_nodes,
#         num_neg_samples=edge_index.shape[1]  # 1:1 ratio of positive to negative samples
#     )

#     # Concatenate positive and negative edges
#     edge_label_index = torch.cat([edge_index, neg_edge_index], dim=1)
#     edge_label = torch.cat([torch.ones(edge_index.shape[1]), torch.zeros(neg_edge_index.shape[1])], dim=0)

#     return edge_label_index, edge_label

# # Apply to train, val, test
# train_edge_label_index, train_edge_label = create_edge_labels(train_data)
# val_edge_label_index, val_edge_label = create_edge_labels(val_data)
# test_edge_label_index, test_edge_label = create_edge_labels(test_data)

# # Assign to datasets
# train_data['artist', 'collab_with', 'artist'].edge_label_index = train_edge_label_index
# train_data['artist', 'collab_with', 'artist'].edge_label = train_edge_label

# val_data['artist', 'collab_with', 'artist'].edge_label_index = val_edge_label_index
# val_data['artist', 'collab_with', 'artist'].edge_label = val_edge_label

# test_data['artist', 'collab_with', 'artist'].edge_label_index = test_edge_label_index
# test_data['artist', 'collab_with', 'artist'].edge_label = test_edge_label


In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: '{device}'")

Device: 'cuda'


In [9]:
from torch_geometric.loader import LinkNeighborLoader

compt_tree_size = [25, 20]

edge_label_index = train_data["artist", "collab_with", "artist"].edge_label_index
edge_label = train_data["artist", "collab_with", "artist"].edge_label

print("Creating train_loader...")
train_loader = LinkNeighborLoader(
    data=train_data,
    num_neighbors=compt_tree_size,
    neg_sampling_ratio=1,
    edge_label_index=(("artist", "collab_with", "artist"), edge_label_index),
    edge_label=edge_label,
    batch_size=64,
    shuffle=True,
    num_workers=10,
    pin_memory=True,
)

edge_label_index = val_data["artist", "collab_with", "artist"].edge_label_index
edge_label = val_data["artist", "collab_with", "artist"].edge_label

print("Creating val_loader...")
val_loader = LinkNeighborLoader(
    data=val_data,
    num_neighbors=compt_tree_size,
    edge_label_index=(("artist", "collab_with", "artist"), edge_label_index),
    edge_label=edge_label,
    batch_size=64,
    shuffle=False,
    num_workers=10,
    pin_memory=True,
)

print("Sampling mini-batch...")

sampled_data = next(iter(train_loader))

print("Sampled mini-batch:")
print("===================")
print(sampled_data)

Creating train_loader...




Creating val_loader...
Sampling mini-batch...
Sampled mini-batch:
HeteroData(
  artist={
    x=[90052, 16],
    n_id=[90052],
  },
  track={
    x=[119352, 4],
    n_id=[119352],
  },
  tag={
    x=[23, 24],
    n_id=[23],
  },
  (artist, collab_with, artist)={
    edge_index=[2, 35614],
    edge_attr=[35614, 1],
    edge_label=[128],
    edge_label_index=[2, 128],
    e_id=[35614],
    input_id=[64],
  },
  (artist, has_tag_artists, tag)={
    edge_index=[2, 460],
    e_id=[460],
  },
  (track, has_tag_tracks, tag)={
    edge_index=[2, 460],
    e_id=[460],
  },
  (artist, last_fm_match, artist)={
    edge_index=[2, 155658],
    edge_attr=[155658, 1],
    e_id=[155658],
  },
  (artist, linked_to, artist)={
    edge_index=[2, 159],
    edge_attr=[159, 1],
    e_id=[159],
  },
  (artist, musically_related_to, artist)={
    edge_index=[2, 3251],
    edge_attr=[3251, 1],
    e_id=[3251],
  },
  (artist, personally_related_to, artist)={
    edge_index=[2, 332],
    edge_attr=[332, 1],
    

In [10]:
debug = False
if debug:
    print(torch.unique(train_data['artist', 'collab_with', 'artist'].edge_label))
    print(torch.unique(next(iter(train_loader))["artist", "collab_with", "artist"].edge_label))
    print(torch.unique(val_data['artist', 'collab_with', 'artist'].edge_label))
    print(torch.unique(next(iter(val_loader))["artist", "collab_with", "artist"].edge_label))
    print(torch.unique(test_data['artist', 'collab_with', 'artist'].edge_label))
    print(torch.unique(next(iter(test_loader))["artist", "collab_with", "artist"].edge_label))


In [11]:
from torch_geometric.nn import HeteroConv, GATConv, SAGEConv, Linear
import torch.nn.functional as F

class GNN(torch.nn.Module):
    def __init__(self, metadata, hidden_channels, out_channels):
        super().__init__()
        self.metadata = metadata
        self.out_channels = out_channels

        self.conv1 = HeteroConv({
            ("artist", "collab_with", "artist"): GATConv((-1, -1), hidden_channels),
            ("artist", "has_tag_artists", "tag"): SAGEConv((-1, -1), hidden_channels),
            ("artist", "last_fm_match", "artist"): GATConv((-1, -1), hidden_channels),
            ("track", "has_tag_tracks", "tag"): SAGEConv((-1, -1), hidden_channels),
            ("artist", "linked_to", "artist"): GATConv((-1, -1), hidden_channels),
            ("artist", "musically_related_to", "artist"): GATConv((-1, -1), hidden_channels),
            ("artist", "personally_related_to", "artist"): GATConv((-1, -1), hidden_channels),
            ("tag", "tags_artists", "artist"): SAGEConv((-1, -1), hidden_channels),
            ("tag", "tags_tracks", "track"): SAGEConv((-1, -1), hidden_channels),
            ("track", "worked_by", "artist"): SAGEConv((-1, -1), hidden_channels),
            ("artist", "worked_in", "track"): SAGEConv((-1, -1), hidden_channels),
        }, aggr="mean")

        self.conv2 = HeteroConv({
            ("artist", "collab_with", "artist"): GATConv((-1, -1), hidden_channels),
            ("artist", "has_tag_artists", "tag"): SAGEConv((-1, -1), hidden_channels),
            ("artist", "last_fm_match", "artist"): GATConv((-1, -1), hidden_channels),
            ("track", "has_tag_tracks", "tag"): SAGEConv((-1, -1), hidden_channels),
            ("artist", "linked_to", "artist"): GATConv((-1, -1), hidden_channels),
            ("artist", "musically_related_to", "artist"): GATConv((-1, -1), hidden_channels),
            ("artist", "personally_related_to", "artist"): GATConv((-1, -1), hidden_channels),
            ("tag", "tags_artists", "artist"): SAGEConv((-1, -1), hidden_channels),
            ("tag", "tags_tracks", "track"): SAGEConv((-1, -1), hidden_channels),
            ("track", "worked_by", "artist"): SAGEConv((-1, -1), hidden_channels),
            ("artist", "worked_in", "track"): SAGEConv((-1, -1), hidden_channels),
        }, aggr="mean")

        self.linear1 = Linear(hidden_channels * 2, hidden_channels * 4)
        self.linear2 = Linear(hidden_channels * 4, out_channels)

    def forward(self, x_dict, edge_index_dict):
        x_dict1 = self.conv1(x_dict, edge_index_dict)
        x_dict2 = self.conv2(x_dict1, edge_index_dict)

        x_artist = torch.cat([x_dict1['artist'], x_dict2['artist']], dim=-1)

        x_artist = self.linear1(x_artist)
        x_artist = self.linear2(x_artist)

        # Normalize the artist node features
        x_artist = F.normalize(x_artist, p=2, dim=-1)

        # Update the dictionary with the new 'artist' features, leaving other nodes unchanged
        x_dict['artist'] = x_artist

        return x_dict

In [12]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix
import tqdm
import numpy as np

def train(model, train_loader, val_loader, optimizer, criterion, device, num_epochs):
    for epoch in range(num_epochs):
        model.train()  # Set model to training mode
        epoch_loss = 0.0
        
        for sampled_data in tqdm.tqdm(train_loader):
            # Move data to device
            sampled_data = sampled_data.to(device)
            
            # Forward pass
            pred_dict = model(sampled_data.x_dict, sampled_data.edge_index_dict)
            
            # Get predictions and labels for the 'collab_with' edge type
            edge_label_index = sampled_data['artist', 'collab_with', 'artist'].edge_label_index
            edge_label = sampled_data['artist', 'collab_with', 'artist'].edge_label

            src_emb = pred_dict['artist'][edge_label_index[0]]  # Source node embeddings
            dst_emb = pred_dict['artist'][edge_label_index[1]]  # Destination node embeddings
            
            # Compute the dot product between source and destination embeddings
            preds = (src_emb * dst_emb).sum(dim=-1)  # Scalar for each edge
            
            # Compute loss
            loss = criterion(preds, edge_label.float())
            epoch_loss += loss.item()
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        # Average loss for the epoch
        epoch_loss /= len(train_loader)
        print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {epoch_loss:.4f}")
        
        # Validation metrics
        model.eval()  # Set model to evaluation mode
        all_labels = []
        all_probs = []
        val_loss = 0.0
        
        with torch.no_grad():  # Disable gradient computation for validation
            for sampled_data in tqdm.tqdm(val_loader):
                # Move data to device
                sampled_data = sampled_data.to(device)
                
                # Forward pass
                pred_dict = model(sampled_data.x_dict, sampled_data.edge_index_dict)
                
                # Get predictions and labels for the 'collab_with' edge type
                edge_label_index = sampled_data['artist', 'collab_with', 'artist'].edge_label_index
                edge_label = sampled_data['artist', 'collab_with', 'artist'].edge_label

                src_emb = pred_dict['artist'][edge_label_index[0]]  # Source node embeddings
                dst_emb = pred_dict['artist'][edge_label_index[1]]  # Destination node embeddings
                
                # Compute the dot product between source and destination embeddings
                preds = (src_emb * dst_emb).sum(dim=-1)  # Scalar for each edge

                probs = torch.sigmoid(preds)  # Convert to probabilities

                loss = criterion(preds, edge_label.float())
                val_loss += loss.item()
                
                # Collect predictions, probabilities, and labels
                all_labels.append(edge_label.cpu())
                all_probs.append(probs.cpu())
        
        # Concatenate all predictions and labels
        all_labels = torch.cat(all_labels)
        all_probs = torch.cat(all_probs)
        val_loss /= len(val_loader)

        # Find threshold for predictions
        best_threshold = 0
        best_f1 = 0
        for threshold in np.arange(0.2, 0.81, 0.01):
            preds_binary = (all_probs > threshold).long()
            cm = confusion_matrix(all_labels, preds_binary)
            tp = cm[1, 1]
            fp = cm[0, 1]
            fn = cm[1, 0]
            tn = cm[0, 0]
            precision = 0 if tp == 0 else tp / (tp + fp)
            recall = 0 if tp == 0 else tp / (tp + fn)
            f1 = 0 if precision * recall == 0 else 2 * precision * recall / (precision + recall)
            if f1 > best_f1:
                best_threshold = threshold
                best_f1 = f1
        print(f"Best threshold: {best_threshold}")
        all_preds = (all_probs > best_threshold).long()
        
        # Compute metrics
        cm = confusion_matrix(all_labels, all_preds)
        tp = cm[1, 1]
        fp = cm[0, 1]
        fn = cm[1, 0]
        tn = cm[0, 0]
        accuracy = (tp + tn) / (tp + fp + fn + tn)
        precision = tp / (tp + fp)
        recall = tp / (tp + fn)
        f1 = 2 * precision * recall / (precision + recall)
        roc_auc = roc_auc_score(all_labels, all_probs)
        
        # Print validation metrics
        print(f"Validation Metrics - Epoch {epoch+1}/{num_epochs}:")
        print(f"Loss:      {val_loss:.4f}")
        print(f"Accuracy:  {accuracy:.4f}")
        print(f"Precision: {precision:.4f}")
        print(f"Recall:    {recall:.4f}")
        print(f"F1-score:  {f1:.4f}")
        print(f"ROC-AUC:   {roc_auc:.4f}")
        print(f"Confusion Matrix:\n{tp} {fn}\n{fp} {tn}")

    return best_threshold


In [13]:
def test_model(model, test_loader, criterion, device, threshold):
    model.eval()  # Set the model to evaluation mode
    all_preds = []
    all_labels = []
    all_probs = []
    test_loss = 0.0

    with torch.no_grad():  # Disable gradient computation
        for sampled_data in tqdm.tqdm(test_loader):
            # Move data to the device
            sampled_data = sampled_data.to(device)

            # Forward pass
            pred_dict = model(sampled_data.x_dict, sampled_data.edge_index_dict)

            # Get predictions and labels for the 'collab_with' edge type
            edge_label_index = sampled_data['artist', 'collab_with', 'artist'].edge_label_index
            edge_label = sampled_data['artist', 'collab_with', 'artist'].edge_label

            src_emb = pred_dict['artist'][edge_label_index[0]]  # Source node embeddings
            dst_emb = pred_dict['artist'][edge_label_index[1]]  # Destination node embeddings
            
            # Compute the dot product between source and destination embeddings
            preds = (src_emb * dst_emb).sum(dim=-1)  # Scalar for each edge
            probs = torch.sigmoid(preds)  # Convert logits to probabilities
            preds_binary = (probs > threshold).long()  # Convert probabilities to binary predictions

            # Compute loss
            loss = criterion(preds, edge_label.float())
            test_loss += loss.item()

            # Collect predictions and labels
            all_preds.append(preds_binary.cpu())
            all_labels.append(edge_label.cpu())
            all_probs.append(probs.cpu())

    # Concatenate all predictions and labels
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    all_probs = torch.cat(all_probs)

    # Compute metrics
    cm = confusion_matrix(all_labels, all_preds)
    tp = cm[1, 1]
    fp = cm[0, 1]
    fn = cm[1, 0]
    tn = cm[0, 0]
    accuracy = (tp + tn) / (tp + fp + fn + tn)
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    f1 = 2 * precision * recall / (precision + recall)
    roc_auc = roc_auc_score(all_labels, all_probs)

    # Average test loss
    test_loss /= len(test_loader)

    print("Test Results:")
    print(f"Loss:      {test_loss:.4f}")
    print(f"Accuracy:  {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall:    {recall:.4f}")
    print(f"F1-score:  {f1:.4f}")
    print(f"ROC-AUC:   {roc_auc:.4f}")
    print(f"Confusion Matrix:\n{tp} {fn}\n{fp} {tn}")

In [14]:
model = GNN(metadata=train_data.metadata(), hidden_channels=64, out_channels=64).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

best_threshold = train(
    model,
    train_loader,
    val_loader,
    optimizer,
    F.binary_cross_entropy_with_logits,
    device,
    20
)


100%|██████████| 2066/2066 [02:47<00:00, 12.33it/s]


Epoch 1/20, Training Loss: 0.5898


100%|██████████| 2951/2951 [02:18<00:00, 21.35it/s]


Best threshold: 0.6800000000000004
Validation Metrics - Epoch 1/20:
Loss:      0.5768
Accuracy:  0.7943
Precision: 0.7502
Recall:    0.8823
F1-score:  0.8109
ROC-AUC:   0.8223
Confusion Matrix:
83289 11112
27734 66667


100%|██████████| 2066/2066 [03:00<00:00, 11.44it/s]


Epoch 2/20, Training Loss: 0.5672


100%|██████████| 2951/2951 [02:20<00:00, 20.99it/s]


Best threshold: 0.6600000000000004
Validation Metrics - Epoch 2/20:
Loss:      0.5621
Accuracy:  0.8310
Precision: 0.7851
Recall:    0.9116
F1-score:  0.8436
ROC-AUC:   0.8867
Confusion Matrix:
86053 8348
23559 70842


100%|██████████| 2066/2066 [03:00<00:00, 11.46it/s]


Epoch 3/20, Training Loss: 0.5611


100%|██████████| 2951/2951 [02:21<00:00, 20.85it/s]


Best threshold: 0.7200000000000004
Validation Metrics - Epoch 3/20:
Loss:      0.5676
Accuracy:  0.8532
Precision: 0.8438
Recall:    0.8669
F1-score:  0.8552
ROC-AUC:   0.9029
Confusion Matrix:
81837 12564
15149 79252


 10%|█         | 211/2066 [00:22<03:13,  9.58it/s]


KeyboardInterrupt: 

In [14]:
# best_threshold = train(
#     model,
#     train_loader,
#     val_loader,
#     optimizer,
#     F.binary_cross_entropy_with_logits,
#     device,
#     20
# )

In [15]:
edge_label_index = test_data["artist", "collab_with", "artist"].edge_label_index
edge_label = test_data["artist", "collab_with", "artist"].edge_label

print("Creating test_loader...")
test_loader = LinkNeighborLoader(
    data=test_data,
    num_neighbors=compt_tree_size,
    edge_label_index=(("artist", "collab_with", "artist"), edge_label_index),
    edge_label=edge_label,
    batch_size=512,
    shuffle=False,
    num_workers=10,
    pin_memory=True,
)

Creating test_loader...




In [17]:
test_model(
    model,
    test_loader,
    F.binary_cross_entropy_with_logits,
    device,
    best_threshold
)

100%|██████████| 369/369 [01:43<00:00,  3.55it/s]


Test Results:
Loss:      0.5527
Accuracy:  0.8295
Precision: 0.8379
Recall:    0.8171
F1-score:  0.8274
ROC-AUC:   0.8862
Confusion Matrix:
77136 17265
14925 79476


In [19]:
torch.save(model.state_dict(), "./normal-newdata.pth")

In [18]:
test = GNN(metadata=train_data.metadata(), out_channels=64).to(device)
test.load_state_dict(torch.load("./normal-newdata.pth"))
test_model(
    test,
    test_loader,
    F.binary_cross_entropy_with_logits,
    device,
    best_threshold
)

  test.load_state_dict(torch.load("./normal-nolfm.pth"))
100%|██████████| 615/615 [00:29<00:00, 21.13it/s]


Test Results:
Loss:      0.5740
Accuracy:  0.8026
Precision: 0.7512
Recall:    0.9051
F1-score:  0.8210
ROC-AUC:   0.8483
Confusion Matrix:
142398 14937
47169 110166
