<a href="https://colab.research.google.com/github/Nakshatra1729yuvi/Finetuning/blob/main/Finetuning_Gemma_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -U accelerate bitsandbytes peft transformers trl



In [None]:
import os
import torch
from google.colab import userdata
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging
)
from peft import LoraConfig,PeftModel
from trl import SFTTrainer

In [None]:
os.environ['HF_TOKEN']=userdata.get('HF_TOKEN')

In [None]:
# <bos><start_of_turn>user
# {user_message}<end_of_turn>
# <start_of_turn>model
# {assistant_message}<end_of_turn>

In [None]:
model_name='google/gemma-2b'
dataset_name='eswardivi/medical_qa'
new_model='Gemma-2b-chat-finetune'


# QLORA Parameters

lora_rank=4

lora_alpha=16

lora_dropout=0.1

#bitsandbytes parameters

use_4bit=True
bnb_4bit_compute_dtype='float16'
bnb_4bit_quant_type = 'nf4' ##Quant type(nf4 or gp4)
use_nested_quant=False


#TrainingArguments parameters
output_dir="./results"

num_train_epochs=1

fp16=False
bf16=False    # set bf16 to True for A100 GPU

per_device_train_batch_size=1

per_device_eval_batch_size=1

gradient_accumulation_steps=1

gradient_checkpointing=True

max_grad_norm=0.3

learning_rate=2e-4

weight_decay=0.001

optim="paged_adamw_32bit"

lr_scheduler_type="cosine"

max_steps=-1

warmup_ratio=0.03

group_by_length=True

save_steps=0

logging_steps=25

#SFT parameters

max_seq_length=None

packing=False

device_map={"":0}



In [None]:
dataset=load_dataset(dataset_name,split="train")
def preprocess(example):
    instruction = example.get("instruction", "").strip()
    inp = example.get("input", "").strip()
    answer = example.get("output", "").strip() or example.get("response", "").strip()

    # Log if any fields are missing or empty
    if not instruction or not answer:
        print(f"Warning: Empty instruction or answer in example: {example}")
        return {"text": ""}  # Return empty text to avoid breaking the pipeline

    # Build user message
    if inp:
        user_message = f"{instruction}\n{inp}"
    else:
        user_message = instruction

    # Format in Gemma chat style
    text = (
        "<start_of_turn>user\n"
        f"{user_message.strip()}\n"
        "<end_of_turn>\n"
        "<start_of_turn>model\n"
        f"{answer.strip()}\n"
        "<end_of_turn>"
    )
    return {"text": text}

# Apply preprocessing
dataset = dataset.map(preprocess)

In [None]:
print("Columns after preprocessing:", dataset.column_names)
print("First preprocessed example:", dataset[0])

Columns after preprocessing: ['instruction', 'input', 'output', 'text']
First preprocessed example: {'instruction': "My daughter ( F, 18 y/o, 5'5', 165lbs) has been feeling poorly for a 6-8 months. She had COVID a couple of months ago and symptoms have are much worse in the last month or so. Symptoms seem POTS-like. She feels light headed, breathless, dizzy, HR goes from ~65 lying down to ~155-160 on standing. Today she tells me HR has been around 170 all day and she feels really lousy. (She using an OTC pulse ox to measure.) She has a cardiology appt but not until March and a PCP appt but not until April since she's at school and it's a new provider. What to do? Is this a on call nurse sort of issue? Or a trip to the ED? Or wait till tomorrow and try for an early appt? Try a couple of Valsalvas? Wait it out until her cardio appt? Or? She's away at school if Boston, what to do? Thank you", 'input': '', 'output': 'If she actually has a HR of 170 that is accurate, ongoing and persistent,

In [None]:
compute_dtype=getattr(torch,bnb_4bit_compute_dtype)

bnb_config=BitsAndBytesConfig(
    load_in_4bit=use_4bit,
    bnb_4bit_quant_type=bnb_4bit_quant_type,
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=use_nested_quant
)

# if compute_dtype==torch.float16 and use_4bit:
#   major,_=torch.cuda.get_device_capability()
#   if major>=8:
#     print("-"*10)
#     print("Your GPU supprots bffloat16:accelerate training with bf16=True")
#     print("-"*10)


model=AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map=device_map,
    token=os.environ['HF_TOKEN']
)

model.config.use_cache=False
model.config.pretraining_tp=1

tokenizer=AutoTokenizer.from_pretrained(model_name,token=os.environ['HF_TOKEN'])
tokenizer.pad_token=tokenizer.eos_token
tokenizer.padding_side="right"

peft_config=LoraConfig(
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    r=lora_rank,
    bias="none",
    task_type="CAUSAL_LM"
)

training_arguments = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim=optim,
    save_steps=save_steps,
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    fp16=fp16,
    bf16=bf16,
    max_grad_norm=max_grad_norm,
    max_steps=max_steps,
    warmup_ratio=warmup_ratio,
    group_by_length=group_by_length,
    lr_scheduler_type=lr_scheduler_type,
    report_to="tensorboard"
)
def formatting_prompts_func(example):
    return example["text"]

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=peft_config,
    args=training_arguments,
    formatting_func=formatting_prompts_func
    )


trainer.train()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Applying formatting function to train dataset:   0%|          | 0/6307 [00:00<?, ? examples/s]

Adding EOS to train dataset:   0%|          | 0/6307 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/6307 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/6307 [00:00<?, ? examples/s]

Step,Training Loss
25,3.9336
50,4.4367
75,3.4594
100,3.5643
125,3.0304
150,3.1119
175,2.9294
200,2.6451
225,2.9706
250,2.6056


In [None]:
trainer.model.save_pretrained(new_model)

In [None]:
torch.cuda.empty_cache()

In [None]:
# del model
# del trainer
# del tokenizer
import gc
gc.collect()
gc.collect()

In [None]:
# Reload model in FP16 and merge it with LoRA weights
base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float16,
    device_map=device_map,
    token=os.environ['HF_TOKEN']
)
model = PeftModel.from_pretrained(base_model, new_model)
model = model.merge_and_unload()

# Reload tokenizer to save it
tokenizer = AutoTokenizer.from_pretrained(model_name,token=os.environ['HF_TOKEN'])
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

In [None]:
logging.set_verbosity(logging.CRITICAL)
prompt="What is meta"
pipe=pipeline(task="text-generation",model=model,tokenizer=tokenizer,max_length=100)
result=pipe(
        "<start_of_turn>user\n"
        f"{prompt}\n"
        "<end_of_turn>\n"
        "<start_of_turn>model\n")
print(result[0]['generated_text'])

In [None]:
from huggingface_hub import login

# This will ask you to paste your HF token
login(token=os.environ['HF_TOKEN'])


In [None]:
# Assuming `trainer.model` is your PEFT model
trainer.model.save_pretrained("Gemma-2b-lora-adapter")


In [None]:
from huggingface_hub import HfApi

repo_name = "Gemma-2b-lora-medicalqa"
api = HfApi()
api.create_repo(repo_id=repo_name, exist_ok=True, repo_type="model")


In [None]:
from huggingface_hub import Repository

repo = Repository(local_dir="Gemma-2b-lora-adapter", clone_from=repo_name)
repo.push_to_hub(commit_message="Add Gemma LoRA adapter for medical QA")


In [None]:
use_4bit=True
bnb_4bit_compute_dtype='float16'
bnb_4bit_quant_type = 'nf4' ##Quant type(nf4 or gp4)
use_nested_quant=False

compute_dtype=getattr(torch,bnb_4bit_compute_dtype)

bnb_config=BitsAndBytesConfig(
    load_in_4bit=use_4bit,
    bnb_4bit_quant_type=bnb_4bit_quant_type,
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=use_nested_quant
)


In [None]:
from transformers import AutoModelForCausalLM
from peft import PeftModel

base_model_name = "google/gemma-2b"
lora_model_name = "Nakshatra1729/Gemma-2b-lora-medicalqa"

# Load base model
base_model = AutoModelForCausalLM.from_pretrained(base_model_name,token=os.environ['HF_TOKEN'],quantization_config=bnb_config,)

# Load LoRA adapter
model = PeftModel.from_pretrained(base_model, lora_model_name)




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
tokenizer = AutoTokenizer.from_pretrained(base_model_name,token=os.environ['HF_TOKEN'])
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

tokenizer_config.json:   0%|          | 0.00/33.6k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

In [None]:
logging.set_verbosity(logging.CRITICAL)
prompt="What to do if i have stomach ache"
pipe=pipeline(task="text-generation",model=model,tokenizer=tokenizer,max_length=20)
result=pipe(
        "<start_of_turn>user\n"
        f"{prompt}\n"
        "<end_of_turn>\n"
        "<start_of_turn>model\n")
print(result[0]['generated_text'])

KeyboardInterrupt: 