In [1]:
%%capture
%pip install -U transformers 
%pip install -U datasets 
%pip install -U accelerate 

%pip install -U peft 
%pip install -U trl 
%pip install -U bitsandbytes 
%pip install -U wandb

In [2]:
import multiprocessing as mp
mp.set_start_method('spawn', force=True)


In [3]:
%pip install peft


Note: you may need to restart the kernel to use updated packages.


In [4]:
%pip install trl


Note: you may need to restart the kernel to use updated packages.


In [6]:
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import (
    LoraConfig,
    PeftModel,
    prepare_model_for_kbit_training,
    get_peft_model,
)
import os, torch, wandb
from datasets import load_dataset
from trl import SFTTrainer, setup_chat_format

In [11]:
from huggingface_hub import login
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()

hf_token = user_secrets.get_secret("HuggingFace")

login(token = hf_token)

wb_token = user_secrets.get_secret("Weights & Biases")

wandb.login(key=wb_token)
run = wandb.init(
    project='Fine-tune Llama 3 8B on Medical Dataset', 
    job_type="training", 
    anonymous="allow"
)



The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: fineGrained).
Your token has been saved to /root/.cache/huggingface/token
Login successful


VBox(children=(Label(value='0.017 MB of 0.017 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

In [12]:
base_model = "/kaggle/input/llama-3/transformers/8b-chat-hf/1"
dataset_name = "ruslanmv/ai-medical-chatbot"
new_model = "llama-3-8b-chat-doctor"

In [13]:
torch_dtype = torch.float16
attn_implementation = "eager"

In [68]:
# QLoRA config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True,
)

# Load model
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation=attn_implementation
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [74]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model)
model, tokenizer = setup_chat_format(model, tokenizer)

In [75]:
# LoRA config
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)
model = get_peft_model(model, peft_config)

In [76]:
#Importing the dataset
dataset = load_dataset(dataset_name, split="all")
dataset = dataset.shuffle(seed=65).select(range(1000)) # Only use 1000 samples for quick demo

def format_chat_template(row):
    row_json = [{"role": "user", "content": row["Patient"]},
               {"role": "assistant", "content": row["Doctor"]}]
    row["text"] = tokenizer.apply_chat_template(row_json, tokenize=False)
    return row

dataset = dataset.map(
    format_chat_template,
    num_proc=4,
)

dataset['text'][3]

'<|im_start|>user\nFell on sidewalk face first about 8 hrs ago. Swollen, cut lip bruised and cut knee, and hurt pride initially. Now have muscle and shoulder pain, stiff jaw(think this is from the really swollen lip),pain in wrist, and headache. I assume this is all normal but are there specific things I should look for or will I just be in pain for a while given the hard fall?<|im_end|>\n<|im_start|>assistant\nHello and welcome to HCM,The injuries caused on various body parts have to be managed.The cut and swollen lip has to be managed by sterile dressing.The body pains, pain on injured site and jaw pain should be managed by pain killer and muscle relaxant.I suggest you to consult your primary healthcare provider for clinical assessment.In case there is evidence of infection in any of the injured sites, a course of antibiotics may have to be started to control the infection.Thanks and take careDr Shailja P Wahal<|im_end|>\n'

In [77]:
dataset = dataset.train_test_split(test_size=0.1)

In [78]:
training_arguments = TrainingArguments(
    output_dir=new_model,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    optim="paged_adamw_32bit",
    num_train_epochs=1,
    eval_strategy="steps",  # Updated argument
    eval_steps=0.2,
    logging_steps=1,
    warmup_steps=10,
    logging_strategy="steps",
    learning_rate=2e-4,
    fp16=False,
    bf16=False,
    group_by_length=True,
    report_to="wandb"
)


In [79]:
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    peft_config=peft_config,
    max_seq_length=512,
    dataset_text_field="text",
    tokenizer=tokenizer,
    args=training_arguments,
    packing= False,
)

Map:   0%|          | 0/900 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [80]:
trainer.train()

Step,Training Loss,Validation Loss
90,3.5294,2.552205
180,2.7655,2.476684
270,2.1201,2.419214
360,2.6061,2.393162
450,2.7057,2.380491


TrainOutput(global_step=450, training_loss=2.666596469614241, metrics={'train_runtime': 1846.6813, 'train_samples_per_second': 0.487, 'train_steps_per_second': 0.244, 'total_flos': 9209935665709056.0, 'train_loss': 2.666596469614241, 'epoch': 1.0})

In [81]:
wandb.finish()
model.config.use_cache = True

VBox(children=(Label(value='0.037 MB of 0.037 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
eval/loss,█▅▃▂▁
eval/runtime,▆▄█▅▁
eval/samples_per_second,▄▄▁▄█
eval/steps_per_second,▄▄▁▄█
train/epoch,▁▁▁▁▁▂▂▂▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▇▇▇▇▇▇██
train/global_step,▁▁▁▂▂▂▁▁▁▁▂▂▂▂▃▃▃▄▄▄▄▄▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇██
train/grad_norm,▂▂▂▂▁▁▁█▆█▂▂▂▂▂▂▁▁▁▂▂▂▂▁▁▁▁▁▁▂▂▁▂▁▁▁▂▂▁▂
train/learning_rate,█▇▇▇▁▆████▇▆▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▁
train/loss,▃▃▁▃▂█▆▆▂▂▃▂▂▂▂▂▃▂▁▁▂▁▁▃▃▂▁▂▂▁▂▂▂▂▂▁▃▂▁▂

0,1
eval/loss,2.38049
eval/runtime,80.4724
eval/samples_per_second,1.243
eval/steps_per_second,1.243
total_flos,9209935665709056.0
train/epoch,1.0
train/global_step,450.0
train/grad_norm,2.41205
train/learning_rate,0.0
train/loss,2.7057


In [82]:
messages = [
    {
        "role": "user",
        "content": "Hello doctor, I have bad acne. How do I get rid of it?"
    }
]

prompt = tokenizer.apply_chat_template(messages, tokenize=False, 
                                       add_generation_prompt=True)

inputs = tokenizer(prompt, return_tensors='pt', padding=True, 
                   truncation=True).to("cuda")

outputs = model.generate(**inputs, max_length=150, 
                         num_return_sequences=1)

text = tokenizer.decode(outputs[0], skip_special_tokens=True)

print(text.split("assistant")[1])

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)



Hi. I have gone through your query and understand your concern. I would suggest you to use a combination of topical and oral medications. For topical application, you can use a combination of benzoyl peroxide and salicylic acid. For oral application, you can use a combination of doxycycline and minocycline. You can also use a retinoid cream at night. For more information consult a dermatologist online --> http://www.lybrate.com/consult/dermatologist. Hope I have answered your query. Let me know if I can assist you further. Take care Regards, Dr.


In [84]:
trainer.model.save_pretrained(new_model)
trainer.model.push_to_hub(new_model, use_temp_dir=False)

adapter_model.safetensors:   0%|          | 0.00/168M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/Jenas-Anton/llama-3-8b-chat-doctor/commit/a74eb02e8eaa21a408143602d1fc02e85edb91b8', commit_message='Upload model', commit_description='', oid='a74eb02e8eaa21a408143602d1fc02e85edb91b8', pr_url=None, repo_url=RepoUrl('https://huggingface.co/Jenas-Anton/llama-3-8b-chat-doctor', endpoint='https://huggingface.co', repo_type='model', repo_id='Jenas-Anton/llama-3-8b-chat-doctor'), pr_revision=None, pr_num=None)