In [1]:
from torch_geometric.datasets import Reddit
from torch_geometric.datasets import CitationFull
import torch.nn.functional as F
import graphsage_calculate_embeddings
import test_embeddings
import torch.nn.functional as F
import torch
import locale


# Download dataset

In [2]:
dataset = Reddit("./")
data = dataset[0]

# Hyperparameters

In [6]:
learning_rate = 0.0001 
aggregator = 'MeanAggregation'

epochs = 10
dropout_rate = 0.4
normalization = True 
activation_function = F.relu
bias = True
batch_size =  512
neighborhood_1 = 25
neighborhood_2 = 10
embedding_dimension = 128
hidden_layer = 512
project = False

# Obtain embedding matrix

In [7]:
number_features, number_nodes = data.num_features, data.x.shape[0]
data = data.sort(sort_by_row=False)

In [None]:
embedding_matrix = graphsage_calculate_embeddings.compute_embedding_matrix(
    data=data,
    number_features=number_features,
    number_nodes=number_nodes,
    batch_size=batch_size,
    hidden_layer=hidden_layer,
    epochs=epochs,
    neighborhood_1=neighborhood_1,
    neighborhood_2=neighborhood_2,
    embedding_dimension=embedding_dimension,
    learning_rate=learning_rate,
    dropout_rate=dropout_rate,
    activation_function=activation_function,
    aggregator=aggregator,
    activation_before_normalization=True, 
    bias=True,
    normalize=normalization, 
    project=project
)


# Save embedding

In [None]:
file_name = 'embeddings/reddit.pt'
torch.save(embedding_matrix, file_name)

# Testing results

# Node clasification

In [None]:
acc, f1_macro, f1_micro = test_embeddings.test_node_classification_multi_class(embedding_matrix, data.y)
#print(f"Accuracy: {acc*100:.4f}, F1_macro: {f1_macro*100:.4f}, F1_micro: {f1_micro*100:.4f}")

In [None]:
locale.setlocale(locale.LC_ALL, 'de_DE')

# Format the numbers with four digits after the decimal and replace the dot with a comma
formatted_acc = locale.format_string("%.4f", acc * 100).replace('.', ',')
formatted_f1_macro = locale.format_string("%.4f", f1_macro * 100).replace('.', ',')
formatted_f1_micro = locale.format_string("%.4f", f1_micro * 100).replace('.', ',')

print(f"Accuracy: {formatted_acc}, F1_macro: {formatted_f1_macro}, F1_micro: {formatted_f1_micro}")

Accuracy: 44,0762, F1_macro: 25,5765, F1_micro: 44,0762


# Link Prediction

In [None]:
train_data, test_data = test_embeddings.train_test_split_graph(data = data, is_undirected = True) # TODO: change the is_undirected depending on graph

# Prepare edges
test_edges = test_data.edge_label_index.numpy().T
y_true = test_data.edge_label.numpy()

# Prepare embeddings
embedding_detached = embedding_matrix.detach()
embedding_np = embedding_detached.numpy()

In [None]:
roc_auc_score = test_embeddings.k_fold_cross_validation_link_prediction(embedding_np, test_edges, y_true, k=5)


In [None]:
formatted_score = "{:.4f}".format(roc_auc_score * 100).replace('.', ',')
print("ROC AUC Score:", formatted_score)