# Jsonfarmer

In [None]:
%load_ext dotenv
%dotenv
%env PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0
# !HF_HUB_ENABLE_HF_TRANSFER=1 litgpt download google/gemma-3-1b-it --tokenizer_only true
!HF_HUB_ENABLE_HF_TRANSFER=1 litgpt download meta-llama/Llama-3.2-1B --tokenizer_only true

In [None]:
import torch
import litgpt
from litgpt.lora import GPT
import lightning as L


# model_name = "google/gemma-3-1b-it"
model_name = "meta-llama/Llama-3.2-1B"


class Jsonfarmer(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = GPT.from_name(
            name=model_name,
            lora_r=8,
            lora_alpha=32,
            lora_dropout=0.05,
            lora_query=True,
            lora_key=False,
            lora_value=True,
        )
        litgpt.lora.mark_only_lora_as_trainable(self.model)

    # def on_train_start(self):
    #     state_dict = torch.load("checkpoints/google/gemma-3-1b-it/", mmap=True)
    #     self.model.load_state_dict(state_dict, strict=False)

    def training_step(self, batch):
        input_ids, targets = batch["input_ids"], batch["labels"]
        logits = self.model(input_ids)
        loss = litgpt.utils.chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:])
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        warmup_steps = 10
        optimizer = torch.optim.AdamW(
            self.model.parameters(), lr=2e-4, weight_decay=0.0, betas=(0.9, 0.95)
        )
        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer, lambda step: step / warmup_steps
        )
        return [optimizer], [scheduler]

: 

In [None]:
from litgpt.data import Alpaca2k
from litgpt.lora import merge_lora_weights
from lightning.pytorch import Trainer, seed_everything


seed_everything(42, workers=True)

data = Alpaca2k()
tokenizer = litgpt.Tokenizer(f"checkpoints/{model_name}")
data.connect(tokenizer, batch_size=8, max_seq_length=512)

trainer = Trainer(
    devices=1,
    max_epochs=2,
    accumulate_grad_batches=8,
    precision="bf16-true",
    deterministic=True,
)
with trainer.init_module(empty_init=True):
    model = Jsonfarmer()

trainer.fit(model, data)
merge_lora_weights(model.model)
trainer.save_checkpoint("checkpoints/jsonfarmer.ckpt", weights_only=True)
