## Install Libraries

In [None]:
!pip install -q -U transformers datasets accelerate peft trl bitsandbytes

## Load Dataset

In [None]:
from datasets import load_dataset

dataset_name = "XenArcAI/MathX-5M"
# To speed up training, we'll just use a small fraction of the data.
# Remove the slicing to use the full dataset.
dataset = load_dataset(dataset_name, split="train[:1%]")

# Inspect the dataset
print(dataset)
print(dataset[0])

## Load the Model anf Tokenizer

In [None]:
from datasets import load_dataset

dataset_name = "XenArcAI/MathX-5M"
# To speed up the demo, we'll just use a small fraction of the data.
# Remove the slicing "[ :1%]" to use more of the dataset.
dataset = load_dataset(dataset_name, split="train[:1%]")

# Inspect the dataset
print(dataset)
print(dataset[0])

## Preprocess data

In [None]:
def format_prompt(sample):
    # The 'text' field contains the problem and solution. We format it
    # as a user query for the model to learn from.
    messages = [
        {"role": "system", "content": "You are a helpful assistant that solves math problems."},
        {"role": "user", "content": sample['text']}
    ]
    return {"text": tokenizer.apply_chat_template(messages, tokenize=False)}


formatted_dataset = dataset.map(format_prompt)

## Training

In [None]:
from transformers import TrainingArguments
from peft import LoraConfig

# LoRA configuration
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

# Training arguments
training_args = TrainingArguments(
    output_dir="./llama3-8b-math-tuned",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    logging_steps=10,
    max_steps=100,      # Increase this for a more thorough training
    save_steps=50,
    fp16=True,
)

In [None]:
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    train_dataset=formatted_dataset,
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=1024,
    tokenizer=tokenizer,
    args=training_args,
)

In [None]:
trainer.train()

## Export adapters

In [None]:
adapter_path = "./llama3-8b-math-tuned-adapters"
trainer.save_model(adapter_path)

print(f"LoRA adapters saved to {adapter_path}")

## Merge with base model

In [None]:
from peft import PeftModel

# --- Reload the base model without quantization ---
# This is important for merging and for Ollama compatibility
base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# --- Load the PeftModel with the adapters ---
model = PeftModel.from_pretrained(base_model, adapter_path)

# --- Merge the weights and save the new model ---
model = model.merge_and_unload()

merged_model_path = "./llama3-8b-math-merged"
model.save_pretrained(merged_model_path)
tokenizer.save_pretrained(merged_model_path)

print(f"Merged model saved to {merged_model_path}")

In [None]:
!zip -r llama3-8b-math-merged.zip ./llama3-8b-math-merged