In [1]:
import rdflib
import networkx as nx
import torch
import time
import random
import matplotlib.pyplot as plt
from torch_geometric.nn import GCNConv
import torch.nn.functional as F
import numpy as np
from rdflib.util import guess_format
from sklearn.preprocessing import label_binarize
from hyperparameter_tuning import weight_tuning
from torch_geometric.utils import from_networkx
from sklearn.metrics import precision_score, f1_score, roc_auc_score
from sampling_techniques import random_node_sampling, node_type_sampling, edge_type_sampling, node_edge_type_sampling, \
    degree_based_sampling, degree_centrality_sampling, pagerank_sampling, node_type_pagerank_sampling, edge_type_pagerank_sampling, \
    node_type_degree_based_sampling

In [2]:
def get_node_labels(G):
    """This function assigns numerical labels to nodes in the graph based on their type"""
    label_dict = {'uri': 0, 'literal': 1, 'predicate': 2}
    labels = []
    for node in G.nodes(data=True):
        node_type = node[1]['node_type']
        labels.append(label_dict.get(node_type, -1))
    return labels

In [3]:
def create_masks(data, train_nodes, val_nodes, test_nodes):
    """This function creates boolean masks for train, validation, and test nodes in the graph data"""
    train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    train_mask[train_nodes] = True
    data.train_mask = train_mask

    val_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    val_mask[val_nodes] = True
    data.val_mask = val_mask

    test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    test_mask[test_nodes] = True
    data.test_mask = test_mask

In [4]:
def rdf_to_data(rdf_file, sampling_technique, type_weights):
    # Parsing the RDF file
    g = rdflib.Graph()
    g.parse(rdf_file, format=guess_format(rdf_file))

    # Creating a NetworkX graph
    G = nx.MultiDiGraph()

    # Creating a node-to-index mapping
    node_to_index = {}

    # Iterating through the RDF triples
    for i, (s, p, o) in enumerate(g):
        if str(s) not in node_to_index:
            node_to_index[str(s)] = len(node_to_index)
        if str(p) not in node_to_index:
            node_to_index[str(p)] = len(node_to_index)
        if str(o) not in node_to_index:
            node_to_index[str(o)] = len(node_to_index)

    # Iterating through the RDF triples
    for s, p, o in g:
        # Adding subject and object nodes to the graph
        G.add_node(str(s), node_type='uri')
        if isinstance(o, rdflib.URIRef):
            G.add_node(str(o), node_type='uri')
        else:
            G.add_node(str(o), node_type='literal')

        # Adding predicate nodes to the graph
        G.add_node(str(p), node_type='predicate')

        # Adding edges connecting subject, predicate, and object nodes
        G.add_edge(str(s), str(p), key=str(o), label=str(p), edge_type='subj_pred')
        G.add_edge(str(p), str(o), key=str(o), label=str(p), edge_type='pred_obj')

    # Converting the graph to Pytorch Geometric data
    data = from_networkx(G)
    data.x = torch.ones((data.num_nodes, 1))
    labels = get_node_labels(G)
    data.y = torch.tensor(labels, dtype=torch.long)

    # Splitting edges into training, validation and test sets
    num_edges = G.number_of_edges()
    num_train = int(0.4 * num_edges)
    num_val = int(0.3 * num_edges)

    edges = list(G.edges())
    train_edges = random.sample(edges, num_train)

    rest_edges = list(set(edges) - set(train_edges))
    val_edges = random.sample(rest_edges, num_val)
    test_edges = list(set(rest_edges) - set(val_edges))

    # Creating masks for training, validation, and test sets
    train_nodes = [node_to_index[n] for n in set([n for edge in train_edges for n in edge])]
    val_nodes = [node_to_index[n] for n in set([n for edge in val_edges for n in edge])]
    test_nodes = [node_to_index[n] for n in set([n for edge in test_edges for n in edge])]
    create_masks(data, train_nodes, val_nodes, test_nodes)

    # Applying the sampling technique on the training nodes
    train_nodes = list(set([n for edge in train_edges for n in edge]))
    train_G = G.subgraph(train_nodes)
    num_samples = int(0.3 * train_G.number_of_nodes())

    G_sampled = sampling_technique(train_G, num_samples)

    # Updating training nodes and training mask
    train_nodes_sampled = list(G_sampled.nodes)
    train_nodes_sampled = [node_to_index[node] for node in train_nodes_sampled]
    data.train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    data.train_mask[train_nodes_sampled] = True

    return data

In [5]:
class GNN(torch.nn.Module):
    def __init__(self, num_features, num_classes):
        super(GNN, self).__init__()
        self.conv1 = GCNConv(num_features, 16)
        self.conv2 = GCNConv(16, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)

        return F.log_softmax(x, dim=1)

In [None]:
sampling_techniques = [random_node_sampling, node_type_sampling, edge_type_sampling, node_edge_type_sampling, degree_based_sampling, \
                       degree_centrality_sampling, pagerank_sampling, node_type_pagerank_sampling, edge_type_pagerank_sampling, node_type_degree_based_sampling]

accuracy_results = []
precision_results = []
f1_results = []
roc_results = []

best_weights = []

for sampling_technique in sampling_techniques:

    rdf_file = "datasets/aifb.nt"

    # # Code for running the Parameter Tuning
    # if sampling_technique == node_type_sampling or sampling_technique == edge_type_sampling or sampling_technique == pagerank_sampling:
    #     best_weights = weight_tuning(rdf_to_data, GNN, rdf_file, sampling_technique)
    #     print(best_weights)

    start_time = time.time()
    data_sampled = rdf_to_data(rdf_file, sampling_technique, None)
    
     # Retrieving the number of features and classes
    num_features = data_sampled.x.shape[1]
    num_classes = len(torch.unique(data_sampled.y))

    # Setting device and initialize the model and optimizer
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = GNN(num_features=num_features, num_classes=num_classes).to(device)
    data_sampled = data_sampled.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

    # Using CrossEntropyLoss for multi-class classification
    loss_fn = torch.nn.CrossEntropyLoss()

    # Training settings
    model.train()

    best_val_loss = float('inf')
    patience = 10
    patience_counter = 0

    sampled_losses = []

    # Training loop
    for epoch in range(1000):
        optimizer.zero_grad()
        out = model(data_sampled)
        
        # Calculating loss only for nodes in the training set
        loss = loss_fn(out[data_sampled.train_mask], data_sampled.y[data_sampled.train_mask])
        loss.backward()
        optimizer.step()

        # Evaluating on the validation set
        model.eval()
        with torch.no_grad():
            val_out = model(data_sampled)
            val_loss = loss_fn(val_out[data_sampled.val_mask], data_sampled.y[data_sampled.val_mask])

        # If validation loss has increased, increment patience counter
        if val_loss >= best_val_loss:
            patience_counter += 1
        else:  # Else, save this model and reset patience counter
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pth')
            patience_counter = 0

        # If patience counter has reached limit, stop training
        if patience_counter >= patience:
            print("Early stopping")
            break
        
        sampled_losses.append(loss.item())

        model.train()

    # Loading the best model
    model.load_state_dict(torch.load('best_model.pth'))

    model.eval()
    out = model(data_sampled)
    _, pred = out.max(dim=1)

    # Converting the outputs to probabilities using softmax
    probs = torch.nn.functional.softmax(out, dim=1)

    # Converting the PyTorch tensors to numpy arrays for sklearn
    y_true = data_sampled.y[data_sampled.test_mask].cpu().numpy()
    y_pred = pred[data_sampled.test_mask].cpu().numpy()
    y_prob = probs.detach().cpu().numpy()
    y_prob_test = y_prob[data_sampled.test_mask.cpu().numpy()]

    # Performing one-hot encoding of the true labels
    y_true_bin = label_binarize(y_true, classes=np.unique(y_true))

    # Computing accuracy, precision, F1-score and ROC AUC score
    correct = float(pred[data_sampled.test_mask].eq(data_sampled.y[data_sampled.test_mask]).sum().item())
    accuracy = correct / data_sampled.test_mask.sum().item()
    precision = precision_score(y_true, y_pred, average='weighted')
    f1 = f1_score(y_true, y_pred, average='weighted')
    roc_auc = roc_auc_score(y_true_bin, y_prob_test, multi_class='ovr', average='macro')

    accuracy_results.append(accuracy)
    precision_results.append(precision)
    f1_results.append(f1)
    roc_results.append(roc_auc)

    end_time = time.time()
    total_time = end_time - start_time

    print('Accuracy: {:.4f}'.format(accuracy))
    print('Precision: {:.4f}'.format(precision))
    print('F1-score: {:.4f}'.format(f1))
    print('ROC: {:.4f}'.format(roc_auc))
    print('Total time: {:.4f} seconds'.format(total_time))

    # Plotting the losses
    plt.figure()
    plt.plot(range(1, len(sampled_losses) + 1), sampled_losses, label='Sampled Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'Losses for {sampling_technique.__name__}')
    plt.legend()
    plt.show()

# Plotting the performance metrics for different sampling techniques
plt.figure(figsize=(10, 7))

sampling_techniques_names = [func.__name__ for func in sampling_techniques]

plt.plot(sampling_techniques_names, accuracy_results, label='Accuracy', marker='o')
plt.plot(sampling_techniques_names, precision_results, label='Precision', marker='o')
plt.plot(sampling_techniques_names, f1_results, label='F1-score', marker='o')
plt.plot(sampling_techniques_names, roc_results, label='ROC', marker='o')

plt.xlabel('Sampling Techniques')
plt.ylabel('Performance Metrics')
plt.title('Performance comparison of different sampling techniques')
plt.legend()

plt.ylim([0,1])
plt.xticks(rotation=45)

plt.tight_layout()
plt.show()