In [None]:
import datasets
import transformers
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer
from utils import load_untrained_llama2_model
tokenizer, model = load_untrained_llama2_model()

In [None]:

# for now only with zero-shot dataset
zero_shot_dataset_dict: datasets.DatasetDict = datasets.load_from_disk(
    "zero_shot_dataset"
)

def tokenize_function(examples):
    result = tokenizer(examples["prompt"], truncation=False)["input_ids"]
    return {"input_ids": result}

tokenized_train_dataset = zero_shot_dataset_dict["train"].map(
    tokenize_function, batched=True, remove_columns="prompt"
)
tokenized_eval_dataset = zero_shot_dataset_dict["validation_prompted"].map(
    tokenize_function, batched=True, remove_columns="prompt"
)
peft_params = LoraConfig(
    lora_alpha=16, lora_dropout=0.1, r=16, bias="none", task_type="CAUSAL_LM"
)
collator = transformers.DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=False
)
training_params = SFTConfig(
    max_steps=-1,
    group_by_length=True,
    output_dir="output",
    learning_rate=5e-5,
    logging_steps=100,
    weight_decay=0.01,
    max_grad_norm=1.0,
    adam_epsilon=1e-8,
    warmup_steps=10,
    save_steps=1000,
    save_total_limit=2,
    gradient_accumulation_steps=1,
    num_train_epochs=50,
)
trainer = SFTTrainer(
    model=model,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_eval_dataset,
    peft_config=peft_params,
    args=training_params,
    data_collator=collator,
)
trainer.train()
