In [None]:
# referred from https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Mistral/Supervised_fine_tuning_(SFT)_of_an_LLM_using_Hugging_Face_tooling.ipynb
from datasets import load_dataset
import json
from tqdm import tqdm
from itertools import chain
from datasets import DatasetDict
from transformers import AutoTokenizer, BitsAndBytesConfig
import torch
from trl import SFTTrainer
from peft import LoraConfig
from transformers import TrainingArguments
from helpers import *
from datasets import Dataset, DatasetDict
import pandas as pd


def read_jsonl_file(file_path):
    """
    Parses a JSONL (JSON Lines) file and returns a list of dictionaries.

    Args:
        file_path (str): The path to the JSONL file to be read.

    Returns:
        list of dict: A list where each element is a dictionary representing
            a JSON object from the file.
    """
    jsonl_lines = []
    with open(file_path, 'r', encoding="utf-8") as file:
        for line in file:
            json_object = json.loads(line)
            jsonl_lines.append(json_object)
            
    return jsonl_lines

def convert_to_template(data):
    messages = []
    messages.append({"role": "user", "content": data["question"]})
    messages.append({"role": "assistant", "content": data["answer"]})
    
    return tokenizer.apply_chat_template(messages, tokenize = False)

mesh_dataset = parse_dataset("CDR_TrainingSet.PubTator.txt")

In [None]:
model_id = "mistralai/Mistral-7B-Instruct-v0.2"

tokenizer = AutoTokenizer.from_pretrained(model_id)

# set pad_token_id equal to the eos_token_id if not set
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "right"

# Set reasonable default for models without max length
if tokenizer.model_max_length > 100_000:
  tokenizer.model_max_length = 512

In [None]:
prepared_dataset = []
system_prompt = "Answer the question factually and precisely."
entity_prompt = "What are the chemical and disease related entities present in this biomedical text?"
prepared_dataset = []
def prepare_instructions(elem):
    entities = []
    for x in elem["annotations"]:
        if x["text"] not in entities:
            entities.append(x["text"])
            
    return {"question": system_prompt + "\n" + entity_prompt + "\n" + elem["title"] + " " + elem["abstract"] , "answer": "The entities are:" + ",".join(entities)}
    

In [None]:
questions = [prepare_instructions(x) for x in tqdm(mesh_dataset)]
chat_format_questions = [{"text": convert_to_template(x)} for x in tqdm(questions)]

In [None]:
df = pd.DataFrame(chat_format_questions)
train_dataset = Dataset.from_pandas(df)

In [None]:
quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
)
device_map = {"": torch.cuda.current_device()} if torch.cuda.is_available() else None

model_kwargs = dict(
    torch_dtype=torch.bfloat16,
    use_cache=False, # set to False as we're going to use gradient checkpointing
    device_map=device_map,
    quantization_config=quantization_config,
)

In [None]:
output_dir = 'entity_finetune'

# based on config
training_args = TrainingArguments(
    bf16=True, # specify bf16=True instead when training on GPUs that support bf16
    do_eval=False,
    # evaluation_strategy="no",
    gradient_accumulation_steps=1,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    learning_rate=1.0e-04,
    log_level="info",
    logging_steps=5,
    logging_strategy="steps",
    lr_scheduler_type="cosine",
    max_steps=-1,
    num_train_epochs=5,
    output_dir=output_dir,
    overwrite_output_dir=True,
    per_device_eval_batch_size=1, # originally set to 8
    per_device_train_batch_size=8, # originally set to 8
    save_strategy="no",
    save_total_limit=None,
    seed=42,
)

# based on config
peft_config = LoraConfig(
        r=16,
        lora_alpha=16,
        lora_dropout=0.1,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)

trainer = SFTTrainer(
        model=model_id,
        model_init_kwargs=model_kwargs,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=train_dataset,
        dataset_text_field="text",
        tokenizer=tokenizer,
        packing = True,
        peft_config=peft_config,
        max_seq_length=tokenizer.model_max_length,
    )

## Train!

Finally, training is as simple as calling trainer.train()!

In [None]:
train_result = trainer.train()

## Saving the model

Next, we save the Trainer's state. We also add the number of training samples to the logs.

In [None]:
trainer.save_model(output_dir)

In [None]:
import json

def write_jsonl(data, filename):
    """
    Writes a list of dictionaries to a JSONL file.

    Parameters:
    - data (list of dict): The data to be written to the file.
    - filename (str): The name of the file to write to.
    """
    with open(filename, 'w') as file:
        for item in data:
            json_str = json.dumps(item)  # Convert dictionary to JSON string
            file.write(json_str + '\n')

In [None]:
write_jsonl(questions, "mesh_entity_finetune_data.jsonl")