In [1]:
import torch
from torch_geometric.data import HeteroData
import os.path as path

In [2]:
data_folder = "ds_float32/"

In [3]:
data = HeteroData()

In [4]:
data["artist"].x = torch.load(path.join(data_folder, "artists.pt"), weights_only=True)
print("Artist tensor shape:", data["artist"].x.shape)

data["track"].x = torch.load(path.join(data_folder, "tracks.pt"), weights_only=True)
print("Track tensor shape:", data["track"].x.shape)

data["tag"].x = torch.load(path.join(data_folder, "tags.pt"), weights_only=True)
print("Tag tensor shape:", data["tag"].x.shape)


data["artist", "collab_with", "artist"].edge_index = torch.load(path.join(data_folder, "collab_with.pt"), weights_only=True).t().long()
data["artist", "collab_with", "artist"].edge_attr = torch.load(path.join(data_folder, "collab_with_attr.pt"), weights_only=True)
print("collab_with index tensor shape:", data["artist", "collab_with", "artist"].edge_index.shape)
print("collab_with attr tensor shape:", data["artist", "collab_with", "artist"].edge_attr.shape)

data["artist", "has_tag_artists", "tag"].edge_index = torch.load(path.join(data_folder, "has_tag_artists.pt"), weights_only=True).t().long()
data["track", "has_tag_tracks", "tag"].edge_index = torch.load(path.join(data_folder, "has_tag_tracks.pt"), weights_only=True).t().long()
print("has_tag_artists index tensor shape:", data["artist", "has_tag_artists", "tag"].edge_index.shape)
print("has_tag_tracks index tensor shape:", data["track", "has_tag_tracks", "tag"].edge_index.shape)

data["artist", "last_fm_match", "artist"].edge_index = torch.load(path.join(data_folder, "last_fm_match.pt"), weights_only=True).t().long()
data["artist", "last_fm_match", "artist"].edge_attr = torch.load(path.join(data_folder, "last_fm_match_attr.pt"), weights_only=True)
print("last_fm_match index tensor shape:", data["artist", "last_fm_match", "artist"].edge_index.shape)
print("last_fm_match attr tensor shape:", data["artist", "last_fm_match", "artist"].edge_attr.shape)

data["artist", "linked_to", "artist"].edge_index = torch.load(path.join(data_folder, "linked_to.pt"), weights_only=True).t().long()
data["artist", "linked_to", "artist"].edge_attr = torch.load(path.join(data_folder, "linked_to_attr.pt"), weights_only=True)
print("linked_to index tensor shape:", data["artist", "linked_to", "artist"].edge_index.shape)
print("linked_to attr tensor shape:", data["artist", "linked_to", "artist"].edge_attr.shape)

data["artist", "musically_related_to", "artist"].edge_index = torch.load(path.join(data_folder, "musically_related_to.pt"), weights_only=True).t().long()
data["artist", "musically_related_to", "artist"].edge_attr = torch.load(path.join(data_folder, "musically_related_to_attr.pt"), weights_only=True)
print("musically_related_to index tensor shape:", data["artist", "musically_related_to", "artist"].edge_index.shape)
print("musically_related_to attr tensor shape:", data["artist", "musically_related_to", "artist"].edge_attr.shape)

data["artist", "personally_related_to", "artist"].edge_index = torch.load(path.join(data_folder, "personally_related_to.pt"), weights_only=True).t().long()
data["artist", "personally_related_to", "artist"].edge_attr = torch.load(path.join(data_folder, "personally_related_to_attr.pt"), weights_only=True)
print("personally_related_to index tensor shape:", data["artist", "personally_related_to", "artist"].edge_index.shape)
print("personally_related_to attr tensor shape:", data["artist", "personally_related_to", "artist"].edge_attr.shape)

data["tag", "tags_artists", "artist"].edge_index = torch.load(path.join(data_folder, "tags_artists.pt"), weights_only=True).t().long()
data["tag", "tags_track", "track"].edge_index = torch.load(path.join(data_folder, "tags_tracks.pt"), weights_only=True).t().long()
print("tags_artists index tensor shape:", data["tag", "tags_artists", "artist"].edge_index.shape)
print("tags_tracks index tensor shape:", data["tag", "tags_track", "track"].edge_index.shape)

data["track", "worked_by", "artist"].edge_index = torch.load(path.join(data_folder, "worked_by.pt"), weights_only=True).t().long()
data["artist", "worked_in", "track"].edge_index = torch.load(path.join(data_folder, "worked_in.pt"), weights_only=True).t().long()
print("worked_by index tensor shape:", data["track", "worked_by", "artist"].edge_index.shape)
print("worked_in index tensor shape:", data["artist", "worked_in", "track"].edge_index.shape)

print()

data.validate()

Artist tensor shape: torch.Size([1489250, 16])
Track tensor shape: torch.Size([24324100, 4])
Tag tensor shape: torch.Size([23, 1])
collab_with index tensor shape: torch.Size([2, 2463052])
collab_with attr tensor shape: torch.Size([2463052, 1])
has_tag_artists index tensor shape: torch.Size([2, 2410207])
has_tag_tracks index tensor shape: torch.Size([2, 4030735])
last_fm_match index tensor shape: torch.Size([2, 154865250])
last_fm_match attr tensor shape: torch.Size([154865250, 1])
linked_to index tensor shape: torch.Size([2, 23128])
linked_to attr tensor shape: torch.Size([23128, 1])
musically_related_to index tensor shape: torch.Size([2, 373262])
musically_related_to attr tensor shape: torch.Size([373262, 1])
personally_related_to index tensor shape: torch.Size([2, 26720])
personally_related_to attr tensor shape: torch.Size([26720, 1])
tags_artists index tensor shape: torch.Size([2, 2410207])
tags_tracks index tensor shape: torch.Size([2, 4030735])
worked_by index tensor shape: torch.

True

In [5]:
# OPTIONAL SUBGRAPH

if True:

    # Data
    percentile = 0.85
    artist_popularity = data["artist"].x[:, 8]
    edge_types = [
        ("artist", "collab_with", "artist"),
        ("artist", "has_tag_artists", "tag"),
        ("track", "has_tag_tracks", "tag"),
        ("artist", "last_fm_match", "artist"),
        ("artist", "linked_to", "artist"),
        ("artist", "musically_related_to", "artist"),
        ("artist", "personally_related_to", "artist"),
        ("tag", "tags_artists", "artist"),
        ("tag", "tags_track", "track"),
        ("track", "worked_by", "artist"),
        ("artist", "worked_in", "track")
    ]

    # Threshold obtention
    threshold = torch.quantile(artist_popularity, percentile)
    selected_artists = artist_popularity >= threshold
    selected_artist_ids = torch.nonzero(selected_artists).squeeze()

    # Mapping
    old_to_new_artist_idx = {old: new for new, old in enumerate(selected_artist_ids.tolist())}

    # Subgraph
    subdata = HeteroData()
    for edge_type in edge_types:
        print(f"edge_type: {edge_type}")
        # Filter edge indices
        edge_index = data[edge_type].edge_index
        mask = torch.ones(edge_index.shape[1], dtype=torch.bool)
        if edge_type[0] == "artist":
            mask &= torch.isin(edge_index[0], selected_artist_ids)
        if edge_type[2] == "artist":
            mask &= torch.isin(edge_index[1], selected_artist_ids)

        filtered_edge_index = edge_index[:, mask]

        # Map the old indices to new ones for 'artist' nodes
        if edge_type[0] == "artist":  # Reindex source node
            filtered_edge_index[0] = torch.tensor(
                [old_to_new_artist_idx[idx.item()] for idx in filtered_edge_index[0]],
                dtype=torch.long,
            )
        if edge_type[2] == "artist":  # Reindex destination node
            filtered_edge_index[1] = torch.tensor(
                [old_to_new_artist_idx[idx.item()] for idx in filtered_edge_index[1]],
                dtype=torch.long,
            )

        # Assign filtered edges to subgraph
        subdata[edge_type].edge_index = filtered_edge_index

        # Handle edge attributes if they exist
        if hasattr(data[edge_type], "edge_attr"):
            try:
                subdata[edge_type].edge_attr = data[edge_type].edge_attr[mask]
            except IndexError as e:
                print(f"IndexError for {edge_type}: {e}")
        else:
            print(f"No edge_attr for {edge_type}")

    # Nodes filtering
    subdata["artist"].x = data["artist"].x[selected_artist_ids]
    subdata["track"].x = data["track"].x
    subdata["tag"].x = data["tag"].x

    # Check the shape of the filtered nodes and edges
    for edge_type in edge_types:
        print(f"Edge type: {edge_type}, edge_index shape: {subdata[edge_type].edge_index.shape}")

    # Check the artist features (should only have the selected artists)
    print("Subgraph artist tensor shape:", subdata["artist"].x.shape)
    print("Subgraph track tensor shape:", subdata["track"].x.shape)
    print("Subgraph tag tensor shape:", subdata["tag"].x.shape)

    print("\n")

    # Validate the subgraph
    try:
        subdata.validate()
        print("Validation successful.")

        del data
        data = subdata
    except ValueError as e:
        print("Validation failed:", e)

edge_type: ('artist', 'collab_with', 'artist')
edge_type: ('artist', 'has_tag_artists', 'tag')
No edge_attr for ('artist', 'has_tag_artists', 'tag')
edge_type: ('track', 'has_tag_tracks', 'tag')
No edge_attr for ('track', 'has_tag_tracks', 'tag')
edge_type: ('artist', 'last_fm_match', 'artist')
edge_type: ('artist', 'linked_to', 'artist')
edge_type: ('artist', 'musically_related_to', 'artist')
edge_type: ('artist', 'personally_related_to', 'artist')
edge_type: ('tag', 'tags_artists', 'artist')
No edge_attr for ('tag', 'tags_artists', 'artist')
edge_type: ('tag', 'tags_track', 'track')
No edge_attr for ('tag', 'tags_track', 'track')
edge_type: ('track', 'worked_by', 'artist')
No edge_attr for ('track', 'worked_by', 'artist')
edge_type: ('artist', 'worked_in', 'track')
No edge_attr for ('artist', 'worked_in', 'track')
Edge type: ('artist', 'collab_with', 'artist'), edge_index shape: torch.Size([2, 629340])
Edge type: ('artist', 'has_tag_artists', 'tag'), edge_index shape: torch.Size([2, 

In [6]:
import torch_geometric.transforms as T

transform = T.RandomLinkSplit(
    num_val=0.25,
    num_test=0.25,
    disjoint_train_ratio=0.5,
    neg_sampling_ratio=1,
    add_negative_train_samples=False,
    edge_types=("artist", "collab_with", "artist"),
)

train_data, val_data, test_data = transform(data)

print("Training data:")
print("==============")
print(train_data)
print()
print("Validation data:")
print("================")
print(val_data)
print()
print("Test data:")
print("================")
print(test_data)

Training data:
HeteroData(
  artist={ x=[223388, 16] },
  track={ x=[24324100, 4] },
  tag={ x=[23, 1] },
  (artist, collab_with, artist)={
    edge_index=[2, 157335],
    edge_attr=[157335, 1],
    edge_label=[157335],
    edge_label_index=[2, 157335],
  },
  (artist, has_tag_artists, tag)={ edge_index=[2, 1042766] },
  (track, has_tag_tracks, tag)={ edge_index=[2, 4030735] },
  (artist, last_fm_match, artist)={
    edge_index=[2, 28357816],
    edge_attr=[28357816, 1],
  },
  (artist, linked_to, artist)={
    edge_index=[2, 1438],
    edge_attr=[1438, 1],
  },
  (artist, musically_related_to, artist)={
    edge_index=[2, 41760],
    edge_attr=[41760, 1],
  },
  (artist, personally_related_to, artist)={
    edge_index=[2, 3334],
    edge_attr=[3334, 1],
  },
  (tag, tags_artists, artist)={ edge_index=[2, 1042766] },
  (tag, tags_track, track)={ edge_index=[2, 4030735] },
  (track, worked_by, artist)={ edge_index=[2, 12509457] },
  (artist, worked_in, track)={ edge_index=[2, 12509457] 

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: '{device}'")

Device: 'cuda'


In [8]:
from torch_geometric.loader import LinkNeighborLoader

edge_label_index = train_data["artist", "collab_with", "artist"].edge_label_index
edge_label = train_data["artist", "collab_with", "artist"].edge_label

print("Creating train_loader...")
train_loader = LinkNeighborLoader(
    data=train_data,
    num_neighbors=[25, 20],
    neg_sampling_ratio=1,
    edge_label_index=(("artist", "collab_with", "artist"), edge_label_index),
    edge_label=edge_label,
    batch_size=128,
    shuffle=True,
    num_workers=10,
    pin_memory=True,
)

edge_label_index = val_data["artist", "collab_with", "artist"].edge_label_index
edge_label = val_data["artist", "collab_with", "artist"].edge_label

print("Creating val_loader...")
val_loader = LinkNeighborLoader(
    data=val_data,
    num_neighbors=[25, 20],
    edge_label_index=(("artist", "collab_with", "artist"), edge_label_index),
    edge_label=edge_label,
    batch_size=512,
    shuffle=False,
    num_workers=10,
    pin_memory=True,
)

edge_label_index = test_data["artist", "collab_with", "artist"].edge_label_index
edge_label = test_data["artist", "collab_with", "artist"].edge_label

print("Creating test_loader...")
test_loader = LinkNeighborLoader(
    data=test_data,
    num_neighbors=[25, 20],
    edge_label_index=(("artist", "collab_with", "artist"), edge_label_index),
    edge_label=edge_label,
    batch_size=512,
    shuffle=False,
    num_workers=10,
    pin_memory=True,
)

print("Sampling mini-batch...")

sampled_data = next(iter(train_loader))

print("Sampled mini-batch:")
print("===================")
print(sampled_data)

Creating train_loader...




Creating val_loader...
Creating test_loader...
Sampling mini-batch...
Sampled mini-batch:
HeteroData(
  artist={
    x=[124069, 16],
    n_id=[124069],
  },
  track={
    x=[201198, 4],
    n_id=[201198],
  },
  tag={
    x=[23, 1],
    n_id=[23],
  },
  (artist, collab_with, artist)={
    edge_index=[2, 33146],
    edge_attr=[33146, 1],
    edge_label=[256],
    edge_label_index=[2, 256],
    e_id=[33146],
    input_id=[128],
  },
  (artist, has_tag_artists, tag)={
    edge_index=[2, 460],
    e_id=[460],
  },
  (track, has_tag_tracks, tag)={
    edge_index=[2, 460],
    e_id=[460],
  },
  (artist, last_fm_match, artist)={
    edge_index=[2, 272397],
    edge_attr=[272397, 1],
    e_id=[272397],
  },
  (artist, linked_to, artist)={
    edge_index=[2, 206],
    edge_attr=[206, 1],
    e_id=[206],
  },
  (artist, musically_related_to, artist)={
    edge_index=[2, 5089],
    edge_attr=[5089, 1],
    e_id=[5089],
  },
  (artist, personally_related_to, artist)={
    edge_index=[2, 524],
  

In [None]:
debug = False
if debug:
    print(torch.unique(train_data['artist', 'collab_with', 'artist'].edge_label))
    print(torch.unique(next(iter(train_loader))["artist", "collab_with", "artist"].edge_label))
    print(torch.unique(val_data['artist', 'collab_with', 'artist'].edge_label))
    print(torch.unique(next(iter(val_loader))["artist", "collab_with", "artist"].edge_label))
    print(torch.unique(test_data['artist', 'collab_with', 'artist'].edge_label))
    print(torch.unique(next(iter(test_loader))["artist", "collab_with", "artist"].edge_label))


In [10]:
from torch_geometric.nn import HeteroConv, GATConv, SAGEConv
import torch.nn.functional as F

class GNN(torch.nn.Module):
    def __init__(self, metadata, out_channels):
        super().__init__()
        self.metadata = metadata
        self.out_channels = out_channels

        self.conv1 = HeteroConv({
            ("artist", "collab_with", "artist"): GATConv((-1, -1), out_channels),
            ("artist", "has_tag_artists", "tag"): SAGEConv((-1, -1), out_channels),
            ("artist", "last_fm_match", "artist"): GATConv((-1, -1), out_channels),
            ("track", "has_tag_tracks", "tag"): SAGEConv((-1, -1), out_channels),
            ("artist", "linked_to", "artist"): GATConv((-1, -1), out_channels),
            ("artist", "musically_related_to", "artist"): GATConv((-1, -1), out_channels),
            ("artist", "personally_related_to", "artist"): GATConv((-1, -1), out_channels),
            ("tag", "tags_artists", "artist"): SAGEConv((-1, -1), out_channels),
            ("tag", "tags_tracks", "track"): SAGEConv((-1, -1), out_channels),
            ("track", "worked_by", "artist"): SAGEConv((-1, -1), out_channels),
            ("artist", "worked_in", "track"): SAGEConv((-1, -1), out_channels),
        }, aggr="mean")

    def forward(self, x_dict, edge_index_dict):
        x_dict = self.conv1(x_dict, edge_index_dict)
        norm_x_dict = {
            key: F.normalize(
                x,
                p=2,
                dim=-1
            )
            for key, x in x_dict.items()
        }
        return norm_x_dict

In [11]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix
import tqdm
import numpy as np

def train(model, train_loader, val_loader, optimizer, criterion, device, num_epochs):
    for epoch in range(num_epochs):
        model.train()  # Set model to training mode
        epoch_loss = 0.0
        
        for sampled_data in tqdm.tqdm(train_loader):
            # Move data to device
            sampled_data = sampled_data.to(device)
            
            # Forward pass
            pred_dict = model(sampled_data.x_dict, sampled_data.edge_index_dict)
            
            # Get predictions and labels for the 'collab_with' edge type
            edge_label_index = sampled_data['artist', 'collab_with', 'artist'].edge_label_index
            edge_label = sampled_data['artist', 'collab_with', 'artist'].edge_label

            src_emb = pred_dict['artist'][edge_label_index[0]]  # Source node embeddings
            dst_emb = pred_dict['artist'][edge_label_index[1]]  # Destination node embeddings
            
            # Compute the dot product between source and destination embeddings
            preds = (src_emb * dst_emb).sum(dim=-1)  # Scalar for each edge
            
            # Compute loss
            loss = criterion(preds, edge_label.float())
            epoch_loss += loss.item()
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        # Average loss for the epoch
        epoch_loss /= len(train_loader)
        print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {epoch_loss:.4f}")
        
        # Validation metrics
        model.eval()  # Set model to evaluation mode
        all_labels = []
        all_probs = []
        val_loss = 0.0
        
        with torch.no_grad():  # Disable gradient computation for validation
            for sampled_data in tqdm.tqdm(val_loader):
                # Move data to device
                sampled_data = sampled_data.to(device)
                
                # Forward pass
                pred_dict = model(sampled_data.x_dict, sampled_data.edge_index_dict)
                
                # Get predictions and labels for the 'collab_with' edge type
                edge_label_index = sampled_data['artist', 'collab_with', 'artist'].edge_label_index
                edge_label = sampled_data['artist', 'collab_with', 'artist'].edge_label

                src_emb = pred_dict['artist'][edge_label_index[0]]  # Source node embeddings
                dst_emb = pred_dict['artist'][edge_label_index[1]]  # Destination node embeddings
                
                # Compute the dot product between source and destination embeddings
                preds = (src_emb * dst_emb).sum(dim=-1)  # Scalar for each edge

                probs = torch.sigmoid(preds)  # Convert to probabilities

                loss = criterion(preds, edge_label.float())
                val_loss += loss.item()
                
                # Collect predictions, probabilities, and labels
                all_labels.append(edge_label.cpu())
                all_probs.append(probs.cpu())
        
        # Concatenate all predictions and labels
        all_labels = torch.cat(all_labels)
        all_probs = torch.cat(all_probs)
        val_loss /= len(val_loader)

        # Find threshold for predictions
        best_threshold = 0
        best_f1 = 0
        for threshold in np.arange(0.2, 0.81, 0.01):
            preds_binary = (all_probs > threshold).long()
            cm = confusion_matrix(all_labels, preds_binary)
            tp = cm[1, 1]
            fp = cm[0, 1]
            fn = cm[1, 0]
            tn = cm[0, 0]
            precision = 0 if tp == 0 else tp / (tp + fp)
            recall = 0 if tp == 0 else tp / (tp + fn)
            f1 = 2 * precision * recall / (precision + recall)
            if f1 > best_f1:
                best_threshold = threshold
                best_f1 = f1
        print(f"Best threshold: {best_threshold}")
        all_preds = (all_probs > best_threshold).long()
        
        # Compute metrics
        cm = confusion_matrix(all_labels, all_preds)
        tp = cm[1, 1]
        fp = cm[0, 1]
        fn = cm[1, 0]
        tn = cm[0, 0]
        accuracy = (tp + tn) / (tp + fp + fn + tn)
        precision = tp / (tp + fp)
        recall = tp / (tp + fn)
        f1 = 2 * precision * recall / (precision + recall)
        roc_auc = roc_auc_score(all_labels, all_probs)
        
        # Print validation metrics
        print(f"Validation Metrics - Epoch {epoch+1}/{num_epochs}:")
        print(f"Loss:      {val_loss:.4f}")
        print(f"Accuracy:  {accuracy:.4f}")
        print(f"Precision: {precision:.4f}")
        print(f"Recall:    {recall:.4f}")
        print(f"F1-score:  {f1:.4f}")
        print(f"ROC-AUC:   {roc_auc:.4f}")
        print(f"Confusion Matrix:\n{tp} {fn}\n{fp} {tn}")

    return best_threshold


In [12]:
def test_model(model, test_loader, criterion, device, threshold):
    model.eval()  # Set the model to evaluation mode
    all_preds = []
    all_labels = []
    all_probs = []
    test_loss = 0.0

    with torch.no_grad():  # Disable gradient computation
        for sampled_data in tqdm.tqdm(test_loader):
            # Move data to the device
            sampled_data = sampled_data.to(device)

            # Forward pass
            pred_dict = model(sampled_data.x_dict, sampled_data.edge_index_dict)

            # Get predictions and labels for the 'collab_with' edge type
            edge_label_index = sampled_data['artist', 'collab_with', 'artist'].edge_label_index
            edge_label = sampled_data['artist', 'collab_with', 'artist'].edge_label

            src_emb = pred_dict['artist'][edge_label_index[0]]  # Source node embeddings
            dst_emb = pred_dict['artist'][edge_label_index[1]]  # Destination node embeddings
            
            # Compute the dot product between source and destination embeddings
            preds = (src_emb * dst_emb).sum(dim=-1)  # Scalar for each edge
            probs = torch.sigmoid(preds)  # Convert logits to probabilities
            preds_binary = (probs > threshold).long()  # Convert probabilities to binary predictions

            # Compute loss
            loss = criterion(preds, edge_label.float())
            test_loss += loss.item()

            # Collect predictions and labels
            all_preds.append(preds_binary.cpu())
            all_labels.append(edge_label.cpu())
            all_probs.append(probs.cpu())

    # Concatenate all predictions and labels
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    all_probs = torch.cat(all_probs)

    # Compute metrics
    cm = confusion_matrix(all_labels, all_preds)
    tp = cm[1, 1]
    fp = cm[0, 1]
    fn = cm[1, 0]
    tn = cm[0, 0]
    accuracy = (tp + tn) / (tp + fp + fn + tn)
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    f1 = 2 * precision * recall / (precision + recall)
    roc_auc = roc_auc_score(all_labels, all_probs)

    # Average test loss
    test_loss /= len(test_loader)

    print("Test Results:")
    print(f"Loss:      {test_loss:.4f}")
    print(f"Accuracy:  {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall:    {recall:.4f}")
    print(f"F1-score:  {f1:.4f}")
    print(f"ROC-AUC:   {roc_auc:.4f}")
    print(f"Confusion Matrix:\n{tp} {fn}\n{fp} {tn}")

In [13]:
model = GNN(metadata=train_data.metadata(), out_channels=64).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

best_threshold = train(
    model,
    train_loader,
    val_loader,
    optimizer,
    F.binary_cross_entropy_with_logits,
    device,
    20
)


100%|██████████| 1230/1230 [02:31<00:00,  8.13it/s]


Epoch 1/20, Training Loss: 0.6736


100%|██████████| 615/615 [02:20<00:00,  4.37it/s]
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)


Best threshold: 0.46000000000000024
Validation Metrics - Epoch 1/20:
Loss:      0.6551
Accuracy:  0.6701
Precision: 0.6082
Recall:    0.9562
F1-score:  0.7435
ROC-AUC:   0.6488
Confusion Matrix:
150446 6889
96928 60407


100%|██████████| 1230/1230 [02:35<00:00,  7.93it/s]


Epoch 2/20, Training Loss: 0.6134


100%|██████████| 615/615 [02:19<00:00,  4.41it/s]
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)


Best threshold: 0.6200000000000003
Validation Metrics - Epoch 2/20:
Loss:      0.5964
Accuracy:  0.7838
Precision: 0.7432
Recall:    0.8674
F1-score:  0.8005
ROC-AUC:   0.8186
Confusion Matrix:
136469 20866
47154 110181


100%|██████████| 1230/1230 [02:31<00:00,  8.13it/s]


Epoch 3/20, Training Loss: 0.5998


100%|██████████| 615/615 [02:20<00:00,  4.38it/s]
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)


Best threshold: 0.6300000000000003
Validation Metrics - Epoch 3/20:
Loss:      0.5855
Accuracy:  0.7885
Precision: 0.7524
Recall:    0.8602
F1-score:  0.8027
ROC-AUC:   0.8414
Confusion Matrix:
135333 22002
44543 112792


100%|██████████| 1230/1230 [02:31<00:00,  8.10it/s]


Epoch 4/20, Training Loss: 0.5925


100%|██████████| 615/615 [02:20<00:00,  4.38it/s]
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)


Best threshold: 0.6400000000000003
Validation Metrics - Epoch 4/20:
Loss:      0.5876
Accuracy:  0.7927
Precision: 0.7614
Recall:    0.8524
F1-score:  0.8044
ROC-AUC:   0.8470
Confusion Matrix:
134115 23220
42019 115316


100%|██████████| 1230/1230 [02:31<00:00,  8.13it/s]


Epoch 5/20, Training Loss: 0.5904


100%|██████████| 615/615 [02:19<00:00,  4.40it/s]
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)


Best threshold: 0.6400000000000003
Validation Metrics - Epoch 5/20:
Loss:      0.5835
Accuracy:  0.7992
Precision: 0.7699
Recall:    0.8536
F1-score:  0.8096
ROC-AUC:   0.8470
Confusion Matrix:
134295 23040
40145 117190


100%|██████████| 1230/1230 [02:31<00:00,  8.12it/s]


Epoch 6/20, Training Loss: 0.5890


100%|██████████| 615/615 [02:21<00:00,  4.35it/s]
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)


Best threshold: 0.6100000000000003
Validation Metrics - Epoch 6/20:
Loss:      0.5803
Accuracy:  0.7925
Precision: 0.7533
Recall:    0.8699
F1-score:  0.8074
ROC-AUC:   0.8498
Confusion Matrix:
136858 20477
44824 112511


100%|██████████| 1230/1230 [02:32<00:00,  8.09it/s]


Epoch 7/20, Training Loss: 0.5876


100%|██████████| 615/615 [02:20<00:00,  4.38it/s]
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)


Best threshold: 0.6500000000000004
Validation Metrics - Epoch 7/20:
Loss:      0.5821
Accuracy:  0.8002
Precision: 0.7670
Recall:    0.8624
F1-score:  0.8119
ROC-AUC:   0.8531
Confusion Matrix:
135693 21642
41225 116110


100%|██████████| 1230/1230 [02:32<00:00,  8.08it/s]


Epoch 8/20, Training Loss: 0.5886


100%|██████████| 615/615 [02:21<00:00,  4.35it/s]
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)


Best threshold: 0.6300000000000003
Validation Metrics - Epoch 8/20:
Loss:      0.5807
Accuracy:  0.7975
Precision: 0.7631
Recall:    0.8627
F1-score:  0.8099
ROC-AUC:   0.8503
Confusion Matrix:
135736 21599
42136 115199


100%|██████████| 1230/1230 [02:31<00:00,  8.11it/s]


Epoch 9/20, Training Loss: 0.5879


100%|██████████| 615/615 [02:19<00:00,  4.40it/s]
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)


Best threshold: 0.6500000000000004
Validation Metrics - Epoch 9/20:
Loss:      0.5820
Accuracy:  0.8060
Precision: 0.7788
Recall:    0.8548
F1-score:  0.8150
ROC-AUC:   0.8571
Confusion Matrix:
134487 22848
38198 119137


100%|██████████| 1230/1230 [02:33<00:00,  8.02it/s]


Epoch 10/20, Training Loss: 0.5870


100%|██████████| 615/615 [02:22<00:00,  4.32it/s]
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)


Best threshold: 0.6400000000000003
Validation Metrics - Epoch 10/20:
Loss:      0.5813
Accuracy:  0.8050
Precision: 0.7716
Recall:    0.8665
F1-score:  0.8163
ROC-AUC:   0.8590
Confusion Matrix:
136333 21002
40357 116978


100%|██████████| 1230/1230 [02:30<00:00,  8.16it/s]


Epoch 11/20, Training Loss: 0.5857


100%|██████████| 615/615 [02:20<00:00,  4.38it/s]
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)


Best threshold: 0.6200000000000003
Validation Metrics - Epoch 11/20:
Loss:      0.5783
Accuracy:  0.7999
Precision: 0.7612
Recall:    0.8739
F1-score:  0.8137
ROC-AUC:   0.8550
Confusion Matrix:
137495 19840
43124 114211


100%|██████████| 1230/1230 [02:32<00:00,  8.09it/s]


Epoch 12/20, Training Loss: 0.5862


100%|██████████| 615/615 [02:21<00:00,  4.35it/s]
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)


Best threshold: 0.6400000000000003
Validation Metrics - Epoch 12/20:
Loss:      0.5787
Accuracy:  0.8057
Precision: 0.7783
Recall:    0.8548
F1-score:  0.8148
ROC-AUC:   0.8557
Confusion Matrix:
134493 22842
38304 119031


100%|██████████| 1230/1230 [02:31<00:00,  8.10it/s]


Epoch 13/20, Training Loss: 0.5855


100%|██████████| 615/615 [02:20<00:00,  4.39it/s]
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)


Best threshold: 0.6400000000000003
Validation Metrics - Epoch 13/20:
Loss:      0.5801
Accuracy:  0.7981
Precision: 0.7769
Recall:    0.8364
F1-score:  0.8055
ROC-AUC:   0.8501
Confusion Matrix:
131592 25743
37790 119545


100%|██████████| 1230/1230 [02:32<00:00,  8.08it/s]


Epoch 14/20, Training Loss: 0.5853


100%|██████████| 615/615 [02:19<00:00,  4.41it/s]
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)


Best threshold: 0.6300000000000003
Validation Metrics - Epoch 14/20:
Loss:      0.5760
Accuracy:  0.8022
Precision: 0.7750
Recall:    0.8518
F1-score:  0.8116
ROC-AUC:   0.8529
Confusion Matrix:
134021 23314
38914 118421


100%|██████████| 1230/1230 [02:31<00:00,  8.11it/s]


Epoch 15/20, Training Loss: 0.5838


100%|██████████| 615/615 [02:18<00:00,  4.45it/s]
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)


Best threshold: 0.6200000000000003
Validation Metrics - Epoch 15/20:
Loss:      0.5762
Accuracy:  0.8013
Precision: 0.7722
Recall:    0.8548
F1-score:  0.8114
ROC-AUC:   0.8473
Confusion Matrix:
134487 22848
39673 117662


100%|██████████| 1230/1230 [02:31<00:00,  8.13it/s]


Epoch 16/20, Training Loss: 0.5841


100%|██████████| 615/615 [02:19<00:00,  4.41it/s]
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)


Best threshold: 0.6300000000000003
Validation Metrics - Epoch 16/20:
Loss:      0.5756
Accuracy:  0.8046
Precision: 0.7779
Recall:    0.8528
F1-score:  0.8136
ROC-AUC:   0.8542
Confusion Matrix:
134170 23165
38310 119025


100%|██████████| 1230/1230 [02:32<00:00,  8.06it/s]


Epoch 17/20, Training Loss: 0.5843


100%|██████████| 615/615 [02:16<00:00,  4.50it/s]
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)


Best threshold: 0.6100000000000003
Validation Metrics - Epoch 17/20:
Loss:      0.5762
Accuracy:  0.7983
Precision: 0.7583
Recall:    0.8756
F1-score:  0.8127
ROC-AUC:   0.8627
Confusion Matrix:
137762 19573
43908 113427


100%|██████████| 1230/1230 [02:28<00:00,  8.30it/s]


Epoch 18/20, Training Loss: 0.5834


100%|██████████| 615/615 [02:17<00:00,  4.48it/s]
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)


Best threshold: 0.6300000000000003
Validation Metrics - Epoch 18/20:
Loss:      0.5785
Accuracy:  0.8050
Precision: 0.7686
Recall:    0.8727
F1-score:  0.8174
ROC-AUC:   0.8627
Confusion Matrix:
137310 20025
41339 115996


100%|██████████| 1230/1230 [02:30<00:00,  8.16it/s]


Epoch 19/20, Training Loss: 0.5832


100%|██████████| 615/615 [02:26<00:00,  4.19it/s]
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)


Best threshold: 0.6500000000000004
Validation Metrics - Epoch 19/20:
Loss:      0.5793
Accuracy:  0.8051
Precision: 0.7833
Recall:    0.8435
F1-score:  0.8123
ROC-AUC:   0.8630
Confusion Matrix:
132716 24619
36714 120621


100%|██████████| 1230/1230 [02:38<00:00,  7.77it/s]


Epoch 20/20, Training Loss: 0.5830


100%|██████████| 615/615 [02:29<00:00,  4.12it/s]
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)
  precision = tp / (tp + fp)


Best threshold: 0.6500000000000004
Validation Metrics - Epoch 20/20:
Loss:      0.5748
Accuracy:  0.8045
Precision: 0.7849
Recall:    0.8388
F1-score:  0.8110
ROC-AUC:   0.8606
Confusion Matrix:
131969 25366
36163 121172


In [14]:
# best_threshold = train(
#     model,
#     train_loader,
#     val_loader,
#     optimizer,
#     F.binary_cross_entropy_with_logits,
#     device,
#     100
# )

In [20]:
test_model(
    model,
    test_loader,
    F.binary_cross_entropy_with_logits,
    device,
    best_threshold
)

100%|██████████| 615/615 [02:31<00:00,  4.06it/s]


Test Results:
Loss:      0.5729
Accuracy:  0.8108
Precision: 0.7808
Recall:    0.8643
F1-score:  0.8204
ROC-AUC:   0.8695
Confusion Matrix:
135986 21349
38187 119148


In [17]:
torch.save(model.state_dict(), "./normal.pth")

In [21]:
test = GNN(metadata=train_data.metadata(), out_channels=64).to(device)
test.load_state_dict(torch.load("./normal.pth"))
test_model(
    test,
    test_loader,
    F.binary_cross_entropy_with_logits,
    device,
    best_threshold
)

  test.load_state_dict(torch.load("./normal.pth"))
100%|██████████| 615/615 [02:32<00:00,  4.03it/s]


Test Results:
Loss:      0.5729
Accuracy:  0.8108
Precision: 0.7811
Recall:    0.8635
F1-score:  0.8202
ROC-AUC:   0.8697
Confusion Matrix:
135862 21473
38076 119259
