In [1]:
from itertools import product
import pandas as pd
import tensorflow as tf
import numpy as np
from tqdm import tqdm_notebook
import gc

import ampligraph.datasets
from ampligraph.datasets import load_wn18, load_wn18rr, load_fb15k, load_fb15k_237, load_yago3_10
from ampligraph.latent_features import ComplEx, DistMult, TransE, HolE
from ampligraph.evaluation import evaluate_performance, mr_score, mrr_score, hits_at_n_score, generate_corruptions_for_eval
from ampligraph.latent_features import *
from ampligraph.datasets.numpy_adapter import NumpyDatasetAdapter

In [2]:
%env CUDA_VISIBLE_DEVICES=1

env: CUDA_VISIBLE_DEVICES=1


In [3]:
import json
with open("AmpliGraph/experiments/config.json", "r") as f:
    config = json.load(f)

In [4]:
import numpy as np
import tensorflow as tf
from sklearn.utils import check_random_state
import abc
from tqdm import tqdm
import logging
from ampligraph.latent_features.loss_functions import LOSS_REGISTRY
from ampligraph.latent_features.regularizers import REGULARIZER_REGISTRY
from ampligraph.latent_features.optimizers import OPTIMIZER_REGISTRY, SGDOptimizer
from ampligraph.latent_features.initializers import INITIALIZER_REGISTRY, DEFAULT_XAVIER_IS_UNIFORM
from ampligraph.evaluation import generate_corruptions_for_fit, to_idx, generate_corruptions_for_eval, \
    hits_at_n_score, mrr_score
from ampligraph.datasets import AmpligraphDatasetAdapter, NumpyDatasetAdapter
from functools import partial
from ampligraph.latent_features import constants as constants
from ampligraph.latent_features.constants import *

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


class Ensemble(object):
    def __init__(self, mode, models, epochs, batches_count, eta, 
                 loss, loss_params={}, 
                 embedding_model_params={},
                 optimizer='adam', optimizer_params={}, 
                 verbose=True):
        self.eval_dataset_handle = None
        self.mode = mode
        self.models = models
        self.ent_to_idx = models[0].ent_to_idx
        self.rel_to_idx = models[0].rel_to_idx
        assert all(e.ent_to_idx == self.ent_to_idx for e in models)
        assert all(e.rel_to_idx == self.rel_to_idx for e in models)
        self.verbose = verbose
        
        self.is_filtered = False
        self.is_fitted = True
        self.seed = 0
        
        self.batches_count = batches_count
        self.eta = eta
        self.epochs = epochs
        
        self.embedding_model_params = embedding_model_params
        self.loss_params = loss_params

        try:
            self.loss = LOSS_REGISTRY[loss](self.eta, self.loss_params, verbose=verbose)
        except KeyError:
            msg = 'Unsupported loss function: {}'.format(loss)
            logger.error(msg)
            raise ValueError(msg)

        self.optimizer_params = optimizer_params

        try:
            self.optimizer = OPTIMIZER_REGISTRY[optimizer](self.optimizer_params, 
                                                           self.batches_count, 
                                                            verbose)
        except KeyError:
            msg = 'Unsupported optimizer: {}'.format(optimizer)
            logger.error(msg)
            raise ValueError(msg)
        
        self.tf_config = tf.ConfigProto(allow_soft_placement=True)
        self.tf_config.gpu_options.allow_growth = True
        self.sess_train = None
        self.sess_predict = None

    def _save_trained_params(self):
        self.trained_model_params = self.sess_train.run(self.ensemble)

    def _load_model_from_trained_params(self):
        self.ensemble = tf.constant(self.trained_model_params)
        
    def _initialize_parameters(self):
        self.ensemble = tf.get_variable('ensemble', #shape=[len(self.models)],
                                        initializer=np.ones(len(self.models), dtype=np.float32))#tf.contrib.layers.xavier_initializer(uniform=False, seed=self.seed))
          
    def _lookup_embeddings(self, i, x):
        e_s = tf.nn.embedding_lookup(self.ent_emb[i], x[:, 0])
        e_p = tf.nn.embedding_lookup(self.rel_emb[i], x[:, 1], name='embedding_lookup_predicate')
        e_o = tf.nn.embedding_lookup(self.ent_emb[i], x[:, 2])
        return e_s, e_p, e_o
    
    def set_filter_for_eval(self):
        """Configures to use filter
        """
        self.is_filtered = True

    def configure_evaluation_protocol(self, config=None):
        if config is None:
            config = {'corruption_entities': DEFAULT_CORRUPTION_ENTITIES,
                      'corrupt_side': DEFAULT_CORRUPT_SIDE_EVAL,
                      'default_protocol': DEFAULT_PROTOCOL_EVAL}
        self.eval_config = config
        if self.eval_config['default_protocol']:
            self.eval_config['corrupt_side'] = 's+o'

    def test_retrieve(self, mode):
        if self.is_filtered:
            test_generator = partial(self.eval_dataset_handle.get_next_batch_with_filter,
                                     batch_size=1, dataset_type=mode)
        else:
            test_generator = partial(self.eval_dataset_handle.get_next_eval_batch, batch_size=1, dataset_type=mode)
            
        batch_iterator = iter(test_generator())
        indices_obj = np.empty(shape=(0, 1), dtype=np.int32)
        indices_sub = np.empty(shape=(0, 1), dtype=np.int32)
        for i in range(self.eval_dataset_handle.get_size(mode)):
            if self.is_filtered:
                out, indices_obj, indices_sub = next(batch_iterator)
            else:
                out = next(batch_iterator)
  
            yield out, indices_obj, indices_sub
            
            
    def _initialize_eval_graph(self, mode="test"):
        """Initialize the evaluation graph. 
        
        Use prime number based filtering strategy (refer set_filter_for_eval()), if the filter is set
        """

        # Use a data generator which returns a test triple along with the subjects and objects indices for filtering
        # The last two data are used if the graph is large. They are the embeddings of the entities that must be 
        # loaded on the GPU before scoring and the indices of those embeddings. 
        dataset = tf.data.Dataset.from_generator(partial(self.test_retrieve, mode=mode),
                                                 output_types=(tf.int32, tf.int32, tf.int32),
                                                 output_shapes=((1, 3), (None, 1), (None, 1))) 
        dataset = dataset.repeat()
        dataset = dataset.prefetch(1)
        dataset_iter = dataset.make_one_shot_iterator()
        self.X_test_tf, indices_obj, indices_sub = dataset_iter.get_next()

        use_default_protocol = self.eval_config.get('default_protocol', DEFAULT_PROTOCOL_EVAL)
        corrupt_side = self.eval_config.get('corrupt_side', DEFAULT_CORRUPT_SIDE_EVAL)
        # Dependencies that need to be run before scoring
        test_dependency = []
        # For large graphs

        # Rather than generating corruptions in batches do it at once on the GPU for small or medium sized graphs
        all_entities_np = np.arange(len(self.ent_to_idx))

        corruption_entities = self.eval_config.get('corruption_entities', DEFAULT_CORRUPTION_ENTITIES)

        if corruption_entities == 'all':
            corruption_entities = all_entities_np
        elif isinstance(corruption_entities, np.ndarray):
            corruption_entities = corruption_entities
        else:
            msg = 'Invalid type for corruption entities.'
            logger.error(msg)
            raise ValueError(msg)

        # Entities that must be used while generating corruptions
        self.corruption_entities_tf = tf.constant(corruption_entities, dtype=tf.int32)

        corrupt_side = self.eval_config.get('corrupt_side', DEFAULT_CORRUPT_SIDE_EVAL)
        # Generate corruptions
        self.out_corr = generate_corruptions_for_eval(self.X_test_tf,
                                                      self.corruption_entities_tf,
                                                      corrupt_side)

        scores_predict = []
        score_positive = []
        
        if self.mode == 'calibration':
            for i, m in enumerate(self.models):
                w = tf.Variable(m.calibration_parameters[0], dtype=tf.float32, trainable=False)
                b = tf.Variable(m.calibration_parameters[1], dtype=tf.float32, trainable=False)
                # Compute scores for negatives
                e_s, e_p, e_o = self._lookup_embeddings(i, self.out_corr)
                scores_predict.append(tf.sigmoid(-(w*m._fn(e_s, e_p, e_o)+b)))
                # Compute scores for positive
                e_s, e_p, e_o = self._lookup_embeddings(i, self.X_test_tf)
                score_positive.append(tf.sigmoid(-(w*m._fn(e_s, e_p, e_o)+b)))

            dummy = tf.ones(shape=tf.shape(tf.reshape(tf.exp(self.ensemble), (-1, 1))))
            scores_predict = tf.squeeze(tf.matmul(tf.stack(scores_predict, axis=1), dummy))
            score_positive = tf.squeeze(tf.matmul(tf.stack(score_positive, axis=1), dummy))
        elif self.mode == 'expit':
            for i, m in enumerate(self.models):
                # Compute scores for negatives
                e_s, e_p, e_o = self._lookup_embeddings(i, self.out_corr)
                scores_predict.append(tf.sigmoid(m._fn(e_s, e_p, e_o)))
                # Compute scores for positive
                e_s, e_p, e_o = self._lookup_embeddings(i, self.X_test_tf)
                score_positive.append(tf.sigmoid(m._fn(e_s, e_p, e_o)))

            dummy = tf.ones(shape=tf.shape(tf.reshape(tf.exp(self.ensemble), (-1, 1))))
            scores_predict = tf.squeeze(tf.matmul(tf.stack(scores_predict, axis=1), dummy))
            score_positive = tf.squeeze(tf.matmul(tf.stack(score_positive, axis=1), dummy))
        elif self.mode == 'mean':
            for i, m in enumerate(self.models):
                # Compute scores for negatives
                e_s, e_p, e_o = self._lookup_embeddings(i, self.out_corr)
                scores_predict.append(m._fn(e_s, e_p, e_o))
                # Compute scores for positive
                e_s, e_p, e_o = self._lookup_embeddings(i, self.X_test_tf)
                score_positive.append(m._fn(e_s, e_p, e_o))
                
            dummy = tf.ones(shape=tf.shape(tf.reshape(tf.exp(self.ensemble), (-1, 1))))
            scores_predict = tf.squeeze(tf.matmul(tf.stack(scores_predict, axis=1), dummy))
            score_positive = tf.squeeze(tf.matmul(tf.stack(score_positive, axis=1), dummy))
        elif self.mode == 'linear':
            for i, m in enumerate(self.models):
                # Compute scores for negatives
                e_s, e_p, e_o = self._lookup_embeddings(i, self.out_corr)
                scores_predict.append(m._fn(e_s, e_p, e_o))
                # Compute scores for positive
                e_s, e_p, e_o = self._lookup_embeddings(i, self.X_test_tf)
                score_positive.append(m._fn(e_s, e_p, e_o))
                
            dummy = tf.reshape(tf.exp(self.ensemble), (-1, 1))
            scores_predict = tf.squeeze(tf.matmul(tf.stack(scores_predict, axis=1), dummy))
            score_positive = tf.squeeze(tf.matmul(tf.stack(score_positive, axis=1), dummy))
        else:
            raise ValueError("Unknown mode: {}".format(self.mode))
            
        use_default_protocol = self.eval_config.get('default_protocol', DEFAULT_PROTOCOL_EVAL)

        if use_default_protocol:
            obj_corruption_scores = tf.slice(scores_predict,
                                             [0],
                                             [tf.shape(scores_predict)[0] // 2])

            subj_corruption_scores = tf.slice(scores_predict,
                                              [tf.shape(scores_predict)[0] // 2],
                                              [tf.shape(scores_predict)[0] // 2])

        # this is to remove the positives from corruptions - while ranking with filter
        positives_among_obj_corruptions_ranked_higher = tf.constant(0, dtype=tf.int32)
        positives_among_sub_corruptions_ranked_higher = tf.constant(0, dtype=tf.int32)
        
        if self.is_filtered:
            # If a list of specified entities were used for corruption generation
            if isinstance(self.eval_config.get('corruption_entities',
                                               DEFAULT_CORRUPTION_ENTITIES), np.ndarray):
                corruption_entities = self.eval_config.get('corruption_entities',
                                                           DEFAULT_CORRUPTION_ENTITIES).astype(np.int32)
                if corruption_entities.ndim == 1:
                    corruption_entities = np.expand_dims(corruption_entities, 1)
                # If the specified key is not present then it would return the length of corruption_entities
                corruption_mapping = tf.contrib.lookup.MutableDenseHashTable(key_dtype=tf.int32,
                                                                             value_dtype=tf.int32,
                                                                             default_value=len(corruption_entities),
                                                                             empty_key=-2,
                                                                             deleted_key=-1)

                insert_lookup_op = corruption_mapping.insert(corruption_entities, 
                                                             tf.reshape(tf.range(tf.shape(corruption_entities)[0],
                                                                                 dtype=tf.int32), (-1, 1)))

                with tf.control_dependencies([insert_lookup_op]):
                    # remap the indices of objects to the smaller set of corruptions
                    indices_obj = corruption_mapping.lookup(indices_obj)
                    # mask out the invalid indices (i.e. the entities that were not in corruption list
                    indices_obj = tf.boolean_mask(indices_obj, indices_obj < len(corruption_entities))
                    # remap the indices of subject to the smaller set of corruptions
                    indices_sub = corruption_mapping.lookup(indices_sub)
                    # mask out the invalid indices (i.e. the entities that were not in corruption list
                    indices_sub = tf.boolean_mask(indices_sub, indices_sub < len(corruption_entities))

            # get the scores of positives present in corruptions
            if use_default_protocol:
                scores_pos_obj = tf.gather(obj_corruption_scores, indices_obj) 
                scores_pos_sub = tf.gather(subj_corruption_scores, indices_sub)
            else:
                scores_pos_obj = tf.gather(scores_predict, indices_obj) 
                if corrupt_side == 's+o':
                    scores_pos_sub = tf.gather(scores_predict, indices_sub + len(corruption_entities))
                else:
                    scores_pos_sub = tf.gather(scores_predict, indices_sub)
            # compute the ranks of the positives present in the corruptions and 
            # see how many are ranked higher than the test triple
            if corrupt_side == 's+o' or corrupt_side == 'o':
                positives_among_obj_corruptions_ranked_higher = tf.reduce_sum(
                    tf.cast(scores_pos_obj >= score_positive, tf.int32)) 
            if corrupt_side == 's+o' or corrupt_side == 's':
                positives_among_sub_corruptions_ranked_higher = tf.reduce_sum(
                    tf.cast(scores_pos_sub >= score_positive, tf.int32)) 
                
        # compute the rank of the test triple and subtract the positives(from corruptions) that are ranked higher
        if use_default_protocol:       
            self.rank = tf.stack([tf.reduce_sum(tf.cast(
                subj_corruption_scores >= score_positive, 
                tf.int32)) + 1 - positives_among_sub_corruptions_ranked_higher,
                tf.reduce_sum(tf.cast(obj_corruption_scores >= score_positive,
                                      tf.int32)) + 1 - positives_among_obj_corruptions_ranked_higher], 0)
        else:
            self.rank = tf.reduce_sum(tf.cast(
                scores_predict >= score_positive, 
                tf.int32)) + 1 - positives_among_sub_corruptions_ranked_higher - \
                positives_among_obj_corruptions_ranked_higher

        
    def get_ranks(self, dataset_handle):
        if not self.is_fitted:
            msg = 'Model has not been fitted.'
            logger.error(msg)
            raise RuntimeError(msg)
        
        self.eval_dataset_handle = dataset_handle

        # build tf graph for predictions
        if self.sess_predict is None:
            tf.reset_default_graph()
            self.rnd = check_random_state(self.seed)
            tf.random.set_random_seed(self.seed)

            self.ent_emb = [tf.Variable(m.trained_model_params[0], dtype=tf.float32) for m in self.models]
            self.rel_emb = [tf.Variable(m.trained_model_params[1], dtype=tf.float32) for m in self.models]
      
            self._load_model_from_trained_params()
            # build the eval graph
            self._initialize_eval_graph()

            sess = tf.Session()
            sess.run(tf.tables_initializer())
            sess.run(tf.global_variables_initializer())
            self.sess_predict = sess

        ranks = []
                                                   
        for i in tqdm(range(self.eval_dataset_handle.get_size('test'))):
            rank = self.sess_predict.run(self.rank)
            if self.eval_config.get('default_protocol', DEFAULT_PROTOCOL_EVAL): 
                ranks.append(list(rank))
            else:
                ranks.append(rank)
        return ranks

    def predict(self, X, from_idx=False):
        if not self.is_fitted:
            msg = 'Model has not been fitted.'
            logger.error(msg)
            raise RuntimeError(msg)
        # adapt the data with numpy adapter for internal use
        dataset_handle = NumpyDatasetAdapter()
        dataset_handle.use_mappings(self.rel_to_idx, self.ent_to_idx)
        dataset_handle.set_data(X, "test", mapped_status=from_idx)
        
        self.eval_dataset_handle = dataset_handle
        
        # build tf graph for predictions
        if self.sess_predict is None:
            tf.reset_default_graph()
            self.rnd = check_random_state(self.seed)
            tf.random.set_random_seed(self.seed)

            self.ent_emb = [tf.Variable(m.trained_model_params[0], dtype=tf.float32) for m in self.models]
            self.rel_emb = [tf.Variable(m.trained_model_params[1], dtype=tf.float32) for m in self.models]
      
            self._load_model_from_trained_params()
            # build the eval graph
            self._initialize_eval_graph()

            sess = tf.Session()
            sess.run(tf.tables_initializer())
            sess.run(tf.global_variables_initializer())
            self.sess_predict = sess

        scores = []
                                                   
        for i in tqdm(range(self.eval_dataset_handle.get_size('test'))):
            score = self.sess_predict.run([score_positive])
            if self.eval_config.get('default_protocol', DEFAULT_PROTOCOL_EVAL): 
                scores.extend(list(score)) 
            else:
                scores.append(score)

        return scores
    
    
    def end_evaluation(self):
        """End the evaluation and close the Tensorflow session.
        """
        
        if self.is_filtered and self.eval_dataset_handle is not None:
            self.eval_dataset_handle.cleanup()
            self.eval_dataset_handle = None
            
        if self.sess_predict is not None:
            self.sess_predict.close()
            
        self.sess_predict = None
        self.is_filtered = False
        
        self.eval_config = {}
        

    def _get_model_loss(self, dataset_iterator):
        # get the train triples of the batch, unique entities and the corresponding embeddings
        # the latter 2 variables are passed only for large graphs. 
        x_pos_tf = dataset_iterator.get_next()

        entities_size = 0
        entities_list = None

        x_pos = x_pos_tf

        scores_pos = []
        for i, m in enumerate(self.models):
            e_s, e_p, e_o = self._lookup_embeddings(i, x_pos)
            scores_pos.append(tf.squeeze(m._fn(e_s, e_p, e_o)))
        
        scores_pos = tf.squeeze(tf.matmul(tf.stop_gradient(tf.stack(scores_pos, axis=1)), 
                                          tf.reshape(tf.exp(self.ensemble), (-1, 1))))
                    
        if self.loss.get_state('require_same_size_pos_neg'):
            logger.debug('Requires the same size of postive and negative')
            scores_pos = tf.reshape(tf.tile(scores_pos, [self.eta]), [tf.shape(scores_pos)[0] * self.eta])

        # look up embeddings from input training triples
        negative_corruption_entities = self.embedding_model_params.get('negative_corruption_entities',
                                                                       DEFAULT_CORRUPTION_ENTITIES)

        if negative_corruption_entities == 'all':
            logger.debug('Using all entities for generation of corruptions during training')
            entities_size = tf.shape(self.ent_emb)[0]
        elif negative_corruption_entities == 'batch':
            # default is batch (entities_size=0 and entities_list=None)
            logger.debug('Using batch entities for generation of corruptions during training')
        elif isinstance(negative_corruption_entities, list):
            logger.debug('Using the supplied entities for generation of corruptions during training')
            entities_list = tf.squeeze(tf.constant(np.asarray([idx for uri, idx in self.ent_to_idx.items()
                                                               if uri in negative_corruption_entities]),
                                                   dtype=tf.int32))
        elif isinstance(negative_corruption_entities, int):
            logger.debug('Using first {} entities for generation of corruptions during \
                         training'.format(negative_corruption_entities))
            entities_size = negative_corruption_entities

        loss = 0
        corruption_sides = self.embedding_model_params.get('corrupt_sides', DEFAULT_CORRUPT_SIDE_TRAIN)
        if not isinstance(corruption_sides, list):
            corruption_sides = [corruption_sides]

        for side in corruption_sides:
            # Generate the corruptions
            x_neg_tf = generate_corruptions_for_fit(x_pos_tf, 
                                                    entities_list=entities_list, 
                                                    eta=self.eta, 
                                                    corrupt_side=side, 
                                                    entities_size=entities_size, 
                                                    rnd=self.seed)


            scores_neg = []
            for i, m in enumerate(self.models):
                e_s, e_p, e_o = self._lookup_embeddings(i,x_neg_tf)
                scores_neg.append(m._fn(e_s, e_p, e_o))
            scores_neg = tf.squeeze(tf.matmul(tf.stop_gradient(tf.stack(scores_neg, axis=1)),
                                              tf.reshape(tf.exp(self.ensemble), (-1, 1))))

            # Apply the loss function
            loss += self.loss.apply(scores_pos, scores_neg)

        return loss

    def _training_data_generator(self):
        batch_iterator = iter(self.train_dataset_handle.get_next_train_batch(self.batch_size, "train"))
        for i in range(self.batches_count):
            out = next(batch_iterator)
            yield out
    
    def fit(self, X):
        self.train_dataset_handle = None
        # try-except block is mainly to handle clean up in case of exception or manual stop in jupyter notebook
        try:
            if isinstance(X, np.ndarray):
                # Adapt the numpy data in the internal format - to generalize
                self.train_dataset_handle = NumpyDatasetAdapter()
                self.train_dataset_handle.set_data(X, "train")
            elif isinstance(X, AmpligraphDatasetAdapter):
                self.train_dataset_handle = X
            else:
                msg = 'Invalid type for input X. Expected ndarray/AmpligraphDataset object, got {}'.format(type(X))
                logger.error(msg)
                raise ValueError(msg)

            self.train_dataset_handle.map_data()

            tf.reset_default_graph()
            self.rnd = check_random_state(self.seed)
            tf.random.set_random_seed(self.seed)

            self.ent_emb = [tf.Variable(m.trained_model_params[0], dtype=tf.float32) for m in self.models]
            self.rel_emb = [tf.Variable(m.trained_model_params[1], dtype=tf.float32) for m in self.models]
      
            self.sess_train = tf.Session(config=self.tf_config)

            batch_size = int(np.ceil(self.train_dataset_handle.get_size("train") / self.batches_count))

            self.batch_size = batch_size

            dataset = tf.data.Dataset.from_generator(self._training_data_generator, 
                                                     output_types=tf.int32,
                                                     output_shapes=(None, 3))

            dataset = dataset.repeat().prefetch(1)

            dataset_iterator = dataset.make_one_shot_iterator()
            # init tf graph/dataflow for training
            # init variables (model parameters to be learned - i.e. the embeddings)
            self._initialize_parameters()

            if self.loss.get_state('require_same_size_pos_neg'):
                batch_size = batch_size * self.eta

            loss = self._get_model_loss(dataset_iterator)

            train = self.optimizer.minimize(loss)

            self.sess_train.run(tf.tables_initializer())
            self.sess_train.run(tf.global_variables_initializer())

            epoch_iterator_with_progress = tqdm(range(1, self.epochs + 1), disable=(not self.verbose), unit='epoch')

            for epoch in epoch_iterator_with_progress:
                losses = []
                for batch in range(1, self.batches_count + 1):
                    feed_dict = {}
                    self.optimizer.update_feed_dict(feed_dict, batch, epoch)
                    loss_batch, _ = self.sess_train.run([loss, train], feed_dict=feed_dict)

                    if np.isnan(loss_batch) or np.isinf(loss_batch):
                        msg = 'Loss is {}. Please change the hyperparameters.'.format(loss_batch)
                        logger.error(msg)
                        raise ValueError(msg)

                    losses.append(loss_batch)

                if self.verbose:
                    msg = 'Average Loss: {:10f}'.format(sum(losses) / (batch_size * self.batches_count))

                    logger.debug(msg)
                    epoch_iterator_with_progress.set_description(msg)
            
            self._save_trained_params()
            self._end_training()
        except BaseException as e:
            self._end_training()
            raise e
        
    def _end_training(self):
        """Performs clean up tasks after training.
        """
        # Reset this variable as it is reused during evaluation phase
        if self.is_filtered and self.eval_dataset_handle is not None:
            # cleanup the evaluation data (deletion of tables
            self.eval_dataset_handle.cleanup()
            self.eval_dataset_handle = None
            
        if self.train_dataset_handle is not None:
            self.train_dataset_handle.cleanup()
            self.train_dataset_handle = None
            
        self.is_filtered = False
        self.eval_config = {}

        # close the tf session
        if self.sess_train is not None:
            self.sess_train.close()

        # set is_fitted to true to indicate that the model fitting is completed
        self.is_fitted = True

In [5]:
def ensemble_experiment(dataset):    
    X = getattr(ampligraph.datasets, config['load_function_map'][dataset.upper()])()
    
    filter_triples = np.concatenate((X['train'], X['valid'], X['test']))
    
    model_loss_combs = list(product([ComplEx, DistMult, TransE, HolE], 
                                    ["pairwise", "nll", "multiclass_nll", "self_adversarial"]))
    
    results = {'model_metrics': []}
    models = []
    
    for m, l in tqdm_notebook(model_loss_combs):
        model_name = m.__name__
        pickle_name = "./calibration_ensemble/{}_{}_{}.pkl".format(dataset, model_name.lower(), l)
        c = config['hyperparams'][dataset.upper()][model_name.upper()]

        try:
            model = restore_model(pickle_name)
            print("Restored model {}".format(pickle_name))
        except:
            print("Training model {}".format(pickle_name))
            model = m(batches_count=c['batches_count'], seed=0, epochs=1000, 
                      k=300 if m in (DistMult, TransE) else 150, 
                      eta=c['eta'], optimizer=c['optimizer'], optimizer_params={'lr': c['optimizer_params']['lr']},
                      regularizer=c.get('regularizer'), 
                      regularizer_params={'p':c['regularizer_params']['p'], 'lambda':c['regularizer_params']['lambda']} if 'regularizer' in c and 'regularizer_params' in c else {}, 
                      loss=l, verbose=False)
            try:
                print("Fitting")
                model.fit(X['train'])
                print("Calibrating")
                model.calibrate(X['valid'], positive_base_rate=0.5)
                save_model(model, pickle_name)
            except:
                continue

#         ranks = evaluate_performance(X['test'], 
#                              model=model, 
#                              filter_triples=filter_triples,
#                              use_default_protocol=True, 
#                              verbose=False)
#         results['model_metrics'].append([model_name, l, mrr_score(ranks), hits_at_n_score(ranks, n=10), mr_score(ranks)])
#         print(results['model_metrics'][-1])        
        models.append(model)
        
#     with open('./calibration_ensemble/{}_model_metrics.json'.format(dataset), 'w') as f:
#         json.dump(results['model_metrics'], f)
        
    results = {}
    for mode in ['calibration', 'mean', 'expit']:
        try:
            ens = Ensemble(mode=mode, models=models,
                           epochs=1, 
                           batches_count=1, eta=1,
                           loss='pairwise', optimizer='adam')
        
            ens.fit(X['valid'])

            ranks = evaluate_performance(X['test'], 
                                         model=ens, 
                                         filter_triples=filter_triples,
                                         use_default_protocol=True, 
                                         verbose=False)
        
            results[mode] = [mrr_score(ranks), hits_at_n_score(ranks, n=10), mr_score(ranks)]
            print(mode, results[mode])
        except Exception as e:
            results[mode] = ["exception", "exception"]
            print("Exception: {}".format(str(e)))
            
    with open('./calibration_ensemble/{}_results.json'.format(dataset), 'w') as f:
        json.dump(results, f)

    return models, results

In [6]:
datasets = list(config['load_function_map'].keys())
datasets

['WN18', 'FB15K', 'FB15K-237', 'WN18RR', 'YAGO310']

In [7]:
results = []
for d in ['FB15K-237', 'WN18RR']:
    _, r = ensemble_experiment(d)
    results.append(r)
    gc.collect()

HBox(children=(IntProgress(value=0, max=16), HTML(value='')))

Restored model ./calibration_ensemble/FB15K-237_complex_pairwise.pkl
Restored model ./calibration_ensemble/FB15K-237_complex_nll.pkl
Restored model ./calibration_ensemble/FB15K-237_complex_multiclass_nll.pkl
Restored model ./calibration_ensemble/FB15K-237_complex_self_adversarial.pkl
Restored model ./calibration_ensemble/FB15K-237_distmult_pairwise.pkl
Restored model ./calibration_ensemble/FB15K-237_distmult_nll.pkl
Restored model ./calibration_ensemble/FB15K-237_distmult_multiclass_nll.pkl
Restored model ./calibration_ensemble/FB15K-237_distmult_self_adversarial.pkl
Restored model ./calibration_ensemble/FB15K-237_transe_pairwise.pkl
Restored model ./calibration_ensemble/FB15K-237_transe_nll.pkl
Restored model ./calibration_ensemble/FB15K-237_transe_multiclass_nll.pkl
Restored model ./calibration_ensemble/FB15K-237_transe_self_adversarial.pkl
Restored model ./calibration_ensemble/FB15K-237_hole_pairwise.pkl
Restored model ./calibration_ensemble/FB15K-237_hole_nll.pkl
Restored model ./c

Average Loss:  16.148472: 100%|██████████| 1/1 [00:01<00:00,  1.27s/epoch]
100%|██████████| 20438/20438 [09:40<00:00, 38.04it/s]


calibration [0.30994084986300474, 0.49226930228006655, 173.28207260984442]


Average Loss:  16.148472: 100%|██████████| 1/1 [00:00<00:00,  1.76epoch/s]
100%|██████████| 20438/20438 [08:47<00:00, 38.73it/s]


mean [0.3162208948932377, 0.49867893140228986, 172.99207358841375]


Average Loss:  16.148472: 100%|██████████| 1/1 [00:01<00:00,  1.18s/epoch]
100%|██████████| 20438/20438 [09:36<00:00, 35.42it/s]


expit [0.2762770332507093, 0.4652118602602994, 196.76529014580683]


HBox(children=(IntProgress(value=0, max=16), HTML(value='')))

Restored model ./calibration_ensemble/WN18RR_complex_pairwise.pkl
Restored model ./calibration_ensemble/WN18RR_complex_nll.pkl
Restored model ./calibration_ensemble/WN18RR_complex_multiclass_nll.pkl
Restored model ./calibration_ensemble/WN18RR_complex_self_adversarial.pkl
Restored model ./calibration_ensemble/WN18RR_distmult_pairwise.pkl
Restored model ./calibration_ensemble/WN18RR_distmult_nll.pkl
Restored model ./calibration_ensemble/WN18RR_distmult_multiclass_nll.pkl
Restored model ./calibration_ensemble/WN18RR_distmult_self_adversarial.pkl
Restored model ./calibration_ensemble/WN18RR_transe_pairwise.pkl
Restored model ./calibration_ensemble/WN18RR_transe_nll.pkl
Restored model ./calibration_ensemble/WN18RR_transe_multiclass_nll.pkl
Restored model ./calibration_ensemble/WN18RR_transe_self_adversarial.pkl
Restored model ./calibration_ensemble/WN18RR_hole_pairwise.pkl
Restored model ./calibration_ensemble/WN18RR_hole_nll.pkl
Restored model ./calibration_ensemble/WN18RR_hole_multiclass

Average Loss:   8.735569: 100%|██████████| 1/1 [00:03<00:00,  3.19s/epoch]
100%|██████████| 2924/2924 [02:43<00:00, 17.86it/s]


calibration [0.4950741010374487, 0.5714774281805746, 2329.877222982216]


Average Loss:   8.735569: 100%|██████████| 1/1 [00:02<00:00,  2.10s/epoch]
100%|██████████| 2924/2924 [03:03<00:00, 15.94it/s]


mean [0.4837868888733338, 0.5528385772913816, 4218.939466484268]


Average Loss:   8.735569: 100%|██████████| 1/1 [00:03<00:00,  3.82s/epoch]
100%|██████████| 2924/2924 [03:18<00:00, 14.76it/s]


expit [0.49104252907088924, 0.5745554035567716, 3427.872777017784]
