In [12]:
import graphsage_calculate_embeddings
import test_embeddings
from torch_geometric.datasets import PPI
import torch.nn.functional as F
import torch
import locale

# Read in data

In [13]:
dataset = PPI("./")
data = dataset[0]

# Hyperparameters

In [14]:
learning_rate = 0.0001 
aggregator = 'MeanAggregation'

epochs = 10
dropout_rate = 0.4
normalization = True 
activation_function = F.relu
bias = True
batch_size =  512
neighborhood_1 = 25
neighborhood_2 = 10
embedding_dimension = 128
hidden_layer = 512
project = False

# Obtain embedding matrix

In [15]:
number_features, number_nodes = data.num_features, data.x.shape[0]
data = data.sort(sort_by_row=False)

In [16]:
embedding_matrix = graphsage_calculate_embeddings.compute_embedding_matrix(
    data = data,
    number_features = number_features,
    number_nodes = number_nodes,
    batch_size = batch_size,
    hidden_layer = hidden_layer, 
    epochs = epochs, 
    neighborhood_1 = neighborhood_1,
    neighborhood_2 = neighborhood_2,
    embedding_dimension = embedding_dimension,
    learning_rate = learning_rate,
    dropout_rate = dropout_rate,
    activation_function = activation_function,
    aggregator = aggregator,
    activation_before_normalization = True, 
    bias= True,
    normalize = normalization, 
    project = project
)


Training Progress:   9%|▉         | 1/11 [00:03<00:30,  3.09s/it]

Epoch: 000, Total loss: 11.8686, time_taken: 3.0946149826049805


Training Progress:  18%|█▊        | 2/11 [00:06<00:28,  3.16s/it]

Epoch: 001, Total loss: 11.5195, time_taken: 3.2128729820251465


Training Progress:  27%|██▋       | 3/11 [00:09<00:25,  3.14s/it]

Epoch: 002, Total loss: 11.0202, time_taken: 3.1034491062164307


Training Progress:  36%|███▋      | 4/11 [00:12<00:21,  3.09s/it]

Epoch: 003, Total loss: 10.8675, time_taken: 3.028134822845459


Training Progress:  45%|████▌     | 5/11 [00:15<00:18,  3.08s/it]

Epoch: 004, Total loss: 10.9041, time_taken: 3.0403189659118652


Training Progress:  55%|█████▍    | 6/11 [00:18<00:15,  3.09s/it]

Epoch: 005, Total loss: 10.8273, time_taken: 3.126552104949951


Training Progress:  64%|██████▎   | 7/11 [00:21<00:12,  3.10s/it]

Epoch: 006, Total loss: 10.8377, time_taken: 3.1052498817443848


Training Progress:  73%|███████▎  | 8/11 [00:24<00:09,  3.07s/it]

Epoch: 007, Total loss: 10.8221, time_taken: 3.010427713394165


Training Progress:  82%|████████▏ | 9/11 [00:27<00:06,  3.11s/it]

Epoch: 008, Total loss: 10.8230, time_taken: 3.2019190788269043


Training Progress:  91%|█████████ | 10/11 [00:31<00:03,  3.11s/it]

Epoch: 009, Total loss: 10.8259, time_taken: 3.097565174102783


Training Progress: 100%|██████████| 11/11 [00:34<00:00,  3.10s/it]

Epoch: 010, Total loss: 10.7990, time_taken: 3.070646047592163
Median time per epoch: 3.0976s





# Save embedding matrix

In [17]:
torch.save(embedding_matrix, 'embeddings/ppi.pt')

In [None]:
# How to load it again: 
embedding_matrix = torch.load('embeddings/ppi.pt')

# Evaluate node classification 

In [18]:
acc, f1_macro, f1_micro = test_embeddings.test_node_classification_multi_class(embedding_matrix, data.y)

In [19]:
locale.setlocale(locale.LC_ALL, 'de_DE')

# Format the numbers with four digits after the decimal and replace the dot with a comma
formatted_acc = locale.format_string("%.4f", acc * 100).replace('.', ',')
formatted_f1_macro = locale.format_string("%.4f", f1_macro * 100).replace('.', ',')
formatted_f1_micro = locale.format_string("%.4f", f1_micro * 100).replace('.', ',')

print(f"Accuracy: {formatted_acc}, F1_macro: {formatted_f1_macro}, F1_micro: {formatted_f1_micro}")

Accuracy: 0,0000, F1_macro: 11,8861, F1_micro: 41,0880


# Link Prediction

In [20]:
train_data, test_data = test_embeddings.train_test_split_graph(data = data, is_undirected = False) # TODO: change the is_undirected depending on graph

# Prepare edges
test_edges = test_data.edge_label_index.numpy().T
y_true = test_data.edge_label.numpy()

# Prepare embeddings
embedding_detached = embedding_matrix.detach()
embedding_np = embedding_detached.numpy()

In [22]:
roc_auc_score = test_embeddings.k_fold_cross_validation_link_prediction(embedding_np, test_edges, y_true, k=5)


In [23]:
formatted_score = "{:.4f}".format(roc_auc_score * 100).replace('.', ',')
print("ROC AUC Score:", formatted_score)

ROC AUC Score: 88,3325
