# Libraries

In [1]:
# %pip install -U transformers
# %pip install -U datasets
# %pip install -U accelerate
# %pip install -U peft
# %pip install -U trl
# %pip install -U bitsandbytes

In [2]:
import os, torch, wandb

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import (
    LoraConfig,
    PeftModel,
    prepare_model_for_kbit_training,
    get_peft_model,
)

from datasets import load_dataset, concatenate_datasets
from trl import SFTTrainer, setup_chat_format
from dataclasses import dataclass

  from .autonotebook import tqdm as notebook_tqdm


## Setup Huggingface ü§ó & Wandb

In [5]:
@dataclass
class Config:
#     model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
#     model_name = "AnatoliiPotapov/T-lite-instruct-0.1"
    model_name = "notused"
    dataset_name = "notused"
    new_model = "russia_chad"
    torch_dtype = torch.float16
    attn_implementation = "eager"
cfg = Config()

# Loading model and tokenizer

In [6]:
# QLoRA config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=cfg.torch_dtype,
    bnb_4bit_use_double_quant=True,
)

# Load model
model = AutoModelForCausalLM.from_pretrained(
    "C:\\Users\\USER_ELISEY\\gemma",
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation=cfg.attn_implementation
)

Loading checkpoint shards: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4/4 [00:08<00:00,  2.03s/it]


In [7]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("C:\\Users\\USER_ELISEY\\gemma")
tokenizer.padding_side = 'right'
tokenizer.padding_token = '<|pad|>'
print(len(tokenizer))

256000


## LoRA adapter

In [8]:
# LoRA config
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)
model = get_peft_model(model, peft_config)

# Data

## Load

In [9]:
dataset = load_dataset('miracl/miracl', 'ru', trust_remote_code=True)
# data_eval["dev"]['negative_passages']

In [10]:
dataset

DatasetDict({
    dev: Dataset({
        features: ['query_id', 'query', 'positive_passages', 'negative_passages'],
        num_rows: 1252
    })
    testB: Dataset({
        features: ['query_id', 'query', 'positive_passages', 'negative_passages'],
        num_rows: 718
    })
    train: Dataset({
        features: ['query_id', 'query', 'positive_passages', 'negative_passages'],
        num_rows: 4683
    })
    testA: Dataset({
        features: ['query_id', 'query', 'positive_passages', 'negative_passages'],
        num_rows: 911
    })
})

## Format to chat 

In [11]:
def format_chat_template(row):
    row_json = [{"role": "user", "content": row["query"]},
               {"role": "assistant", "content": row["positive_passages"][0]["text"]}]
    row["text"] = tokenizer.apply_chat_template(row_json, tokenize=False)
    return row

In [12]:
dataset = concatenate_datasets([
    dataset['dev'],
    dataset['train']
])

dataset = dataset.remove_columns('negative_passages')

In [13]:
dataset['positive_passages'][0][0]["text"]

'–ö–∞—Ä–∏ÃÅ–±—Å–∫–∏–π –∫—Ä–∏–∑–∏—Å\xa0‚Äî –∏—Å—Ç–æ—Ä–∏—á–µ—Å–∫–∏–π —Ç–µ—Ä–º–∏–Ω, –æ–ø—Ä–µ–¥–µ–ª—è—é—â–∏–π —á—Ä–µ–∑–≤—ã—á–∞–π–Ω–æ –Ω–∞–ø—Ä—è–∂—ë–Ω–Ω–æ–µ –ø–æ–ª–∏—Ç–∏—á–µ—Å–∫–æ–µ, –¥–∏–ø–ª–æ–º–∞—Ç–∏—á–µ—Å–∫–æ–µ –∏ –≤–æ–µ–Ω–Ω–æ–µ –ø—Ä–æ—Ç–∏–≤–æ—Å—Ç–æ—è–Ω–∏–µ –º–µ–∂–¥—É –°–æ–≤–µ—Ç—Å–∫–∏–º –°–æ—é–∑–æ–º –∏ –°–æ–µ–¥–∏–Ω—ë–Ω–Ω—ã–º–∏ –®—Ç–∞—Ç–∞–º–∏ –≤ –æ–∫—Ç—è–±—Ä–µ 1962 –≥–æ–¥–∞, –∫–æ—Ç–æ—Ä–æ–µ –±—ã–ª–æ –≤—ã–∑–≤–∞–Ω–æ —Ä–∞–∑–º–µ—â–µ–Ω–∏–µ–º –°–®–ê —è–¥–µ—Ä–Ω–æ–≥–æ –æ—Ä—É–∂–∏—è –≤ –¢—É—Ä—Ü–∏–∏ –≤ 1961 –≥–æ–¥—É –∏ –≤–ø–æ—Å–ª–µ–¥—Å—Ç–≤–∏–∏ —Ç–∞–π–Ω–æ–π –ø–µ—Ä–µ–±—Ä–æ—Å–∫–æ–π –∏ —Ä–∞–∑–º–µ—â–µ–Ω–∏–µ–º –Ω–∞ –ö—É–±–µ –≤–æ–µ–Ω–Ω—ã—Ö —á–∞—Å—Ç–µ–π –∏ –ø–æ–¥—Ä–∞–∑–¥–µ–ª–µ–Ω–∏–π –í–æ–æ—Ä—É–∂—ë–Ω–Ω—ã—Ö –°–∏–ª –°–°–°–†, —Ç–µ—Ö–Ω–∏–∫–∏ –∏ –≤–æ–æ—Ä—É–∂–µ–Ω–∏—è, –≤–∫–ª—é—á–∞—è —è–¥–µ—Ä–Ω–æ–µ –æ—Ä—É–∂–∏–µ. –ö—Ä–∏–∑–∏—Å –º–æ–≥ –ø—Ä–∏–≤–µ—Å—Ç–∏ –∫ –≥–ª–æ–±–∞–ª—å–Ω–æ–π —è–¥–µ—Ä–Ω–æ–π –≤–æ–π–Ω–µ. –ö—É–±–∏–Ω—Ü—ã –Ω–∞–∑—ã–≤–∞—é—Ç –µ–≥–æ ¬´–û–∫—Ç—è–±—Ä—å—Å–∫–∏–º –∫—Ä–∏–∑–∏—Å–æ–º¬ª (), –≤ –°–®–ê —Ä–∞—Å–ø—Ä–æ—Å—Ç—Ä–∞–Ω

In [14]:
dataset = dataset.map(
    format_chat_template,
    num_proc=1,
)

## Select only part

In [15]:
dataset_sh = dataset.shuffle(seed=2024)#.select(range(10_000))
dataset_sh

Dataset({
    features: ['query_id', 'query', 'positive_passages', 'text'],
    num_rows: 5935
})

In [16]:
dataset_sh = dataset_sh.train_test_split(0.1)

In [17]:
dataset_sh

DatasetDict({
    train: Dataset({
        features: ['query_id', 'query', 'positive_passages', 'text'],
        num_rows: 5341
    })
    test: Dataset({
        features: ['query_id', 'query', 'positive_passages', 'text'],
        num_rows: 594
    })
})

# Train model

## Training arguments

In [18]:
training_arguments = TrainingArguments(
    output_dir=cfg.new_model,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    optim="paged_adamw_32bit",
#     num_train_epochs=1,
    max_steps=500,
    eval_strategy="steps",
    eval_steps=100,
    logging_steps=10,
    warmup_steps=10,
    logging_strategy="steps",
    learning_rate=2e-4,
    fp16=True,
    bf16=False,
    group_by_length=True,
    report_to="wandb",
    run_name="Llama-3.1-medicine",
)

In [19]:
# print(len(tokenizer))

In [20]:
# model.resize_token_embeddings(256000)

## Train model

In [21]:
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset_sh["train"],
    eval_dataset=dataset_sh["test"],
    peft_config=peft_config,
    max_seq_length=512,
    dataset_text_field="text",
    tokenizer=tokenizer,
    args=training_arguments,
    packing= False,
)


Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.
Map: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 5341/5341 [00:00<00:00, 7611.39 examples/s]
Map: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 594/594 [00:00<00:00, 6795.04 examples/s]
max_steps is given, it will override any value given in num_train_epochs


In [None]:
trainer.train()

  2%|‚ñè         | 10/500 [08:31<5:48:18, 42.65s/it]

{'loss': 2.3, 'grad_norm': 2.1840381622314453, 'learning_rate': 0.0002, 'epoch': 0.0}


  4%|‚ñç         | 20/500 [13:25<3:47:05, 28.39s/it]

{'loss': 2.0288, 'grad_norm': 1.8392691612243652, 'learning_rate': 0.0001959183673469388, 'epoch': 0.01}


  6%|‚ñå         | 30/500 [18:52<4:17:19, 32.85s/it]

{'loss': 1.968, 'grad_norm': 2.3942601680755615, 'learning_rate': 0.00019183673469387756, 'epoch': 0.01}


  8%|‚ñä         | 40/500 [23:50<3:31:20, 27.57s/it]

{'loss': 1.7308, 'grad_norm': 2.6008598804473877, 'learning_rate': 0.00018775510204081634, 'epoch': 0.01}


 10%|‚ñà         | 50/500 [26:59<2:18:42, 18.49s/it]

{'loss': 1.6545, 'grad_norm': 4.566784858703613, 'learning_rate': 0.00018367346938775512, 'epoch': 0.02}


 12%|‚ñà‚ñè        | 60/500 [36:40<5:36:39, 45.91s/it]

{'loss': 1.8122, 'grad_norm': 1.5805020332336426, 'learning_rate': 0.0001795918367346939, 'epoch': 0.02}


 14%|‚ñà‚ñç        | 70/500 [42:37<3:56:23, 32.98s/it]

{'loss': 1.6243, 'grad_norm': 2.4196367263793945, 'learning_rate': 0.00017551020408163265, 'epoch': 0.03}


 16%|‚ñà‚ñå        | 80/500 [47:39<3:45:45, 32.25s/it]

{'loss': 1.6776, 'grad_norm': 2.4707908630371094, 'learning_rate': 0.00017142857142857143, 'epoch': 0.03}


 18%|‚ñà‚ñä        | 90/500 [52:18<2:41:45, 23.67s/it]

{'loss': 1.4889, 'grad_norm': 2.4366214275360107, 'learning_rate': 0.00016734693877551023, 'epoch': 0.03}


 20%|‚ñà‚ñà        | 100/500 [55:06<1:47:43, 16.16s/it]

{'loss': 1.5494, 'grad_norm': 2.841456651687622, 'learning_rate': 0.00016326530612244898, 'epoch': 0.04}


                                                   
 20%|‚ñà‚ñà        | 100/500 [1:57:25<1:47:43, 16.16s/it]

{'eval_loss': 1.6582292318344116, 'eval_runtime': 3739.4813, 'eval_samples_per_second': 0.159, 'eval_steps_per_second': 0.159, 'epoch': 0.04}


 22%|‚ñà‚ñà‚ñè       | 110/500 [2:06:19<10:03:33, 92.86s/it]   

{'loss': 1.7072, 'grad_norm': 1.3870774507522583, 'learning_rate': 0.00015918367346938776, 'epoch': 0.04}


 24%|‚ñà‚ñà‚ñç       | 120/500 [2:12:43<3:53:04, 36.80s/it] 

{'loss': 1.6798, 'grad_norm': 1.8217418193817139, 'learning_rate': 0.00015510204081632654, 'epoch': 0.04}


 26%|‚ñà‚ñà‚ñå       | 130/500 [2:18:51<4:00:28, 39.00s/it]

{'loss': 1.8031, 'grad_norm': 2.0685861110687256, 'learning_rate': 0.0001510204081632653, 'epoch': 0.05}


 28%|‚ñà‚ñà‚ñä       | 140/500 [2:25:00<3:25:59, 34.33s/it]

{'loss': 1.5728, 'grad_norm': 3.1982595920562744, 'learning_rate': 0.0001469387755102041, 'epoch': 0.05}


 30%|‚ñà‚ñà‚ñà       | 150/500 [2:28:46<2:00:25, 20.64s/it]

{'loss': 1.6333, 'grad_norm': nan, 'learning_rate': 0.00014326530612244898, 'epoch': 0.06}


 32%|‚ñà‚ñà‚ñà‚ñè      | 160/500 [2:38:55<4:51:20, 51.41s/it]

{'loss': 1.8147, 'grad_norm': 1.7196639776229858, 'learning_rate': 0.00013918367346938776, 'epoch': 0.06}


 34%|‚ñà‚ñà‚ñà‚ñç      | 170/500 [2:45:44<3:24:34, 37.20s/it]

{'loss': 1.8404, 'grad_norm': 2.099281072616577, 'learning_rate': 0.00013510204081632654, 'epoch': 0.06}


 36%|‚ñà‚ñà‚ñà‚ñå      | 180/500 [2:51:54<3:24:31, 38.35s/it]

{'loss': 1.6406, 'grad_norm': 1.7859504222869873, 'learning_rate': 0.00013102040816326531, 'epoch': 0.07}


 38%|‚ñà‚ñà‚ñà‚ñä      | 190/500 [2:57:52<2:47:27, 32.41s/it]

{'loss': 1.4023, 'grad_norm': 2.9348511695861816, 'learning_rate': 0.00012693877551020406, 'epoch': 0.07}


 40%|‚ñà‚ñà‚ñà‚ñà      | 200/500 [3:01:31<1:43:49, 20.76s/it]

{'loss': 1.5202, 'grad_norm': 4.104401588439941, 'learning_rate': 0.00012285714285714287, 'epoch': 0.07}


                                                     
 40%|‚ñà‚ñà‚ñà‚ñà      | 200/500 [4:18:27<1:43:49, 20.76s/it]

{'eval_loss': 1.6190582513809204, 'eval_runtime': 4616.0114, 'eval_samples_per_second': 0.129, 'eval_steps_per_second': 0.129, 'epoch': 0.07}


 42%|‚ñà‚ñà‚ñà‚ñà‚ñè     | 210/500 [4:28:46<8:34:52, 106.53s/it]   

{'loss': 1.7892, 'grad_norm': 1.7662056684494019, 'learning_rate': 0.00011877551020408165, 'epoch': 0.08}


 44%|‚ñà‚ñà‚ñà‚ñà‚ñç     | 220/500 [4:35:58<3:20:44, 43.02s/it] 

{'loss': 1.7324, 'grad_norm': 2.059535503387451, 'learning_rate': 0.00011469387755102041, 'epoch': 0.08}


 46%|‚ñà‚ñà‚ñà‚ñà‚ñå     | 230/500 [4:42:03<2:53:37, 38.58s/it]

{'loss': 1.7448, 'grad_norm': 2.26393461227417, 'learning_rate': 0.00011061224489795919, 'epoch': 0.09}


 48%|‚ñà‚ñà‚ñà‚ñà‚ñä     | 240/500 [4:47:41<2:07:04, 29.33s/it]

{'loss': 1.6726, 'grad_norm': 2.531026840209961, 'learning_rate': 0.00010653061224489795, 'epoch': 0.09}


 50%|‚ñà‚ñà‚ñà‚ñà‚ñà     | 250/500 [4:51:18<1:21:05, 19.46s/it]

{'loss': 1.4, 'grad_norm': 4.796047687530518, 'learning_rate': 0.00010244897959183674, 'epoch': 0.09}


 52%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè    | 260/500 [5:00:34<3:04:46, 46.19s/it]

{'loss': 1.8523, 'grad_norm': 1.7520687580108643, 'learning_rate': 9.836734693877552e-05, 'epoch': 0.1}


 54%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç    | 270/500 [5:06:42<2:17:31, 35.88s/it]

{'loss': 1.6612, 'grad_norm': 1.7135014533996582, 'learning_rate': 9.428571428571429e-05, 'epoch': 0.1}


 56%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå    | 280/500 [5:13:03<2:17:52, 37.60s/it]

{'loss': 1.4078, 'grad_norm': 2.065044641494751, 'learning_rate': 9.020408163265308e-05, 'epoch': 0.1}


 58%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä    | 290/500 [5:18:42<1:49:47, 31.37s/it]

{'loss': 1.3318, 'grad_norm': 3.0822982788085938, 'learning_rate': 8.612244897959184e-05, 'epoch': 0.11}


 60%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà    | 300/500 [5:22:10<1:02:32, 18.76s/it]

{'loss': 1.5945, 'grad_norm': 2.65639066696167, 'learning_rate': 8.204081632653062e-05, 'epoch': 0.11}


                                                     
 60%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà    | 300/500 [6:38:11<1:02:32, 18.76s/it]

{'eval_loss': 1.5992199182510376, 'eval_runtime': 4561.7417, 'eval_samples_per_second': 0.13, 'eval_steps_per_second': 0.13, 'epoch': 0.11}


 62%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè   | 310/500 [6:48:45<5:42:46, 108.25s/it]  

{'loss': 1.791, 'grad_norm': 1.5422536134719849, 'learning_rate': 7.79591836734694e-05, 'epoch': 0.12}


 64%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç   | 320/500 [6:55:14<1:52:04, 37.36s/it] 

{'loss': 1.6219, 'grad_norm': 1.8921648263931274, 'learning_rate': 7.387755102040816e-05, 'epoch': 0.12}


 66%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå   | 330/500 [7:01:39<1:50:11, 38.89s/it]

{'loss': 1.591, 'grad_norm': 2.324852705001831, 'learning_rate': 6.979591836734695e-05, 'epoch': 0.12}


 68%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä   | 340/500 [7:07:37<1:29:15, 33.47s/it]

{'loss': 1.4953, 'grad_norm': 2.306370496749878, 'learning_rate': 6.571428571428571e-05, 'epoch': 0.13}


 70%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà   | 350/500 [7:11:32<52:43, 21.09s/it]  

{'loss': 1.4416, 'grad_norm': 2.694382905960083, 'learning_rate': 6.163265306122449e-05, 'epoch': 0.13}


 72%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè  | 360/500 [7:20:59<1:53:06, 48.47s/it]

{'loss': 1.6559, 'grad_norm': 1.485676884651184, 'learning_rate': 5.755102040816327e-05, 'epoch': 0.13}


 74%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç  | 370/500 [7:27:20<1:17:02, 35.56s/it]

{'loss': 1.509, 'grad_norm': 2.0512969493865967, 'learning_rate': 5.346938775510204e-05, 'epoch': 0.14}


 76%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå  | 380/500 [7:33:40<1:16:09, 38.08s/it]

{'loss': 1.6028, 'grad_norm': 2.4973597526550293, 'learning_rate': 4.938775510204082e-05, 'epoch': 0.14}


 78%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä  | 390/500 [7:39:18<54:28, 29.71s/it]  

{'loss': 1.4453, 'grad_norm': 2.477543830871582, 'learning_rate': 4.5306122448979595e-05, 'epoch': 0.15}


 80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 400/500 [7:42:35<29:47, 17.88s/it]

{'loss': 1.4901, 'grad_norm': 4.916280269622803, 'learning_rate': 4.122448979591837e-05, 'epoch': 0.15}


                                                   
 80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 400/500 [8:59:21<29:47, 17.88s/it]

{'eval_loss': 1.5806922912597656, 'eval_runtime': 4606.0529, 'eval_samples_per_second': 0.129, 'eval_steps_per_second': 0.129, 'epoch': 0.15}


 82%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè | 410/500 [9:07:26<2:25:25, 96.95s/it]   

{'loss': 1.8179, 'grad_norm': 1.8731502294540405, 'learning_rate': 3.7142857142857143e-05, 'epoch': 0.15}


 84%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç | 420/500 [9:12:38<41:37, 31.22s/it]  

{'loss': 1.6391, 'grad_norm': 1.7788532972335815, 'learning_rate': 3.306122448979592e-05, 'epoch': 0.16}


 86%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå | 430/500 [9:18:11<38:33, 33.04s/it]

{'loss': 1.7178, 'grad_norm': 2.2750277519226074, 'learning_rate': 2.8979591836734692e-05, 'epoch': 0.16}


 88%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä | 440/500 [9:23:33<31:22, 31.37s/it]

{'loss': 1.311, 'grad_norm': 2.0465900897979736, 'learning_rate': 2.489795918367347e-05, 'epoch': 0.16}


 90%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà | 450/500 [9:26:51<14:13, 17.07s/it]

{'loss': 1.4539, 'grad_norm': 3.6052777767181396, 'learning_rate': 2.0816326530612247e-05, 'epoch': 0.17}


 92%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè| 460/500 [9:35:39<29:35, 44.38s/it]

{'loss': 1.7986, 'grad_norm': 1.5277632474899292, 'learning_rate': 1.673469387755102e-05, 'epoch': 0.17}


 94%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç| 470/500 [9:41:13<15:16, 30.55s/it]

{'loss': 1.6121, 'grad_norm': 2.007514238357544, 'learning_rate': 1.2653061224489795e-05, 'epoch': 0.18}


 96%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå| 480/500 [9:46:41<10:49, 32.49s/it]

{'loss': 1.6892, 'grad_norm': 2.6101012229919434, 'learning_rate': 8.571428571428573e-06, 'epoch': 0.18}


 98%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä| 490/500 [9:50:31<03:22, 20.30s/it]

{'loss': 1.4698, 'grad_norm': 2.0734493732452393, 'learning_rate': 4.489795918367347e-06, 'epoch': 0.18}


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 500/500 [9:53:19<00:00, 15.90s/it]

{'loss': 1.4128, 'grad_norm': 2.824575901031494, 'learning_rate': 4.081632653061225e-07, 'epoch': 0.19}


                                                   
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 500/500 [10:55:23<00:00, 15.90s/it]

{'eval_loss': 1.570472002029419, 'eval_runtime': 3723.8734, 'eval_samples_per_second': 0.16, 'eval_steps_per_second': 0.16, 'epoch': 0.19}


In [None]:
path_to_save = "Llama-finetuned"
trainer.save_model(path_to_save)
model.save_pretrained(path_to_save)
tokenizer.save_pretrained(path_to_save)

In [None]:
del model, tokenizer, trainer

# Compare models

## Init casual LLM

In [None]:
# QLoRA config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=cfg.torch_dtype,
    bnb_4bit_use_double_quant=True,
)

# Load model
casual_model = AutoModelForCausalLM.from_pretrained(
    cfg.model_name,
    quantization_config=bnb_config,
#     device_map="auto",
    attn_implementation=cfg.attn_implementation
)

tokenizer = tokenizer = AutoTokenizer.from_pretrained(cfg.model_name)
tokenizer.padding_side = 'right'
tokenizer.padding_token = '<|pad_token|>'

In [None]:
casual_model, tokenizer = setup_chat_format(casual_model, tokenizer)

## Get answers

In [None]:
def generate_answer(model, prompt):
    chat = [
        { "role": "user", "content": prompt },
    ]
    prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
    outputs = model.generate(input_ids=inputs.to(model.device), max_new_tokens=150)

    return(tokenizer.decode(outputs[0]))

# Comprasion

In [None]:
q1 = "I have severe headaches help me please"
q2 = "I have a suspiciously large mole. Could I have cancer? How can I determine this at home?"
q3 = "What does abutment of the nerve root mean?"

In [None]:
generate_answer(model, q1)

In [None]:
generate_answer(model, q2)

In [None]:
generate_answer(model, q3)

In [None]:
# Free gpu memory
import numba
numba.cuda.close()

In [None]:
print(generate_answer(casual_model, q1))

In [None]:
generate_answer(casual_model, q2)

In [None]:
generate_answer(casual_model, q3)