<a href="https://colab.research.google.com/github/ashishmohapatra240/fine-tune-LLaMA-2/blob/main/fine_tune_LLaMA_2_on_custom_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [12]:
!pip install transformers peft trl datasets



In [13]:
!pip install einops



In [14]:
!pip install bitsandbytes accelerate



In [15]:
import os
import torch
import gc
import pandas as pd
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer
from transformers import TrainingArguments, pipeline
from peft import LoraConfig, PeftModel, get_peft_config
from trl import SFTTrainer

In [16]:
import warnings

warnings.filterwarnings("ignore")


In [17]:
  bnb_config=BitsAndBytesConfig(
      load_in_4bit=True,
      bnb_4bit_use_double_quant=True,
      bnb_4bit_quant_type="nf4",
      bnb_4bit_compute_dytype=torch.bfloat16
  )

In [18]:
device_map="auto"

In [19]:
df=pd.read_csv("/content/medquad.csv")

In [20]:
df.shape

(16412, 4)

In [21]:
df.head()

Unnamed: 0,question,answer,source,focus_area
0,What is (are) Glaucoma ?,Glaucoma is a group of diseases that can damag...,NIHSeniorHealth,Glaucoma
1,What causes Glaucoma ?,"Nearly 2.7 million people have glaucoma, a lea...",NIHSeniorHealth,Glaucoma
2,What are the symptoms of Glaucoma ?,Symptoms of Glaucoma Glaucoma can develop in ...,NIHSeniorHealth,Glaucoma
3,What are the treatments for Glaucoma ?,"Although open-angle glaucoma cannot be cured, ...",NIHSeniorHealth,Glaucoma
4,What is (are) Glaucoma ?,Glaucoma is a group of diseases that can damag...,NIHSeniorHealth,Glaucoma


In [22]:
data=Dataset.from_pandas(pd.DataFrame(data=df))

In [23]:
model_name="NousResearch/Llama-2-7b-chat-hf"

In [24]:
tokenizer=AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token=tokenizer.eos_token

Downloading (…)okenizer_config.json:   0%|          | 0.00/746 [00:00<?, ?B/s]

Downloading tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

Downloading (…)in/added_tokens.json:   0%|          | 0.00/21.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/435 [00:00<?, ?B/s]

In [25]:
torch.cuda.empty_cache()

In [26]:
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map=device_map,
)


Downloading (…)lve/main/config.json:   0%|          | 0.00/583 [00:00<?, ?B/s]

Downloading (…)fetensors.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

Downloading (…)of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)neration_config.json:   0%|          | 0.00/179 [00:00<?, ?B/s]

In [27]:
model.config.pretraining_tp=1
torch.cuda.empty_cache()

In [28]:
LORA_ALPHA=16
LORA_DROPOUT=0.2
LORA_R=64

In [29]:
peft_config=LoraConfig(
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    r=LORA_R,
    bias="none",
    task_type="CAUSAl_LM",
)

In [30]:
LEARNING_RATE=1e-4
NUM_EPOCHS=10
BATCH_SIZE=16
WEIGHT_DECAY=0.001
MAX_GRAD_NORM=0.3
gradient_accumulation_steps=16
STEPS=1
OPTIM="adam"
MAX_STEPS=10

In [31]:
OUTPUT_DIR="./results"

In [32]:
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=gradient_accumulation_steps,
    learning_rate=LEARNING_RATE,
    logging_steps=STEPS,
    num_train_epochs=NUM_EPOCHS,
    max_steps=MAX_STEPS,
)


In [33]:
torch.cuda.empty_cache()

In [34]:
trainer=SFTTrainer(
    model=model,
    train_dataset=data,
    peft_config=peft_config,
    dataset_text_field="question",
    max_seq_length=500,
    tokenizer=tokenizer,
    args=training_args,
)

Map:   0%|          | 0/16412 [00:00<?, ? examples/s]

In [35]:
data = Dataset.from_pandas(df)
data = data.map(
    lambda example: tokenizer(example["question"], padding="max_length", truncation=True),
    batched=True,
)

Map:   0%|          | 0/16412 [00:00<?, ? examples/s]

Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no padding.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [36]:
trainer.train

<bound method Trainer.train of <trl.trainer.sft_trainer.SFTTrainer object at 0x7cd3924597b0>>

In [37]:
torch.cuda.empty_cache()

In [38]:
pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_length=500)

Xformers is not installed correctly. If you want to use memory_efficient_attention to accelerate training use the following command to install Xformers
pip install xformers.


In [39]:
prompt="what is Glaucoma?"

In [40]:
template=f"""<s>[INST]<<SYS>>
You are a helpful, respectful and honest Medical Assistant.
You always answer to the context and do not hallucinate. You only answer to the questions that are related to Medical.
if a question does not make any sense, Just answer I don't know, please don't share false information.
<</SYS>>
{prompt}[/INST]
"""

In [41]:
result=pipe(template)

In [42]:
response=result[0]['generated_text']
response

"<s>[INST]<<SYS>>\nYou are a helpful, respectful and honest Medical Assistant.\nYou always answer to the context and do not hallucinate. You only answer to the questions that are related to Medical.\nif a question does not make any sense, Just answer I don't know, please don't share false information.\n<</SYS>>\nwhat is Glaucoma?[/INST]\nHello! I'm here to help you with any medical-related questions you may have. Glaucoma is a group of eye conditions that can damage the optic nerve, which carries visual information from the eye to the brain. It is often associated with increased pressure inside the eye, which can lead to vision loss and even blindness if left untreated.\n\nThere are several types of glaucoma, including:\n\n1. Open-angle glaucoma: This is the most common form of glaucoma, where the drainage channels in the eye are clogged, causing pressure to build up slowly over time.\n2. Closed-angle glaucoma: This type of glaucoma occurs when the drainage channels in the eye are comp