In [1]:
#!/usr/bin/python
# -*- coding: utf-8 -*-
# ========== python ==========
import os
from pathlib import Path
from logging import Logger
from typing import List, Dict, Tuple, Optional, Union, Callable, Final, Literal, get_args
from operator import itemgetter, attrgetter
import itertools
import collections

import numpy as np

from utils.setup import easy_logger, get_device
from const.const_values import PROJECT_DIR

os.chdir(PROJECT_DIR)
logger: Logger = easy_logger(f'{PROJECT_DIR}/log/jupyter_run.log')
device = get_device(device_name='cpu', logger=logger)
print(PROJECT_DIR)

/Users/ryoyakaneda/Documents/学校/M1Study/knowledge_graph


In [2]:
import pandas as pd
import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
import matplotlib.pyplot as plt
wn18rr_train_file_path = f"{PROJECT_DIR}/data/external/KGdata/WN18RR/text/train.txt"
wn18rr_valid_file_path = f"{PROJECT_DIR}/data/external/KGdata/WN18RR/text/valid.txt"
wn18rr_test_file_path = f"{PROJECT_DIR}/data/external/KGdata/WN18RR/text/test.txt"
HEAD, RELATION, TAIL = 'head', 'relation', 'tail'
MODE = 'mode'
HYPERNYM = '_hypernym'

In [91]:
def _get_unique_list(series):
    item2count = pd.DataFrame(series.value_counts()).reset_index()
    item2count.columns=['item', 'count']
    item2count = item2count.sort_values(['count', 'item'], ascending=[False, True])
    return item2count['item'].to_list()

def get_entity_relation(df):
    entities, relations = pd.concat([df[HEAD], df[TAIL]]).sort_values(), df[RELATION].sort_values()
    entities.columns, relations.columns = ['entity'], ['relation']
    return pd.concat([df[HEAD], df[TAIL]]).sort_values(), df[RELATION].sort_values()

def get_entity_relation_unique_list(df):
    entities, relations = get_entity_relation(df)
    return _get_unique_list(entities), _get_unique_list(relations)

def get_id2entity_id2relation(df):
    entities_list, relations_list = get_entity_relation_unique_list(df)
    return {i: e for i, e in enumerate(entities_list)}, {i: r for i, r in enumerate(relations_list)}

def get_entity2idx_relation2idx(df):
    entities_list, relations_list = get_entity_relation_unique_list(df)
    return {e: i for i, e in enumerate(entities_list)}, {r: i for i, r in enumerate(relations_list)}

def change_entity_relation(list_, entity2idx, relation2idx, entity_indexes, relation_indexes):
    new_list = []
    for i, item in enumerate(list_):
        new_list.append(entity2idx[item] if i in entity_indexes else
                        relation2idx[item] if i in relation_indexes else item)
    return tuple(new_list)

def get_hypernym_df(df):
    return df[df[RELATION]==HYPERNYM]

def get_hypernym_dict(df, entities):
    hypernym_df = get_hypernym_df(df)
    hypernym_dict = {key: [] for key in entities}
    for index, row in hypernym_df.iterrows():
        hypernym_dict[row[HEAD]].append(row[TAIL])
    return hypernym_dict

def get_entity2triples(df)->dict:
    entities, _ = get_entity_relation(df)
    entity2triples = {key: [] for key in entities}
    for index, row in df.iterrows():
        triple = (row[HEAD], row[RELATION], row[TAIL], row[MODE])
        entity2triples[row[HEAD]].append(triple)
    return entity2triples

def get_hypernym_list(key, list_, *, hypernym_dict):
    list_ = list_ + [key]
    list_list = []
    for child_key in hypernym_dict[key]:
        if child_key in set(list_):
            list_list.append(list_)
        else:
            list_list.extend(get_hypernym_list(child_key, list_, hypernym_dict=hypernym_dict))
    if len(list_list)==0:
        list_list = [list_]
    return list_list

def get_to_top_list_list(df):
    hypernym_df = get_hypernym_df(df)
    hypernym_bottom_entity_set = set(hypernym_df[HEAD]) - set(hypernym_df[TAIL])
    hypernym_df = get_hypernym_df(df)
    entities, _ = get_entity_relation(df)
    hypernym_dict = get_hypernym_dict(hypernym_df, entities)
    to_top_list_list = []
    for e in hypernym_bottom_entity_set:
        to_top_list = get_hypernym_list(e, [], hypernym_dict=hypernym_dict)
        to_top_list_list.extend(to_top_list)
    return to_top_list_list

def get_to_top_triples_list(df, to_top_list_list):
    entity2triples = get_entity2triples(df)
    to_top_triples_list = []
    for to_top_list in to_top_list_list:
        to_top_triples = []
        for e in to_top_list:
            to_top_triples.extend(entity2triples[e])
        to_top_triples_list.append(to_top_triples)
    return to_top_triples_list

def get_to_top_triples_list_limit_(to_top_triples_list, limit):
    new_to_top_triples_list_limit = []
    for list_ in to_top_triples_list:
        for i in range(0, len(list_), limit):
            new_to_top_triples_list_limit.append(list_[i: i+limit])
    return new_to_top_triples_list_limit

def get_to_top_tensor_list(to_top_triples_list, entity2idx, relation2idx):
    to_top_tensor_list = [torch.tensor(
        [change_entity_relation(_triple, entity2idx, relation2idx, (0,2), (1,)) for _triple in to_top_triples]
        , dtype=torch.int64) for to_top_triples in to_top_triples_list]
    return to_top_tensor_list

def get_sequence_tensor(tensor_list, padding_token, max_len):
    sequence_tensor = torch.tensor([[padding_token for _ in range(max_len)] for _ in range(len(tensor_list))
                                    ], dtype=torch.int64)
    [sequence_tensor[i, :len(_tensor)].copy_(_tensor) for  i, _tensor in enumerate(tensor_list)]
    return sequence_tensor

In [92]:
def func01(df, *, to_top_list_list=None):
    df = df.sort_values([HEAD, RELATION, TAIL]).reset_index()
    to_top_list_list = to_top_list_list or get_to_top_list_list(df)
    to_top_triples_list = get_to_top_triples_list(df, to_top_list_list)
    to_top_triples_list_limit32 = get_to_top_triples_list_limit_(to_top_triples_list, 32)

    return to_top_list_list, to_top_triples_list, to_top_triples_list_limit32

In [93]:
def main():
    train_df = pd.read_table(wn18rr_train_file_path, header=None, names=(HEAD, RELATION, TAIL)).assign(**{MODE: 1})
    valid_df = pd.read_table(wn18rr_valid_file_path, header=None, names=(HEAD, RELATION, TAIL)).assign(**{MODE: 2})
    test_df = pd.read_table(wn18rr_test_file_path, header=None, names=(HEAD, RELATION, TAIL)).assign(**{MODE: 3})

    entity2idx, relation2idx = get_entity2idx_relation2idx(pd.concat([train_df, valid_df, test_df]))

    train_to_top_list_list, _, train_to_top_triples_list_limit32 = func01(train_df)
    _, _, valid_to_top_triples_list_limit32 = func01(pd.concat([train_df, valid_df]), to_top_list_list=train_to_top_list_list)
    _, _, test_to_top_triples_list_limit32 = func01(pd.concat([train_df, valid_df, test_df]), to_top_list_list= train_to_top_list_list)

    # print(train_to_top_triples_list_limit32[0])
    train_to_top_tensor_list = get_to_top_tensor_list(train_to_top_triples_list_limit32, entity2idx, relation2idx)
    sequence_tensor = get_sequence_tensor(train_to_top_tensor_list, [0,0,0,0], 32)


In [94]:
main()

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



torch.Size([30425, 32, 4])


In [None]:
plt.hist([len(list_) for list_ in train_to_top_triples_list_limit32])

In [None]:
plt.hist([len(list_) for list_ in valid_to_top_triples_list_limit32])

In [None]:
len(train_to_top_triples_list), len(valid_to_top_triples_list)

In [None]:
valid_to_top_triples_list[:10]