In [1]:
import torch
from torch_geometric.data import HeteroData
import os.path as path

In [2]:
data_folder = "pyg_experiments/ds_float32/"

In [3]:
data = HeteroData()

In [4]:
data["artist"].x = torch.load(path.join(data_folder, "artists.pt"), weights_only=True)
print("Artist tensor shape:", data["artist"].x.shape)

data["track"].x = torch.load(path.join(data_folder, "tracks.pt"), weights_only=True)
print("Track tensor shape:", data["track"].x.shape)

data["tag"].x = torch.load(path.join(data_folder, "tags.pt"), weights_only=True)
print("Tag tensor shape:", data["tag"].x.shape)


data["artist", "collab_with", "artist"].edge_index = torch.load(path.join(data_folder, "collab_with.pt"), weights_only=True).t().long()
data["artist", "collab_with", "artist"].edge_attr = torch.load(path.join(data_folder, "collab_with_attr.pt"), weights_only=True)
print("collab_with index tensor shape:", data["artist", "collab_with", "artist"].edge_index.shape)
print("collab_with attr tensor shape:", data["artist", "collab_with", "artist"].edge_attr.shape)

data["artist", "has_tag_artists", "tag"].edge_index = torch.load(path.join(data_folder, "has_tag_artists.pt"), weights_only=True).t().long()
data["track", "has_tag_tracks", "tag"].edge_index = torch.load(path.join(data_folder, "has_tag_tracks.pt"), weights_only=True).t().long()
print("has_tag_artists index tensor shape:", data["artist", "has_tag_artists", "tag"].edge_index.shape)
print("has_tag_tracks index tensor shape:", data["track", "has_tag_tracks", "tag"].edge_index.shape)

# data["artist", "last_fm_match", "artist"].edge_index = torch.load(path.join(data_folder, "last_fm_match.pt"), weights_only=True).t().long()
# data["artist", "last_fm_match", "artist"].edge_attr = torch.load(path.join(data_folder, "last_fm_match_attr.pt"), weights_only=True)
# print("last_fm_match index tensor shape:", data["artist", "last_fm_match", "artist"].edge_index.shape)
# print("last_fm_match attr tensor shape:", data["artist", "last_fm_match", "artist"].edge_attr.shape)

data["artist", "linked_to", "artist"].edge_index = torch.load(path.join(data_folder, "linked_to.pt"), weights_only=True).t().long()
data["artist", "linked_to", "artist"].edge_attr = torch.load(path.join(data_folder, "linked_to_attr.pt"), weights_only=True)
print("linked_to index tensor shape:", data["artist", "linked_to", "artist"].edge_index.shape)
print("linked_to attr tensor shape:", data["artist", "linked_to", "artist"].edge_attr.shape)

data["artist", "musically_related_to", "artist"].edge_index = torch.load(path.join(data_folder, "musically_related_to.pt"), weights_only=True).t().long()
data["artist", "musically_related_to", "artist"].edge_attr = torch.load(path.join(data_folder, "musically_related_to_attr.pt"), weights_only=True)
print("musically_related_to index tensor shape:", data["artist", "musically_related_to", "artist"].edge_index.shape)
print("musically_related_to attr tensor shape:", data["artist", "musically_related_to", "artist"].edge_attr.shape)

data["artist", "personally_related_to", "artist"].edge_index = torch.load(path.join(data_folder, "personally_related_to.pt"), weights_only=True).t().long()
data["artist", "personally_related_to", "artist"].edge_attr = torch.load(path.join(data_folder, "personally_related_to_attr.pt"), weights_only=True)
print("personally_related_to index tensor shape:", data["artist", "personally_related_to", "artist"].edge_index.shape)
print("personally_related_to attr tensor shape:", data["artist", "personally_related_to", "artist"].edge_attr.shape)

data["tag", "tags_artists", "artist"].edge_index = torch.load(path.join(data_folder, "tags_artists.pt"), weights_only=True).t().long()
data["tag", "tags_track", "track"].edge_index = torch.load(path.join(data_folder, "tags_tracks.pt"), weights_only=True).t().long()
print("tags_artists index tensor shape:", data["tag", "tags_artists", "artist"].edge_index.shape)
print("tags_tracks index tensor shape:", data["tag", "tags_track", "track"].edge_index.shape)

data["track", "worked_by", "artist"].edge_index = torch.load(path.join(data_folder, "worked_by.pt"), weights_only=True).t().long()
data["artist", "worked_in", "track"].edge_index = torch.load(path.join(data_folder, "worked_in.pt"), weights_only=True).t().long()
print("worked_by index tensor shape:", data["track", "worked_by", "artist"].edge_index.shape)
print("worked_in index tensor shape:", data["artist", "worked_in", "track"].edge_index.shape)

Artist tensor shape: torch.Size([1489250, 16])
Track tensor shape: torch.Size([24324100, 4])
Tag tensor shape: torch.Size([23, 1])
collab_with index tensor shape: torch.Size([2, 2463052])
collab_with attr tensor shape: torch.Size([2463052, 1])
has_tag_artists index tensor shape: torch.Size([2, 2410207])
has_tag_tracks index tensor shape: torch.Size([2, 4030735])
linked_to index tensor shape: torch.Size([2, 23128])
linked_to attr tensor shape: torch.Size([23128, 1])
musically_related_to index tensor shape: torch.Size([2, 373262])
musically_related_to attr tensor shape: torch.Size([373262, 1])
personally_related_to index tensor shape: torch.Size([2, 26720])
personally_related_to attr tensor shape: torch.Size([26720, 1])
tags_artists index tensor shape: torch.Size([2, 2410207])
tags_tracks index tensor shape: torch.Size([2, 4030735])
worked_by index tensor shape: torch.Size([2, 27661673])
worked_in index tensor shape: torch.Size([2, 27661673])


In [5]:
import torch_geometric.transforms as T

transform = T.RandomLinkSplit(
    num_val=0.2,
    num_test=0.2,
    disjoint_train_ratio=0.3,
    neg_sampling_ratio=2,
    add_negative_train_samples=False,
    edge_types=("artist", "collab_with", "artist"),
)

train_data, val_data, test_data = transform(data)

print("Training data:")
print("==============")
print(train_data)
print()
print("Validation data:")
print("================")
print(val_data)
print()
print("Test data:")
print("================")
print(test_data)

Training data:
HeteroData(
  artist={ x=[1489250, 16] },
  track={ x=[24324100, 4] },
  tag={ x=[23, 1] },
  (artist, collab_with, artist)={
    edge_index=[2, 1034483],
    edge_attr=[1034483, 1],
    edge_label=[443349],
    edge_label_index=[2, 443349],
  },
  (artist, has_tag_artists, tag)={ edge_index=[2, 2410207] },
  (track, has_tag_tracks, tag)={ edge_index=[2, 4030735] },
  (artist, linked_to, artist)={
    edge_index=[2, 23128],
    edge_attr=[23128, 1],
  },
  (artist, musically_related_to, artist)={
    edge_index=[2, 373262],
    edge_attr=[373262, 1],
  },
  (artist, personally_related_to, artist)={
    edge_index=[2, 26720],
    edge_attr=[26720, 1],
  },
  (tag, tags_artists, artist)={ edge_index=[2, 2410207] },
  (tag, tags_track, track)={ edge_index=[2, 4030735] },
  (track, worked_by, artist)={ edge_index=[2, 27661673] },
  (artist, worked_in, track)={ edge_index=[2, 27661673] }
)

Validation data:
HeteroData(
  artist={ x=[1489250, 16] },
  track={ x=[24324100, 4] }

In [6]:
from torch_geometric.loader import LinkNeighborLoader

edge_label_index = train_data["artist", "collab_with", "artist"].edge_label_index
edge_label = train_data["artist", "collab_with", "artist"].edge_label

train_loader = LinkNeighborLoader(
    data=train_data,
    num_neighbors=[20, 10],
    neg_sampling_ratio= 2,
    edge_label_index=(("artist", "collab_with", "artist"), edge_label_index),
    edge_label=edge_label,
    batch_size=128,
    shuffle=True,
)

edge_label_index = val_data["artist", "collab_with", "artist"].edge_label_index
edge_label = val_data["artist", "collab_with", "artist"].edge_label

val_loader = LinkNeighborLoader(
    data=val_data,
    num_neighbors=[20, 10],
    neg_sampling_ratio= 2,
    edge_label_index=(("artist", "collab_with", "artist"), edge_label_index),
    edge_label=edge_label,
    batch_size=128,
    shuffle=True,
)

edge_label_index = test_data["artist", "collab_with", "artist"].edge_label_index
edge_label = test_data["artist", "collab_with", "artist"].edge_label

test_loader = LinkNeighborLoader(
    data=test_data,
    num_neighbors=[20, 10],
    neg_sampling_ratio= 2,
    edge_label_index=(("artist", "collab_with", "artist"), edge_label_index),
    edge_label=edge_label,
    batch_size=128,
    shuffle=True,
)

sampled_data = next(iter(train_loader))

print("Sampled mini-batch:")
print("===================")
print(sampled_data)



Sampled mini-batch:
HeteroData(
  artist={
    x=[16366, 16],
    n_id=[16366],
  },
  track={
    x=[27523, 4],
    n_id=[27523],
  },
  tag={
    x=[23, 1],
    n_id=[23],
  },
  (artist, collab_with, artist)={
    edge_index=[2, 15612],
    edge_attr=[15612, 1],
    edge_label=[384],
    edge_label_index=[2, 384],
    e_id=[15612],
    input_id=[128],
  },
  (artist, has_tag_artists, tag)={
    edge_index=[2, 230],
    e_id=[230],
  },
  (track, has_tag_tracks, tag)={
    edge_index=[2, 230],
    e_id=[230],
  },
  (artist, linked_to, artist)={
    edge_index=[2, 216],
    edge_attr=[216, 1],
    e_id=[216],
  },
  (artist, musically_related_to, artist)={
    edge_index=[2, 3803],
    edge_attr=[3803, 1],
    e_id=[3803],
  },
  (artist, personally_related_to, artist)={
    edge_index=[2, 468],
    edge_attr=[468, 1],
    e_id=[468],
  },
  (tag, tags_artists, artist)={
    edge_index=[2, 8419],
    e_id=[8419],
  },
  (tag, tags_track, track)={
    edge_index=[2, 1179],
    e_id=[1

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: '{device}'")

Device: 'cuda'


In [8]:
from torch_geometric.nn import HeteroConv, GATConv, SAGEConv
import torch.nn.functional as F

class GNN(torch.nn.Module):
    def __init__(self, metadata, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = HeteroConv({
            ("artist", "collab_with", "artist"): GATConv((-1, -1), hidden_channels),
            ("artist", "has_tag_artists", "tag"): SAGEConv((-1, -1), hidden_channels),
            ("track", "has_tag_tracks", "tag"): SAGEConv((-1, -1), hidden_channels),
            ("artist", "linked_to", "artist"): GATConv((-1, -1), hidden_channels),
            ("artist", "musically_related_to", "artist"): GATConv((-1, -1), hidden_channels),
            ("artist", "personally_related_to", "artist"): GATConv((-1, -1), hidden_channels),
            ("tag", "tags_artists", "artist"): SAGEConv((-1, -1), hidden_channels),
            ("tag", "tags_tracks", "track"): SAGEConv((-1, -1), hidden_channels),
            ("track", "worked_by", "artist"): SAGEConv((-1, -1), hidden_channels),
            ("artist", "worked_in", "track"): SAGEConv((-1, -1), hidden_channels),
        }, aggr='sum')
        
        self.conv2 = HeteroConv({
            ("artist", "collab_with", "artist"): GATConv((-1, -1), hidden_channels),
            ("artist", "has_tag_artists", "tag"): SAGEConv((-1, -1), hidden_channels),
            ("track", "has_tag_tracks", "tag"): SAGEConv((-1, -1), hidden_channels),
            ("artist", "linked_to", "artist"): GATConv((-1, -1), hidden_channels),
            ("artist", "musically_related_to", "artist"): GATConv((-1, -1), hidden_channels),
            ("artist", "personally_related_to", "artist"): GATConv((-1, -1), hidden_channels),
            ("tag", "tags_artists", "artist"): SAGEConv((-1, -1), hidden_channels),
            ("tag", "tags_tracks", "track"): SAGEConv((-1, -1), hidden_channels),
            ("track", "worked_by", "artist"): SAGEConv((-1, -1), hidden_channels),
            ("artist", "worked_in", "track"): SAGEConv((-1, -1), hidden_channels),
        }, aggr='sum')

        self.lin = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict):
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {key: F.relu(x) for key, x in x_dict.items()}
        x_dict = self.conv2(x_dict, edge_index_dict)
        return {key: self.lin(x) for key, x in x_dict.items()}

In [19]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import tqdm

def train(model, train_loader, val_loader, optimizer, criterion, device, num_epochs):
    for epoch in range(num_epochs):
        model.train()  # Set model to training mode
        epoch_loss = 0.0
        
        for sampled_data in tqdm.tqdm(train_loader):
            # Move data to device
            sampled_data = sampled_data.to(device)
            
            # Forward pass
            pred_dict = model(sampled_data.x_dict, sampled_data.edge_index_dict)
            
            # Get predictions and labels for the 'collab_with' edge type
            edge_label_index = sampled_data['artist', 'collab_with', 'artist'].edge_label_index
            edge_label = sampled_data['artist', 'collab_with', 'artist'].edge_label
            
            preds = pred_dict['artist'][edge_label_index[0]].squeeze()
            
            # Compute loss
            loss = criterion(preds, edge_label.float())
            epoch_loss += loss.item()
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        # Average loss for the epoch
        epoch_loss /= len(train_loader)
        print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {epoch_loss:.4f}")
        
        # Validation metrics
        model.eval()  # Set model to evaluation mode
        all_preds = []
        all_labels = []
        all_probs = []
        val_loss = 0.0
        
        with torch.no_grad():  # Disable gradient computation for validation
            for sampled_data in tqdm.tqdm(val_loader):
                # Move data to device
                sampled_data = sampled_data.to(device)
                
                # Forward pass
                pred_dict = model(sampled_data.x_dict, sampled_data.edge_index_dict)
                
                # Get predictions and labels for the 'collab_with' edge type
                edge_label_index = sampled_data['artist', 'collab_with', 'artist'].edge_label_index
                edge_label = sampled_data['artist', 'collab_with', 'artist'].edge_label
                
                preds = pred_dict['artist'][edge_label_index[0]].squeeze()
                probs = torch.sigmoid(preds)  # Convert to probabilities
                preds_binary = (probs > 0.5).long()  # Convert to binary predictions

                loss = criterion(preds, edge_label.float())
                val_loss += loss.item()
                
                # Collect predictions, probabilities, and labels
                all_preds.append(preds_binary.cpu())
                all_labels.append(edge_label.cpu())
                all_probs.append(probs.cpu())
        
        # Concatenate all predictions and labels
        all_preds = torch.cat(all_preds)
        all_labels = torch.cat(all_labels)
        all_probs = torch.cat(all_probs)
        val_loss /= len(val_loader)
        
        # Compute metrics
        accuracy = accuracy_score(all_labels, all_preds)
        precision = precision_score(all_labels, all_preds, zero_division=0)
        recall = recall_score(all_labels, all_preds, zero_division=0)
        f1 = f1_score(all_labels, all_preds, zero_division=0)
        roc_auc = roc_auc_score(all_labels, all_probs)
        
        # Print validation metrics
        print(f"Validation Metrics - Epoch {epoch+1}/{num_epochs}:")
        print(f"Loss:      {val_loss:.4f}")
        print(f"Accuracy:  {accuracy:.4f}")
        print(f"Precision: {precision:.4f}")
        print(f"Recall:    {recall:.4f}")
        print(f"F1-score:  {f1:.4f}")
        print(f"ROC-AUC:   {roc_auc:.4f}")


In [23]:
def test_model(model, test_loader, criterion, device):
    model.eval()  # Set the model to evaluation mode
    all_preds = []
    all_labels = []
    all_probs = []
    test_loss = 0.0

    with torch.no_grad():  # Disable gradient computation
        for sampled_data in tqdm.tqdm(test_loader):
            # Move data to the device
            sampled_data = sampled_data.to(device)

            # Forward pass
            pred_dict = model(sampled_data.x_dict, sampled_data.edge_index_dict)

            # Get predictions and labels for the 'collab_with' edge type
            edge_label_index = sampled_data['artist', 'collab_with', 'artist'].edge_label_index
            edge_label = sampled_data['artist', 'collab_with', 'artist'].edge_label

            preds = pred_dict['artist'][edge_label_index[0]].squeeze()
            probs = torch.sigmoid(preds)  # Convert logits to probabilities
            preds_binary = (probs > 0.5).long()  # Convert probabilities to binary predictions

            # Compute loss
            loss = criterion(preds, edge_label.float())
            test_loss += loss.item()

            # Collect predictions and labels
            all_preds.append(preds_binary.cpu())
            all_labels.append(edge_label.cpu())
            all_probs.append(probs.cpu())

    # Concatenate all predictions and labels
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    all_probs = torch.cat(all_probs)

    # Compute metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, zero_division=0)
    recall = recall_score(all_labels, all_preds, zero_division=0)
    f1 = f1_score(all_labels, all_preds, zero_division=0)
    roc_auc = roc_auc_score(all_labels, all_probs)

    # Average test loss
    test_loss /= len(test_loader)

    print("Test Results:")
    print(f"Loss:      {test_loss:.4f}")
    print(f"Accuracy:  {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall:    {recall:.4f}")
    print(f"F1-score:  {f1:.4f}")
    print(f"ROC-AUC:   {roc_auc:.4f}")

In [20]:
model = GNN(metadata=train_data.metadata(), hidden_channels=64, out_channels=1).to(device)

train(
    model,
    train_loader,
    val_loader,
    torch.optim.Adam(model.parameters(), lr=0.001),
    F.binary_cross_entropy_with_logits,
    device,
    3
)

print("Training complete!")


100%|██████████| 3464/3464 [04:03<00:00, 14.24it/s]


Epoch 1/3, Training Loss: 7.3193


100%|██████████| 1732/1732 [03:41<00:00,  7.81it/s]


Validation Metrics - Epoch 1/3:
Loss:      1.6500
Accuracy:  0.7470
Precision: 0.6045
Recall:    0.6972
F1-score:  0.6475
ROC-AUC:   0.8129


100%|██████████| 3464/3464 [04:21<00:00, 13.26it/s]


Epoch 2/3, Training Loss: 1.2705


100%|██████████| 1732/1732 [03:55<00:00,  7.35it/s]


Validation Metrics - Epoch 2/3:
Loss:      0.6933
Accuracy:  0.8014
Precision: 0.7505
Recall:    0.6053
F1-score:  0.6701
ROC-AUC:   0.8407


100%|██████████| 3464/3464 [04:14<00:00, 13.62it/s]


Epoch 3/3, Training Loss: 0.6519


100%|██████████| 1732/1732 [03:44<00:00,  7.71it/s]


Validation Metrics - Epoch 3/3:
Loss:      0.4708
Accuracy:  0.7855
Precision: 0.7761
Recall:    0.5011
F1-score:  0.6090
ROC-AUC:   0.8540
Training complete!


In [21]:

train(
    model,
    train_loader,
    val_loader,
    torch.optim.Adam(model.parameters(), lr=0.001),
    F.binary_cross_entropy_with_logits,
    device,
    3
)

100%|██████████| 3464/3464 [04:05<00:00, 14.13it/s]


Epoch 1/3, Training Loss: 0.5458


100%|██████████| 1732/1732 [03:44<00:00,  7.73it/s]


Validation Metrics - Epoch 1/3:
Loss:      0.4591
Accuracy:  0.7914
Precision: 0.7928
Recall:    0.5066
F1-score:  0.6182
ROC-AUC:   0.8664


100%|██████████| 3464/3464 [04:01<00:00, 14.35it/s]


Epoch 2/3, Training Loss: 0.4655


100%|██████████| 1732/1732 [03:42<00:00,  7.79it/s]


Validation Metrics - Epoch 2/3:
Loss:      0.4414
Accuracy:  0.8009
Precision: 0.7185
Recall:    0.6621
F1-score:  0.6891
ROC-AUC:   0.8566


100%|██████████| 3464/3464 [04:00<00:00, 14.42it/s]


Epoch 3/3, Training Loss: 0.4513


100%|██████████| 1732/1732 [03:45<00:00,  7.67it/s]


Validation Metrics - Epoch 3/3:
Loss:      0.4281
Accuracy:  0.8034
Precision: 0.7257
Recall:    0.6596
F1-score:  0.6910
ROC-AUC:   0.8674


In [24]:
test_model(
    model,
    test_loader,
    F.binary_cross_entropy_with_logits,
    device
)

100%|██████████| 3464/3464 [04:40<00:00, 12.34it/s]


Test Results:
Loss:      0.4233
Accuracy:  0.8027
Precision: 0.7056
Recall:    0.7002
F1-score:  0.7029
ROC-AUC:   0.8726
