In [1]:
import torch
from torch import Tensor
from torch.nn import Linear, ReLU, Sequential
import torch.nn.functional as F
from torch_geometric.data import HeteroData
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.nn import HeteroConv, GATConv, SAGEConv, Linear
from torch_geometric.nn.aggr import Aggregation, MultiAggregation
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.typing import OptPairTensor, Adj, Size
import torch_geometric.transforms as T
from torch.utils.data import SubsetRandomSampler
import os.path as path
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix
import tqdm
import copy
import requests

In [2]:
data_folder = "ds/"
model_name = "main_mb_test"
year = 2019
month = 11
perc = 0.9
latest_epoch = 0
train_hd = f"train_hdmb_{year}_{month}_{perc}.pt"
# train_hd = f"train_hd_{year}_{month}_{perc}.pt"
# train_hd = f"train_hd_nomatch_{year}_{month}_{perc}.pt"

In [3]:
data = torch.load(path.join(data_folder, train_hd))

data.validate()

  data = torch.load(path.join(data_folder, train_hd))


True

In [4]:
artist_channels = data["artist"].x.size(1)
track_channels = data["track"].x.size(1)
tag_channels = data["tag"].x.size(1)

print(f"Artist channels: {artist_channels}")
print(f"Track channels: {track_channels}")
print(f"Tag channels: {tag_channels}")

Artist channels: 17
Track channels: 5
Tag channels: 24


In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

data["artist", "collab_with", "artist"].edge_index = data["artist", "collab_with", "artist"].edge_index.contiguous()

print(f"Device: '{device}'")

Device: 'cuda'


In [6]:
compt_tree_size = [25, 20]

edge_indices = torch.arange(data["artist", "collab_with", "artist"].edge_index.shape[1])

# Shuffle and split
num_edges = len(edge_indices)
perm = torch.randperm(num_edges)
split_idx = int(0.8 * num_edges)

train_sampler = SubsetRandomSampler(perm[:split_idx])
val_sampler = SubsetRandomSampler(perm[split_idx:])

print("Creating train_loader...")
train_loader = LinkNeighborLoader(
    data=data,
    num_neighbors=compt_tree_size,
    neg_sampling_ratio=1,
    edge_label_index=("artist", "collab_with", "artist"),
    batch_size=64,
    shuffle=False,
    num_workers=10,
    pin_memory=True,
    sampler=train_sampler,
)

print("Creating val loader...")
val_loader = LinkNeighborLoader(
    data=data,
    num_neighbors=compt_tree_size,
    neg_sampling_ratio=1,
    edge_label_index=("artist", "collab_with", "artist"),
    batch_size=64,
    shuffle=False,
    num_workers=10,
    pin_memory=True,
    sampler=val_sampler,
)

print("Number of train batches:", len(train_loader))
print("Number of validation batches:", len(val_loader))

Creating train_loader...
Creating val loader...
Number of train batches: 22067
Number of validation batches: 5517


In [7]:
class GNN(torch.nn.Module):
    def __init__(self, metadata, hidden_channels, out_channels):
        super().__init__()
        self.metadata = metadata
        self.out_channels = out_channels

        self.conv1 = HeteroConv({
            ("artist", "collab_with", "artist"): GATConv((artist_channels, artist_channels), hidden_channels, heads=3, concat=False),
            ("artist", "has_tag_artists", "tag"): SAGEConv((artist_channels, tag_channels), hidden_channels, normalize=True, project=True),
            ("artist", "last_fm_match", "artist"): GATConv((artist_channels, artist_channels), hidden_channels, heads=3, concat=False),
            ("track", "has_tag_tracks", "tag"): SAGEConv((track_channels, tag_channels), hidden_channels, normalize=True, project=True),
            ("artist", "linked_to", "artist"): GATConv((artist_channels, artist_channels), hidden_channels, heads=3, concat=False),
            ("artist", "musically_related_to", "artist"): GATConv((artist_channels, artist_channels), hidden_channels, heads=3, concat=False),
            ("artist", "personally_related_to", "artist"): GATConv((artist_channels, artist_channels), hidden_channels, heads=3, concat=False),
            ("tag", "tags_artists", "artist"): SAGEConv((tag_channels, artist_channels), hidden_channels, normalize=True, project=True),
            ("tag", "tags_tracks", "track"): SAGEConv((tag_channels, track_channels), hidden_channels, normalize=True, project=True),
            ("track", "worked_by", "artist"): SAGEConv((track_channels, artist_channels), hidden_channels, normalize=True, project=True),
            ("artist", "worked_in", "track"): SAGEConv((artist_channels, track_channels), hidden_channels, normalize=True, project=True),
        }, aggr="mean")

        self.conv2 = HeteroConv({
            ("artist", "collab_with", "artist"): GATConv((hidden_channels, hidden_channels), hidden_channels, heads=3, concat=False),
            ("artist", "has_tag_artists", "tag"): SAGEConv((hidden_channels, hidden_channels), hidden_channels, normalize=True, project=True),
            ("artist", "last_fm_match", "artist"): GATConv((hidden_channels, hidden_channels), hidden_channels, heads=3, concat=False),
            ("track", "has_tag_tracks", "tag"): SAGEConv((hidden_channels, hidden_channels), hidden_channels, normalize=True, project=True),
            ("artist", "linked_to", "artist"): GATConv((hidden_channels, hidden_channels), hidden_channels, heads=3, concat=False),
            ("artist", "musically_related_to", "artist"): GATConv((hidden_channels, hidden_channels), hidden_channels, heads=3, concat=False),
            ("artist", "personally_related_to", "artist"): GATConv((hidden_channels, hidden_channels), hidden_channels, heads=3, concat=False),
            ("tag", "tags_artists", "artist"): SAGEConv((hidden_channels, hidden_channels), hidden_channels, normalize=True, project=True),
            ("tag", "tags_tracks", "track"): SAGEConv((hidden_channels, hidden_channels), hidden_channels, normalize=True, project=True),
            ("track", "worked_by", "artist"): SAGEConv((hidden_channels, hidden_channels), hidden_channels, normalize=True, project=True),
            ("artist", "worked_in", "track"): SAGEConv((hidden_channels, hidden_channels), hidden_channels, normalize=True, project=True),
        }, aggr="mean")

        self.linear1 = Linear(hidden_channels * 2, hidden_channels * 4)
        self.linear2 = Linear(hidden_channels * 4, out_channels)

    def forward(self, x_dict, edge_index_dict):
        x_dict1 = self.conv1(x_dict, edge_index_dict)
        x_dict2 = self.conv2(x_dict1, edge_index_dict)

        x_artist = torch.cat([x_dict1['artist'], x_dict2['artist']], dim=-1)

        x_artist = self.linear1(x_artist)
        x_artist = F.relu(x_artist)
        x_artist = self.linear2(x_artist)

        # Normalize the artist node features
        x_artist = F.normalize(x_artist, p=2, dim=-1)

        # Update the dictionary with the new 'artist' features, leaving other nodes unchanged
        x_dict['artist'] = x_artist

        return x_dict

In [None]:
def train(model, train_loader, val_loader, optimizer, criterion, device, num_epochs, patience=5):
    best_val_f1 = 0.0
    best_threshold = 0
    epochs_no_improve = 0
    best_model_state = None
    train_losses = list()
    val_losses = list()
    best_epoch = 0

    for epoch in range(num_epochs):
        model.train()  # Set model to training mode
        epoch_loss = 0.0
        
        for sampled_data in tqdm.tqdm(train_loader):
            # Move data to device
            sampled_data = sampled_data.to(device)
            
            # Forward pass
            pred_dict = model(sampled_data.x_dict, sampled_data.edge_index_dict)
            
            # Get predictions and labels for the 'collab_with' edge type
            edge_label_index = sampled_data['artist', 'collab_with', 'artist'].edge_label_index
            edge_label = sampled_data['artist', 'collab_with', 'artist'].edge_label

            src_emb = pred_dict['artist'][edge_label_index[0]]  # Source node embeddings
            dst_emb = pred_dict['artist'][edge_label_index[1]]  # Destination node embeddings
            
            # Compute the dot product between source and destination embeddings
            preds = (src_emb * dst_emb).sum(dim=-1)  # Scalar for each edge
            
            # Compute loss
            loss = criterion(preds, edge_label.float())
            epoch_loss += loss.item()
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        # Average loss for the epoch
        epoch_loss /= len(train_loader)
        train_losses.append(epoch_loss)
        print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {epoch_loss:.4f}")

        print("Computing validation metrics")
        
        # Validation metrics
        model.eval()  # Set model to evaluation mode
        all_labels = []
        all_probs = []
        val_loss = 0.0
        
        with torch.no_grad():  # Disable gradient computation for validation
            for sampled_data in tqdm.tqdm(val_loader):
                # Move data to device
                sampled_data = sampled_data.to(device)

                # Forward pass
                pred_dict = model(sampled_data.x_dict, sampled_data.edge_index_dict)

                # Get predictions and labels for the 'collab_with' edge type
                edge_label_index = sampled_data['artist', 'collab_with', 'artist'].edge_label_index
                edge_label = sampled_data['artist', 'collab_with', 'artist'].edge_label

                src_emb = pred_dict['artist'][edge_label_index[0]]  # Source node embeddings
                dst_emb = pred_dict['artist'][edge_label_index[1]]  # Destination node embeddings

                # Compute the dot product between source and destination embeddings
                preds = (src_emb * dst_emb).sum(dim=-1)  # Scalar for each edge

                loss = criterion(preds, edge_label.float())
                val_loss += loss.item()

                probs = torch.sigmoid(preds)  # Convert to probabilities

                # Collect predictions, probabilities, and labels
                all_labels.append(edge_label.cpu())
                all_probs.append(probs.cpu())
        
        # Concatenate all predictions and labels
        all_labels = torch.cat(all_labels)
        all_probs = torch.cat(all_probs)

        val_loss /= len(val_loader)
        val_losses.append(val_loss)

        # Find threshold for predictions
        print("Looking for threshold")
        best_threshold_epoch = 0
        best_f1_epoch = 0
        for threshold in tqdm.tqdm(np.arange(0.2, 0.91, 0.01)):
            preds_binary = (all_probs > threshold).long()
            cm = confusion_matrix(all_labels, preds_binary)
            tp = cm[1, 1]
            fp = cm[0, 1]
            fn = cm[1, 0]
            tn = cm[0, 0]
            precision = 0 if tp == 0 else tp / (tp + fp)
            recall = 0 if tp == 0 else tp / (tp + fn)
            f1 = 0 if precision * recall == 0 else 2 * precision * recall / (precision + recall)
            if f1 > best_f1_epoch:
                best_threshold_epoch = threshold
                best_f1_epoch = f1
        print(f"Best threshold: {best_threshold_epoch}")
        all_preds = (all_probs > best_threshold_epoch).long()
        
        # Compute metrics
        cm = confusion_matrix(all_labels, all_preds)
        tp = cm[1, 1]
        fp = cm[0, 1]
        fn = cm[1, 0]
        tn = cm[0, 0]
        accuracy = (tp + tn) / (tp + fp + fn + tn)
        precision = tp / (tp + fp)
        recall = tp / (tp + fn)
        f1 = 2 * precision * recall / (precision + recall)
        roc_auc = roc_auc_score(all_labels, all_probs)
        
        # Print validation metrics
        print(f"Validation Metrics - Epoch {epoch+1}/{num_epochs}:")
        print(f"Loss:      {val_loss:.4f}")
        print(f"Accuracy:  {accuracy:.4f}")
        print(f"Precision: {precision:.4f}")
        print(f"Recall:    {recall:.4f}")
        print(f"F1-score:  {f1:.4f}")
        print(f"ROC-AUC:   {roc_auc:.4f}")
        print(f"Confusion Matrix:\n{tp} {fn}\n{fp} {tn}")

        new_row = {
            "model": model_name,
            "year": year,
            "month": month,
            "perc": perc,
            "epoch": latest_epoch + epoch + 1,
            "train_loss": epoch_loss,
            "val_loss": val_loss,
            "acc": accuracy,
            "prec": precision,
            "rec": recall,
            "f1": f1,
            "auc": roc_auc,
            "tp": int(tp),
            "fp": int(fp),
            "fn": int(fn),
            "tn": int(tn),
            "best_threshold": best_threshold_epoch,
            "done": False
        }
        url = "http://localhost:5000/save_results"
        response = requests.post(url, json=new_row)
        assert response.status_code == 200

        torch.save(model.state_dict(), f"./model_{model_name}_{year}_{month}_{perc}_{latest_epoch + epoch + 1}.pth")

        if f1 > best_val_f1:
            best_val_f1 = f1
            best_threshold = best_threshold_epoch
            epochs_no_improve = 0
            best_model_state = copy.deepcopy(model.state_dict())
            best_epoch = latest_epoch + epoch + 1
        else:
            epochs_no_improve += 1
            if epochs_no_improve == patience:
                print(f"Early stopping!!!")
                print(f"Early stopping!!!")
                print(f"Early stopping!!!")
                print("Best epoch:", best_epoch)
                model.load_state_dict(best_model_state)
                break

    return best_threshold, train_losses, val_losses


In [None]:
model = GNN(metadata=data.metadata(), hidden_channels=64, out_channels=64).to(device)

if latest_epoch > 0:
    model.load_state_dict(torch.load(f"./model_{model_name}_{year}_{month}_{perc}_{latest_epoch}.pth"))

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

best_threshold, train_losses, val_losses = train(
    model,
    train_loader,
    val_loader,
    optimizer,
    F.binary_cross_entropy_with_logits,
    device,
    100
)


  0%|          | 0/22067 [00:02<?, ?it/s]


Epoch 1/100, Training Loss: 0.0000
Computing validation metrics


  0%|          | 0/5517 [00:01<?, ?it/s]


Looking for threshold


100%|██████████| 71/71 [00:00<00:00, 1108.39it/s]


Best threshold: 0.5700000000000003
Validation Metrics - Epoch 1/100:
Loss:      0.0001
Accuracy:  0.6562
Precision: 0.6000
Recall:    0.9375
F1-score:  0.7317
ROC-AUC:   0.6016
Confusion Matrix:
60 4
40 24


  0%|          | 0/22067 [00:00<?, ?it/s]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7bb8d029f560>
Traceback (most recent call last):
  File "/home/aleferu/miniforge3/envs/musicbrainz/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__
    self._shutdown_workers()
  File "/home/aleferu/miniforge3/envs/musicbrainz/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1441, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/home/aleferu/miniforge3/envs/musicbrainz/lib/python3.12/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/aleferu/miniforge3/envs/musicbrainz/lib/python3.12/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/aleferu/miniforge3/envs/musicbrainz/lib/python3.12/multiprocessing/connection.

Epoch 2/100, Training Loss: 0.0000
Computing validation metrics


  0%|          | 0/5517 [00:01<?, ?it/s]


Looking for threshold


100%|██████████| 71/71 [00:00<00:00, 1307.47it/s]


Best threshold: 0.5800000000000003
Validation Metrics - Epoch 2/100:
Loss:      0.0001
Accuracy:  0.7500
Precision: 0.6702
Recall:    0.9844
F1-score:  0.7975
ROC-AUC:   0.6920
Confusion Matrix:
63 1
31 33


  0%|          | 0/22067 [00:02<?, ?it/s]


Epoch 3/100, Training Loss: 0.0000
Computing validation metrics


  0%|          | 0/5517 [00:01<?, ?it/s]


Looking for threshold


100%|██████████| 71/71 [00:00<00:00, 1298.10it/s]


Best threshold: 0.5700000000000003
Validation Metrics - Epoch 3/100:
Loss:      0.0001
Accuracy:  0.7500
Precision: 0.6702
Recall:    0.9844
F1-score:  0.7975
ROC-AUC:   0.6699
Confusion Matrix:
63 1
31 33


  0%|          | 0/22067 [00:02<?, ?it/s]


Epoch 4/100, Training Loss: 0.0000
Computing validation metrics


  0%|          | 0/5517 [00:01<?, ?it/s]


Looking for threshold


100%|██████████| 71/71 [00:00<00:00, 1278.20it/s]


Best threshold: 0.5600000000000003
Validation Metrics - Epoch 4/100:
Loss:      0.0001
Accuracy:  0.7734
Precision: 0.6882
Recall:    1.0000
F1-score:  0.8153
ROC-AUC:   0.7074
Confusion Matrix:
64 0
29 35


  0%|          | 0/22067 [00:00<?, ?it/s]

In [None]:
print("BEST THRESHOLD:", best_threshold)

BEST THRESHOLD: 0.7300000000000004


In [None]:
torch.save(model.state_dict(), f"./model_{model_name}_{year}_{month}_{perc}.pth")