In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path

import torch
import dotenv
from accelerate import PartialState
from datasets import Dataset, load_dataset
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import DataCollatorForCompletionOnlyLM, SFTTrainer
from absl import logging

from llm_ol.experiments.llm.finetune.training.utils import GenerateSamplesCallback
from llm_ol.experiments.llm.templates import (
    _MISTRAL_TEMPLATE,
    PROMPT_TEMPLATE,
    RESPONSE_TEMPLATE,
)

logging.set_verbosity(logging.INFO)
dotenv.load_dotenv()


def dataset_from_file(
    data_file: str | Path, size: int | None = None, seed: int = 0
) -> Dataset:
    dataset = load_dataset("json", data_files=str(data_file), split="train")
    assert isinstance(dataset, Dataset)
    if size is not None:
        dataset = dataset.shuffle(seed=seed).select(range(size))

    def make_messages(examples: dict[str, list]) -> dict[str, list]:
        outputs = []
        for title, abstract, paths in zip(
            examples["title"], examples["abstract"], examples["paths"]
        ):
            prompt = PROMPT_TEMPLATE.render(title=title, abstract=abstract)
            response = RESPONSE_TEMPLATE.render(paths=paths)
            messages = [
                {"role": "user", "content": prompt},
                {"role": "assistant", "content": response},
            ]
            outputs.append(messages)
        return {"messages": outputs}

    dataset = dataset.map(make_messages, batched=True, num_proc=16)
    return dataset

In [None]:
model_name = "out/experiments/finetune/v9/train/checkpoint-15000/merged"
device_string = PartialState().process_index
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    use_cache=False,
    device_map={"": device_string},
    torch_dtype="auto",
)
model = get_peft_model(
    model,
    LoraConfig(
        r=16,
        lora_alpha=16,
        lora_dropout=0,
        target_modules=[
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj",
        ],
        task_type="CAUSAL_LM",
    ),
)
model.print_trainable_parameters()

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.chat_template = _MISTRAL_TEMPLATE
tokenizer.padding_side = "right"
if getattr(tokenizer, "pad_token", None) is None:
    tokenizer.pad_token = tokenizer.unk_token

collator = DataCollatorForCompletionOnlyLM(
    response_template=[733, 28748, 16289, 28793],
    instruction_template=[733, 16289, 28793],
    tokenizer=tokenizer,
    pad_to_multiple_of=8,
)

In [None]:
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    data_collator=collator,
    max_seq_length=2048,
    dataset_num_proc=16,
    train_dataset=dataset_from_file(
        "out/experiments/llm/arxiv/train_dataset.jsonl", 1024
    ),
    eval_dataset=dataset_from_file("out/experiments/llm/arxiv/eval_dataset.jsonl", 128),
    dataset_kwargs={
        "add_special_tokens": False,
    },
    callbacks=[GenerateSamplesCallback(5, [733, 28748, 16289, 28793])],
    args=TrainingArguments(
        output_dir="/tmp/llm",
        overwrite_output_dir=True,
        optim="adamw_torch_fused",
        learning_rate=1e-5,
        lr_scheduler_type="constant_with_warmup",
        warmup_steps=0,
        report_to=[],
        num_train_epochs=2,
        logging_steps=10,
        gradient_checkpointing=True,
        gradient_checkpointing_kwargs={"use_reentrant": False},
        gradient_accumulation_steps=1,
        ddp_find_unused_parameters=False,
        group_by_length=True,
        fp16=torch.cuda.is_available() and not torch.cuda.is_bf16_supported(),
        bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(),
        evaluation_strategy="steps",
        eval_steps=20,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=32,
        seed=0,
        data_seed=0,
    ),
)

In [None]:
trainer.evaluate()

In [None]:
trainer.train()