In [1]:
#!/usr/bin/python
# -*- coding: utf-8 -*-
# ========== python ==========
import os
from pathlib import Path
from logging import Logger
from typing import List, Dict, Tuple, Optional, Union, Callable, Final, Literal, get_args
from operator import itemgetter, attrgetter
import itertools
from IPython.display import display

from utils.setup import setup_logger, get_device
from const.const_values import PROJECT_DIR

os.chdir(PROJECT_DIR)
logger: Logger = setup_logger(__name__, f'{PROJECT_DIR}/log/jupyter_run.log')
device = get_device(device_name='cpu', logger=logger)

In [2]:
# jupyter
import seaborn as sns
import matplotlib.pyplot as plt
# Machine learning
import numpy as np
import pandas as pd
import h5py
import optuna
# torch
import torch
from torch.nn import CrossEntropyLoss
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from torch.utils.tensorboard.writer import SummaryWriter
# torch ignite
from ignite.engine import Engine
from ignite.handlers import Checkpoint, EpochOutputStore
# My items
from models.datasets.data_helper import MyDataHelper, MyDataLoaderHelper, DefaultTokens
from models.datasets.datasets_for_sequence import StoryTriple
# My utils
from utils.setup import load_param
from utils.torch import load_model, torch_fix_seed
# main function
from run_for_KGC import main_function, fix_args

In [3]:
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 100)

In [4]:
from const.const_values import CPU, MODEL
from models.KGModel.kg_model import HEAD, RELATION, TAIL
from utils.torch_ignite import TRAINER, EVALUATOR
from const.const_values import DATASETS, DATA_HELPER, DATA_LOADERS, TRAIN_RETURNS

In [5]:
SEED: Final[int] = 42
args_path = f'{PROJECT_DIR}/models/230205/01/param.pkl'
model_path = f'{PROJECT_DIR}/models/230205/01/model.pth'

In [6]:
args = load_param(args_path)

# args.pre_train = True
args.logger = logger
# args.device = device
args.batch_size = 16
args.pre_train=False
args.init_embedding_using_bert = False
args.model_path = model_path
args.only_load_trainer_evaluator = True
args.train_anyway=True
args.non_blocking=False
args.lr_head = None
args.lr_tail = None

args = fix_args(args)

del args.optuna_file, args.device_name, args.pid, args.study_name, args.n_trials



In [7]:
args

Namespace(notebook=False, console_level='info', logfile='models/230205/01/log.log', param_file='models/230205/01/param.pkl', train_anyway=True, old_data=0, tensorboard_dir='models/230205/01/tensorboard', checkpoint_dir='models/230205/01/checkpoint/', model_path='/Users/ryoyakaneda/Documents/学校/M1Study/knowledge_graph/models/230205/01/model.pth', resume_from_checkpoint=False, resume_from_last_point=False, only_load_trainer_evaluator=True, resume_checkpoint_path=None, pre_train=False, train_valid_test=True, only_train=False, use_for_challenge100=False, use_for_challenge090=False, use_for_challenge075=False, use_title=None, do_optuna=False, story_special_num=5, relation_special_num=5, entity_special_num=5, padding_token_e=0, cls_token_e=1, mask_token_e=2, sep_token_e=3, bos_token_e=4, padding_token_r=0, cls_token_r=1, mask_token_r=2, sep_token_r=3, bos_token_r=4, padding_token_s=0, cls_token_s=1, mask_token_s=2, sep_token_s=3, bos_token_s=4, model_version='03', embedding_dim=128, entity_e

In [8]:
args.device = get_device(device_name='cpu', logger=logger)
torch_fix_seed(seed=SEED)
return_dict = main_function(args, logger=logger)

model = return_dict[MODEL]

dataset_train: StoryTriple = return_dict[DATASETS][0]
triple: torch.Tensor = dataset_train.triple
data_helper: MyDataHelper = return_dict[DATA_HELPER]
evaluator: Checkpoint = return_dict[TRAIN_RETURNS][EVALUATOR]

load_model(model, args.model_path, args.device)
model.eval()

entities, relations = data_helper.processed_entities, data_helper.processed_relations
d_e, d_r = {e: i for i, e in enumerate(entities)}, {r: i for i, r in enumerate(relations)}

triple_df = pd.DataFrame([(entities[_t[0]], relations[_t[1]], entities[_t[2]]) for _t in triple], columns=[HEAD, RELATION, TAIL])
story_entities = triple_df[HEAD].tolist()

2023-02-09 00:42:55 - INFO - run_for_KGC.py - 989 - ----- make datahelper start. -----
2023-02-09 00:42:55 - INFO - data_helper.py - 334 - entity num: 7812
2023-02-09 00:42:55 - INFO - data_helper.py - 335 - relation num: 62
2023-02-09 00:42:55 - INFO - data_helper.py - 334 - entity num: 7812
2023-02-09 00:42:55 - INFO - data_helper.py - 335 - relation num: 62
2023-02-09 00:42:55 - INFO - data_helper.py - 609 - entity_special_dicts: {0: '<pad_e>', 1: '<cls_e>', 2: '<mask_e>', 3: '<sep_e>', 4: '<bos_e>'}
2023-02-09 00:42:55 - INFO - data_helper.py - 610 - relation_special_dicts: {0: '<pad_r>', 1: '<cls_r>', 2: '<mask_r>', 3: '<sep_r>', 4: '<bos_r>'}
2023-02-09 00:42:55 - INFO - data_helper.py - 611 - processed entity num: 7817
2023-02-09 00:42:55 - INFO - data_helper.py - 612 - processed relation num: 67
2023-02-09 00:42:55 - INFO - run_for_KGC.py - 991 - ----- make datahelper complete. -----
2023-02-09 00:42:55 - INFO - run_for_KGC.py - 993 - ----- make datasets start. -----
2023-02-09

In [None]:
eos = EpochOutputStore()
eos.attach(evaluator, 'output')

with torch.no_grad():
    model.eval()
    valid = return_dict['data_loaders'].test_dataloader
    evaluator = return_dict['train_returns']['evaluator']
    evaluator.run(valid)
    display(evaluator.state.metrics)
    display(evaluator.state.output)

In [None]:
with torch.no_grad():
    model.eval()
    valid = return_dict['data_loaders'].test_dataloader
    evaluator = return_dict['train_returns']['evaluator']
    evaluator.run(valid)
    display(evaluator.state.metrics)

In [17]:
print([len(list(g)) for k, g in itertools.groupby(story_entities)])

[1, 6, 1, 6, 1, 6, 1, 4, 1, 5, 1, 3, 1, 5, 1, 5, 1, 2, 1, 5, 1, 2, 1, 3, 1, 3, 1, 3, 1, 5, 1, 3, 1, 2, 1, 3, 1, 4, 1, 6, 1, 9, 1, 2, 1, 7, 1, 5, 1, 3, 1, 2, 1, 4, 1, 3, 1, 2, 1, 2, 1, 3, 1, 2, 1, 2, 1, 1, 1, 2, 1, 1, 1, 3, 1, 2, 1, 5, 1, 5, 1, 4, 1, 7, 1, 4, 1, 3, 1, 1, 1, 3, 1, 4, 1, 7, 1, 6, 1, 5, 1, 4, 1, 8, 1, 4, 1, 2, 1, 6, 1, 5, 1, 6, 1, 4, 1, 4, 1, 4, 1, 3, 1, 3, 1, 4, 1, 3, 1, 2, 1, 4, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 3, 1, 2, 1, 4, 1, 4, 1, 3, 1, 1, 1, 3, 1, 4, 1, 5, 1, 3, 1, 2, 1, 3, 1, 7, 1, 6, 1, 3, 1, 4, 1, 3, 1, 4, 1, 3, 1, 4, 1, 2, 1, 3, 1, 2, 1, 1, 1, 2, 1, 2, 1, 4, 1, 4, 1, 4, 1, 5, 1, 3, 1, 3, 1, 4, 1, 3, 1, 5, 1, 2, 1, 3, 1, 3, 1, 3, 1, 5, 1, 3, 1, 3, 1, 2, 1, 2, 1, 3, 1, 1, 1, 4, 1, 5, 1, 3, 1, 4, 1, 4, 1, 2, 1, 1, 1, 2, 1, 4, 1, 3, 1, 4, 1, 2, 1, 2, 1, 3, 1, 3, 1, 3, 1, 3, 1, 2, 1, 4, 1, 5, 1, 3, 1, 2, 1, 3, 1, 2, 1, 4, 1, 4, 1, 5, 1, 4, 1, 1, 1, 2, 1, 2, 1, 3, 1, 4, 1, 3, 1, 4, 1, 5, 1, 4, 1, 3, 1, 4, 1, 3, 1, 3, 1, 3, 1, 2, 1, 5, 1, 4, 1, 3, 1, 4, 1, 4, 1, 3, 1, 

In [None]:
most_frequency_percent_list = []
num_list = []

for tensor in dataset_train:
    head_unique, head_unique_count = torch.unique(tensor[:, 0], return_counts=True)
    relation_unique, relation_unique_count = torch.unique(tensor[:, 1], return_counts=True)
    tail_unique, tail_unique_count = torch.unique(tensor[:, 2], return_counts=True)
    entity_unique, entity_unique_count = torch.unique(tensor[:, (0,2)], return_counts=True)

    head_unique, head_unique_count = [_tensor[head_unique!=4] for _tensor in (head_unique, head_unique_count)]
    relation_unique, relation_unique_count = [_tensor[relation_unique!=4] for _tensor in (relation_unique, relation_unique_count)]
    tail_unique, tail_unique_count = [_tensor[tail_unique!=4] for _tensor in (tail_unique, tail_unique_count)]
    entity_unique, entity_unique_count =  [_tensor[entity_unique!=4] for _tensor in (entity_unique, entity_unique_count)]

    head_most_frequency = torch.max(head_unique_count)
    relation_most_frequency = torch.max(relation_unique_count)
    tail_most_frequency = torch.max(tail_unique_count)
    entity_most_frequency = torch.max(entity_unique_count)

    print(entities[tail_unique[torch.argmax(tail_unique_count)]])
    most_frequency_percent_list.append(
        [(value/len(tensor)).item() for value in (head_most_frequency, relation_most_frequency, tail_most_frequency, entity_most_frequency) ])
    num_list.append([1/len(value) for value in (head_unique, relation_unique, tail_unique, entity_unique)])

torch.tensor(most_frequency_percent_list).mean(dim=0), torch.tensor(num_list).mean(dim=0)

#### Note that the data presented in the previous section are part of this series below.

This is the model made by Pytorch.

In [None]:
model

This function is used to visualize Attention. The description is skipped.

In [None]:
# This is sub functions for getting Attention.
def extract(_model, target, inputs):
    """This is sub functions for getting Attention.

    """
    features = None

    def forward_hook(_module, _inputs, _):
        nonlocal features
        x, _, _ = _inputs
        outputs = _module.forward(x, x, x, need_weights=True)[1]
        features = outputs.detach().clone()

    handle = target.register_forward_hook(forward_hook)

    _model.eval()
    _model(inputs, torch.LongTensor([[]]), torch.LongTensor([[]]), torch.LongTensor([[]]))

    handle.remove()

    return features

# This is main functions for getting Attention.
def get_attention(input_):
    """This is main functions for getting Attention.

    """
    assert len(input_) == 1
    features = extract(model, model.transformer.layers[-1].self_attn, input_)[0]
    df_attention = pd.DataFrame([[entities[h], relations[r], entities[t]]+[features[j, i].item() for j in range(len(features))] for i, (h, r, t) in enumerate(input_[0])])
    df_attention.columns=[HEAD, RELATION, TAIL] + [f'atten_from{i}' for i in range(len(df_attention.columns)-3)]
    return df_attention

def show_attention_heatmap(df_attention):
    sns.heatmap(df_attention.iloc[:,3:])
    plt.show()

In [None]:
MASK_E = DefaultTokens.MASK_E
KILL = 'word.predicate:kill'

TAKE = 'word.predicate:take'
BRING = 'word.predicate:bring'
DIE = 'word.predicate:die'
HIDE = 'word.predicate:hide'

The input to the model is the following function's "question_".
There are some last scenes and criminal scene after last scene.
Each part of criminal scene are changed into a victim or <mask>, and the MASK part is estimated.

In [None]:
bos_triple = [d_e[DefaultTokens.BOS_E], d_r[DefaultTokens.BOS_R],d_e[DefaultTokens.BOS_E]]
mask_e_id = d_e[DefaultTokens.MASK_E]
Holmes_id =d_e['AllTitle:Holmes']

def make_ranking(from_story_name, to_story_name, predicate_, whom_, subject_, why_, what_, where_):
    if not (from_story_name is None and to_story_name is None):
        _start_index = story_entities.index(from_story_name)-1
        _end_index = len(story_entities) - story_entities[::-1].index(to_story_name)
    else:
        _start_index = 0
        _end_index = 0
    question_ = torch.tensor(
        [
            # [... last 80 scenes ...],
            # ...
            bos_triple,
            [mask_e_id, d_r['kgc:infoSource'],     Holmes_id      ],
            [mask_e_id, d_r['kgc:hasPredicate'],   d_e[predicate_]],
            [mask_e_id, d_r['kgc:whom'],           d_e[whom_     ]],
            [mask_e_id, d_r['kgc:subject'],        d_e[subject_  ]],
            [mask_e_id, d_r['kgc:why'],            d_e[why_      ]],
            [mask_e_id, d_r['kgc:what'],           d_e[what_     ]],
            [mask_e_id, d_r['kgc:where'],          d_e[where_    ]],
        ]
    )
    mask_ = torch.zeros_like(question_, dtype=torch.bool) # not mask all position
    mask_[1:, 0] = True                                   # where head position without bos token
    mask_[1:, 2] = True                                   # where tail position without bos token

    last_triples = triple[_start_index: _end_index]

    questions = torch.cat([last_triples, question_], dim=0).unsqueeze(0)
    masks = torch.cat([torch.zeros_like(last_triples), mask_], dim=0).to(torch.bool).transpose(1,0).unsqueeze(0)

    data_list = []
    with torch.no_grad():
        _, (story_pred, relation_pred, entity_pred) = model(questions, masks[:,0], masks[:,1], masks[:,2])
        sorted_ = torch.argsort(entity_pred, dim=1, descending=True)
        for i in range(sorted_.shape[1]):
            ans_= sorted_[:, i]
            info_source_, predicate_pred, whom_pred, subject_pred, why_pred, what_pred, where_pred = ans_
            data_list.append([entities[predicate_pred], entities[whom_pred], entities[subject_pred], entities[why_pred], entities[what_pred], entities[where_pred]])
    df_ranking = pd.DataFrame(data_list, columns=['predicate', 'whom', 'subject', 'why', 'what', 'where'])
    df_attension = get_attention(questions)

    return df_ranking, df_attension

In [None]:
def main_func01(_title, _victim_name, criminal, predicate, _last_index, _story_len):
    from_ = f'{_title}:{_last_index-_story_len+1}'
    to_ = f'{_title}:{_last_index}'
    predicate = predicate
    victim = f'{_title}:{_victim_name}'
    criminal = f'{_title}:{criminal}'
    df_ranking, df_attention = make_ranking(
        from_, to_, predicate, victim, MASK_E, MASK_E, MASK_E, MASK_E)
    df_ranking.index.name='ranking'

    pred_rank = df_ranking.index[df_ranking['subject']==criminal].tolist()
    pred_rank = pred_rank[0] if len(pred_rank)==1 else -1
    logger.info(f"The pred ranking about {criminal} is {pred_rank}")
    display(df_ranking.iloc[:max(20, pred_rank)])
    len_ = len(df_attention)
    for i in range(len_-10, len_):
        print(f"index={i}, triple={df_attention.iloc[i,:3].tolist()}, attention list")
        display(df_attention.sort_values(f'atten_from{i}', ascending=False).iloc[:,[0,1,2,3+i]],)
    return df_ranking, df_attention

def check_killer(_title, _victim_name, _killer_name, _last_index, _story_len):
    return main_func01(_title, _victim_name, _killer_name, KILL, _last_index, _story_len)

# Estimate Criminals

### SpeckledBand(まだらの紐)
Who killed Julia? (criminal & explanation)
被害者: Julia
犯人: Roylott
犯行に用いたもの: snake
犯行動機: 母の相続財産を独占したい

### Input sequence is like this.


|     head     | relation  |            tail            |
|:------------:|:---------:|:--------------------------:|
| SpeckledBand |  stories  |            ...             |
|     ...      |    ...    |            ...             |
|    \<bos>    |  \<bos>   |           \<bos>           |
|  \<unknown>  | predicate |            kill            |
|  \<unknown>  |   whom    |           Julia            |
|  \<unknown>  |  subject  | \<mask(Answer is Roylott)> |
|  \<unknown>  |    why    |          \<mask>           |
|  \<unknown>  |   what    |          \<mask>           |
|  \<unknown>  |    why    |          \<mask>           |

In [None]:
def do_SpeckledBand_pred():
    title = 'SpeckledBand'
    victim_name = 'Julia'
    killer_name = 'Roylott'
    last_index = 401
    story_len = 80

    df_ranking, df_attention = check_killer(title, victim_name, killer_name, last_index, story_len)
    return df_ranking, df_attention
do_SpeckledBand_pred()
pass

### DevilsFoot(悪魔の足跡１)
Who killed the victims? (criminal & explanation)
被害者: Brenda
犯人: Mortimer
犯行動機: 財産

In [None]:
def do_devil1_pred():
    title = 'DevilsFoot'
    victim_name = 'Brenda'
    killer_name = 'Mortimer'
    last_index = 489
    story_len = 80
    df_ranking, df_attention = check_killer(title, victim_name, killer_name, last_index, story_len)
    return df_ranking, df_attention

do_devil1_pred()
pass

### DevilsFoot(悪魔の足跡2)
Who killed the victims? (criminal & explanation)
被害者: Mortimer
犯人: 
犯行動機: 恋人の敵

In [None]:
def do_devil2_pred():
    title = 'DevilsFoot'
    victim_name = 'Mortimer'
    killer_name = 'Sterndale'
    last_index = 489
    story_len = 80
    df_ranking, df_attention = check_killer(title, victim_name, killer_name, last_index, story_len)
    return df_ranking, df_attention

do_devil2_pred()
pass

### AbbeyGrange(僧坊荘園)
Who killed Lord Blackenstall? (criminal & explanation)
被害者: Sir_Eustace_Brackenstall
犯人: 
犯行動機:

In [None]:
def do_AbbeyGrange_pred():
    title = 'AbbeyGrange'
    victim_name = 'Sir_Eustace_Brackenstall'
    killer_name = 'Jack_Croker'
    last_index = 414
    story_len = 80

    df_ranking, df_attention = check_killer(title, victim_name, killer_name, last_index, story_len)
    return df_ranking, df_attention

do_AbbeyGrange_pred()
pass

### 入院患者
Who killed Blessington? (criminal & explanation)
被害者: Blessington
犯人: 3人
犯行動機:

In [None]:
def do_ResidentPatient_pred():
    title = 'ResidentPatient'
    victim_name = 'Blessington'
    killer_name = ''
    last_index = 324
    story_len = 80

    df_ranking, df_attention = check_killer(title, victim_name, killer_name, last_index, story_len)
    return df_ranking, df_attention

do_ResidentPatient_pred()

### 白銀
Who took out the White Silver Blaze? (criminal & explanation) 
被害者: Silver_Blaze
犯人: 
犯行動機:

In [None]:
victim = 'SilverBlaze:Silver_Blaze'
df_ranking_SilverBlaze, df_attension_SilverBlaze = make_ranking(
    'SilverBlaze:330', 'SilverBlaze:396', BRING, MASK_E, MASK_E, MASK_E, victim, MASK_E)

display(df_ranking_SilverBlaze.iloc[:20,:])
# display(df_attension_SpeckledBand)
# ヒートマップの作成
# sns.heatmap(df_atten.iloc[:,3:].iloc[:32,:32])
len_ = len(df_attension_SilverBlaze)
for i in range(len_-20, len_):
    display(i, df_attension_SilverBlaze.iloc[i,:3].tolist())
    display(df_attension_SilverBlaze.sort_values(f'atten_from{i}', ascending=False).iloc[:20,[0,1,2,3+i]])
    print("----------")

### CrookedMan(背中の曲がった男):
Why did Barclay die?
被害者: Barclay
犯人:
犯行動機:

In [None]:
victim = 'CrookedMan:Barclay'
df_ranking_CrookedMan, df_attension_CrookedMan = make_ranking(
    f'CrookedMan:{373-80+1}', 'CrookedMan:373', DIE, MASK_E, victim, MASK_E, MASK_E, MASK_E)

display(df_ranking_CrookedMan.iloc[:20,:])
# display(df_attension_SpeckledBand)
# ヒートマップの作成
# sns.heatmap(df_atten.iloc[:,3:].iloc[:32,:32])
len_ = len(df_attension_CrookedMan)
for i in range(len_-20, len_):
    display(i, df_attension_CrookedMan.iloc[i,:3].tolist())
    display(df_attension_CrookedMan.sort_values(f'atten_from{i}', ascending=False).iloc[:20,[0,1,2,3+i]])
    print("----------")

### 花嫁失踪事件（同一事件）
hozmaの失踪の名瀬を探る
被害者: ACaseOfIdentity:Hozma
犯人: 
犯行動機: 

In [None]:
victim = 'ACaseOfIdentity:Hosmer'
df_ranking_ACaseOfIdentity, df_attension_ACaseOfIdentity = make_ranking(
    'ACaseOfIdentity:510', 'ACaseOfIdentity:578', HIDE, victim, MASK_E, MASK_E, MASK_E, MASK_E)

display(df_ranking_ACaseOfIdentity.iloc[:20,:])
# display(df_attension_SpeckledBand)
# ヒートマップの作成
# sns.heatmap(df_atten.iloc[:,3:].iloc[:32,:32])
len_ = len(df_attension_ACaseOfIdentity)
for i in range(len_-20, len_):
    display(i, df_attension_ACaseOfIdentity.iloc[i,:3].tolist())
    display(df_attension_ACaseOfIdentity.sort_values(f'atten_from{i}', ascending=False).iloc[:20,[0,1,2,3+i]])
    print("----------")
