In [1]:
import torch
import torch.nn.functional as F
from torch_geometric.data import HeteroData
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.nn import HeteroConv, GATConv, SAGEConv, Linear
import os.path as path
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix
import tqdm

In [None]:
data_folder = "ds/"
train_hd = "train_hd_2020_3_0.1.pt"

In [3]:
data = torch.load(path.join(data_folder, train_hd))

data.validate()

  data = torch.load(path.join(data_folder, train_hd))


True

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: '{device}'")

Device: 'cuda'


In [5]:
compt_tree_size = [25, 20]

print("Creating train_loader...")
train_loader = LinkNeighborLoader(
    data=data,
    num_neighbors=compt_tree_size,
    neg_sampling_ratio=1,
    edge_label_index=("artist", "collab_with", "artist"),
    batch_size=128,
    shuffle=True,
    num_workers=10,
    pin_memory=True,
)

Creating train_loader...




In [6]:
class GNN(torch.nn.Module):
    def __init__(self, metadata, hidden_channels, out_channels):
        super().__init__()
        self.metadata = metadata
        self.out_channels = out_channels

        self.conv1 = HeteroConv({
            ("artist", "collab_with", "artist"): GATConv((-1, -1), hidden_channels),
            ("artist", "has_tag_artists", "tag"): SAGEConv((-1, -1), hidden_channels),
            ("artist", "last_fm_match", "artist"): GATConv((-1, -1), hidden_channels),
            ("track", "has_tag_tracks", "tag"): SAGEConv((-1, -1), hidden_channels),
            ("artist", "linked_to", "artist"): GATConv((-1, -1), hidden_channels),
            ("artist", "musically_related_to", "artist"): GATConv((-1, -1), hidden_channels),
            ("artist", "personally_related_to", "artist"): GATConv((-1, -1), hidden_channels),
            ("tag", "tags_artists", "artist"): SAGEConv((-1, -1), hidden_channels),
            ("tag", "tags_tracks", "track"): SAGEConv((-1, -1), hidden_channels),
            ("track", "worked_by", "artist"): SAGEConv((-1, -1), hidden_channels),
            ("artist", "worked_in", "track"): SAGEConv((-1, -1), hidden_channels),
        }, aggr="mean")

        self.conv2 = HeteroConv({
            ("artist", "collab_with", "artist"): GATConv((-1, -1), hidden_channels),
            ("artist", "has_tag_artists", "tag"): SAGEConv((-1, -1), hidden_channels),
            ("artist", "last_fm_match", "artist"): GATConv((-1, -1), hidden_channels),
            ("track", "has_tag_tracks", "tag"): SAGEConv((-1, -1), hidden_channels),
            ("artist", "linked_to", "artist"): GATConv((-1, -1), hidden_channels),
            ("artist", "musically_related_to", "artist"): GATConv((-1, -1), hidden_channels),
            ("artist", "personally_related_to", "artist"): GATConv((-1, -1), hidden_channels),
            ("tag", "tags_artists", "artist"): SAGEConv((-1, -1), hidden_channels),
            ("tag", "tags_tracks", "track"): SAGEConv((-1, -1), hidden_channels),
            ("track", "worked_by", "artist"): SAGEConv((-1, -1), hidden_channels),
            ("artist", "worked_in", "track"): SAGEConv((-1, -1), hidden_channels),
        }, aggr="mean")

        self.linear1 = Linear(hidden_channels * 2, hidden_channels * 4)
        self.linear2 = Linear(hidden_channels * 4, out_channels)

    def forward(self, x_dict, edge_index_dict):
        x_dict1 = self.conv1(x_dict, edge_index_dict)
        x_dict2 = self.conv2(x_dict1, edge_index_dict)

        x_artist = torch.cat([x_dict1['artist'], x_dict2['artist']], dim=-1)

        x_artist = self.linear1(x_artist)
        x_artist = self.linear2(x_artist)

        # Normalize the artist node features
        x_artist = F.normalize(x_artist, p=2, dim=-1)

        # Update the dictionary with the new 'artist' features, leaving other nodes unchanged
        x_dict['artist'] = x_artist

        return x_dict

In [7]:


def train(model, train_loader, optimizer, criterion, device, num_epochs, val_epochs):
    for epoch in range(num_epochs):
        model.train()  # Set model to training mode
        epoch_loss = 0.0
        
        for sampled_data in tqdm.tqdm(train_loader):
            # Move data to device
            sampled_data = sampled_data.to(device)
            
            # Forward pass
            pred_dict = model(sampled_data.x_dict, sampled_data.edge_index_dict)
            
            # Get predictions and labels for the 'collab_with' edge type
            edge_label_index = sampled_data['artist', 'collab_with', 'artist'].edge_label_index
            edge_label = sampled_data['artist', 'collab_with', 'artist'].edge_label

            src_emb = pred_dict['artist'][edge_label_index[0]]  # Source node embeddings
            dst_emb = pred_dict['artist'][edge_label_index[1]]  # Destination node embeddings
            
            # Compute the dot product between source and destination embeddings
            preds = (src_emb * dst_emb).sum(dim=-1)  # Scalar for each edge
            
            # Compute loss
            loss = criterion(preds, edge_label.float())
            epoch_loss += loss.item()
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        # Average loss for the epoch
        epoch_loss /= len(train_loader)
        print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {epoch_loss:.4f}")

        # Validation?
        if (epoch + 1) % val_epochs != 0:
            continue

        print("Computing validation metrics")
        
        # Validation metrics
        model.eval()  # Set model to evaluation mode
        all_labels = []
        all_probs = []
        
        with torch.no_grad():  # Disable gradient computation for validation
            for sampled_data in tqdm.tqdm(train_loader):
                # Move data to device
                sampled_data = sampled_data.to(device)
                
                # Forward pass
                pred_dict = model(sampled_data.x_dict, sampled_data.edge_index_dict)
                
                # Get predictions and labels for the 'collab_with' edge type
                edge_label_index = sampled_data['artist', 'collab_with', 'artist'].edge_label_index
                edge_label = sampled_data['artist', 'collab_with', 'artist'].edge_label

                src_emb = pred_dict['artist'][edge_label_index[0]]  # Source node embeddings
                dst_emb = pred_dict['artist'][edge_label_index[1]]  # Destination node embeddings
                
                # Compute the dot product between source and destination embeddings
                preds = (src_emb * dst_emb).sum(dim=-1)  # Scalar for each edge

                probs = torch.sigmoid(preds)  # Convert to probabilities
                
                # Collect predictions, probabilities, and labels
                all_labels.append(edge_label.cpu())
                all_probs.append(probs.cpu())
        
        # Concatenate all predictions and labels
        all_labels = torch.cat(all_labels)
        all_probs = torch.cat(all_probs)

        # Find threshold for predictions
        print("Looking for threshold")
        best_threshold = 0
        best_f1 = 0
        for threshold in tqdm.tqdm(np.arange(0.2, 0.81, 0.01)):
            preds_binary = (all_probs > threshold).long()
            cm = confusion_matrix(all_labels, preds_binary)
            tp = cm[1, 1]
            fp = cm[0, 1]
            fn = cm[1, 0]
            tn = cm[0, 0]
            precision = 0 if tp == 0 else tp / (tp + fp)
            recall = 0 if tp == 0 else tp / (tp + fn)
            f1 = 0 if precision * recall == 0 else 2 * precision * recall / (precision + recall)
            if f1 > best_f1:
                best_threshold = threshold
                best_f1 = f1
        print(f"Best threshold: {best_threshold}")
        all_preds = (all_probs > best_threshold).long()
        
        # Compute metrics
        cm = confusion_matrix(all_labels, all_preds)
        tp = cm[1, 1]
        fp = cm[0, 1]
        fn = cm[1, 0]
        tn = cm[0, 0]
        accuracy = (tp + tn) / (tp + fp + fn + tn)
        precision = tp / (tp + fp)
        recall = tp / (tp + fn)
        f1 = 2 * precision * recall / (precision + recall)
        roc_auc = roc_auc_score(all_labels, all_probs)
        
        # Print validation metrics
        print(f"Validation Metrics - Epoch {epoch+1}/{num_epochs}:")
        print(f"Accuracy:  {accuracy:.4f}")
        print(f"Precision: {precision:.4f}")
        print(f"Recall:    {recall:.4f}")
        print(f"F1-score:  {f1:.4f}")
        print(f"ROC-AUC:   {roc_auc:.4f}")
        print(f"Confusion Matrix:\n{tp} {fn}\n{fp} {tn}")

    return best_threshold


In [8]:
model = GNN(metadata=data.metadata(), hidden_channels=64, out_channels=64).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

best_threshold = train(
    model,
    train_loader,
    optimizer,
    F.binary_cross_entropy_with_logits,
    device,
    3,
    3
)


100%|██████████| 12516/12516 [28:33<00:00,  7.30it/s]


Epoch 1/3, Training Loss: 0.5424


100%|██████████| 12516/12516 [28:07<00:00,  7.42it/s]


Epoch 2/3, Training Loss: 0.5289


100%|██████████| 12516/12516 [27:56<00:00,  7.47it/s]


Epoch 3/3, Training Loss: 0.5248
Computing validation metrics


100%|██████████| 12516/12516 [21:51<00:00,  9.55it/s]


Looking for threshold


100%|██████████| 62/62 [06:21<00:00,  6.15s/it]


Best threshold: 0.6700000000000004
Validation Metrics - Epoch 3/3:
Accuracy:  0.9210
Precision: 0.8940
Recall:    0.9552
F1-score:  0.9236
ROC-AUC:   0.9560
Confusion Matrix:
1530207 71769
181472 1420504


In [None]:
torch.save(model.state_dict(), "./model_2020_3_0.1.pth")