# Jsonfarmer

In [None]:
%load_ext dotenv
%dotenv
%env PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0
%env HF_HUB_ENABLE_HF_TRANSFER=1
%env PYDEVD_DISABLE_FILE_VALIDATION=1
# !litgpt download meta-llama/Llama-3.2-1B
!litgpt download Qwen/Qwen2.5-0.5B-Instruct

In [None]:
from pathlib import Path

import torch
import torch.nn.functional as F

import litgpt
from litgpt.data import JSON
from litgpt.lora import GPT, merge_lora_weights
import lightning as L
from lightning.pytorch import Trainer, seed_everything

seed_everything(42, workers=True)
torch.set_float32_matmul_precision("high")
model_name = "Qwen/Qwen2.5-0.5B-Instruct"

In [None]:
class Jsonfarmer(L.LightningModule):
    def __init__(self, model_name: str):
        super().__init__()
        self.model_name = model_name
        self.model = GPT.from_name(
            name=model_name,
            # lora_r=4,
            # lora_alpha=8,
            # lora_dropout=0.05,
            # lora_query=True,
            # lora_key=False,
            # lora_value=True,
            # lora_projection=False,
            # lora_mlp=False,
            # lora_head=False,
        )
        # litgpt.lora.mark_only_lora_as_trainable(self.model)

    def on_train_start(self):
        state_dict = torch.load(
            f"checkpoints/{self.model_name}/lit_model.pth", mmap=True
        )
        self.model.load_state_dict(state_dict, strict=False)
    
    def loop_step(self, batch):
        input_ids, targets = batch["input_ids"], batch["labels"]
        logits = self.model(input_ids)
        targets = targets[..., 1:]
        logits = logits[..., :-1, :]
        y = targets.reshape(-1)
        y_pred = logits.reshape(-1, logits.size(-1))
        loss = F.cross_entropy(y_pred, y)
        print(f"loss={loss}")

        return loss

    def training_step(self, batch):
        loss = self.loop_step(batch)
        self.log(
            "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
        )
        return loss

    def validation_step(self, batch):
        loss = self.loop_step(batch)
        self.log(
            "val_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
        )
        return loss

    def configure_optimizers(self):
        warmup_steps = 100
        optimizer = torch.optim.AdamW(
            self.model.parameters(), lr=1e-3, #weight_decay=0.0, betas=(0.9, 0.95)
        )
        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer, lambda step: step / warmup_steps
        )
        return [optimizer], [scheduler]

In [None]:
data = JSON(
    json_path=Path("dataset.json"),
    mask_prompt=True,
    prompt_style="alpaca",
    val_split_fraction=0.1,
)
data.prepare_data()
data.setup()
tokenizer = litgpt.Tokenizer(f"checkpoints/{model_name}")
data.connect(tokenizer, batch_size=1, max_seq_length=512)

trainer = Trainer(
    devices=1,
    max_epochs=10,
    deterministic=True,
    log_every_n_steps=1,
)
with trainer.init_module(empty_init=True):
    model = Jsonfarmer(model_name)

trainer.fit(model, data)
# merge_lora_weights(model.model)
trainer.save_checkpoint("checkpoints/jsonfarmer.ckpt", weights_only=True)