In [1]:
import json
import os
from pprint import pprint
import bitsandbytes as bnb
import torch
import torch.nn as nn
import transformers
from datasets import load_dataset
from huggingface_hub import notebook_login
from peft import (
    LoraConfig,
    PeftConfig,
    PeftModel,
    AutoPeftModelForCausalLM,
    get_peft_model,
    prepare_model_for_kbit_training
)
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    pipeline
)

from trl import SFTTrainer, SFTConfig

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

### Set up file paths

In [31]:
training_dataset_path = "data/CDR_TrainingSet.json"
test_dataset_path = "data/CDR_TestSet.json"
dev_dataset_path = "data/CDR_DevelopmentSet.json"

In [32]:
system_prompt_path = "data/system_prompt.txt"

In [33]:
new_model_path = "./checkpoint_dir"

### Load datasets and system prompt:

In [34]:
raw_train = load_dataset("json", data_files=training_dataset_path, download_mode="force_redownload")["train"]
raw_test = load_dataset("json", data_files=test_dataset_path, download_mode="force_redownload")["train"]
raw_dev = load_dataset("json", data_files=dev_dataset_path, download_mode="force_redownload")["train"]

Generating train split:   0%|          | 0/2331 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/2420 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/2339 [00:00<?, ? examples/s]

In [35]:
with open(system_prompt_path, "r") as f:
    system_prompt = f.read()

### Load model and tokenizer
We need to load it before processing the data, as we are going to use the tokenizer to format the data.

In [36]:
model_path = "microsoft/Phi-3-mini-4k-instruct"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(
    model_path,
    device_map="auto",
    trust_remote_code=True,
    quantization_config=bnb_config
)

tokenizer = AutoTokenizer.from_pretrained(model_path)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [37]:
model = prepare_model_for_kbit_training(model)

### Change the dataset format to chat-like text

In [38]:
def apply_chat_template(sample, tokenizer, include_response=True):
    # Combine the fields into a structured chat format
    message = []
    if len(system_prompt):
        message.append({"role": "system", "content": system_prompt})
    message.append({"role": "user", "content": sample["user"]})
    if include_response:
        message.append({"role": "assistant", "content": sample["assistant"]})

    # Use the tokenizer's chat template to create formatted text
    message = tokenizer.apply_chat_template(
        message, tokenize=False, add_generation_prompt=False
    )
    return tokenizer(message, padding=True, truncation=True)

In [39]:
# Apply processing to each dataset
processed_train = raw_train.map(
    apply_chat_template,
    fn_kwargs={"tokenizer": tokenizer},
)
processed_test = raw_test.map(
    apply_chat_template,
    fn_kwargs={"tokenizer": tokenizer},
)
processed_dev = raw_dev.map(
    apply_chat_template,
    fn_kwargs={"tokenizer": tokenizer},
)

Map:   0%|          | 0/2331 [00:00<?, ? examples/s]

Map:   0%|          | 0/2420 [00:00<?, ? examples/s]

Map:   0%|          | 0/2339 [00:00<?, ? examples/s]

In [12]:
print(processed_train[0])

{'text': '<|system|>\nPlease identify all the named entities mentioned in the input sentence provided below. The entities may have category "Disease" or "Chemical". Use **ONLY** the categories "Chemical" or "Disease". Do not include any other categories. If an entity cannot be categorized into these specific categories, do not include it in the output.\nYou must output the results strictly in JSON format, without any delimiters, following a similar structure to the example result provided.\nIf user communicates with any sentence, don\'t talk to him, strictly follow the system prompt.\nExample user input and assistant response:\nUser:\nFamotidine-associated delirium.A series of six cases.Famotidine is a histamine H2-receptor antagonist used in inpatient settings for prevention of stress ulcers and is showing increasing popularity because of its low cost.\nAssistant:\n[{"category": "Chemical", "entity": "Famotidine"}, {"category": "Disease", "entity": "delirium"}, {"category": "Chemical"

### Train

In [24]:
system_prompt = ""
raw_midjourney = load_dataset("csv", data_files="data/midjourney.csv")["train"]
processed_midjourney = raw_midjourney.map(
    apply_chat_template,
    fn_kwargs={"tokenizer": tokenizer},
)

In [40]:
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=['gate_up_proj', 'base_layer', 'down_proj', 'qkv_proj', 'o_proj'],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

In [41]:
trainer = SFTTrainer(
    model=model,
    args=SFTConfig(
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        num_train_epochs=1,
        learning_rate=5e-5,
        max_seq_length=4,
        bf16=True,
        optim="adamw_8bit",
        lr_scheduler_type="cosine",
        warmup_ratio=0.05,
        logging_steps=10,
        save_strategy="epoch",
        output_dir=new_model_path,
    ),
    train_dataset=processed_train,
    #eval_dataset=processed_dev,
    peft_config=peft_config,
    tokenizer=tokenizer
)
trainer.train()

  return fn(*args, **kwargs)


Step,Training Loss
10,1.6716
20,1.1765
30,0.631
40,0.4176
50,0.3885
60,0.3701
70,0.3561
80,0.3605
90,0.3783
100,0.3601


TrainOutput(global_step=145, training_loss=0.5313554467826054, metrics={'train_runtime': 856.4899, 'train_samples_per_second': 2.722, 'train_steps_per_second': 0.169, 'total_flos': 2.456790608823091e+16, 'train_loss': 0.5313554467826054, 'epoch': 0.9948542024013722})

In [42]:
model.save_pretrained(new_model_path)

### Inference

In [43]:
config = PeftConfig.from_pretrained(new_model_path+"/checkpoint-145")
model = AutoModelForCausalLM.from_pretrained(
    config.base_model_name_or_path,
    return_dict=True,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)

tokenizer=AutoTokenizer.from_pretrained(config.base_model_name_or_path)
tokenizer.pad_token = tokenizer.unk_token

model = PeftModel.from_pretrained(model, new_model_path+"/checkpoint-145")


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [44]:
def prepare_for_inference(user_input: str, system_prompt: str = system_prompt):
    prompt_data = []
    if len(system_prompt):
        prompt_data.append({"role": "system", "content": system_prompt})
    prompt_data.append({"role": "user", "content": user_input})
    return tokenizer.apply_chat_template(
        prompt_data, tokenize=False, add_generation_prompt=True
    )

In [46]:
sentence: str = "A random paragraph can also be an excellent way for a writer to tackle writers' block. Writing block can often happen due to being stuck with a current project that the writer is trying to complete. By inserting a completely random paragraph from which to begin, it can take down some of the issues that may have been causing the writers' block in the first place. Another productive way, other than using xanax, to use this tool to begin a daily writing routine."

In [50]:
generation_args = {
    "max_new_tokens": 150,
    "return_full_text": False,
}

In [56]:
peft_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
output = peft_pipeline(prepare_for_inference("There haven't been any signs of anything going wrong."), **generation_args)
print(output[0]["generated_text"])

The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FalconMambaForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'Gemma2ForCausalLM', 'GitForCausalLM', 'GlmForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'GraniteForCausalLM', 'GraniteMoeForCausalLM', 'JambaForCausalLM', 'JetMoeForCausalLM', 'LlamaForCausalLM', 'MambaForCausalLM', 'Mamba2ForCausalLM', 'MarianForCausalLM', 'MBartForCausa

 []
