In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import tensorflow as tf
import pandas as pd
tf.config.set_soft_device_placement(False)
tf.debugging.set_log_device_placement(True)
import numpy as np
from ampligraph.datasets import load_fb15k_237, load_yago3_10
from ampligraph.evaluation.protocol import create_mappings, to_idx
import time
assert(tf.__version__.startswith('2.3'))

In [None]:
from ampligraph.latent_features import EmbeddingModel

In [None]:
gpus = tf.config.experimental.list_physical_devices('GPU')


tf.config.experimental.set_virtual_device_configuration(
    gpus[0],
    [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=8000),
    tf.config.experimental.VirtualDeviceConfiguration(memory_limit=8000)])

logical_gpus = tf.config.experimental.list_logical_devices('GPU')
print(len(gpus), "Physical GPU,", len(logical_gpus), "Logical GPUs")

logical_gpus  

In [None]:
# load the dataset
full_dataset = load_fb15k_237()

rel_to_idx, ent_to_idx = create_mappings(full_dataset['train'])
train_dataset = to_idx(full_dataset['train'], ent_to_idx, rel_to_idx)
test_dataset = to_idx(full_dataset['test'], ent_to_idx, rel_to_idx)

In [None]:
#get the unique entities and permute it before bucketing it
unique_ents = np.random.permutation(np.array(list(set(train_dataset[:,0]).union(set(train_dataset[:,2])))))
unique_ents.shape

In [None]:
# create n buckets of nodes
num_buckets = 4

dataset_df = pd.DataFrame(train_dataset, columns=['s','p', 'o'])
p_triples_bool = dict()
p_triples = dict()

# store entities in buckets
bucketed_entities = dict()


p_triples_multiple_buckets = dict()
p_ent_multiple_buckets = dict()
start_time = time.time()

# Max entities in a bucket 
max_entities = unique_ents.shape[0]//num_buckets
for i in range(num_buckets):
    # store the entities in the buckets
    bucketed_entities[i] = unique_ents[i * max_entities: (i+1) * max_entities]

total_partitions = 0
# partition the edges based on bucketed entities
# if you have 2 buckets (0, 1) then you will have 3 partitions i.e.
#   a. edges with sub and obj in bucket 0 (0 - 0)
#   b. edges with sub and obj in bucket 1 (1 - 1)
#   c. edges with sub in 0 and obj in  1 (0 - 1)
for i in range(num_buckets):
    for j in range(i, num_buckets):
        try:
            # based on where the triples start and end, put in respective partitions
            p_triples_multiple_buckets[i][j] =  train_dataset[np.logical_or(
                                                    np.logical_and(dataset_df['s'].isin(bucketed_entities[i]).values,
                                                                  dataset_df['o'].isin(bucketed_entities[j]).values),
                                                    np.logical_and(dataset_df['s'].isin(bucketed_entities[j]).values,
                                                                  dataset_df['o'].isin(bucketed_entities[i]).values)), :]



            p_ent_multiple_buckets[i][j] = np.array(list(set(p_triples_multiple_buckets[i][j][:, 0]).union(
                set(p_triples_multiple_buckets[i][j][:, 2]))))
        except KeyError:
            p_triples_multiple_buckets[i] = dict()
            p_ent_multiple_buckets[i] = dict()
            p_triples_multiple_buckets[i][j] =  train_dataset[np.logical_or(
                                                    np.logical_and(dataset_df['s'].isin(bucketed_entities[i]).values,
                                                                  dataset_df['o'].isin(bucketed_entities[j]).values),
                                                    np.logical_and(dataset_df['s'].isin(bucketed_entities[j]).values,
                                                                  dataset_df['o'].isin(bucketed_entities[i]).values)), :]

            p_ent_multiple_buckets[i][j] = np.array(list(set(p_triples_multiple_buckets[i][j][:, 0]).union(
                set(p_triples_multiple_buckets[i][j][:, 2]))))
        print('{} -> {} : {} triples, {} entities'.format(i, j, p_triples_multiple_buckets[i][j].shape,
                                                         p_ent_multiple_buckets[i][j].shape))
        total_partitions +=1
        
end_time = time.time()

print('Time Taken: {} secs'.format(end_time - start_time) )
print('Total node partitions:', num_buckets)
print('Total edge partitions:', total_partitions)

In [None]:
# this would go into separate classes, like we had in ampligraph 1 (for loss functions and initializers)
# initializer
def xavier(in_shape, out_shape):
    std = np.sqrt(2 / (in_shape + out_shape))
    return np.random.normal(0, std, size=(in_shape, out_shape)).astype(np.float32)

#loss function
with tf.device('GPU:0'):
    def nll(scores_pred, eta):
        scores_neg = scores_pred[1]
        scores_pos = scores_pred[0]

        scores_neg_reshaped = tf.reshape(scores_neg, [eta, tf.shape(scores_pos)[0]])
        neg_exp = tf.exp(scores_neg_reshaped)
        pos_exp = tf.exp(scores_pos)
        softmax_score = pos_exp / (tf.reduce_sum(neg_exp, axis=0) + pos_exp)

        loss = -tf.reduce_sum(tf.math.log(softmax_score))
        return loss

In [None]:
# this class is responsible for training the embeddings on multiple GPUs
class DistributedTrainer():
    
    # this class manages the partitions
    # This would later be responsible for persisting the input data and creating the partitions
    # during training time, it will load the partitions and related embeddings 
    class PartitionManager():
        def __init__(self, num_buckets, k, num_entities, num_rels, num_devices=1):
            self.num_buckets = num_buckets
            
            if num_buckets > 1:
                self.num_partitions = np.sum(np.arange(1, num_buckets+1))
            else:
                self.num_partitions = num_buckets
            
            # needs to go into the database
            self.entity_embeddings = xavier(num_entities, k)
            self.rel_embeddings = xavier(num_rels, k)
            
            self.num_devices = 1
            
            # use multiple GPUs only if we do multiple partitions
            if self.num_partitions > 1:
                self.num_devices = num_devices
            
        
        def get_next_partition(self):
            # get the next partition to train on, along with the embeddings of nodes in that partition
            for i in range(len(p_ent_multiple_buckets)):
                for j in range(len(p_ent_multiple_buckets[i])):
                    partition_dict = dict(zip(p_ent_multiple_buckets[i][i+j], 
                                              np.arange(p_ent_multiple_buckets[i][i+j].shape[0])))
                    new_rel_dict = dict(zip(np.arange(len(rel_to_idx)), np.arange(len(rel_to_idx))))
                    # remap the triples to reflect the position in the embedding matrix
                    remapped_triples = to_idx(p_triples_multiple_buckets[i][i+j], partition_dict, new_rel_dict)
                    yield partition_dict, \
                            new_rel_dict, \
                            remapped_triples, \
                            self.entity_embeddings[list(partition_dict.keys()), :], \
                            self.rel_embeddings[list(new_rel_dict.keys()), :]
        
        
        
    def __init__(self, batch_size, max_ent_size, k, num_ents, num_rels, num_buckets=1, num_devices=1):
        # max_ent_size - is the max embeddings that can be loaded in memory
        self.partition_manager = self.PartitionManager(num_buckets, k, num_ents, num_rels, num_devices)
        self.eta = 5
        self.models = None # would be a list later (one per device)
        self.optimizers = None # would be a list later (one per device)
        self.max_ent_size = max_ent_size
        self.k = k
        self.batch_size=batch_size
        self.num_devices = num_devices
        self.num_ents = num_ents
        self.num_rels = num_rels
        
        # create embedding model and optimizer
        for i in range(num_devices):
            with tf.device('GPU:{}'.format(i)):
                self.optimizers = tf.optimizers.Adam(lr=0.01)
                self.models = EmbeddingModel(eta=self.eta, 
                                             k=k, 
                                             max_ent_size=max_ent_size, 
                                             max_rel_size=self.num_rels)
        
        self.batch_size=batch_size
        
        # this is the hyperparams of the optimizer - would be later moved 
        # to database calls and retrieved from partition manager
        self.optimizer_hyperparams_ent = np.zeros(shape=(self.num_ents, 2, k), 
                                                  dtype=np.float32)
        self.optimizer_hyperparams_rel = np.zeros(shape=(self.num_rels, 2, k), 
                                                  dtype=np.float32)
        
        
    def train_dataset_generator(self, dataset):
        # generator for the training data
        batch_count = dataset.shape[0]//self.batch_size + 1
        for j in range(batch_count):
            inputs = dataset[j * self.batch_size : (j+1) * self.batch_size, :].astype(np.int32)
            yield inputs
            
    def update_partion_embeddings_after_train(self):
        # before changing the partition, save the trained embeddings and optimizer params
        
        self.partition_manager.entity_embeddings[list(self.ent_dict.keys()), :] = self.models.encoding_layer.ent_emb.numpy()[:len(self.ent_dict), :]
        self.partition_manager.rel_embeddings[list(self.rel_dict.keys()), :] = self.models.encoding_layer.rel_emb.numpy()[:len(self.rel_dict), :]
        
        opt_weights = self.optimizers.get_weights()
        if len(opt_weights)>0:
            self.optimizer_hyperparams_rel[list(self.rel_dict.keys()), :, :] = np.concatenate([opt_weights[2][:len(self.rel_dict)][:, np.newaxis, :], 
                                                                                                 opt_weights[4][:len(self.rel_dict)][:, np.newaxis, :]], 1)
            
            self.optimizer_hyperparams_ent[list(self.ent_dict.keys()), :, :] = np.concatenate([opt_weights[1][:len(self.ent_dict)][:, np.newaxis, :], 
                                                                                             opt_weights[3][:len(self.ent_dict)][:, np.newaxis, :]], 1)
        
        
    def change_partition(self):
        # load a new partition and update the trainable params and optimizer hyperparams
        self.ent_dict, self.rel_dict, remapped_triples, ent_embs, rel_embs = next(self.partition_iterator)
        print('partition has {} triples', remapped_triples.shape)
        self.partition_dataset_iterator = iter(self.train_dataset_generator(remapped_triples))
        self.models.partition_change_updates(len(self.ent_dict), ent_embs, rel_embs)
        if self.global_epoch >1:
            # needs to be better handled
            optimizer_rel_weights_updates_beta1 = self.optimizer_hyperparams_rel[list(self.rel_dict.keys()), 0, :]
            optimizer_rel_weights_updates_beta2 = self.optimizer_hyperparams_rel[list(self.rel_dict.keys()), 1, :]
            optimizer_ent_weights_updates_beta1 = self.optimizer_hyperparams_ent[list(self.ent_dict.keys()), 0, :]
            optimizer_ent_weights_updates_beta2 = self.optimizer_hyperparams_ent[list(self.ent_dict.keys()), 1, :]
            
            optimizer_rel_weights_updates_beta1 = np.pad(optimizer_rel_weights_updates_beta1, 
                                                         ((0, self.num_rels - optimizer_rel_weights_updates_beta1.shape[0]), 
                                                          (0,0)), 
                                                         'constant', 
                                                         constant_values=(0))
            optimizer_rel_weights_updates_beta2 = np.pad(optimizer_rel_weights_updates_beta2, 
                                                         ((0, self.num_rels - optimizer_rel_weights_updates_beta2.shape[0]), 
                                                          (0,0)), 
                                                         'constant', 
                                                         constant_values=(0))
            optimizer_ent_weights_updates_beta1 = np.pad(optimizer_ent_weights_updates_beta1, 
                                                         ((0, self.max_ent_size - optimizer_ent_weights_updates_beta1.shape[0]), 
                                                          (0,0)), 
                                                         'constant', 
                                                         constant_values=(0))
            optimizer_ent_weights_updates_beta2 = np.pad(optimizer_ent_weights_updates_beta2, 
                                                         ((0, self.max_ent_size - optimizer_ent_weights_updates_beta2.shape[0]), 
                                                          (0,0)), 
                                                         'constant', 
                                                         constant_values=(0))
            
            self.optimizers.set_weights(self.optimizers.get_weights())

            self.optimizers.set_weights([self.optimizers.iterations.numpy(), 
                                         optimizer_ent_weights_updates_beta1,
                                         optimizer_rel_weights_updates_beta1,
                                         optimizer_ent_weights_updates_beta2,
                                         optimizer_rel_weights_updates_beta2
                                        ])

            
    def get_next_batch(self):
        try:
            self.partition_iterator = iter(self.partition_manager.get_next_partition())
            # get new partition
            self.change_partition()
            while True:
                try:
                    # get batches from the current partition
                    out = next(self.partition_dataset_iterator)
                    yield out
                except StopIteration:
                    # if no more batch data - save the trained params and load next partition
                    self.update_partion_embeddings_after_train()
                    self.change_partition()
        except StopIteration:
            #if no more partitions, end
            return
                
    
    @tf.function()
    def train_step(self, inputs, optimizer):
        with tf.GradientTape() as tape:
            # get the model predictions
            preds = self.models(inputs, training=0)
            # compute the loss
            loss = nll(preds, self.eta)
            # regularizer - will be in a separate class like ampligraph 1
            loss += (0.0001 * (tf.reduce_sum(tf.pow(tf.abs(self.models.encoding_layer.ent_emb), 3)) + \
                              tf.reduce_sum(tf.pow(tf.abs(self.models.encoding_layer.rel_emb), 3))))

        # compute the grads
        gradients = tape.gradient(loss, [self.models.encoding_layer.ent_emb, 
                                         self.models.encoding_layer.rel_emb])
        # update the trainable params
        optimizer.apply_gradients(zip(gradients, [self.models.encoding_layer.ent_emb, 
                                                  self.models.encoding_layer.rel_emb]))   
        return loss
        
                

    def train(self, epochs = 100):
        dataset = tf.data.Dataset.from_generator(self.get_next_batch,
                                             output_types=(tf.int32),
                                             output_shapes=((None, 3)))
        dataset = dataset.prefetch(0)



        for i in range(epochs):
            total_loss = []
            print(i)
            self.global_epoch = i

            for j, inputs in dataset.enumerate():
                self.global_batch = 0
                with tf.device('{}'.format('GPU:0')):
                    loss = self.train_step(inputs, self.optimizers)
                
                total_loss.append(loss/inputs.shape[0])
            
            print('\n\n\n\nloss------------------{}:{}'.format(i, np.mean(total_loss)))
        print('done')


        return

In [None]:
num_ents = len(ent_to_idx)
num_rels = len(rel_to_idx)
dist_trainer = DistributedTrainer(30000, 7000, k=300, 
                                  num_ents=num_ents, num_rels=num_rels, 
                                  num_buckets=num_buckets, num_devices=1)

In [None]:
start = time.time()
dist_trainer.train()
end = time.time()

print(end - start)

In [None]:
# 1 partition
# loss 0.05109425261616707
# 65 sec

# 10 partitions
