<a href="https://colab.research.google.com/github/ThienAnTrinh/llama2-medical-consultant/blob/master/llama2_7b_medical.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Reference: https://www.philschmid.de/instruction-tune-llama-2#4-test-model-and-run-inference

In [None]:
# Install necessary dependencies

!pip install "transformers==4.31.0" "datasets==2.13.0" "peft==0.4.0" "accelerate==0.21.0" "bitsandbytes==0.40.2" "trl==0.4.7" "safetensors>=0.3.1" --upgrade

Collecting transformers==4.31.0
  Downloading transformers-4.31.0-py3-none-any.whl (7.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m48.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets==2.13.0
  Downloading datasets-2.13.0-py3-none-any.whl (485 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m485.6/485.6 kB[0m [31m43.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting peft==0.4.0
  Downloading peft-0.4.0-py3-none-any.whl (72 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m72.9/72.9 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting accelerate==0.21.0
  Downloading accelerate-0.21.0-py3-none-any.whl (244 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m244.2/244.2 kB[0m [31m26.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting bitsandbytes==0.40.2
  Downloading bitsandbytes-0.40.2-py3-none-any.whl (92.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9

## Data

In [None]:
# Load train dataset

from datasets import load_dataset

dataset = load_dataset("medical_dialog", "processed.en", split="train")

Downloading builder script:   0%|          | 0.00/13.8k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/9.74k [00:00<?, ?B/s]

Downloading and preparing dataset medical_dialog/processed.en to /root/.cache/huggingface/datasets/medical_dialog/processed.en/2.0.0/0e925f6f3a036cf46434ddd9e73e9a69bfc91dd467825560d27f04c4e226cba6...


Downloading data:   0%|          | 0.00/139k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/482 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/60 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/61 [00:00<?, ? examples/s]

Dataset medical_dialog downloaded and prepared to /root/.cache/huggingface/datasets/medical_dialog/processed.en/2.0.0/0e925f6f3a036cf46434ddd9e73e9a69bfc91dd467825560d27f04c4e226cba6. Subsequent calls will reuse this data.


In [None]:
# Load validation dataset

val_dataset = load_dataset("medical_dialog", "processed.en", split="validation")



In [None]:
dataset, val_dataset

(Dataset({
     features: ['description', 'utterances'],
     num_rows: 482
 }),
 Dataset({
     features: ['description', 'utterances'],
     num_rows: 60
 }))

In [None]:
dataset["utterances"][0]

['patient: throat a bit sore and want to get a good imune booster, especially in light of the virus. please advise. have not been in contact with nyone with the virus.',
 "doctor: during this pandemic. throat pain can be from a strep throat infection (antibiotics needed), a cold or influenza or other virus, or from some other cause such as allergies or irritants. usually, a person sees the doctor (call first) if the sore throat is bothersome, recurrent, or doesn't go away quickly. covid-19 infections tend to have cough, whereas strep throat usually lacks cough but has more throat pain. (3/21/20)"]

In [None]:
# Structure the data into instruction format

def format_instruction(sample):
    return f"""### Instruction:
Analyze the indicators and symptoms of the patient in the Input. Provide a Response with doctor's advice to cure or alleviate the related sickness or disease.

### Input:
{sample["utterances"][0]}

### Response:
{sample["utterances"][1]}
"""

In [None]:
# View one instruction example

from random import randrange

print(format_instruction(dataset[randrange(len(dataset))]))

### Instruction:
Analyze the indicators and symptoms of the patient in the Input. Provide a Response with doctor's advice to cure or alleviate the related sickness or disease.

### Input:
patient: are children with a respiratory pathology at greater risk for covid-19 (since there are so little cases of children getting the virus, but still their natural history could be an issue)?

### Response:
doctor: be very cautious . there is not a lot of information on children with covid-19 but definitely would try to protect this child especially and consult the doctor asap if child develops fever and shortness of breath.



## Model

In [None]:
# prepare quantization config

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

use_flash_attention = False
model_id = "NousResearch/Llama-2-7b-hf"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [None]:
# load model with quantization config

model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, use_cache=False, device_map="auto")
model.config.pretraining_tp = 1

Downloading (…)lve/main/config.json:   0%|          | 0.00/583 [00:00<?, ?B/s]

Downloading (…)fetensors.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

Downloading (…)of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)neration_config.json:   0%|          | 0.00/179 [00:00<?, ?B/s]

In [None]:
# load tokenizer

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

Downloading (…)okenizer_config.json:   0%|          | 0.00/746 [00:00<?, ?B/s]

Downloading tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

Downloading (…)in/added_tokens.json:   0%|          | 0.00/21.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/435 [00:00<?, ?B/s]

In [None]:
# Lora

from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM"
)

# prepare model with Lora

model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)

## Train

In [None]:
# trainer config

from transformers import TrainingArguments

args = TrainingArguments(
    output_dir="llama-7b-int4-medical",
    num_train_epochs=3,
    per_device_train_batch_size=6 if use_flash_attention else 2,
    gradient_accumulation_steps=2,
    gradient_checkpointing=True,
    optim="paged_adamw_32bit",
    logging_steps=10,
    save_strategy="epoch",
    learning_rate=2e-4,
    bf16=False, #bf16 for Ampere GPUs
    tf32=False, #tf32 for Ampere GPUs
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="constant",
    disable_tqdm=True
)


# trainer

from trl import SFTTrainer

max_seq_length = 512 # max sequence length for model and packing of the dataset

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=peft_config,
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    packing=True,
    formatting_func=format_instruction,
    args=args,
)


In [None]:
# train
trainer.train() # there will not be a progress bar since tqdm is disabled

# save model
trainer.save_model()

{'loss': 1.3697, 'learning_rate': 0.0002, 'epoch': 0.08}
{'loss': 1.3267, 'learning_rate': 0.0002, 'epoch': 0.17}
{'loss': 1.3462, 'learning_rate': 0.0002, 'epoch': 0.25}
{'loss': 1.295, 'learning_rate': 0.0002, 'epoch': 0.33}
{'loss': 1.324, 'learning_rate': 0.0002, 'epoch': 0.41}
{'loss': 1.2526, 'learning_rate': 0.0002, 'epoch': 1.07}
{'loss': 1.2283, 'learning_rate': 0.0002, 'epoch': 1.16}
{'loss': 1.2393, 'learning_rate': 0.0002, 'epoch': 1.24}
{'loss': 1.2487, 'learning_rate': 0.0002, 'epoch': 1.32}
{'loss': 1.3067, 'learning_rate': 0.0002, 'epoch': 1.41}
{'loss': 1.1237, 'learning_rate': 0.0002, 'epoch': 2.07}
{'loss': 1.1314, 'learning_rate': 0.0002, 'epoch': 2.15}
{'loss': 1.1837, 'learning_rate': 0.0002, 'epoch': 2.23}
{'loss': 1.1642, 'learning_rate': 0.0002, 'epoch': 2.32}
{'loss': 1.1947, 'learning_rate': 0.0002, 'epoch': 2.4}
{'train_runtime': 6333.8221, 'train_samples_per_second': 0.228, 'train_steps_per_second': 0.057, 'train_loss': 1.2452210102205963, 'epoch': 2.42}


In [None]:
!zip -r llama-7b-int4-medical-2.zip llama-7b-int4-medical

  adding: llama-7b-int4-medical/ (stored 0%)
  adding: llama-7b-int4-medical/runs/ (stored 0%)
  adding: llama-7b-int4-medical/runs/Aug26_16-22-51_057cfdc18a22/ (stored 0%)
  adding: llama-7b-int4-medical/runs/Aug26_16-22-51_057cfdc18a22/events.out.tfevents.1693074625.057cfdc18a22.1959.1 (deflated 61%)
  adding: llama-7b-int4-medical/runs/Aug26_16-22-51_057cfdc18a22/events.out.tfevents.1693066972.057cfdc18a22.1959.0 (deflated 60%)
  adding: llama-7b-int4-medical/adapter_config.json (deflated 43%)
  adding: llama-7b-int4-medical/checkpoint-102/ (stored 0%)
  adding: llama-7b-int4-medical/checkpoint-102/trainer_state.json (deflated 75%)
  adding: llama-7b-int4-medical/checkpoint-102/adapter_config.json (deflated 43%)
  adding: llama-7b-int4-medical/checkpoint-102/scheduler.pt (deflated 51%)
  adding: llama-7b-int4-medical/checkpoint-102/training_args.bin (deflated 49%)
  adding: llama-7b-int4-medical/checkpoint-102/special_tokens_map.json (deflated 72%)
  adding: llama-7b-int4-medical/ch

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import locale
locale.getpreferredencoding = lambda: "UTF-8"

In [None]:
!cp llama-7b-int4-medical-2.zip /content/drive/MyDrive

## Inference

In [None]:
import torch
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer

args.output_dir = "llama-7b-int4-medical"

# load base LLM model and tokenizer
model = AutoPeftModelForCausalLM.from_pretrained(
    args.output_dir,
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
)
tokenizer = AutoTokenizer.from_pretrained(args.output_dir)

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
sample = val_dataset[randrange(len(val_dataset))]

In [None]:
prompt = f"""### Instruction:
Analyze the indicators and symptoms of the patient in the Input. Provide a Response with doctor's advice to cure or alleviate the related sickness or disease.

### Input:
{sample["utterances"][0]}

### Response:
"""

input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda()
# with torch.inference_mode():
outputs = model.generate(input_ids=input_ids, max_new_tokens=512, do_sample=True, top_p=0.3,temperature=0.9)

print(f"Prompt:\n{sample['utterances'][0]}\n")
print(f"Generated instruction:\n{tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0][len(prompt):]}")
print(f"Ground truth:\n{sample['utterances'][1]}")

Prompt:
patient: my 62 year old sister is currently hospitalized for pneumonia that was dx after foot surgery related to a fall. she had breast ca 8 years ago. currently taking tamoxifin. brca ii gene. 3 sisters also with gene mutation and hx of breast ca. youngest sister passed away from recurrance of breast ca. (mets to lungs, liver and bones.) hospital ruled out blood clot in lung. dx with pneumonia. should ct scan be done since she has has a long standing cough and multiple bouts of pneumonia.

Generated instruction:
doctor: thanks for your question on healthcare magic.i have gone through your query. yes, ct scan of the chest is advisable. it will help you to diagnose the underlying cause of pneumonia. so better to get done ct thorax. hope this clears your query. i am sure you will like my response.thanks.wishing good health to your sister.regards.

Ground truth:
doctor: hello and welcome to ‘ask a doctor’ service. i have reviewed your query and here is my advice. yes, she can safe

In [None]:
from peft import AutoPeftModelForCausalLM

model = AutoPeftModelForCausalLM.from_pretrained(
    args.output_dir,
    low_cpu_mem_usage=True,
)

# Merge LoRA and base model
merged_model = model.merge_and_unload()

# Save the merged model
merged_model.save_pretrained("merged_model", safe_serialization=True)
tokenizer.save_pretrained("merged_model")

# push merged model to the hub
# merged_model.push_to_hub("user/repo")
# tokenizer.push_to_hub("user/repo")