In [None]:
import json
import torch
import re
import pandas as pd
from datasets import Dataset, load_dataset, DatasetDict
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    DataCollatorForLanguageModeling,
    BitsAndBytesConfig
)
import bitsandbytes
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer
import jsonlines
import ast
import os
from tqdm import tqdm


In [None]:
def load_model_and_tokenizer():
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
    )
    model = AutoModelForCausalLM.from_pretrained(
        '/root/autodl-tmp/models/Llama3-Med42-8B',
        device_map={"":0},
        trust_remote_code=True,
        quantization_config=bnb_config
    )

    tokenizer = AutoTokenizer.from_pretrained('/root/autodl-tmp/models/Llama3-Med42-8B')
    lora_config = LoraConfig.from_pretrained('/root/autodl-tmp/resultss/checkpoint-1500')
    test_model = get_peft_model(model, lora_config)
    return test_model, tokenizer

def load_data():
    root_directory = "/root/autodl-tmp/filtrate_chunks"
    all_txt_contents = {}
    all_file = '/root/autodl-tmp/output.jsonl'
    all_exist_data = []
    with open(all_file, 'r', encoding='utf-8') as file:
        for item in jsonlines.Reader(file):
            all_exist_data.append(item['id'])
    for dirpath, dirnames, filenames in os.walk(root_directory):
        for filename in filenames:
            if filename.endswith(".txt"):
                file_path = os.path.join(dirpath, filename)
                with open(file_path, "r", encoding="utf-8") as file:
                    content = file.read()
                    title = content.split('\n')[0]
                    content = '\n'.join(content.split('\n')[1::])
                    if title not in all_exist_data:
                        all_txt_contents[title] = content
                    else:
                        pass
    return all_txt_contents

def prompt_data(test_model, tokenizer, test_data):
    test_text = """### Instruction: 
Identify the entities in the sentence and the relationships between them.
The following is an unstructured question, please extract the relationship between the possible entities in the text, and the output should follow the following format: {'relation_type': 'test_relation', 'Entity1_name': 'test1_name', 'Entity1_type': 'test1_type', 'Entity2_name': 'test2_name', 'Entity2_type': 'test2_type}
You can only use the following entity types and relation types:
    All entity types are: {clinical manifestations, materials, description, instrument, dosage, symptom, disease, prevention, oral part, causes, population, medicine, treatment, frequency, usage, examination} (seperated by ,)
    All relationship types are: {medicine_contraindicates_population, treatment_has_description, is, medicine_contraindication_disease, disease_caused_disease, symptom_has_description, medicine_side-effect_symptom, medication-information_frequency, disease_usually happens_population, disease_has_clinical manifestations, clinical manifestations_happens at_oral part, Causes_disease, population_use_as_alternative_medicine, symptom_oralpart, clinical manifestations_use_treatment, examination_has_description, symptom_cause_symptom, medicine_medication-information, disease_has_description, oral_part_discription, examination_at_oral part, medicine_side-effect_disease, disease_symptom, disease_oral-part, disease_examination, prevent, disease_treatment, medicine_has_description, disease_lack response_examination, treatment_frequency, medicine_reduce_symptom, materials_use_as_medicine, instrument_discription, parent-child, medicine_treats_disease, clinical manifestations_from_examination, materials_has_description} (seperated by ,)
Only extract the entities and relation that related with patient inquiry dental medical. Some relation type identify the two entity types that it connects, like medicine_contraindication_disease connect medicine and disease. Please make sure that relationship types like the above are connected to the correct entity types and the entities name cannot be the same as the entities type. If anything is related to the picture or illustration that do not appear in the input text, please ignore it.
### Input: 
""" + test_data + \
"""
### Response: 
"""
    device = "cuda:0"
    inputs = tokenizer(test_text, return_tensors="pt").to(device)
    outputs = test_model.generate(**inputs, max_new_tokens=3000)
    response = '[' + "},".join(tokenizer.decode(outputs[0], skip_special_tokens=True).split('\n### Response: ')[-1].strip().split('}')[:-1:]) + '}' + ']'
    return response

def storage(title, output):
    with open("/root/autodl-tmp/output.jsonl", "a") as file:
        list_data = ast.literal_eval(response)
        dict_data = {
            'id': title,
            'response': list_data
        }
        json_line = json.dumps(dict_data)
        file.write(json_line + "\n")

In [None]:
test_model, tokenizer = load_model_and_tokenizer()
all_data = load_data()
for title, content in tqdm(all_data.items(), desc="Processing data", ncols=100):
    try:
        response = prompt_data(test_model, tokenizer, content)
        storage(title, response)
        print(title, response)
    except:
        pass


In [15]:
test = """[
{'relation_type': 'disease_caused_disease', 'Entity1_name': 'oral pathosis', 'Entity1_type': 'disease', 'Entity2_name': 'dental medical', 'Entity2_type': 'disease'},
{'relation_type': 'examination_at_oral part', 'Entity1_name': 'Surgical endodontic treatment', 'Entity1_type': 'treatment', 'Entity2_name': 'apical region', 'Entity2_type': 'oral part'}
]"""

dict_data = ast.literal_eval(test)
print(dict_data[0])

{'relation_type': 'disease_caused_disease', 'Entity1_name': 'oral pathosis', 'Entity1_type': 'disease', 'Entity2_name': 'dental medical', 'Entity2_type': 'disease'}
