In [3]:
%pip install \
accelerate==0.23.0 \
bitsandbytes==0.41.1 \
datasets==2.13.0 \
openai==0.28.1 \
peft==0.4.0 \
safetensors==0.4.0 \
transformers==4.34.0 \
trl==0.4.7


Note: you may need to restart the kernel to use updated packages.


In [4]:
%pip install py7zr

Note: you may need to restart the kernel to use updated packages.


In [5]:
import pandas as pd
from datasets import Dataset

# Load the CSV file into a pandas DataFrame
df = pd.read_csv("MeQSum.csv")

# Convert the DataFrame to a datasets.Dataset object
dataset = Dataset.from_pandas(df)

# Split the dataset into training and test datasets
train_dataset = dataset.select(range(800))
test_dataset = dataset.select(range(800, len(dataset)))

# Print information about the datasets
print("Train Dataset:")
print(train_dataset)

print("\nTest Dataset:")
print(test_dataset)


Train Dataset:
Dataset({
    features: ['CHQ', 'Summary'],
    num_rows: 800
})

Test Dataset:
Dataset({
    features: ['CHQ', 'Summary'],
    num_rows: 200
})


  from .autonotebook import tqdm as notebook_tqdm


In [6]:
print(train_dataset)

Dataset({
    features: ['CHQ', 'Summary'],
    num_rows: 800
})


In [7]:
train_dataset[0]

{'CHQ': 'SUBJECT: who and where to get cetirizine - D\nMESSAGE: I need/want to know who manufscturs Cetirizine. My Walmart is looking for a new supply and are not getting the recent',
 'Summary': 'Who manufactures cetirizine?'}

In [8]:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import torch

model_id = "meta-llama/Llama-2-7b-chat-hf"
access_token = "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"  # Replace with your actual token

# Set up the configuration for quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16
)

# Authenticate and load the model
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    use_cache=False,
    device_map="auto",
    use_auth_token=access_token  # Use the token for authentication
)


Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.25it/s]


In [9]:
def prompt_formatter(sample):
	return f"""<s>### Instruction:
You are a helpful, respectful and honest assistant. \
Your task is to summarize the following consumer health query. \
Your answer should be based on the provided consumer health query only.

### Consumer Health Query:
{sample['CHQ']}

### Summary:
{sample['Summary']} </s>"""

n = 0
print(prompt_formatter(train_dataset[n]))

<s>### Instruction:
You are a helpful, respectful and honest assistant. Your task is to summarize the following consumer health query. Your answer should be based on the provided consumer health query only.

### Consumer Health Query:
SUBJECT: who and where to get cetirizine - D
MESSAGE: I need/want to know who manufscturs Cetirizine. My Walmart is looking for a new supply and are not getting the recent

### Summary:
Who manufactures cetirizine? </s>


In [10]:
!huggingface-cli login --token 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

Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /home/bio/.cache/huggingface/token
Login successful


In [11]:
%load_ext autoreload
%autoreload 2

In [12]:
from transformers import TrainingArguments, AutoTokenizer
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer

#
# construct a Peft model.
# the QLoRA paper recommends LoRA dropout = 0.05 for small models (7B, 13B)
#
peft_config = LoraConfig(
    r=8,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, peft_config)

#
# set up the trainer
#
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

args = TrainingArguments(
    output_dir="llama2-7b-chat-meqsum",
    num_train_epochs=10,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=2,
    logging_steps=4,
    save_strategy="epoch",
    learning_rate=2e-4,
    optim="paged_adamw_8bit",
    bf16=False,  # Disable bf16 precision
    fp16=True,   # Enable fp16 precision
    tf32=False,  # Disable tf32 precision (optional, depends on your setup)
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="constant",
    disable_tqdm=False,
    report_to="none",

)


In [13]:
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    peft_config=peft_config,
    max_seq_length=1024,
    tokenizer=tokenizer,
    packing=True,
    formatting_func=prompt_formatter,
    args=args,
)

In [14]:
import gc

gc.collect()

torch.cuda.empty_cache()

In [15]:
trainer.train()

  0%|          | 0/4000 [00:00<?, ?it/s]You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  0%|          | 4/4000 [00:08<2:14:33,  2.02s/it]

{'loss': 2.066, 'learning_rate': 0.0002, 'epoch': 0.01}


  0%|          | 8/4000 [00:16<2:09:40,  1.95s/it]

{'loss': 1.7871, 'learning_rate': 0.0002, 'epoch': 0.02}


  0%|          | 12/4000 [00:23<2:08:27,  1.93s/it]

{'loss': 1.8278, 'learning_rate': 0.0002, 'epoch': 0.03}


  0%|          | 16/4000 [00:31<2:08:08,  1.93s/it]

{'loss': 1.8316, 'learning_rate': 0.0002, 'epoch': 0.04}


  0%|          | 20/4000 [00:39<2:08:11,  1.93s/it]

{'loss': 1.6887, 'learning_rate': 0.0002, 'epoch': 0.05}


  1%|          | 24/4000 [00:47<2:08:03,  1.93s/it]

{'loss': 1.771, 'learning_rate': 0.0002, 'epoch': 0.06}


  1%|          | 28/4000 [00:54<2:08:16,  1.94s/it]

{'loss': 1.7668, 'learning_rate': 0.0002, 'epoch': 0.07}


  1%|          | 32/4000 [01:02<2:07:41,  1.93s/it]

{'loss': 1.5569, 'learning_rate': 0.0002, 'epoch': 0.08}


  1%|          | 36/4000 [01:10<2:07:46,  1.93s/it]

{'loss': 1.632, 'learning_rate': 0.0002, 'epoch': 0.09}


  1%|          | 40/4000 [01:17<2:07:16,  1.93s/it]

{'loss': 1.5671, 'learning_rate': 0.0002, 'epoch': 0.1}


  1%|          | 44/4000 [01:25<2:07:13,  1.93s/it]

{'loss': 1.5857, 'learning_rate': 0.0002, 'epoch': 0.11}


  1%|          | 48/4000 [01:33<2:07:23,  1.93s/it]

{'loss': 1.6277, 'learning_rate': 0.0002, 'epoch': 0.12}


  1%|▏         | 52/4000 [01:41<2:07:03,  1.93s/it]

{'loss': 1.5439, 'learning_rate': 0.0002, 'epoch': 0.13}


  1%|▏         | 56/4000 [01:48<2:06:52,  1.93s/it]

{'loss': 1.4477, 'learning_rate': 0.0002, 'epoch': 0.14}


  2%|▏         | 60/4000 [01:56<2:06:32,  1.93s/it]

{'loss': 1.4691, 'learning_rate': 0.0002, 'epoch': 0.15}


  2%|▏         | 64/4000 [02:04<2:06:25,  1.93s/it]

{'loss': 1.437, 'learning_rate': 0.0002, 'epoch': 0.16}


  2%|▏         | 68/4000 [02:12<2:06:35,  1.93s/it]

{'loss': 1.404, 'learning_rate': 0.0002, 'epoch': 0.17}


  2%|▏         | 72/4000 [02:19<2:07:36,  1.95s/it]

{'loss': 1.3931, 'learning_rate': 0.0002, 'epoch': 1.01}


  2%|▏         | 76/4000 [02:27<2:06:40,  1.94s/it]

{'loss': 1.4074, 'learning_rate': 0.0002, 'epoch': 1.02}


  2%|▏         | 80/4000 [02:35<2:06:30,  1.94s/it]

{'loss': 1.478, 'learning_rate': 0.0002, 'epoch': 1.03}


  2%|▏         | 84/4000 [02:43<2:06:10,  1.93s/it]

{'loss': 1.3179, 'learning_rate': 0.0002, 'epoch': 1.04}


  2%|▏         | 88/4000 [02:50<2:06:33,  1.94s/it]

{'loss': 1.3354, 'learning_rate': 0.0002, 'epoch': 1.05}


  2%|▏         | 92/4000 [02:58<2:06:17,  1.94s/it]

{'loss': 1.3853, 'learning_rate': 0.0002, 'epoch': 1.06}


  2%|▏         | 96/4000 [03:06<2:05:49,  1.93s/it]

{'loss': 1.3562, 'learning_rate': 0.0002, 'epoch': 1.07}


  2%|▎         | 100/4000 [03:14<2:05:39,  1.93s/it]

{'loss': 1.4242, 'learning_rate': 0.0002, 'epoch': 1.08}


  3%|▎         | 104/4000 [03:21<2:05:32,  1.93s/it]

{'loss': 1.2911, 'learning_rate': 0.0002, 'epoch': 1.09}


  3%|▎         | 108/4000 [03:29<2:05:20,  1.93s/it]

{'loss': 1.3348, 'learning_rate': 0.0002, 'epoch': 1.1}


  3%|▎         | 112/4000 [03:37<2:05:01,  1.93s/it]

{'loss': 1.2842, 'learning_rate': 0.0002, 'epoch': 1.11}


  3%|▎         | 116/4000 [03:45<2:05:11,  1.93s/it]

{'loss': 1.3963, 'learning_rate': 0.0002, 'epoch': 1.12}


  3%|▎         | 120/4000 [03:52<2:04:48,  1.93s/it]

{'loss': 1.5367, 'learning_rate': 0.0002, 'epoch': 1.13}


  3%|▎         | 124/4000 [04:00<2:04:37,  1.93s/it]

{'loss': 1.1995, 'learning_rate': 0.0002, 'epoch': 1.14}


  3%|▎         | 128/4000 [04:08<2:04:55,  1.94s/it]

{'loss': 1.295, 'learning_rate': 0.0002, 'epoch': 1.15}


  3%|▎         | 132/4000 [04:15<2:04:49,  1.94s/it]

{'loss': 1.3329, 'learning_rate': 0.0002, 'epoch': 1.16}


  3%|▎         | 136/4000 [04:23<2:04:44,  1.94s/it]

{'loss': 1.481, 'learning_rate': 0.0002, 'epoch': 1.17}


  4%|▎         | 140/4000 [04:31<2:06:09,  1.96s/it]

{'loss': 1.2545, 'learning_rate': 0.0002, 'epoch': 2.0}


  4%|▎         | 144/4000 [04:39<2:04:29,  1.94s/it]

{'loss': 1.2863, 'learning_rate': 0.0002, 'epoch': 2.02}


  4%|▎         | 148/4000 [04:47<2:04:13,  1.94s/it]

{'loss': 1.2434, 'learning_rate': 0.0002, 'epoch': 2.02}


  4%|▍         | 152/4000 [04:54<2:03:54,  1.93s/it]

{'loss': 1.2757, 'learning_rate': 0.0002, 'epoch': 2.04}


  4%|▍         | 156/4000 [05:02<2:03:33,  1.93s/it]

{'loss': 1.2095, 'learning_rate': 0.0002, 'epoch': 2.04}


  4%|▍         | 160/4000 [05:10<2:03:38,  1.93s/it]

{'loss': 1.3356, 'learning_rate': 0.0002, 'epoch': 2.06}


  4%|▍         | 164/4000 [05:17<2:03:29,  1.93s/it]

{'loss': 1.3166, 'learning_rate': 0.0002, 'epoch': 2.06}


  4%|▍         | 168/4000 [05:25<2:03:48,  1.94s/it]

{'loss': 1.3626, 'learning_rate': 0.0002, 'epoch': 2.08}


  4%|▍         | 172/4000 [05:33<2:03:20,  1.93s/it]

{'loss': 1.2758, 'learning_rate': 0.0002, 'epoch': 2.08}


  4%|▍         | 176/4000 [05:41<2:03:04,  1.93s/it]

{'loss': 1.1727, 'learning_rate': 0.0002, 'epoch': 2.1}


  4%|▍         | 180/4000 [05:48<2:03:07,  1.93s/it]

{'loss': 1.3802, 'learning_rate': 0.0002, 'epoch': 2.1}


  5%|▍         | 184/4000 [05:56<2:03:19,  1.94s/it]

{'loss': 1.2125, 'learning_rate': 0.0002, 'epoch': 2.12}


  5%|▍         | 188/4000 [06:04<2:02:49,  1.93s/it]

{'loss': 1.4167, 'learning_rate': 0.0002, 'epoch': 2.12}


  5%|▍         | 192/4000 [06:12<2:02:54,  1.94s/it]

{'loss': 1.2277, 'learning_rate': 0.0002, 'epoch': 2.13}


  5%|▍         | 196/4000 [06:19<2:02:28,  1.93s/it]

{'loss': 1.3319, 'learning_rate': 0.0002, 'epoch': 2.15}


  5%|▌         | 200/4000 [06:27<2:01:44,  1.92s/it]

{'loss': 1.3447, 'learning_rate': 0.0002, 'epoch': 2.15}


  5%|▌         | 204/4000 [06:35<2:02:25,  1.94s/it]

{'loss': 1.195, 'learning_rate': 0.0002, 'epoch': 2.17}


  5%|▌         | 208/4000 [06:43<2:05:16,  1.98s/it]

{'loss': 1.2635, 'learning_rate': 0.0002, 'epoch': 3.0}


  5%|▌         | 212/4000 [06:50<2:02:41,  1.94s/it]

{'loss': 1.1371, 'learning_rate': 0.0002, 'epoch': 3.01}


  5%|▌         | 216/4000 [06:58<2:02:36,  1.94s/it]

{'loss': 1.2451, 'learning_rate': 0.0002, 'epoch': 3.02}


  6%|▌         | 220/4000 [07:06<2:01:57,  1.94s/it]

{'loss': 1.1364, 'learning_rate': 0.0002, 'epoch': 3.03}


  6%|▌         | 224/4000 [07:14<2:01:38,  1.93s/it]

{'loss': 1.1423, 'learning_rate': 0.0002, 'epoch': 3.04}


  6%|▌         | 228/4000 [07:21<2:02:07,  1.94s/it]

{'loss': 1.1962, 'learning_rate': 0.0002, 'epoch': 3.05}


  6%|▌         | 232/4000 [07:29<2:01:43,  1.94s/it]

{'loss': 1.1956, 'learning_rate': 0.0002, 'epoch': 3.06}


  6%|▌         | 236/4000 [07:37<2:01:21,  1.93s/it]

{'loss': 1.2208, 'learning_rate': 0.0002, 'epoch': 3.07}


  6%|▌         | 240/4000 [07:45<2:01:16,  1.94s/it]

{'loss': 1.1875, 'learning_rate': 0.0002, 'epoch': 3.08}


  6%|▌         | 244/4000 [07:52<2:01:26,  1.94s/it]

{'loss': 1.2052, 'learning_rate': 0.0002, 'epoch': 3.09}


  6%|▌         | 248/4000 [08:00<2:01:18,  1.94s/it]

{'loss': 1.1782, 'learning_rate': 0.0002, 'epoch': 3.1}


  6%|▋         | 252/4000 [08:08<2:01:10,  1.94s/it]

{'loss': 1.1695, 'learning_rate': 0.0002, 'epoch': 3.11}


  6%|▋         | 256/4000 [08:16<2:01:57,  1.95s/it]

{'loss': 1.1837, 'learning_rate': 0.0002, 'epoch': 3.12}


  6%|▋         | 260/4000 [08:24<2:01:00,  1.94s/it]

{'loss': 1.0713, 'learning_rate': 0.0002, 'epoch': 3.13}


  7%|▋         | 264/4000 [08:31<2:00:28,  1.93s/it]

{'loss': 1.2614, 'learning_rate': 0.0002, 'epoch': 3.14}


  7%|▋         | 268/4000 [08:39<2:02:49,  1.97s/it]

{'loss': 1.2924, 'learning_rate': 0.0002, 'epoch': 3.15}


  7%|▋         | 272/4000 [08:47<2:02:56,  1.98s/it]

{'loss': 1.258, 'learning_rate': 0.0002, 'epoch': 3.16}


  7%|▋         | 276/4000 [08:55<2:02:21,  1.97s/it]

{'loss': 1.2997, 'learning_rate': 0.0002, 'epoch': 3.17}


  7%|▋         | 280/4000 [09:03<2:02:29,  1.98s/it]

{'loss': 1.2464, 'learning_rate': 0.0002, 'epoch': 4.01}


  7%|▋         | 284/4000 [09:11<2:00:37,  1.95s/it]

{'loss': 1.0327, 'learning_rate': 0.0002, 'epoch': 4.02}


  7%|▋         | 288/4000 [09:19<1:59:59,  1.94s/it]

{'loss': 1.0894, 'learning_rate': 0.0002, 'epoch': 4.03}


  7%|▋         | 292/4000 [09:26<1:59:40,  1.94s/it]

{'loss': 1.0919, 'learning_rate': 0.0002, 'epoch': 4.04}


  7%|▋         | 296/4000 [09:34<1:59:29,  1.94s/it]

{'loss': 1.1295, 'learning_rate': 0.0002, 'epoch': 4.05}


  8%|▊         | 300/4000 [09:42<1:59:29,  1.94s/it]

{'loss': 1.12, 'learning_rate': 0.0002, 'epoch': 4.06}


  8%|▊         | 304/4000 [09:50<1:59:05,  1.93s/it]

{'loss': 1.0878, 'learning_rate': 0.0002, 'epoch': 4.07}


  8%|▊         | 308/4000 [09:57<1:59:20,  1.94s/it]

{'loss': 1.0348, 'learning_rate': 0.0002, 'epoch': 4.08}


  8%|▊         | 312/4000 [10:05<1:58:30,  1.93s/it]

{'loss': 1.1492, 'learning_rate': 0.0002, 'epoch': 4.09}


  8%|▊         | 316/4000 [10:13<1:58:30,  1.93s/it]

{'loss': 1.0457, 'learning_rate': 0.0002, 'epoch': 4.1}


  8%|▊         | 320/4000 [10:20<1:58:57,  1.94s/it]

{'loss': 1.1969, 'learning_rate': 0.0002, 'epoch': 4.11}


  8%|▊         | 324/4000 [10:28<1:58:54,  1.94s/it]

{'loss': 1.0878, 'learning_rate': 0.0002, 'epoch': 4.12}


  8%|▊         | 328/4000 [10:36<1:58:42,  1.94s/it]

{'loss': 1.0716, 'learning_rate': 0.0002, 'epoch': 4.13}


  8%|▊         | 332/4000 [10:44<1:58:13,  1.93s/it]

{'loss': 1.0607, 'learning_rate': 0.0002, 'epoch': 4.14}


  8%|▊         | 336/4000 [10:51<1:58:09,  1.94s/it]

{'loss': 1.1082, 'learning_rate': 0.0002, 'epoch': 4.15}


  8%|▊         | 340/4000 [10:59<1:58:14,  1.94s/it]

{'loss': 1.0939, 'learning_rate': 0.0002, 'epoch': 4.16}


  9%|▊         | 344/4000 [11:07<1:57:52,  1.93s/it]

{'loss': 1.1325, 'learning_rate': 0.0002, 'epoch': 4.17}


  9%|▊         | 348/4000 [11:15<1:59:18,  1.96s/it]

{'loss': 1.003, 'learning_rate': 0.0002, 'epoch': 5.01}


  9%|▉         | 352/4000 [11:23<1:58:05,  1.94s/it]

{'loss': 1.1069, 'learning_rate': 0.0002, 'epoch': 5.02}


  9%|▉         | 356/4000 [11:30<1:57:41,  1.94s/it]

{'loss': 0.9751, 'learning_rate': 0.0002, 'epoch': 5.03}


  9%|▉         | 360/4000 [11:38<1:57:33,  1.94s/it]

{'loss': 0.9583, 'learning_rate': 0.0002, 'epoch': 5.04}


  9%|▉         | 364/4000 [11:46<1:57:29,  1.94s/it]

{'loss': 1.0833, 'learning_rate': 0.0002, 'epoch': 5.05}


  9%|▉         | 368/4000 [11:54<1:57:02,  1.93s/it]

{'loss': 0.9555, 'learning_rate': 0.0002, 'epoch': 5.06}


  9%|▉         | 372/4000 [12:01<1:57:02,  1.94s/it]

{'loss': 0.9841, 'learning_rate': 0.0002, 'epoch': 5.07}


  9%|▉         | 376/4000 [12:09<1:56:55,  1.94s/it]

{'loss': 1.0234, 'learning_rate': 0.0002, 'epoch': 5.08}


 10%|▉         | 380/4000 [12:17<1:56:44,  1.94s/it]

{'loss': 0.9634, 'learning_rate': 0.0002, 'epoch': 5.09}


 10%|▉         | 384/4000 [12:25<1:56:35,  1.93s/it]

{'loss': 1.013, 'learning_rate': 0.0002, 'epoch': 5.1}


 10%|▉         | 388/4000 [12:32<1:56:19,  1.93s/it]

{'loss': 0.9498, 'learning_rate': 0.0002, 'epoch': 5.11}


 10%|▉         | 392/4000 [12:40<1:56:14,  1.93s/it]

{'loss': 1.0, 'learning_rate': 0.0002, 'epoch': 5.12}


 10%|▉         | 396/4000 [12:48<1:56:27,  1.94s/it]

{'loss': 0.9775, 'learning_rate': 0.0002, 'epoch': 5.13}


 10%|█         | 400/4000 [12:56<1:56:33,  1.94s/it]

{'loss': 0.9961, 'learning_rate': 0.0002, 'epoch': 5.14}


 10%|█         | 404/4000 [13:03<1:56:03,  1.94s/it]

{'loss': 0.9708, 'learning_rate': 0.0002, 'epoch': 5.15}


 10%|█         | 408/4000 [13:11<1:55:52,  1.94s/it]

{'loss': 0.9954, 'learning_rate': 0.0002, 'epoch': 5.16}


 10%|█         | 412/4000 [13:19<1:55:38,  1.93s/it]

{'loss': 1.0176, 'learning_rate': 0.0002, 'epoch': 5.17}


 10%|█         | 416/4000 [13:27<1:57:35,  1.97s/it]

{'loss': 0.8582, 'learning_rate': 0.0002, 'epoch': 6.0}


 10%|█         | 420/4000 [13:34<1:56:02,  1.94s/it]

{'loss': 0.8817, 'learning_rate': 0.0002, 'epoch': 6.01}


 11%|█         | 424/4000 [13:42<1:55:44,  1.94s/it]

{'loss': 0.8246, 'learning_rate': 0.0002, 'epoch': 6.03}


 11%|█         | 428/4000 [13:50<1:56:47,  1.96s/it]

{'loss': 0.8312, 'learning_rate': 0.0002, 'epoch': 6.04}


 11%|█         | 432/4000 [13:58<1:56:08,  1.95s/it]

{'loss': 0.788, 'learning_rate': 0.0002, 'epoch': 6.04}


 11%|█         | 436/4000 [14:06<1:55:24,  1.94s/it]

{'loss': 0.8695, 'learning_rate': 0.0002, 'epoch': 6.05}


 11%|█         | 440/4000 [14:13<1:55:59,  1.96s/it]

{'loss': 0.8507, 'learning_rate': 0.0002, 'epoch': 6.07}


 11%|█         | 444/4000 [14:21<1:56:35,  1.97s/it]

{'loss': 0.9669, 'learning_rate': 0.0002, 'epoch': 6.08}


 11%|█         | 448/4000 [14:29<1:55:06,  1.94s/it]

{'loss': 0.8534, 'learning_rate': 0.0002, 'epoch': 6.08}


 11%|█▏        | 452/4000 [14:37<1:54:34,  1.94s/it]

{'loss': 0.8673, 'learning_rate': 0.0002, 'epoch': 6.09}


 11%|█▏        | 456/4000 [14:45<1:54:35,  1.94s/it]

{'loss': 0.8737, 'learning_rate': 0.0002, 'epoch': 6.11}


 12%|█▏        | 460/4000 [14:52<1:54:19,  1.94s/it]

{'loss': 0.9189, 'learning_rate': 0.0002, 'epoch': 6.12}


 12%|█▏        | 464/4000 [15:00<1:53:52,  1.93s/it]

{'loss': 0.8992, 'learning_rate': 0.0002, 'epoch': 6.12}


 12%|█▏        | 468/4000 [15:08<1:53:42,  1.93s/it]

{'loss': 0.9342, 'learning_rate': 0.0002, 'epoch': 6.13}


 12%|█▏        | 472/4000 [15:16<1:53:45,  1.93s/it]

{'loss': 0.9168, 'learning_rate': 0.0002, 'epoch': 6.14}


 12%|█▏        | 476/4000 [15:23<1:53:44,  1.94s/it]

{'loss': 0.8167, 'learning_rate': 0.0002, 'epoch': 6.16}


 12%|█▏        | 480/4000 [15:31<1:53:44,  1.94s/it]

{'loss': 1.0379, 'learning_rate': 0.0002, 'epoch': 6.17}


 12%|█▏        | 484/4000 [15:39<1:56:06,  1.98s/it]

{'loss': 0.9059, 'learning_rate': 0.0002, 'epoch': 7.0}


 12%|█▏        | 488/4000 [15:47<1:54:06,  1.95s/it]

{'loss': 0.8056, 'learning_rate': 0.0002, 'epoch': 7.01}


 12%|█▏        | 492/4000 [15:55<1:53:38,  1.94s/it]

{'loss': 0.7358, 'learning_rate': 0.0002, 'epoch': 7.02}


 12%|█▏        | 496/4000 [16:02<1:53:42,  1.95s/it]

{'loss': 0.7545, 'learning_rate': 0.0002, 'epoch': 7.03}


 12%|█▎        | 500/4000 [16:10<1:53:02,  1.94s/it]

{'loss': 0.7752, 'learning_rate': 0.0002, 'epoch': 7.04}


 13%|█▎        | 504/4000 [16:18<1:52:48,  1.94s/it]

{'loss': 0.7598, 'learning_rate': 0.0002, 'epoch': 7.05}


 13%|█▎        | 508/4000 [16:26<1:52:51,  1.94s/it]

{'loss': 0.809, 'learning_rate': 0.0002, 'epoch': 7.06}


 13%|█▎        | 512/4000 [16:33<1:54:22,  1.97s/it]

{'loss': 0.7867, 'learning_rate': 0.0002, 'epoch': 7.07}


 13%|█▎        | 516/4000 [16:41<1:52:00,  1.93s/it]

{'loss': 0.7379, 'learning_rate': 0.0002, 'epoch': 7.08}


 13%|█▎        | 520/4000 [16:49<1:54:23,  1.97s/it]

{'loss': 0.7078, 'learning_rate': 0.0002, 'epoch': 7.09}


 13%|█▎        | 524/4000 [16:57<1:52:48,  1.95s/it]

{'loss': 0.8203, 'learning_rate': 0.0002, 'epoch': 7.1}


 13%|█▎        | 528/4000 [17:04<1:52:03,  1.94s/it]

{'loss': 0.7121, 'learning_rate': 0.0002, 'epoch': 7.11}


 13%|█▎        | 532/4000 [17:12<1:52:01,  1.94s/it]

{'loss': 0.8042, 'learning_rate': 0.0002, 'epoch': 7.12}


 13%|█▎        | 536/4000 [17:20<1:51:56,  1.94s/it]

{'loss': 0.7618, 'learning_rate': 0.0002, 'epoch': 7.13}


 14%|█▎        | 540/4000 [17:28<1:51:30,  1.93s/it]

{'loss': 0.7665, 'learning_rate': 0.0002, 'epoch': 7.14}


 14%|█▎        | 544/4000 [17:35<1:51:23,  1.93s/it]

{'loss': 0.7867, 'learning_rate': 0.0002, 'epoch': 7.15}


 14%|█▎        | 548/4000 [17:43<1:51:32,  1.94s/it]

{'loss': 0.7553, 'learning_rate': 0.0002, 'epoch': 7.16}


 14%|█▍        | 552/4000 [17:51<1:52:02,  1.95s/it]

{'loss': 0.8819, 'learning_rate': 0.0002, 'epoch': 7.17}


 14%|█▍        | 556/4000 [17:59<1:51:43,  1.95s/it]

{'loss': 0.6363, 'learning_rate': 0.0002, 'epoch': 8.01}


 14%|█▍        | 560/4000 [18:07<1:50:56,  1.94s/it]

{'loss': 0.6919, 'learning_rate': 0.0002, 'epoch': 8.02}


 14%|█▍        | 564/4000 [18:14<1:50:45,  1.93s/it]

{'loss': 0.6105, 'learning_rate': 0.0002, 'epoch': 8.03}


 14%|█▍        | 568/4000 [18:22<1:50:35,  1.93s/it]

{'loss': 0.6422, 'learning_rate': 0.0002, 'epoch': 8.04}


 14%|█▍        | 572/4000 [18:30<1:50:21,  1.93s/it]

{'loss': 0.5242, 'learning_rate': 0.0002, 'epoch': 8.05}


 14%|█▍        | 576/4000 [18:38<1:50:24,  1.93s/it]

{'loss': 0.6248, 'learning_rate': 0.0002, 'epoch': 8.06}


 14%|█▍        | 580/4000 [18:45<1:50:02,  1.93s/it]

{'loss': 0.6631, 'learning_rate': 0.0002, 'epoch': 8.07}


 15%|█▍        | 584/4000 [18:53<1:50:07,  1.93s/it]

{'loss': 0.6197, 'learning_rate': 0.0002, 'epoch': 8.08}


 15%|█▍        | 588/4000 [19:01<1:49:57,  1.93s/it]

{'loss': 0.7303, 'learning_rate': 0.0002, 'epoch': 8.09}


 15%|█▍        | 592/4000 [19:08<1:50:06,  1.94s/it]

{'loss': 0.7027, 'learning_rate': 0.0002, 'epoch': 8.1}


 15%|█▍        | 596/4000 [19:16<1:49:49,  1.94s/it]

{'loss': 0.72, 'learning_rate': 0.0002, 'epoch': 8.11}


 15%|█▌        | 600/4000 [19:24<1:49:30,  1.93s/it]

{'loss': 0.6643, 'learning_rate': 0.0002, 'epoch': 8.12}


 15%|█▌        | 604/4000 [19:32<1:49:26,  1.93s/it]

{'loss': 0.722, 'learning_rate': 0.0002, 'epoch': 8.13}


 15%|█▌        | 608/4000 [19:39<1:49:47,  1.94s/it]

{'loss': 0.6222, 'learning_rate': 0.0002, 'epoch': 8.14}


 15%|█▌        | 612/4000 [19:47<1:49:15,  1.93s/it]

{'loss': 0.76, 'learning_rate': 0.0002, 'epoch': 8.15}


 15%|█▌        | 616/4000 [19:55<1:49:22,  1.94s/it]

{'loss': 0.6592, 'learning_rate': 0.0002, 'epoch': 8.16}


 16%|█▌        | 620/4000 [20:03<1:48:49,  1.93s/it]

{'loss': 0.6933, 'learning_rate': 0.0002, 'epoch': 8.17}


 16%|█▌        | 624/4000 [20:11<1:49:39,  1.95s/it]

{'loss': 0.5479, 'learning_rate': 0.0002, 'epoch': 9.01}


 16%|█▌        | 628/4000 [20:18<1:48:52,  1.94s/it]

{'loss': 0.4928, 'learning_rate': 0.0002, 'epoch': 9.02}


 16%|█▌        | 632/4000 [20:26<1:48:28,  1.93s/it]

{'loss': 0.5857, 'learning_rate': 0.0002, 'epoch': 9.03}


 16%|█▌        | 636/4000 [20:34<1:48:03,  1.93s/it]

{'loss': 0.5307, 'learning_rate': 0.0002, 'epoch': 9.04}


 16%|█▌        | 640/4000 [20:41<1:48:25,  1.94s/it]

{'loss': 0.4939, 'learning_rate': 0.0002, 'epoch': 9.05}


 16%|█▌        | 644/4000 [20:49<1:48:06,  1.93s/it]

{'loss': 0.581, 'learning_rate': 0.0002, 'epoch': 9.06}


 16%|█▌        | 648/4000 [20:57<1:48:00,  1.93s/it]

{'loss': 0.5117, 'learning_rate': 0.0002, 'epoch': 9.07}


 16%|█▋        | 652/4000 [21:05<1:47:56,  1.93s/it]

{'loss': 0.5995, 'learning_rate': 0.0002, 'epoch': 9.08}


 16%|█▋        | 656/4000 [21:12<1:47:38,  1.93s/it]

{'loss': 0.5124, 'learning_rate': 0.0002, 'epoch': 9.09}


 16%|█▋        | 660/4000 [21:20<1:47:22,  1.93s/it]

{'loss': 0.5418, 'learning_rate': 0.0002, 'epoch': 9.1}


 17%|█▋        | 664/4000 [21:28<1:47:15,  1.93s/it]

{'loss': 0.5624, 'learning_rate': 0.0002, 'epoch': 9.11}


 17%|█▋        | 668/4000 [21:36<1:47:07,  1.93s/it]

{'loss': 0.5983, 'learning_rate': 0.0002, 'epoch': 9.12}


 17%|█▋        | 672/4000 [21:43<1:47:13,  1.93s/it]

{'loss': 0.5429, 'learning_rate': 0.0002, 'epoch': 9.13}


 17%|█▋        | 676/4000 [21:51<1:47:15,  1.94s/it]

{'loss': 0.6862, 'learning_rate': 0.0002, 'epoch': 9.14}


 17%|█▋        | 680/4000 [21:59<1:46:46,  1.93s/it]

{'loss': 0.6781, 'learning_rate': 0.0002, 'epoch': 9.15}


 17%|█▋        | 684/4000 [22:06<1:46:55,  1.93s/it]

{'loss': 0.5494, 'learning_rate': 0.0002, 'epoch': 9.16}


 17%|█▋        | 688/4000 [22:14<1:46:52,  1.94s/it]

{'loss': 0.533, 'learning_rate': 0.0002, 'epoch': 9.17}


 17%|█▋        | 690/4000 [22:18<1:47:02,  1.94s/it]

{'train_runtime': 1338.7335, 'train_samples_per_second': 5.976, 'train_steps_per_second': 2.988, 'train_loss': 1.0474436658016149, 'epoch': 9.17}





TrainOutput(global_step=690, training_loss=1.0474436658016149, metrics={'train_runtime': 1338.7335, 'train_samples_per_second': 5.976, 'train_steps_per_second': 2.988, 'train_loss': 1.0474436658016149, 'epoch': 9.17})

# 3. Run inference using the fine-tuned model

In [16]:
trainer.save_model()

In [26]:
import torch
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer

model_folder = "llama2-7b-chat-meqsum"

# load both the adapter and the base model
model = AutoPeftModelForCausalLM.from_pretrained(
    model_folder,
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    load_in_4bit=True,
    device_map='auto'
)
# tokenizer = AutoTokenizer.from_pretrained(model_folder)


Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.21it/s]


In [27]:
print(model)

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(32000, 4096)
        (layers): ModuleList(
          (0-31): 32 x LlamaDecoderLayer(
            (self_attn): LlamaAttention(
              (q_proj): Linear4bit(
                in_features=4096, out_features=4096, bias=False
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
              (v_proj): Linear4

In [61]:
sample = test_dataset[50]

prompt = f"""### Instruction:
You are a helpful, respectful and honest assistant. \
Your task is to summarize the following consumer  health query. \
Your answer should be based on the provided text only.

### Consumer Health Query:
{sample['CHQ']}

### Summary:
"""

print(prompt)

### Instruction:
You are a helpful, respectful and honest assistant. Your task is to summarize the following consumer  health query. Your answer should be based on the provided text only.

### Consumer Health Query:
KNEE OSTEOARTHRITIS.
 Good morning about 20 years ago I suffered ruptured anterior cruciate ligament and removal of domestic law meniscus, was operated and made me clancy, at present unfortunately my knee is totally affected and I have arthritis and severe pain, according to a dr traumatologo commented me I need a knee prosthesis my question is can you treat me, or turn me can recommend doctors or hospitals to treat in the U.S.   [NAME]

### Summary:



In [64]:
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
outputs = model.generate(input_ids=input_ids, max_new_tokens=50, temperature=0.7)

print('Output:\n',
      tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0][len(prompt):])
print('\nGround truth:\n', sample['Summary'])


Output:
 Whatbara patronlibs setContentView shed Monte pó patronлон Monte Montenten Unabararog shedierz patron Philip patronyanrog patronbara patronga Van Monte Montebaradatenigaitzen Рес Kriegs patron setContentView setContentView partiesbara KriegsNaN shedSBNSBNierz Kriegsphiarog

Ground truth:
 How can I find physician(s) or hospital(s) who specialize in knee osteoarthritis?
