<img src="https://storage.googleapis.com/gweb-uniblog-publish-prod/images/gemma-header.width-1200.format-webp.webp" width="100%">

## Instruct Fine-tuning [Gemma](https://blog.google/technology/developers/gemma-open-models/) using qLora and Supervise Finetuning

This is a comprahensive notebook and tutorial on how to fine tune the `gemma-7b-it` Model

All the code will be available on my Github. Do drop by and give a follow and a star.
[adithya-s-k](https://github.com/adithya-s-k)
\
[Github Code](https://github.com/adithya-s-k/LLM-Cookbook/blob/main/LLMs/Gemma/finetune-gemma.ipynb)

I also post content about LLMs and what I have been working on Twitter.
[AdithyaSK (@adithya_s_k) / X](https://twitter.com/adithya_s_k)

## Prerequisites

Before delving into the fine-tuning process, ensure that you have the following prerequisites in place:

1. **GPU**: [gemma-2b](https://huggingface.co/google/gemma-2b) - can be finetuned on T4(free google colab) while [gemma-7b](https://huggingface.co/google/gemma-7b) requires an A100 GPU.
2. **Python Packages**: Ensure that you have the necessary Python packages installed. You can use the following commands to install them:

Let's begin by checking if your GPU is correctly detected:

In [1]:
!nvidia-smi

Sat Jul  6 15:16:48 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.42.02              Driver Version: 555.42.02      CUDA Version: 12.5     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   41C    P0             46W /  400W |     138MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                

## Step 2 - Model loading
We'll load the model using QLoRA quantization to reduce the usage of memory


In [2]:
#!pip3 install -q -U bitsandbytes==0.42.0
#!pip3 install -q -U peft==0.8.2
#!pip3 install -q -U trl==0.7.10
#!pip3 install -q -U accelerate==0.27.1
#!pip3 install -q -U datasets==2.17.0
#!pip3 install -q -U transformers==4.38.0

In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

Now we specify the model ID and then we load it with our previously defined quantization configuration.Now we specify the model ID and then we load it with our previously defined quantization configuration.

In [4]:
# if you are using google colab

# import os
# from google.colab import userdata
# os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')

In [4]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [2]:
model_id = "google/gemma-7b-it"
# model_id = "google/gemma-7b"
# model_id = "google/gemma-2b-it"
# model_id = "google/gemma-2b"

model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0})
tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token=True)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [7]:
'''
def get_completion(query: str, model, tokenizer) -> str:
    device = "cuda:0"

    prompt_template = """
    <start_of_turn>user
    Below is an instruction that describes a task. Write a response that appropriately completes the request.
    {query}
    <end_of_turn>\n<start_of_turn>model


    """
    prompt = prompt_template.format(query=query)

    encodeds = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)

    model_inputs = encodeds.to(device)


    generated_ids = model.generate(**model_inputs, max_new_tokens=1000, do_sample=True, pad_token_id=tokenizer.eos_token_id)
    # decoded = tokenizer.batch_decode(generated_ids)
    decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    return (decoded)
'''

In [6]:
#result = get_completion(query="code the fibonacci series in python using reccursion", model=model, tokenizer=tokenizer)
#print(result)

## Step 3 - Load dataset for finetuning

### Lets Load the Dataset

For this tutorial, we will fine-tune Mistral 7B Instruct for code generation.

We will be using this [dataset](https://huggingface.co/datasets/TokenBender/code_instructions_122k_alpaca_style) which is curated by [TokenBender (e/xperiments)](https://twitter.com/4evaBehindSOTA) and is an excellent data source for fine-tuning models for code generation. It follows the alpaca style of instructions, which is an excellent starting point for this task. The dataset structure should resemble the following:

```json
{
  "instruction": "Create a function to calculate the sum of a sequence of integers.",
  "input": "[1, 2, 3, 4, 5]",
  "output": "# Python code def sum_sequence(sequence): sum = 0 for num in sequence: sum += num return sum"
}
```

In [3]:
import pandas as pd
import os
import numpy as np
import torch

In [4]:
os.getcwd()

'/home/jupyter/finetune_LLM'

In [4]:
classified_shrishti_df = pd.read_csv('/home/jupyter/finetune_LLM/Data/shrishti/train/processed_chunked_files_llm_finetuining_shrishti.csv')
classified_shrishti_df

Unnamed: 0,filename,notes,classified
0,Notes_text_shrishti_chunk_1382.txt,amoxycillin +clarithromycin+ esomeprazole (100...,{'recipere': ['amoxycillin +clarithromycin+ es...
1,Notes_text_shrishti_chunk_9906.txt,"examinationfgc,well hydrated,no jaundice,febri...","{'examination': ['examinationfgc,well hydrated..."
2,Notes_text_shrishti_chunk_3595.txt,acute coronary syndrome profiledo cardiac biom...,{'plan': ['acute coronary syndrome profiledo c...
3,Notes_text_shrishti_chunk_2853.txt,"paracetamol tabs 500 mg, dosage: 2, route: 106...","{'recipere': ['paracetamol tabs 500 mg, dosage..."
4,Notes_text_shrishti_chunk_5682.txt,"nausea, poor appetite, no abdominal pain, no h...","{'history_of_previous_illness': ['nausea, poor..."
...,...,...,...
10516,Notes_text_shrishti_chunk_5843.txt,34 year old female with 2days history of lower...,{'history_of_previous_illness': ['34 year old ...
10517,Notes_text_shrishti_chunk_335.txt,"general fair general condition, not pale, not ...",{'examination': ['general fair general conditi...
10518,Notes_text_shrishti_chunk_453.txt,"cetrizine 5mg/5ml syrup 30ml, dosage: 1, route...","{'recipere': ['cetrizine 5mg/5ml syrup 30ml, d..."
10519,Notes_text_shrishti_chunk_2692.txt,"cetrizine 5mg/5ml syrup 50ml, dosage: 5ml, rou...","{'recipere': ['cetrizine 5mg/5ml syrup 50ml, d..."


In [5]:
# Define the instruction string
instruction = '''Classify the input into below categories:
- recipere
- investigations
- plan
- complaints
- history_of_previous_illness
- examination
- diagnoses
Output the classified data into JSON format'''

# Add the 'instruction' column with the repeated string
classified_shrishti_df['instruction'] = instruction

In [6]:
classified_shrishti_df

Unnamed: 0,filename,notes,classified,instruction
0,Notes_text_shrishti_chunk_1382.txt,amoxycillin +clarithromycin+ esomeprazole (100...,{'recipere': ['amoxycillin +clarithromycin+ es...,Classify the input into below categories:\n- r...
1,Notes_text_shrishti_chunk_9906.txt,"examinationfgc,well hydrated,no jaundice,febri...","{'examination': ['examinationfgc,well hydrated...",Classify the input into below categories:\n- r...
2,Notes_text_shrishti_chunk_3595.txt,acute coronary syndrome profiledo cardiac biom...,{'plan': ['acute coronary syndrome profiledo c...,Classify the input into below categories:\n- r...
3,Notes_text_shrishti_chunk_2853.txt,"paracetamol tabs 500 mg, dosage: 2, route: 106...","{'recipere': ['paracetamol tabs 500 mg, dosage...",Classify the input into below categories:\n- r...
4,Notes_text_shrishti_chunk_5682.txt,"nausea, poor appetite, no abdominal pain, no h...","{'history_of_previous_illness': ['nausea, poor...",Classify the input into below categories:\n- r...
...,...,...,...,...
10516,Notes_text_shrishti_chunk_5843.txt,34 year old female with 2days history of lower...,{'history_of_previous_illness': ['34 year old ...,Classify the input into below categories:\n- r...
10517,Notes_text_shrishti_chunk_335.txt,"general fair general condition, not pale, not ...",{'examination': ['general fair general conditi...,Classify the input into below categories:\n- r...
10518,Notes_text_shrishti_chunk_453.txt,"cetrizine 5mg/5ml syrup 30ml, dosage: 1, route...","{'recipere': ['cetrizine 5mg/5ml syrup 30ml, d...",Classify the input into below categories:\n- r...
10519,Notes_text_shrishti_chunk_2692.txt,"cetrizine 5mg/5ml syrup 50ml, dosage: 5ml, rou...","{'recipere': ['cetrizine 5mg/5ml syrup 50ml, d...",Classify the input into below categories:\n- r...


In [7]:
# Rename columns
new_column_names = {
    'notes': 'input',
    'classified': 'output',
}
classified_shrishti_df = classified_shrishti_df.rename(columns=new_column_names)
classified_shrishti_df

Unnamed: 0,filename,input,output,instruction
0,Notes_text_shrishti_chunk_1382.txt,amoxycillin +clarithromycin+ esomeprazole (100...,{'recipere': ['amoxycillin +clarithromycin+ es...,Classify the input into below categories:\n- r...
1,Notes_text_shrishti_chunk_9906.txt,"examinationfgc,well hydrated,no jaundice,febri...","{'examination': ['examinationfgc,well hydrated...",Classify the input into below categories:\n- r...
2,Notes_text_shrishti_chunk_3595.txt,acute coronary syndrome profiledo cardiac biom...,{'plan': ['acute coronary syndrome profiledo c...,Classify the input into below categories:\n- r...
3,Notes_text_shrishti_chunk_2853.txt,"paracetamol tabs 500 mg, dosage: 2, route: 106...","{'recipere': ['paracetamol tabs 500 mg, dosage...",Classify the input into below categories:\n- r...
4,Notes_text_shrishti_chunk_5682.txt,"nausea, poor appetite, no abdominal pain, no h...","{'history_of_previous_illness': ['nausea, poor...",Classify the input into below categories:\n- r...
...,...,...,...,...
10516,Notes_text_shrishti_chunk_5843.txt,34 year old female with 2days history of lower...,{'history_of_previous_illness': ['34 year old ...,Classify the input into below categories:\n- r...
10517,Notes_text_shrishti_chunk_335.txt,"general fair general condition, not pale, not ...",{'examination': ['general fair general conditi...,Classify the input into below categories:\n- r...
10518,Notes_text_shrishti_chunk_453.txt,"cetrizine 5mg/5ml syrup 30ml, dosage: 1, route...","{'recipere': ['cetrizine 5mg/5ml syrup 30ml, d...",Classify the input into below categories:\n- r...
10519,Notes_text_shrishti_chunk_2692.txt,"cetrizine 5mg/5ml syrup 50ml, dosage: 5ml, rou...","{'recipere': ['cetrizine 5mg/5ml syrup 50ml, d...",Classify the input into below categories:\n- r...


In [8]:
from datasets import Dataset
dataset = Dataset.from_pandas(classified_shrishti_df)

In [13]:
#from datasets import load_dataset

#dataset = load_dataset("TokenBender/code_instructions_122k_alpaca_style", split="train")
#dataset

In [14]:
#df = dataset.to_pandas()
#df.head(10)

Instruction Fintuning - Prepare the dataset under the format of "prompt" so the model can better understand :
1. the function generate_prompt : take the instruction and output and generate a prompt
2. shuffle the dataset
3. tokenizer the dataset

### Formatting the Dataset

Now, let's format the dataset in the required [gemma instruction formate](https://huggingface.co/google/gemma-7b-it).

> Many tutorials and blogs skip over this part, but I feel this is a really important step.

```
<start_of_turn>user What is your favorite condiment? <end_of_turn>
<start_of_turn>model Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavor to whatever I'm cooking up in the kitchen!<end_of_turn>
```

You can use the following code to process your dataset and create a JSONL file in the correct format:

In [9]:
def generate_prompt(data_point):
    """Gen. input text based on a prompt, task instruction, (context info.), and answer

    :param data_point: dict: Data point
    :return: dict: tokenzed prompt
    """
    prefix_text = 'You are a Medical Clinical notes document entity extraction specialist.' \
               'You are also provided with the Doctors Note.\n\n'
    # Samples with additional context into.
    if data_point['input']:
        text = f"""<start_of_turn>user {prefix_text} {data_point["instruction"]} here are the inputs {data_point["input"]} <end_of_turn>\n<start_of_turn>model{data_point["output"]} <end_of_turn>"""
    # Without
    else:
        text = f"""<start_of_turn>user {prefix_text} {data_point["instruction"]} <end_of_turn>\n<start_of_turn>model{data_point["output"]} <end_of_turn>"""
    return text

In [10]:
# add the "prompt" column in the dataset
text_column = [generate_prompt(data_point) for data_point in dataset]
dataset = dataset.add_column("prompt", text_column)

In [11]:
dataset['prompt'][0]

"<start_of_turn>user You are a Medical Clinical notes document entity extraction specialist.You are also provided with the Doctors Note.\n\n Classify the input into below categories:\n- recipere\n- investigations\n- plan\n- complaints\n- history_of_previous_illness\n- examination\n- diagnoses\nOutput the classified data into JSON format here are the inputs amoxycillin +clarithromycin+ esomeprazole (1000+500+20)mg, dosage: 1, route: 106.0, qty: 14.0, duration: 14.0, dura_unit: 193.0, instructions: take three tablets twice as indicated on the kit esomeprazole 40mg tabs, dosage: 1, route: 106.0, qty: 14.0, duration: 14.0, dura_unit: 193.0, instructions: take one tablet twice daily povidone iodine 200mg pessaries, dosage: 1, route: 106.0, qty: 1.0, duration: 7.0, dura_unit: 193.0, instructions: gargle & spit 10ml twice(2) daily terbutaline+ambroxol+guaifensine syrup, dosage: 1, route: 106.0, qty: 2.0, duration: 5.0, dura_unit: 193.0, instructions: take 10ml three(3) times daily dextrometho

We'll need to tokenize our data so the model can understand.


In [12]:
#from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

#model_id = "google/gemma-7b-it"
#tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token=True)

In [13]:
dataset = dataset.shuffle(seed=1234)  # Shuffle dataset here
dataset = dataset.map(lambda samples: tokenizer(samples["prompt"]), batched=True)

Map:   0%|          | 0/10521 [00:00<?, ? examples/s]

Split dataset into 90% for training and 10% for testing

In [14]:
dataset = dataset.train_test_split(test_size=0.2)
train_data = dataset["train"]
test_data = dataset["test"]

### After Formatting, We should get something like this

```json
{
"text":"<start_of_turn>user Create a function to calculate the sum of a sequence of integers. here are the inputs [1, 2, 3, 4, 5] <end_of_turn>
<start_of_turn>model # Python code def sum_sequence(sequence): sum = 0 for num in sequence: sum += num return sum <end_of_turn>",
"instruction":"Create a function to calculate the sum of a sequence of integers",
"input":"[1, 2, 3, 4, 5]",
"output":"# Python code def sum_sequence(sequence): sum = 0 for num in,
 sequence: sum += num return sum",
"prompt":"<start_of_turn>user Create a function to calculate the sum of a sequence of integers. here are the inputs [1, 2, 3, 4, 5] <end_of_turn>
<start_of_turn>model # Python code def sum_sequence(sequence): sum = 0 for num in sequence: sum += num return sum <end_of_turn>"

}
```

While using SFT (**[Supervised Fine-tuning Trainer](https://huggingface.co/docs/trl/main/en/sft_trainer)**) for fine-tuning, we will be only passing in the “text” column of the dataset for fine-tuning.

In [15]:
print(test_data)

Dataset({
    features: ['filename', 'input', 'output', 'instruction', 'prompt', 'input_ids', 'attention_mask'],
    num_rows: 2105
})


## Step 4 - Apply Lora  
Here comes the magic with peft! Let's load a PeftModel and specify that we are going to use low-rank adapters (LoRA) using get_peft_model utility function and  the prepare_model_for_kbit_training method from PEFT.

In [16]:
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

In [17]:
print(model)

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 3072, padding_idx=0)
    (layers): ModuleList(
      (0-27): 28 x GemmaDecoderLayer(
        (self_attn): GemmaSdpaAttention(
          (q_proj): Linear4bit(in_features=3072, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=3072, out_features=4096, bias=False)
          (v_proj): Linear4bit(in_features=3072, out_features=4096, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=3072, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear4bit(in_features=3072, out_features=24576, bias=False)
          (up_proj): Linear4bit(in_features=3072, out_features=24576, bias=False)
          (down_proj): Linear4bit(in_features=24576, out_features=3072, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): GemmaRMSNorm()
        (post_attention_layernorm): GemmaRMSNorm()
  

In [18]:
import bitsandbytes as bnb
def find_all_linear_names(model):
    cls = bnb.nn.Linear4bit #if args.bits == 4 else (bnb.nn.Linear8bitLt if args.bits == 8 else torch.nn.Linear)
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
        if 'lm_head' in lora_module_names: # needed for 16-bit
            lora_module_names.remove('lm_head')
    return list(lora_module_names)

In [19]:
modules = find_all_linear_names(model)
print(modules)

['o_proj', 'down_proj', 'q_proj', 'gate_proj', 'up_proj', 'v_proj', 'k_proj']


In [20]:
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=64,
    lora_alpha=32,
    target_modules=modules,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)

In [21]:
trainable, total = model.get_nb_trainable_parameters()
print(f"Trainable: {trainable} | total: {total} | Percentage: {trainable/total*100:.4f}%")

Trainable: 200015872 | total: 8737696768 | Percentage: 2.2891%


## Step 5 - Run the training!

Setting the training arguments:
* for the reason of demo, we just ran it for few steps (100) just to showcase how to use this integration with existing tools on the HF ecosystem.

In [23]:
# import transformers

# tokenizer.pad_token = tokenizer.eos_token


# trainer = transformers.Trainer(
#     model=model,
#     train_dataset=train_data,
#     eval_dataset=test_data,
#     args=transformers.TrainingArguments(
#         per_device_train_batch_size=1,
#         gradient_accumulation_steps=4,
#         warmup_steps=0.03,
#         max_steps=100,
#         learning_rate=2e-4,
#         fp16=True,
#         logging_steps=1,
#         output_dir="outputs_mistral_b_finance_finetuned_test",
#         optim="paged_adamw_8bit",
#         save_strategy="epoch",
#     ),
#     data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
# )


### Fine-Tuning with qLora and Supervised Fine-Tuning

We're ready to fine-tune our model using qLora. For this tutorial, we'll use the `SFTTrainer` from the `trl` library for supervised fine-tuning. Ensure that you've installed the `trl` library as mentioned in the prerequisites.

In [22]:
#new code using SFTTrainer
import transformers

from trl import SFTTrainer

tokenizer.pad_token = tokenizer.eos_token
torch.cuda.empty_cache()

trainer = SFTTrainer(
    model=model,
    train_dataset=train_data,
    eval_dataset=test_data,
    dataset_text_field="prompt",
    peft_config=lora_config,
    args=transformers.TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        #warmup_steps=0.03,
        max_steps=1000,
        learning_rate=2e-4,
        logging_steps=100,
        output_dir="./finetune_LLM/finetuned_LLM/gemma_shrishti_outputs/",
        optim="paged_adamw_8bit",
        save_strategy="epoch",
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)


Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.
max_steps is given, it will override any value given in num_train_epochs


## Lets start training

In [None]:
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train()

 Share adapters on the 🤗 Hub

In [25]:
new_model = "./finetune_LLM/finetuned_LLM/gemma_shrishti_outputs/gemma-Code-Instruct-Finetune-test" #Name of the model you will be pushing to huggingface model hub

In [26]:
trainer.model.save_pretrained(new_model)

In [28]:
model_id = "google/gemma-7b-it"

In [29]:
base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float16,
    device_map={"": 0},
)
merged_model= PeftModel.from_pretrained(base_model, new_model)
merged_model= merged_model.merge_and_unload()


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [30]:
# Save the merged model
merged_model.save_pretrained("./finetune_LLM/finetuned_LLM/gemma_shrishti_outputs/merged_model",safe_serialization=True)
tokenizer.save_pretrained("./finetune_LLM/finetuned_LLM/gemma_shrishti_outputs/merged_model")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

In [35]:
# Push the model and tokenizer to the Hugging Face Model Hub
merged_model.push_to_hub(new_model, use_temp_dir=False)
tokenizer.push_to_hub(new_model, use_temp_dir=False)

HFValidationError: Repo id must be in the form 'repo_name' or 'namespace/repo_name': './finetune_LLM/finetuned_LLM/gemma_os_outputs/gemma-Code-Instruct-Finetune-test'. Use `repo_type` argument if needed.

## Test out Finetuned Model

In [31]:
def generate_test_prompt(data_point):
    """Gen. input text based on a prompt, task instruction, (context info.), and answer

    :param data_point: dict: Data point
    :return: dict: tokenzed prompt
    """
    prefix_text = 'You are a Medical Clinical notes document entity extraction specialist.' \
               'You are also provided with the Doctors Note.\n\n'
    # Samples with additional context into.
    if data_point['input']:
        text = f"""<start_of_turn>user {prefix_text} {data_point["instruction"]} here are the inputs {data_point["input"]} <end_of_turn>"""
    
    return text

In [32]:
test_data

Dataset({
    features: ['filename', 'input', 'output', 'instruction', 'prompt', 'input_ids', 'attention_mask'],
    num_rows: 2105
})

In [34]:
# add the "prompt" column in the dataset
text_column = [generate_test_prompt(data_point) for data_point in test_data]
test_data = test_data.add_column("test_prompt", text_column)

Flattening the indices:   0%|          | 0/2105 [00:00<?, ? examples/s]

In [37]:
prompt = test_data['test_prompt'][0]

In [49]:
def get_completion(query: str, model, tokenizer) -> str:
    device = "cuda:0"
    '''
    prompt_template = """
    <start_of_turn>user
    Below is an instruction that describes a task. Write a response that appropriately completes the request.
    {query}
    <end_of_turn>\n<start_of_turn>model


    """
    '''
    prompt = query #prompt_template.format(query=query)

    encodeds = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)

    model_inputs = encodeds.to(device)


    generated_ids = model.generate(**model_inputs, 
                                   max_new_tokens=512, 
                                   num_beams=5,
                                   temperature=0.0,
                                   top_k=10,
                                   top_p=0.9,
                                   repetition_penalty=2.0,
                                   early_stopping=True,
                                   do_sample=False, 
                                   pad_token_id=tokenizer.eos_token_id)
    # decoded = tokenizer.batch_decode(generated_ids)
    decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    return (decoded)

In [50]:
%%time
result = get_completion(query=prompt, model=merged_model, tokenizer=tokenizer)
print(result)

user You are a Medical Clinical notes document entity extraction specialist.You are also provided with the Doctors Note.

 Classify the input into below categories:
- recipere
- investigations
- plan
- complaints
- history_of_previous_illness
- examination
- diagnoses
Output the classified data into JSON format here are the inputs chlorzoxazone 500mg tablets, dosage: 1tab, route: 106.0, qty: 15.0, duration: 10.0, dura_unit: 193.0, instructions: take one(1) tablet three(3) times daily;when neccessary nifedipine sr 20mg tablets, dosage: 1tab, route: 106.0, qty: 2.0, duration: 1.0, dura_unit: 193.0, instructions: take two tablets stat amlodipine + valsartan+hydrochlothiazide (5+160+12.5)mg tablets, dosage: 1tab, route: 106.0, qty: 7.0, duration: 1.0, dura_unit: 194.0, instructions: take one tablet once daily cefuroxime 500mg tablets, dosage: 1tab, route: 106.0, qty: 14.0, duration: 7.0, dura_unit: 193.0, instructions: take one tablet twice daily
diarrhea three times, watery, foul smelling

In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load the tokenizer and merged model
model_path = "./finetune_LLM/finetuned_LLM/gemma_shrishti_outputs/merged_model"
tokenizer = AutoTokenizer.from_pretrained(model_path)
merged_model = AutoModelForCausalLM.from_pretrained(model_path)

# Ensure the model is in evaluation mode
merged_model.eval()

# Move model to CPU
merged_model.to('cpu')


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 3072, padding_idx=0)
    (layers): ModuleList(
      (0-27): 28 x GemmaDecoderLayer(
        (self_attn): GemmaSdpaAttention(
          (q_proj): Linear(in_features=3072, out_features=4096, bias=False)
          (k_proj): Linear(in_features=3072, out_features=4096, bias=False)
          (v_proj): Linear(in_features=3072, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=3072, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=3072, out_features=24576, bias=False)
          (up_proj): Linear(in_features=3072, out_features=24576, bias=False)
          (down_proj): Linear(in_features=24576, out_features=3072, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): GemmaRMSNorm()
        (post_attention_layernorm): GemmaRMSNorm()
      )
    )
    (norm): Gemm

In [2]:
import torch.quantization

# Apply dynamic quantization to the model
quantized_model = torch.quantization.quantize_dynamic(
    merged_model,  # the original model
    {torch.nn.Linear},  # layers to quantize
    dtype=torch.qint8  # quantization dtype
)

# Save the quantized model
quantized_model.save_pretrained("./finetune_LLM/finetuned_LLM/gemma_shrishti_outputs/quantized_merged_model")
tokenizer.save_pretrained("./finetune_LLM/finetuned_LLM/gemma_shrishti_outputs/quantized_merged_model")



AttributeError: 'torch.dtype' object has no attribute 'data_ptr'

In [None]:
test_data

In [None]:
prompt = test_data['test_prompt'][0]

In [None]:
def get_completion(query: str, model, tokenizer) -> str:
    device = "cuda:0"
    '''
    prompt_template = """
    <start_of_turn>user
    Below is an instruction that describes a task. Write a response that appropriately completes the request.
    {query}
    <end_of_turn>\n<start_of_turn>model


    """
    '''
    prompt = query #prompt_template.format(query=query)

    encodeds = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)

    model_inputs = encodeds.to(device)


    generated_ids = model.generate(**model_inputs, 
                                   max_new_tokens=512, 
                                   num_beams=5,
                                   temperature=0.0,
                                   top_k=10,
                                   top_p=0.9,
                                   repetition_penalty=2.0,
                                   early_stopping=True,
                                   do_sample=False, 
                                   pad_token_id=tokenizer.eos_token_id)
    # decoded = tokenizer.batch_decode(generated_ids)
    decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    return (decoded)

In [None]:
# Load the quantized model
quantized_model = AutoModelForCausalLM.from_pretrained("./finetune_LLM/finetuned_LLM/gemma_shrishti_outputs/quantized_merged_model")

# Move the model to CPU
quantized_model.to('cpu')

# Ensure the model is in evaluation mode
quantized_model.eval()

# Example inference
#input_text = "You are a Medical Clinical notes document entity extraction specialist."
#inputs = tokenizer(input_text, return_tensors="pt")

# Generate output
with torch.no_grad():
    output = quantized_model.generate(**inputs)
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)

print(generated_text)
