This notebook is taken directly from https://github.com/tcapelle/llm_recipes/tree/main

# From Llama to Alpaca: Finetunning and LLM with Weights & Biases
In this notebooks you will learn how to finetune a pretrained LLama model on an Instruction dataset. We will use an updated version of the Alpaca dataset that, instead of davinci-003 (GPT3) generations uses GPT4 to get an even better instruction dataset! More details on the [official repo page](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM#how-good-is-the-data)

> This notebook requires a A100/A10 GPU with at least 24GB of memory. You could tweak the params down and run on a T4 but it would take very long time

This notebooks has a companion project and [report](wandb.me/alpaca)

In [1]:
# !pip install wandb transformers trl datasets "protobuf==3.20.3" evaluate

## With Huggingface TRL

Let's grab the Alpaca (GPT-4 curated instructions and outputs) dataset:

In [2]:
# !wget https://raw.githubusercontent.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM/main/data/alpaca_gpt4_data.json
# from uparse_benchmark import parse_benchmark
# from ..utilities.parse_benchmark import parse_benchmark
from utilities.parse_benchmark import parse_benchmark

benchmark = "MedQA"
benchmark_questions, benchmark_answers = parse_benchmark(benchmark)
# print(benchmark_questions[0])
# print(benchmark_answers[0])

Loading Benchmark from MedQA-USMLE/US/test.jsonl
Benchmark contains 1273 questions, made up of 1273 with 5 options and 0 with non-5 options


In [3]:
import wandb
wandb.init(project="biollama_ft", # the project I am working on
           tags=["hf_sft", "BioLlama"]) # the Hyperparameters I want to keep track of
# artifact = wandb.use_artifact('Neelectric/MedQA-USMLE', type='dataset')
# artifact_dir = artifact.download()

[34m[1mwandb[0m: Currently logged in as: [33mnelectric[0m ([33mneelectric[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: wandb version 0.16.3 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


[34m[1mwandb[0m: Tracking run with wandb version 0.16.1


[34m[1mwandb[0m: Run data is saved locally in [35m[1m/home/service/BioLlama/wandb/run-20240208_000825-0funche6[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.


[34m[1mwandb[0m: Syncing run [33mdeep-monkey-111[0m


[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/neelectric/biollama_ft[0m


[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/neelectric/biollama_ft/runs/0funche6[0m


In [4]:
import os
# print(artifact_dir)
artifact_dir = os.getcwd() + "/benchmarks/MedQA-USMLE/"
from datasets import load_dataset
#dataset = load_dataset("Neelectric/MedQA-USMLE")
medqa = load_dataset("json", data_dir=artifact_dir)

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
#trying gsutil for SciFive pretraining corpus
# !pip install gsutil
import pandas as pd
import numpy as np
# abs_1_16 = pd.read_csv("abs_1_16.tsv", sep='\t')
# abs_1_30 = pd.read_csv("abs_1_30.tsv", sep='\t')

In [6]:
# abs_1_16 = abs_1_16.dropna()
# count_nans = abs_1_16.iloc[:, 0].isna().sum()

Let's log the dataset also as a table so we can inspect it on the workspace.

In [7]:
train_dataset = medqa["train"]
eval_dataset = medqa["validation"]
#print sizes
print(len(train_dataset))
print(len(eval_dataset))
# turn both of these into only half their size
# train_dataset = train_dataset.select(range(0, len(train_dataset)//2))
# eval_dataset = eval_dataset.select(range(0, len(eval_dataset)//2))

# print(len(train_dataset))
# print(len(eval_dataset))

10178
1272


In [8]:
def create_prompt(row):
    option_string = ""
    for option in row["options"].keys():
        option_string += "\n (" + option + ") " + row["options"][option]
    row["option_string"] = option_string
    return ("<QUESTION>{question} {option_string}</QUESTION>\n<ANSWER> ({answer_idx}) {answer}</ANSWER>").format_map(row)
create_prompt(train_dataset[4])

"<QUESTION>A 20-year-old woman presents with menorrhagia for the past several years. She says that her menses “have always been heavy”, and she has experienced easy bruising for as long as she can remember. Family history is significant for her mother, who had similar problems with bruising easily. The patient's vital signs include: heart rate 98/min, respiratory rate 14/min, temperature 36.1°C (96.9°F), and blood pressure 110/87 mm Hg. Physical examination is unremarkable. Laboratory tests show the following: platelet count 200,000/mm3, PT 12 seconds, and PTT 43 seconds. Which of the following is the most likely cause of this patient’s symptoms? \n (A) Factor V Leiden\n (B) Hemophilia A\n (C) Lupus anticoagulant\n (D) Protein C deficiency\n (E) Von Willebrand disease</QUESTION>\n<ANSWER> (E) Von Willebrand disease</ANSWER>"

In [9]:
def create_prompt_no_answer(row):
    option_string = ""
    for option in row["options"].keys():
        option_string += "\n (" + option + ") " + row["options"][option]
    row["option_string"] = option_string
    return ("<QUESTION>{question} {option_string}</QUESTION>\n<ANSWER> ").format_map(row)

def return_prompt_no_answer(row):
    return {"text": create_prompt_no_answer(row)}

def return_prompt(row):
    return {"text": create_prompt(row)}
    
test_dataset = eval_dataset.map(return_prompt_no_answer)
# print(test_dataset[0]["text"])
train_dataset_with_texts = train_dataset.map(return_prompt)
# print(train_dataset_with_texts[0]["text"])

Training the full models is expensive, but if you have a GPU that can fit the full model, you can skip this part. Let's just train the last 8 layers of the model (Llama2-7B has 32)

In [10]:
from utilities.biollama import BioLlama

# questions = ["Which is the main calcium pump of the sarcoplasmic reticulum? Answer:"]
amended_questions = ["The main calcium pump of the sarcoplasmic reticulum is "]
questions = amended_questions
# answers = ["Sarcoplasmic reticulum Ca(2+)-ATPase"] # or "SERCA","serca2"

prompt = questions[0]
# model_id = "TheBloke/Llama-2-7b-chat-GPTQ"
model_id = "meta-llama/Llama-2-7b-chat-hf"
chunk_length = 32

RETRO_layer_ids = [15]

BioLlama = BioLlama(
    model_id=model_id,
    chunk_length=chunk_length,
    RETRO_layer_ids=RETRO_layer_ids,
    training=True,
)

Loading checkpoint shards:   0%|                                                                                                            | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:  50%|██████████████████████████████████████████████████                                                  | 1/2 [00:03<00:03,  3.52s/it]

Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:04<00:00,  2.18s/it]

Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:04<00:00,  2.39s/it]




Wrapping layer 15 with retro


  return self.fget.__get__(instance, owner)()


In [11]:
# print(BioLlama.model)
model = BioLlama.model
tokenizer = BioLlama.tokenizer

In [12]:
print("freezing layers, currently only works for single unfrozen retro layer")
n_freeze = BioLlama.RETRO_layer_ids[0]
# n_freeze = 15

# freeze layers (disable gradients)
for param in model.parameters(): 
    param.requires_grad = False
for param in model.lm_head.parameters(): 
    param.requires_grad = True
#for every parameter in retro_layer_params, print where in the model it comes from (ie is it from self attention, layer norm, etc)
print("printing layer 14 params")
for name, param in model.model.layers[14].named_parameters():
    print(f"{name}, requires_grad = {param.requires_grad}") 
print("\nprinting layer 15 params")
for name, param in model.model.layers[n_freeze].named_parameters():
    print(f"{name}, requires_grad = {param.requires_grad}")   

list_of_params_to_unfreeze = [
    "CCA.pre_CCA_layernorm.weight",
    "layer.CCA_attn.q_proj.weight",
    "layer.CCA_attn.k_proj.weight",
    "layer.CCA_attn.v_proj.weight",
    "layer.CCA_attn.o_proj.weight",
    # "layer.post_attention_layernorm.weight",
    "layer.mlp.gate_proj.weight",
    "layer.mlp.up_proj.weight",
    "layer.mlp.down_proj.weight",
]

for name, param in model.model.layers[n_freeze].named_parameters(): 
    if name in list_of_params_to_unfreeze:
        param.requires_grad = True
print("\nprinting layer 15 params")
for name, param in model.model.layers[n_freeze].named_parameters():
    print(f"{name}, requires_grad = {param.requires_grad}")   

freezing layers, currently only works for single unfrozen retro layer
printing layer 14 params
self_attn.q_proj.weight, requires_grad = False
self_attn.k_proj.weight, requires_grad = False
self_attn.v_proj.weight, requires_grad = False
self_attn.o_proj.weight, requires_grad = False
mlp.gate_proj.weight, requires_grad = False
mlp.up_proj.weight, requires_grad = False
mlp.down_proj.weight, requires_grad = False
input_layernorm.weight, requires_grad = False
post_attention_layernorm.weight, requires_grad = False

printing layer 15 params
layer.self_attn.q_proj.weight, requires_grad = False
layer.self_attn.k_proj.weight, requires_grad = False
layer.self_attn.v_proj.weight, requires_grad = False
layer.self_attn.o_proj.weight, requires_grad = False
layer.mlp.gate_proj.weight, requires_grad = False
layer.mlp.up_proj.weight, requires_grad = False
layer.mlp.down_proj.weight, requires_grad = False
layer.input_layernorm.weight, requires_grad = False
layer.post_attention_layernorm.weight, requires_

In [13]:
# Just freeze embeddings for small memory decrease
model.model.embed_tokens.weight.requires_grad_(False);

In [14]:
def param_count(m):
    params = sum([p.numel() for p in m.parameters()])/1_000_000
    trainable_params = sum([p.numel() for p in m.parameters() if p.requires_grad])/1_000_000
    print(f"Total params: {params:.2f}M, Trainable: {trainable_params:.2f}M")
    return params, trainable_params

params, trainable_params = param_count(model)

Total params: 6805.53M, Trainable: 333.45M


In [15]:
batch_size = 2

total_num_steps = 11_210 // batch_size
print(total_num_steps)


total_num_steps = 10000
print(f"changing total num size to {total_num_steps}")

5605
changing total num size to 10000


In [16]:
from trl import SFTTrainer
from transformers import TrainingArguments
output_dir = "/home/service/BioLlama/utilities/finetuning/biollama_training_output/"
training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size//2,
    bf16=True,
    learning_rate=2e-4,
    lr_scheduler_type="cosine",
    warmup_steps=total_num_steps // 10,
    num_train_epochs=10,
    # max_steps=total_num_steps,
    gradient_accumulation_steps=1,
    gradient_checkpointing=True,
    evaluation_strategy="steps",
    eval_steps=total_num_steps // 6,
    # logging strategies
    logging_dir=f"{output_dir}/logs",
    logging_strategy="steps",
    logging_steps=8,
    save_strategy="epoch", #changed to epoch so we save every epoch i guess?
)

In [17]:
# print("CREATING A TEMPORARY COPY OF TRAIN DATASET TRUNCATED FROM 9240 ONWARDS IN HOPE OF FINDING CULPRIT")
# temp_dataset = train_dataset_with_texts[9240:]
# # do a deep copy of this:
# import copy
# temp_dataset_2 = copy.deepcopy(temp_dataset)

In [18]:
# from utils import LLMSampleCB, token_accuracy
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_dataset_with_texts,
    dataset_text_field="text",
    eval_dataset=test_dataset,
    packing=True,
    max_seq_length=1024,
    args=training_args,
    formatting_func=create_prompt,
    # compute_metrics=token_accuracy,
)



In [19]:
len(train_dataset_with_texts)
#train_dataset_with_texts is of type dataset. we want to create a copy that exists only 
#of the items from 9240 onwards
print(type(train_dataset_with_texts))

<class 'datasets.arrow_dataset.Dataset'>


In [None]:
#very hacky but maybe this will work:
tokenizer.model_input_names = ['labels', 'input_ids', 'attention_mask']
# trainer.args.train_batch_size = 1
# self.args.train_batch_size

#also hacky, but could work:
tokenizer.pad_token = tokenizer.eos_token
print("Starting training")
trainer.train()
wandb.finish()

In [None]:
import os
print(os.path.abspath(output_dir))

In [None]:
trainer.save_model(output_dir)
#print contents of output_dir
!ls -l $output_dir
#print full path of output_dir
# !pwd $output_dir

In [None]:
#load this local model here and use it to generate some text
output_dir = "/home/service/BioLlama/utilities/finetuning/biollama_training_output/"
print(output_dir)

from transformers import AutoModelForCausalLM, AutoTokenizer
import time
from utilities.biollama import BioLlama
#answers = ["Sarcoplasmic reticulum Ca(2+)-ATPase"] # or "SERCA","serca2"

chunk_length = 32

BioLlama = BioLlama(model_id=output_dir, chunk_length=chunk_length, RETRO_layer_ids = [15], training=False)
# num_tokens, text = BioLlama.generate(prompt=prompt, max_new_tokens=35)

# new_tokenizer = AutoTokenizer.from_pretrained(output_dir)
# new_model = AutoModelForCausalLM.from_pretrained(output_dir, device_map="auto")
prompt  = '<QUESTION>A 23-year-old pregnant woman at 22 weeks gestation presents with burning upon urination. She states it started 1 day ago and has been worsening despite drinking more water and taking cranberry extract. She otherwise feels well and is followed by a doctor for her pregnancy. Her temperature is 97.7°F (36.5°C), blood pressure is 122/77 mmHg, pulse is 80/min, respirations are 19/min, and oxygen saturation is 98% on room air. Physical exam is notable for an absence of costovertebral angle tenderness and a gravid uterus. Which of the following is the best treatment for this patient? \n (A) Ampicillin\n (B) Ceftriaxone\n (C) Ciprofloxacin\n (D) Doxycycline\n (E) Nitrofurantoin</QUESTION>\n<ANSWER> '
prompt2 = '<QUESTION>A 3-month-old baby died suddenly at night while asleep. His mother noticed that he had died only after she awoke in the morning. No cause of death was determined based on the autopsy. Which of the following precautions could have prevented the death of the baby? \n (A) Placing the infant in a supine position on a firm mattress while sleeping\n (B) Routine postnatal electrocardiogram (ECG)\n (C) Keeping the infant covered and maintaining a high room temperature\n (D) Application of a device to maintain the sleeping position\n (E) Avoiding pacifier use during sleep</QUESTION>\n<ANSWER> '
prompt3 = "<QUESTION>A mother brings her 3-week-old infant to the pediatrician's office because she is concerned about his feeding habits. He was born without complications and has not had any medical problems up until this time. However, for the past 4 days, he has been fussy, is regurgitating all of his feeds, and his vomit is yellow in color. On physical exam, the child's abdomen is minimally distended but no other abnormalities are appreciated. Which of the following embryologic errors could account for this presentation? \n (A) Abnormal migration of ventral pancreatic bud\n (B) Complete failure of proximal duodenum to recanalize\n (C) Error in neural crest cell migration\n (D) Abnormal hypertrophy of the pylorus\n (E) Failure of lateral body folds to move ventrally and fuse in the midline</QUESTION>\n<ANSWER> "
prompt4 = "<QUESTION>A 20-year-old woman presents with menorrhagia for the past several years. She says that her menses “have always been heavy”, and she has experienced easy bruising for as long as she can remember. Family history is significant for her mother, who had similar problems with bruising easily. The patient's vital signs include: heart rate 98/min, respiratory rate 14/min, temperature 36.1°C (96.9°F), and blood pressure 110/87 mm Hg. Physical examination is unremarkable. Laboratory tests show the following: platelet count 200,000/mm3, PT 12 seconds, and PTT 43 seconds. Which of the following is the most likely cause of this patient’s symptoms? \n (A) Factor V Leiden\n (B) Hemophilia A\n (C) Lupus anticoagulant\n (D) Protein C deficiency\n (E) Von Willebrand disease</QUESTION>\n<ANSWER> "
# input_ids = new_tokenizer.encode(prompt, return_tensors="pt")
# input_ids = new_tokenizer.encode(prompt, return_tensors="pt")

# print(input_ids)
# print(input_ids.shape)

# output = new_model.generate(input_ids, max_new_tokens=35, do_sample=True, top_p=0.95, top_k=60)
# print(new_tokenizer.decode(output[0], skip_special_tokens=True))

time_before_generation = time.time()
num_tokens, text = BioLlama.generate(prompt=prompt4, max_new_tokens=15)
time_after = time.time()

print("***Generating***")
print(text)
print(f"Time taken for generation: {time_after - time_before_generation}")
print(f"Tokens per second: {num_tokens/(time_after - time_before_generation)}")