This notebook is taken directly from https://github.com/tcapelle/llm_recipes/tree/main


# Finetuning Llama-2 to produce BioLlama using HF and WanB

In [1]:
# !pip install wandb transformers trl datasets "protobuf==3.20.3" evaluate
# !wget https://raw.githubusercontent.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM/main/data/alpaca_gpt4_data.json
from utilities.parse_benchmark import parse_benchmark
# benchmark = "MedQA-4"
# benchmark = "MedQA-5"
# benchmark = "MedMCQA"
benchmark = "PubMedQA"
# benchmark = "bioASQ_with_snippet"
# if benchmark == "PubMedQA":
#     benchmark_questions, benchmark_answers = parse_benchmark(benchmark, "test.json")
# else:
#     benchmark_questions, benchmark_answers = parse_benchmark(benchmark, "train.jsonl")

In [2]:
import wandb
wandb.init(project="biollama_ft", # the project I am working on
           tags=["hf_sft", "BioLlama"]) # the Hyperparameters I want to keep track of

[34m[1mwandb[0m: Currently logged in as: [33mnelectric[0m ([33mneelectric[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: wandb version 0.16.3 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


[34m[1mwandb[0m: Tracking run with wandb version 0.16.1


[34m[1mwandb[0m: Run data is saved locally in [35m[1m/home/service/BioLlama/wandb/run-20240305_193753-yclujsr5[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.


[34m[1mwandb[0m: Syncing run [33mbreezy-violet-203[0m


[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/neelectric/biollama_ft[0m


[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/neelectric/biollama_ft/runs/yclujsr5[0m


In [3]:
import os
from datasets import load_dataset
from datasets import Dataset
import json
if benchmark == "PubMedQA":
    artifact_dir = os.getcwd() + "/benchmarks/PubMedQA/edited"
    dataset = load_dataset("json", data_dir=artifact_dir)
else:
    if benchmark == "MedQA-4":
        artifact_dir = os.getcwd() + "/benchmarks/MedQA-4-option/"
    elif benchmark == "MedQA-5":
        artifact_dir = os.getcwd() + "/benchmarks/MedQA-USMLE/"
    elif benchmark == "MedMCQA":
        artifact_dir = os.getcwd() + "/benchmarks/MedMCQA/"
    elif benchmark == "bioASQ_with_snippet":
        print("loading bioASQ")
        artifact_dir = os.getcwd() + "/benchmarks/BioASQ/edited"
    dataset = load_dataset("json", data_dir=artifact_dir)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
dataset

DatasetDict({
    test: Dataset({
        features: ['CONTEXTS', 'QUESTION', 'final_decision'],
        num_rows: 1000
    })
})

In [5]:
if benchmark == "PubMedQA" or benchmark == "bioASQ_with_snippet":
    train_dataset = dataset["test"]
else:
    train_dataset = dataset["train"]
    eval_dataset = dataset["validation"]
    print(len(eval_dataset))
print(len(train_dataset))

1000


In [6]:
from utilities.prompts2 import promptify
# def create_prompt(row):
#     option_string = ""
#     for option in row["options"].keys():
#         option_string += "\n (" + option + ") " + row["options"][option]
#     row["option_string"] = option_string
#     return ("<QUESTION>{question} {option_string}</QUESTION><ANSWER> ({answer_idx}) {answer}</ANSWER>").format_map(row)
def create_prompt(row):
    option_string = ""
    for option in row["options"].keys():
        option_string += "\n (" + option + ") " + row["options"][option]
    MCQ_answer = "(" + row['answer_idx'] + ") " + row["answer"]
    question = row["question"] + option_string
    promptified = promptify(benchmark, question, retrieval_mode = None, retrieved_chunks = None, model = None)
    row["promptified"] = promptified
    row["MCQ_answer"] = MCQ_answer
    return ("{promptified} {MCQ_answer}</ANSWER>").format_map(row)

if benchmark == "MedMCQA":
    def create_prompt(row):
        option_string = "\n(1) " + row['opa']
        option_string += "\n(2) " + row['opb']
        option_string += "\n(3) " + row['opc']
        option_string += "\n(4) " + row['opd']
        row["option_string"] = option_string
        if row['cop'] == 1:
            row['answer'] = row['opa']
        elif row['cop'] == 2:
            row['answer'] = row['opb']
        elif row['cop'] == 3:
            row['answer'] = row['opc']
        elif row['cop'] == 4:
            row['answer'] = row['opd']
        question = row["question"] + option_string
        promptified = promptify(benchmark, question, retrieval_mode = None, retrieved_chunks = None, model = None)
        #replace all occurrences of "{" with "(":
        promptified = promptified.replace("{", "(")
        promptified = promptified.replace("}", ")")
        return (promptified + " {cop}</ANSWER>").format_map(row)
elif benchmark == "PubMedQA":
    def create_prompt(row):
        snippet_string = ""
        for snippet in row["CONTEXTS"]:
            snippet_string += snippet + "\n"
        row["snippet_string"] = snippet_string
        row["example"] = "You start all of your responses with <ANSWER> and end them with </ANSWER>, as shown in the following example:\n<QUESTION>Do mitochondria play a role in remodelling lace plant leaves during programmed cell death?</QUESTION>\n<ANSWER> yes</ANSWER>\nDo not justify your response, respond with only yes, maybe or no.\n"
        return ("Using the following text snippets, answer the question that follows.\n<SNIPPETS>\n{snippet_string}</SNIPPETS>\n{example}<QUESTION>{QUESTION}</QUESTION>\n<ANSWER> {final_decision}</ANSWER>").format_map(row)
elif benchmark == "bioASQ_with_snippet":
    def create_prompt(row):
        question = [row['snippets'], row['question']]
        # print(question)
        promptified = promptify("bioASQ_with_snippet", question, retrieval_mode = None, retrieved_chunks = None, model = None)
        return promptified + " " + row['answer'] + "</ANSWER>"
print(create_prompt(train_dataset[5]))

Using the following text snippets, answer the question that follows.
<SNIPPETS>
From March 2007 to January 2011, 88 DBE procedures were performed on 66 patients. Indications included evaluation anemia/gastrointestinal bleed, small bowel IBD and dilation of strictures. Video-capsule endoscopy (VCE) was used prior to DBE in 43 of the 66 patients prior to DBE evaluation.
The mean age was 62 years. Thirty-two patients were female, 15 were African-American; 44 antegrade and 44 retrograde DBEs were performed. The mean time per antegrade DBE was 107.4±30.0 minutes with a distance of 318.4±152.9 cm reached past the pylorus. The mean time per lower DBE was 100.7±27.3 minutes with 168.9±109.1 cm meters past the ileocecal valve reached. Endoscopic therapy in the form of electrocautery to ablate bleeding sources was performed in 20 patients (30.3%), biopsy in 17 patients (25.8%) and dilation of Crohn's-related small bowel strictures in 4 (6.1%). 43 VCEs with pathology noted were performed prior to

In [7]:
def create_prompt_no_answer(row):
    option_string = ""
    for option in row["options"].keys():
        option_string += "\n (" + option + ") " + row["options"][option]
    row["option_string"] = option_string
    return ("<QUESTION>{question} {option_string}</QUESTION><ANSWER> ").format_map(row)

def return_prompt_no_answer(row):
    return {"text": create_prompt_no_answer(row)}

def return_prompt(row):
    return {"text": create_prompt(row)}
    
if benchmark == "MedQA":
    test_dataset = eval_dataset.map(return_prompt_no_answer)
train_dataset_with_texts = train_dataset.map(return_prompt)
print(train_dataset_with_texts[0]["text"])

Using the following text snippets, answer the question that follows.
<SNIPPETS>
Programmed cell death (PCD) is the regulated death of cells within an organism. The lace plant (Aponogeton madagascariensis) produces perforations in its leaves through PCD. The leaves of the plant consist of a latticework of longitudinal and transverse veins enclosing areoles. PCD occurs in the cells at the center of these areoles and progresses outwards, stopping approximately five cells from the vasculature. The role of mitochondria during PCD has been recognized in animals; however, it has been less studied during PCD in plants.
The following paper elucidates the role of mitochondrial dynamics during developmentally regulated PCD in vivo in A. madagascariensis. A single areole within a window stage leaf (PCD is occurring) was divided into three areas based on the progression of PCD; cells that will not undergo PCD (NPCD), cells in early stages of PCD (EPCD), and cells in late stages of PCD (LPCD). Windo

In [8]:
import torch
from transformers import BitsAndBytesConfig, AutoModelForCausalLM
size = "13"
if size == "7":
    RETRO_layer_ids = [15]
    torch_dtype=torch.float32
elif size == "13":
    RETRO_layer_ids = [19]
    torch_dtype=torch.bfloat16
elif size == "70":
    RETRO_layer_ids = [39]
    torch_dtype=torch.bfloat16

    print("best of luck training 70b lol")
print(f"RETRO_layer_ids is {RETRO_layer_ids} and torch_dtype is {torch_dtype}")

RETRO_layer_ids is [19] and torch_dtype is torch.bfloat16


In [9]:
from utilities.biollama import BioLlama
import torch

amended_questions = ["The main calcium pump of the sarcoplasmic reticulum is "]
# answers = ["Sarcoplasmic reticulum Ca(2+)-ATPase"] # or "SERCA","serca2"
prompt = amended_questions[0]
model_id = "meta-llama/Llama-2-" + size +"b-chat-hf"
chunk_length = 32

BioLlama = BioLlama(
    model_id=model_id,
    chunk_length=chunk_length,
    RETRO_layer_ids=RETRO_layer_ids,
    training=True,
    torch_dtype=torch_dtype)

Loading checkpoint shards:   0%|                                                                  | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:  33%|███████████████████▎                                      | 1/3 [00:02<00:05,  2.91s/it]

Loading checkpoint shards:  67%|██████████████████████████████████████▋                   | 2/3 [00:04<00:02,  2.41s/it]

Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 3/3 [00:06<00:00,  1.85s/it]

Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 3/3 [00:06<00:00,  2.05s/it]




  return self.fget.__get__(instance, owner)()


In [10]:
model = BioLlama.model
tokenizer = BioLlama.tokenizer

In [11]:
print("freezing layers, currently only works for single unfrozen retro layer")
n_freeze = BioLlama.RETRO_layer_ids[0]

# freeze layers (disable gradients)
for param in model.parameters(): 
    param.requires_grad = False
for param in model.lm_head.parameters(): 
    param.requires_grad = True
#for every parameter in retro_layer_params, print where in the model it comes from (ie is it from self attention, layer norm, etc)
print(f"\nprinting layer {n_freeze} params")
for name, param in model.model.layers[n_freeze].named_parameters():
    print(f"{name}, requires_grad = {param.requires_grad}")   

list_of_params_to_unfreeze = [
    "cca_attn.q_proj.weight",
    "cca_attn.k_proj.weight",
    "cca_attn.v_proj.weight",
    "cca_attn.o_proj.weight",
    "pre_cca_layernorm.weight",
]

for name, param in model.model.layers[n_freeze].named_parameters(): 
    if name in list_of_params_to_unfreeze:
        param.requires_grad = True
print(f"\nprinting layer {n_freeze} params")
for name, param in model.model.layers[n_freeze].named_parameters():
    print(f"{name}, requires_grad = {param.requires_grad}")   

freezing layers, currently only works for single unfrozen retro layer

printing layer 19 params
self_attn.q_proj.weight, requires_grad = False
self_attn.k_proj.weight, requires_grad = False
self_attn.v_proj.weight, requires_grad = False
self_attn.o_proj.weight, requires_grad = False
mlp.gate_proj.weight, requires_grad = False
mlp.up_proj.weight, requires_grad = False
mlp.down_proj.weight, requires_grad = False
input_layernorm.weight, requires_grad = False
post_attention_layernorm.weight, requires_grad = False
cca_attn.q_proj.weight, requires_grad = False
cca_attn.k_proj.weight, requires_grad = False
cca_attn.v_proj.weight, requires_grad = False
cca_attn.o_proj.weight, requires_grad = False
pre_cca_layernorm.weight, requires_grad = False

printing layer 19 params
self_attn.q_proj.weight, requires_grad = False
self_attn.k_proj.weight, requires_grad = False
self_attn.v_proj.weight, requires_grad = False
self_attn.o_proj.weight, requires_grad = False
mlp.gate_proj.weight, requires_grad = F

In [12]:
for name, param in model.model.named_parameters(): 
    print(f"{name}, requires_grad = {param.requires_grad}")
    # param.requires_grad = True

BioLlama.model.train()

embed_tokens.weight, requires_grad = False
layers.0.self_attn.q_proj.weight, requires_grad = False
layers.0.self_attn.k_proj.weight, requires_grad = False
layers.0.self_attn.v_proj.weight, requires_grad = False
layers.0.self_attn.o_proj.weight, requires_grad = False
layers.0.mlp.gate_proj.weight, requires_grad = False
layers.0.mlp.up_proj.weight, requires_grad = False
layers.0.mlp.down_proj.weight, requires_grad = False
layers.0.input_layernorm.weight, requires_grad = False
layers.0.post_attention_layernorm.weight, requires_grad = False
layers.1.self_attn.q_proj.weight, requires_grad = False
layers.1.self_attn.k_proj.weight, requires_grad = False
layers.1.self_attn.v_proj.weight, requires_grad = False
layers.1.self_attn.o_proj.weight, requires_grad = False
layers.1.mlp.gate_proj.weight, requires_grad = False
layers.1.mlp.up_proj.weight, requires_grad = False
layers.1.mlp.down_proj.weight, requires_grad = False
layers.1.input_layernorm.weight, requires_grad = False
layers.1.post_attenti

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 5120)
    (layers): ModuleList(
      (0-18): 19 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=5120, out_features=5120, bias=False)
          (k_proj): Linear(in_features=5120, out_features=5120, bias=False)
          (v_proj): Linear(in_features=5120, out_features=5120, bias=False)
          (o_proj): Linear(in_features=5120, out_features=5120, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=5120, out_features=13824, bias=False)
          (up_proj): Linear(in_features=5120, out_features=13824, bias=False)
          (down_proj): Linear(in_features=13824, out_features=5120, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
      (19): LlamaDecoderLayer(
        (self_attn

In [13]:
# Just freeze embeddings for small memory decrease
model.model.embed_tokens.weight.requires_grad_(False);

In [14]:
def param_count(m):
    params = sum([p.numel() for p in m.parameters()])/1_000_000
    trainable_params = sum([p.numel() for p in m.parameters() if p.requires_grad])/1_000_000
    print(f"Total params: {params:.2f}M, Trainable: {trainable_params:.2f}M")
    return params, trainable_params

params, trainable_params = param_count(model)

Total params: 13120.73M, Trainable: 268.70M


In [15]:
batch_size = 2

total_num_steps = 11_210 // batch_size
print(total_num_steps)

if benchmark == "MedQA-4" or benchmark == "MedQA-5":
    total_num_steps = 10178
elif benchmark == "MedMCQA":
    total_num_steps = 100000
elif benchmark == "PubMedQA":
    total_num_steps = 1000 * 2
elif benchmark == "bioASQ_with_snippet":
    total_num_steps = 486 
print(f"changing total num size to {total_num_steps}")
print(benchmark)

5605
changing total num size to 2000
PubMedQA


In [16]:
from transformers import TrainingArguments
output_dir = "/home/service/BioLlama/utilities/finetuning/biollama_training_output/" + size  + "/"
training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size//2,
    bf16=True,
    learning_rate=2e-4,
    lr_scheduler_type="cosine",
    warmup_steps=total_num_steps // 10,
    num_train_epochs=2,
    max_steps=total_num_steps,
    gradient_accumulation_steps=1,
    gradient_checkpointing=True,
    # evaluation_strategy="steps",
    # eval_steps=5000,
    # logging strategies
    logging_dir=f"{output_dir}/logs",
    logging_strategy="steps",
    logging_steps=8,
    save_strategy="epoch", #changed to epoch so we save every epoch i guess?
    save_total_limit=1,
)

In [17]:
from trl import SFTTrainer
import trl
# from utilities.finetuning.sft_trainer import SFTTrainer
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_dataset_with_texts,
    dataset_text_field="text",
    # eval_dataset=test_dataset,
    packing=True,
    max_seq_length=1024,
    args=training_args,
    formatting_func=create_prompt,
    # compute_metrics=token_accuracy,
)



In [18]:
#very hacky but maybe this will work:
tokenizer.model_input_names = ['labels', 'input_ids', 'attention_mask']
# trainer.args.train_batch_size = 1
# self.args.train_batch_size

#also hacky, but could work:
tokenizer.pad_token = tokenizer.eos_token
print("Starting training")
trainer.train()
wandb.finish()

Starting training






  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss
8,1.7902
16,1.7314
24,1.7545
32,1.7369
40,1.6591
48,1.6162
56,1.6044
64,1.5721
72,1.4918
80,1.4795














































































































































































































































































































































































































































































































































































































































































































































































































































































































































In [None]:
benchmark

In [None]:
print(size)
output_dir = "/home/service/BioLlama/utilities/finetuning/biollama_training_output/" + benchmark + "/" + size  + "/"
print(RETRO_layer_ids)

In [None]:
import os
print(os.path.abspath(output_dir))
trainer.save_model(output_dir)
# !ls -l $output_dir

In [None]:
#load this local model here and use it to generate some text
print(output_dir)

# from transformers import AutoModelForCausalLM, AutoTokenizer
# import time
# import torch
# from utilities.biollama import BioLlama

# chunk_length = 32

# BioLlama = BioLlama(model_id=output_dir, 
#     chunk_length=chunk_length, 
#     RETRO_layer_ids = RETRO_layer_ids, 
#     training=False, 
#     torch_dtype=torch.float32)

In [None]:
BioLlama.training = False
import time
prompt  = '<QUESTION>A 23-year-old pregnant woman at 22 weeks gestation presents with burning upon urination. She states it started 1 day ago and has been worsening despite drinking more water and taking cranberry extract. She otherwise feels well and is followed by a doctor for her pregnancy. Her temperature is 97.7°F (36.5°C), blood pressure is 122/77 mmHg, pulse is 80/min, respirations are 19/min, and oxygen saturation is 98% on room air. Physical exam is notable for an absence of costovertebral angle tenderness and a gravid uterus. Which of the following is the best treatment for this patient? \n (A) Ampicillin\n (B) Ceftriaxone\n (C) Ciprofloxacin\n (D) Doxycycline\n (E) Nitrofurantoin</QUESTION>\n<ANSWER> '
time_before_generation = time.time()
num_tokens, text = BioLlama.generate(prompt=prompt, max_new_tokens=50)
time_after = time.time()

print("***Generating***")
print(text)
print(f"Time taken for generation: {time_after - time_before_generation}")
print(f"Tokens per second: {num_tokens/(time_after - time_before_generation)}")

In [None]:
prompt2 = '<QUESTION>A 3-month-old baby died suddenly at night while asleep. His mother noticed that he had died only after she awoke in the morning. No cause of death was determined based on the autopsy. Which of the following precautions could have prevented the death of the baby? \n (A) Placing the infant in a supine position on a firm mattress while sleeping\n (B) Routine postnatal electrocardiogram (ECG)\n (C) Keeping the infant covered and maintaining a high room temperature\n (D) Application of a device to maintain the sleeping position\n (E) Avoiding pacifier use during sleep</QUESTION>\n<ANSWER> '
time_before_generation = time.time()
num_tokens, text = BioLlama.generate(prompt=prompt2, max_new_tokens=50)
time_after = time.time()

print("***Generating***")
print(text)
print(f"Time taken for generation: {time_after - time_before_generation}")
print(f"Tokens per second: {num_tokens/(time_after - time_before_generation)}")

In [None]:
prompt3 = "<QUESTION>A mother brings her 3-week-old infant to the pediatrician's office because she is concerned about his feeding habits. He was born without complications and has not had any medical problems up until this time. However, for the past 4 days, he has been fussy, is regurgitating all of his feeds, and his vomit is yellow in color. On physical exam, the child's abdomen is minimally distended but no other abnormalities are appreciated. Which of the following embryologic errors could account for this presentation? \n (A) Abnormal migration of ventral pancreatic bud\n (B) Complete failure of proximal duodenum to recanalize\n (C) Error in neural crest cell migration\n (D) Abnormal hypertrophy of the pylorus\n (E) Failure of lateral body folds to move ventrally and fuse in the midline</QUESTION>\n<ANSWER> "
time_before_generation = time.time()
num_tokens, text = BioLlama.generate(prompt=prompt3, max_new_tokens=50)
time_after = time.time()

print("***Generating***")
print(text)
print(f"Time taken for generation: {time_after - time_before_generation}")
print(f"Tokens per second: {num_tokens/(time_after - time_before_generation)}")

In [None]:
prompt4 = "<QUESTION>A 20-year-old woman presents with menorrhagia for the past several years. She says that her menses “have always been heavy”, and she has experienced easy bruising for as long as she can remember. Family history is significant for her mother, who had similar problems with bruising easily. The patient's vital signs include: heart rate 98/min, respiratory rate 14/min, temperature 36.1°C (96.9°F), and blood pressure 110/87 mm Hg. Physical examination is unremarkable. Laboratory tests show the following: platelet count 200,000/mm3, PT 12 seconds, and PTT 43 seconds. Which of the following is the most likely cause of this patient’s symptoms? \n (A) Factor V Leiden\n (B) Hemophilia A\n (C) Lupus anticoagulant\n (D) Protein C deficiency\n (E) Von Willebrand disease</QUESTION>\n<ANSWER> "
time_before_generation = time.time()
num_tokens, text = BioLlama.generate(prompt=prompt4, max_new_tokens=50)
time_after = time.time()

print("***Generating***")
print(text)
print(f"Time taken for generation: {time_after - time_before_generation}")
print(f"Tokens per second: {num_tokens/(time_after - time_before_generation)}")