# Empowering Young Minds with Gemma

In [None]:
try:
    import accelerate
except Exception as e:
    !pip install -q -U transformers
    !pip install -q datasets accelerate
    !pip install -q lomo-optim

!pip install peft trl lomo-optim
!pip install bitsandbytes>=0.39.0
!pip install --upgrade accelerate transformers

In [None]:
import huggingface_hub
huggingface_hub.notebook_login()

In [None]:
import json
from datasets import Dataset

from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, BitsAndBytesConfig, TrainingArguments
import transformers
import torch

from peft import LoraConfig, PeftModel
from trl import SFTTrainer

from tools.dataset import *

## load and preprocess data

In [None]:
with open(data_path, "r") as f:
    data_dict = json.load(f)

In [None]:
dataset = generate_dataset(data_dict)

## define model

In [None]:
lora_config = LoraConfig(
    r=6,
    lora_alpha = 8,
    lora_dropout = 0.05,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

In [None]:
model_id = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    # attn_implementation='eager',
    quantization_config=bnb_config
)

## Fine tuning

In [None]:
# define trainer
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    max_seq_length=512,
    args=TrainingArguments(
        output_dir="outputs",
        #num_train_epochs = 10,
        max_steps=1000,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        optim="paged_adamw_8bit",
        learning_rate=2e-4,
        fp16=True,
        logging_steps=100,
        push_to_hub=False,
        report_to='none',
    ),
    peft_config=lora_config,
    formatting_func=generate_chat_prompts,
)

In [None]:
trainer.train()