# Config

In [1]:
import yaml
import logging

_config = None
class ConfigDict(dict):
    __getattr__ = dict.__getitem__

def config(config_path='./config/config_wn18rr.yaml'):
    """
    default: config("config_wn18rr.yaml")
    """
    def _make_config_dict(obj):
        if isinstance(obj, dict):
            return ConfigDict({k: _make_config_dict(v) for k, v in obj.items()})
        elif isinstance(obj, list):
            return [_make_config_dict(x) for x in obj]
        else:
            return obj
    
    global _config
    if _config is None:
        with open(config_path) as f:
            _config = _make_config_dict(yaml.load(f, Loader=yaml.FullLoader))
    return _config

def overwrite_config_with_args(args=[], sep='.'):
    """
    Manually pass parameters. E.g. overwrite_config_with_args(["--pretrain_config=TransD"])
    """
    def path_set(path, val, sep='.', auto_convert=False):
        steps = path.split(sep)
        obj = _config
        for step in steps[:-1]:
            obj = obj[step]
        old_val = obj[steps[-1]]
        
        if not auto_convert:
            obj[steps[-1]] = val
        elif isinstance(old_val, bool):
            obj[steps[-1]] = val.lower() == 'true'
        elif isinstance(old_val, float):
            obj[steps[-1]] = float(val)
        elif isinstance(old_val, int):
            try:
                obj[steps[-1]] = int(val)
            except ValueError:
                obj[steps[-1]] = float(val)
        else:
            obj[steps[-1]] = val
    
    for arg in args:
        if arg.startswith('--') and '=' in arg:
            path, val = arg[2:].split('=')
            if path != 'config':
                path_set(path, val, sep, auto_convert=True)

def dump_config():
    def _dump_config(obj, prefix):
        if isinstance(obj, dict):
            for k, v in obj.items():
                _dump_config(v, prefix + (k,))
        elif isinstance(obj, list):
            for i, v in enumerate(obj):
                _dump_config(v, prefix + (str(i),))
        else:
            logging.debug('%s=%s', '.'.join(prefix), repr(obj))
    return _dump_config(_config, tuple())

## Select GPU

In [2]:
import subprocess
import logging
import torch

def select_gpu():
    if not torch.cuda.is_available():
        logging.warning("No GPU available. Running on CPU.")
        return None

    try:
        nvidia_info = subprocess.run(
            ['nvidia-smi'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=True
        )
    except FileNotFoundError:
        logging.warning("nvidia-smi not found. Running on CPU.")
        return None

    gpu_info = False
    gpu_info_line = 0
    proc_info = False
    gpu_mem = []
    gpu_occupied = set()

    for line in nvidia_info.stdout.split(b'\n'):
        line = line.decode().strip()
        if gpu_info:
            gpu_info_line += 1
            if line == '':
                gpu_info = False
                continue
            if gpu_info_line % 3 == 2:
                mem_info = line.split('|')[2]
                used_mem_mb = int(mem_info.strip().split()[0][:-3])
                gpu_mem.append(used_mem_mb)
        if proc_info:
            if line == '|  No running processes found                                                 |':
                continue
            if line == '+-----------------------------------------------------------------------------+':
                proc_info = False
                continue
            proc_gpu = int(line.split()[1])
            gpu_occupied.add(proc_gpu)
        if line == '|===============================+======================+======================|':
            gpu_info = True
        if line == '|=============================================================================|':
            proc_info = True

    if not gpu_mem:
        logging.warning("Could not parse nvidia-smi output. Defaulting to GPU 0.")
        return 0

    for i in range(len(gpu_mem)):
        if i not in gpu_occupied:
            logging.info('Automatically selected GPU %d because it is vacant.', i)
            return i
    for i in range(len(gpu_mem)):
        if gpu_mem[i] == min(gpu_mem):
            logging.info('All GPUs are occupied. Automatically selected GPU %d because it has the most free memory.', i)
            return i

## Logger Init

In [3]:
import logging
import os
import datetime

def logger_init():
    root_logger = logging.getLogger()
    root_logger.handlers.clear()    # Xoá các handler mặc định của Jupyter
    root_logger.setLevel(logging.DEBUG) # Hiện cả DEBUG, INFO...

    # Hiện log ra output của cell
    console_handler = logging.StreamHandler()
    console_handler.setFormatter(logging.Formatter('%(module)15s %(asctime)s %(message)s', datefmt='%H:%M:%S'))
    root_logger.addHandler(console_handler)

    if (config().log.to_file):
        log_dir = './output/' + config().task.dir + '/logs'
        os.makedirs(log_dir, exist_ok=True)
        log_filename = os.path.join(
            log_dir,
            _config.log.prefix + datetime.datetime.now().strftime("%m%d%H%M%S") + ".log"
        )
        file_handler = logging.FileHandler(log_filename)
        file_handler.setFormatter(logging.Formatter('%(module)15s %(asctime)s %(message)s', datefmt='%H:%M:%S'))
        root_logger.addHandler(file_handler)

    if config().log.dump_config:
        dump_config()

# Datasets

## Data Utils

In [4]:
from random import randint
from collections import defaultdict
import torch

def sparse_heads_tails(n_entity, train_data, valid_data=None, test_data=None):
    if train_data:
        train_head, train_relation, train_tail = train_data
    else:
        train_head = train_relation = train_tail = []
    if valid_data:
        valid_head, valid_relation, valid_tail = valid_data
    else:
        valid_head = valid_relation = valid_tail = []
        
    if test_data:
        test_head, test_relation, test_tail = test_data
    else:
        test_head = test_relation = test_tail = []
        
    all_head = train_head + valid_head + test_head
    all_relation = train_relation + valid_relation + test_relation
    all_tail = train_tail + valid_tail + test_tail
    
    heads = defaultdict(lambda: set())
    tails = defaultdict(lambda: set())
    for h, r, t in zip(all_head, all_relation, all_tail):
        heads[(t, r)].add(h)
        tails[(h, r)].add(t)
    
    heads_sparse = tails_sparse = {}
    for k in heads.keys():
        heads_sparse[k] = torch.sparse.FloatTensor(torch.LongTensor([list(heads[k])]), torch.ones(len(heads[k])), torch.Size([n_entity]))
    for k in tails.keys():
        tails_sparse[k] = torch.sparse.FloatTensor(torch.LongTensor([list(tails[k])]), torch.ones(len(tails[k])), torch.Size([n_entity]))
    return heads_sparse, tails_sparse

def inplace_shuffle(*lists):
    idx = []
    for i in range(len(lists[0])):
        idx.append(randint(0, i+1))
    for ls in lists:
        for i, item in enumerate(ls):
            j = idx[i]
            ls[i], ls[j] = ls[j], ls[i]

def batch_by_num(n_batch, *lists, n_sample=None):
    if n_sample is None:
        n_sample = len(lists[0])
    for i in range(n_batch):
        head = int(n_sample * i / n_batch)
        tail = int(n_sample * (i + 1) / n_batch)
        ret = [ls[head:tail] for ls in lists]
        if len(ret) > 1:
            yield ret
        else:
            yield ret[0]

def batch_by_size(batch_size, *lists, n_sample=None):
    if n_sample is None:
        n_sample = len(lists[0])
    head = 0
    while head < n_sample:
        tail = min(n_sample, head + batch_size)
        ret = [ls[head:tail] for ls in lists]
        head += batch_size
        if len(ret) > 1:
            yield ret
        else:
            yield ret[0]

## Corrupter

In [5]:
import torch
from collections import defaultdict
import numpy as np
from numpy.random import choice

def get_bern_prob(data, n_relation):
    head, relation, tail = data
    edges = defaultdict(lambda: defaultdict(lambda: set()))
    rev_edges = defaultdict(lambda: defaultdict(lambda: set()))
    for s, r, t in zip(head, relation, tail):
        edges[r][s].add(t)
        rev_edges[r][t].add(s)
    bern_prob = torch.zeros(n_relation)
    for r in edges.keys():
        tph = sum(len(tails) for tails in edges[r].values()) / len(edges[r])
        htp = sum(len(heads) for heads in rev_edges[r].values()) / len(rev_edges[r])
        bern_prob[r] = tph / (tph + htp)
    return bern_prob

class BernCorrupter(object):
    def __init__(self, data, n_entity, n_relation):
        self.bern_prob = get_bern_prob(data, n_relation)
        self.n_entity = n_entity

    def corrupt(self, head, relation, tail):
        prob = self.bern_prob[relation]
        selection = torch.bernoulli(prob).numpy().astype('int64')
        entity_random = choice(self.n_entity, len(head))
        head_out = (1 - selection) * head.numpy() + selection * entity_random
        tail_out = selection * tail.numpy() + (1 - selection) * entity_random
        return torch.from_numpy(head_out), torch.from_numpy(tail_out)

class BernCorrupterMulti(object):
    def __init__(self, data, n_entity, n_relation, n_sample):
        self.bern_prob = get_bern_prob(data, n_relation)
        self.n_entity = n_entity
        self.n_sample = n_sample

    def corrupt(self, head, relation, tail, keep_truth=True):
        n = len(head)
        prob = self.bern_prob[relation]
        selection = torch.bernoulli(prob).numpy().astype('bool')
        head_out = np.tile(head.numpy(), (self.n_sample, 1)).transpose()
        tail_out = np.tile(tail.numpy(), (self.n_sample, 1)).transpose()
        relation_out = relation.unsqueeze(1).expand(n, self.n_sample)
        if keep_truth:
            entity_random = choice(self.n_entity, (n, self.n_sample - 1))
            head_out[selection, 1:] = entity_random[selection]
            tail_out[~selection, 1:] = entity_random[~selection]
        else:
            entity_random = choice(self.n_entity, (n, self.n_sample))
            head_out[selection, :] = entity_random[selection]
            tail_out[~selection, :] = entity_random[~selection]
        return torch.from_numpy(head_out), relation_out, torch.from_numpy(tail_out)

# Data Loader

In [6]:
from itertools import count
from collections import namedtuple
import logging

KBIndex = namedtuple('KBIndex', ['entity_list', 'relation_list', 'entity_id', 'relation_id'])

def index_entity_relation(*filenames):
    entity_set = set()
    relation_set = set()
    for filename in filenames:
        with open(filename) as f:
            for ln in f:
                s, r, t = ln.strip().split('\t')[:3]
                entity_set.add(s)
                entity_set.add(t)
                relation_set.add(r)
    entity_list = sorted(list(entity_set))
    relation_list = sorted(list(relation_set))
    entity_id = dict(zip(entity_list, count()))
    relation_id = dict(zip(relation_list, count()))
    return KBIndex(entity_list, relation_list, entity_id, relation_id)

def graph_size(kb_index):
    return len(kb_index.entity_id), len(kb_index.relation_id)

def read_data(filename, kb_index, with_label=False):
    heads, relations, tails = [], [], []
    labels = []
    skipped_count = 0
    
    with open(filename) as f:
        for ln in f:
            parts = ln.strip().split('\t')
            h, r, t = parts[:3]
            
            # Check if entity and relation exist in kb_index
            if h not in kb_index.entity_id:
                skipped_count += 1
                continue
            if r not in kb_index.relation_id:
                skipped_count += 1
                continue
            if t not in kb_index.entity_id:
                skipped_count += 1
                continue
            
            # All entities and relations are valid, add to lists
            heads.append(kb_index.entity_id[h])
            relations.append(kb_index.relation_id[r])
            tails.append(kb_index.entity_id[t])

            if with_label and len(parts) > 3:
                labels.append(int(parts[3]))
    
    if skipped_count > 0:
        logging.warning(f"Skipped {skipped_count} triples with entities/relations not in kb_index from {filename}")
    
    if with_label:
        return heads, relations, tails, labels
    else:
        return heads, relations, tails

## Metrics for Evaluation


In [18]:
def mr_mrr_hitsk(scores, target, k_list=[1, 3, 10]):
    _, sorted_idx = torch.sort(scores)
    find_target = sorted_idx == target
    target_rank = torch.nonzero(find_target)[0, 0] + 1
    target_score = scores[target].item()  # Get the score of the target entity
    return target_rank, 1 / target_rank, [int(target_rank <= k) for k in k_list], target_score

def acc_pre_rec_f1(predictions, true_labels):
    try:
        y_pred = np.asarray(predictions)
        y_true = np.asarray(true_labels)
    except Exception as e:
        print(f"Error converting inputs to numpy arrays: {e}")
        return None, None, None, None

    if y_pred.shape != y_true.shape:
        raise ValueError("Predictions and true labels must have the same shape.")

    # TP: y_pred == 1 AND y_true == 1
    TP = np.sum((y_pred == 1) & (y_true == 1))
    # TN: y_pred == 0 AND y_true == 0
    TN = np.sum((y_pred == 0) & (y_true == 0))
    # FP: y_pred == 1 AND y_true == 0
    FP = np.sum((y_pred == 1) & (y_true == 0))
    # FN: y_pred == 0 AND y_true == 1
    FN = np.sum((y_pred == 0) & (y_true == 1))

    # (TP + TN) / (TP + TN + FP + FN)
    total_samples = TP + TN + FP + FN
    accuracy = (TP + TN) / total_samples if total_samples > 0 else 0.0

    # TP / (TP + FP)
    precision_denominator = TP + FP
    precision = TP / precision_denominator if precision_denominator > 0 else 0.0

    # TP / (TP + FN)
    recall_denominator = TP + FN
    recall = TP / recall_denominator if recall_denominator > 0 else 0.0

    # 2 * (Precision * Recall) / (Precision + Recall)
    f1_denominator = precision + recall
    f1_score = 2 * (precision * recall) / f1_denominator if f1_denominator > 0 else 0.0
    return accuracy, precision, recall, f1_score


# Base Models

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as nnf
from torch.optim import Adam
import numpy as np
import logging

from config import config
from datasets import batch_by_size


class BaseModule(nn.Module):
    def __init__(self):
        super().__init__()

    def init_weight(self):
        pass

    def forward(self, head, relation, tail):
        pass

    def dist(self, head, relation, tail):
        pass

    def score(self, head, relation, tail):
        pass

    def prob_logit(self, head, relation, tail):
        pass

    def constraint(self):
        pass

    def prob(self, head, relation, tail):
        return nnf.softmax(self.prob_logit(head, relation, tail), dim=-1)

    def pair_loss(self, head, relation, tail, head_bad, tail_bad):
        d_good = self.dist(head, relation, tail)
        d_bad = self.dist(head_bad, relation, tail_bad)
        return nnf.relu(self.margin + d_good - d_bad)

    def softmax_loss(self, head, relation, tail, truth):
        probs = self.prob(head, relation, tail)
        n = probs.size(0)
        truth_probs = torch.log(probs[torch.arange(0, n).type(torch.LongTensor).cuda(), truth] + 1e-30)
        return -truth_probs

class BaseModel(object):
    def __init__(self):
        self.mdl = None # type: BaseModule
        self.weight_decay = 0

    def train(self, train_data, corrupter, tester):
        pass

    def save(self, filename):
        torch.save(self.mdl.state_dict(), filename)

    def load(self, filename):
        state_dict = torch.load(filename, map_location=lambda storage, location: storage.cuda(), weights_only=True)
        self.mdl.load_state_dict(state_dict)

    def _ensure_optimizer(self):
        if not hasattr(self, 'opt'):
            self.opt = Adam(self.mdl.parameters(), weight_decay=self.weight_decay)

    def evaluate(self, test_data, n_entity, heads, tails, filt=True):
        """
        Evaluate the model on Link Prediction task.
        """
        mr_total = mrr_total = 0
        k_list = [1, 3, 10]
        hits_total = [0] * len(k_list)

        count = 0
        with torch.no_grad():  # Thay volatile=True
            for batch_head, batch_relation, batch_tail in batch_by_size(config().test_batch_size, *test_data):
                batch_size = batch_head.size(0)

                all_var = torch.arange(0, n_entity).unsqueeze(0).expand(batch_size, n_entity).long().cuda()
                head_var = batch_head.unsqueeze(1).expand(batch_size, n_entity).cuda()
                relation_var = batch_relation.unsqueeze(1).expand(batch_size, n_entity).cuda()
                tail_var = batch_tail.unsqueeze(1).expand(batch_size, n_entity).cuda()

                batch_head_scores = self.mdl.score(all_var, relation_var, tail_var)
                batch_tail_scores = self.mdl.score(head_var, relation_var, all_var)
            
                # Convert to numpy if needed
                batch_head_scores = batch_head_scores.detach()
                batch_tail_scores = batch_tail_scores.detach()

                for head, relation, tail, head_scores, tail_scores in zip(batch_head, batch_relation, batch_tail, batch_head_scores, batch_tail_scores):
                    head_id, relation_id, tail_id = head.item(), relation.item(), tail.item()
                    if filt:
                        key_head = (tail_id, relation_id)
                        if key_head in heads and heads[key_head]._nnz() > 1:
                            tmp = head_scores[head_id].item()
                            head_scores += heads[key_head].cuda() * 1e30
                            head_scores[head_id] = tmp
                            
                        key_tail = (head_id, relation_id)
                        if key_tail in tails and tails[key_tail]._nnz() > 1:
                            tmp = tail_scores[tail_id].item()
                            tail_scores += tails[key_tail].cuda() * 1e30
                            tail_scores[tail_id] = tmp

                    head_mr, head_mrr, head_hits, head_target_score = mr_mrr_hitsk(scores=head_scores, target=head_id, k_list=k_list)
                    tail_mr, tail_mrr, tail_hits, tail_target_score = mr_mrr_hitsk(scores=tail_scores, target=tail_id, k_list=k_list)                    
                    
                    mr_total += (head_mr + tail_mr)
                    mrr_total += (head_mrr + tail_mrr)
                    hits_total = [(hits_total[i] + head_hits[i] + tail_hits[i]) for i in range(len(k_list))]
                    count += 2
                    
        mr_rate = mr_total / count
        mrr_rate = mrr_total / count
        hits_rate = [hit_total / count for hit_total in hits_total]
        
        metrics_str = f"MR = {mr_rate}\nMRR = {mrr_rate}\n"
        for i in range(len(k_list)):
            metrics_str += f"Hit@{k_list[i]} = {hits_rate[i]}\n"

        logging.info(metrics_str)
        print(metrics_str)
        return mrr_rate

## Discriminator Models

### TransE

In [9]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.autograd import Variable
import logging
import os

class TransEModule(BaseModule):
    def __init__(self, n_entity, n_relation, config):
        super().__init__()
        self.p = config.p
        self.margin = config.margin
        self.temp = config.get('temp', 1)
        self.relation_embed = nn.Embedding(n_relation, config.dim)
        self.entity_embed = nn.Embedding(n_entity, config.dim)
        self.init_weight()

    def init_weight(self):
        for param in self.parameters():
            param.data.normal_(1 / param.size(1) ** 0.5)
            param.data.renorm_(2, 0, 1)

    def forward(self, head, relation, tail):
        return torch.norm(self.entity_embed(tail) - self.entity_embed(head) - self.relation_embed(relation) + 1e-30, p=self.p, dim=-1)

    def dist(self, head, relation, tail):
        return self.forward(head, relation, tail)

    def score(self, head, relation, tail):
        return self.forward(head, relation, tail)

    def prob_logit(self, head, relation, tail):
        return -self.forward(head, relation ,tail) / self.temp

    def constraint(self):
        self.entity_embed.weight.data.renorm_(2, 0, 1)
        self.relation_embed.weight.data.renorm_(2, 0, 1)

class TransE(BaseModel):
    def __init__(self, n_entity, n_relation, config):
        super().__init__()
        self.mdl = TransEModule(n_entity, n_relation, config)
        self.mdl.cuda()
        self.config = config

    def train(self, train_data, corrupter, tester):
        head, relation, tail = train_data
        n_train = len(head)
        n_epoch = self.config.n_epoch
        n_batch = self.config.n_batch
        optimizer = Adam(self.mdl.parameters())

        best_perf = 0
        for epoch in range(n_epoch):
            rand_idx = torch.randperm(n_train)
            head = head[rand_idx]
            relation = relation[rand_idx]
            tail = tail[rand_idx]

            head_corrupted, tail_corrupted = corrupter.corrupt(head, relation, tail)
            head_cuda = head.cuda()
            relation_cuda = relation.cuda()
            tail_cuda = tail.cuda()
            head_corrupted = head_corrupted.cuda()
            tail_corrupted = tail_corrupted.cuda()

            epoch_loss = 0
            for h0, r, t0, h1, t1 in batch_by_num(n_batch, head_cuda, relation_cuda, tail_cuda, head_corrupted, tail_corrupted, n_sample=n_train):
                self.mdl.zero_grad()

                loss = torch.sum(self.mdl.pair_loss(Variable(h0), Variable(r), Variable(t0), Variable(h1), Variable(t1)))
                loss.backward()

                optimizer.step()
                self.mdl.constraint()
                epoch_loss += loss.item()

            logging.info('Epoch %d/%d, Loss=%f', epoch + 1, n_epoch, epoch_loss / n_train)
            if ((epoch + 1) % self.config.epoch_per_test == 0):
                test_perf = tester()
                if (test_perf > best_perf):
                    task_dir = './output/' + config().task.dir + '/models'
                    os.makedirs(task_dir, exist_ok=True)
                    self.save(os.path.join(task_dir, self.config.model_file))
                    
                    best_perf = test_perf
        return best_perf

### TransD

In [10]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.autograd import Variable
import logging
import os
import numpy as np

class TransDModule(BaseModule):
    def __init__(self, n_entity, n_relation, config):
        super().__init__()
        self.margin = config.margin
        self.p = config.p
        self.temp = config.get('temp', 1)
        self.relation_embed = nn.Embedding(n_relation, config.dim)
        self.entity_embed = nn.Embedding(n_entity, config.dim)
        self.proj_relation_embed = nn.Embedding(n_relation, config.dim)
        self.proj_entity_embed = nn.Embedding(n_entity, config.dim)
        self.init_weight()

    def init_weight(self):
        for param in self.parameters():
            param.data.normal_(1 / param.size(1) ** 0.5)
            param.data.renorm_(2, 0, 1)

    def forward(self, head, relation, tail):
        head_proj = self.entity_embed(head) +\
                   torch.sum(self.proj_entity_embed(head) * self.entity_embed(head), dim=-1, keepdim=True) * self.proj_relation_embed(relation)
        tail_proj = self.entity_embed(tail) +\
                   torch.sum(self.proj_entity_embed(tail) * self.entity_embed(tail), dim=-1, keepdim=True) * self.proj_relation_embed(relation)
        return torch.norm(tail_proj - self.relation_embed(relation) - head_proj + 1e-30, p=self.p, dim=-1)

    def dist(self, head, relation, tail):
        return self.forward(head, relation, tail)

    def score(self, head, relation, tail):
        return self.forward(head, relation, tail)

    def prob_logit(self, head, relation, tail):
        return -self.forward(head, relation ,tail) / self.temp

    def constraint(self):
        for param in self.parameters():
            param.data.renorm_(2, 0, 1)

class TransD(BaseModel):
    def __init__(self, n_entity, n_relation, config):
        super().__init__()
        self.mdl = TransDModule(n_entity, n_relation, config)
        self.mdl.cuda()
        self.config = config

    def load_vec(self, path):
        entity_mat = np.loadtxt(os.path.join(path, 'entity2vec.vec'))
        self.mdl.entity_embed.weight.data.copy_(torch.from_numpy(entity_mat))

        relation_mat = np.loadtxt(os.path.join(path, 'relation2vec.vec'))
        n_relation = relation_mat.shape[0]
        self.mdl.relation_embed.weight.data.copy_(torch.from_numpy(relation_mat))

        a_mat = np.loadtxt(os.path.join(path, 'A.vec'))
        self.mdl.proj_relation_embed.weight.data.copy_(torch.from_numpy(a_mat[:n_relation, :]))
        self.mdl.proj_entity_embed.weight.data.copy_(torch.from_numpy(a_mat[n_relation:, :]))
        self.mdl.cuda()

    def train(self, train_data, corrupter, tester):
        head, relation, tail = train_data
        n_train = len(head)
        n_epoch = self.config.n_epoch
        n_batch = self.config.n_batch
        optimizer = Adam(self.mdl.parameters())

        best_perf = 0
        for epoch in range(n_epoch):
            rand_idx = torch.randperm(n_train)
            head = head[rand_idx]
            relation = relation[rand_idx]
            tail = tail[rand_idx]

            head_corrupted, tail_corrupted = corrupter.corrupt(head, relation, tail)
            head_cuda = head.cuda()
            relation_cuda = relation.cuda()
            tail_cuda = tail.cuda()
            head_corrupted = head_corrupted.cuda()
            tail_corrupted = tail_corrupted.cuda()

            epoch_loss = 0
            for h0, r, t0, h1, t1 in batch_by_num(n_batch, head_cuda, relation_cuda, tail_cuda, head_corrupted, tail_corrupted, n_sample=n_train):
                self.mdl.zero_grad()

                loss = torch.sum(self.mdl.pair_loss(Variable(h0), Variable(r), Variable(t0), Variable(h1), Variable(t1)))
                loss.backward()

                optimizer.step()
                self.mdl.constraint()
                epoch_loss += loss.item()

            logging.info('Epoch %d/%d, Loss=%f', epoch + 1, n_epoch, epoch_loss / n_train)
            if ((epoch + 1) % self.config.epoch_per_test == 0):
                test_perf = tester()
                if (test_perf > best_perf):
                    task_dir = './output/' + config().task.dir + '/models'
                    os.makedirs(task_dir, exist_ok=True)
                    self.save(os.path.join(task_dir, self.config.model_file))
                    
                    best_perf = test_perf
        return best_perf

## Generator Models

### DistMult

In [11]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.autograd import Variable
import logging
import os

class DistMultModule(BaseModule):
    def __init__(self, n_entity, n_relation, config):
        super().__init__()
        sigma = 0.2
        self.relation_embed = nn.Embedding(n_relation, config.dim)
        self.relation_embed.weight.data.div_((config.dim / sigma ** 2) ** (1 / 6))
        self.entity_embed = nn.Embedding(n_entity, config.dim)
        self.entity_embed.weight.data.div_((config.dim / sigma ** 2) ** (1 / 6))

    def forward(self, head, relation, tail):
        return torch.sum(self.entity_embed(tail) * self.entity_embed(head) * self.relation_embed(relation), dim=-1)

    def dist(self, head, relation, tail):
        return -self.forward(head, relation, tail)
    
    def score(self, head, relation, tail):
        return -self.forward(head, relation, tail)

    def prob_logit(self, head, relation, tail):
        return self.forward(head, relation, tail)

class DistMult(BaseModel):
    def __init__(self, n_entity, n_relation, config):
        super().__init__()
        self.mdl = DistMultModule(n_entity, n_relation, config)
        self.mdl.cuda()
        self.config = config
        self.weight_decay = config.lam / config.n_batch

    def train(self, train_data, corrupter, tester):
        head, relation, tail = train_data
        n_train = len(head)
        n_epoch = self.config.n_epoch
        n_batch = self.config.n_batch
        optimizer = Adam(self.mdl.parameters(), weight_decay=self.weight_decay)
    
        best_perf = 0
        for epoch in range(n_epoch):
            epoch_loss = 0
            if (epoch % self.config.sample_freq == 0):
                rand_idx = torch.randperm(n_train)
                head = head[rand_idx]
                relation = relation[rand_idx]
                tail = tail[rand_idx]

                head_corrupted, relation_corrupted, tail_corrupted = corrupter.corrupt(head, relation, tail)
                head_corrupted = head_corrupted.cuda()
                relation_corrupted = relation_corrupted.cuda()
                tail_corrupted = tail_corrupted.cuda()

            for hs, rs, ts in batch_by_num(n_batch, head_corrupted, relation_corrupted, tail_corrupted, n_sample=n_train):
                self.mdl.zero_grad()
                label = torch.zeros(len(hs)).type(torch.LongTensor).cuda()

                loss = torch.sum(self.mdl.softmax_loss(Variable(hs), Variable(rs), Variable(ts), label))
                loss.backward()

                optimizer.step()
                epoch_loss += loss.item()

            logging.info('Epoch %d/%d, Loss=%f', epoch + 1, n_epoch, epoch_loss / n_train)
            if ((epoch + 1) % self.config.epoch_per_test == 0):
                test_perf = tester()
                if (test_perf > best_perf):
                    task_dir = './output/' + config().task.dir + '/models'
                    os.makedirs(task_dir, exist_ok=True)
                    self.save(os.path.join(task_dir, self.config.model_file))

                    best_perf = test_perf
        return best_perf

### ComplEx

In [12]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.autograd import Variable
import logging
import os

class ComplExModule(BaseModule):
    def __init__(self, n_entity, n_relation, config):
        super().__init__()
        self.sigma = 0.2
        self.relation_re_embed = nn.Embedding(n_relation, config.dim)
        self.relation_im_embed = nn.Embedding(n_relation, config.dim)
        self.entity_re_embed = nn.Embedding(n_entity, config.dim)
        self.entity_im_embed = nn.Embedding(n_entity, config.dim)
        self.init_weight()

    def init_weight(self):
        for param in self.parameters():
            param.data.div_((config.dim / self.sigma ** 2) ** (1 / 6))

    def forward(self, head, relation, tail):
        return torch.sum(self.relation_re_embed(relation) * self.entity_re_embed(head) * self.entity_re_embed(tail), dim=-1) \
            + torch.sum(self.relation_re_embed(relation) * self.entity_im_embed(head) * self.entity_im_embed(tail), dim=-1) \
            + torch.sum(self.relation_im_embed(relation) * self.entity_re_embed(head) * self.entity_im_embed(tail), dim=-1) \
            - torch.sum(self.relation_im_embed(relation) * self.entity_im_embed(head) * self.entity_re_embed(tail), dim=-1)

    def dist(self, head, relation, tail):
        return -self.forward(head, relation, tail)
    
    def score(self, head, relation, tail):
        return -self.forward(head, relation, tail)

    def prob_logit(self, head, relation, tail):
        return self.forward(head, relation, tail)

class ComplEx(BaseModel):
    def __init__(self, n_entity, n_relation, config):
        super().__init__()
        self.mdl = ComplExModule(n_entity, n_relation, config)
        self.mdl.cuda()
        self.config = config
        self.weight_decay = config.lam / config.n_batch

    def train(self, train_data, corrupter, tester):
        head, relation, tail = train_data
        n_train = len(head)
        n_epoch = self.config.n_epoch
        n_batch = self.config.n_batch
        optimizer = Adam(self.mdl.parameters(), weight_decay=self.weight_decay)

        best_perf = 0
        for epoch in range(n_epoch):
            epoch_loss = 0
            if (epoch % self.config.sample_freq == 0):
                rand_idx = torch.randperm(n_train)
                head = head[rand_idx]
                relation = relation[rand_idx]
                tail = tail[rand_idx]

                head_corrupted, relation_corrupted, tail_corrupted = corrupter.corrupt(head, relation, tail)
                head_corrupted = head_corrupted.cuda()
                relation_corrupted = relation_corrupted.cuda()
                tail_corrupted = tail_corrupted.cuda()

            for hs, rs, ts in batch_by_num(n_batch, head_corrupted, relation_corrupted, tail_corrupted, n_sample=n_train):
                self.mdl.zero_grad()
                label = torch.zeros(len(hs)).type(torch.LongTensor).cuda()

                loss = torch.sum(self.mdl.softmax_loss(Variable(hs), Variable(rs), Variable(ts), label))
                loss.backward()

                optimizer.step()
                epoch_loss += loss.item()

            logging.info('Epoch %d/%d, Loss=%f', epoch + 1, n_epoch, epoch_loss / n_train)
            if ((epoch + 1) % self.config.epoch_per_test == 0):
                test_perf = tester()
                if (test_perf > best_perf):
                    task_dir = './output/' + config().task.dir + '/models'
                    os.makedirs(task_dir, exist_ok=True)
                    self.save(os.path.join(task_dir, self.config.model_file))

                    best_perf = test_perf
        return best_perf

# KBGAN

In [None]:
import pandas as pd
import datetime
from torch.autograd import Variable
import os

class Component():
    def __init__(self, model_type, role='discriminator'):
        """
        model_type = ["TransE", "TransD", "DistMult", "ComplEx"]
        role = ["discriminator", "generator"]
        """
        self.model_type = model_type
        self.model = None 
        self.model_config = None
        self.role = role

    def load_model(self, model_type, role, n_entity, n_relation, model_path=None):
        print(f'Loading pretrained {self.model_type} model...')
        self.model_type = model_type
        self.role = role
        self.model_config = config()[self.model_type]
        output_dir = './output/' + config().task.dir + '/models'
        if self.model_type == 'TransE':
            self.model = TransE(n_entity, n_relation, self.model_config)
        elif self.model_type == 'TransD':
            self.model = TransD(n_entity, n_relation, self.model_config)
        elif self.model_type == 'DistMult':
            self.model = DistMult(n_entity, n_relation, self.model_config)
        elif self.model_type == 'ComplEx':
            self.model = ComplEx(n_entity, n_relation, self.model_config)
        self.model.load(model_path if model_path is not None else os.path.join(output_dir, self.model_config.model_file))

    def pretrain(self, n_entity, n_relation, heads, tails, train_data, valid_data):    
        overwrite_config_with_args(["--pretrain_config=" + self.model_type])
        overwrite_config_with_args(["--log.prefix=" + self.model_type + '_'])
        logger_init()
        
        print(f'Pretraining {self.model_type} model...')
        self.model_config = config()[self.model_type]
        
        if self.model_type == 'TransE':
            corrupter = BernCorrupter(train_data, n_entity, n_relation)
            self.model = TransE(n_entity, n_relation, self.model_config)
        elif self.model_type == 'TransD':
            corrupter = BernCorrupter(train_data, n_entity, n_relation)
            self.model = TransD(n_entity, n_relation, self.model_config)
        elif self.model_type == 'DistMult':
            corrupter = BernCorrupterMulti(train_data, n_entity, n_relation, self.model_config.n_sample)
            self.model = DistMult(n_entity, n_relation, self.model_config)
        elif self.model_type == 'ComplEx':
            corrupter = BernCorrupterMulti(train_data, n_entity, n_relation, self.model_config.n_sample)
            self.model = ComplEx(n_entity, n_relation, self.model_config)
            
        tester = lambda: self.model.evaluate(valid_data, n_entity, heads, tails)
        self.model.train(train_data, corrupter, tester)

    def step(self, head, relation, tail, **kwargs):
        """
        Unified step function that handles both generator and discriminator logic.
        
        For generator:
            kwargs: n_sample=1, temperature=1.0, train=True
            Returns generator coroutine (yields samples, receives rewards)
            
        For discriminator:
            kwargs: head_fake, tail_fake, train=True
            Returns (losses, rewards)
        """
        if self.role == 'generator':
            return self._generator_step(head, relation, tail, **kwargs)
        elif self.role == 'discriminator':
            return self._discriminator_step(head, relation, tail, **kwargs)
        else:
            raise ValueError("Role must be either 'generator' or 'discriminator'")
    
    def _generator_step(self, head, relation, tail, n_sample=1, temperature=1.0, train=True):
        """Generator step: sample fake triples and update with REINFORCE"""
        # Forward pass: generate samples
        n, m = tail.size()
        relation_var = Variable(relation.cuda())
        head_var = Variable(head.cuda())
        tail_var = Variable(tail.cuda())

        logits = self.model.prob_logit(head_var, relation_var, tail_var) / temperature
        probs = nnf.softmax(logits, dim=-1)
        row_idx = torch.arange(0, n).type(torch.LongTensor).unsqueeze(1).expand(n, n_sample)
        sample_idx = torch.multinomial(probs, n_sample, replacement=True)
        sample_heads = head[row_idx, sample_idx.data.cpu()]
        sample_tails = tail[row_idx, sample_idx.data.cpu()]
        
        # Yield samples to get rewards from discriminator
        rewards = yield sample_heads, sample_tails
        
        # Backward pass: update generator with REINFORCE
        if train:            
            self.model._ensure_optimizer()
            self.model.mdl.zero_grad()
            log_probs = nnf.log_softmax(logits, dim=-1)
            reinforce_loss = -torch.sum(Variable(rewards) * log_probs[row_idx.cuda(), sample_idx.data])
            reinforce_loss.backward()
            self.model.opt.step()
            self.model.mdl.constraint()

        yield None

    def _discriminator_step(self, head, relation, tail, head_fake=None, tail_fake=None, train=True):
        """Discriminator step: distinguish real from fake triples"""
        if head_fake is None or tail_fake is None:
            raise ValueError("head_fake and tail_fake must be provided for discriminator step")
        
        # Forward pass: compute losses and scores
        head_var = Variable(head.cuda())
        relation_var = Variable(relation.cuda())
        tail_var = Variable(tail.cuda())
        head_fake_var = Variable(head_fake.cuda())
        tail_fake_var = Variable(tail_fake.cuda())
        
        losses = self.model.mdl.pair_loss(head_var, relation_var, tail_var, head_fake_var, tail_fake_var)
        fake_scores = self.model.mdl.score(head_fake_var, relation_var, tail_fake_var)
                
        # Backward pass: update discriminator
        if train:
            self.model._ensure_optimizer()
            self.model.mdl.zero_grad()
            torch.sum(losses).backward()
            self.model.opt.step()
            self.model.mdl.constraint()
        
        return losses.data, -fake_scores.data
        
    def evaluate(self, test_data, n_entity, heads, tails):
        """
        Evaluate the model on Link Prediction task.
        """
        output_dir = './output/' + config().task.dir + '/models'
        self.model.load(os.path.join(output_dir, self.model_config.model_file))
        print(f'Testing {self.model_type} model...')
        self.model.evaluate(test_data, n_entity, heads, tails)

class KBGAN():
    def __init__(self, discriminator_type="TransE", generator_type="DistMult"):
        """
        discriminator_type = ["TransE", "TransD"]
        generator_type = ["DistMult", "ComplEx"]
        """    
        self.discriminator_type = discriminator_type
        self.discriminator = Component(model_type=discriminator_type, role='discriminator')
        self.generator_type = generator_type
        self.generator = Component(model_type=generator_type, role='generator')

    def load_discriminator(self, n_entity, n_relation, test_data, heads, tails, disc_model_path=None):
        self.discriminator.load_model(self.discriminator_type, 'discriminator', n_entity, n_relation, disc_model_path)
        self.discriminator.evaluate(test_data, n_entity, heads, tails)

    def load_generator(self, n_entity, n_relation, test_data, heads, tails, gen_model_path=None):
        self.generator.load_model(self.discriminator_type, 'generator', n_entity, n_relation, gen_model_path)
        self.generator.evaluate(test_data, n_entity, heads, tails)
        
    def pretrain(self, n_entity, n_relation, heads, tails, train_data, valid_data, test_data):
        if not isinstance(train_data[0], torch.Tensor):
            train_data = [torch.LongTensor(vec) for vec in train_data]
        if not isinstance(valid_data[0], torch.Tensor):
            valid_data = [torch.LongTensor(vec) for vec in valid_data]
        if not isinstance(test_data[0], torch.Tensor):
            test_data = [torch.LongTensor(vec) for vec in test_data]

        # Pretrain discriminator
        self.discriminator.pretrain(n_entity, n_relation, heads, tails, train_data, valid_data)
        self.discriminator.evaluate(test_data, n_entity, heads, tails)

        # Pretrain generator
        self.generator.pretrain(n_entity, n_relation, heads, tails, train_data, valid_data)
        self.generator.evaluate(test_data, n_entity, heads, tails)

    def train_n_test(self, n_entity, n_relation, heads, tails, train_data, valid_data, test_data):
        if not isinstance(train_data[0], torch.Tensor):
            train_data = [torch.LongTensor(vec) for vec in train_data]
        if not isinstance(valid_data[0], torch.Tensor):
            valid_data = [torch.LongTensor(vec) for vec in valid_data]
        if not isinstance(test_data[0], torch.Tensor):
            test_data = [torch.LongTensor(vec) for vec in test_data]
        
        overwrite_config_with_args(["--log.prefix=" + self.discriminator_type + '-' + self.generator_type + "_"])
        logger_init()

        generator_config = config()[config().g_config]
        discriminator_config = config()[config().d_config]

        models = {'TransE': TransE, 'TransD': TransD, 'DistMult': DistMult, 'ComplEx': ComplEx}

        # Load pretrained models into the Component instances
        model_dir = './output/' + config().task.dir + '/models'
        
        # Initialize generator model if not already done
        if self.generator.model is None:
            self.generator.model_config = generator_config
            self.generator.model = models[config().g_config](n_entity, n_relation, generator_config)
        self.generator.model.load(os.path.join(model_dir, generator_config.model_file))
        
        # Initialize discriminator model if not already done
        if self.discriminator.model is None:
            self.discriminator.model_config = discriminator_config
            self.discriminator.model = models[config().d_config](n_entity, n_relation, discriminator_config)
        self.discriminator.model.load(os.path.join(model_dir, discriminator_config.model_file))

        corrupter = BernCorrupterMulti(train_data, n_entity, n_relation, config().KBGAN.n_sample)
        head, relation, tail = train_data
        n_train = len(head)
        n_epoch = config().KBGAN.n_epoch
        n_batch = config().KBGAN.n_batch

        model_name = 'gan_' + self.discriminator_type + '-dis_' + self.generator_type + '-gen_' + datetime.datetime.now().strftime("%m%d%H%M%S") + '.mdl'
        best_perf = 0
        avg_reward = 0

        print(f'Training KBGAN with {self.generator_type} as generator and {self.discriminator_type} as discriminator...')
        for epoch in range(n_epoch):
            epoch_d_loss = 0
            epoch_reward = 0

            head_cand, relation_cand, tail_cand = corrupter.corrupt(head, relation, tail, keep_truth=False)
            for h, r, t, hs, rs, ts in batch_by_num(n_batch, head, relation, tail, head_cand, relation_cand, tail_cand, n_sample=n_train):
                gen_step = self.generator.step(hs, rs, ts, temperature=config().KBGAN.temperature)
                head_smpl, tail_smpl = next(gen_step)
                
                losses, rewards = self.discriminator.step(h, r, t, head_fake=head_smpl.squeeze(), tail_fake=tail_smpl.squeeze())
                epoch_reward += torch.sum(rewards)

                rewards = rewards - avg_reward
                
                # Update generator with rewards
                try:
                    gen_step.send(rewards.unsqueeze(1))
                except StopIteration:
                    pass
                
                epoch_d_loss += torch.sum(losses)
                
            avg_loss = epoch_d_loss / n_train
            avg_reward = epoch_reward / n_train
            logging.info('Epoch %d/%d, D_loss=%f, reward=%f', epoch + 1, n_epoch, avg_loss, avg_reward)
            
            if (epoch + 1) % config().KBGAN.epoch_per_test == 0:
                perf = self.discriminator.model.evaluate(valid_data, n_entity, heads, tails)
                if perf > best_perf:
                    best_perf = perf
                    save_dir = './output/' + config().task.dir + '/kbgan/'
                    os.makedirs(save_dir, exist_ok=True)
                    self.discriminator.model.save(os.path.join(save_dir, model_name))

        save_dir = './output/' + config().task.dir + '/kbgan/'
        self.discriminator.model.load(os.path.join(save_dir, model_name))
        print('Testing KBGAN discriminator model...')
        self.discriminator.model.evaluate(test_data, n_entity, heads, tails)

    def evaluate(self, model_path, model_type, threshold = None, auto_threshold = True, eval_subdir='evaluation on TP'):
        """
        Evaluate a pretrained model on Triple Classification task.
        
        Args:
            model_type: Type of model to evaluate ['DistMult', 'ComplEx', 'TransE', 'TransD']
            eval_subdir: Subdirectory name for evaluation data (default: 'evaluation on TP')
        
        Returns:
            Dictionary containing evaluation results
        """
        # Configuration
        base_task_name = config().task.dir  # e.g., 'wn18rr' or 'FB15k-237'
        train_dir = f'./data/{base_task_name}'
        eval_dir = f'./data/{eval_subdir}/{base_task_name}'
        
        print(f"Loading data from:")
        print(f"  Training dir: {train_dir}")
        print(f"  Evaluation dir: {eval_dir}")
        
        # Create kb_index from training directory
        kb_index = index_entity_relation(
            os.path.join(train_dir, 'train.txt'),
            os.path.join(train_dir, 'valid.txt'),
            os.path.join(train_dir, 'test.txt')
        )
        n_entity, n_relation = graph_size(kb_index)
        print(f"  Entities: {n_entity}, Relations: {n_relation}")
        
        # Load evaluation data with labels
        test_data = read_data(os.path.join(eval_dir, 'test.txt'), kb_index, with_label=True)
        print(f"  Test data: {len(test_data[0])} triples")
        test_labels = test_data[3]
        test_data = test_data[:3]
        
        valid_data = read_data(os.path.join(eval_dir, 'valid.txt'), kb_index, with_label=True)
        print(f"  Valid data: {len(valid_data[0])} triples")
        valid_labels = valid_data[3]
        valid_data = valid_data[:3]
        
        # Create filtering dictionaries
        heads, tails = sparse_heads_tails(n_entity, train_data=None, valid_data=valid_data, test_data=test_data)
        
        # Load pretrained model
        model_dir = f'./output/{base_task_name}/models'
        model_config = config()[self.discriminator_type]
        
        print(f"\nLoading {self.discriminator_type} model from {model_dir}")
        
        if self.discriminator_type == 'DistMult':
            model = DistMult(n_entity, n_relation, model_config)
        elif self.discriminator_type == 'ComplEx':
            model = ComplEx(n_entity, n_relation, model_config)
        elif self.discriminator_type == 'TransE':
            model = TransE(n_entity, n_relation, model_config)
        elif self.discriminator_type == 'TransD':
            model = TransD(n_entity, n_relation, model_config)
        else:
            raise ValueError(f"Unknown model type: {self.discriminator_type}")
        
        model.load(model_path if model_path is not None else os.path.join(model_dir, model_config.model_file))
        print(f"Model loaded successfully!\n")
        
        # Evaluate on triple classification
        print("=" * 60)
        print(f"Evaluating {model_path} on Triple Classification task")
        print("=" * 60)

        heads_test, relations_test, tails_test = test_data
        predictions = [] 
        scores_list = []
        
        with torch.no_grad():
            for i in range(len(heads_test)):
                head = torch.LongTensor([heads_test[i]]).cuda()
                relation = torch.LongTensor([relations_test[i]]).cuda()
                tail = torch.LongTensor([tails_test[i]]).cuda()
                
                score = model.mdl.score(head, relation, tail).item()
                scores_list.append(score)
        
        # Auto-compute threshold if needed
        if threshold is None and auto_threshold:
            # Use validation data to find optimal threshold
            threshold = find_optimal_threshold(model, model_type, valid_data, valid_labels)
            logging.info(f"Auto-computed threshold: {threshold:.4f}")
        elif threshold is None:
            threshold = 0.0
            logging.info(f"Using default threshold: {threshold:.4f}")
        
        is_distance_based = model_type in ['TransE', 'TransD']
        
        
        for score in scores_list:
            if is_distance_based:
                predictions.append(1 if score < threshold else 0)
            else:       
                predictions.append(0 if score < threshold else 1)
       
        
        # Compute metrics
        accuracy, precision, recall, f1 = acc_pre_rec_f1(predictions, test_labels)
        
        metrics_str = f"\nTriple Classification Results:\n"
        metrics_str += f"Threshold = {threshold:.4f}\n"
        metrics_str += f"Accuracy = {accuracy:.4f}\n"
        metrics_str += f"Precision = {precision:.4f}\n"
        metrics_str += f"Recall = {recall:.4f}\n"
        metrics_str += f"F1 Score = {f1:.4f}\n"
        logging.info(metrics_str)

        # # ======================================================
        # # ✅ EXPORT TO EXCEL
        # # ======================================================
        # output_df = pd.DataFrame({
        #     "score": scores_list,
        #     "prediction": predictions,
        #     "label": test_labels
        # })

        # output_df["threshold_used"] = threshold

        # csv_path = f"./output/{base_task_name}/{os.path.basename(model_path)}.csv"
        # output_df.to_csv(csv_path, index=False)
        
        return {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'threshold': threshold,
            'predictions': predictions,
            'scores': scores_list
        }

def find_optimal_threshold(model, model_type, validation_data, labels, n_thresholds=100):
    """
    Find the optimal threshold for triple classification using validation data.

    Args:
        model: The model to evaluate
        validation_data: Tuple of (heads, relations, tails)
        labels: Ground truth labels for validation data
        n_thresholds: Number of threshold values to try

    Returns:
        Optimal threshold value that maximizes F1 score
    """
    heads, relations, tails = validation_data

    # Compute scores for all validation samples
    scores_list = []
    with torch.no_grad():
        for i in range(len(heads)):
            head = torch.LongTensor([heads[i]]).cuda()
            relation = torch.LongTensor([relations[i]]).cuda()
            tail = torch.LongTensor([tails[i]]).cuda()
            score = model.mdl.score(head, relation, tail).item()
            scores_list.append(score)

    # Try different threshold values
    min_score = min(scores_list)
    max_score = max(scores_list)
    threshold_values = np.linspace(min_score, max_score, n_thresholds)

    best_f1 = 0
    best_threshold = 0

    # Determine if model is distance-based or similarity-based
    is_distance_based = model_type in ['TransE', 'TransD']
 
    for threshold in threshold_values:
        predictions = []
        for score in scores_list:
            if is_distance_based:
                predictions.append(1 if score < threshold else 0)
            else:       
                predictions.append(0 if score < threshold else 1)
        
        _, _, _, f1 = acc_pre_rec_f1(predictions, labels)
        
        if f1 > best_f1:
            best_f1 = f1
            best_threshold = threshold

    logging.info(f"Optimal threshold: {best_threshold:.4f} (F1: {best_f1:.4f})")
    return best_threshold

## PU cell for fixing bugs

In [56]:
main = KBGAN()
result = main.evaluate('output/wn18rr/kbgan/gan_TransE-dis_DistMult-gen_1117155145.mdl', 'TransE')

Loading data from:
  Training dir: ./data/wn18rr
  Evaluation dir: ./data/evaluation on TP/wn18rr
  Entities: 40943, Relations: 11
  Test data: 6268 triples
  Valid data: 6068 triples

Loading TransE model from ./output/wn18rr/models
Model loaded successfully!

Evaluating output/wn18rr/kbgan/gan_TransE-dis_DistMult-gen_1117155145.mdl on Triple Classification task


     2331752683 17:56:40 Optimal threshold: 7.1507 (F1: 0.7924)
     2331752683 17:56:40 Auto-computed threshold: 7.1507
     2331752683 17:56:40 
Triple Classification Results:
Threshold = 7.1507
Accuracy = 0.8028
Precision = 0.8625
Recall = 0.7205
F1 Score = 0.7851



# Main

In [None]:
torch.cuda.set_device(select_gpu())

_config = config()
_config.log.to_file = True
# overwrite_config_with_args(["--TransE.n_epoch=10"])
# overwrite_config_with_args(["--TransE.epoch_per_test=10"])
# overwrite_config_with_args(["--DistMult.n_epoch=10"])
# overwrite_config_with_args(["--DistMult.epoch_per_test=10"])
# overwrite_config_with_args(["--KBGAN.n_epoch=10"])
# overwrite_config_with_args(["--KBGAN.epoch_per_test=10"])

# Load data
task_dir = _config.task.dir
task_dir = './data/' + task_dir
kb_index = index_entity_relation(os.path.join(task_dir, 'train.txt'),
                                    os.path.join(task_dir, 'valid.txt'),
                                    os.path.join(task_dir, 'test.txt'))
n_entity, n_relation = graph_size(kb_index)

train_data = read_data(os.path.join(task_dir, 'train.txt'), kb_index)
inplace_shuffle(*train_data)

valid_data = read_data(os.path.join(task_dir, 'valid.txt'), kb_index)
test_data = read_data(os.path.join(task_dir, 'test.txt'), kb_index)
heads, tails = sparse_heads_tails(n_entity, train_data, valid_data, test_data)

valid_data = [torch.LongTensor(vec) for vec in valid_data]
test_data = [torch.LongTensor(vec) for vec in test_data]
train_data = [torch.LongTensor(vec) for vec in train_data]

main_model = KBGAN(discriminator_type="TransE", generator_type="DistMult")
main_model.load_models(n_entity, n_relation, test_data, heads, tails, disc_model_path='output/wn18rr/models/transe.mdl', gen_model_path='output/wn18rr/models/distmult.mdl')
# main_model.train_n_test(n_entity, n_relation, heads, tails, train_data, valid_data, test_data)



  heads_sparse[k] = torch.sparse.FloatTensor(torch.LongTensor([list(heads[k])]), torch.ones(len(heads[k])), torch.Size([n_entity]))


Loading pretrained TransE model...
Loading pretrained DistMult model...
Testing TransE model...
hhehe
MR = 5660.6748046875
MRR = 0.1760859191417694
Hit@1 = 0.006541161455009573
Hit@3 = 0.33806636885768987
Hit@10 = 0.40682833439693683

Testing DistMult model...
hhehe
MR = 5310.08203125
MRR = 0.3606793284416199
Hit@1 = 0.30871091257179323
Hit@3 = 0.3948627951499681
Hit@10 = 0.44416081684747927



In [15]:
main_model.train_n_test(n_entity, n_relation, heads, tails, train_data, valid_data, test_data)

Training KBGAN with DistMult as generator and TransE as discriminator...


      433983126 15:51:46 Epoch 1/5000, D_loss=0.062418, reward=-8.051137
      433983126 15:51:46 Epoch 2/5000, D_loss=0.062277, reward=-8.047962
      433983126 15:51:47 Epoch 3/5000, D_loss=0.063198, reward=-8.047464
      433983126 15:51:47 Epoch 4/5000, D_loss=0.060982, reward=-8.056492
      433983126 15:51:47 Epoch 5/5000, D_loss=0.063656, reward=-8.048387
      433983126 15:51:48 Epoch 6/5000, D_loss=0.062375, reward=-8.055072
      433983126 15:51:48 Epoch 7/5000, D_loss=0.061417, reward=-8.051257
      433983126 15:51:49 Epoch 8/5000, D_loss=0.061189, reward=-8.055042
      433983126 15:51:49 Epoch 9/5000, D_loss=0.061309, reward=-8.050263
      433983126 15:51:49 Epoch 10/5000, D_loss=0.062492, reward=-8.051524
      433983126 15:51:50 Epoch 11/5000, D_loss=0.060551, reward=-8.056904
      433983126 15:51:50 Epoch 12/5000, D_loss=0.061429, reward=-8.046429
      433983126 15:51:50 Epoch 13/5000, D_loss=0.060959, reward=-8.050442
      433983126 15:51:51 Epoch 14/5000, D_loss=

hhehe


      248942384 15:52:27 MR = 5369.865234375
MRR = 0.17133787274360657
Hit@1 = 0.004449571522742254
Hit@3 = 0.32333553065260384
Hit@10 = 0.41529334212261043



MR = 5369.865234375
MRR = 0.17133787274360657
Hit@1 = 0.004449571522742254
Hit@3 = 0.32333553065260384
Hit@10 = 0.41529334212261043



      433983126 15:52:27 Epoch 101/5000, D_loss=0.049927, reward=-7.979209
      433983126 15:52:28 Epoch 102/5000, D_loss=0.052282, reward=-7.972497
      433983126 15:52:28 Epoch 103/5000, D_loss=0.049850, reward=-7.979791
      433983126 15:52:28 Epoch 104/5000, D_loss=0.050921, reward=-7.976317
      433983126 15:52:29 Epoch 105/5000, D_loss=0.050676, reward=-7.978796
      433983126 15:52:29 Epoch 106/5000, D_loss=0.050434, reward=-7.972086
      433983126 15:52:30 Epoch 107/5000, D_loss=0.049596, reward=-7.976666
      433983126 15:52:30 Epoch 108/5000, D_loss=0.051329, reward=-7.973950
      433983126 15:52:30 Epoch 109/5000, D_loss=0.049334, reward=-7.977036
      433983126 15:52:31 Epoch 110/5000, D_loss=0.050504, reward=-7.970263
      433983126 15:52:31 Epoch 111/5000, D_loss=0.053034, reward=-7.966420
      433983126 15:52:31 Epoch 112/5000, D_loss=0.048658, reward=-7.972291
      433983126 15:52:32 Epoch 113/5000, D_loss=0.050715, reward=-7.971625
      433983126 15:52:32 

hhehe


      248942384 15:53:08 MR = 5111.28125
MRR = 0.1725415587425232
Hit@1 = 0.005273566249176005
Hit@3 = 0.3246539222148978
Hit@10 = 0.41727092946605143



MR = 5111.28125
MRR = 0.1725415587425232
Hit@1 = 0.005273566249176005
Hit@3 = 0.3246539222148978
Hit@10 = 0.41727092946605143



      433983126 15:53:09 Epoch 201/5000, D_loss=0.047369, reward=-7.932260
      433983126 15:53:09 Epoch 202/5000, D_loss=0.046649, reward=-7.942058
      433983126 15:53:09 Epoch 203/5000, D_loss=0.045997, reward=-7.936792
      433983126 15:53:10 Epoch 204/5000, D_loss=0.048778, reward=-7.936220
      433983126 15:53:10 Epoch 205/5000, D_loss=0.049351, reward=-7.930022
      433983126 15:53:10 Epoch 206/5000, D_loss=0.048522, reward=-7.927615
      433983126 15:53:11 Epoch 207/5000, D_loss=0.048237, reward=-7.930089
      433983126 15:53:11 Epoch 208/5000, D_loss=0.047345, reward=-7.925621
      433983126 15:53:12 Epoch 209/5000, D_loss=0.046669, reward=-7.930151
      433983126 15:53:12 Epoch 210/5000, D_loss=0.047763, reward=-7.928990
      433983126 15:53:12 Epoch 211/5000, D_loss=0.046871, reward=-7.928370
      433983126 15:53:13 Epoch 212/5000, D_loss=0.048148, reward=-7.928568
      433983126 15:53:13 Epoch 213/5000, D_loss=0.047528, reward=-7.932072
      433983126 15:53:13 

hhehe


      248942384 15:53:50 MR = 5053.005859375
MRR = 0.17394746840000153
Hit@1 = 0.005767963085036256
Hit@3 = 0.3264667106130521
Hit@10 = 0.4225444957152274



MR = 5053.005859375
MRR = 0.17394746840000153
Hit@1 = 0.005767963085036256
Hit@3 = 0.3264667106130521
Hit@10 = 0.4225444957152274



      433983126 15:53:51 Epoch 301/5000, D_loss=0.045747, reward=-7.911760
      433983126 15:53:51 Epoch 302/5000, D_loss=0.047738, reward=-7.907630
      433983126 15:53:51 Epoch 303/5000, D_loss=0.046424, reward=-7.912888
      433983126 15:53:52 Epoch 304/5000, D_loss=0.047167, reward=-7.902901
      433983126 15:53:52 Epoch 305/5000, D_loss=0.047320, reward=-7.910397
      433983126 15:53:52 Epoch 306/5000, D_loss=0.045935, reward=-7.914910
      433983126 15:53:53 Epoch 307/5000, D_loss=0.044857, reward=-7.907165
      433983126 15:53:53 Epoch 308/5000, D_loss=0.045954, reward=-7.905906
      433983126 15:53:54 Epoch 309/5000, D_loss=0.045179, reward=-7.907722
      433983126 15:53:54 Epoch 310/5000, D_loss=0.045070, reward=-7.910141
      433983126 15:53:54 Epoch 311/5000, D_loss=0.047052, reward=-7.906309
      433983126 15:53:55 Epoch 312/5000, D_loss=0.047657, reward=-7.903251
      433983126 15:53:55 Epoch 313/5000, D_loss=0.045973, reward=-7.908085
      433983126 15:53:55 

hhehe


      248942384 15:54:32 MR = 4901.384765625
MRR = 0.17428149282932281
Hit@1 = 0.0054383651944627555
Hit@3 = 0.3276203032300593
Hit@10 = 0.4215557020435069



MR = 4901.384765625
MRR = 0.17428149282932281
Hit@1 = 0.0054383651944627555
Hit@3 = 0.3276203032300593
Hit@10 = 0.4215557020435069



      433983126 15:54:33 Epoch 401/5000, D_loss=0.044692, reward=-7.890929
      433983126 15:54:33 Epoch 402/5000, D_loss=0.045659, reward=-7.890118
      433983126 15:54:33 Epoch 403/5000, D_loss=0.046801, reward=-7.882823
      433983126 15:54:34 Epoch 404/5000, D_loss=0.045887, reward=-7.889050
      433983126 15:54:34 Epoch 405/5000, D_loss=0.044632, reward=-7.893925
      433983126 15:54:34 Epoch 406/5000, D_loss=0.046047, reward=-7.890211
      433983126 15:54:35 Epoch 407/5000, D_loss=0.045610, reward=-7.888832
      433983126 15:54:35 Epoch 408/5000, D_loss=0.045283, reward=-7.888434
      433983126 15:54:36 Epoch 409/5000, D_loss=0.045644, reward=-7.889419
      433983126 15:54:36 Epoch 410/5000, D_loss=0.043950, reward=-7.894065
      433983126 15:54:36 Epoch 411/5000, D_loss=0.046460, reward=-7.889514
      433983126 15:54:37 Epoch 412/5000, D_loss=0.046064, reward=-7.879041
      433983126 15:54:37 Epoch 413/5000, D_loss=0.046212, reward=-7.884784
      433983126 15:54:38 

hhehe


      248942384 15:55:14 MR = 4910.8974609375
MRR = 0.17493632435798645
Hit@1 = 0.004943968358602505
Hit@3 = 0.32679630850362557
Hit@10 = 0.4268292682926829



MR = 4910.8974609375
MRR = 0.17493632435798645
Hit@1 = 0.004943968358602505
Hit@3 = 0.32679630850362557
Hit@10 = 0.4268292682926829



      433983126 15:55:15 Epoch 501/5000, D_loss=0.045607, reward=-7.865664
      433983126 15:55:15 Epoch 502/5000, D_loss=0.044791, reward=-7.873797
      433983126 15:55:16 Epoch 503/5000, D_loss=0.043893, reward=-7.875847
      433983126 15:55:16 Epoch 504/5000, D_loss=0.046280, reward=-7.871843
      433983126 15:55:16 Epoch 505/5000, D_loss=0.043968, reward=-7.873231
      433983126 15:55:17 Epoch 506/5000, D_loss=0.044250, reward=-7.874781
      433983126 15:55:17 Epoch 507/5000, D_loss=0.045207, reward=-7.873032
      433983126 15:55:17 Epoch 508/5000, D_loss=0.045119, reward=-7.869075
      433983126 15:55:18 Epoch 509/5000, D_loss=0.043706, reward=-7.875257
      433983126 15:55:18 Epoch 510/5000, D_loss=0.043775, reward=-7.875618
      433983126 15:55:19 Epoch 511/5000, D_loss=0.044062, reward=-7.879329
      433983126 15:55:19 Epoch 512/5000, D_loss=0.044053, reward=-7.879839
      433983126 15:55:19 Epoch 513/5000, D_loss=0.043846, reward=-7.875437
      433983126 15:55:20 

hhehe


      248942384 15:55:56 MR = 4772.8681640625
MRR = 0.17553958296775818
Hit@1 = 0.004943968358602505
Hit@3 = 0.3276203032300593
Hit@10 = 0.4288068556361239



MR = 4772.8681640625
MRR = 0.17553958296775818
Hit@1 = 0.004943968358602505
Hit@3 = 0.3276203032300593
Hit@10 = 0.4288068556361239



      433983126 15:55:56 Epoch 601/5000, D_loss=0.045569, reward=-7.853202
      433983126 15:55:57 Epoch 602/5000, D_loss=0.045192, reward=-7.854895
      433983126 15:55:57 Epoch 603/5000, D_loss=0.043935, reward=-7.855207
      433983126 15:55:57 Epoch 604/5000, D_loss=0.045824, reward=-7.856401
      433983126 15:55:58 Epoch 605/5000, D_loss=0.044454, reward=-7.859341
      433983126 15:55:58 Epoch 606/5000, D_loss=0.043726, reward=-7.854714
      433983126 15:55:58 Epoch 607/5000, D_loss=0.046050, reward=-7.849056
      433983126 15:55:59 Epoch 608/5000, D_loss=0.045044, reward=-7.851746
      433983126 15:55:59 Epoch 609/5000, D_loss=0.044976, reward=-7.854652
      433983126 15:56:00 Epoch 610/5000, D_loss=0.043867, reward=-7.852609
      433983126 15:56:00 Epoch 611/5000, D_loss=0.044232, reward=-7.857585
      433983126 15:56:00 Epoch 612/5000, D_loss=0.044906, reward=-7.859560
      433983126 15:56:01 Epoch 613/5000, D_loss=0.042780, reward=-7.853894
      433983126 15:56:01 

hhehe


      248942384 15:56:37 MR = 4762.50244140625
MRR = 0.1745133250951767
Hit@1 = 0.004614370468029005
Hit@3 = 0.3266315095583388
Hit@10 = 0.4293012524719842



MR = 4762.50244140625
MRR = 0.1745133250951767
Hit@1 = 0.004614370468029005
Hit@3 = 0.3266315095583388
Hit@10 = 0.4293012524719842



      433983126 15:56:38 Epoch 701/5000, D_loss=0.047420, reward=-7.838552
      433983126 15:56:38 Epoch 702/5000, D_loss=0.044551, reward=-7.839141
      433983126 15:56:38 Epoch 703/5000, D_loss=0.044110, reward=-7.837630
      433983126 15:56:39 Epoch 704/5000, D_loss=0.043370, reward=-7.841883
      433983126 15:56:39 Epoch 705/5000, D_loss=0.043514, reward=-7.849239
      433983126 15:56:39 Epoch 706/5000, D_loss=0.042605, reward=-7.842292
      433983126 15:56:40 Epoch 707/5000, D_loss=0.044818, reward=-7.835853
      433983126 15:56:40 Epoch 708/5000, D_loss=0.042166, reward=-7.843591
      433983126 15:56:41 Epoch 709/5000, D_loss=0.043358, reward=-7.848644
      433983126 15:56:41 Epoch 710/5000, D_loss=0.042756, reward=-7.840231
      433983126 15:56:41 Epoch 711/5000, D_loss=0.045131, reward=-7.840500
      433983126 15:56:42 Epoch 712/5000, D_loss=0.045610, reward=-7.842658
      433983126 15:56:42 Epoch 713/5000, D_loss=0.045723, reward=-7.839142
      433983126 15:56:42 

hhehe


      248942384 15:57:18 MR = 4789.22607421875
MRR = 0.17442521452903748
Hit@1 = 0.004449571522742254
Hit@3 = 0.3246539222148978
Hit@10 = 0.4299604482531312



MR = 4789.22607421875
MRR = 0.17442521452903748
Hit@1 = 0.004449571522742254
Hit@3 = 0.3246539222148978
Hit@10 = 0.4299604482531312



      433983126 15:57:18 Epoch 801/5000, D_loss=0.045540, reward=-7.821571
      433983126 15:57:19 Epoch 802/5000, D_loss=0.042724, reward=-7.829964
      433983126 15:57:19 Epoch 803/5000, D_loss=0.045321, reward=-7.822020
      433983126 15:57:19 Epoch 804/5000, D_loss=0.045212, reward=-7.825594
      433983126 15:57:20 Epoch 805/5000, D_loss=0.045258, reward=-7.823433
      433983126 15:57:20 Epoch 806/5000, D_loss=0.044512, reward=-7.820435
      433983126 15:57:21 Epoch 807/5000, D_loss=0.045307, reward=-7.820511
      433983126 15:57:21 Epoch 808/5000, D_loss=0.044146, reward=-7.824034
      433983126 15:57:21 Epoch 809/5000, D_loss=0.043557, reward=-7.824746
      433983126 15:57:22 Epoch 810/5000, D_loss=0.045349, reward=-7.824862
      433983126 15:57:22 Epoch 811/5000, D_loss=0.044292, reward=-7.822688
      433983126 15:57:22 Epoch 812/5000, D_loss=0.044360, reward=-7.823657
      433983126 15:57:23 Epoch 813/5000, D_loss=0.045932, reward=-7.822784
      433983126 15:57:23 

hhehe


      248942384 15:58:00 MR = 4721.64794921875
MRR = 0.17756271362304688
Hit@1 = 0.007251153592617007
Hit@3 = 0.3266315095583388
Hit@10 = 0.4342452208305867



MR = 4721.64794921875
MRR = 0.17756271362304688
Hit@1 = 0.007251153592617007
Hit@3 = 0.3266315095583388
Hit@10 = 0.4342452208305867



      433983126 15:58:00 Epoch 901/5000, D_loss=0.044935, reward=-7.807064
      433983126 15:58:00 Epoch 902/5000, D_loss=0.045312, reward=-7.802032
      433983126 15:58:01 Epoch 903/5000, D_loss=0.043868, reward=-7.806353
      433983126 15:58:01 Epoch 904/5000, D_loss=0.045326, reward=-7.801484
      433983126 15:58:02 Epoch 905/5000, D_loss=0.045224, reward=-7.808917
      433983126 15:58:02 Epoch 906/5000, D_loss=0.046778, reward=-7.800796
      433983126 15:58:02 Epoch 907/5000, D_loss=0.045651, reward=-7.809041
      433983126 15:58:03 Epoch 908/5000, D_loss=0.044285, reward=-7.804794
      433983126 15:58:03 Epoch 909/5000, D_loss=0.043641, reward=-7.810477
      433983126 15:58:04 Epoch 910/5000, D_loss=0.046691, reward=-7.805077
      433983126 15:58:04 Epoch 911/5000, D_loss=0.046752, reward=-7.804611
      433983126 15:58:04 Epoch 912/5000, D_loss=0.046979, reward=-7.799766
      433983126 15:58:05 Epoch 913/5000, D_loss=0.044085, reward=-7.803256
      433983126 15:58:05 

hhehe


      248942384 15:58:41 MR = 4640.31396484375
MRR = 0.1773034781217575
Hit@1 = 0.0065919578114700065
Hit@3 = 0.3269611074489123
Hit@10 = 0.43342122610415296



MR = 4640.31396484375
MRR = 0.1773034781217575
Hit@1 = 0.0065919578114700065
Hit@3 = 0.3269611074489123
Hit@10 = 0.43342122610415296



      433983126 15:58:42 Epoch 1001/5000, D_loss=0.045539, reward=-7.781003
      433983126 15:58:42 Epoch 1002/5000, D_loss=0.044238, reward=-7.785508
      433983126 15:58:42 Epoch 1003/5000, D_loss=0.046808, reward=-7.782728
      433983126 15:58:43 Epoch 1004/5000, D_loss=0.046011, reward=-7.776143
      433983126 15:58:43 Epoch 1005/5000, D_loss=0.046864, reward=-7.787508
      433983126 15:58:43 Epoch 1006/5000, D_loss=0.045537, reward=-7.783315
      433983126 15:58:44 Epoch 1007/5000, D_loss=0.045893, reward=-7.781581
      433983126 15:58:44 Epoch 1008/5000, D_loss=0.046408, reward=-7.778918
      433983126 15:58:45 Epoch 1009/5000, D_loss=0.046089, reward=-7.778867
      433983126 15:58:45 Epoch 1010/5000, D_loss=0.045345, reward=-7.779580
      433983126 15:58:45 Epoch 1011/5000, D_loss=0.045479, reward=-7.782218
      433983126 15:58:46 Epoch 1012/5000, D_loss=0.045468, reward=-7.784847
      433983126 15:58:46 Epoch 1013/5000, D_loss=0.047425, reward=-7.779617
      433983

hhehe


      248942384 15:59:23 MR = 4651.92626953125
MRR = 0.17738938331604004
Hit@1 = 0.005932762030323006
Hit@3 = 0.3269611074489123
Hit@10 = 0.4332564271588662



MR = 4651.92626953125
MRR = 0.17738938331604004
Hit@1 = 0.005932762030323006
Hit@3 = 0.3269611074489123
Hit@10 = 0.4332564271588662



      433983126 15:59:23 Epoch 1101/5000, D_loss=0.045773, reward=-7.758100
      433983126 15:59:23 Epoch 1102/5000, D_loss=0.046090, reward=-7.763723
      433983126 15:59:24 Epoch 1103/5000, D_loss=0.044957, reward=-7.757276
      433983126 15:59:24 Epoch 1104/5000, D_loss=0.047444, reward=-7.752429
      433983126 15:59:25 Epoch 1105/5000, D_loss=0.046432, reward=-7.758199
      433983126 15:59:25 Epoch 1106/5000, D_loss=0.045863, reward=-7.759752
      433983126 15:59:25 Epoch 1107/5000, D_loss=0.045206, reward=-7.761202
      433983126 15:59:26 Epoch 1108/5000, D_loss=0.044578, reward=-7.754630
      433983126 15:59:26 Epoch 1109/5000, D_loss=0.046476, reward=-7.752301
      433983126 15:59:26 Epoch 1110/5000, D_loss=0.045462, reward=-7.758770
      433983126 15:59:27 Epoch 1111/5000, D_loss=0.045785, reward=-7.752745
      433983126 15:59:27 Epoch 1112/5000, D_loss=0.045467, reward=-7.753915
      433983126 15:59:28 Epoch 1113/5000, D_loss=0.047828, reward=-7.750113
      433983

hhehe


      248942384 16:00:03 MR = 4628.58837890625
MRR = 0.1757526844739914
Hit@1 = 0.005108767303889255
Hit@3 = 0.3231707317073171
Hit@10 = 0.4312788398154252



MR = 4628.58837890625
MRR = 0.1757526844739914
Hit@1 = 0.005108767303889255
Hit@3 = 0.3231707317073171
Hit@10 = 0.4312788398154252



      433983126 16:00:04 Epoch 1201/5000, D_loss=0.045141, reward=-7.737478
      433983126 16:00:04 Epoch 1202/5000, D_loss=0.046819, reward=-7.731623
      433983126 16:00:04 Epoch 1203/5000, D_loss=0.047127, reward=-7.730025
      433983126 16:00:05 Epoch 1204/5000, D_loss=0.046370, reward=-7.730460
      433983126 16:00:05 Epoch 1205/5000, D_loss=0.047656, reward=-7.726569
      433983126 16:00:06 Epoch 1206/5000, D_loss=0.048688, reward=-7.726820
      433983126 16:00:06 Epoch 1207/5000, D_loss=0.046540, reward=-7.728196
      433983126 16:00:06 Epoch 1208/5000, D_loss=0.047504, reward=-7.726711
      433983126 16:00:07 Epoch 1209/5000, D_loss=0.046738, reward=-7.727208
      433983126 16:00:07 Epoch 1210/5000, D_loss=0.047708, reward=-7.721283
      433983126 16:00:07 Epoch 1211/5000, D_loss=0.048069, reward=-7.720345
      433983126 16:00:08 Epoch 1212/5000, D_loss=0.047706, reward=-7.723946
      433983126 16:00:08 Epoch 1213/5000, D_loss=0.047445, reward=-7.728153
      433983

hhehe


      248942384 16:00:45 MR = 4626.69873046875
MRR = 0.17784765362739563
Hit@1 = 0.005767963085036256
Hit@3 = 0.32679630850362557
Hit@10 = 0.43243243243243246



MR = 4626.69873046875
MRR = 0.17784765362739563
Hit@1 = 0.005767963085036256
Hit@3 = 0.32679630850362557
Hit@10 = 0.43243243243243246



      433983126 16:00:45 Epoch 1301/5000, D_loss=0.047460, reward=-7.703501
      433983126 16:00:45 Epoch 1302/5000, D_loss=0.047297, reward=-7.698467
      433983126 16:00:46 Epoch 1303/5000, D_loss=0.047584, reward=-7.697904
      433983126 16:00:46 Epoch 1304/5000, D_loss=0.047846, reward=-7.699614
      433983126 16:00:46 Epoch 1305/5000, D_loss=0.048169, reward=-7.695851
      433983126 16:00:47 Epoch 1306/5000, D_loss=0.047025, reward=-7.700907
      433983126 16:00:47 Epoch 1307/5000, D_loss=0.047089, reward=-7.707117
      433983126 16:00:48 Epoch 1308/5000, D_loss=0.048719, reward=-7.703499
      433983126 16:00:48 Epoch 1309/5000, D_loss=0.048462, reward=-7.699844
      433983126 16:00:48 Epoch 1310/5000, D_loss=0.046847, reward=-7.699240
      433983126 16:00:49 Epoch 1311/5000, D_loss=0.047271, reward=-7.699429
      433983126 16:00:49 Epoch 1312/5000, D_loss=0.048194, reward=-7.697919
      433983126 16:00:49 Epoch 1313/5000, D_loss=0.049502, reward=-7.695900
      433983

hhehe


      248942384 16:01:26 MR = 4626.4912109375
MRR = 0.17746135592460632
Hit@1 = 0.005108767303889255
Hit@3 = 0.32778510217534607
Hit@10 = 0.43441001977587346



MR = 4626.4912109375
MRR = 0.17746135592460632
Hit@1 = 0.005108767303889255
Hit@3 = 0.32778510217534607
Hit@10 = 0.43441001977587346



      433983126 16:01:27 Epoch 1401/5000, D_loss=0.049672, reward=-7.666975
      433983126 16:01:27 Epoch 1402/5000, D_loss=0.048225, reward=-7.671791
      433983126 16:01:27 Epoch 1403/5000, D_loss=0.047758, reward=-7.676845
      433983126 16:01:28 Epoch 1404/5000, D_loss=0.047675, reward=-7.671540
      433983126 16:01:28 Epoch 1405/5000, D_loss=0.047337, reward=-7.673842
      433983126 16:01:29 Epoch 1406/5000, D_loss=0.048340, reward=-7.669434
      433983126 16:01:29 Epoch 1407/5000, D_loss=0.049533, reward=-7.671974
      433983126 16:01:29 Epoch 1408/5000, D_loss=0.049535, reward=-7.673113
      433983126 16:01:30 Epoch 1409/5000, D_loss=0.047255, reward=-7.669598
      433983126 16:01:30 Epoch 1410/5000, D_loss=0.049078, reward=-7.667931
      433983126 16:01:30 Epoch 1411/5000, D_loss=0.049930, reward=-7.673573
      433983126 16:01:31 Epoch 1412/5000, D_loss=0.048409, reward=-7.673744
      433983126 16:01:31 Epoch 1413/5000, D_loss=0.050835, reward=-7.671271
      433983

hhehe


      248942384 16:02:08 MR = 4599.9267578125
MRR = 0.18133024871349335
Hit@1 = 0.008404746209624258
Hit@3 = 0.3314106789716546
Hit@10 = 0.4372116018457482



MR = 4599.9267578125
MRR = 0.18133024871349335
Hit@1 = 0.008404746209624258
Hit@3 = 0.3314106789716546
Hit@10 = 0.4372116018457482



      433983126 16:02:08 Epoch 1501/5000, D_loss=0.050165, reward=-7.647502
      433983126 16:02:08 Epoch 1502/5000, D_loss=0.050491, reward=-7.646110
      433983126 16:02:09 Epoch 1503/5000, D_loss=0.048644, reward=-7.653928
      433983126 16:02:09 Epoch 1504/5000, D_loss=0.047805, reward=-7.647807
      433983126 16:02:10 Epoch 1505/5000, D_loss=0.048258, reward=-7.652993
      433983126 16:02:10 Epoch 1506/5000, D_loss=0.049238, reward=-7.643524
      433983126 16:02:10 Epoch 1507/5000, D_loss=0.050686, reward=-7.646611
      433983126 16:02:11 Epoch 1508/5000, D_loss=0.049254, reward=-7.649440
      433983126 16:02:11 Epoch 1509/5000, D_loss=0.051681, reward=-7.648094
      433983126 16:02:11 Epoch 1510/5000, D_loss=0.049092, reward=-7.654383
      433983126 16:02:12 Epoch 1511/5000, D_loss=0.050252, reward=-7.640727
      433983126 16:02:12 Epoch 1512/5000, D_loss=0.049116, reward=-7.641704
      433983126 16:02:13 Epoch 1513/5000, D_loss=0.048796, reward=-7.650681
      433983

hhehe


      248942384 16:02:49 MR = 4501.2490234375
MRR = 0.18070189654827118
Hit@1 = 0.006262359920896506
Hit@3 = 0.3360250494396836
Hit@10 = 0.4362228081740277



MR = 4501.2490234375
MRR = 0.18070189654827118
Hit@1 = 0.006262359920896506
Hit@3 = 0.3360250494396836
Hit@10 = 0.4362228081740277



      433983126 16:02:49 Epoch 1601/5000, D_loss=0.048574, reward=-7.624144
      433983126 16:02:50 Epoch 1602/5000, D_loss=0.049220, reward=-7.626698
      433983126 16:02:50 Epoch 1603/5000, D_loss=0.049794, reward=-7.630246
      433983126 16:02:50 Epoch 1604/5000, D_loss=0.049144, reward=-7.617232
      433983126 16:02:51 Epoch 1605/5000, D_loss=0.048330, reward=-7.627537
      433983126 16:02:51 Epoch 1606/5000, D_loss=0.050618, reward=-7.624564
      433983126 16:02:52 Epoch 1607/5000, D_loss=0.051424, reward=-7.622464
      433983126 16:02:52 Epoch 1608/5000, D_loss=0.049650, reward=-7.623703
      433983126 16:02:52 Epoch 1609/5000, D_loss=0.048742, reward=-7.625051
      433983126 16:02:53 Epoch 1610/5000, D_loss=0.049100, reward=-7.625636
      433983126 16:02:53 Epoch 1611/5000, D_loss=0.049694, reward=-7.625449
      433983126 16:02:53 Epoch 1612/5000, D_loss=0.047545, reward=-7.627980
      433983126 16:02:54 Epoch 1613/5000, D_loss=0.049966, reward=-7.617781
      433983

hhehe


      248942384 16:03:30 MR = 4478.18408203125
MRR = 0.18287399411201477
Hit@1 = 0.008075148319050759
Hit@3 = 0.33569545154911007
Hit@10 = 0.4401779828609097



MR = 4478.18408203125
MRR = 0.18287399411201477
Hit@1 = 0.008075148319050759
Hit@3 = 0.33569545154911007
Hit@10 = 0.4401779828609097



      433983126 16:03:30 Epoch 1701/5000, D_loss=0.048348, reward=-7.609323
      433983126 16:03:31 Epoch 1702/5000, D_loss=0.050975, reward=-7.601363
      433983126 16:03:31 Epoch 1703/5000, D_loss=0.050582, reward=-7.604121
      433983126 16:03:31 Epoch 1704/5000, D_loss=0.048839, reward=-7.606769
      433983126 16:03:32 Epoch 1705/5000, D_loss=0.049708, reward=-7.598791
      433983126 16:03:32 Epoch 1706/5000, D_loss=0.050567, reward=-7.608658
      433983126 16:03:32 Epoch 1707/5000, D_loss=0.050544, reward=-7.601024
      433983126 16:03:33 Epoch 1708/5000, D_loss=0.050806, reward=-7.601025
      433983126 16:03:33 Epoch 1709/5000, D_loss=0.048570, reward=-7.600719
      433983126 16:03:34 Epoch 1710/5000, D_loss=0.050365, reward=-7.599561
      433983126 16:03:34 Epoch 1711/5000, D_loss=0.050717, reward=-7.596979
      433983126 16:03:34 Epoch 1712/5000, D_loss=0.049803, reward=-7.598984
      433983126 16:03:35 Epoch 1713/5000, D_loss=0.049088, reward=-7.598957
      433983

hhehe


      248942384 16:04:11 MR = 4543.48974609375
MRR = 0.18234692513942719
Hit@1 = 0.007910349373764008
Hit@3 = 0.3380026367831246
Hit@10 = 0.43852999340804216



MR = 4543.48974609375
MRR = 0.18234692513942719
Hit@1 = 0.007910349373764008
Hit@3 = 0.3380026367831246
Hit@10 = 0.43852999340804216



      433983126 16:04:12 Epoch 1801/5000, D_loss=0.051725, reward=-7.577101
      433983126 16:04:12 Epoch 1802/5000, D_loss=0.049891, reward=-7.578153
      433983126 16:04:12 Epoch 1803/5000, D_loss=0.048826, reward=-7.578783
      433983126 16:04:13 Epoch 1804/5000, D_loss=0.048490, reward=-7.577110
      433983126 16:04:13 Epoch 1805/5000, D_loss=0.049611, reward=-7.578300
      433983126 16:04:14 Epoch 1806/5000, D_loss=0.049796, reward=-7.581692
      433983126 16:04:14 Epoch 1807/5000, D_loss=0.050209, reward=-7.577842
      433983126 16:04:14 Epoch 1808/5000, D_loss=0.049460, reward=-7.579483
      433983126 16:04:15 Epoch 1809/5000, D_loss=0.049332, reward=-7.579086
      433983126 16:04:15 Epoch 1810/5000, D_loss=0.052811, reward=-7.569573
      433983126 16:04:15 Epoch 1811/5000, D_loss=0.050028, reward=-7.581668
      433983126 16:04:16 Epoch 1812/5000, D_loss=0.050191, reward=-7.577880
      433983126 16:04:16 Epoch 1813/5000, D_loss=0.050578, reward=-7.575333
      433983

hhehe


      248942384 16:04:53 MR = 4395.1435546875
MRR = 0.18395082652568817
Hit@1 = 0.008075148319050759
Hit@3 = 0.34278180619644033
Hit@10 = 0.4398483849703362



MR = 4395.1435546875
MRR = 0.18395082652568817
Hit@1 = 0.008075148319050759
Hit@3 = 0.34278180619644033
Hit@10 = 0.4398483849703362



      433983126 16:04:53 Epoch 1901/5000, D_loss=0.049108, reward=-7.553177
      433983126 16:04:54 Epoch 1902/5000, D_loss=0.050863, reward=-7.561576
      433983126 16:04:54 Epoch 1903/5000, D_loss=0.050645, reward=-7.561898
      433983126 16:04:54 Epoch 1904/5000, D_loss=0.049056, reward=-7.554152
      433983126 16:04:55 Epoch 1905/5000, D_loss=0.050119, reward=-7.558410
      433983126 16:04:55 Epoch 1906/5000, D_loss=0.051800, reward=-7.556155
      433983126 16:04:56 Epoch 1907/5000, D_loss=0.049556, reward=-7.562667
      433983126 16:04:56 Epoch 1908/5000, D_loss=0.049890, reward=-7.556755
      433983126 16:04:56 Epoch 1909/5000, D_loss=0.050838, reward=-7.551769
      433983126 16:04:57 Epoch 1910/5000, D_loss=0.050325, reward=-7.554568
      433983126 16:04:57 Epoch 1911/5000, D_loss=0.051205, reward=-7.557166
      433983126 16:04:57 Epoch 1912/5000, D_loss=0.049957, reward=-7.553424
      433983126 16:04:58 Epoch 1913/5000, D_loss=0.050222, reward=-7.554236
      433983

hhehe


      248942384 16:05:34 MR = 4312.9677734375
MRR = 0.1838178038597107
Hit@1 = 0.00922874093605801
Hit@3 = 0.33569545154911007
Hit@10 = 0.4396835860250494



MR = 4312.9677734375
MRR = 0.1838178038597107
Hit@1 = 0.00922874093605801
Hit@3 = 0.33569545154911007
Hit@10 = 0.4396835860250494



      433983126 16:05:35 Epoch 2001/5000, D_loss=0.049689, reward=-7.536351
      433983126 16:05:35 Epoch 2002/5000, D_loss=0.048703, reward=-7.540701
      433983126 16:05:36 Epoch 2003/5000, D_loss=0.051491, reward=-7.536658
      433983126 16:05:36 Epoch 2004/5000, D_loss=0.051623, reward=-7.534793
      433983126 16:05:36 Epoch 2005/5000, D_loss=0.051514, reward=-7.529397
      433983126 16:05:37 Epoch 2006/5000, D_loss=0.048654, reward=-7.535520
      433983126 16:05:37 Epoch 2007/5000, D_loss=0.050887, reward=-7.529681
      433983126 16:05:37 Epoch 2008/5000, D_loss=0.049645, reward=-7.532733
      433983126 16:05:38 Epoch 2009/5000, D_loss=0.049717, reward=-7.533763
      433983126 16:05:38 Epoch 2010/5000, D_loss=0.053165, reward=-7.524553
      433983126 16:05:39 Epoch 2011/5000, D_loss=0.049726, reward=-7.532300
      433983126 16:05:39 Epoch 2012/5000, D_loss=0.052258, reward=-7.531756
      433983126 16:05:39 Epoch 2013/5000, D_loss=0.049567, reward=-7.534202
      433983

KeyboardInterrupt: 