# Libraries

In [1]:
# %%capture

# %pip install -U peft
# %pip install -U trl
# %pip install -U bitsandbytes 

In [2]:
import os, torch, wandb

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import (
    LoraConfig,
    PeftModel,
    prepare_model_for_kbit_training,
    get_peft_model,
)

from datasets import load_dataset
from trl import SFTTrainer, setup_chat_format
from dataclasses import dataclass

  from .autonotebook import tqdm as notebook_tqdm


## Setup Huggingface 🤗 & Wandb

In [3]:
from huggingface_hub import login

login(token = "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")

wandb.login(key="ce84c3af2fdee6c3e2696b2a4ad96af49a3dd86e")


The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.


[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.


Token is valid (permission: fineGrained).
Your token has been saved to C:\Users\USER_ELISEY\.cache\huggingface\token
Login successful


[34m[1mwandb[0m: Currently logged in as: [33mez1071[0m ([33mez1071-mipt[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: C:\Users\USER_ELISEY\_netrc


True

In [4]:
run = wandb.init(
    project='Fine-tune Llama 3.1 8B on Russian Dataset', 
    job_type="training"
)

In [5]:
@dataclass
class Config:
#     model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
#     model_name = "AnatoliiPotapov/T-lite-instruct-0.1"
    model_name = "google/gemma-2-9b-it"
    dataset_name = "ruslanmv/ai-medical-chatbot"
    new_model = "llama-3.1-8b-chat-doctor"
    torch_dtype = torch.float16
    attn_implementation = "eager"
cfg = Config()

# Loading model and tokenizer

In [6]:
# QLoRA config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=cfg.torch_dtype,
    bnb_4bit_use_double_quant=True,
)

# Load model
model = AutoModelForCausalLM.from_pretrained(
    "C:\\Users\\USER_ELISEY\\gemma",
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation=cfg.attn_implementation
)

Loading checkpoint shards: 100%|██████████| 4/4 [00:07<00:00,  1.94s/it]


In [7]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("C:\\Users\\USER_ELISEY\\gemma")
model, tokenizer = setup_chat_format(model, tokenizer)
tokenizer.padding_side = 'right'
tokenizer.padding_token = '<|pad|>'

## LoRA adapter

In [8]:
# LoRA config
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)
model = get_peft_model(model, peft_config)

# Data

## Load

In [9]:
dataset = load_dataset(cfg.dataset_name, split="all")

## Format to chat 

In [10]:
def format_chat_template(row):
    row_json = [{"role": "user", "content": row["Patient"]},
               {"role": "assistant", "content": row["Doctor"]}]
    row["text"] = tokenizer.apply_chat_template(row_json, tokenize=False)
    return row

In [11]:
dataset = dataset.map(
    format_chat_template,
    num_proc=1,
)

## Select only part

In [12]:
dataset_sh = dataset.shuffle(seed=2024).select(range(10_000))

In [13]:
dataset_sh = dataset_sh.train_test_split(0.1)
dataset_sh

DatasetDict({
    train: Dataset({
        features: ['Description', 'Patient', 'Doctor', 'text'],
        num_rows: 9000
    })
    test: Dataset({
        features: ['Description', 'Patient', 'Doctor', 'text'],
        num_rows: 1000
    })
})

# Train model

## Training arguments

In [14]:
training_arguments = TrainingArguments(
    output_dir=cfg.new_model,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    optim="paged_adamw_32bit",
#     num_train_epochs=1,
    max_steps=500,
    eval_strategy="steps",
    eval_steps=500,
    logging_steps=10,
    warmup_steps=10,
    logging_strategy="steps",
    learning_rate=2e-4,
    fp16=False,
    bf16=False,
    group_by_length=True,
    report_to="wandb",
    run_name="Llama-3.1-medicine",
)

## Train model

In [15]:
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset_sh["train"],
    eval_dataset=dataset_sh["test"],
    peft_config=peft_config,
    max_seq_length=512,
    dataset_text_field="text",
    tokenizer=tokenizer,
    args=training_arguments,
    packing= False,
)


Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.
Map: 100%|██████████| 9000/9000 [00:01<00:00, 7568.75 examples/s]
Map: 100%|██████████| 1000/1000 [00:00<00:00, 7149.10 examples/s]
max_steps is given, it will override any value given in num_train_epochs


In [None]:
trainer.train()

  2%|▏         | 10/500 [07:42<5:50:03, 42.86s/it]

{'loss': 3.2385, 'grad_norm': 2.1750564575195312, 'learning_rate': 0.0002, 'epoch': 0.0}


  4%|▍         | 20/500 [13:30<4:37:07, 34.64s/it]

{'loss': 2.8138, 'grad_norm': 1.774752140045166, 'learning_rate': 0.0001959183673469388, 'epoch': 0.0}


  6%|▌         | 30/500 [18:52<4:17:23, 32.86s/it]

{'loss': 2.5374, 'grad_norm': 2.214508533477783, 'learning_rate': 0.00019183673469387756, 'epoch': 0.01}


  8%|▊         | 40/500 [24:27<4:12:28, 32.93s/it]

{'loss': 2.6435, 'grad_norm': 1.7735646963119507, 'learning_rate': 0.00018775510204081634, 'epoch': 0.01}


 10%|█         | 50/500 [29:01<2:46:49, 22.24s/it]

{'loss': 2.546, 'grad_norm': 2.6800169944763184, 'learning_rate': 0.00018367346938775512, 'epoch': 0.01}


 12%|█▏        | 60/500 [37:10<5:27:01, 44.59s/it]

{'loss': 2.6979, 'grad_norm': 1.3342658281326294, 'learning_rate': 0.0001795918367346939, 'epoch': 0.01}


 14%|█▍        | 70/500 [43:42<4:20:41, 36.37s/it]

{'loss': 2.4575, 'grad_norm': 1.6305186748504639, 'learning_rate': 0.00017551020408163265, 'epoch': 0.02}


 16%|█▌        | 80/500 [49:17<3:59:14, 34.18s/it]

{'loss': 2.5765, 'grad_norm': 1.6835591793060303, 'learning_rate': 0.00017142857142857143, 'epoch': 0.02}


 18%|█▊        | 90/500 [55:02<3:49:39, 33.61s/it]

{'loss': 2.3294, 'grad_norm': 1.9981335401535034, 'learning_rate': 0.00016734693877551023, 'epoch': 0.02}


 20%|██        | 100/500 [59:02<2:10:20, 19.55s/it]

{'loss': 2.5165, 'grad_norm': 2.8242881298065186, 'learning_rate': 0.00016326530612244898, 'epoch': 0.02}


 22%|██▏       | 110/500 [1:07:15<4:45:05, 43.86s/it]

{'loss': 2.5328, 'grad_norm': 1.4915645122528076, 'learning_rate': 0.00015918367346938776, 'epoch': 0.02}


 24%|██▍       | 120/500 [1:13:51<3:54:48, 37.07s/it]

{'loss': 2.4477, 'grad_norm': 1.5756242275238037, 'learning_rate': 0.00015510204081632654, 'epoch': 0.03}


 26%|██▌       | 130/500 [1:19:22<3:24:37, 33.18s/it]

{'loss': 2.495, 'grad_norm': 1.7442927360534668, 'learning_rate': 0.0001510204081632653, 'epoch': 0.03}


 28%|██▊       | 140/500 [1:25:00<3:19:52, 33.31s/it]

{'loss': 2.2987, 'grad_norm': 1.8456257581710815, 'learning_rate': 0.0001469387755102041, 'epoch': 0.03}


 30%|███       | 150/500 [1:29:45<2:14:32, 23.07s/it]

{'loss': 2.525, 'grad_norm': 3.773294687271118, 'learning_rate': 0.00014285714285714287, 'epoch': 0.03}


 32%|███▏      | 160/500 [1:38:53<4:14:54, 44.98s/it]

{'loss': 2.3004, 'grad_norm': 1.2723884582519531, 'learning_rate': 0.00013877551020408165, 'epoch': 0.04}


 34%|███▍      | 170/500 [1:44:46<3:12:16, 34.96s/it]

{'loss': 2.4125, 'grad_norm': 1.615640640258789, 'learning_rate': 0.0001346938775510204, 'epoch': 0.04}


 36%|███▌      | 180/500 [1:50:26<3:04:59, 34.68s/it]

{'loss': 2.3533, 'grad_norm': 1.7801198959350586, 'learning_rate': 0.00013061224489795917, 'epoch': 0.04}


 38%|███▊      | 190/500 [1:56:08<2:58:02, 34.46s/it]

{'loss': 2.5303, 'grad_norm': 1.8462018966674805, 'learning_rate': 0.00012653061224489798, 'epoch': 0.04}


 40%|████      | 200/500 [2:00:53<1:53:58, 22.79s/it]

{'loss': 2.5818, 'grad_norm': 2.5140554904937744, 'learning_rate': 0.00012244897959183676, 'epoch': 0.04}


 42%|████▏     | 210/500 [2:14:04<5:20:14, 66.26s/it]

{'loss': 2.6029, 'grad_norm': 1.4094632863998413, 'learning_rate': 0.00011836734693877552, 'epoch': 0.05}


 44%|████▍     | 220/500 [2:24:15<4:31:23, 58.15s/it]

{'loss': 2.3849, 'grad_norm': 1.4031249284744263, 'learning_rate': 0.00011428571428571428, 'epoch': 0.05}


 46%|████▌     | 230/500 [2:33:05<3:42:27, 49.44s/it]

{'loss': 2.6077, 'grad_norm': 1.8373081684112549, 'learning_rate': 0.00011020408163265306, 'epoch': 0.05}


 48%|████▊     | 240/500 [2:42:04<3:50:08, 53.11s/it]

{'loss': 2.3954, 'grad_norm': 1.8295235633850098, 'learning_rate': 0.00010612244897959185, 'epoch': 0.05}


 50%|█████     | 250/500 [2:48:44<2:12:25, 31.78s/it]

{'loss': 2.198, 'grad_norm': 2.901500701904297, 'learning_rate': 0.00010204081632653062, 'epoch': 0.06}


 52%|█████▏    | 260/500 [3:02:05<4:34:24, 68.60s/it]

{'loss': 2.4492, 'grad_norm': 1.3046141862869263, 'learning_rate': 9.79591836734694e-05, 'epoch': 0.06}


 54%|█████▍    | 270/500 [3:11:46<3:33:38, 55.73s/it]

{'loss': 2.4154, 'grad_norm': 1.5080292224884033, 'learning_rate': 9.387755102040817e-05, 'epoch': 0.06}


 56%|█████▌    | 280/500 [3:20:21<3:14:41, 53.10s/it]

{'loss': 2.4427, 'grad_norm': 1.633769154548645, 'learning_rate': 8.979591836734695e-05, 'epoch': 0.06}


 58%|█████▊    | 290/500 [3:29:00<3:00:47, 51.66s/it]

{'loss': 2.2842, 'grad_norm': 1.6494824886322021, 'learning_rate': 8.571428571428571e-05, 'epoch': 0.06}


 60%|██████    | 300/500 [3:35:50<1:51:33, 33.47s/it]

{'loss': 2.4737, 'grad_norm': 2.5093307495117188, 'learning_rate': 8.163265306122449e-05, 'epoch': 0.07}


 62%|██████▏   | 310/500 [3:48:18<3:25:09, 64.79s/it]

{'loss': 2.3785, 'grad_norm': 1.574847936630249, 'learning_rate': 7.755102040816327e-05, 'epoch': 0.07}


 64%|██████▍   | 320/500 [3:57:06<2:30:21, 50.12s/it]

{'loss': 2.4694, 'grad_norm': 1.9669430255889893, 'learning_rate': 7.346938775510205e-05, 'epoch': 0.07}


 66%|██████▌   | 330/500 [4:05:51<2:27:50, 52.18s/it]

{'loss': 2.2753, 'grad_norm': 1.7294447422027588, 'learning_rate': 6.938775510204082e-05, 'epoch': 0.07}


 68%|██████▊   | 340/500 [4:14:02<2:06:59, 47.62s/it]

{'loss': 2.3736, 'grad_norm': 1.6452603340148926, 'learning_rate': 6.530612244897959e-05, 'epoch': 0.08}


 70%|███████   | 350/500 [4:19:28<1:06:43, 26.69s/it]

{'loss': 2.3382, 'grad_norm': 2.6385293006896973, 'learning_rate': 6.122448979591838e-05, 'epoch': 0.08}


 72%|███████▏  | 360/500 [4:32:29<2:38:06, 67.76s/it]

{'loss': 2.6472, 'grad_norm': 1.4074010848999023, 'learning_rate': 5.714285714285714e-05, 'epoch': 0.08}


 74%|███████▍  | 370/500 [4:42:42<2:04:42, 57.56s/it]

{'loss': 2.3262, 'grad_norm': 1.6429510116577148, 'learning_rate': 5.3061224489795926e-05, 'epoch': 0.08}


 76%|███████▌  | 380/500 [4:51:34<1:41:04, 50.54s/it]

{'loss': 2.3398, 'grad_norm': 1.4690697193145752, 'learning_rate': 4.89795918367347e-05, 'epoch': 0.08}


 78%|███████▊  | 390/500 [5:00:21<1:34:34, 51.59s/it]

{'loss': 2.2366, 'grad_norm': 1.5293055772781372, 'learning_rate': 4.4897959183673474e-05, 'epoch': 0.09}


 80%|████████  | 400/500 [5:06:47<51:00, 30.60s/it]  

{'loss': 2.5311, 'grad_norm': 3.3154120445251465, 'learning_rate': 4.0816326530612245e-05, 'epoch': 0.09}


 82%|████████▏ | 410/500 [5:20:36<1:43:23, 68.93s/it]

{'loss': 2.1668, 'grad_norm': 1.5779565572738647, 'learning_rate': 3.673469387755102e-05, 'epoch': 0.09}


 84%|████████▍ | 420/500 [5:30:22<1:16:39, 57.49s/it]

{'loss': 2.3574, 'grad_norm': 1.4244446754455566, 'learning_rate': 3.265306122448979e-05, 'epoch': 0.09}


 86%|████████▌ | 430/500 [5:39:12<1:03:09, 54.14s/it]

{'loss': 2.4089, 'grad_norm': 1.686224102973938, 'learning_rate': 2.857142857142857e-05, 'epoch': 0.1}


 88%|████████▊ | 440/500 [5:47:49<50:52, 50.87s/it]  

{'loss': 2.5172, 'grad_norm': 1.9205021858215332, 'learning_rate': 2.448979591836735e-05, 'epoch': 0.1}


 90%|█████████ | 450/500 [5:54:14<25:25, 30.50s/it]

{'loss': 2.309, 'grad_norm': 3.266089677810669, 'learning_rate': 2.0408163265306123e-05, 'epoch': 0.1}


 92%|█████████▏| 460/500 [6:08:38<47:51, 71.79s/it]  

{'loss': 2.3879, 'grad_norm': 1.2667280435562134, 'learning_rate': 1.6326530612244897e-05, 'epoch': 0.1}


 94%|█████████▍| 470/500 [6:18:42<28:10, 56.34s/it]

{'loss': 2.3044, 'grad_norm': 1.4839311838150024, 'learning_rate': 1.2244897959183674e-05, 'epoch': 0.1}


 96%|█████████▌| 480/500 [6:27:14<17:27, 52.38s/it]

{'loss': 2.1671, 'grad_norm': 1.5683590173721313, 'learning_rate': 8.163265306122448e-06, 'epoch': 0.11}


 98%|█████████▊| 490/500 [6:35:31<07:58, 47.83s/it]

{'loss': 2.4756, 'grad_norm': 2.0352673530578613, 'learning_rate': 4.081632653061224e-06, 'epoch': 0.11}


100%|██████████| 500/500 [6:41:07<00:00, 28.09s/it]

{'loss': 2.3548, 'grad_norm': 2.0355608463287354, 'learning_rate': 0.0, 'epoch': 0.11}


                                                   


{'eval_loss': 2.370558023452759, 'eval_runtime': 7942.5771, 'eval_samples_per_second': 0.126, 'eval_steps_per_second': 0.126, 'epoch': 0.11}


In [None]:
path_to_save = "Llama-finetuned"
trainer.save_model(path_to_save)
model.save_pretrained(path_to_save)
tokenizer.save_pretrained(path_to_save)

In [None]:
del model, tokenizer, trainer

# Compare models

## Init casual LLM

In [None]:
# QLoRA config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=cfg.torch_dtype,
    bnb_4bit_use_double_quant=True,
)

# Load model
casual_model = AutoModelForCausalLM.from_pretrained(
    cfg.model_name,
    quantization_config=bnb_config,
#     device_map="auto",
    attn_implementation=cfg.attn_implementation
)

tokenizer = tokenizer = AutoTokenizer.from_pretrained(cfg.model_name)
tokenizer.padding_side = 'right'
tokenizer.padding_token = '<|pad_token|>'

In [None]:
casual_model, tokenizer = setup_chat_format(casual_model, tokenizer)

## Get answers

In [None]:
def generate_answer(model, prompt):
    chat = [
        { "role": "user", "content": prompt },
    ]
    prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
    outputs = model.generate(input_ids=inputs.to(model.device), max_new_tokens=150)

    return(tokenizer.decode(outputs[0]))

# Comprasion

In [None]:
q1 = "I have severe headaches help me please"
q2 = "I have a suspiciously large mole. Could I have cancer? How can I determine this at home?"
q3 = "What does abutment of the nerve root mean?"

In [None]:
generate_answer(model, q1)

In [None]:
generate_answer(model, q2)

In [None]:
generate_answer(model, q3)

In [None]:
# Free gpu memory
import numba
numba.cuda.close()

In [None]:
print(generate_answer(casual_model, q1))

In [None]:
generate_answer(casual_model, q2)

In [None]:
generate_answer(casual_model, q3)