In [17]:
# coding: utf-8
"""
KnowledgeGraph
"""
import os
import brain.config as config
from transformers import BertTokenizer
import pkuseg
import numpy as np


class KnowledgeGraph(object):
    """
    spo_files - list of Path of *.spo files, or default kg name. e.g., ['HowNet']
    """

    def __init__(self, spo_files, predicate=False, tokenizer_name= ''):
        self.predicate = predicate
        self.spo_file_paths = [config.KGS.get(f, f) for f in spo_files]
        self.lookup_table = self._create_lookup_table()
        self.segment_vocab = list(self.lookup_table.keys()) + config.NEVER_SPLIT_TAG
        self.tokenizer_name = tokenizer_name
        if not tokenizer_name:
            self.tokenizer = pkuseg.pkuseg(model_name="default", postag=False, user_dict=self.segment_vocab)
        else:
            self.tokenizer = BertTokenizer.from_pretrained(tokenizer_name)
        self.special_tags = set(config.NEVER_SPLIT_TAG)

    def _create_lookup_table(self):
        lookup_table = {}
        for spo_path in self.spo_file_paths:
            print("[KnowledgeGraph] Loading spo from {}".format(spo_path))
            with open(spo_path, 'r', encoding='utf-8') as f:
                for line in f:
                    try:
                        subj, pred, obje = line.strip().split("\t")    # medKG which written in english, need to be adopted here
                        subj, pred, obje = subj.replace('_',' '), pred.replace('_',' '), obje.replace('_', '')
                        subj, pred, 
                    except:
                        print("[KnowledgeGraph] Bad spo:", line)
                    if self.predicate:
                        value = pred +' ' + obje
                    else:
                        value = obje
                    if subj in lookup_table.keys():
                        lookup_table[subj].add(' '+value)
                    else:
                        lookup_table[subj] = set([value])
        return lookup_table

    def add_knowledge_with_vm(self, sent_batch, max_entities=config.MAX_ENTITIES, add_pad=True, max_length=128):
        """
        input: sent_batch - list of sentences, e.g., ["abcd", "efgh"]
        return: know_sent_batch - list of sentences with entites embedding
                position_batch - list of position index of each character.
                visible_matrix_batch - list of visible matrixs
                seg_batch - list of segment tags
        """
        if not self.tokenizer_name:
            split_sent_batch = [self.tokenizer.cut(sent) for sent in sent_batch]
        else:
            split_sent_batch = [self.tokenizer.tokenize(sent) for sent in sent_batch]
        know_sent_batch = []
        position_batch = []
        visible_matrix_batch = []
        seg_batch = []
        for split_sent in split_sent_batch:

            # create tree
            sent_tree = []
            pos_idx_tree = []
            abs_idx_tree = []
            pos_idx = -1
            abs_idx = -1
            abs_idx_src = []
            for token in split_sent:
                
                entities = list(self.lookup_table.get(token, []))[:max_entities]
                sent_tree.append((token, entities))

                if token in self.special_tags:
                    token_pos_idx = [pos_idx+1]
                    token_abs_idx = [abs_idx+1]
                else:
                    token_pos_idx = [pos_idx+i for i in range(1, len(token)+1)]
                    token_abs_idx = [abs_idx+i for i in range(1, len(token)+1)]
                abs_idx = token_abs_idx[-1]

                entities_pos_idx = []
                entities_abs_idx = []
                for ent in entities:
                    ent_pos_idx = [token_pos_idx[-1] + i for i in range(1, len(ent)+1)]
                    entities_pos_idx.append(ent_pos_idx)
                    ent_abs_idx = [abs_idx + i for i in range(1, len(ent)+1)]
                    abs_idx = ent_abs_idx[-1]
                    entities_abs_idx.append(ent_abs_idx)

                pos_idx_tree.append((token_pos_idx, entities_pos_idx))
                pos_idx = token_pos_idx[-1]
                abs_idx_tree.append((token_abs_idx, entities_abs_idx))
                abs_idx_src += token_abs_idx

            # Get know_sent and pos
            know_sent = []
            pos = []
            seg = []
            for i in range(len(sent_tree)):
                word = sent_tree[i][0]
                if word in self.special_tags:
                    know_sent += [word]
                    seg += [0]
                else:
                    add_word = list(word)
                    know_sent += add_word 
                    seg += [0] * len(add_word)
                pos += pos_idx_tree[i][0]
                for j in range(len(sent_tree[i][1])):
                    add_word = list(sent_tree[i][1][j])
                    know_sent += add_word
                    seg += [1] * len(add_word)
                    pos += list(pos_idx_tree[i][1][j])

            token_num = len(know_sent)

            # Calculate visible matrix
            visible_matrix = np.zeros((token_num, token_num))
            for item in abs_idx_tree:
                src_ids = item[0]
                for id in src_ids:
                    visible_abs_idx = abs_idx_src + [idx for ent in item[1] for idx in ent]
                    visible_matrix[id, visible_abs_idx] = 1
                for ent in item[1]:
                    for id in ent:
                        visible_abs_idx = ent + src_ids
                        visible_matrix[id, visible_abs_idx] = 1

            src_length = len(know_sent)
            if len(know_sent) < max_length:
                pad_num = max_length - src_length
                know_sent += [config.PAD_TOKEN] * pad_num
                seg += [0] * pad_num
                pos += [max_length - 1] * pad_num
                visible_matrix = np.pad(visible_matrix, ((0, pad_num), (0, pad_num)), 'constant')  # pad 0
            else:
                know_sent = know_sent[:max_length]
                seg = seg[:max_length]
                pos = pos[:max_length]
                visible_matrix = visible_matrix[:max_length, :max_length]
            
            know_sent_batch.append(know_sent)
            position_batch.append(pos)
            visible_matrix_batch.append(visible_matrix)
            seg_batch.append(seg)
        
        return know_sent_batch, position_batch, visible_matrix_batch, seg_batch


In [24]:
x = KnowledgeGraph(['./brain/kgs/kg_anatomy3kAndAtelectasis.spo'],predicate=True, tokenizer_name='microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext')
z = KnowledgeGraph(['./brain/kgs/kg_anatomy3kAndAtelectasis.spo'],predicate=True)

[KnowledgeGraph] Loading spo from ./brain/kgs/kg_anatomy3kAndAtelectasis.spo
[KnowledgeGraph] Loading spo from ./brain/kgs/kg_anatomy3kAndAtelectasis.spo


In [19]:
x.lookup_table

{'x name': {'Predicate yname'},
 'Atelectasis': {' disease phenotype positive phenotype present 22q11.2deletionsyndrome',
  ' disease phenotype positive phenotype present Beemer-Langersyndrome',
  ' disease phenotype positive phenotype present CHANDsyndrome',
  ' disease phenotype positive phenotype present Farberlipogranulomatosis',
  ' disease phenotype positive phenotype present Hyper-IgErecurrentinfectionsyndrome1',
  ' disease phenotype positive phenotype present Rowley-Rosenbergsyndrome',
  ' disease phenotype positive phenotype present WHIMsyndrome',
  ' disease phenotype positive phenotype present Waardenburgsyndrome',
  ' disease phenotype positive phenotype present Zygomycosis',
  ' disease phenotype positive phenotype present acuteinterstitialpneumonia',
  ' disease phenotype positive phenotype present asbestosis',
  ' disease phenotype positive phenotype present bronchopulmonarydysplasia',
  ' disease phenotype positive phenotype present congenitalmerosin-deficientmusculard

In [13]:
x.tokenizer.special_tokens_map

{'unk_token': '[UNK]',
 'sep_token': '[SEP]',
 'pad_token': '[PAD]',
 'cls_token': '[CLS]',
 'mask_token': '[MASK]'}

In [20]:
text = [['this is a test'],
        ['Atelectasis is random thing'],
        ['Atelectasis is another random thing'],
        ['Atelectasis has some relationship with familialnasalacilia']]

In [28]:
[x.tokenizer.tokenize(m) for n in text for m in n]

[['this', 'is', 'a', 'test'],
 ['ate', '##lec', '##ta', '##sis', 'is', 'random', 'thing'],
 ['ate', '##lec', '##ta', '##sis', 'is', 'another', 'random', 'thing'],
 ['ate',
  '##lec',
  '##ta',
  '##sis',
  'has',
  'some',
  'relationship',
  'with',
  'familial',
  '##nas',
  '##ala',
  '##cil',
  '##ia']]

In [30]:
x.tokenizer.encode(['ate',
  '##lec',
  '##ta',
  '##sis',
  'has',
  'some',
  'relationship',
  'with',
  'familial',
  '##nas',
  '##ala',
  '##cil',
  '##ia'])

[2,
 30280,
 25419,
 3857,
 8362,
 2258,
 2673,
 3303,
 1956,
 11444,
 3676,
 8056,
 10401,
 2126,
 3]

In [31]:
x.tokenizer.decode([2,
 30280,
 25419,
 3857,
 8362,
 2258,
 2673,
 3303,
 1956,
 11444,
 3676,
 8056,
 10401,
 2126,
 3])

'[CLS] atelectasis has some relationship with familialnasalacilia [SEP]'

In [27]:
[z.tokenizer.cut(m) for n in text for m in n]

[['this', 'is', 'a', 'test'],
 ['Atelectasis', 'is', 'random', 'thing'],
 ['Atelectasis', 'is', 'another', 'random', 'thing'],
 ['Atelectasis', 'has', 'some', 'relationship', 'with', 'familialnasalacilia']]

In [21]:
y = x.add_knowledge_with_vm(text)

TypeError: expected string or bytes-like object, got 'list'

In [37]:
xs= ['ste','la fwef', 'la fwew', 'la wsst']

In [45]:

[1 if 'la' in x else 0 for x in xs].count(1)

3

In [46]:
tem_dic = {'darer': 1, 'awfw': 2, 'a': 9, 'zzzzzz': 9999, 'c': 0}

In [53]:
list_te =list(tem_dic.keys())

In [54]:
list_te.sort()

In [55]:
list_te

['a', 'awfw', 'c', 'darer', 'zzzzzz']

In [56]:
len(tem_dic)

5

In [58]:
['s' for x in [1,2,3,4]]

['s', 's', 's', 's']

In [64]:
import pandas as pd

test =  pd.read_csv('./datasets/CheXpert/impression/test.csv', sep='\t')

In [65]:
test['label_Atelectasis'].value_counts()

label_Atelectasis
0    24923
1     6524
2     1677
Name: count, dtype: int64