Import Statements

In [3]:
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import (
    LoraConfig,
    PeftModel,
    prepare_model_for_kbit_training,
    get_peft_model,
)
import os, torch, wandb
from datasets import load_dataset
from trl import SFTTrainer, setup_chat_format

Retrieve HuggingFace and W&B API Key/Tokens from .env file and log in.

In [4]:
import huggingface_hub as hf_hub
from dotenv import load_dotenv
load_dotenv()

hf_token = os.getenv("HUGGINGFACE_TOKEN")
hf_hub.login(token=hf_token)

wb_api_key = os.getenv("WANDB_API_KEY")
wandb.login(key=wb_api_key)
run = wandb.init(
    project='Fine-tune Llama 3 8B on Medical Dataset', 
    job_type="training", 
    anonymous="allow"
)



Paths to the base model, dataset, and new model name

In [23]:
base_model = "/home/arlamb/.llama/checkpoints/Llama3.1-8B-hf"
dataset_name = "ruslanmv/ai-medical-chatbot"
new_model = "llama-3-8b-chat-doctor"

Set the data type and attention implementation

In [24]:
torch_dtype = torch.float16
attn_implementation = "eager"

Load the model

In [25]:
# QLoRA config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True,
)

# Load model
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation=attn_implementation
)

Loading checkpoint shards: 100%|██████████| 4/4 [00:07<00:00,  1.95s/it]


Load the Tokenizer

In [26]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model)
model, tokenizer = setup_chat_format(model, tokenizer)

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


Use Low-Rank Adaption (LoRA) to fine-tune the model efficiently

In [27]:
# LoRA config
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)
model = get_peft_model(model, peft_config)

Load the dataset, select only the top 1000 rows to reduce training time, and apply chat template tranformation

In [28]:
#Importing the dataset
dataset = load_dataset(dataset_name, split="all")
dataset = dataset.shuffle(seed=65).select(range(1000)) # Only use 1000 samples for quick demo

def format_chat_template(row):
    row_json = [{"role": "user", "content": row["Patient"]},
               {"role": "assistant", "content": row["Doctor"]}]
    row["text"] = tokenizer.apply_chat_template(row_json, tokenize=False)
    return row

dataset = dataset.map(
    format_chat_template,
    num_proc=4,
)

dataset['text'][3]

Generating train split: 100%|██████████| 256916/256916 [00:00<00:00, 652093.86 examples/s]
Map (num_proc=4): 100%|██████████| 1000/1000 [00:00<00:00, 6195.83 examples/s]


'<|im_start|>user\nFell on sidewalk face first about 8 hrs ago. Swollen, cut lip bruised and cut knee, and hurt pride initially. Now have muscle and shoulder pain, stiff jaw(think this is from the really swollen lip),pain in wrist, and headache. I assume this is all normal but are there specific things I should look for or will I just be in pain for a while given the hard fall?<|im_end|>\n<|im_start|>assistant\nHello and welcome to HCM,The injuries caused on various body parts have to be managed.The cut and swollen lip has to be managed by sterile dressing.The body pains, pain on injured site and jaw pain should be managed by pain killer and muscle relaxant.I suggest you to consult your primary healthcare provider for clinical assessment.In case there is evidence of infection in any of the injured sites, a course of antibiotics may have to be started to control the infection.Thanks and take careDr Shailja P Wahal<|im_end|>\n'

Split dataset into training and validation

In [29]:
dataset = dataset.train_test_split(test_size=0.1)

Set model hyperparameters

In [30]:
training_arguments = TrainingArguments(
    output_dir=new_model,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    optim="paged_adamw_32bit",
    num_train_epochs=1,
    evaluation_strategy="steps",
    eval_steps=0.2,
    logging_steps=1,
    warmup_steps=10,
    logging_strategy="steps",
    learning_rate=2e-4,
    fp16=False,
    bf16=False,
    group_by_length=True,
    report_to="wandb"
)



Set-up a supervised fine-tuning trainer and provide a trainer. We provide a train and evaluation dataset, LoRA configuration, training argument, tokenizer, and model. The max_sequ_length is set to 512 for low VRAM usage

In [None]:
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    peft_config=peft_config,
    max_seq_length=512,
    dataset_text_field="text",
    tokenizer=tokenizer,
    args=training_arguments,
    packing= False,
)


Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.
Map: 100%|██████████| 900/900 [00:00<00:00, 3050.15 examples/s]
Map: 100%|██████████| 100/100 [00:00<00:00, 2447.04 examples/s]
  super().__init__(


Train

In [32]:
trainer.train()



Step,Training Loss,Validation Loss
90,2.0432,2.633747
180,2.4868,2.603117
270,2.7716,2.577504
360,2.346,2.55715
450,2.2501,2.54453


Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.


TrainOutput(global_step=450, training_loss=2.5712594265407986, metrics={'train_runtime': 324.9197, 'train_samples_per_second': 2.77, 'train_steps_per_second': 1.385, 'total_flos': 9312135459201024.0, 'train_loss': 2.5712594265407986, 'epoch': 1.0})

Model evaluation

In [33]:
wandb.finish()
model.config.use_cache = True

0,1
eval/loss,█▆▄▂▁
eval/runtime,█▇▁▃▁
eval/samples_per_second,▁▂█▆█
eval/steps_per_second,▁▂█▆█
train/epoch,▁▁▁▂▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇███
train/global_step,▁▁▁▁▂▂▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▇▇████
train/grad_norm,▄█▄▂▂▂▂▃▂▂▃▃▂▁▂▂▂▂▂▂▃▂▁▂▂▁▁▂▂▂▁▂▂▂▁▁▂▂▁▂
train/learning_rate,▆▇██▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▄▄▄▄▄▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁
train/loss,█▅▆▆▇▇▆▅▇▅▆▅▄▆▄▅▄▇▃▇▅▆▄▆▃▁▄▃▄▃▄▂▄▃▄▇▃▄▄▆

0,1
eval/loss,2.54453
eval/runtime,10.6788
eval/samples_per_second,9.364
eval/steps_per_second,9.364
total_flos,9312135459201024.0
train/epoch,1.0
train/global_step,450.0
train/grad_norm,1.7339
train/learning_rate,0.0
train/loss,2.2501


In [34]:
messages = [
    {
        "role": "user",
        "content": "Hello doctor, I have bad acne. How do I get rid of it?"
    }
]

prompt = tokenizer.apply_chat_template(messages, tokenize=False, 
                                       add_generation_prompt=True)

inputs = tokenizer(prompt, return_tensors='pt', padding=True, 
                   truncation=True).to("cuda")

outputs = model.generate(**inputs, max_length=150, 
                         num_return_sequences=1)

text = tokenizer.decode(outputs[0], skip_special_tokens=True)

print(text.split("assistant")[1])


Hi. For the treatment of acne, you should follow the steps below: 1. Cleanse your face with a mild soap and water twice daily. 2. Apply an antibacterial lotion containing 2% to 3% benzoyl peroxide. 3. Apply an oral antibiotic like minocycline or doxycycline. 4. Avoid greasy makeup. 5. Avoid sun exposure. 6. Avoid spicy food. 7. Avoid fatty food. 8. Avoid junk food. 9. Avoid stress. 10. Avoid oily hair. 11. Avoid


In [None]:
trainer.model.save_pretrained(new_model)
trainer.model.push_to_hub(new_model, use_temp_dir=False)

adapter_model.safetensors: 100%|██████████| 2.27G/2.27G [01:17<00:00, 29.3MB/s]


CommitInfo(commit_url='https://huggingface.co/AustianoTrojan/llama-3-8b-chat-doctor/commit/936762562635cd40ebe96c818e2926e46615768e', commit_message='Upload model', commit_description='', oid='936762562635cd40ebe96c818e2926e46615768e', pr_url=None, repo_url=RepoUrl('https://huggingface.co/AustianoTrojan/llama-3-8b-chat-doctor', endpoint='https://huggingface.co', repo_type='model', repo_id='AustianoTrojan/llama-3-8b-chat-doctor'), pr_revision=None, pr_num=None)