In [None]:
import rdflib
from rdflib import URIRef
from rdflib.namespace import OWL, RDF, RDFS,XSD, Namespace
import csv
from torch_geometric.data import HeteroData
import torch
import numpy as np
import matplotlib.pyplot as plt

In [None]:
from torch_geometric.datasets import HGBDataset

HGBDataset(root='.',name="Freebase").process()
complete_data = torch.load('./freebase/processed/data.pt')[0]
for k in complete_data.node_types:
    complete_data[k].x = torch.Tensor([[1] for i in range(complete_data[k].num_nodes)])

In [None]:
#GRAPH PREPROCESSING: REMOVE SELF LOOPS AND ISOLATED NODES
from torch_geometric.utils import remove_self_loops, remove_isolated_nodes

for edge_type in complete_data.edge_index_dict.keys():
    if edge_type[0] == edge_type[2]:
        new_edge_index = remove_self_loops(complete_data[edge_type].edge_index)[0]
        complete_data[edge_type].edge_index = new_edge_index
    new_edge_index = remove_isolated_nodes(complete_data[edge_type].edge_index)[0]
    complete_data[edge_type].edge_index = new_edge_index

In [None]:
edge_types = list(complete_data.edge_index_dict.keys())

In [None]:
#SAVE ROOT IN A DICT TO NOT LOST THEM DURING THE GNN COMPUTATION
root_nodes_types = {}
for node_type in complete_data.x_dict.keys():
    i = 0
    for edge_t in edge_types:
        if node_type == edge_t[2]: break 
        i+=1
    if i == len(edge_types):
        root_nodes_types[node_type] = complete_data.x_dict[node_type]

In [None]:
#MAPPING FROM RELATION TYPE TO AN INT

rel_to_index = {edge_t:i for edge_t,i in zip(edge_types,range(len(edge_types)))}

In [None]:
from torch_geometric.nn import HeteroConv, GATConv, Linear

class HeteroGNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, edge_types, root_nodes_types):
        super().__init__()
        g = torch.manual_seed(0)
        self.conv1 = HeteroConv({edge_t: GATConv((1, 1) ,hidden_channels,add_self_loops=False) for edge_t in edge_types})
        self.conv2 = HeteroConv({edge_t: GATConv((hidden_channels, hidden_channels), out_channels,add_self_loops=False) for edge_t in edge_types})
        self.rel_weight = torch.nn.Parameter(torch.randn(len(edge_types), out_channels, generator=g))
    
    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()

    def forward(self, x_dict, edge_index_dict, data):
        x_dict = self.conv1(x_dict, edge_index_dict)
        for t,v in root_nodes_types.items(): #RE-ADD ROOT NODES THAT ARE DISCARDED BY DEFAULT
            x_dict[t] = v
        x_dict = {key: x.relu() for key, x in x_dict.items()}
        
        x_dict = self.conv2(x_dict, edge_index_dict)
        for t,v in root_nodes_types.items():
            x_dict[t] = v
            
        out = x_dict
    
        pred_dict = {}
        ### LINK PREDICTION ACTS HERE ###
        for edge_t in edge_types:
            #Compute link embedding for each edge type
            #for src in train_link[edge_t].edge_label_index[0]:
            out_src = out[edge_t[0]][data[edge_t].edge_label_index[0]]#embedding src nodes
            out_dst = out[edge_t[2]][data[edge_t].edge_label_index[1]] #embedding dst nodes
        
            # LINK EMBEDDING #
            # 2- Distmult with random initialized rel weights
            #r = torch.Tensor([self.rel_weight[rel_to_index[edge_t]].detach().numpy() for z in range(len(out_src))])
            out_sim = torch.sum(out_src * self.rel_weight[rel_to_index[edge_t]] * out_dst, dim=-1)
            pred = out_sim
        
            pred_dict[edge_t] = pred
        return pred_dict

In [None]:
import random
device = torch.device('cuda')
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
np.random.seed(0)
random.seed(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.cuda.empty_cache()

In [None]:
model = HeteroGNN(hidden_channels=4, out_channels=2, 
                  edge_types=edge_types,
                  root_nodes_types=root_nodes_types)

model.reset_parameters()

In [None]:
#SPLIT THE KG IN TRAIN AND TEST SET
from torch_geometric.transforms import RandomLinkSplit

link_split = RandomLinkSplit(num_val=0.0,
                             num_test=0.25,
                             edge_types=edge_types,
                             rev_edge_types=[None]*len(edge_types))
train_link, val_link, test_link = link_split(complete_data)

In [None]:
with torch.no_grad():  # Initialize lazy modules.
    out = model(train_link.x_dict,train_link.edge_index_dict, train_link)

In [None]:
#CHOOSE AN OPTIMIZER AND TRAINING HYPERPARAMETERS
weight_decay=5e-4
#optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=weight_decay)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01, weight_decay = weight_decay)
#optimizer = torch.optim.RMSprop(model.parameters(), lr=0.1, weight_decay = weight_decay)
#optimizer = torch.optim.Adagrad(model.parameters(), lr=0.01, weight_decay=weight_decay)
criterion =  torch.nn.BCEWithLogitsLoss() #change loss function

In [None]:
def train_hetlinkpre():
    model.train()
    optimizer.zero_grad()  # Clear gradients.
    pred_dict = model(train_link.x_dict, train_link.edge_index_dict, train_link)  # Perform a single forward pass.
    edge_labels = torch.Tensor()
    preds = torch.Tensor()
    for edge_t in edge_types:
        preds = torch.cat((preds,pred_dict[edge_t]),-1)
        edge_labels = torch.cat((edge_labels,train_link[edge_t].edge_label.type_as(pred_dict[edge_t])),-1)
    #compute loss function based on all edge types
    loss = criterion(preds, edge_labels)
    loss.backward()  # Derive gradients.
    optimizer.step()  # Update parameters based on gradients.
    return loss

In [None]:
from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score

def test_hetlinkpre(test_link,evaluate='linkpre'):
    if evaluate not in ['linkpre','propdetection','all']:
        #linkpre: link between entities
        #propdetection: link between an entity and a property
        #all: both 
        raise NotImplementedError()
    model.eval()
    hs_dict = model(test_link.x_dict, test_link.edge_index_dict, test_link) #obtain edge embeddings of the test set
    hs = torch.Tensor([])
    edge_labels = np.array([])
    ### LINK PREDICTION ACTS HERE ###
    #evaluate distincly entity-to-entity link prediction and entity-to-property(property-detection)
    prop = ['String','Integer','Double','gYear','Date'] #add other property types if used
    rel_with_prop = [edge_t for edge_t in edge_types if edge_t[2] in prop]
    if evaluate == 'linkpre':
        edge_types_to_evaluate = [edge_t for edge_t in edge_types if edge_t not in rel_with_prop]
    elif evaluate == 'propdetection':
        edge_types_to_evaluate = rel_with_prop
    else:
        edge_types_to_evaluate = edge_types
    for edge_t in edge_types_to_evaluate:
        hs = torch.cat((hs,hs_dict[edge_t]),-1)
        edge_labels = np.concatenate((edge_labels,test_link[edge_t].edge_label.cpu().detach().numpy()))
    
    
    pred_cont = torch.sigmoid(hs).cpu().detach().numpy()
    
    # EVALUATION
    if evaluate=='propdetection':
        test_roc_score = average_precision_score(edge_labels, pred_cont)
    else:
        test_roc_score = roc_auc_score(edge_labels, pred_cont) #compute AUROC score for test set
    
    return test_roc_score

In [None]:
perf_train = []
perf_test = []
num_epochs = 200
for epoch in range(num_epochs):
    loss = train_hetlinkpre()
    roc_train = test_hetlinkpre(train_link)
    roc_test = test_hetlinkpre(test_link)
    perf_train.append(roc_train)
    perf_test.append(roc_test)
    #print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, ROC train: {roc_train:.4f}, ROC test: {roc_test:.4f}')

In [None]:
roc_train = test_hetlinkpre(train_link,evaluate='linkpre')
roc_test = test_hetlinkpre(test_link,evaluate='linkpre')
print(f'Train AUROC: {roc_train:.4f}\nTest AUROC: {roc_test:.4f}')

In [None]:
import matplotlib.pyplot as plt

x = range(num_epochs)
plt.clf()
plt.plot(x, perf_train, color='orange', label='auroc_train')
plt.plot(x, perf_test, color='blue', label='auroc_test')
plt.xlabel('Epoch')
plt.ylabel('AUROC-score')
plt.legend()
plt.ylim(top=1)
plt.grid()
plt.show()
#plt.clf()