# SAT Tutor - Fine-tuned with Gemma 3N

This project fine-tunes **Gemma 3N** to create a subject-specific SAT tutoring assistant, capable of solving and explaining math and reasoning problems.

Special thanks to **Unsloth AI** for their excellent [Gemma 3N fine-tuning notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Gemma3N_(4B)-Conversational.ipynb), which served as the foundation for this workflow.


### Installation

In [None]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer
    !pip install --no-deps unsloth

In [None]:
%%capture
# Install latest transformers for Gemma 3N
!pip install --no-deps --upgrade timm # Only for Gemma 3N

### Unsloth

`FastModel` supports loading any Gemma model.

In [None]:
from unsloth import FastModel
import torch

fourbit_models = [
    # 4bit dynamic quants for superior accuracy and low memory use
    "unsloth/gemma-3n-E4B-it-unsloth-bnb-4bit",
    "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit",
    # Pretrained models
    "unsloth/gemma-3n-E4B-unsloth-bnb-4bit",
    "unsloth/gemma-3n-E2B-unsloth-bnb-4bit",

    # Other Gemma 3 quants
    "unsloth/gemma-3-1b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-4b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-12b-it-unsloth-bnb-4bit",
    "unsloth/gemma-3-27b-it-unsloth-bnb-4bit",
] # More models at https://huggingface.co/unsloth

model, tokenizer = FastModel.from_pretrained(
    model_name = "unsloth/gemma-3n-E4B-it",
    dtype = None, # None for auto detection
    max_seq_length = 1024, # Choose any for long context!
    load_in_4bit = True,  # 4 bit quantization to reduce memory
    full_finetuning = False,
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.8.1: Fast Gemma3N patching. Transformers: 4.54.0.
   \\   /|    NVIDIA L4. Num GPUs = 1. Max memory: 22.161 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.9. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Gemma3N does not support SDPA - switching to eager!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/3.72G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/1.15G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/210 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/98.0 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

preprocessor_config.json: 0.00B [00:00, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/4.70M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/777 [00:00<?, ?B/s]

# Finetuning Gemma 3N!

We now add LoRA adapters so we only need to update a small amount of parameters!

In [None]:
model = FastModel.get_peft_model(
    model,
    finetune_vision_layers     = False, # Turn off for just text!
    finetune_language_layers   = True,  # Should leave on!
    finetune_attention_modules = True,  # Attention good for GRPO
    finetune_mlp_modules       = True,  # Should leave on always!

    r = 8,           # Larger = higher accuracy, but might overfit
    lora_alpha = 8,  # Recommended alpha == r at least
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
)

Unsloth: Making `model.base_model.model.model.language_model` require gradients


<a name="Data"></a>
### Data Prep
We now use the `Gemma-3` format for conversation style finetunes. We use [GSM8k](https://huggingface.co/datasets/openai/gsm8k), [RACE](https://huggingface.co/datasets/ehovy/race), [openbookqa](https://huggingface.co/datasets/allenai/openbookqa), and [truthful_qa](https://huggingface.co/datasets/truthfulqa/truthfulqa) datasets in ShareGPT style. Gemma-3 renders multi turn conversations like below:

```
<bos><start_of_turn>user
Hello!<end_of_turn>
<start_of_turn>model
Hey there!<end_of_turn>
```

We use our `get_chat_template` function to get the correct chat template.

In [None]:
from unsloth.chat_templates import get_chat_template
tokenizer = get_chat_template(
    tokenizer,
    chat_template = "gemma-3",
)

Loading and merging all datasets.

In [None]:
from datasets import load_dataset, concatenate_datasets

# Load datasets
dataset1 = load_dataset("gsm8k", "main", split="train")
dataset2 = load_dataset("ehovy/race", "all", split="train")
dataset3 = load_dataset("allenai/openbookqa", "main", split="train")
dataset4 = load_dataset("truthfulqa/truthful_qa", "generation", split="validation")

# Combine all datasets
dataset = concatenate_datasets([dataset1, dataset2, dataset3, dataset4])

README.md: 0.00B [00:00, ?B/s]

main/train-00000-of-00001.parquet:   0%|          | 0.00/2.31M [00:00<?, ?B/s]

main/test-00000-of-00001.parquet:   0%|          | 0.00/419k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

README.md: 0.00B [00:00, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/2.08M [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/37.4M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/2.05M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4934 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/87866 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/4887 [00:00<?, ? examples/s]

README.md: 0.00B [00:00, ?B/s]

main/train-00000-of-00001.parquet:   0%|          | 0.00/496k [00:00<?, ?B/s]

main/validation-00000-of-00001.parquet:   0%|          | 0.00/58.2k [00:00<?, ?B/s]

main/test-00000-of-00001.parquet:   0%|          | 0.00/55.5k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/4957 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/500 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/500 [00:00<?, ? examples/s]

README.md: 0.00B [00:00, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/223k [00:00<?, ?B/s]

Generating validation split:   0%|          | 0/817 [00:00<?, ? examples/s]

We now use `standardize_data_formats` to try converting datasets to the correct format for finetuning purposes!

In [None]:
from unsloth.chat_templates import standardize_data_formats
dataset = standardize_data_formats(dataset)

In [None]:
from datasets import Dataset
from typing import Union

def preprocess_dataset(dataset: Union[Dataset, list], seed: int = 42):
    """
    Filters out bad rows (empty questions/answers), removes duplicates, and shuffles the dataset.
    Assumes 'question' and 'answer' fields exist.
    """

    # Step 1: Check and filter out empty question or answer
    print("Checking and filtering empty or invalid questions/answers...")
    dataset = dataset.filter(
        lambda x: isinstance(x["question"], str) and x["question"].strip() != "" and
                  isinstance(x["answer"], str) and x["answer"].strip() != ""
    )

    # Step 2: Deduplicate using string key of question + answer
    seen = set()
    def dedup(example):
        key = (example["question"].strip(), example["answer"].strip())
        if key in seen:
            return False
        seen.add(key)
        return True

    print("Removing duplicates...")
    dataset = dataset.filter(dedup)

    # Step 3: Shuffle
    print("Shuffling dataset...")
    dataset = dataset.shuffle(seed=seed)

    print(f"Final dataset size: {len(dataset)}")
    return dataset

dataset = preprocess_dataset(dataset)

🔍 Checking and filtering empty or invalid questions/answers...


Filter:   0%|          | 0/101113 [00:00<?, ? examples/s]

🔄 Removing duplicates...


Filter:   0%|          | 0/95339 [00:00<?, ? examples/s]

🔀 Shuffling dataset...
✅ Final dataset size: 83366


Let's see how row 100 looks like!

In [None]:
dataset[100]

{'question': 'If you suggest someone for the awards, you should   _  .',
 'answer': 'C',
 'example_id': 'high12584.txt',
 'article': 'Do you know a child who has used first aid to save a life or help an injured person?\nSt.John Ambulance is seeking young people who have acted quickly, calmly and effectively at a real emergency for its annual Young First Aider of the Year awards.\nThe awards are open to all those under 18, and the closing date for nomination   is April 30, 2016.The winners will be invited to attend a special ceremony in June, 2016.\n"St.John Ambulance believes it is essential for young people to learn first aid so that they can help anyone who is injured," said Sandra Stocker, director of St.John Ambulance Awards Committee."The Young First Aider of the Year is a wonderful way to celebrate their bravery and quick-thinking."\nNomination for the Young First Aider of the Year is now open.Please complete and return the nomination forms as soon as possible and certainly no la

We now have to apply the chat template for `Gemma-3` onto the conversations, and save it to `text`. We remove the `<bos>` token using removeprefix(`'<bos>'`) since we're finetuning. The Processor will add this token before training and the model expects only one.

In [None]:
def formatting_prompts_func(examples):
    texts = []
    for q, a in zip(examples["question"], examples["answer"]):
        # Ensure both question and answer are strings, defaulting to empty string if None
        q = str(q) if q is not None else ""
        a = str(a) if a is not None else ""

        chat = [
            {"role": "user", "content": q},
            {"role": "assistant", "content": a}
        ]
        try:
            text = tokenizer.apply_chat_template(
                chat,
                tokenize=False,
                add_generation_prompt=False
            ).removeprefix("<bos>")
        except Exception as e:
            print(f"Skipping due to error: {e}")
            continue
        texts.append(text)

    return {"text": texts}

dataset = dataset.map(formatting_prompts_func, batched=True)

Map:   0%|          | 0/83366 [00:00<?, ? examples/s]

Let's see how the chat template did! Notice there is no `<bos>` token as the processor tokenizer will be adding one.

In [None]:
dataset[100]["text"]

'<start_of_turn>user\nIf you suggest someone for the awards, you should   _  .<end_of_turn>\n<start_of_turn>model\nC<end_of_turn>\n'

<a name="Train"></a>
### Train the model
Now let's use Huggingface TRL's `SFTTrainer`! We do 100 steps.

In [None]:
from trl import SFTTrainer, SFTConfig
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    eval_dataset = None, # Can set up evaluation!
    args = SFTConfig(
        dataset_text_field = "text",
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 4, # Use GA to mimic batch size
        warmup_steps = 5,
        max_steps = 100,
        learning_rate = 2e-5,
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        report_to = "none",
    ),
)

Unsloth: Tokenizing ["text"] (num_proc=2):   0%|          | 0/83366 [00:00<?, ? examples/s]

We also use Unsloth's `train_on_completions` method to only train on the assistant outputs and ignore the loss on the user's inputs. This helps increase accuracy of finetunes!

In [None]:
from unsloth.chat_templates import train_on_responses_only
trainer = train_on_responses_only(
    trainer,
    instruction_part = "<start_of_turn>user\n",
    response_part = "<start_of_turn>model\n",
)

Map (num_proc=12):   0%|          | 0/83366 [00:00<?, ? examples/s]

Let's verify masking the instruction part is done! Let's print the 100th row again.  Notice how the sample only has a single `<bos>` as expected!

In [None]:
tokenizer.decode(trainer.train_dataset[100]["input_ids"])

'<bos><start_of_turn>user\nIf you suggest someone for the awards, you should   _  .<end_of_turn>\n<start_of_turn>model\nC<end_of_turn>\n'

Now let's print the masked out example - you should see only the answer is present:

In [None]:
tokenizer.decode([tokenizer.pad_token_id if x == -100 else x for x in trainer.train_dataset[100]["labels"]]).replace(tokenizer.pad_token, " ")

'                       C<end_of_turn>\n'

In [None]:
# @title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

GPU = NVIDIA L4. Max memory = 22.161 GB.
9.305 GB of memory reserved.


# Let's train the model!

To resume a training run, set `trainer.train(resume_from_checkpoint = True)`

In [None]:
trainer_stats = trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 83,366 | Num Epochs = 1 | Total steps = 100
O^O/ \_/ \    Batch size per device = 1 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (1 x 4 x 1) = 4
 "-____-"     Trainable parameters = 19,210,240 of 7,869,188,432 (0.24% trained)
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
1,18.5224
2,6.613
3,5.3722
4,4.2073
5,3.5166
6,2.1083
7,1.3548
8,3.2441
9,1.075
10,0.7298


Unsloth: Will smartly offload gradients to save VRAM!


### Large Losses During Fine-tuning

High initial training losses (e.g., 6–18) are expected when fine-tuning models like **Gemma 3N**, especially due to its multimodal nature. The loss typically drops and stabilizes quickly, which aligns with what I observed.

For more details, see [Unsloth's explanation](https://unsloth.ai/blog/gemma-3n#large-losses-during-finetuning).

In [None]:
# @title Show final memory and time stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(
    f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training."
)
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

684.6371 seconds used for training.
11.41 minutes used for training.
Peak reserved memory = 10.062 GB.
Peak reserved memory for training = 0.757 GB.
Peak reserved memory % of max memory = 45.404 %.
Peak reserved memory for training % of max memory = 3.416 %.


<a name="Inference"></a>
### Inference
Let's run the model via Unsloth native inference! According to the `Gemma-3` team, the recommended settings for inference are `temperature = 1.0, top_p = 0.95, top_k = 64`

In [None]:
from unsloth.chat_templates import get_chat_template

# Re-attach the correct chat template (same as used for training)
tokenizer = get_chat_template(
    tokenizer,
    chat_template = "gemma-3",
)

# Format your inference prompt
messages = [
    {
        "role": "user",
        "content": [{
            "type": "text",
            "text": "A square has an area of 49 square centimeters. What is the perimeter of the square?",
        }]
    }
]

# Apply template for generation
inputs = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,  # Required for generation!
    tokenize=True,
    return_tensors="pt",
    return_dict=True,
).to("cuda")

# Generate output
outputs = model.generate(
    **inputs,
    max_new_tokens=256,
    temperature=1.0,
    top_p=0.95,
    top_k=64,
)

# Decode and print the model's response
tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

'user\nA square has an area of 49 square centimeters. What is the perimeter of the square?\nmodel\nLet $s$ be the side length of the square.\nThe area of the square is given by $A = s^2$.\nWe are given that the area of the square is 49 square centimeters.\nSo, $s^2 = 49$.\nTaking the square root of both sides, we get $s = \\sqrt{49} = 7$ centimeters.\nThe perimeter of a square is given by $P = 4s$.\nSince $s = 7$ centimeters, the perimeter is $P = 4(7) = 28$ centimeters.\n\nFinal Answer: The final answer is $\\boxed{28}$'

 You can also use a `TextStreamer` for continuous inference - so you can see the generation token by token, instead of waiting the whole time!

In [None]:
# Define the prompt in Unsloth-compatible format
messages = [{
    "role": "user",
    "content": [{
        "type": "text",
        "text": "A triangle has angles measuring 35° and 75°. What is the measure of the third angle?",
    }]
}]

# Tokenize the prompt
inputs = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt = True,  # Appends <start_of_turn>model\n
    return_tensors = "pt",
    tokenize = True,
    return_dict = True,
).to("cuda")

from transformers import TextStreamer
# Generate and stream
_ = model.generate(
    **inputs,
    max_new_tokens = 256,
    temperature = 1.0,
    top_p = 0.95,
    top_k = 64,
    streamer = TextStreamer(tokenizer, skip_prompt = True),
)

Let the three angles of the triangle be $A$, $B$, and $C$. We are given that two of the angles are $35^\circ$ and $75^\circ$. Let $A = 35^\circ$ and $B = 75^\circ$.
The sum of the angles in a triangle is $180^\circ$. Therefore, we have
$$A + B + C = 180^\circ$$
Substituting the given values, we get
$$35^\circ + 75^\circ + C = 180^\circ$$
$$110^\circ + C = 180^\circ$$
$$C = 180^\circ - 110^\circ$$
$$C = 70^\circ$$
The measure of the third angle is $70^\circ$.

Final Answer: The final answer is $\boxed{70}$<end_of_turn>


<a name="Save"></a>
### Saving, loading finetuned models
To save the final model as LoRA adapters, use `save_pretrained` for a local save.

#### GGUF / llama.cpp Conversion
To save to `GGUF` / `llama.cpp`, Unsloth supports it natively now for all models. We can convert easily to `Q8_0, F16 or BF16` precision.

In [None]:
model.save_pretrained_merged("sat-tutor-gemma3n-merged", tokenizer)

model.save_pretrained_gguf(
    "sat-tutor-gemma3n-merged",  # must be the merged folder
    quantization_type="Q8_0",    # or "F16"
)

Found HuggingFace hub cache directory: /root/.cache/huggingface/hub
Checking cache directory for required files...
Cache check failed: model-00001-of-00004.safetensors not found in local cache.
Not all required files found in cache. Will proceed with downloading.
Downloading safetensors index for unsloth/gemma-3n-e4b-it...


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Unsloth: Merging weights into 16bit:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/3.08G [00:00<?, ?B/s]

Unsloth: Merging weights into 16bit:  25%|██▌       | 1/4 [00:21<01:05, 21.86s/it]

model-00002-of-00004.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

Unsloth: Merging weights into 16bit:  50%|█████     | 2/4 [00:57<00:59, 29.74s/it]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

Unsloth: Merging weights into 16bit:  75%|███████▌  | 3/4 [01:41<00:36, 36.50s/it]

model-00004-of-00004.safetensors:   0%|          | 0.00/2.66G [00:00<?, ?B/s]

Unsloth: Merging weights into 16bit: 100%|██████████| 4/4 [02:07<00:00, 31.75s/it]


Unsloth: Updating system package directories
Unsloth: Install GGUF and other packages
Unsloth GGUF:hf-to-gguf:Loading model: sat-tutor-gemma3n-merged
Unsloth GGUF:hf-to-gguf:Model architecture: Gemma3nForConditionalGeneration
Unsloth GGUF:gguf.gguf_writer:gguf: This GGUF file is for Little Endian only
Unsloth GGUF:hf-to-gguf:Exporting model...
Unsloth GGUF:hf-to-gguf:gguf: loading model weight map from 'model.safetensors.index.json'
Unsloth GGUF:hf-to-gguf:gguf: loading model part 'model-00001-of-00004.safetensors'
Unsloth GGUF:hf-to-gguf:altup_proj.weight,                 torch.bfloat16 --> Q8_0, shape = {2048, 2048, 3}
Unsloth GGUF:hf-to-gguf:altup_unembd_proj.weight,          torch.bfloat16 --> Q8_0, shape = {2048, 2048, 3}
Unsloth GGUF:hf-to-gguf:token_embd.weight,                 torch.bfloat16 --> Q8_0, shape = {2048, 262144}
Unsloth GGUF:hf-to-gguf:gguf: loading model part 'model-00002-of-00004.safetensors'
Unsloth GGUF:hf-to-gguf:per_layer_token_embd.weight,       torch.bfloat1

Unsloth: GGUF conversion:   0%|          | 0/100 [00:00<?, ?it/s]

Unsloth GGUF:hf-to-gguf:Model successfully exported to ./
Unsloth: Converted to sat-tutor-gemma3n-merged.Q8_0.gguf with size = 7.3G
Unsloth: Successfully saved GGUF to:
sat-tutor-gemma3n-merged.Q8_0.gguf


['sat-tutor-gemma3n-merged.Q8_0.gguf']

In [None]:
# Save the model as a zip file and download
!zip sat-tutor-gemma3n.Q8_0.zip /content/sat-tutor-gemma3n-merged.Q8_0.gguf
from google.colab import files
files.download("sat-tutor-gemma3n.Q8_0.zip")

  adding: content/sat-tutor-gemma3n-merged.Q8_0.gguf (deflated 4%)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Now, use the `sat-tutor-gemma3n-merged.Q8_0.gguf` file in llama.cpp or a UI based system like Jan or Open WebUI.