In [1]:
#!/usr/bin/python
# -*- coding: utf-8 -*-
# ========== python ==========
import os
from pathlib import Path
from logging import Logger
from typing import List, Dict, Tuple, Optional, Union, Callable, Final, Literal, get_args
from operator import itemgetter, attrgetter
import itertools
from IPython.display import display, clear_output

from utils.setup import setup_logger, get_device
from const.const_values import PROJECT_DIR

os.chdir(PROJECT_DIR)
logger: Logger = setup_logger(__name__, f'{PROJECT_DIR}/log/jupyter_run.log')
device = get_device(device_name='cpu', logger=logger)

In [2]:
# jupyter
import seaborn as sns
import matplotlib.pyplot as plt
# Machine learning
import numpy as np
import pandas as pd
import h5py
import optuna
# torch
import torch
from torch.nn import CrossEntropyLoss
from torch.nn import functional as F
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from torch.utils.tensorboard.writer import SummaryWriter
# torch ignite
from ignite.engine import Engine
from ignite.handlers import Checkpoint, EpochOutputStore
# My items
from models.datasets.data_helper import MyDataHelperForStory, MyDataLoaderHelper, DefaultTokens
from models.datasets.datasets_for_sequence import StoryTriple
# My utils
from utils.setup import load_param
from utils.torch import load_model, torch_fix_seed
# main function
from run_for_KGC import main_function, fix_args

In [3]:
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 100)

In [4]:
from const.const_values import CPU, MODEL
from models.KGModel.kg_model import HEAD, RELATION, TAIL
from utils.torch_ignite import TRAINER, EVALUATOR
from const.const_values import DATASETS, DATA_HELPER, DATA_LOADERS, TRAIN_RETURNS

In [5]:
SEED: Final[int] = 42
args_path = f'{PROJECT_DIR}/models/230205/02/param.pkl'
model_path = f'{PROJECT_DIR}/models/230205/02/model.pth'

In [6]:
args = load_param(args_path)

# args.pre_train = True
args.logger = logger
# args.device = device
args.batch_size = 16
args.pre_train=False
args.init_embedding_using_bert = False
args.model_path = model_path
args.only_load_trainer_evaluator = True
args.train_anyway=True
args.non_blocking=False
args.lr_head, args.lr_tail = None, None
args.no_grad_entity_embedding=False
args.no_grad_relation_embedding=False

args = fix_args(args)

del args.optuna_file, args.device_name, args.pid, args.study_name, args.n_trials



In [7]:
args

Namespace(notebook=False, console_level='info', logfile='models/230205/02/log.log', param_file='models/230205/02/param.pkl', train_anyway=True, old_data=0, tensorboard_dir='models/230205/02/tensorboard', checkpoint_dir='models/230205/02/checkpoint/', model_path='/Users/ryoyakaneda/Documents/学校/M1Study/knowledge_graph/models/230205/02/model.pth', resume_from_checkpoint=False, resume_from_last_point=False, only_load_trainer_evaluator=True, resume_checkpoint_path=None, pre_train=False, train_valid_test=True, only_train=False, use_for_challenge100=False, use_for_challenge090=False, use_for_challenge075=False, use_title=None, do_optuna=False, story_special_num=5, relation_special_num=5, entity_special_num=5, padding_token_e=0, cls_token_e=1, mask_token_e=2, sep_token_e=3, bos_token_e=4, padding_token_r=0, cls_token_r=1, mask_token_r=2, sep_token_r=3, bos_token_r=4, padding_token_s=0, cls_token_s=1, mask_token_s=2, sep_token_s=3, bos_token_s=4, model_version='03', embedding_dim=128, entity_e

In [8]:
args.lr, args.lr_head, args.lr_relation,  args.lr_tail, args.no_grad_entity_embedding

(0.0001, 5e-05, 1e-06, 1e-06, False)

In [9]:
args.device = device
torch_fix_seed(seed=SEED)
return_dict = main_function(args, logger=logger)

model = return_dict[MODEL]

dataset_train, dataset_valid, dataset_test = return_dict[DATASETS]

triple: torch.Tensor = dataset_train.triple
data_helper: MyDataHelperForStory = return_dict[DATA_HELPER]
evaluator: Checkpoint = return_dict[TRAIN_RETURNS][EVALUATOR]

load_model(model, args.model_path, args.device)
model.eval()
evaluator = return_dict['train_returns']['evaluator']
test = return_dict['data_loaders'].test_dataloader

entities, relations = data_helper.processed_entities, data_helper.processed_relations
d_e, d_r = {e: i for i, e in enumerate(entities)}, {r: i for i, r in enumerate(relations)}

triple_df = pd.DataFrame([(entities[_t[0]], relations[_t[1]], entities[_t[2]]) for _t in triple], columns=[HEAD, RELATION, TAIL])
story_entities = triple_df[HEAD].tolist()
del triple_df, return_dict

2023-02-19 20:54:48 - INFO - run_for_KGC.py - 1029 - ----- make datahelper start. -----
2023-02-19 20:54:48 - INFO - data_helper.py - 336 - entity num: 7812
2023-02-19 20:54:48 - INFO - data_helper.py - 337 - relation num: 62
2023-02-19 20:54:48 - INFO - data_helper.py - 336 - entity num: 7812
2023-02-19 20:54:48 - INFO - data_helper.py - 337 - relation num: 62
2023-02-19 20:54:48 - INFO - data_helper.py - 633 - entity_special_dicts: {0: '<pad_e>', 1: '<cls_e>', 2: '<mask_e>', 3: '<sep_e>', 4: '<bos_e>'}
2023-02-19 20:54:48 - INFO - data_helper.py - 634 - relation_special_dicts: {0: '<pad_r>', 1: '<cls_r>', 2: '<mask_r>', 3: '<sep_r>', 4: '<bos_r>'}
2023-02-19 20:54:48 - INFO - data_helper.py - 635 - processed entity num: 7817
2023-02-19 20:54:48 - INFO - data_helper.py - 636 - processed relation num: 67
2023-02-19 20:54:48 - INFO - run_for_KGC.py - 1031 - ----- make datahelper complete. -----
2023-02-19 20:54:48 - INFO - run_for_KGC.py - 1033 - ----- make datasets start. -----
2023-02

In [10]:
with torch.no_grad():
    eos = EpochOutputStore()
    eos.attach(evaluator, 'output')
    model.eval()
    evaluator.run(test)
    metrics = evaluator.state.metrics
    head_pred, head_ans, relation_pred, relation_ans, tail_pred, tail_ans = zip(*[
        [dict_['head_pred'], dict_['head_ans'],
         dict_['relation_pred'], dict_['relation_ans'],
         dict_['tail_pred'], dict_['tail_ans']]
        for dict_ in evaluator.state.output
    ])
    clear_output()
    display(metrics)
    head_pred, head_ans, relation_pred, relation_ans, tail_pred, tail_ans = [
        torch.cat(item).clone() for item in (head_pred, head_ans, relation_pred, relation_ans, tail_pred, tail_ans)
    ]
    head_pred, relation_pred, tail_pred = [
        F.softmax(item, dim=1) for item in (head_pred, relation_pred, tail_pred)
    ]
    del evaluator, eos

{'loss': 10.60277793903162,
 'head_loss': 1.8657006845615878,
 'relation_loss': 1.6497741465521332,
 'tail_loss': 7.087303138015294,
 'head_accuracy': 0.5514659873066953,
 'relation_accuracy': 0.51275289770865,
 'tail_accuracy': 0.17034653317839157,
 'head_top1': 0.5514659873066953,
 'relation_top1': 0.51275289770865,
 'tail_top1': 0.17034653317839157,
 'head_top3': 0.9021781234170615,
 'relation_top3': 0.7911638508983642,
 'tail_top3': 0.26588897827835883,
 'head_top10': 0.9733693870862012,
 'relation_top10': 0.9677750960936802,
 'tail_top10': 0.3210273829742856}

In [37]:
tail_pred_rank_df = pd.DataFrame(tail_pred.T).rank(method='min', ascending=False)
_list = []
for i in range(len(tail_ans)):
    _list.append(tail_pred_rank_df.iloc[tail_ans[i].item(), i])
_series = pd.Series(_list)
print((1/_series).mean())

0.22686198738598584


In [11]:
def get_str_list(_triple):
    return (entities[_triple[0]], relations[_triple[1]], entities[_triple[2]])

def get_triple_and_sequence_df(_dataset, is_valid_dataset):
    _triple = _dataset.triple
    _sequences = torch.stack([_dataset[i] for i in range(len(_dataset))]) if not is_valid_dataset \
            else torch.stack([_dataset[i][0] for i in range(len(_dataset))])
    _triples_df = pd.DataFrame(_triple, columns=['head', 'relation', 'tail'])
    _all_sequences_df = pd.DataFrame(_sequences.view(-1, 3), columns=['head', 'relation', 'tail'])
    _all_sequences_df['index'] = [i for i in range(len(_dataset)) for _ in range(_dataset.max_len) ]
    if is_valid_dataset:
        _all_sequences_df['is_valid']= torch.cat([_dataset[i][1] for i in range(len(_dataset))])
    return _triples_df, _all_sequences_df

In [12]:
train_triples_df, train_all_sequences_df = get_triple_and_sequence_df(dataset_train, is_valid_dataset=False)
valid_triples_df, valid_all_sequences_df = get_triple_and_sequence_df(dataset_valid, is_valid_dataset=True)
test_triples_df, test_all_sequences_df = get_triple_and_sequence_df(dataset_test, is_valid_dataset=True)

In [None]:
train_entities = pd.concat(
    [train_triples_df['head'].value_counts(), train_triples_df['tail'].value_counts()], axis=1).fillna(0).astype(int)
train_entities['entity'] = train_entities['head'] + train_entities['tail']
display(train_entities['entity'])
display(train_triples_df['head'].value_counts())

In [20]:
valid_all_sequences_df

Unnamed: 0,head,relation,tail,index,is_valid
0,4,4,4,0,False
1,3958,6,5,0,True
2,3958,45,6,0,False
3,3958,26,6,0,False
4,3958,42,382,0,False
...,...,...,...,...,...
1651707,4,4,4,3225,False
1651708,4056,6,337,3225,False
1651709,4056,21,7239,3225,False
1651710,4056,44,437,3225,False


In [69]:
def get_num_df(_df, _index):
    _df = _df[_df['index']==_index]
    _series = _df['tail'].value_counts()
    _series.name = f'{_index}'
    return _series

tail_nums_df = pd.concat([pd.Series([1 for _ in range(len(entities))], index=list(range(len(entities))), name='tmp')]+[
    get_num_df(valid_all_sequences_df, i) for i in range(max(valid_all_sequences_df['index'])+1)
], axis=1).fillna(0).astype('int')
tail_nums_rank_df =tail_nums_df.rank(method='min', ascending=False)
del tail_nums_rank_df['tmp']

In [70]:
display(tail_nums_rank_df)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,...,3176,3177,3178,3179,3180,3181,3182,3183,3184,3185,3186,3187,3188,3189,3190,3191,3192,3193,3194,3195,3196,3197,3198,3199,3200,3201,3202,3203,3204,3205,3206,3207,3208,3209,3210,3211,3212,3213,3214,3215,3216,3217,3218,3219,3220,3221,3222,3223,3224,3225
0,205.0,202.0,201.0,199.0,199.0,199.0,199.0,199.0,200.0,200.0,200.0,200.0,199.0,199.0,200.0,199.0,199.0,199.0,199.0,199.0,197.0,194.0,192.0,190.0,188.0,189.0,190.0,189.0,189.0,190.0,190.0,189.0,189.0,188.0,188.0,188.0,187.0,186.0,185.0,186.0,186.0,190.0,189.0,189.0,190.0,189.0,188.0,187.0,186.0,186.0,...,200.0,201.0,201.0,200.0,203.0,204.0,204.0,206.0,208.0,209.0,209.0,209.0,210.0,213.0,216.0,217.0,218.0,219.0,216.0,213.0,214.0,215.0,215.0,215.0,217.0,216.0,216.0,215.0,213.0,214.0,215.0,215.0,210.0,208.0,206.0,207.0,209.0,209.0,208.0,205.0,204.0,205.0,207.0,206.0,205.0,208.0,208.0,208.0,208.0,205.0
1,205.0,202.0,201.0,199.0,199.0,199.0,199.0,199.0,200.0,200.0,200.0,200.0,199.0,199.0,200.0,199.0,199.0,199.0,199.0,199.0,197.0,194.0,192.0,190.0,188.0,189.0,190.0,189.0,189.0,190.0,190.0,189.0,189.0,188.0,188.0,188.0,187.0,186.0,185.0,186.0,186.0,190.0,189.0,189.0,190.0,189.0,188.0,187.0,186.0,186.0,...,200.0,201.0,201.0,200.0,203.0,204.0,204.0,206.0,208.0,209.0,209.0,209.0,210.0,213.0,216.0,217.0,218.0,219.0,216.0,213.0,214.0,215.0,215.0,215.0,217.0,216.0,216.0,215.0,213.0,214.0,215.0,215.0,210.0,208.0,206.0,207.0,209.0,209.0,208.0,205.0,204.0,205.0,207.0,206.0,205.0,208.0,208.0,208.0,208.0,205.0
2,205.0,202.0,201.0,199.0,199.0,199.0,199.0,199.0,200.0,200.0,200.0,200.0,199.0,199.0,200.0,199.0,199.0,199.0,199.0,199.0,197.0,194.0,192.0,190.0,188.0,189.0,190.0,189.0,189.0,190.0,190.0,189.0,189.0,188.0,188.0,188.0,187.0,186.0,185.0,186.0,186.0,190.0,189.0,189.0,190.0,189.0,188.0,187.0,186.0,186.0,...,200.0,201.0,201.0,200.0,203.0,204.0,204.0,206.0,208.0,209.0,209.0,209.0,210.0,213.0,216.0,217.0,218.0,219.0,216.0,213.0,214.0,215.0,215.0,215.0,217.0,216.0,216.0,215.0,213.0,214.0,215.0,215.0,210.0,208.0,206.0,207.0,209.0,209.0,208.0,205.0,204.0,205.0,207.0,206.0,205.0,208.0,208.0,208.0,208.0,205.0
3,205.0,202.0,201.0,199.0,199.0,199.0,199.0,199.0,200.0,200.0,200.0,200.0,199.0,199.0,200.0,199.0,199.0,199.0,199.0,199.0,197.0,194.0,192.0,190.0,188.0,189.0,190.0,189.0,189.0,190.0,190.0,189.0,189.0,188.0,188.0,188.0,187.0,186.0,185.0,186.0,186.0,190.0,189.0,189.0,190.0,189.0,188.0,187.0,186.0,186.0,...,200.0,201.0,201.0,200.0,203.0,204.0,204.0,206.0,208.0,209.0,209.0,209.0,210.0,213.0,216.0,217.0,218.0,219.0,216.0,213.0,214.0,215.0,215.0,215.0,217.0,216.0,216.0,215.0,213.0,214.0,215.0,215.0,210.0,208.0,206.0,207.0,209.0,209.0,208.0,205.0,204.0,205.0,207.0,206.0,205.0,208.0,208.0,208.0,208.0,205.0
4,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
7812,205.0,202.0,201.0,199.0,199.0,199.0,199.0,199.0,200.0,200.0,200.0,200.0,199.0,199.0,200.0,199.0,199.0,199.0,199.0,199.0,197.0,194.0,192.0,190.0,188.0,189.0,190.0,189.0,189.0,190.0,190.0,189.0,189.0,188.0,188.0,188.0,187.0,186.0,185.0,186.0,186.0,190.0,189.0,189.0,190.0,189.0,188.0,187.0,186.0,186.0,...,200.0,201.0,201.0,200.0,203.0,204.0,204.0,206.0,208.0,209.0,209.0,209.0,210.0,213.0,216.0,217.0,218.0,219.0,216.0,213.0,214.0,215.0,215.0,215.0,217.0,216.0,216.0,215.0,213.0,214.0,215.0,215.0,210.0,208.0,206.0,207.0,209.0,209.0,208.0,205.0,204.0,205.0,207.0,206.0,205.0,208.0,208.0,208.0,208.0,205.0
7813,205.0,202.0,201.0,199.0,199.0,199.0,199.0,199.0,200.0,200.0,200.0,200.0,199.0,199.0,200.0,199.0,199.0,199.0,199.0,199.0,197.0,194.0,192.0,190.0,188.0,189.0,190.0,189.0,189.0,190.0,190.0,189.0,189.0,188.0,188.0,188.0,187.0,186.0,185.0,186.0,186.0,190.0,189.0,189.0,190.0,189.0,188.0,187.0,186.0,186.0,...,54.0,53.0,55.0,55.0,54.0,204.0,204.0,206.0,208.0,209.0,209.0,209.0,210.0,213.0,216.0,217.0,218.0,219.0,216.0,213.0,214.0,215.0,215.0,215.0,217.0,216.0,216.0,215.0,213.0,214.0,215.0,215.0,210.0,208.0,206.0,207.0,209.0,209.0,208.0,205.0,204.0,205.0,207.0,206.0,205.0,208.0,208.0,208.0,208.0,205.0
7814,205.0,202.0,201.0,199.0,199.0,199.0,199.0,199.0,200.0,200.0,200.0,200.0,199.0,199.0,200.0,199.0,199.0,199.0,199.0,199.0,197.0,194.0,192.0,190.0,188.0,189.0,190.0,189.0,189.0,190.0,190.0,189.0,189.0,188.0,188.0,188.0,187.0,186.0,185.0,186.0,186.0,190.0,189.0,189.0,190.0,189.0,188.0,187.0,186.0,186.0,...,200.0,201.0,201.0,200.0,203.0,204.0,204.0,206.0,208.0,209.0,209.0,209.0,210.0,213.0,216.0,217.0,218.0,219.0,216.0,213.0,214.0,215.0,215.0,215.0,217.0,216.0,216.0,215.0,213.0,214.0,215.0,215.0,210.0,208.0,206.0,207.0,209.0,209.0,208.0,205.0,204.0,205.0,207.0,206.0,205.0,208.0,208.0,208.0,208.0,205.0
7815,205.0,202.0,201.0,199.0,199.0,199.0,199.0,199.0,200.0,200.0,200.0,200.0,199.0,199.0,200.0,199.0,199.0,199.0,199.0,199.0,197.0,194.0,192.0,190.0,188.0,189.0,190.0,189.0,189.0,190.0,190.0,189.0,189.0,188.0,188.0,188.0,187.0,186.0,185.0,186.0,186.0,190.0,189.0,189.0,190.0,189.0,188.0,187.0,186.0,186.0,...,200.0,201.0,201.0,200.0,203.0,204.0,204.0,206.0,208.0,209.0,209.0,209.0,210.0,213.0,216.0,217.0,218.0,219.0,216.0,213.0,214.0,215.0,215.0,215.0,217.0,216.0,216.0,215.0,213.0,214.0,215.0,215.0,210.0,208.0,206.0,207.0,209.0,209.0,208.0,205.0,204.0,205.0,207.0,206.0,205.0,208.0,208.0,208.0,208.0,205.0


In [71]:
_list = []
for i in range(len(dataset_valid)):
    data, is_valid = dataset_valid[i]
    tails = data[is_valid][:, 2]
    for j in range(len(tails)):
        _list.append(tail_nums_rank_df.iloc[tails[j].item(), i])
_series2 = pd.Series(_list)
print((1/_series2).mean())

0.15344268594965893


In [26]:
def get_triple_ans_valid_id_preds_df()
    k = 50
    triple_ans_valid_id_preds = torch.cat(
        [torch.cat([test_dataset[i][0][test_dataset[i][1]] for i in range(len(test_dataset))]),
         torch.cat([test_dataset[i][1][test_dataset[i][1]==1]*i for i in range(len(test_dataset))])[:, None],
         torch.topk(head_pred, k, dim=1).indices.to(torch.int),
         torch.topk(relation_pred, k, dim=1).indices.to(torch.int),
         torch.topk(tail_pred, k, dim=1).indices.to(torch.int),
        ], dim=1
    )

    return pd.DataFrame(
        triple_ans_valid_id_preds,
        columns=['head', 'relation', 'tail', 'valid_index',
                 *[f'head_pred_rank{i+1}' for i in range(k)],
                 *[f'relation_pred_rank{i+1}' for i in range(k)],
                 *[f'tail_pred_rank{i+1}' for i in range(k)],
                 ]
    )

NameError: name 'test_dataset' is not defined

In [27]:
df = df.iloc[:,:(3*k+4)]
df.loc[:, f'head_is_true'] = (df['head'] == df[f'head_pred_rank1']).copy()
df.loc[:, f'relation_is_true'] = (df['relation'] == df[f'relation_pred_rank1']).copy()
df.loc[:, f'tail_is_true'] = (df['tail'] == df[f'tail_pred_rank1']).copy()
df.loc[:, f'head_is_true_in_top1'] = df[f'head_is_true'].copy()
df.loc[:, f'relation_is_true_in_top1'] = df[f'relation_is_true'].copy()
df.loc[:, f'tail_is_true_in_top1'] = df[f'tail_is_true'].copy()

for i in range(1, k):
    df.loc[:, f'head_is_true_in_top{i+1}'] = (
        df['head'] == df[f'head_pred_rank{i+1}']) | df.loc[:, f'head_is_true_in_top{i}'].copy()
    df.loc[:, f'relation_is_true_in_top{i+1}'] = (
        df['relation'] == df[f'relation_pred_rank{i+1}']) | df.loc[:, f'relation_is_true_in_top{i}'].copy()
    df.loc[:, f'tail_is_true_in_top{i+1}'] = (
        df['tail'] == df[f'tail_pred_rank{i+1}']) | df.loc[:, f'tail_is_true_in_top{i}'].copy()

df = df.reindex(columns=[
    'head', 'relation', 'tail', 'valid_index',
    *[f'head_pred_rank{i+1}' for i in range(k)],
    *[f'relation_pred_rank{i+1}' for i in range(k)],
    *[f'tail_pred_rank{i+1}' for i in range(k)],
    *[f'head_is_true_in_top{i+1}' for i in range(k)],
    *[f'relation_is_true_in_top{i+1}' for i in range(k)],
    *[f'tail_is_true_in_top{i+1}' for i in range(k)],
])
display(df)

NameError: name 'df' is not defined

In [28]:
train_1sequence_df = pd.DataFrame(dataset_train.triple, columns=['head', 'relation', 'tail'] )
predicate_counts = train_1sequence_df[train_1sequence_df['relation']==d_r['kgc:hasPredicate']]['tail'].value_counts()
predicate_counts.index = [entities[_index] for _index in predicate_counts.index]
predicate_counts.mean(), predicate_counts.median()

(3.613559322033898, 1.0)

In [29]:
def get_value_counts(_df):
    _head_value_counts = _df['head'].value_counts()
    _head_value_counts.index = [entities[_index] for _index in _head_value_counts.index]
    _relation_value_counts = _df['relation'].value_counts()
    _relation_value_counts.index = [relations[_index] for _index in _relation_value_counts.index]
    _tail_value_counts = _df['tail'].value_counts()
    _tail_value_counts.index = [entities[_index] for _index in _tail_value_counts.index]
    return _head_value_counts, _relation_value_counts, _tail_value_counts

In [30]:
true_relation_check_by_tail_is_true = pd.concat(
    [get_value_counts(df[df[f'tail_is_true_in_top{i+1}']])[1] for i in range(k)],
    axis=1
)

true_relation_check_by_tail_is_true = pd.concat(
    (true_relation_check_by_tail_is_true, get_value_counts(df)[1],
     get_value_counts(train_triples_df)[1], get_value_counts(train_all_sequences_df)[1]), axis=1
)
true_relation_check_by_tail_is_true = true_relation_check_by_tail_is_true.fillna(0).astype('int')
true_relation_check_by_tail_is_true.columns = (
    *[f'tail_is_true_in_top{i+1}' for i in range(k)],
    'all_in_test', 'train_triples', 'train_all_sequences')

true_relation_check_by_tail_is_true_percent = true_relation_check_by_tail_is_true.copy()
all_in_test_series = true_relation_check_by_tail_is_true_percent['all_in_test']
for i in range(k):
    true_relation_check_by_tail_is_true_percent[f'tail_is_true_in_top{i+1}'] = \
            true_relation_check_by_tail_is_true_percent[f'tail_is_true_in_top{i+1}']/all_in_test_series
display(true_relation_check_by_tail_is_true_percent)

NameError: name 'df' is not defined

In [31]:
to_latex_df = true_relation_check_by_tail_is_true_percent.loc[:,
              ['tail_is_true_in_top1', 'tail_is_true_in_top3', 'tail_is_true_in_top10', 'train_triples', 'train_all_sequences']]
to_latex_df.columns = ['top1', 'top3', 'top10', '訓練データに含まれる個数', '訓練系列に登場する個数']
print(to_latex_df.to_latex())

NameError: name 'true_relation_check_by_tail_is_true_percent' is not defined

In [32]:
what_value_counts = train_triples_df[train_triples_df['relation'] == d_r['kgc:what']]['tail'].value_counts()
what_value_counts.index = [entities[_index] for _index in what_value_counts.index]
what_value_counts.median(), len(what_value_counts[what_value_counts==1])

(1.0, 1779)

In [33]:
one_count = 0
for tail_id in [d_e[_index] for _index in what_value_counts[what_value_counts==1].index]:
    # print(len(train_triples_df[train_triples_df['tail']==tail_id]))
    if len(train_triples_df[train_triples_df['tail']==tail_id])==1:
        print(get_str_list(train_triples_df[train_triples_df['tail']==tail_id].iloc[0, :]))
        one_count=one_count+1
one_count

('DevilsFoot:473', 'kgc:what', 'DevilsFoot:480')
('DevilsFoot:473', 'kgc:what', 'DevilsFoot:476')
('DevilsFoot:473', 'kgc:what', 'DevilsFoot:483')
('DevilsFoot:473', 'kgc:what', 'DevilsFoot:481')
('DevilsFoot:473', 'kgc:what', 'DevilsFoot:484')
('DevilsFoot:473', 'kgc:what', 'DevilsFoot:477')
('DevilsFoot:473', 'kgc:what', 'DevilsFoot:479')
('DevilsFoot:473', 'kgc:what', 'DevilsFoot:478')
('SilverBlaze:345', 'kgc:what', 'SilverBlaze:351')
('DevilsFoot:473', 'kgc:what', 'DevilsFoot:485')
('ResidentPatient:006', 'kgc:what', 'ResidentPatient:007')
('ResidentPatient:008', 'kgc:what', 'ResidentPatient:017')
('ResidentPatient:008', 'kgc:what', 'ResidentPatient:015')
('ResidentPatient:008', 'kgc:what', 'ResidentPatient:014')
('ResidentPatient:008', 'kgc:what', 'ResidentPatient:013')
('ResidentPatient:008', 'kgc:what', 'ResidentPatient:012')
('ResidentPatient:008', 'kgc:what', 'ResidentPatient:011')
('ResidentPatient:008', 'kgc:what', 'ResidentPatient:010')
('ResidentPatient:005', 'kgc:what', 

1650

In [34]:
true_head_value_counts, true_relation_value_counts, true_tail_value_counts = get_value_counts(df[df['tail_is_true']])

NameError: name 'df' is not defined

In [35]:
false_head_value_counts, false_relation_value_counts, false_tail_value_counts = get_value_counts(df[~df['tail_is_true']])

NameError: name 'df' is not defined

In [36]:
def get_tf_value_counts(true_value_counts, false_value_counts):
    tf_value_counts = pd.concat([true_value_counts, false_value_counts], axis=1)
    tf_value_counts = tf_value_counts.fillna(0).astype('int')
    tf_value_counts.columns = ['true_counts', 'false_counts']
    tf_value_counts['true_percent'] = tf_value_counts['true_counts']/(
            tf_value_counts['true_counts']+tf_value_counts['false_counts'])
    return tf_value_counts

In [37]:
tf_relation_value_counts = get_tf_value_counts(true_relation_value_counts, false_relation_value_counts)
tf_tail_value_counts = get_tf_value_counts(true_tail_value_counts, false_tail_value_counts)

NameError: name 'true_relation_value_counts' is not defined

In [38]:
tf_relation_value_counts

NameError: name 'tf_relation_value_counts' is not defined

In [39]:
percent_all_head, percent_all_relation, percent_all_tail = np.zeros(10), np.zeros(5), np.zeros(10)

test_all_sequences_df = test_all_sequences_df[test_all_sequences_df['head']!=4]
sequence_count = test_all_sequences_df['index'].max()

for i in range(sequence_count):
    test_sequences_df = test_all_sequences_df[test_all_sequences_df['index']==i]
    only_test = test_sequences_df[test_sequences_df['is_valid']]
    only_not_test = test_sequences_df[~test_sequences_df['is_valid']]
    # print(only_test)
    # print(only_not_test)

    head_unique_series = only_not_test['head'].value_counts()
    relation_unique_series = only_not_test['relation'].value_counts()
    tail_unique_series = only_not_test['tail'].value_counts()

    # print(entities[head_unique_series.index[0]])
    percent_head = np.array([len(only_test[only_test['head'] == head_unique_series.index[i]]) for i in range(10)])/len(only_test)
    percent_head = np.array([np.sum(percent_head[:i+1]) for i in range(10)])
    percent_all_head+=percent_head
    percent_relation = np.array([len(only_test[only_test['relation'] == relation_unique_series.index[i]]) for i in range(5)])/len(only_test)
    percent_relation = np.array([np.sum(percent_relation[:i+1]) for i in range(5)])
    percent_all_relation+=percent_relation
    percent_tail = np.array([len(only_test[only_test['tail'] == tail_unique_series.index[i]]) for i in range(10)])/len(only_test)
    percent_tail = np.array([np.sum(percent_tail[:i+1]) for i in range(10)])
    percent_all_tail+=percent_tail
percent_all_head/=sequence_count
percent_all_relation/=sequence_count
percent_all_tail/=sequence_count
print(percent_all_tail)

[0.10426162 0.17228795 0.21424489 0.2495138  0.28043442 0.30109194
 0.31605248 0.32868557 0.34031899 0.35104279]


In [194]:
most_frequency_percent_list = []
num_list = []

for tensor in dataset_train:
    head_unique, head_unique_count = torch.unique(tensor[:, 0], return_counts=True)
    relation_unique, relation_unique_count = torch.unique(tensor[:, 1], return_counts=True)
    tail_unique, tail_unique_count = torch.unique(tensor[:, 2], return_counts=True)
    entity_unique, entity_unique_count = torch.unique(tensor[:, (0,2)], return_counts=True)

    head_unique, head_unique_count = [_tensor[head_unique!=4] for _tensor in (head_unique, head_unique_count)]
    relation_unique, relation_unique_count = [_tensor[relation_unique!=4] for _tensor in (relation_unique, relation_unique_count)]
    tail_unique, tail_unique_count = [_tensor[tail_unique!=4] for _tensor in (tail_unique, tail_unique_count)]
    entity_unique, entity_unique_count =  [_tensor[entity_unique!=4] for _tensor in (entity_unique, entity_unique_count)]

    head_most_frequency = torch.max(head_unique_count)
    relation_most_frequency = torch.max(relation_unique_count)
    tail_most_frequency = torch.max(tail_unique_count)
    entity_most_frequency = torch.max(entity_unique_count)

    most_frequency_percent_list.append(
        [(value/len(tensor)).item() for value in (head_most_frequency, relation_most_frequency, tail_most_frequency, entity_most_frequency) ])
    num_list.append([1/len(value) for value in (head_unique, relation_unique, tail_unique, entity_unique)])

torch.tensor(most_frequency_percent_list).mean(dim=0), torch.tensor(num_list).mean(dim=0)

AbbeyGrange:Lady_Brackenstall
AbbeyGrange:Lady_Brackenstall
AbbeyGrange:Lady_Brackenstall
AbbeyGrange:Lady_Brackenstall
AbbeyGrange:Lady_Brackenstall
AbbeyGrange:Lady_Brackenstall
AbbeyGrange:Lady_Brackenstall
AbbeyGrange:Lady_Brackenstall
AbbeyGrange:Lady_Brackenstall
AbbeyGrange:Lady_Brackenstall
AbbeyGrange:Lady_Brackenstall
AbbeyGrange:Lady_Brackenstall
AbbeyGrange:Lady_Brackenstall
AbbeyGrange:Lady_Brackenstall
AbbeyGrange:Lady_Brackenstall
AbbeyGrange:Lady_Brackenstall
AbbeyGrange:Lady_Brackenstall
AbbeyGrange:Lady_Brackenstall
AbbeyGrange:Lady_Brackenstall
AbbeyGrange:Lady_Brackenstall
AbbeyGrange:Lady_Brackenstall
AbbeyGrange:Lady_Brackenstall
AbbeyGrange:Lady_Brackenstall
AbbeyGrange:Lady_Brackenstall
AbbeyGrange:Lady_Brackenstall
AbbeyGrange:Lady_Brackenstall
AbbeyGrange:Lady_Brackenstall
AbbeyGrange:Lady_Brackenstall
AbbeyGrange:Lady_Brackenstall
AbbeyGrange:Lady_Brackenstall
AbbeyGrange:Lady_Brackenstall
AbbeyGrange:Lady_Brackenstall
AbbeyGrange:Lady_Brackenstall
AbbeyGrang

(tensor([0.0429, 0.2160, 0.0969, 0.1016]),
 tensor([0.0102, 0.0824, 0.0059, 0.0042]))

#### Note that the data presented in the previous section are part of this series below.

This is the model made by Pytorch.

In [40]:
model

KgSequenceTransformer03(
  (entity_embeddings): Embedding(7817, 768, padding_idx=0)
  (relation_embeddings): Embedding(67, 64, padding_idx=0)
  (head_maskdlm): Feedforward(
    (linear1): Linear(in_features=128, out_features=128, bias=True)
    (norm): Identity()
    (activation): GELU(approximate='none')
    (linear2): Linear(in_features=128, out_features=7817, bias=True)
  )
  (relation_maskdlm): Feedforward(
    (linear1): Linear(in_features=128, out_features=128, bias=True)
    (norm): Identity()
    (activation): GELU(approximate='none')
    (linear2): Linear(in_features=128, out_features=67, bias=True)
  )
  (tail_maskdlm): Feedforward(
    (linear1): Linear(in_features=128, out_features=128, bias=True)
    (norm): Identity()
    (activation): GELU(approximate='none')
    (linear2): Linear(in_features=128, out_features=7817, bias=True)
  )
  (pe): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): TransformerEncoder(
    (layers): ModuleList(
  

This function is used to visualize Attention. The description is skipped.

In [11]:
# This is sub functions for getting Attention.
def extract(_model, target, inputs):
    """This is sub functions for getting Attention.

    """
    features = None

    def forward_hook(_module, _inputs, _):
        nonlocal features
        x, _, _ = _inputs
        outputs = _module.forward(x, x, x, need_weights=True)[1]
        features = outputs.detach().clone()

    handle = target.register_forward_hook(forward_hook)

    _model.eval()
    _model(inputs, torch.LongTensor([[]]), torch.LongTensor([[]]), torch.LongTensor([[]]))

    handle.remove()

    return features

# This is main functions for getting Attention.
def get_attention(input_):
    """This is main functions for getting Attention.

    """
    assert len(input_) == 1
    features = extract(model, model.transformer.layers[-1].self_attn, input_)[0]
    df_attention = pd.DataFrame([[entities[h], relations[r], entities[t]]+[features[j, i].item() for j in range(len(features))] for i, (h, r, t) in enumerate(input_[0])])
    df_attention.columns=[HEAD, RELATION, TAIL] + [f'atten_from{i}' for i in range(len(df_attention.columns)-3)]
    return df_attention

def show_attention_heatmap(df_attention):
    sns.heatmap(df_attention.iloc[:,3:])
    plt.show()

In [12]:
MASK_E = DefaultTokens.MASK_E
KILL = 'word.predicate:kill'

TAKE = 'word.predicate:take'
BRING = 'word.predicate:bring'
DIE = 'word.predicate:die'
HIDE = 'word.predicate:hide'

The input to the model is the following function's "question_".
There are some last scenes and criminal scene after last scene.
Each part of criminal scene are changed into a victim or <mask>, and the MASK part is estimated.

In [13]:
bos_triple = [d_e[DefaultTokens.BOS_E], d_r[DefaultTokens.BOS_R],d_e[DefaultTokens.BOS_E]]
mask_e_id = d_e[DefaultTokens.MASK_E]
Holmes_id =d_e['AllTitle:Holmes']

def make_ranking(from_story_name, to_story_name, predicate_, whom_, subject_, why_, what_, where_, when_):
    if not (from_story_name is None and to_story_name is None):
        _start_index = story_entities.index(from_story_name)-1
        _end_index = len(story_entities) - story_entities[::-1].index(to_story_name)
    else:
        _start_index = 0
        _end_index = 0
    question_ = torch.tensor(
        [
            # [... last 80 scenes ...],
            # ...
            bos_triple,
            [mask_e_id, d_r['kgc:infoSource'],     Holmes_id      ],
            [mask_e_id, d_r['kgc:hasPredicate'],   d_e[predicate_]],
            [mask_e_id, d_r['kgc:whom'],           d_e[whom_     ]],
            [mask_e_id, d_r['kgc:subject'],        d_e[subject_  ]],
            [mask_e_id, d_r['kgc:why'],            d_e[why_      ]],
            [mask_e_id, d_r['kgc:what'],           d_e[what_     ]],
            [mask_e_id, d_r['kgc:where'],          d_e[where_    ]],
            [mask_e_id, d_r['kgc:when'],          d_e[when_     ]],
        ]
    )
    mask_ = torch.zeros_like(question_, dtype=torch.bool) # not mask all position
    mask_[1:, 0] = True                                   # where head position without bos token
    mask_[1:, 2] = True                                   # where tail position without bos token

    last_triples = triple[_start_index: _end_index]

    questions = torch.cat([last_triples, question_], dim=0).unsqueeze(0)
    masks = torch.cat([torch.zeros_like(last_triples), mask_], dim=0).to(torch.bool).transpose(1,0).unsqueeze(0)

    data_list = []
    with torch.no_grad():
        _, (story_pred, relation_pred, entity_pred) = model(questions, masks[:,0], masks[:,1], masks[:,2])
        sorted_ = torch.argsort(entity_pred, dim=1, descending=True)
        for i in range(sorted_.shape[1]):
            ans_= sorted_[:, i]
            info_source_, predicate_pred, whom_pred, subject_pred, why_pred, what_pred, where_pred, when_pred = ans_
            data_list.append([entities[predicate_pred], entities[whom_pred], entities[subject_pred], entities[why_pred], entities[what_pred], entities[where_pred], entities[when_pred]])
    df_ranking = pd.DataFrame(data_list, columns=['predicate', 'whom', 'subject', 'why', 'what', 'where', 'when'])
    df_attention = get_attention(questions)

    return df_ranking, df_attention

In [14]:
def main_func01(_title, _victim_name, criminal, predicate, _last_index, _story_len):
    from_ = f'{_title}:{_last_index-_story_len+1}'
    to_ = f'{_title}:{_last_index}'
    predicate = predicate
    victim = f'{_title}:{_victim_name}'
    criminal = f'{_title}:{criminal}'
    df_ranking, df_attention = make_ranking(
        from_, to_, predicate, victim, MASK_E, MASK_E, MASK_E, MASK_E, MASK_E)
    df_ranking.index.name='ranking'

    pred_rank = df_ranking.index[df_ranking['subject']==criminal].tolist()
    pred_rank = pred_rank[0] if len(pred_rank)==1 else -1
    logger.info(f"The pred ranking about {criminal} is {pred_rank}")
    # display(df_ranking.iloc[:max(20, pred_rank)])
    len_ = len(df_attention)
    for i in range(len_-10, len_):
        print(f"index={i}, triple={df_attention.iloc[i,:3].tolist()}, attention list")
        display(df_attention.sort_values(f'atten_from{i}', ascending=False).iloc[:,[0,1,2,3+i]],)
    return df_ranking, df_attention

def check_killer(_title, _victim_name, _killer_name, _last_index, _story_len):
    return main_func01(_title, _victim_name, _killer_name, KILL, _last_index, _story_len)

# Estimate Criminals

### SpeckledBand(まだらの紐)
Who killed Julia? (criminal & explanation)
被害者: Julia
犯人: Roylott
犯行に用いたもの: snake
犯行動機: 母の相続財産を独占したい

### Input sequence is like this.


|     head     | relation  |            tail            |
|:------------:|:---------:|:--------------------------:|
| SpeckledBand |  stories  |            ...             |
|     ...      |    ...    |            ...             |
|    \<bos>    |  \<bos>   |           \<bos>           |
|  \<unknown>  | predicate |            kill            |
|  \<unknown>  |   whom    |           Julia            |
|  \<unknown>  |  subject  | \<mask(Answer is Roylott)> |
|  \<unknown>  |    why    |          \<mask>           |
|  \<unknown>  |   what    |          \<mask>           |
|  \<unknown>  |    why    |          \<mask>           |

In [15]:
def do_SpeckledBand_pred():
    title = 'SpeckledBand'
    victim_name = 'Julia'
    killer_name = 'Roylott'
    last_index = 401
    story_len = 80

    df_ranking, df_attention = check_killer(title, victim_name, killer_name, last_index, story_len)
    return df_ranking, df_attention
df_ranking, df_attention = do_SpeckledBand_pred()
display(df_ranking.iloc[:20])
pass

2023-02-13 10:48:19 - INFO - 3100288621.py - 13 - The pred ranking about SpeckledBand:Roylott is 0


index=381, triple=['SpeckledBand:401', 'kgc:hasProperty', 'word.predicate:band'], attention list


Unnamed: 0,head,relation,tail,atten_from381
357,SpeckledBand:396,kgc:time,DateTime:1883-04-01T23:00:00,0.009202
352,SpeckledBand:395,kgc:time,DateTime:1883-04-01T23:00:00,0.008888
360,SpeckledBand:397,kgc:hasProperty,word.predicate:angry,0.008638
351,SpeckledBand:395,kgc:when,SpeckledBand:1883-04-01T23,0.006995
119,SpeckledBand:346,kgc:hasProperty,word.predicate:open,0.006919
362,SpeckledBand:397,kgc:time,DateTime:1883-04-01T23:00:00,0.006653
361,SpeckledBand:397,kgc:when,SpeckledBand:1883-04-01T23,0.006329
293,SpeckledBand:381,kgc:hasProperty,word.predicate:unobservable,0.006259
233,SpeckledBand:369,kgc:hasPredicate,word.predicate:say,0.005774
354,SpeckledBand:396,kgc:subject,AllTitle:Holmes,0.005731


index=382, triple=['<bos_e>', '<bos_r>', '<bos_e>'], attention list


Unnamed: 0,head,relation,tail,atten_from382
345,SpeckledBand:394,kgc:subject,SpeckledBand:Helen,0.032461
6,SpeckledBand:323,kgc:subject,AllTitle:Holmes,0.016706
243,SpeckledBand:371,kgc:subject,AllTitle:Holmes,0.01347
118,SpeckledBand:346,kgc:subject,SpeckledBand:door_of_safe,0.013386
380,SpeckledBand:401,kgc:subject,SpeckledBand:Roma,0.012122
53,SpeckledBand:332,kgc:subject,AllTitle:Holmes,0.011867
48,SpeckledBand:331,kgc:subject,AllTitle:Holmes,0.010265
280,SpeckledBand:378,kgc:subject,SpeckledBand:coroner,0.008342
165,SpeckledBand:356,kgc:subject,SpeckledBand:band,0.008201
220,SpeckledBand:367,kgc:subject,AllTitle:Holmes,0.008151


index=383, triple=['<mask_e>', 'kgc:infoSource', 'AllTitle:Holmes'], attention list


Unnamed: 0,head,relation,tail,atten_from383
357,SpeckledBand:396,kgc:time,DateTime:1883-04-01T23:00:00,0.01112
342,SpeckledBand:393,kgc:what,SpeckledBand:metallic_sound,0.009512
352,SpeckledBand:395,kgc:time,DateTime:1883-04-01T23:00:00,0.008254
362,SpeckledBand:397,kgc:time,DateTime:1883-04-01T23:00:00,0.008114
127,SpeckledBand:347,kgc:time,DateTime:1883-04-02T04:00:00,0.007368
338,SpeckledBand:392,kgc:what,SpeckledBand:safe,0.007311
350,SpeckledBand:395,kgc:what,SpeckledBand:sound_of_snake,0.006885
121,SpeckledBand:346,kgc:time,DateTime:1883-04-02T04:00:00,0.006881
347,SpeckledBand:394,kgc:what,SpeckledBand:metallic_sound,0.006587
235,SpeckledBand:369,kgc:when,SpeckledBand:1881-12-02T00,0.006536


index=384, triple=['<mask_e>', 'kgc:hasPredicate', 'word.predicate:kill'], attention list


Unnamed: 0,head,relation,tail,atten_from384
163,SpeckledBand:355,kgc:time,DateTime:1883-04-02T04:00:00,0.007743
169,SpeckledBand:356,kgc:time,DateTime:1883-04-02T04:00:00,0.006893
110,SpeckledBand:344,kgc:subject,SpeckledBand:lanthanum,0.006786
132,SpeckledBand:348,kgc:time,DateTime:1883-04-02T04:00:00,0.006471
137,SpeckledBand:349,kgc:time,DateTime:1883-04-02T04:00:00,0.005794
91,SpeckledBand:339,kgc:infoSource,AllTitle:Holmes,0.005677
187,SpeckledBand:360,kgc:time,DateTime:1883-04-02T04:00:00,0.005611
156,SpeckledBand:353,kgc:time,DateTime:1883-04-02T04:00:00,0.005598
105,SpeckledBand:343,kgc:subject,SpeckledBand:lanthanum,0.005535
167,SpeckledBand:356,kgc:what,SpeckledBand:neck_of_Roylott,0.005501


index=385, triple=['<mask_e>', 'kgc:whom', 'SpeckledBand:Julia'], attention list


Unnamed: 0,head,relation,tail,atten_from385
31,SpeckledBand:327,kgc:what,SpeckledBand:whistle,0.007086
317,SpeckledBand:387,kgc:what,SpeckledBand:milk,0.006024
127,SpeckledBand:347,kgc:time,DateTime:1883-04-02T04:00:00,0.005735
245,SpeckledBand:371,kgc:what,SpeckledBand:VentilationHole,0.005602
85,SpeckledBand:338,kgc:infoSource,AllTitle:Holmes,0.005505
338,SpeckledBand:392,kgc:what,SpeckledBand:safe,0.005468
122,<bos_e>,<bos_r>,<bos_e>,0.005387
183,SpeckledBand:359,kgc:time,DateTime:1883-04-02T04:00:00,0.00534
117,<bos_e>,<bos_r>,<bos_e>,0.005331
104,<bos_e>,<bos_r>,<bos_e>,0.005294


index=386, triple=['<mask_e>', 'kgc:subject', '<mask_e>'], attention list


Unnamed: 0,head,relation,tail,atten_from386
342,SpeckledBand:393,kgc:what,SpeckledBand:metallic_sound,0.013825
347,SpeckledBand:394,kgc:what,SpeckledBand:metallic_sound,0.012781
123,SpeckledBand:347,kgc:subject,SpeckledBand:Roylott,0.012702
328,SpeckledBand:390,kgc:subject,SpeckledBand:Safe,0.012508
139,SpeckledBand:350,kgc:subject,SpeckledBand:dog_whip,0.012395
134,SpeckledBand:349,kgc:subject,SpeckledBand:Roylott,0.010934
329,SpeckledBand:390,kgc:subject,SpeckledBand:whip,0.009595
171,SpeckledBand:357,kgc:subject,SpeckledBand:band,0.009434
162,SpeckledBand:355,kgc:subject,SpeckledBand:Roylott,0.008797
118,SpeckledBand:346,kgc:subject,SpeckledBand:door_of_safe,0.008632


index=387, triple=['<mask_e>', 'kgc:why', '<mask_e>'], attention list


Unnamed: 0,head,relation,tail,atten_from387
348,SpeckledBand:394,kgc:when,SpeckledBand:069,0.008645
118,SpeckledBand:346,kgc:subject,SpeckledBand:door_of_safe,0.007331
380,SpeckledBand:401,kgc:subject,SpeckledBand:Roma,0.00728
345,SpeckledBand:394,kgc:subject,SpeckledBand:Helen,0.006983
130,SpeckledBand:348,kgc:what,SpeckledBand:decorative_wear,0.006624
131,SpeckledBand:348,kgc:when,SpeckledBand:1883-04-02T04,0.006584
350,SpeckledBand:395,kgc:what,SpeckledBand:sound_of_snake,0.006549
168,SpeckledBand:356,kgc:when,SpeckledBand:1883-04-02T04,0.006506
71,SpeckledBand:335,kgc:hasProperty,word.predicate:suffering_voice,0.005862
102,SpeckledBand:341,kgc:when,SpeckledBand:1883-04-02T04,0.00574


index=388, triple=['<mask_e>', 'kgc:what', '<mask_e>'], attention list


Unnamed: 0,head,relation,tail,atten_from388
345,SpeckledBand:394,kgc:subject,SpeckledBand:Helen,0.028272
118,SpeckledBand:346,kgc:subject,SpeckledBand:door_of_safe,0.018735
380,SpeckledBand:401,kgc:subject,SpeckledBand:Roma,0.013726
165,SpeckledBand:356,kgc:subject,SpeckledBand:band,0.012708
340,SpeckledBand:393,kgc:subject,SpeckledBand:safe,0.011637
280,SpeckledBand:378,kgc:subject,SpeckledBand:coroner,0.010621
226,SpeckledBand:368,kgc:subject,SpeckledBand:Roma,0.010297
328,SpeckledBand:390,kgc:subject,SpeckledBand:Safe,0.009959
139,SpeckledBand:350,kgc:subject,SpeckledBand:dog_whip,0.009543
171,SpeckledBand:357,kgc:subject,SpeckledBand:band,0.008538


index=389, triple=['<mask_e>', 'kgc:where', '<mask_e>'], attention list


Unnamed: 0,head,relation,tail,atten_from389
171,SpeckledBand:357,kgc:subject,SpeckledBand:band,0.009687
120,SpeckledBand:346,kgc:when,SpeckledBand:1883-04-02T04,0.008614
168,SpeckledBand:356,kgc:when,SpeckledBand:1883-04-02T04,0.008527
131,SpeckledBand:348,kgc:when,SpeckledBand:1883-04-02T04,0.008503
165,SpeckledBand:356,kgc:subject,SpeckledBand:band,0.007893
160,SpeckledBand:354,kgc:when,SpeckledBand:1883-04-02T04,0.007825
102,SpeckledBand:341,kgc:when,SpeckledBand:1883-04-02T04,0.007685
115,SpeckledBand:345,kgc:when,SpeckledBand:1883-04-02T04,0.007496
359,SpeckledBand:397,kgc:subject,SpeckledBand:snake,0.007455
136,SpeckledBand:349,kgc:when,SpeckledBand:1883-04-02T04,0.007275


index=390, triple=['<mask_e>', 'kgc:when', '<mask_e>'], attention list


Unnamed: 0,head,relation,tail,atten_from390
131,SpeckledBand:348,kgc:when,SpeckledBand:1883-04-02T04,0.015035
168,SpeckledBand:356,kgc:when,SpeckledBand:1883-04-02T04,0.012354
160,SpeckledBand:354,kgc:when,SpeckledBand:1883-04-02T04,0.010873
136,SpeckledBand:349,kgc:when,SpeckledBand:1883-04-02T04,0.010552
182,SpeckledBand:359,kgc:when,SpeckledBand:1883-04-02T04,0.010026
91,SpeckledBand:339,kgc:infoSource,AllTitle:Holmes,0.00936
115,SpeckledBand:345,kgc:when,SpeckledBand:1883-04-02T04,0.008689
120,SpeckledBand:346,kgc:when,SpeckledBand:1883-04-02T04,0.008518
102,SpeckledBand:341,kgc:when,SpeckledBand:1883-04-02T04,0.008315
126,SpeckledBand:347,kgc:when,SpeckledBand:1883-04-02T04,0.008173


Unnamed: 0_level_0,predicate,whom,subject,why,what,where,when
ranking,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
0,word.predicate:kill,DevilsFoot:Roundhay,SpeckledBand:Roylott,<bos_e>,<bos_e>,SpeckledBand:mansion_of_Roylott,<bos_e>
1,word.predicate:return,DevilsFoot:Doctor_Richard,<bos_e>,CrookedMan:Nancy,word.predicate:go,SpeckledBand:1883-04-01T07,SpeckledBand:1883-04-01T23
2,DevilsFoot:Roundhay,AllTitle:Watson,AbbeyGrange:Jack_Croker,DevilsFoot:Roundhay,DevilsFoot:Roundhay,<bos_e>,SpeckledBand:snake
3,DateTime:1883-04-01T15:00:00,CrookedMan:Henry,ResidentPatient:Blessington,word.predicate:find,SilverBlaze:John_Straker,word.predicate:find,SpeckledBand:1881-04-01
4,SilverBlaze:Ned_Hunter,AbbeyGrange:Lady_Brackenstall,DevilsFoot:Doctor_Richard,word.predicate:notKnow,word.predicate:find,word.predicate:notExist,DancingMen:Elsie
5,word.predicate:go,ACaseOfIdentity:Sutherland,DevilsFoot:Roundhay,word.predicate:know,SilverBlaze:Ned_Hunter,SpeckledBand:Julia_s_bedroom,DateTime:1883-04-01T23:00:00
6,word.predicate:shoot,CrookedMan:police,SpeckledBand:Helen,SpeckledBand:1883-04-02T04,word.predicate:know,word.predicate:shoot,word.predicate:say
7,word.predicate:drop,word.predicate:drop,SpeckledBand:VentilationHole,word.predicate:notExist,SpeckledBand:Helen,SpeckledBand:Roylott,word.predicate:find
8,word.predicate:find,SpeckledBand:Roylott,SilverBlaze:John_Straker,AbbeyGrange:dining_room,CrookedMan:Nancy,ResidentPatient:Blessington,SpeckledBand:1883-04-02T04
9,<bos_e>,DevilsFoot:window,SilverBlaze:Wife_of_John_Straker,SilverBlaze:Ned_Hunter,word.predicate:want,AbbeyGrange:dining_room,word.predicate:see


In [19]:
to_latex_df = df_ranking.iloc[:10].loc[:, ['what', 'when', 'where']].T
to_latex_df.columns = [f'Rank {i+1}' for i in range(10)]
print(to_latex_df.T.to_latex())

\begin{tabular}{llll}
\toprule
{} &                      what &                          when &                            where \\
\midrule
Rank 1  &                   <bos\_e> &                       <bos\_e> &  SpeckledBand:mansion\_of\_Roylott \\
Rank 2  &         word.predicate:go &    SpeckledBand:1883-04-01T23 &       SpeckledBand:1883-04-01T07 \\
Rank 3  &       DevilsFoot:Roundhay &            SpeckledBand:snake &                          <bos\_e> \\
Rank 4  &  SilverBlaze:John\_Straker &       SpeckledBand:1881-04-01 &              word.predicate:find \\
Rank 5  &       word.predicate:find &              DancingMen:Elsie &          word.predicate:notExist \\
Rank 6  &    SilverBlaze:Ned\_Hunter &  DateTime:1883-04-01T23:00:00 &     SpeckledBand:Julia\_s\_bedroom \\
Rank 7  &       word.predicate:know &            word.predicate:say &             word.predicate:shoot \\
Rank 8  &        SpeckledBand:Helen &           word.predicate:find &             SpeckledBand:Roylott \\
Ra

  print(to_latex_df.T.to_latex())


In [27]:
print(df_attention.sort_values(f'atten_from386', ascending=False).iloc[:,[0,1,2,3+386]].to_latex())

\begin{tabular}{llllr}
\toprule
{} &              head &          relation &                                             tail &  atten\_from386 \\
\midrule
342 &  SpeckledBand:393 &          kgc:what &                      SpeckledBand:metallic\_sound &       0.013825 \\
347 &  SpeckledBand:394 &          kgc:what &                      SpeckledBand:metallic\_sound &       0.012781 \\
123 &  SpeckledBand:347 &       kgc:subject &                             SpeckledBand:Roylott &       0.012702 \\
328 &  SpeckledBand:390 &       kgc:subject &                                SpeckledBand:Safe &       0.012508 \\
139 &  SpeckledBand:350 &       kgc:subject &                            SpeckledBand:dog\_whip &       0.012395 \\
134 &  SpeckledBand:349 &       kgc:subject &                             SpeckledBand:Roylott &       0.010934 \\
329 &  SpeckledBand:390 &       kgc:subject &                                SpeckledBand:whip &       0.009595 \\
171 &  SpeckledBand:357 &       kgc:

  print(df_attention.sort_values(f'atten_from386', ascending=False).iloc[:,[0,1,2,3+386]].to_latex())


### DevilsFoot(悪魔の足跡１)
Who killed the victims? (criminal & explanation)
被害者: Brenda
犯人: Mortimer
犯行動機: 財産

In [None]:
def do_devil1_pred():
    title = 'DevilsFoot'
    victim_name = 'Brenda'
    killer_name = 'Mortimer'
    last_index = 489
    story_len = 80
    df_ranking, df_attention = check_killer(title, victim_name, killer_name, last_index, story_len)
    return df_ranking, df_attention

do_devil1_pred()
pass

### DevilsFoot(悪魔の足跡2)
Who killed the victims? (criminal & explanation)
被害者: Mortimer
犯人: 
犯行動機: 恋人の敵

In [None]:
def do_devil2_pred():
    title = 'DevilsFoot'
    victim_name = 'Mortimer'
    killer_name = 'Sterndale'
    last_index = 489
    story_len = 80
    df_ranking, df_attention = check_killer(title, victim_name, killer_name, last_index, story_len)
    return df_ranking, df_attention

do_devil2_pred()
pass

### AbbeyGrange(僧坊荘園)
Who killed Lord Blackenstall? (criminal & explanation)
被害者: Sir_Eustace_Brackenstall
犯人: 
犯行動機:

In [None]:
def do_AbbeyGrange_pred():
    title = 'AbbeyGrange'
    victim_name = 'Sir_Eustace_Brackenstall'
    killer_name = 'Jack_Croker'
    last_index = 414
    story_len = 80

    df_ranking, df_attention = check_killer(title, victim_name, killer_name, last_index, story_len)
    return df_ranking, df_attention

do_AbbeyGrange_pred()
pass

### 入院患者
Who killed Blessington? (criminal & explanation)
被害者: Blessington
犯人: 3人
犯行動機:

In [None]:
def do_ResidentPatient_pred():
    title = 'ResidentPatient'
    victim_name = 'Blessington'
    killer_name = ''
    last_index = 324
    story_len = 80

    df_ranking, df_attention = check_killer(title, victim_name, killer_name, last_index, story_len)
    return df_ranking, df_attention

do_ResidentPatient_pred()

### 白銀
Who took out the White Silver Blaze? (criminal & explanation) 
被害者: Silver_Blaze
犯人: 
犯行動機:

In [None]:
victim = 'SilverBlaze:Silver_Blaze'
df_ranking_SilverBlaze, df_attension_SilverBlaze = make_ranking(
    'SilverBlaze:330', 'SilverBlaze:396', BRING, MASK_E, MASK_E, MASK_E, victim, MASK_E)

display(df_ranking_SilverBlaze.iloc[:20,:])
# display(df_attension_SpeckledBand)
# ヒートマップの作成
# sns.heatmap(df_atten.iloc[:,3:].iloc[:32,:32])
len_ = len(df_attension_SilverBlaze)
for i in range(len_-20, len_):
    display(i, df_attension_SilverBlaze.iloc[i,:3].tolist())
    display(df_attension_SilverBlaze.sort_values(f'atten_from{i}', ascending=False).iloc[:20,[0,1,2,3+i]])
    print("----------")

### CrookedMan(背中の曲がった男):
Why did Barclay die?
被害者: Barclay
犯人:
犯行動機:

In [None]:
victim = 'CrookedMan:Barclay'
df_ranking_CrookedMan, df_attension_CrookedMan = make_ranking(
    f'CrookedMan:{373-80+1}', 'CrookedMan:373', DIE, MASK_E, victim, MASK_E, MASK_E, MASK_E)

display(df_ranking_CrookedMan.iloc[:20,:])
# display(df_attension_SpeckledBand)
# ヒートマップの作成
# sns.heatmap(df_atten.iloc[:,3:].iloc[:32,:32])
len_ = len(df_attension_CrookedMan)
for i in range(len_-20, len_):
    display(i, df_attension_CrookedMan.iloc[i,:3].tolist())
    display(df_attension_CrookedMan.sort_values(f'atten_from{i}', ascending=False).iloc[:20,[0,1,2,3+i]])
    print("----------")

### 花嫁失踪事件（同一事件）
hozmaの失踪の名瀬を探る
被害者: ACaseOfIdentity:Hozma
犯人: 
犯行動機: 

In [None]:
victim = 'ACaseOfIdentity:Hosmer'
df_ranking_ACaseOfIdentity, df_attension_ACaseOfIdentity = make_ranking(
    'ACaseOfIdentity:510', 'ACaseOfIdentity:578', HIDE, victim, MASK_E, MASK_E, MASK_E, MASK_E)

display(df_ranking_ACaseOfIdentity.iloc[:20,:])
# display(df_attension_SpeckledBand)
# ヒートマップの作成
# sns.heatmap(df_atten.iloc[:,3:].iloc[:32,:32])
len_ = len(df_attension_ACaseOfIdentity)
for i in range(len_-20, len_):
    display(i, df_attension_ACaseOfIdentity.iloc[i,:3].tolist())
    display(df_attension_ACaseOfIdentity.sort_values(f'atten_from{i}', ascending=False).iloc[:20,[0,1,2,3+i]])
    print("----------")
