In [15]:
import neo4j
import jsonlines
import csv
import pandas as pd
from py2neo import Graph, Node, Relationship
import re
import logging
from tqdm import tqdm
import networkx as nx
import os
os.chdir(r"A:\Desktop\COMP7600\dataset_label_fine-tune\retrieve_part")
import glob
import pandas as pd
import matplotlib.pyplot as plt
import faiss
from transformers import ( 
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    DataCollatorForLanguageModeling,
    BitsAndBytesConfig
)
import bitsandbytes
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer
import torch
import numpy as np
from sentence_transformers import SentenceTransformer
from collections import defaultdict



In [13]:
def model_load():
    embedder_model = SentenceTransformer("A:\Desktop\COMP7600\dataset_label_fine-tune\models\pubmedbert-base-embeddings")
    return embedder_model

def get_entity_embeddings(entities):
    embeddings = []
    for entity in tqdm(entities, desc="Processing entity embeddings"):
        output_embedding = embedder_model.encode(entity)
        embeddings.append(output_embedding)
    return np.vstack(embeddings)

def L2_distance_search(data, k=5, threshold=50):
    entity_embeddings = get_entity_embeddings(data)

    similar_entities = []
    for idx, entity_embedding in enumerate(tqdm(entity_embeddings, desc="Processing similarity search")):
        query_vector = np.array([entity_embedding], dtype=np.float32)
        distances, result_ids = index.search(query_vector, k=k)

        for i, dist in enumerate(distances[0]):
            if dist <= threshold:

                entity_name = data[idx]
                similar_entities.append((result_ids[0][i], entity_name, dist))
    return similar_entities

def normalize_embeddings(embeddings):
    norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
    norms = np.where(norms == 0, 1, norms)  # 避免除以零
    return embeddings / norms

def cosine_distance_search(data, k=5, threshold=0.6):
    entity_embeddings = get_entity_embeddings(data)

    similar_entities = []
    for idx, entity_embedding in enumerate(tqdm(entity_embeddings, desc="Processing similarity search")):
        query_vector = np.array(normalize_embeddings([entity_embedding]), dtype=np.float32)
        distances, result_ids = index.search(query_vector, k=k)
        
        for i, dist in enumerate(distances[0]):
            similarity_score = dist
            
            if similarity_score >= threshold:
                entity_name = data[idx]
                similar_entities.append((result_ids[0][i], entity_name, similarity_score))
                
    return similar_entities

def read_id_table():
    df = pd.read_csv('table.csv')
    return df

def get_type(name):
    query = """
    MATCH (n)
    WHERE n.name IN """ + str(name) + """
    RETURN n.name AS name, labels(n) AS types"""
    result = graph.run(query)
    print(result)
    return result

def get_relation_list():
    df = pd.read_csv('./kg1.csv')

    type_relation_dict = defaultdict(set)  # 使用 set 来自动去重

    for _, row in df.iterrows():
        subject_type = row['subject_type']
        relation = row['relation']
        type_relation_dict[subject_type].add(relation)

    for _, row in df.iterrows():
        object_type = row['object_type']
        relation = row['relation']
        type_relation_dict[object_type].add(relation)

    result_dict = {key: len(list(value)) for key, value in type_relation_dict.items()}

    return result_dict


def get_static_list():
    df = pd.read_csv('./kg1.csv')

    type_relation_dict = defaultdict(set)  # 使用 set 来自动去重
    count_entity_table = {}
    count_relation_table = {}

    for _, row in df.iterrows():
        subject_name = row['subject']
        subject_type = row['subject_type']
        object_name = row['object']
        object_type = row['object_type']
        relation = row['relation']

        # if subject_type not in count_entity_table.keys():
        #     count_entity_table[subject_type] = 0
        # else:
        #     count_entity_table[subject_type] += 1

        # if object_type not in count_entity_table.keys():
        #     count_entity_table[object_type] = 0
        # else:
        #     count_entity_table[object_type] += 1

        # if object_type not in count_entity_table.keys():
        #     count_entity_table[object_type] = 0
        # else:
        #     count_entity_table[object_type] += 1

        if relation not in count_relation_table.keys():
            count_relation_table[relation] = 0
        else:
            count_relation_table[relation] += 1
        type_relation_dict[object_type].add(object_name)
        type_relation_dict[subject_type].add(subject_name)


    result_dict = {key: len(list(value)) for key, value in type_relation_dict.items()}

    return count_entity_table, count_relation_table, result_dict

def build_query(entities, relationships):
    entity_conditions = []
    for entity in entities:
        name = entity['name']
        types = entity['types']
        type_conditions = ''

        type_conditions = " OR ".join([f"(n:`{t}` AND ({''' OR '''.join([f'r:`{r}`' for r in relationships[t]])}))" for t in types])
        entity_conditions.append(f"(n.name = '{name}' AND ({type_conditions}))")

    entity_query = " OR ".join(entity_conditions)

    # relationship_types = []
    # for typ in relationships:
    #     relationship_types.extend(relationships[typ])
    # relationship_query = " OR ".join([f"r:`{rel}`" for rel in relationship_types])
    
    query = f"""
    MATCH (n)-[r]-(m)
    WHERE ({entity_query})
    RETURN n.name AS entity_name, labels(n) AS entity_type, m.name AS related_entity_name, labels(m) AS related_entity_type, type(r) AS relationship_type
    """
    return query

def generate_prompt(data):
    SYSTEM_PROMPT = f"""You are given a list of entities, their types and their relationships."""
    prompt = ''
    for entity in data:
        prompt += f"The entity ({entity['entity_name']}, type: {entity['entity_type'][0]}) and its related entity ({entity['related_entity_name']}, type: {entity['related_entity_type'][0]}) are related by the relationship type: {entity['relationship_type']}.\n"
    return SYSTEM_PROMPT + prompt.strip('\n')

def get_important_entities(question):
    test_text ="""### Instruction: 
    Identify the important entities in the question that you need further extral knowledge to answer.

    ### System Prompt: 
    The following is an unstructured question, please extract the important entities in the text, and the output should follow the following format: [entity1, entity2, entity3]. No more note or explain is needed, only output a list.
    This extracted entities are used for searching further help and knowledge in an extal database. Thus, only find the entities that you are not sure or not familiar with. If you cannot find any entities that satisfy the above requirements, please output: []. No more note or explan is needed.

    ### Input:
    """ + question + """
    ### Response:
    """


    device = "cuda:0"
    inputs = tokenizer(question, return_tensors="pt").to(device)
    outputs = test_model.generate(**inputs, max_new_tokens=2000)
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print(response)
    return response.split('### Response:')[-1].strip()

def cypher_generate(check_name):
    data = get_type(check_name)

    results = [{"name": record["name"], "type": record["types"]} for record in data]

    aggregated_data = defaultdict(list)
    
    for record in results:
        name = record["name"]
        types = record["type"]
        aggregated_data[name].extend(types)
    
    results = [{"name": name, "types": list(set(types))} for name, types in aggregated_data.items()]
    return results

def load_model_and_tokenizer():
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
    )
    model = AutoModelForCausalLM.from_pretrained(
        'A:\Desktop\COMP7600\dataset_label_fine-tune\models\Llama3-Med42-8B',
        device_map={"":0},
        trust_remote_code=True,
        quantization_config=bnb_config
    )
    tokenizer = AutoTokenizer.from_pretrained('A:\Desktop\COMP7600\dataset_label_fine-tune\models\Llama3-Med42-8B')
    lora_config = LoraConfig.from_pretrained('./checkpoint-50')
    test_model = get_peft_model(model, lora_config)
    # tokenizer.pad_token = tokenizer.eos_token
    return test_model, tokenizer




In [14]:
with open("kg_static.txt", "w") as f:
    f.write(str(get_static_list()))

In [59]:
#L2 distnace
df = read_id_table()
embedder_model = model_load()

loaded_embeddings = np.load('./combined_entity_embeddings.npy')
entities_ids = np.array(list(df.to_dict()['name'].keys()))

dimension = loaded_embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index_with_ids = faiss.IndexIDMap(index)
index_with_ids.add_with_ids(loaded_embeddings, entities_ids)

index.ntotal

8115

In [3]:
#cosine distnace
df = read_id_table()
embedder_model = model_load()

loaded_embeddings = np.load('./combined_entity_embeddings.npy')
normalize_embeddings = normalize_embeddings(loaded_embeddings)
entities_ids = np.array(list(df.to_dict()['name'].keys()))

dimension = loaded_embeddings.shape[1]
print(dimension)
index = faiss.IndexFlatIP(dimension)
index_with_ids = faiss.IndexIDMap(index)
index_with_ids.add_with_ids(normalize_embeddings, entities_ids)

index.ntotal

768


8115

: 

In [13]:
uri = "neo4j://101.34.58.20:7687"
username = "neo4j"
password = "Tzmt541881"

graph = Graph(uri, auth=(username, password))
result_dict = get_relation_list()

In [15]:
#检验连接是否成功
query = """
MATCH (n)
RETURN COUNT(DISTINCT n) AS distinct_entities_count"""
result = graph.run(query)
print(result)

 distinct_entities_count 
-------------------------
                    5942 



: 

In [20]:
print(get_relation_list())

{'disease': ['parent-child', 'medicine_side-effect_disease', 'prevent', 'is', 'medicine_contraindication_disease', 'disease_treatment', 'disease_caused_disease', 'disease_usually happens_population', 'disease_oral part', 'disease_lack response_examination', 'medicine_treats_disease', 'disease_symptom', 'causes_disease', 'disease_examination', 'disease_has_clinical manifestations', 'disease_has_description'], 'symptom': ['symptom_has_description', 'parent-child', 'prevent', 'is', 'symptom_cause_symptom', 'medicine_side-effect_symptom', 'medicine_reduce_symptom', 'symptom_oral part', 'disease_symptom', 'disease_has_clinical manifestations'], 'clinical manifestations': ['parent-child', 'clinical manifestations_from_examination', 'clinical manifestations_happens at_oral part', 'clinical manifestations_use_treatment', 'is', 'disease_has_clinical manifestations'], 'description': ['symptom_has_description', 'parent-child', 'is', 'examination_has_description', 'medicine_has_description', 'dise

In [40]:
# test_model, tokenizer = load_model_and_tokenizer()
# question = 'A biopsy specimen of the lower lip salivary glands showed replacement of parenchymal tissue by lymphocytes. The patient also had xerostomia and eratoconjunctivitis sicca. These findings are indicative of which of the following?'
# input_entity = get_important_entities(question)

input_entity = ['lower lip salivary glands', 'xerostomia', 'eratoconjunctivitis sicca']
results = cosine_distance_search(input_entity, k=10)
check_name = []

for i in results:
    check_name.append(df.to_dict()['name'][i[0]])

res = cypher_generate(check_name)

query = build_query(cypher_generate(check_name), result_dict)
query_results = graph.run(query).data()

final_prompt = generate_prompt(query_results)

Processing entity embeddings: 100%|██████████| 3/3 [00:00<00:00, 16.94it/s]
Processing similarity search: 100%|██████████| 3/3 [00:00<00:00, 750.59it/s]


 name                  | types         
-----------------------|---------------
 parotid gland         | ['disease']   
 major salivary glands | ['disease']   
 salivary glands       | ['treatment'] 

 name                  | types         
-----------------------|---------------
 parotid gland         | ['disease']   
 major salivary glands | ['disease']   
 salivary glands       | ['treatment'] 



In [41]:
with open('prompt.txt', 'w') as file:
    file.write(final_prompt)