In [35]:
import re
import os 
os.chdir(r'A:\Desktop\COMP7600\dataset_label_fine-tune\retrieve_part')
import jsonlines
from datasets import Dataset, load_dataset, DatasetDict
import pandas as pd


In [42]:
def load_data():
    with open('val.txt', 'r') as file:
        data = file.read()
    return data

def find_sentences(text):
    pattern = re.compile(r'(?m)^(?:\d+\.|# \d+\.)(.*)')
    matches = pattern.findall(text)
    return [match.strip() for match in matches]

def storage(text):
    with open('retrieve_data.txt', 'a') as file:
        for item in text:
            if len(item.split(' ')) > 5:
                file.write(item + '\n')

def convert_formats(data):
    converted_data = []
    for doc in data:
        if doc['entities'] == []:
            continue
        temp = {}
        temp['id'] = doc['id']
        temp['text'] = doc['text']
        temp_total = []
        temp_entity = []

        for e in doc['entities']:
            start_point = e[0]
            end_point = e[1]
            temp_entity.append(temp['text'][start_point:end_point])
        temp['entity_list'] = temp_entity
        converted_data.append(temp)
    return converted_data

def load_model_and_tokenizer():
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
    )
    model = AutoModelForCausalLM.from_pretrained(
        './models/Llama3-Med42-8B',
        device_map={"":0},
        trust_remote_code=True,
        quantization_config=bnb_config
    )

    tokenizer = AutoTokenizer.from_pretrained('./models/Llama3-Med42-8B')
    # tokenizer.pad_token = tokenizer.eos_token
    return model, tokenizer

def prompt_generation(data):
    dataset = []

    INSTRUCTION = 'Identify the important entities in the question that you need further extral knowledge to answer.'
    SYSTEM_PROMPT = """
    The following is an unstructured question, please extract the important entities in the text, and the output should follow the following format: [entity1, entity2, entity3]. No more note or explain is needed, only output a list.
    This extracted entities are used for searching further help and knowledge in an extal database. Thus, only find the entities that you are not sure or not familiar with. If you cannot find any entities that satisfy the above requirements, please output: []. No more note or explan is needed.
    """.strip().strip('\n')
    for i in range(len(data)):
        dataset.append({
            'input': data[i]['text'],
            'output': str(data[i]['entity_list']).replace("'", ''),
            'prompt': '### Instruction: \n' + INSTRUCTION + '\n\n### System Prompt: \n' +\
             SYSTEM_PROMPT + '\n\n### Input: \n' + data[i]['text'] + '\n\n### Response:'
        })
    with open('test.txt', 'w') as f:
        f.write(str(dataset))
    return Dataset.from_pandas(pd.DataFrame(data=dataset))



In [43]:
all_file = './jiarenmen.jsonl'
all_data = []
with open(all_file, 'r', encoding='utf-8') as file:
    for item in jsonlines.Reader(file):
        all_data.append(item)

In [None]:
train_data = DatasetDict({
    'tarin': prompt_generation(convert_formats(all_data))
}).shuffle(42)

In [None]:
model, tokenizer = load_model_and_tokenizer()
model.config.quantization_config.to_dict()

In [None]:
lora_alpha = 16
lora_dropout = 0.1
lora_r = 64

peft_config = LoraConfig(
    r=lora_r,
    lora_alpha=lora_alpha,
    target_modules=["q_proj", "k_proj", "v_proj", "dense"],
    lora_dropout=lora_dropout,
    bias="none",
    task_type="CAUSAL_LM"
)

In [None]:
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

In [None]:
output_dir = "./results"
per_device_train_batch_size = 1
gradient_accumulation_steps = 2
optim = "paged_adamw_32bit"
save_steps = 1
num_train_epochs = 4
logging_steps = 1
learning_rate = 2e-4
max_grad_norm = 0.3
max_steps = 20
warmup_ratio = 0.03
lr_scheduler_type = "cosine"

In [None]:
training_arguments = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim=optim,
    num_train_epochs=num_train_epochs,
    save_steps=save_steps,
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    fp16=True,
    max_grad_norm=max_grad_norm,
    max_steps=max_steps,
    warmup_ratio=warmup_ratio,
    group_by_length=True,
    lr_scheduler_type=lr_scheduler_type,
)

In [None]:
max_seq_length = 2048
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset['train'],
    eval_dataset=dataset["validation"],
    peft_config=peft_config,
    dataset_text_field="prompt",
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    args=training_arguments,
)

In [None]:
trainer.train()