# Fine tune Gemma-3-12B-it model using Axolotl framework


In [1]:
# Check if GPU is available
import torch
print('GPU available?', torch.cuda.is_available())
print('BF16 is supported?', torch.cuda.is_bf16_supported())

KeyboardInterrupt: 

In [None]:
!printenv CUDA_HOME

In [None]:
# set model name etc.

MODEL_NAME = "google/gemma-3-12b-it"
MODEL_SHORT_NAME = MODEL_NAME.split('/')[-1]
SUFFIX = "FinGreyLit"
#SLICE = 1

In [None]:
# Load and prepare fine-tuning dataset

import json
import glob
import random

random.seed(42)  # for deterministic sampling of test set

train_files = glob.glob("../../llm-dataset/*-train.jsonl")
test_files = glob.glob("../../llm-dataset/*-test.jsonl")

EVAL_SIZE = 32  # how many documents to evaluate (i.e. calculate loss) on during fine-tuning
SYSTEM_PROMPT = "You are a skilled librarian specialized in meticulous cataloguing of digital documents."
INSTRUCTION = "Extract metadata from this document. Return as JSON."

def preprocess_sample(sample):
    output = json.dumps(sample["ground_truth"])
    input_ = json.dumps(sample["content"])
    # ShareGPT format
    conversations = [
        {'from': 'system', 'value': SYSTEM_PROMPT},
        {'from': 'user', 'value': INSTRUCTION + "\n\n" + input_},
        {'from': 'assistant', 'value': output}
    ]
    return {"conversations": conversations}

def dataset_to_records(files):
    records = []
    for filename in files:
        with open(filename) as infile:
            for line in infile:
                sample = json.loads(line)
                records.append(preprocess_sample(sample))
    return records

def write_jsonl(records, filename):
    with open(filename, "w") as outfile:
        for record in records:
            json.dump(record, outfile)
            outfile.write("\n")

train_recs = dataset_to_records(train_files)
random.shuffle(train_recs)
write_jsonl(train_recs, "axolotl-train.jsonl")
print(f"Wrote {len(train_recs)} train records")

test_recs = dataset_to_records(test_files)
write_jsonl(test_recs, "axolotl-test.jsonl")
print(f"Wrote {len(test_recs)} test records")

eval_recs = random.sample(test_recs, EVAL_SIZE)
write_jsonl(eval_recs, "axolotl-eval.jsonl")
print(f"Wrote {len(eval_recs)} eval records")

In [None]:
# Create Axolotl configuration file

CONFIG_FILE = f"config-{MODEL_SHORT_NAME}.yml"


CONFIG = f"""
base_model: {MODEL_NAME}
model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer

load_in_8bit: false
load_in_4bit: false
strict: false

datasets:
  - path: axolotl-train.jsonl
    type: chat_template
    ds_type: json
    split: train
    field_messages: conversations
    message_property_mappings:
      role: from
      content: value

test_datasets:
  - path: axolotl-eval.jsonl
    type: chat_template
    ds_type: json
    split: train
    field_messages: conversations
    message_property_mappings:
      role: from
      content: value

output_dir: ./out-{MODEL_SHORT_NAME}

chat_template: gemma3
eot_tokens:
  - <end_of_turn>

peft_use_dora: true
adapter: lora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true

sequence_len: 4096
sample_packing: true
eval_sample_packing: false
pad_to_sequence_len: true

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 2
eval_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false

gradient_checkpointing: true  # true: saves VRAM but is slower to train
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 10
evals_per_epoch: 2
eval_table_size:
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
weight_decay: 0.0
fsdp:
fsdp_config:


""".strip()

with open(CONFIG_FILE, 'w') as outfile:
    print(CONFIG, file=outfile)

In [None]:
%%time

!venv/bin/accelerate launch -m axolotl.cli.train {CONFIG_FILE}

# Merge the LoRA/DoRA into the base model (for inference & quantization)

In [None]:
%%time

!venv/bin/axolotl merge-lora {CONFIG_FILE}

In [None]:
%%time
resultfile = f"../../eval/results-{MODEL_SHORT_NAME.replace('.','_')}.md"

# evaluate using the evaluate-model script, which needs venv with vLLM installed
!../dspy/venv/bin/python evaluate-model.py out-{MODEL_SHORT_NAME}/merged axolotl-test.jsonl {resultfile}
!cat {resultfile}