In [42]:
from os.path import join
import torch
import pickle
import random
import numpy as np
from typing import *
import math
from torch.utils.data import Sampler, TensorDataset, DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import uniform
import copy

import csv
from random import *
from collections import defaultdict
from os.path import join
import os
print(os.getcwd())

import sys

import time
from torch.optim.lr_scheduler import StepLR
from sklearn.metrics import ndcg_score

import pandas as pd
from torch.utils.data import Dataset, TensorDataset
from torch.distributions import uniform
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
import pprint

KG_PATH = 'data/kg_data'
KG_MODEL = 'data/kg_data/trained_models'


/content


In [43]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [44]:
random_seed = 1

# torch.manual_seed(1222)
# random.seed(159)
# np.random.seed(2333)

torch.manual_seed(0)
np.random.seed(0)


In [45]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = 'cpu'

## Utils

In [46]:
def get_vocab(filename):
    word2idx = defaultdict()
    with open(filename) as inputfile:
        lines = inputfile.readlines()
        for line in lines:
            line = line.strip()
            parts = line.split('\t')
            word2idx[parts[1]] = parts[0]
    return word2idx


def get_new_model(params):
    
    if params.model_name == 'BiGumbelBox':
        model = BiGumbelBox(params.device, params.VOCAB_SIZE, params.DIM, params.NEG_PER_POS,
                            [1e-4, 0.01], [-0.1, -0.001], params).to(params.device)
    elif params.model_name == 'QuatE':
        model = QuatE(emb_dim=params.DIM, n_entities=params.VOCAB_SIZE,
                      n_relations=params.REL_VOCAB_SIZE, ratio = params.NEG_PER_POS, params=params).to(params.device)
    return model




def get_subset_of_given_relations(ids, rel_list):
    subs = []
    for r in rel_list:
        sub = ids[(ids[:, 1] == r).nonzero().squeeze(1)]  # sub triple set
        subs.append(sub)
    subset = torch.cat(subs, dim=0)
    return subset

def load_hr_map(data_dir):
    file = join(data_dir, 'ndcg_test.pickle')
    with open(file, 'rb') as f:
        hr_map = pickle.load(f)
    return hr_map

## Model

In [47]:

class Box:
    def __init__(self, min_embed, max_embed):
        self.min_embed = min_embed
        self.max_embed = max_embed
        self.delta_embed = max_embed - min_embed
    

class BiGumbelBox(nn.Module):
    def __init__(self, device, vocab_size, embed_dim, ratio, min_init_value, delta_init_value, params):
        super(BiGumbelBox, self).__init__()
        # super(BiGumbelBox, self).__init__(device, vocab_size, embed_dim, ratio, min_init_value, delta_init_value, params)

        self.euler_gamma = 0.57721566490153286060
        self.min_init_value = min_init_value
        self.delta_init_value = delta_init_value

        min_embedding = self.init_embedding(vocab_size, embed_dim, min_init_value)
        delta_embedding = self.init_embedding(vocab_size, embed_dim, delta_init_value)
        self.min_embedding = nn.Parameter(min_embedding)
        self.delta_embedding = nn.Parameter(delta_embedding)

        rel_trans_for_head = torch.empty(params.REL_VOCAB_SIZE, params.DIM)
        rel_scale_for_head = torch.empty(params.REL_VOCAB_SIZE, params.DIM)
        torch.nn.init.normal_(rel_trans_for_head, mean=0, std=1e-4)  # 1e-4 before
        torch.nn.init.normal_(rel_scale_for_head, mean=1, std=0.2)  # 0.2 before

        rel_trans_for_tail = torch.empty(params.REL_VOCAB_SIZE, params.DIM)
        rel_scale_for_tail = torch.empty(params.REL_VOCAB_SIZE, params.DIM)
        torch.nn.init.normal_(rel_trans_for_tail, mean=0, std=1e-4)
        torch.nn.init.normal_(rel_scale_for_tail, mean=1, std=0.2)

        # make nn.Parameter
        self.rel_trans_for_head, self.rel_scale_for_head = nn.Parameter(rel_trans_for_head.to(device)), nn.Parameter(
            rel_scale_for_head.to(device))
        self.rel_trans_for_tail, self.rel_scale_for_tail = nn.Parameter(rel_trans_for_tail.to(device)), nn.Parameter(
            rel_scale_for_tail.to(device))

        self.true_head, self.true_tail = None, None  # for negative sample filtering
        self.gumbel_beta = params.GUMBEL_BETA
        self.params = params
        self.device = device
        self.ratio = ratio
        self.vocab_size = vocab_size
        self.alpha = 1e-16
        self.clamp_min = 0.0
        self.clamp_max = 1e10
        self.REL_VOCAB_SIZE = params.REL_VOCAB_SIZE


    def load_model(self, PATH):
        model.load_state_dict(torch.load(PATH))
        model.eval()
        
    def forward(self, ids, probs, train=True):

        head_boxes = self.transform_head_boxes(ids)
        tail_boxes = self.transform_tail_boxes(ids)

        intersection_boxes = self.intersection(head_boxes, tail_boxes)

        log_intersection = self.log_volumes(intersection_boxes)

        # condition on subject or object
        log_prob = log_intersection - self.log_volumes(tail_boxes)

        pos_predictions = log_prob
        return pos_predictions, probs


    def transform_head_boxes(self, ids):
        head_boxes = self.get_entity_boxes(ids[:, 0])

        rel_ids = ids[:, 1]
        relu = nn.ReLU()

        translations = self.rel_trans_for_head[rel_ids]
        scales = relu(self.rel_scale_for_head[rel_ids])

        # affine transformation
        head_boxes.min_embed += translations
        head_boxes.delta_embed *= scales
        head_boxes.max_embed = head_boxes.min_embed + head_boxes.delta_embed

        return head_boxes

    def transform_tail_boxes(self, ids):
        tail_boxes = self.get_entity_boxes(ids[:, 2])

        rel_ids = ids[:, 1]
        relu = nn.ReLU()

        translations = self.rel_trans_for_tail[rel_ids]
        scales = relu(self.rel_scale_for_tail[rel_ids])

        # affine transformation
        tail_boxes.min_embed += translations
        tail_boxes.delta_embed *= scales
        tail_boxes.max_embed = tail_boxes.min_embed + tail_boxes.delta_embed

        return tail_boxes


    def intersection(self, boxes1, boxes2):
        intersections_min = self.gumbel_beta * torch.logsumexp(
            torch.stack((boxes1.min_embed / self.gumbel_beta, boxes2.min_embed / self.gumbel_beta)),
            0
        )
        intersections_min = torch.max(
            intersections_min,
            torch.max(boxes1.min_embed, boxes2.min_embed)
        )
        intersections_max = - self.gumbel_beta * torch.logsumexp(
            torch.stack((-boxes1.max_embed / self.gumbel_beta, -boxes2.max_embed / self.gumbel_beta)),
            0
        )
        intersections_max = torch.min(
            intersections_max,
            torch.min(boxes1.max_embed, boxes2.max_embed)
        )

        intersection_box = Box(intersections_min, intersections_max)
        return intersection_box

    def log_volumes(self, boxes, temp=1., gumbel_beta=1., scale=1.):
        eps = torch.finfo(boxes.min_embed.dtype).tiny  # type: ignore

        if isinstance(scale, float):
            s = torch.tensor(scale)
        else:
            s = scale

        log_vol = torch.sum(
            torch.log(
                F.softplus(boxes.delta_embed - 2 * self.euler_gamma * self.gumbel_beta, beta=temp).clamp_min(eps)
            ),
            dim=-1
        ) + torch.log(s)

        return log_vol

    def get_entity_boxes(self, entities):
        min_rep = self.min_embedding[entities]  # batchsize * embedding_size
        delta_rep = self.delta_embedding[entities]
        max_rep = min_rep + torch.exp(delta_rep)
        boxes = Box(min_rep, max_rep)
        return boxes

    def init_embedding(self, vocab_size, embed_dim, init_value):
        distribution = uniform.Uniform(init_value[0], init_value[1])
        box_embed = distribution.sample((vocab_size, embed_dim))
        return box_embed

    def random_negative_sampling(self, positive_samples, pos_probs, neg_per_pos=None):
        if neg_per_pos is None:
            neg_per_pos = self.ratio
        negative_samples1 = torch.repeat_interleave(positive_samples, neg_per_pos, dim=0)
        negative_samples2 = torch.repeat_interleave(positive_samples, neg_per_pos, dim=0)

        corrupted_heads = [self.get_negative_samples_for_one_positive(pos, neg_per_pos, mode='corrupt_head') for pos in positive_samples]
        corrupted_tails = [self.get_negative_samples_for_one_positive(pos, neg_per_pos, mode='corrupt_tail') for pos in positive_samples]

        negative_samples1[:, 0] = torch.cat(corrupted_heads)
        negative_samples2[:, 2] = torch.cat(corrupted_tails)
        negative_samples = torch.cat((negative_samples1, negative_samples2), 0).to(device)
        neg_probs = torch.zeros(negative_samples.shape[0], dtype=pos_probs.dtype).to(device)

        return negative_samples, neg_probs

    def random_negative_sampling0(self, positive_samples, pos_probs, neg_per_pos=None):
        if neg_per_pos is None:
            neg_per_pos = self.ratio
        negative_samples1 = torch.repeat_interleave(positive_samples, neg_per_pos, dim=0)
        negative_samples2 = torch.repeat_interleave(positive_samples, neg_per_pos, dim=0)

        # corrupt tails
        corrupted_heads = torch.randint(self.vocab_size, (negative_samples1.shape[0],)).to(device)
        corrupted_tails = torch.randint(self.vocab_size, (negative_samples1.shape[0],)).to(device)

        #filter
        bad_heads_idxs = (corrupted_heads == negative_samples1[:,0])
        bad_tails_idxs = (corrupted_tails == negative_samples2[:,2])
        corrupted_heads[bad_heads_idxs] = torch.randint(self.vocab_size, (torch.sum(bad_heads_idxs),)).to(device)
        corrupted_tails[bad_tails_idxs] = torch.randint(self.vocab_size, (torch.sum(bad_tails_idxs),)).to(device)

        negative_samples1[:, 0] = corrupted_heads
        negative_samples2[:, 2] = corrupted_tails
        negative_samples = torch.cat((negative_samples1, negative_samples2), 0).to(device)
        neg_probs = torch.zeros(negative_samples.shape[0], dtype=pos_probs.dtype).to(device)

        return negative_samples, neg_probs


    def get_negative_samples_for_one_positive(self, positive_sample, neg_per_pos, mode):
        head, relation, tail = positive_sample
        negative_sample_list = []
        negative_sample_size = 0
        while negative_sample_size < neg_per_pos:
            negative_sample = np.random.randint(self.params.VOCAB_SIZE, size=neg_per_pos * 2)

            # filter true values
            if mode == 'corrupt_head' and (int(relation), int(tail)) in self.true_head:  # filter true heads
                # For test data, some (relation, tail) pairs may be unseen and not in self.true_head
                mask = np.in1d(
                    negative_sample,
                    self.true_head[(int(relation), int(tail))],
                    assume_unique=True,
                    invert=True
                )
                negative_sample = negative_sample[mask]
            elif mode == 'corrupt_tail' and (int(head), int(relation)) in self.true_tail:
                mask = np.in1d(
                    negative_sample,
                    self.true_tail[(int(head), int(relation))],
                    assume_unique=True,
                    invert=True
                )
                negative_sample = negative_sample[mask]
            negative_sample_list.append(negative_sample)
            negative_sample_size += negative_sample.size

        negative_sample = np.concatenate(negative_sample_list)[:neg_per_pos]

        negative_sample = torch.from_numpy(negative_sample)
        return negative_sample


    def head_transformation(self, head_boxes, rel_ids):
        relu = nn.ReLU()
        translations = self.rel_trans_for_head[rel_ids]
        scales = relu(self.rel_scale_for_head[rel_ids])
        # affine transformation
        head_boxes.min_embed += translations
        head_boxes.delta_embed *= scales
        head_boxes.max_embed = head_boxes.min_embed + head_boxes.delta_embed

        return head_boxes

    def tail_transformation(self, tail_boxes, rel_ids):
        relu = nn.ReLU()
        translations = self.rel_trans_for_tail[rel_ids]
        scales = relu(self.rel_scale_for_tail[rel_ids])
        # affine transformation
        tail_boxes.min_embed += translations
        tail_boxes.delta_embed *= scales
        tail_boxes.max_embed = tail_boxes.min_embed + tail_boxes.delta_embed

        return tail_boxes

    def get_entity_boxes_detached(self, entities):
        """
        For logic constraint. We only want to optimize relation parameters, so detach entity parameters
        """
        min_rep = self.min_embedding[entities].detach()
        delta_rep = self.delta_embedding[entities].detach()
        max_rep = min_rep + torch.exp(delta_rep)
        boxes = Box(min_rep, max_rep)
        return boxes

    def transitive_rule_loss(self, ids):
        subsets = [ids[(ids[:,1] == r).nonzero().squeeze(1),:] for r in self.params.RULE_CONFIGS['transitive']['relations']]
        sub_ids = torch.cat(subsets, dim=0)

        # only optimize relation parameters
        head_boxes = self.get_entity_boxes_detached(sub_ids[:, 0])
        tail_boxes = self.get_entity_boxes_detached(sub_ids[:, 2])
        head_boxes = self.head_transformation(head_boxes, sub_ids[:,1])
        tail_boxes = self.tail_transformation(tail_boxes, sub_ids[:,1])

        intersection_boxes = self.intersection(head_boxes, tail_boxes)

        log_intersection = self.log_volumes(intersection_boxes)

        # P(f_r(epsilon_box)|g_r(epsilon_box)) should be 1
        vol_loss = torch.norm(1 - torch.exp(log_intersection - self.log_volumes(tail_boxes)))
        return vol_loss

    def composition_rule_loss(self, ids):
        def rels(size, rid):
            # fill a tensor with relation id
            return torch.full((size,), rid, dtype=torch.long)

        def biconditioning(boxes1, boxes2):
            intersection_boxes = self.intersection(boxes1, boxes2)
            log_intersection = self.log_volumes(intersection_boxes)
            # || 1-P(Box2|Box1) ||
            condition_on_box1 = torch.norm(1 - torch.exp(log_intersection - self.log_volumes(boxes1)))
            # || 1-P(Box1|Box2) ||
            condition_on_box2 = torch.norm(1 - torch.exp(log_intersection - self.log_volumes(boxes2)))
            loss = condition_on_box1 + condition_on_box2
            return loss

        vol_loss = 0
        for rule_combn in self.params.RULE_CONFIGS['composite']['relations']:
            r1, r2, r3 = rule_combn
            r1_triples = ids[(ids[:, 1] == r1).nonzero().squeeze(1), :]
            r2_triples = ids[(ids[:, 1] == r2).nonzero().squeeze(1), :]

            # use heads and tails from r1, r2 as reasonable entity samples to help optimize relation parameters
            if len(r1_triples) > 0 and len(r2_triples) > 0:
                entities = torch.cartesian_prod(r1_triples[:,0], r2_triples[:,2])
                head_ids, tail_ids = entities[:,0], entities[:,1]
                size = len(entities)

                # only optimize relation parameters
                head_boxes_r1r2 = self.get_entity_boxes_detached(head_ids)
                tail_boxes_r1r2 = self.get_entity_boxes_detached(tail_ids)
                r1r2_head = self.head_transformation(head_boxes_r1r2, rels(size, r1))
                r1r2_head = self.head_transformation(r1r2_head, rels(size, r2))
                r1r2_tail = self.tail_transformation(tail_boxes_r1r2, rels(size, r1))
                r1r2_tail = self.tail_transformation(r1r2_tail, rels(size, r2))

                # head_boxes_r1r2 have been modified in transformation
                # so make separate box objects with the same parameters
                head_boxes_r3 = self.get_entity_boxes_detached(head_ids)
                tail_boxes_r3 = self.get_entity_boxes_detached(tail_ids)
                r3_head = self.head_transformation(head_boxes_r3, rels(size, r3))
                r3_tail = self.tail_transformation(tail_boxes_r3, rels(size, r3))

                head_transform_loss = biconditioning(r1r2_head, r3_head)
                tail_transform_loss = biconditioning(r1r2_tail, r3_tail)
                vol_loss += head_transform_loss
                vol_loss += tail_transform_loss
        return vol_loss




## Losses

In [48]:
def get_logic_loss(model, ids, params):
    # transitive rule loss regularization
    transitive_coff = torch.tensor(params.regularization['transitive']).to(params.device)
    if transitive_coff > 0:
        transitive_rule_reg = transitive_coff * model.transitive_rule_loss(ids)
    else:
        transitive_rule_reg = 0

    # composite rule loss regularization
    composite_coff = torch.tensor(params.regularization['composite']).to(params.device)
    if composite_coff > 0:
        composition_rule_reg = composite_coff * model.composition_rule_loss(ids)
    else:
        composition_rule_reg = 0

    return (transitive_rule_reg + composition_rule_reg) / len(ids)


def main_mse_loss(model, ids, cls):
    criterion = nn.MSELoss(reduction='mean')
    prediction, truth = model(ids, cls, train=True)
    mse = criterion(torch.exp(prediction), truth)
    return mse


def main_msle_loss(model, ids, cls):
    criterion = nn.MSELoss(reduction='mean')
    prediction, truth = model(ids, cls, train=True)
    mse = criterion(prediction + 1, torch.log(truth + 1))  # prediction is already log
    return mse


def main_mle_loss(model, ids, cls):
    criterion = nn.L1Loss(reduction='mean')
    prediction, truth = model(ids, cls, train=True)
    mse = criterion(prediction + 1, torch.log(truth + 1))  # prediction is already log
    return mse


def L2_regularization(model, ids, params):
    regularization = params.regularization
    device = params.device
    # regularization on delta
    delta_coff, min_coff = torch.tensor(regularization['delta']).to(device), torch.tensor(regularization['min']).to(
        device)
    delta_reg1 = delta_coff * torch.norm(torch.exp(model.delta_embedding[ids[:, 0]]), dim=1).mean()
    delta_reg2 = delta_coff * torch.norm(torch.exp(model.delta_embedding[ids[:, 2]]), dim=1).mean()

    min_reg1 = min_coff * torch.norm(model.min_embedding[ids[:, 0]], dim=1).mean()
    min_reg2 = min_coff * torch.norm(model.min_embedding[ids[:, 2]], dim=1).mean()

    rel_trans_coff = torch.tensor(regularization['rel_trans']).to(device)
    rel_trans_reg = rel_trans_coff * (
            torch.norm(torch.exp(model.rel_trans_for_head[ids[:, 1]]), dim=1).mean() + \
            torch.norm(torch.exp(model.rel_trans_for_tail[ids[:, 1]]), dim=1).mean()
    )

    rel_scale_coff = torch.tensor(regularization['rel_scale']).to(device)
    rel_scale_reg = rel_scale_coff * (
            torch.norm(torch.exp(model.rel_scale_for_head[ids[:, 1]]), dim=1).mean() + \
            torch.norm(torch.exp(model.rel_scale_for_tail[ids[:, 1]]), dim=1).mean()
    )

    L2_reg = delta_reg1 + delta_reg2 + min_reg1 + min_reg2 + rel_trans_reg + rel_scale_reg


    return L2_reg

def kg_mse_loss(model, ids, cls):
    NEG_RATIO = 1
    pos_loss = main_mse_loss(model, ids, cls)
    negative_samples, neg_probs = model.random_negative_sampling(ids, cls)
    neg_loss = main_mse_loss(model, negative_samples, neg_probs)
    main_loss = pos_loss + NEG_RATIO * neg_loss
    return main_loss, pos_loss, neg_loss

def my_loss(model, ids, cls, params):
    main_loss, pos_loss, neg_loss = kg_mse_loss(model, ids, cls)

    logic_loss = get_logic_loss(model, ids, params)

    L2_reg = L2_regularization(model, ids, params)

    loss = main_loss + L2_reg + logic_loss

    return loss, pos_loss, neg_loss, logic_loss


## Dataset

In [49]:
class TripleDataset(TensorDataset):
    """Pairwise Probability dataset"""

    def __init__(self, filenames):
        read = False
        for filename in filenames:
            with open(filename, 'rb') as f:
                if read:
                    temp = pickle.load(f)
                    data = np.concatenate((data, temp), axis=0)
                else:
                    read = True
                    data = pickle.load(f)

        self.ids = torch.from_numpy(data[:, :3].astype(np.compat.long))
        if data.shape[1]>4:
            self.probs = torch.from_numpy(data[:, 3:].astype(np.float32))

        else:
            self.probs = torch.from_numpy(data[:, 3].astype(np.compat.long))
        self.length = self.ids.shape[0]
        super().__init__(self.ids, self.probs)

    def __getitem__(self, index):
        return self.ids[index], self.probs[index]

    def __len__(self):
        return self.length



class UncertainTripleDataset(TensorDataset):
    def __init__(self, data_dir, filename):
        # df = pd.read_csv(join(data_dir, filename), sep='\t', names=['h', 'r', 't', 'p'])
        df = pd.read_csv(join(data_dir, filename))
        data = df[['h', 'r', 't']].values

        prob = df['p'].values

        self.ids = torch.from_numpy(data.astype(np.compat.long)).long()
        self.probs = torch.from_numpy(prob.astype(np.float32))
        self.length = self.ids.shape[0]
        super().__init__(self.ids, self.probs)

        self.true_head, self.true_tail = self.get_true_head_and_tail(self.ids)


    def __getitem__(self, index):
        return self.ids[index], self.probs[index]

    def __len__(self):
        return self.length

    def get_true_head_and_tail(self, triples):
        '''
        Build a dictionary of true triples that will
        be used to filter these true triples for negative sampling
        '''

        true_head = {}
        true_tail = {}

        for head0, relation0, tail0 in triples:
            head, relation, tail = int(head0), int(relation0), int(tail0)
            if (head, relation) not in true_tail:
                true_tail[(head, relation)] = []
            true_tail[(head, relation)].append(tail)
            if (relation, tail) not in true_head:
                true_head[(relation, tail)] = []
            true_head[(relation, tail)].append(head)

        for relation, tail in true_head:
            true_head[(relation, tail)] = np.array(list(set(true_head[(relation, tail)])))
        for head, relation in true_tail:
            true_tail[(head, relation)] = np.array(list(set(true_tail[(head, relation)])))

        return true_head, true_tail



In [50]:
######################################
# Custom Sampler
######################################

class TensorBatchSampler(Sampler):
    def __init__(self, data_source, batch_size, shuffle=False, drop_last=False):
        if not isinstance(data_source, TensorDataset):
            raise ValueError(f"data_source should be an instance of torch.utils.data.TensorDataset, but got data_source={data_source}")
        if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \
                batch_size <= 0:
            raise ValueError(f"batch_size should be a positive integer value, but got batch_size={batch_size}")
        if not isinstance(shuffle, bool):
            raise ValueError(f"shuffle should be a boolean value, but got shuffle={shuffle}")
        if not isinstance(drop_last, bool):
            raise ValueError(f"drop_last should be a boolean value, but got drop_last={drop_last}")
        self.data_source = data_source
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.drop_last = drop_last

    def __iter__(self):
        if self.shuffle:
            idxs = torch.randperm(len(self.data_source)).split(self.batch_size)
        else:
            idxs = torch.arange(len(self.data_source)).split(self.batch_size)
        if self.drop_last and len(self.data_source) % self.batch_size != 0:
            idxs = idxs[:-1]
        return iter(idxs)

    def __len__(self):
        return (math.floor if self.drop_last else math.ceil)(len(self.data_source) / self.batch_size)

def unwrap_collate_fn(batch):
    return batch[0]

class TensorDataLoader(DataLoader):
    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
                 batch_sampler=None, *, drop_last=False, collate_fn=None, **kwargs):
        if sampler is not None or batch_sampler is not None or collate_fn is not None:
            raise ValueError("TensorDataLoader does not support alternate samplers, batch samplers, or collate functions.")
        sampler = TensorBatchSampler(dataset, batch_size, shuffle, drop_last)
        super().__init__(dataset, batch_size=1, shuffle=False, sampler=sampler,
                         drop_last=False, collate_fn = unwrap_collate_fn, **kwargs)


## Evaluators

In [51]:
class Tester:
    class IndexScore:
        """
        The score of a tail when h and r is given.
        It's used in the ranking task to facilitate comparison and sorting.
        Print w as 3 digit precision float.
        """

        def __init__(self, index, score):
            self.index = index
            self.score = score

        def __lt__(self, other):
            return self.score < other.score

        def __repr__(self):
            # return "(index: %d, w:%.3f)" % (self.index, self.score)
            return "(%d, %.3f)" % (self.index, self.score)

        def __str__(self):
            return "(index: %d, w:%.3f)" % (self.index, self.score)

    def __init__(self, model, num_entity):
        """
        :type test_dataset: ShirleyTripleDataset
        """
        self.model = model
        self.num_entity = num_entity

    def get_score(self, h, r, i):
        ids = torch.LongTensor([[h, r, i]])
        cls = torch.Tensor([0])  # dummy
        log_score, _ = self.model(ids, cls)
        return torch.exp(log_score).detach().cpu().numpy()[0]

    def get_t_ranks(self, h, r, ts):
        """
        Given some t index, return the ranks for each t
        :return:
        """
        ranking_dataset = NDCGRankingTestDataset(
            h, r, self.num_entity
        )  # for one hr
        candidates_data = TensorDataLoader(
            ranking_dataset,
            batch_size=ranking_dataset.length,
            shuffle=False
        )
        with torch.no_grad():
            for ids in candidates_data:  # only one batch

                ids = ids  # [[h,r,0],[h,r,1]...]
                cls = torch.zeros(ids.shape[0])
                log_scores, _ = self.model(ids, cls)
                scores = log_scores
                grt_scores = scores[ts]
                ranks = np.array([(scores >= s).sum().detach().cpu().numpy() for s in grt_scores])
                # print('ranks', ranks)
                break

        return ranks

    def ndcg0(self, h, r, tw_truth):
        """
        Compute nDCG(normalized discounted cummulative gain)
        sum(score_ground_truth / log2(rank+1)) / max_possible_dcg
        :param tw_truth: [IndexScore1, IndexScore2, ...], soreted by IndexScore.score descending
        :return:
        """
        # prediction
        ts = [tw.index for tw in tw_truth]
        ranks = self.get_t_ranks(h, r, ts)

        # linear gain
        gains = np.array([tw.score for tw in tw_truth])
        discounts = np.log2(ranks + 2)  # avoid division by 0
        discounted_gains = gains / discounts
        dcg = np.sum(discounted_gains)  # discounted cumulative gain

        # normalize
        best_possible_ranks = np.array([(gains >= g).sum() for g in gains])  # gains [0.9, 0.8, 0.8, 0.7] -> [1,3,3,4]
        max_possible_dcg = np.sum(gains / np.log2(best_possible_ranks + 1))
        # max_possible_dcg = np.sum(gains / np.log2(np.arange(len(gains)) + 2))

        ndcg = dcg / max_possible_dcg  # normalized discounted cumulative gain

        # exponential gain
        exp_gains = np.array([2 ** tw.score - 1 for tw in tw_truth])
        exp_discounted_gains = exp_gains / discounts
        exp_dcg = np.sum(exp_discounted_gains)
        # normalize
        exp_best_possible_ranks = np.array([(exp_gains >= g).sum() for g in exp_gains])
        exp_max_possible_dcg = np.sum(exp_gains / np.log2(exp_best_possible_ranks + 1))
        # exp_max_possible_dcg = np.sum(exp_gains / np.log2(np.arange(len(gains)) + 2))
        exp_ndcg = exp_dcg / exp_max_possible_dcg

        return ndcg, exp_ndcg, ranks


    def ndcg(self, h, r, tw_truth):
        with torch.no_grad():
            gains = torch.zeros(self.num_entity)
            indices = torch.LongTensor([tw.index for tw in tw_truth])
            weights = torch.FloatTensor([tw.score for tw in tw_truth])
            gains[indices] = weights

            # exp_gains = torch.exp2(gains) - 1

            ranking_dataset = NDCGRankingTestDataset(
                h, r, self.num_entity
            )  # for one hr
            candidates_data = TensorDataLoader(
                ranking_dataset,
                batch_size=ranking_dataset.length,
                shuffle=False
            )
            for ids in candidates_data:  # only one batch

                ids = ids  # [[h,r,0],[h,r,1]...]
                cls = torch.zeros(ids.shape[0])
                log_scores, _ = self.model(ids, cls)
                scores = torch.exp(log_scores)
                linear_ndcg = ndcg_score(gains.unsqueeze(0).detach().cpu().numpy(), scores.unsqueeze(0).detach().cpu().numpy())
                # exp_ndcg = ndcg_score(exp_gains.unsqueeze(0).detach().cpu().numpy(), scores.unsqueeze(0).detach().cpu().numpy())
                return linear_ndcg, linear_ndcg, None


    def mean_ndcg(self, hr_map):
        """
        :param hr_map: {h:{r:{t:w}}}
        :return:
        """
        ndcg_sum = 0  # nDCG with linear gain
        exp_ndcg_sum = 0
        count = 0

        t0 = time.time()

        # debug ndcg
        res = []  # [(h,r,tw_truth, ndcg)]

        for h in hr_map:
            for r in hr_map[h]:
                tw_dict = hr_map[h][r]  # {t:w}
                tw_truth = [self.IndexScore(t, w) for t, w in tw_dict.items()]
                tw_truth.sort(reverse=True)  # descending on w
                ndcg, exp_ndcg, ranks = self.ndcg(h, r, tw_truth)  # nDCG with linear gain and exponential gain

                ndcg_sum += ndcg
                exp_ndcg_sum += exp_ndcg
                count += 1

                # debug
                res.append((h, r, tw_truth, ndcg, ranks))

        return ndcg_sum / count, exp_ndcg_sum / count

In [52]:
def evaluate_mse(prediction, truth):
    pred = prediction.detach().cpu()
    truth_np = truth.detach().cpu().numpy()

    mse = (np.square(pred - truth_np)).mean()

    mae = (np.absolute(pred - truth_np)).mean()
    return mse, mae


def evaluate_ndcg(model, hr_map, num_entity):
    tester = Tester(model, num_entity)
    mean_linear_ndcg, mean_exp_ndcg = tester.mean_ndcg(hr_map)
    return mean_linear_ndcg, mean_exp_ndcg

In [53]:
class NDCGRankingTestDataset(TensorDataset):
    def __init__(self, h, r, num_entities):
        self.h, self.r = h, r
        self.num_entities = num_entities
        self.length = num_entities

        # make candidate list for ranking task
        self.candidate_triples = self.get_all_candidate_triples()

    def get_all_candidate_triples(self):
        # candidate triples:
        # (h, r, 0), (h, r, 1), (h, r, 2) ...
        candidates = torch.zeros((self.num_entities, 3), dtype=torch.long)
        candidates[:, 0] = self.h
        candidates[:, 1] = self.r
        candidates[:, 2] = torch.arange(0, self.num_entities)
        return candidates

    def __getitem__(self, index):
        return self.candidate_triples[index, :]

    def __len__(self):
        return self.length

## Params

In [54]:
class Params():
    def __init__(self):
        pass

def set_params(args):
    params = Params()
    data_name, task, model_name = args.data, args.task, args.model_name
    params.data_name = data_name
    params.model_name = model_name
    if data_name == 'cn15k':
        params.VOCAB_SIZE = 15000
        params.REL_VOCAB_SIZE = 36
    elif data_name == 'nl27k':
        params.VOCAB_SIZE = 27221
        params.REL_VOCAB_SIZE = 417

    params.data_dir = join('./data', data_name)
    params.model_dir = join('./trained_models/', data_name)
    params.hr_map = load_hr_map(params.data_dir)

    params.device = device

    if task == 'mse':
        params.early_stop = 'valid_mse'  # 'valid_mse' or 'valid_mae' or 'ndcg'
    else:
        params.early_stop = 'ndcg'

    params.whichmodel = 'bigumbelbox'
    if data_name == 'cn15k':
        if params.early_stop == 'valid_mse' or params.early_stop == 'valid_mae':
            params.DIM = 64
            params.NEG_PER_POS = 30
            
            params.EPOCH = 1000
            params.BATCH_SIZE = 4096
            params.regularization = {'delta': 1, 'min': 1e-3, 'rel_trans': 1e-3, 'rel_scale': 1e-3,
                                     'inverse': 0, 'transitive': 0.1, 'composite': 0}  # no composition rule for CN15k
            params.GUMBEL_BETA = 0.01  # gumbel box
            params.LR = 1e-4
            params.NEG_RATIO = 1
        elif params.early_stop == 'ndcg':
            params.DIM = 300
            params.NEG_PER_POS = 30
            params.LR = 1e-4
            params.EPOCH = 1000
            params.BATCH_SIZE = 2048
            params.regularization = {'delta': 0.5, 'min': 0, 'rel_trans': 0, 'rel_scale': 0,
                                     'inverse': 0, 'transitive': 0.1, 'composite': 0} # no composition rule for CN15k
            params.GUMBEL_BETA = 0.001  # gumbel box
            params.NEG_RATIO = 1
    elif data_name == 'nl27k':
        if params.early_stop == 'valid_mse' or params.early_stop == 'valid_mae':
            params.DIM = 64
            params.NEG_PER_POS = 30
            params.LR = 1e-4
            params.EPOCH = 1000
            params.BATCH_SIZE = 2048
            params.regularization = {'delta': 1, 'min': 1e-3, 'rel_trans': 1e-3, 'rel_scale': 1e-3,
                                     'inverse': 0, 'transitive': 0.1, 'composite': 0.1}
            params.GUMBEL_BETA = 0.01  # gumbel box
            params.NEG_RATIO = 1
        elif params.early_stop == 'ndcg':
            params.DIM = 150
            params.NEG_PER_POS = 30
            params.LR = 1e-4
            params.EPOCH = 1000
            params.BATCH_SIZE = 256
            params.regularization = {'delta': 0, 'min': 0, 'rel_trans': 0, 'rel_scale': 0,
                                     'inverse': 0, 'transitive': 0.1, 'composite': 0.1}
            params.GUMBEL_BETA = 0.0001  # gumbel box
            params.NEG_RATIO = 1


    # define RULE_CONFIGS
    if data_name == 'cn15k':
        params.RULE_CONFIGS = {
            'transitive': { # (a,r,b)^(b,r,c)=>(a,r,c)
                'use': True,
                'relations': [0, 3, 22],
            },
        }
    elif data_name == 'nl27k':
        params.RULE_CONFIGS = {
            'transitive': {
                'use': True,
                'relations': [272, 178, 294],
            },
            'composite':{
                'use': True,
                'relations': [(57, 35, 78)],
            }
        }

    return params


In [55]:
def set_general_params(args):
    params = Params()
    data_name, task, model_name = args.data, args.task, args.model_name
    params.data_name = data_name
    params.model_name = model_name
    
    if data_name == 'cn15k':
        params.VOCAB_SIZE = 15000
        params.REL_VOCAB_SIZE = 36
    elif data_name == 'nl27k':
        params.VOCAB_SIZE = 27221
        params.REL_VOCAB_SIZE = 417

    params.data_dir = join('./data', data_name)
    params.model_dir = join('./trained_models/', data_name)
    params.hr_map = load_hr_map(params.data_dir)
    params.device = device
    
    if task == 'mse':
        params.early_stop = 'valid_mse'  # 'valid_mse' or 'valid_mae' or 'ndcg'
    else:
        params.early_stop = 'ndcg'

    params.whichmodel = 'bigumbelbox'
    params.DIM = 100
    params.NEG_PER_POS = 10
    params.LR = 1e-4
    params.EPOCH = 1000
    params.BATCH_SIZE = 1000
    params.regularization = {'delta': 1, 'min': 1e-3, 'rel_trans': 1e-3, 'rel_scale': 1e-3,
                             'inverse': 0, 'transitive': 0.1, 'composite': 0}  # no composition rule for CN15k
    params.GUMBEL_BETA = 0.01  # gumbel box
    params.NEG_RATIO = 1


    # define RULE_CONFIGS
    if data_name == 'cn15k':
        params.RULE_CONFIGS = {
            'transitive': { # (a,r,b)^(b,r,c)=>(a,r,c)
                'use': True,
                'relations': [0, 3, 22],
            },
        }
    elif data_name == 'nl27k':
        params.RULE_CONFIGS = {
            'transitive': {
                'use': True,
                'relations': [272, 178, 294],
            },
            'composite':{
                'use': True,
                'relations': [(57, 35, 78)],
            }
        }

    return params


In [56]:
def kg_params(args):
    params = Params()
    data_name, task, model_name = args.data, args.task, args.model_name
    params.data_name = data_name
    params.model_name = model_name

    params.VOCAB_SIZE = 20685
    params.REL_VOCAB_SIZE = 12

    params.data_dir = join(KG_PATH, 'split')
    params.model_dir = join(KG_MODEL, 'split')
    # params.hr_map = load_hr_map(params.data_dir)
    params.device = device
    
    if task == 'mse':
        params.early_stop = 'valid_mse'  # 'valid_mse' or 'valid_mae' or 'ndcg'
    else:
        params.early_stop = 'ndcg'

    params.whichmodel = 'bigumbelbox'
    params.DIM = 100
    params.NEG_PER_POS = 10
    params.LR = 1e-4
    params.EPOCH = 1000
    params.BATCH_SIZE = 1000
    params.regularization = {'delta': 1, 'min': 1e-3, 'rel_trans': 1e-3, 'rel_scale': 1e-3,
                             'inverse': 0, 'transitive': 0.1, 'composite': 0}  # no composition rule for CN15k
    params.GUMBEL_BETA = 0.01  # gumbel box
    params.NEG_RATIO = 1


    # define RULE_CONFIGS
    if data_name == 'cn15k':
        params.RULE_CONFIGS = {
            'transitive': { # (a,r,b)^(b,r,c)=>(a,r,c)
                'use': True,
                'relations': [0, 3, 22],
            },
        }
    elif data_name == 'nl27k':
        params.RULE_CONFIGS = {
            'transitive': {
                'use': True,
                'relations': [272, 178, 294],
            },
            'composite':{
                'use': True,
                'relations': [(57, 35, 78)],
            }
        }

    return params


## Train Loop

In [57]:
class Params():
    def __init__(self):
        pass

In [58]:
class TrainLoop():
    def __init__(self, args):
        params = kg_params(args)
        # params = set_params(args)
        self.params = params
        # self.train_dataset = UncertainTripleDataset(params.data_dir, 'train.tsv')
        # self.dev_dataset = UncertainTripleDataset(params.data_dir, 'val.tsv')
        # self.test_dataset = UncertainTripleDataset(params.data_dir, 'test.tsv')
      
        self.train_dataset = UncertainTripleDataset(params.data_dir, 'train.csv')
        self.dev_dataset = UncertainTripleDataset(params.data_dir, 'val.csv')
        self.test_dataset = UncertainTripleDataset(params.data_dir, 'test.csv')
        
        print(self.params.whichmodel)
        print(self.params.early_stop)

        if not os.path.exists(self.params.model_dir):
            os.makedirs(self.params.model_dir)

        self.model = get_new_model(self.params)

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.params.LR)

    def train_step(self, best_metric):
        # Train the model
        train_loss = 0
        batch_size = self.params.BATCH_SIZE
        data = TensorDataLoader(self.train_dataset, batch_size=batch_size, shuffle=True)
        current_metric = {}
        self.model.true_head, self.model.true_tail = self.train_dataset.true_head, self.train_dataset.true_tail  # for negative sampling
        for ids, cls in data:
            self.model.train()
            ids, cls = ids.to(device), cls.to(device)
            loss, pos_loss, neg_loss = kg_mse_loss(self.model, ids, cls)
#             loss, pos_loss, neg_loss, logic_loss = my_loss(self.model, ids, cls, self.params)
            train_loss += loss.item()
            loss.backward()
            self.optimizer.step()

        # test at the end of epoch
        test_MSE, test_MAE, _, _ = self.test(self.test_dataset, self.params, ndcg_also=False)
        # validation
        valid_pos_mse, valid_mae, valid_neg_mse, _ = self.test(self.dev_dataset, self.params, neg_mse_also=True)
        valid_mse = (valid_pos_mse + self.params.NEG_RATIO * valid_neg_mse) / (1 + self.params.NEG_RATIO)
        current_metric['test_mse'] = test_MSE
        current_metric['test_mae'] = test_MAE
        current_metric['valid_mse'] = valid_mse
        current_metric['valid_mae'] = valid_mae

        if test_MSE < best_metric['test_mse']:
            best_metric['test_mse'] = test_MSE
        if test_MAE < best_metric['test_mae']:
            best_metric['test_mae'] = test_MAE
        if valid_mse < best_metric['valid_mse']:
            best_metric['valid_mse'] = valid_mse
        if valid_mae < best_metric['valid_mae']:
            best_metric['valid_mae'] = valid_mae

        return train_loss, best_metric, current_metric
    
    def eval_step(self, epoch, best_metric, last_best_metric, last_best_epoch):
        past_patience = False
        res_d = {}
        if self.params.early_stop == 'ndcg' and epoch % 10 == 0:
            print('####NDCG####')
            res_d['linear_ndcg'], res_d['exp_ndcg'] = evaluate_ndcg(self.model, self.params.hr_map, self.params.VOCAB_SIZE)
            linear_ndcg = res_d['linear_ndcg']
            exp_ndcg = res_d['exp_ndcg']
            if linear_ndcg > last_best_ndcg:
                last_best_ndcg = linear_ndcg
                last_best_epoch = epoch
                res_d['best_ndcg'], res_d['exp_ndcg'],res_d['epoch']  = linear_ndcg, exp_ndcg, last_best_epoch
                torch.save(self.model, join(self.params.model_dir, f'{self.params.whichmodel}.pt'))
            else:
                if epoch >= 1 and epoch-last_best_epoch >= 200:
                    res_d['best_ndcg'], res_d['exp_ndcg'],res_d['epoch']  = linear_ndcg, exp_ndcg, last_best_epoch
                    past_patience = True  # early stop
        # early stopping
        elif self.params.early_stop == 'valid_mse':
            if epoch >= 1 and best_metric['valid_mse'] >= last_best_metric['valid_mse']:  # no improvement or already overfit
#                     print('epoch', epoch, 'last_best_epoch', last_best_epoch)
                if epoch - last_best_epoch >= 50:  # patience
                    res_d['last_best_metric'],res_d['epoch']  = last_best_metric, last_best_epoch
                    past_patience = True
            else:
                last_best_metric = best_metric.copy()
                last_best_epoch = epoch

                torch.save(self.model, join(self.params.model_dir, f'{self.params.whichmodel}.pt'))

        elif self.params.early_stop == 'valid_mae':
            if epoch >= 1 and best_metric['valid_mae'] >= last_best_metric['valid_mae']:  # no improvement or already overfit
                print('epoch', epoch, 'last_best_epoch', last_best_epoch)
                if epoch - last_best_epoch >= 50:  # patience
                    res_d['last_best_metric'],res_d['epoch']  = last_best_metric, last_best_epoch
                    past_patience = True
            else:
                last_best_metric = best_metric.copy()
                last_best_epoch = epoch
                torch.save(self.model, join(self.params.model_dir, f'{self.params.whichmodel}.pt'))
        # if past_patience:
        #     pprint.pprint(sample_dict)
        return last_best_metric, past_patience, last_best_epoch
    
    def train(self):
        best_metric = {
            'test_mse': 1,
            'valid_mse': 1,
            'ndcg': 0,
            'test_mae': 100,
            'valid_mae': 100
        }

        last_best_metric = best_metric.copy()
        last_best_ndcg = 0
        last_best_epoch = 0  # for early stopping
        start_time = time.time()
        pbar = tqdm(total = self.params.EPOCH)
        for epoch in range(self.params.EPOCH):
            loss, best_metric, current_metric = self.train_step(best_metric)
            last_best_metric, past_patience, last_best_epoch = self.eval_step(epoch, best_metric, last_best_metric, last_best_epoch)
            if past_patience:
                break
            pbar.set_description('E {}| loss {:.2f}| '.format(epoch+1, loss)\
                     +'| '.join(['c_'+k+':{:.4f}'.format(float(current_metric[k])) for k in current_metric])\
                        +'| '\
                     +'| '.join(['b_'+k+':{:.4f}'.format(float(best_metric[k])) for k in best_metric]))
            pbar.update()
            
                    
    def test(self, test_data, threshold=None, neg_mse_also=False, ndcg_also=False):
        with torch.no_grad():
            # return mse, mae, neg_mse, ndcg
            # neg also: to test negative samples separately (used for validation)
            data = TensorDataLoader(test_data, batch_size=test_data.length)
            for ids, cls in data:
                ids, cls = ids.to(device), cls.to(device)
                prediction, truth = self.model(ids, cls)

                score = prediction
                label = truth
                mse, mae = evaluate_mse(torch.exp(score), label)

                ndcg = None
                if ndcg_also:
                    ndcg = evaluate_ndcg(self.model, self.params.hr_map, self.params.VOCAB_SIZE)

                if not neg_mse_also:
                    return mse, mae, None, ndcg

                # test for negative samples
                negative_samples, neg_probs = self.model.random_negative_sampling(ids, cls, neg_per_pos=1)
                neg_prediction, _ = self.model(negative_samples, neg_probs)
                neg_mse, neg_mae = evaluate_mse(torch.exp(neg_prediction), neg_probs)

                combined_mae = (mae+neg_mae)/(1+self.params.NEG_RATIO)  # for validation

                return mse, combined_mae, neg_mse, ndcg

In [59]:
def parse_args(args=None):
    args = Params()
#     params.model_name BiGumbelBox QuatE
#     cn15k or nl27k
    args.data = 'cn15k'
#     mse or ndcg
    args.task = 'mse'
    args.model_name = 'BiGumbelBox'
    return args

In [60]:
tloop = TrainLoop(parse_args())

bigumbelbox
valid_mse


In [None]:
tloop.train()

  0%|          | 0/1000 [00:00<?, ?it/s]