数据分析和清洗

In [None]:
import numpy as np
import torch
import pandas as pd
from transformers import AutoTokenizer
import pickle
import pandas as pd
from entity_dataset2 import EntityDataset

tok = AutoTokenizer.from_pretrained("/home/cs/yangyuchen/yushengliao/Medical_LLM/FastChat/checkpoints/medical_llama_13b_chatv1.3/checkpoint-4974/")
tok.padding_side = 'right'
tok.pad_token = tok.eos_token
tok.pad_token_id = tok.eos_token_id
tok.add_tokens(["[DASH]"])
dst = EntityDataset(data_path='data/kg_usmle_train_subset.json', kg_path='data/bios_kg_with_def.csv', ignore_output=True, size=20,tokenizer=tok)

In [None]:
from pprint import pprint
target_index = 1
print('target_index: ', target_index)
input_ids, attention_mask, labels, hard_position_type_ids, position_ids = dst[target_index]
prompt = tok.decode(input_ids)
pprint(prompt)
print('seq_len: ', input_ids.shape[-1])

tokens = tok.batch_decode(input_ids)
tokens = [t.replace('\n', '\\n') for t in tokens]
tokens = [f"{t}-{int(hard_position_type_ids[i])}-{position_ids[i]}-{'label' if int(labels[i])!=-100 else 'not'}" for i,t in enumerate(tokens)]
attention_mask = attention_mask.numpy().tolist()
attention_mask = [[int(i) for i in row] for row in attention_mask]

matrix = attention_mask
row_index = tokens
col_index = tokens

max_row_length = max(len(row) for row in row_index)
max_col_length = max(len(col) for col in col_index)

# 打开文件
with open(f'dst_{target_index}.txt', 'w') as f:
    f.write(prompt+ '\n\n')
    f.write('attention_mask:\n0:non-entity tokens，1:entity tokens, 2:triplet tokens, 3:triplet target tokens\n\n')
    # 写入列索引
    f.write(" " * max_row_length + "  " + "  ".join(col.ljust(max_col_length) for col in col_index) + '\n')
    # 写入行索引和每行的值
    for j, row in enumerate(matrix):
        f.write(row_index[j].ljust(max_row_length) + "  " + "  ".join(str(val).ljust(max_col_length) for val in row) + '\n')
        
# 创建DataFrame对象
df = pd.DataFrame(matrix, index=row_index, columns=col_index)

# 将DataFrame写入CSV文件
df.to_csv(f'dst_{target_index}.csv')

# 读取BIOS

In [None]:
from tqdm.auto import tqdm
import igraph as ig
import matplotlib.pyplot as plt
from collections import defaultdict
import marisa_trie
import re
import os


print("Loading BIOS...")
# relations = open("data/bios_v2.2_release/CoreData/Relations.txt").readlines()
uni_relations = open("data/bios_v2.2_release/CoreData/UniRelations.txt").readlines()
concept_terms = open("data/bios_v2.2_release/CoreData/ConceptTerms.txt").readlines()
definitions = open("data/bios_v2.2_release/CoreData/Definitions.txt").readlines()
semantic_types = open("data/bios_v2.2_release/CoreData/SemanticTypes.txt").readlines()

print("Building concept2term...")
concept2term = defaultdict(list)
for line in tqdm(concept_terms):
    ls = line.strip().split('|')
    concept2term[ls[0]].append(ls[2])

print("Building term2concept...")
if os.path.exists('data/term2concept.marisa'):
    term2concept = marisa_trie.BytesTrie()
    term2concept.load("data/term2concept.marisa")
else:
    keys = []
    values = []
    for line in tqdm(concept_terms):
        ls = line.strip().split('|')
        keys.append(ls[2])
        values.append(bytes(ls[0], encoding='utf-8'))
    term2concept = marisa_trie.BytesTrie(zip(keys,values))
    term2concept.save('data/term2concept.marisa')

print("Building node2idx...")
node2idx = {}
for line in tqdm(uni_relations):
    rid, head, tail, relid, rel = line.strip().split("|")
    if not node2idx.get(head):
        node2idx[head] = len(node2idx)
    if not node2idx.get(tail):
        node2idx[tail] = len(node2idx)

print("Building idx2node...")
idx2node = [x[0] for x in sorted(node2idx.items(), key=lambda x: x[1])]

print("Building edges...")
edges = {}
for rl in tqdm(uni_relations):
    ls = rl.strip().split("|")
    edges[(node2idx[ls[1]],node2idx[ls[2]])]=ls[-1]

print("Building graph...")
g = ig.Graph(n=len(node2idx), edges=list(edges.keys()), directed=True)
g.simplify()

print("Adding node descriptions...")
contains_chinese = lambda text:  bool(re.compile(r'[\u4e00-\u9fa5]').search(text))
for i in tqdm(range(len(g.vs))):
    g.vs[i]["cid"] = idx2node[i]
    g.vs[i]["name"] = concept2term[idx2node[i]][0]
    color = "blue"
    neighbors_num = len(g.neighbors(i))
    if neighbors_num > 10:
        color = "green"
    if neighbors_num > 50:
        color = "orange"
    if neighbors_num > 100:
        color = "red"
    if neighbors_num > 1000:
        color = "purple"
    g.vs[i]["color"] = color
    for term in concept2term[idx2node[i]]:
        if contains_chinese(term):
            g.vs[i]["name"] = term
            break

print("Adding edge descriptions...")
for i in tqdm(range(len(g.es))):
    s,t = g.es[i].source, g.es[i].target
    g.es[i]["name"] = edges[(s,t)]

undi_g = g.as_undirected()

In [None]:
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize']=(15,15)

def get_cid(str):
    if term2concept.get(str.lower()):
        return term2concept[str.lower()][0].decode('utf-8')
    return None


def get_n_hop_neighbors_subgraph(str,n=2):
    cid = get_cid(str)
    if not cid:
        print("Entity Not found in KG")
        return
    print(f'cid: {cid} idx: {node2idx[cid]} term: {g.vs[node2idx[cid]]["name"]}')

    root = node2idx[cid]
    g.vs[root]['hop'] = 0
    all_neighbors = [root]
    for i in range(1,n+1):
        i_hop_neighbors = []
        for node in all_neighbors:
            node_neighbors = g.neighbors(node)
            if len(node_neighbors) > 100:
                continue
            node_neighbors = node_neighbors[:min(len(node_neighbors), 5)]
            i_hop_neighbors.extend(node_neighbors)
        i_hop_neighbors = list(set(i_hop_neighbors))
        print(f'{i}_hop_neighbors num: {len(i_hop_neighbors)}')
        for x in i_hop_neighbors:
            if 'hop' not in g.vs[x].attributes().keys() or g.vs[x]['hop'] is None:
                g.vs[x]['hop'] = i
        all_neighbors.extend(i_hop_neighbors)
        all_neighbors = list(set(all_neighbors))
    print('all_neighbors: ', all_neighbors)
    subgraph = g.subgraph(all_neighbors)
    for edge in subgraph.es:
        if abs(subgraph.vs[edge.source]['hop'] - subgraph.vs[edge.target]['hop']) != 1:
            subgraph.delete_edges(edge)
    for node in all_neighbors:
        if 'hop' in g.vs[node].attributes():
            del g.vs[node]['hop']
    return subgraph

def get_n_str_subgraph(str_list):
    node_idxs = list(set([node2idx[get_cid(str)] for str in str_list if get_cid(str)]))
    print('node_idxs: ', node_idxs)
    old_colors = [g.vs[i]['color'] for i in node_idxs]
    for i in node_idxs:
        g.vs[i]['color'] = 'white'
    paths = []
    for i in range(len(node_idxs)):
        for j in range(i+1, len(node_idxs)):
            paths.extend(g.get_shortest_paths(node_idxs[i], node_idxs[j], mode=ig.ALL))
    all_nodes = list(set([n for path in paths for n in path]))
    print('node num: ', len(all_nodes))
    subgraph = g.subgraph(all_nodes)
    for i,color in zip(node_idxs,old_colors):
        g.vs[i]['color'] = color
    return subgraph


In [None]:
get_cid("SAH")

In [None]:
concept2term.get(get_cid("PNA"))

In [None]:
subgraph = get_n_str_subgraph(["periodontal ligament",
            "Hyalinization",
            "Osteoclastic activity",
            "tooth",
            "Crest bone resorption"])
ig.plot(subgraph,vertex_label=subgraph.vs["name"], edge_label=subgraph.es["name"], layout=subgraph.layout("kk"),backend='matplotlib')

In [None]:
subgraph = get_n_hop_neighbors_subgraph("tropicamide", 2)
ig.plot(subgraph,vertex_label=subgraph.vs["name"], edge_label=subgraph.es["name"], layout=subgraph.layout("kk"),backend='matplotlib')

In [None]:
rank = g.pagerank(directed=True, damping=0.85, )
sorted_indices = sorted(range(len(rank)), key=lambda i: rank[i], reverse=True)

In [None]:
for i in range(100):
    print('Rank: %d, Score: %.4f, Node: %s' % (i+1, rank[sorted_indices[i]], g.vs[sorted_indices[i]]['name']))