# Gemma Fine-tuning Prototype

This notebook demonstrates a basic workflow for fine-tuning a Gemma model using the LoRA approach.

## Setup

First, let's install the required dependencies:

In [1]:
!pip install -q transformers datasets accelerate peft bitsandbytes evaluate



## Load Dataset

We'll load a sample healthcare Q&A dataset:

In [2]:
import json
import pandas as pd
from datasets import Dataset

# Load the JSONL file
data = []
with open('../examples/healthcare_sample.jsonl', 'r') as f:
    for line in f:
        data.append(json.loads(line))

# Convert to DataFrame
df = pd.DataFrame(data)
print(f"Loaded {len(df)} examples")
df.head()

ModuleNotFoundError: No module named 'pandas'

## Preprocess Dataset

Now let's preprocess the dataset for fine-tuning:

In [None]:
from datasets import Dataset, DatasetDict
import numpy as np

# Convert to Hugging Face Dataset
dataset = Dataset.from_pandas(df)

# Split dataset
train_test = dataset.train_test_split(test_size=0.2, seed=42)
train_val = train_test["train"].train_test_split(test_size=0.125, seed=42)

# Create DatasetDict
dataset_dict = DatasetDict({
    "train": train_val["train"],
    "validation": train_val["test"],
    "test": train_test["test"]
})

print(f"Train: {len(dataset_dict['train'])} examples")
print(f"Validation: {len(dataset_dict['validation'])} examples")
print(f"Test: {len(dataset_dict['test'])} examples")

## Format Dataset for Instruction Tuning

Let's format our dataset for instruction tuning:

In [None]:
def format_instruction(example):
    """Format the example as an instruction."""
    return {
        "text": f"<start_of_turn>user\n{example['input']}\n<end_of_turn>\n<start_of_turn>model\n{example['output']}\n<end_of_turn>"
    }

# Apply formatting
formatted_dataset = dataset_dict.map(format_instruction)

# Show an example
print(formatted_dataset["train"][0]["text"])

## Load Tokenizer and Model

Now let's load the Gemma tokenizer and model:

In [None]:
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

# Set your Google API key
os.environ["GOOGLE_API_KEY"] = "your_api_key_here"  # Replace with your actual API key

# Load tokenizer
model_id = "google/gemma-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ["GOOGLE_API_KEY"])

# Set padding token
tokenizer.pad_token = tokenizer.eos_token

# Configure quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True
)

# Load model
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
    token=os.environ["GOOGLE_API_KEY"]
)

## Tokenize Dataset

Let's tokenize our dataset:

In [None]:
def tokenize_function(examples):
    """Tokenize the examples and prepare for training."""
    return tokenizer(
        examples["text"],
        truncation=True,
        max_length=512,
        padding="max_length"
    )

# Tokenize the dataset
tokenized_dataset = formatted_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=["text", "input", "output"]
)

## Configure LoRA

Now let's set up LoRA for parameter-efficient fine-tuning:

In [None]:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType

# Prepare model for k-bit training
model = prepare_model_for_kbit_training(model)

# Configure LoRA
lora_config = LoraConfig(
    r=16,  # Rank
    lora_alpha=32,  # Alpha parameter
    lora_dropout=0.1,  # Dropout probability
    bias="none",
    task_type=TaskType.CAUSAL_LM,
    target_modules=["q_proj", "v_proj"]  # Attention modules to apply LoRA to
)

# Apply LoRA to the model
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

## Set Up Training Arguments

Let's configure the training arguments:

In [None]:
from transformers import TrainingArguments

# Set up training arguments
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=2,
    learning_rate=2e-4,
    weight_decay=0.01,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    evaluation_strategy="steps",
    eval_steps=50,
    save_strategy="steps",
    save_steps=50,
    logging_steps=10,
    load_best_model_at_end=True,
    fp16=True,
    optim="paged_adamw_8bit",
    report_to="tensorboard"
)

## Set Up Trainer

Now let's set up the trainer:

In [None]:
from transformers import Trainer, DataCollatorForLanguageModeling

# Set up data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False  # Not using masked language modeling
)

# Set up trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    data_collator=data_collator
)

## Train the Model

Now let's train the model:

In [None]:
# Train the model
trainer.train()

## Evaluate the Model

Let's evaluate the model on the test set:

In [None]:
# Evaluate on test set
eval_results = trainer.evaluate(tokenized_dataset["test"])
print(f"Test loss: {eval_results['eval_loss']:.4f}")

## Save the Model

Let's save the fine-tuned model:

In [None]:
# Save the model
model.save_pretrained("./healthcare-gemma-2b-it")
tokenizer.save_pretrained("./healthcare-gemma-2b-it")

## Test the Model

Let's test the fine-tuned model with some sample questions:

In [None]:
from transformers import pipeline

# Load the fine-tuned model
fine_tuned_model = get_peft_model(
    AutoModelForCausalLM.from_pretrained(
        model_id,
        quantization_config=bnb_config,
        device_map="auto",
        token=os.environ["GOOGLE_API_KEY"]
    ),
    "./healthcare-gemma-2b-it"
)

# Create a text generation pipeline
generator = pipeline(
    "text-generation",
    model=fine_tuned_model,
    tokenizer=tokenizer,
    max_length=512,
    do_sample=True,
    temperature=0.7,
    top_p=0.9
)

# Test with a sample question
question = "What are the early signs of a heart attack?"
prompt = f"<start_of_turn>user\n{question}\n<end_of_turn>\n<start_of_turn>model\n"

response = generator(prompt)[0]["generated_text"]
print(response)