In [25]:
import json
import torch
import re
import pandas as pd
from datasets import Dataset, load_dataset, DatasetDict
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    DataCollatorForLanguageModeling,
    BitsAndBytesConfig
)
import bitsandbytes
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer
import jsonlines
import ast


In [2]:
all_file = './all.jsonl'
all_data = []
with open(all_file, 'r', encoding='utf-8') as file:
    for item in jsonlines.Reader(file):
        all_data.append(item)

In [3]:
def convert_formats(data):
    converted_data = []
    entity_id_to_name = {}
    entity_name_to_type = {}
    for doc in data:
        temp = {}
        temp['id'] = doc['id']
        temp['test'] = doc['text']
        temp_total = []
        temp_entity = []
        temp_realtion = []

        for e in doc['entities']:
            identifier_name = e['id']
            type_name = e['label']
            start_point = e['start_offset']
            end_point = e['end_offset']
            temp_entity.append({
                'id_name': identifier_name,
                'type_name': type_name,
                'start_offset': start_point,
                'end_offset': end_point
            })
        for re_test in doc['relations']:
            id_name = re_test['id']
            from_id = re_test['from_id']
            to_id = re_test['to_id']
            relation_type = re_test['type']
            temp_realtion.append({
                'id_name': id_name,
                'type_name': relation_type,
                'from_id': from_id,
                'to_id': to_id
            })
        temp['entity_list'] = temp_entity
        temp['relation_list'] = temp_realtion
        converted_data.append(temp)
    return {
        'data': converted_data
    }

def combine_text(data):
    test_list = []
    total_entity_list = []
    total_relation_list = []
    for m in data:
        temp_test = m['test']
        temp_entities = []
        temp_relations = []
        entity_list = m['entity_list']
        relation_list = m['relation_list']
        id_to_name = {}
        name_to_type = {}

        if entity_list == [] or entity_list == {} or entity_list == '' or entity_list == None:
            temp_entities.append({
                'id_name': '',
                'type_name': '',
                'test_name': ''
            })
        else:
            for n in entity_list:
                temp_entities.append({
                    'id_name': n['id_name'],
                    'type_name': n['type_name'],
                    'test_name': temp_test[n['start_offset']:n['end_offset']]
                })
                id_to_name[n['id_name']] = temp_test[n['start_offset']:n['end_offset']]
                name_to_type[temp_test[n['start_offset']:n['end_offset']]]= n['type_name']
        total_entity_list.append(temp_entities)
        test_list.append({
            'id': m['id'],
            'test': '\n'.join(temp_test.split('\n')[1::])
        })

        if relation_list == [] or relation_list == {} or relation_list == '' or relation_list == None:
            temp_relations.append({
                'id_name': '',
                'type_name': '',
                'entity1_name': '',
                'entity1_type': '',
                'entity2_name': '',
                'entity2_type': ''
            })
        else:
            for r in relation_list:
                temp_relations.append({
                    'id_name': r['id_name'],
                    'type_name': r['type_name'],
                    'entity1_name': id_to_name[r['from_id']],
                    'entity1_type': name_to_type[id_to_name[r['from_id']]],
                    'entity2_name': id_to_name[r['to_id']],
                    'entity2_type': name_to_type[id_to_name[r['to_id']]]
                })
        total_relation_list.append(temp_relations)
    return {
        'test_list': test_list, 
        'total_entity_list': total_entity_list, 
        'total_relation_list': total_relation_list
    }

def load_model_and_tokenizer():
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
    )
    model = AutoModelForCausalLM.from_pretrained(
        './models/Llama3-Med42-8B',
        device_map={"":0},
        trust_remote_code=True,
        quantization_config=bnb_config
    )

    tokenizer = AutoTokenizer.from_pretrained('./models/Llama3-Med42-8B')
    # tokenizer.pad_token = tokenizer.eos_token
    return model, tokenizer
model, tokenizer = load_model_and_tokenizer()

def find_max_prompt_tokens(dataset, tokenizer):
    # Initialize variables to keep track of the maximum tokens and the corresponding prompt
    max_tokens = 0
    max_prompt = ""

    # Iterate over the dataset and tokenize the prompt
    for data in dataset:
        prompt = data['prompt']
        tokens = tokenizer.encode(prompt)
        token_count = len(tokens)

        # Update max_tokens and max_prompt if the current one is longer
        if token_count > max_tokens:
            max_tokens = token_count
            max_prompt = prompt

    return max_tokens, max_prompt

def prompt_generation(test_list, total_entity_list, total_relation_list):
    dataset = []
    global_entity_labels = []
    global_relation_labels = []

    for e in total_entity_list:
        for ee in e:
            if ee['type_name'] != '':
                global_entity_labels.append(ee['type_name'])
    for r in total_relation_list:
        for rr in r:
            if rr['type_name'] != '':
                global_relation_labels.append(rr['type_name'])
    global_entity_labels = list(set(global_entity_labels))
    global_relation_labels = list(set(global_relation_labels))

    INSTRUCTION = 'Identify the entities in the sentence and the relationships between them.'
    SYSTEM_PROMPT = """
    The following is an unstructured question, please extract the relationship between the possible entities in the text, and the output should follow the following format: {'relation_type': 'test_relation', 'Entity1_name': 'test1_name', 'Entity1_type': 'test1_type', 'Entity2_name': 'test2_name', 'Entity2_type': 'test2_type}
    You can only use the following entity types and relation types:
        All entity types are: {clinical manifestations, materials, medication information, description, instrument, dosage, symptom, disease, prevention, oral part, causes, population, medicine, treatment, frequency, usage, examination} (seperated by ,)
        All relationship types are: {medicine_contraindicates_population, treatment_has_description, is, medicine_contraindication_disease, disease_caused_disease, symptom_has_description, medicine_side-effect_symptom, medication-information_frequency, disease_usually happens_population, disease_has_clinical manifestations, clinical manifestations_happens at_oral part, Causes_disease, population_use_as_alternative_medicine, symptom_oralpart, clinical manifestations_use_treatment, examination_has_description, symptom_cause_symptom, medicine_medication-information, disease_has_description, oral_part_discription, examination_at_oral part, medicine_side-effect_disease, disease_symptom, disease_oral-part, disease_examination, prevent, disease_treatment, medicine_has_description, disease_lack response_examination, treatment_frequency, medicine_reduce_symptom, materials_use_as_medicine, instrument_discription, parent-child, medicine_treats_disease, clinical manifestations_from_examination, materials_has_description} (seperated by ,)
    Only extract the entities and relation that related with dental medical. Some relation type identify the two entity types that it connects, like medicine_contraindication_disease connect medicine and disease. Please make sure that relationship types like the above are connected to the correct entity types. If anything is related to the picture or illustration that do not appear in the input text, please ignore it.
    """.strip().strip('\n')
    for i in range(len(test_list)):
        output = []
        for r in total_relation_list[i]:
            output.append(str({
                'relation_type': r['type_name'], 
                'Entity1_name': r['entity1_name'], 
                'Entity1_type': r['entity1_type'], 
                'Entity2_name': r['entity2_name'], 
                'Entity2_type': r['entity2_type']
            }))
        dataset.append({
            'input': test_list[i]['test'],
            'output': f"{', '.join(output)}",
            'prompt': '### Instruction: \n' + INSTRUCTION + '\n' + SYSTEM_PROMPT + \
            '\n\n### Input: \n' + test_list[i]['test'] + '\n\n### Response:'
        })
    with open('test.txt', 'w') as f:
        for i in dataset:
            f.write(str(i['prompt']) + str(i['output']) + '\n')
        f.write(str(find_max_prompt_tokens(dataset, tokenizer)))
    return Dataset.from_pandas(pd.DataFrame(data=dataset))

def print_trainable_parameterss(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}")


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [4]:
dataset = DatasetDict({
    'train': prompt_generation(**combine_text(**convert_formats(all_data)))
}).shuffle(42)

In [5]:
model, tokenizer = load_model_and_tokenizer()
model.config.quantization_config.to_dict()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

{'quant_method': <QuantizationMethod.BITS_AND_BYTES: 'bitsandbytes'>,
 '_load_in_8bit': False,
 '_load_in_4bit': True,
 'llm_int8_threshold': 6.0,
 'llm_int8_skip_modules': None,
 'llm_int8_enable_fp32_cpu_offload': False,
 'llm_int8_has_fp16_weight': False,
 'bnb_4bit_quant_type': 'nf4',
 'bnb_4bit_use_double_quant': True,
 'bnb_4bit_compute_dtype': 'float16',
 'bnb_4bit_quant_storage': 'uint8',
 'load_in_4bit': True,
 'load_in_8bit': False}

In [6]:
lora_alpha = 16
lora_dropout = 0.1
lora_r = 64

peft_config = LoraConfig(
    r=lora_r,
    lora_alpha=lora_alpha,
    target_modules=["q_proj", "k_proj", "v_proj", "dense"],
    lora_dropout=lora_dropout,
    bias="none",
    task_type="CAUSAL_LM"
)

In [7]:
model = get_peft_model(model, peft_config)

model.print_trainable_parameters()

trainable params: 37,748,736 || all params: 8,068,009,984 || trainable%: 0.4679


In [8]:
output_dir = "./results"
per_device_train_batch_size = 1
gradient_accumulation_steps = 2
optim = "paged_adamw_32bit"
save_steps = 30
num_train_epochs = 4
logging_steps = 1
lr_scheduler_kwargs = {
    'num_cycles': 5
    }
learning_rate = 2e-4
max_grad_norm = 0.3
max_steps = 1500
warmup_ratio = 0.03
lr_scheduler_type = "cosine"

In [9]:
training_arguments = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim=optim,
    num_train_epochs=num_train_epochs,
    save_steps=save_steps,
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    fp16=True,
    max_grad_norm=max_grad_norm,
    max_steps=max_steps,
    warmup_ratio=warmup_ratio,
    group_by_length=True,
    lr_scheduler_type=lr_scheduler_type,
    lr_scheduler_kwargs = lr_scheduler_kwargs
)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [10]:
max_seq_length = 2048
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset['train'],
    peft_config=peft_config,
    dataset_text_field="prompt",
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    args=training_arguments,
)


Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Map:   0%|          | 0/402 [00:00<?, ? examples/s]

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
max_steps is given, it will override any value given in num_train_epochs


In [None]:
trainer.train()

In [42]:
model_to_save = trainer.model.module if hasattr(trainer.model, 'module') else trainer.model  # Take care of distributed/parallel training
model_to_save.save_pretrained("outputs")

In [12]:
# lora_config = LoraConfig.from_pretrained('outputs')
lora_config = LoraConfig.from_pretrained('/root/autodl-tmp/results/checkpoint-1500')
test_model = get_peft_model(model, lora_config)

In [31]:
test_text = """### Instruction: 
Identify the entities in the sentence and the relationships between them.
The following is an unstructured question, please extract the relationship between the possible entities in the text, and the output should follow the following format: {'relation_type': 'test_relation', 'Entity1_name': 'test1_name', 'Entity1_type': 'test1_type', 'Entity2_name': 'test2_name', 'Entity2_type': 'test2_type}
You can only use the following entity types and relation types:
    All entity types are: {clinical manifestations, materials, description, instrument, dosage, symptom, disease, prevention, oral part, causes, population, medicine, treatment, frequency, usage, examination} (seperated by ,)
    All relationship types are: {medicine_contraindicates_population, treatment_has_description, is, medicine_contraindication_disease, disease_caused_disease, symptom_has_description, medicine_side-effect_symptom, medication-information_frequency, disease_usually happens_population, disease_has_clinical manifestations, clinical manifestations_happens at_oral part, Causes_disease, population_use_as_alternative_medicine, symptom_oralpart, clinical manifestations_use_treatment, examination_has_description, symptom_cause_symptom, medicine_medication-information, disease_has_description, oral_part_discription, examination_at_oral part, medicine_side-effect_disease, disease_symptom, disease_oral-part, disease_examination, prevent, disease_treatment, medicine_has_description, disease_lack response_examination, treatment_frequency, medicine_reduce_symptom, materials_use_as_medicine, instrument_discription, parent-child, medicine_treats_disease, clinical manifestations_from_examination, materials_has_description} (seperated by ,)
Only extract the entities and relation that related with patient inquiry dental medical. Some relation type identify the two entity types that it connects, like medicine_contraindication_disease connect medicine and disease. Please make sure that relationship types like the above are connected to the correct entity types and the entities name cannot be the same as the entities type. If anything is related to the picture or illustration that do not appear in the input text, please ignore it.
### Input: 

# Flare-ups

a. This is a true emergency and is so severe pat an unscheduled visit and treatment is required.
b. A history of preoperative pain or swelling is pe best predictor of “flare-up” emergencies.
c. No relationship exists between flare-ups and treatment procedures (i.e., single or multiple visits).
d. Treatment generally involves complete cleaning and shaping of canals, placement of intracanal medicament, and prescription of analgesic.

# Sterilization and Asepsis

# Rationale for sterilization

1. Endodontic instruments are contaminated with blood, soft and hard tissue remnants, bacteria, and bacterial by-products.
2. Instruments must be cleaned often and disinfected during the procedure and sterilized afterward.
3. Because instruments may be contaminated when new, they must be sterilized before initial use.

# Types of sterilization

1. Glutaraldehyde.
2. Pressure sterilization.
3. Dry heat sterilization.
# C. Disinfection

1. Surface disinfection during canal débridement is accomplished by using a sponge soaked in 70% isopropyl alcohol or proprietary quaternary ammonium solutions.

2. Files can be thrust briskly in and out of this sponge to dislodge debris and contact the disinfectant.

3. This procedure cleans but does not disinfect instruments.

# 2.5 Radiographic Techniques

# A. Diagnostic radiographs

1. Angulation

a. Paralleling technique—the most accurate radiographs are made using a paralleling technique.

### Response: 
"""

device = "cuda:0"
inputs = tokenizer(test_text, return_tensors="pt").to(device)
outputs = test_model.generate(**inputs, max_new_tokens=2000)
response = "}".join(tokenizer.decode(outputs[0], skip_special_tokens=True).split('\n### Response: ')[-1].strip().split('}')[:-1:]) + '}'
# print(tokenizer.decode(outputs[0], skip_special_tokens=True))
with open("output.jsonl", "a") as file:
    dict_data = ast.literal_eval(response)
    json_line = json.dumps(dict_data)
    file.write(json_line + "\n")

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


{'relation_type': 'treatment_has_description', 'Entity1_name': 'Flare-ups', 'Entity1_type': 'disease', 'Entity2_name': 'complete cleaning and shaping of canals, placement of intracanal medicament, and prescription of analgesic', 'Entity2_type': 'treatment'}
<class 'dict'>
