In [1]:
from KGEkeras.models import ComplEx

In [2]:
from rdflib import Graph, URIRef, Literal

import numpy as np

In [3]:

from KGEkeras.models import DistMult, HolE, TransE, ComplEx, HAKE, ConvE, ModE, ConvR, ConvKB, RotatE, pRotatE
import numpy as np
import tensorflow as tf
from random import choice, choices
from collections import defaultdict

from tensorflow.keras.layers import Input 
from tqdm import tqdm
from tensorflow.keras.callbacks import Callback, EarlyStopping
from tensorflow.keras.losses import hinge, binary_crossentropy

from tensorflow.keras.models import Model

from KGEkeras.utils import load_kg, validate, loss_function_lookup, generate_negative, oversample_data

models = {'DistMult':DistMult,
           'TransE':TransE,
           'HolE':HolE,
           'ComplEx':ComplEx,
           'ConvE':ConvE,
            'ConvR':ConvR,
             'HAKE':HAKE,
             'RotatE':RotatE,
             'pRotatE':pRotatE
         }

class DataGenerator(tf.keras.utils.Sequence):
    def __init__(self, kg, ns=10, batch_size=32, shuffle=True):
        self.batch_size = min(batch_size,len(kg))
        self.kg = kg
        self.ns = ns
        self.num_e = len(set([s for s,_,_ in kg])|set([o for _,_,o in kg]))
        self.shuffle = shuffle
        self.indices = list(range(len(kg)))
        
        self.on_epoch_end()

    def __len__(self):
        return len(self.kg) // self.batch_size

    def __getitem__(self, index):
        index = self.index[index * self.batch_size:(index + 1) * self.batch_size]
        batch = [self.indices[k] for k in index]
        
        X, y = self.__get_data(batch)
        return X, y

    def on_epoch_end(self):
        self.index = np.arange(len(self.indices))
        if self.shuffle == True:
            np.random.shuffle(self.index)

    def __get_data(self, batch):
        tmp_kg = np.asarray([self.kg[i] for i in batch])
        
        negative_kg = generate_negative(tmp_kg,N=self.num_e,negative=self.ns)
        X = oversample_data(kgs=[tmp_kg,negative_kg])
    
        return X, None 

def build_model(hp):
    
    params = hp.copy()
    params['e_dim'] = params['dim']
    params['r_dim'] = params['dim']
    params['name'] = 'embedding_model'
    
    embedding_model = models[params['embedding_model']]
    embedding_model = embedding_model(**params)
    triple = Input((3,))
    ftriple = Input((3,))
    
    inputs = [triple, ftriple]
    
    score = embedding_model(triple)
    fscore = embedding_model(ftriple)
    
    loss_function = loss_function_lookup(params['loss_function'])
    loss = loss_function(score,fscore,params['margin'] or 1, 1)
    
    model = Model(inputs=inputs, outputs=loss)
    model.add_loss(loss)
    
    model.compile(optimizer='adam',
                  loss=None)
    
    return model

def pad(kg,bs):
    while len(kg) % bs != 0:
        kg.append(choice(kg))
    return kg
             

In [4]:
g = Graph()
g.load('../TERA_OUTPUT/ecotox_taxonomy.nt',format='nt')
for s,p,o in g:
    if isinstance(o,Literal):
        g.remove((s,p,o))
entities = set(g.subjects()) | set(g.objects())
entities = [str(s) for s in entities if isinstance(s,URIRef)]

relations = set(g.predicates())
relations = [str(s) for s in relations if isinstance(s,URIRef)]

entity_mapping = dict(zip(entities,range(len(entities))))
relation_mapping = dict(zip(relations,range(len(relations))))


train_triples = np.asarray([(entity_mapping[str(s)],
                  relation_mapping[str(p)],
                  entity_mapping[str(o)]) for s,p,o in g])

In [5]:
from sklearn.model_selection import KFold

bs = 2048

config = {'num_entities':len(entities),
          'num_relations':len(relations),
          'dim':100,
          'embedding_model':'ComplEx',
          'loss_function':'pairwize_hinge',
          'margin':1}

for fold, (train_idx,test_idx) in enumerate(KFold(5).split(train_triples)):
    
    model = build_model(config)
    
    model.fit(DataGenerator(train_triples[train_idx],batch_size=bs),
             validation_data=DataGenerator(train_triples[test_idx],batch_size=bs),
             callbacks = [tf.keras.callbacks.EarlyStopping('val_loss',patience=5)],
             epochs=100,
             verbose=1)
    
    model.save_weights(f'tmp/model_weights_{fold}.tf')
    W = model.layers[1].entity_embedding.weights[0].numpy()
    np.save(f'tmp/model_entities_{fold}.npy',W)
    

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epo