In [None]:
from datasets import load_dataset
from peft import LoraConfig
import torch
from trl import SFTTrainer
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    BitsAndBytesConfig,
)

In [None]:
# Important Variables
MAX_SEQ_LENGTH = 2048
OUTPUT_DIR = "./outputs"

In [None]:
# training config

training_config = {
    "output_dir": OUTPUT_DIR,
    "bf16": True,
    "optim": "paged_adamw_32bit",
    "learning_rate": 2e-5,
    "logging_steps": 20,
    "logging_strategy": "steps",
    "lr_scheduler_type": "cosine",
    "weight_decay": 0.001,
    "num_train_epochs": 2,
    "overwrite_output_dir": True,
    "per_device_eval_batch_size": 4,
    "per_device_train_batch_size": 4,
    "save_steps": 50,
    "save_total_limit": 1,
    "gradient_accumulation_steps": 2,
    "warmup_steps" : 20,
}

training_config = TrainingArguments(**training_config)

In [None]:
# LoRA configuration

lora_config = {
    "r": 8,
    "lora_alpha": 16,
    "lora_dropout": 0.1,
    "bias": "none",
    "task_type": "CAUSAL_LM",
    "use_rslora": True,

    # Layers to target for LoRA

    # "microsoft/phi-1_5"
    # "target_modules": ["qkv_proj"],

    # "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
    "target_modules": ["v_proj", "k_proj", "q_proj", "o_proj"],
}

lora_config = LoraConfig(**lora_config)

In [None]:
# checkpoint_path = "microsoft/Phi-3-mini-4k-instruct"  # has chat_template
# checkpoint_path = "microsoft/phi-1_5"                 # do not has chat_template
checkpoint_path = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"  # has chat_template

# quantization config
quantization_config = BitsAndBytesConfig(
    load_in_8bit=True,
    bnb_8bit_compute_dtype=torch.bfloat16,
    llm_int8_enable_fp32_cpu_offload = True
)

# load model
model = AutoModelForCausalLM.from_pretrained(
    checkpoint_path,
    use_cache=False,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    quantization_config=quantization_config,
    device_map=None,
    attn_implementation='eager',
)

In [None]:
# load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)

# customize the max length
tokenizer.model_max_length = MAX_SEQ_LENGTH

# use eos rather than eos token to prevent endless generation
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.eos_token)

# use unk rather than eos token to prevent endless generation
# tokenizer.pad_token = tokenizer.unk_token
# tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)

tokenizer.padding_side = 'right'

In [None]:
# load only a subset of the dataset for quick application

train_dataset = load_dataset("HuggingFaceH4/ultrachat_200k", split='train_sft[:5000]')
test_dataset = load_dataset("HuggingFaceH4/ultrachat_200k", split='test_sft[:500]')

column_names = list(train_dataset.features)

In [None]:
# # add your custom chat template

# def apply_chat_template(messages, tokenizer):
#     prompt = ""
#     for m in messages["messages"]:
#         prompt+= f"{m['role']}: {m['content']}\n"
#     messages["text"] = prompt
#     return messages

In [None]:
# use model chat template

def apply_chat_template(
    example,
    tokenizer,
):
    messages = example["messages"]
    example["text"] = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=False
    )
    return example

In [None]:
# apply the chat template

processed_train_dataset = train_dataset.map(
    apply_chat_template,
    fn_kwargs={"tokenizer": tokenizer},
    num_proc=10,
    remove_columns=column_names,
    desc="Applying chat template to train_sft",
)

processed_test_dataset = test_dataset.map(
    apply_chat_template,
    fn_kwargs={"tokenizer": tokenizer},
    num_proc=10,
    remove_columns=column_names,
    desc="Applying chat template to test_sft",
)

In [None]:
# initialize the SFTTrainer
trainer = SFTTrainer(
    model=model,
    args=training_config,
    peft_config=lora_config,
    train_dataset=processed_train_dataset,
    eval_dataset=processed_test_dataset,
)

In [None]:
# start training
train_result = trainer.train()

In [None]:
metrics = train_result.metrics
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()

In [None]:
tokenizer.padding_side = 'left'

metrics = trainer.evaluate()
metrics["eval_samples"] = len(processed_test_dataset)

trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

In [None]:
trainer.save_model(train_conf.output_dir)

In [None]:
# remove modules for memory from huggingface cache

# !rm -r ~/.cache/huggingface/modules/
# !rm -r ~/.cache/huggingface/datasets/