# Jsonfarmer

In [None]:
%cd ..
%load_ext dotenv
%dotenv
%env PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0
%env HF_HUB_ENABLE_HF_TRANSFER=1
%env PYDEVD_DISABLE_FILE_VALIDATION=1
# !litgpt download meta-llama/Llama-3.2-1B
!litgpt download Qwen/Qwen2.5-0.5B-Instruct

In [None]:
from pathlib import Path

import torch

import litgpt
from litgpt.data import JSON
from litgpt.lora import merge_lora_weights
from lightning.pytorch import Trainer, seed_everything
from jsonfarmer.model import Jsonfarmer

seed_everything(42, workers=True)
torch.set_float32_matmul_precision("high")
model_name = "Qwen/Qwen2.5-0.5B-Instruct"

In [None]:
data = JSON(
    json_path=Path("jsonfarmer/dataset.json"),
    mask_prompt=True,
    prompt_style="alpaca",
    val_split_fraction=0.1,
)
data.prepare_data()
data.setup()
tokenizer = litgpt.Tokenizer(f"checkpoints/{model_name}")
data.connect(tokenizer, batch_size=4, max_seq_length=512)

trainer = Trainer(
    devices=1,
    max_epochs=10,
    deterministic=True,
    accumulate_grad_batches=3,
    log_every_n_steps=1,
    # fast_dev_run=True,
)
with trainer.init_module(empty_init=True):
    model = Jsonfarmer(
        model_name=model_name,
        num_tokens=tokenizer.vocab_size,
    )

trainer.fit(model, data)
# merge_lora_weights(model.model)
trainer.save_checkpoint("checkpoints/jsonfarmer.ckpt", weights_only=True)