# Jsonfarmer

In [None]:
%load_ext dotenv
%dotenv
%env PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0
%env HF_HUB_ENABLE_HF_TRANSFER=1
!litgpt download meta-llama/Llama-3.2-1B

In [None]:
from pathlib import Path
import torch
import torch.nn.functional as F
import lightning as L

import litgpt
from litgpt.data import JSON
from litgpt.lora import GPT, merge_lora_weights
from lightning.pytorch import Trainer, seed_everything

seed_everything(42, workers=True)

In [None]:
class Jsonfarmer(L.LightningModule):
    def __init__(self, model_name: str):
        super().__init__()
        self.model_name = model_name
        self.model = GPT.from_name(
            name=model_name,
            lora_r=4,
            lora_alpha=8,
            lora_dropout=0.05,
            lora_query=True,
            lora_key=False,
            lora_value=True,
        )
        litgpt.lora.mark_only_lora_as_trainable(self.model)

    def on_train_start(self):
        state_dict = torch.load(
            f"checkpoints/{self.model_name}/lit_model.pth", mmap=True
        )
        self.model.load_state_dict(state_dict, strict=False)

    def training_step(self, batch):
        input_ids, targets = batch["input_ids"], batch["labels"]
        logits = self.model(input_ids)
        loss = litgpt.utils.chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:])
        self.log(
            "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
        )
        return loss

    def validation_step(self, batch):
        input_ids, targets = batch["input_ids"], batch["labels"]
        logits = self.model(input_ids)
        loss = litgpt.utils.chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:])
        self.log(
            "val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
        )
        return loss

    def configure_optimizers(self):
        warmup_steps = 10
        optimizer = torch.optim.AdamW(
            self.model.parameters(), lr=2e-4, weight_decay=0.0, betas=(0.9, 0.95)
        )
        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer, lambda step: step / warmup_steps
        )
        return [optimizer], [scheduler]

In [None]:
model_name = "meta-llama/Llama-3.2-1B"
data = JSON(
    json_path=Path("dataset.json"),
    prompt_style="alpaca",
    val_split_fraction=0.2,
)
data.setup()
tokenizer = litgpt.Tokenizer(f"checkpoints/{model_name}")
data.connect(tokenizer, batch_size=1, max_seq_length=512)


trainer = Trainer(
    devices=1,
    max_epochs=10,
    # accumulate_grad_batches=8,
    precision="bf16-mixed",
    deterministic=True,
    log_every_n_steps=1,
)
with trainer.init_module(empty_init=True):
    model = Jsonfarmer(model_name)

trainer.fit(model, data)
merge_lora_weights(model.model)
trainer.save_checkpoint("checkpoints/jsonfarmer.ckpt", weights_only=True)