## Login to Hugging Face

In [1]:
from dotenv import load_dotenv
import os
from huggingface_hub import login

load_dotenv()
token = os.getenv("HUGGINGFACE_TOKEN")
login(
    token=token,  # ADD YOUR TOKEN HERE
    add_to_git_credential=True
)

Token is valid (permission: write).
Your token has been saved in your configured git credential helpers (store).
Your token has been saved to /home/pathfinder/.cache/huggingface/token
Login successful


In [2]:
model_name = "Waktaverse-Llama-3-KO-8B-Instruct"  # ADD YOUR MODEL NAME HERE
username = "PathFinderKR"  # ADD YOUR USERNAME HERE
repo_id = f"{username}/{model_name}"  # repository id

## Login to Weights & Biases

In [3]:
import wandb

api_key = os.getenv("WANDB_API_KEY")
wandb.login(
    key=api_key  # ADD YOUR API KEY HERE
)
wandb.init(project=model_name)

[34m[1mwandb[0m: Currently logged in as: [33mpathfinderkr[0m ([33mwaktaverse[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/pathfinder/.netrc


## Downloads

In [4]:
#!pip install huggingface_hub
#!pip install wandb
#!pip install transformers
#!pip install bitsandbytes
#!pip install peft
#!pip install trl
#!pip install accelerate
#!pip install datasets
#!pip install scikit-learn
#!pip install packaging
#!pip install ninja
#!pip install flash-attn --no-build-isolation

## Imports

In [5]:
from IPython.display import display, Markdown

# pytorch
import torch

# huggingface
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    TrainingArguments
)
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM

# datasets
from datasets import load_dataset

## Device

In [6]:
# Device setup
device = (
    "cuda:0" if torch.cuda.is_available() else # Nvidia GPU
    "mps" if torch.backends.mps.is_available() else # Apple Silicon GPU
    "cpu"
)
print(f"Device = {device}")

Device = cuda:0


In [7]:
# Flash Attention Implementation
if device == "cuda:0":
    if torch.cuda.get_device_capability()[0] >= 8: # Ampere, Ada, or Hopper GPUs
        attn_implementation = "flash_attention_2"
        torch_dtype = torch.bfloat16
    else:
        attn_implementation = "eager"
        torch_dtype = torch.float16
else:
    attn_implementation = "eager"
    torch_dtype = torch.float32
print(f"Attention Implementation = {attn_implementation}")

Attention Implementation = flash_attention_2


## Hyperparameters

In [8]:
################################################################################
# seed
################################################################################
seed=42
torch.manual_seed(seed)

################################################################################
# Tokenizer parameters
################################################################################
max_length=8192
padding="do_not_pad" # "max_length", "longest", "do_not_pad"
truncation=True

################################################################################
# Generation parameters
################################################################################
num_return_sequences=1
max_new_tokens=512
do_sample=True # True for sampling, False for greedy decoding
temperature=0.6
top_k=0 # not recommended
top_p=0.9
repetition_penalty=1.1

################################################################################
# bitsandbytes parameters
################################################################################
load_in_4bit=True
bnb_4bit_compute_dtype=torch_dtype
bnb_4bit_quant_type="nf4" # "nf4", #fp4"
bnb_4bit_use_double_quant=False

################################################################################
# LoRA parameters
################################################################################
task_type="CAUSAL_LM"
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
r=8
lora_alpha=16
lora_dropout=0.05
bias="none"

################################################################################
# TrainingArguments parameters
################################################################################
output_dir="./results"
logging_dir="./logs"
save_strategy="epoch"
logging_strategy="steps" # "steps", "epoch"
if logging_strategy == "steps":
    logging_steps=10
else:
    logging_steps=None
save_total_limit=1
report_to="wandb"

num_train_epochs=2
per_device_train_batch_size=1
gradient_accumulation_steps=1
gradient_checkpointing=True
learning_rate=2e-5
lr_scheduler_type="cosine" # "constant", "linear", "cosine"
warmup_ratio=0.1
optim = "adamw_torch"
weight_decay=0.01

################################################################################
# SFT parameters
################################################################################
max_seq_length=1024
packing=True # packing not supported for masking instructions

## Model

In [9]:
# Model List

# gemma variants
# "google/gemma-1.1-7b-it"
# "google/codegemma-7b-it"

# llama2 variants
# "meta-llama/Meta-Llama-3-8B" // downloaded
# "meta-llama/Meta-Llama-3-8B-Instruct" // downloaded
# "codellama/CodeLlama-7b-Instruct-hf"
# "PathFinderKR/Waktaverse-Llama-3-KO-8B-Instruct"

# mistral variants
# "mistralai/Mistral-7B-Instruct-v0.2"

# solar variants
# "upstage/SOLAR-10.7B-Instruct-v1.0" // downloaded
# "PathFinderKR/Waktaverse-SOLAR-KO-10.7B-Instruct"

In [10]:
# Model ID for base model
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

In [11]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.add_special_tokens({'pad_token': '<|pad_id|>'}) # add padding token
tokenizer.padding_side = "right"

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [12]:
# Display tokenizer information
display(Markdown(f'```{tokenizer}```'))

```PreTrainedTokenizerFast(name_or_path='meta-llama/Meta-Llama-3-8B-Instruct', vocab_size=128000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|begin_of_text|>', 'eos_token': '<|end_of_text|>', 'pad_token': '<|pad_id|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	128000: AddedToken("<|begin_of_text|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128001: AddedToken("<|end_of_text|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128002: AddedToken("<|reserved_special_token_0|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128003: AddedToken("<|reserved_special_token_1|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128004: AddedToken("<|reserved_special_token_2|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128005: AddedToken("<|reserved_special_token_3|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128006: AddedToken("<|start_header_id|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128007: AddedToken("<|end_header_id|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128008: AddedToken("<|reserved_special_token_4|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128009: AddedToken("<|eot_id|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128010: AddedToken("<|reserved_special_token_5|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128011: AddedToken("<|reserved_special_token_6|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128012: AddedToken("<|reserved_special_token_7|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128013: AddedToken("<|reserved_special_token_8|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128014: AddedToken("<|reserved_special_token_9|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128015: AddedToken("<|reserved_special_token_10|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128016: AddedToken("<|reserved_special_token_11|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128017: AddedToken("<|reserved_special_token_12|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128018: AddedToken("<|reserved_special_token_13|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128019: AddedToken("<|reserved_special_token_14|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128020: AddedToken("<|reserved_special_token_15|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128021: AddedToken("<|reserved_special_token_16|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128022: AddedToken("<|reserved_special_token_17|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128023: AddedToken("<|reserved_special_token_18|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128024: AddedToken("<|reserved_special_token_19|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128025: AddedToken("<|reserved_special_token_20|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128026: AddedToken("<|reserved_special_token_21|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128027: AddedToken("<|reserved_special_token_22|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128028: AddedToken("<|reserved_special_token_23|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128029: AddedToken("<|reserved_special_token_24|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128030: AddedToken("<|reserved_special_token_25|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128031: AddedToken("<|reserved_special_token_26|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128032: AddedToken("<|reserved_special_token_27|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128033: AddedToken("<|reserved_special_token_28|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128034: AddedToken("<|reserved_special_token_29|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128035: AddedToken("<|reserved_special_token_30|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128036: AddedToken("<|reserved_special_token_31|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128037: AddedToken("<|reserved_special_token_32|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128038: AddedToken("<|reserved_special_token_33|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128039: AddedToken("<|reserved_special_token_34|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128040: AddedToken("<|reserved_special_token_35|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128041: AddedToken("<|reserved_special_token_36|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128042: AddedToken("<|reserved_special_token_37|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128043: AddedToken("<|reserved_special_token_38|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128044: AddedToken("<|reserved_special_token_39|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128045: AddedToken("<|reserved_special_token_40|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128046: AddedToken("<|reserved_special_token_41|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128047: AddedToken("<|reserved_special_token_42|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128048: AddedToken("<|reserved_special_token_43|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128049: AddedToken("<|reserved_special_token_44|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128050: AddedToken("<|reserved_special_token_45|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128051: AddedToken("<|reserved_special_token_46|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128052: AddedToken("<|reserved_special_token_47|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128053: AddedToken("<|reserved_special_token_48|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128054: AddedToken("<|reserved_special_token_49|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128055: AddedToken("<|reserved_special_token_50|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128056: AddedToken("<|reserved_special_token_51|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128057: AddedToken("<|reserved_special_token_52|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128058: AddedToken("<|reserved_special_token_53|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128059: AddedToken("<|reserved_special_token_54|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128060: AddedToken("<|reserved_special_token_55|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128061: AddedToken("<|reserved_special_token_56|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128062: AddedToken("<|reserved_special_token_57|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128063: AddedToken("<|reserved_special_token_58|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128064: AddedToken("<|reserved_special_token_59|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128065: AddedToken("<|reserved_special_token_60|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128066: AddedToken("<|reserved_special_token_61|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128067: AddedToken("<|reserved_special_token_62|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128068: AddedToken("<|reserved_special_token_63|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128069: AddedToken("<|reserved_special_token_64|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128070: AddedToken("<|reserved_special_token_65|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128071: AddedToken("<|reserved_special_token_66|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128072: AddedToken("<|reserved_special_token_67|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128073: AddedToken("<|reserved_special_token_68|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128074: AddedToken("<|reserved_special_token_69|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128075: AddedToken("<|reserved_special_token_70|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128076: AddedToken("<|reserved_special_token_71|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128077: AddedToken("<|reserved_special_token_72|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128078: AddedToken("<|reserved_special_token_73|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128079: AddedToken("<|reserved_special_token_74|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128080: AddedToken("<|reserved_special_token_75|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128081: AddedToken("<|reserved_special_token_76|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128082: AddedToken("<|reserved_special_token_77|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128083: AddedToken("<|reserved_special_token_78|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128084: AddedToken("<|reserved_special_token_79|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128085: AddedToken("<|reserved_special_token_80|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128086: AddedToken("<|reserved_special_token_81|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128087: AddedToken("<|reserved_special_token_82|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128088: AddedToken("<|reserved_special_token_83|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128089: AddedToken("<|reserved_special_token_84|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128090: AddedToken("<|reserved_special_token_85|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128091: AddedToken("<|reserved_special_token_86|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128092: AddedToken("<|reserved_special_token_87|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128093: AddedToken("<|reserved_special_token_88|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128094: AddedToken("<|reserved_special_token_89|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128095: AddedToken("<|reserved_special_token_90|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128096: AddedToken("<|reserved_special_token_91|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128097: AddedToken("<|reserved_special_token_92|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128098: AddedToken("<|reserved_special_token_93|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128099: AddedToken("<|reserved_special_token_94|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128100: AddedToken("<|reserved_special_token_95|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128101: AddedToken("<|reserved_special_token_96|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128102: AddedToken("<|reserved_special_token_97|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128103: AddedToken("<|reserved_special_token_98|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128104: AddedToken("<|reserved_special_token_99|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128105: AddedToken("<|reserved_special_token_100|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128106: AddedToken("<|reserved_special_token_101|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128107: AddedToken("<|reserved_special_token_102|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128108: AddedToken("<|reserved_special_token_103|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128109: AddedToken("<|reserved_special_token_104|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128110: AddedToken("<|reserved_special_token_105|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128111: AddedToken("<|reserved_special_token_106|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128112: AddedToken("<|reserved_special_token_107|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128113: AddedToken("<|reserved_special_token_108|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128114: AddedToken("<|reserved_special_token_109|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128115: AddedToken("<|reserved_special_token_110|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128116: AddedToken("<|reserved_special_token_111|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128117: AddedToken("<|reserved_special_token_112|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128118: AddedToken("<|reserved_special_token_113|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128119: AddedToken("<|reserved_special_token_114|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128120: AddedToken("<|reserved_special_token_115|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128121: AddedToken("<|reserved_special_token_116|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128122: AddedToken("<|reserved_special_token_117|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128123: AddedToken("<|reserved_special_token_118|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128124: AddedToken("<|reserved_special_token_119|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128125: AddedToken("<|reserved_special_token_120|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128126: AddedToken("<|reserved_special_token_121|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128127: AddedToken("<|reserved_special_token_122|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128128: AddedToken("<|reserved_special_token_123|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128129: AddedToken("<|reserved_special_token_124|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128130: AddedToken("<|reserved_special_token_125|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128131: AddedToken("<|reserved_special_token_126|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128132: AddedToken("<|reserved_special_token_127|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128133: AddedToken("<|reserved_special_token_128|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128134: AddedToken("<|reserved_special_token_129|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128135: AddedToken("<|reserved_special_token_130|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128136: AddedToken("<|reserved_special_token_131|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128137: AddedToken("<|reserved_special_token_132|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128138: AddedToken("<|reserved_special_token_133|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128139: AddedToken("<|reserved_special_token_134|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128140: AddedToken("<|reserved_special_token_135|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128141: AddedToken("<|reserved_special_token_136|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128142: AddedToken("<|reserved_special_token_137|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128143: AddedToken("<|reserved_special_token_138|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128144: AddedToken("<|reserved_special_token_139|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128145: AddedToken("<|reserved_special_token_140|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128146: AddedToken("<|reserved_special_token_141|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128147: AddedToken("<|reserved_special_token_142|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128148: AddedToken("<|reserved_special_token_143|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128149: AddedToken("<|reserved_special_token_144|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128150: AddedToken("<|reserved_special_token_145|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128151: AddedToken("<|reserved_special_token_146|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128152: AddedToken("<|reserved_special_token_147|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128153: AddedToken("<|reserved_special_token_148|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128154: AddedToken("<|reserved_special_token_149|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128155: AddedToken("<|reserved_special_token_150|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128156: AddedToken("<|reserved_special_token_151|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128157: AddedToken("<|reserved_special_token_152|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128158: AddedToken("<|reserved_special_token_153|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128159: AddedToken("<|reserved_special_token_154|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128160: AddedToken("<|reserved_special_token_155|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128161: AddedToken("<|reserved_special_token_156|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128162: AddedToken("<|reserved_special_token_157|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128163: AddedToken("<|reserved_special_token_158|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128164: AddedToken("<|reserved_special_token_159|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128165: AddedToken("<|reserved_special_token_160|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128166: AddedToken("<|reserved_special_token_161|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128167: AddedToken("<|reserved_special_token_162|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128168: AddedToken("<|reserved_special_token_163|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128169: AddedToken("<|reserved_special_token_164|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128170: AddedToken("<|reserved_special_token_165|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128171: AddedToken("<|reserved_special_token_166|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128172: AddedToken("<|reserved_special_token_167|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128173: AddedToken("<|reserved_special_token_168|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128174: AddedToken("<|reserved_special_token_169|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128175: AddedToken("<|reserved_special_token_170|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128176: AddedToken("<|reserved_special_token_171|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128177: AddedToken("<|reserved_special_token_172|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128178: AddedToken("<|reserved_special_token_173|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128179: AddedToken("<|reserved_special_token_174|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128180: AddedToken("<|reserved_special_token_175|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128181: AddedToken("<|reserved_special_token_176|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128182: AddedToken("<|reserved_special_token_177|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128183: AddedToken("<|reserved_special_token_178|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128184: AddedToken("<|reserved_special_token_179|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128185: AddedToken("<|reserved_special_token_180|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128186: AddedToken("<|reserved_special_token_181|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128187: AddedToken("<|reserved_special_token_182|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128188: AddedToken("<|reserved_special_token_183|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128189: AddedToken("<|reserved_special_token_184|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128190: AddedToken("<|reserved_special_token_185|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128191: AddedToken("<|reserved_special_token_186|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128192: AddedToken("<|reserved_special_token_187|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128193: AddedToken("<|reserved_special_token_188|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128194: AddedToken("<|reserved_special_token_189|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128195: AddedToken("<|reserved_special_token_190|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128196: AddedToken("<|reserved_special_token_191|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128197: AddedToken("<|reserved_special_token_192|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128198: AddedToken("<|reserved_special_token_193|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128199: AddedToken("<|reserved_special_token_194|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128200: AddedToken("<|reserved_special_token_195|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128201: AddedToken("<|reserved_special_token_196|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128202: AddedToken("<|reserved_special_token_197|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128203: AddedToken("<|reserved_special_token_198|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128204: AddedToken("<|reserved_special_token_199|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128205: AddedToken("<|reserved_special_token_200|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128206: AddedToken("<|reserved_special_token_201|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128207: AddedToken("<|reserved_special_token_202|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128208: AddedToken("<|reserved_special_token_203|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128209: AddedToken("<|reserved_special_token_204|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128210: AddedToken("<|reserved_special_token_205|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128211: AddedToken("<|reserved_special_token_206|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128212: AddedToken("<|reserved_special_token_207|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128213: AddedToken("<|reserved_special_token_208|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128214: AddedToken("<|reserved_special_token_209|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128215: AddedToken("<|reserved_special_token_210|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128216: AddedToken("<|reserved_special_token_211|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128217: AddedToken("<|reserved_special_token_212|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128218: AddedToken("<|reserved_special_token_213|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128219: AddedToken("<|reserved_special_token_214|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128220: AddedToken("<|reserved_special_token_215|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128221: AddedToken("<|reserved_special_token_216|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128222: AddedToken("<|reserved_special_token_217|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128223: AddedToken("<|reserved_special_token_218|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128224: AddedToken("<|reserved_special_token_219|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128225: AddedToken("<|reserved_special_token_220|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128226: AddedToken("<|reserved_special_token_221|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128227: AddedToken("<|reserved_special_token_222|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128228: AddedToken("<|reserved_special_token_223|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128229: AddedToken("<|reserved_special_token_224|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128230: AddedToken("<|reserved_special_token_225|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128231: AddedToken("<|reserved_special_token_226|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128232: AddedToken("<|reserved_special_token_227|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128233: AddedToken("<|reserved_special_token_228|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128234: AddedToken("<|reserved_special_token_229|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128235: AddedToken("<|reserved_special_token_230|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128236: AddedToken("<|reserved_special_token_231|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128237: AddedToken("<|reserved_special_token_232|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128238: AddedToken("<|reserved_special_token_233|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128239: AddedToken("<|reserved_special_token_234|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128240: AddedToken("<|reserved_special_token_235|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128241: AddedToken("<|reserved_special_token_236|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128242: AddedToken("<|reserved_special_token_237|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128243: AddedToken("<|reserved_special_token_238|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128244: AddedToken("<|reserved_special_token_239|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128245: AddedToken("<|reserved_special_token_240|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128246: AddedToken("<|reserved_special_token_241|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128247: AddedToken("<|reserved_special_token_242|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128248: AddedToken("<|reserved_special_token_243|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128249: AddedToken("<|reserved_special_token_244|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128250: AddedToken("<|reserved_special_token_245|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128251: AddedToken("<|reserved_special_token_246|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128252: AddedToken("<|reserved_special_token_247|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128253: AddedToken("<|reserved_special_token_248|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128254: AddedToken("<|reserved_special_token_249|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128255: AddedToken("<|reserved_special_token_250|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	128256: AddedToken("<|pad_id|>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}```

In [13]:
# Quantization
quantization_config = BitsAndBytesConfig(
    load_in_4bit=load_in_4bit,
    bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
    bnb_4bit_quant_type=bnb_4bit_quant_type,
    bnb_4bit_use_double_quant=bnb_4bit_use_double_quant
)

In [14]:
# Load model
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map=device,
    attn_implementation=attn_implementation,
    torch_dtype=torch_dtype,
    quantization_config=quantization_config
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [15]:
# Prepare model for kbit training
model = prepare_model_for_kbit_training(model)

In [16]:
# Display the model architecture
display(Markdown(f'```{model}```'))

```LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaFlashAttention2(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=128256, bias=False)
)```

## Dataset

In [17]:
# Dataset ID
dataset_id = "MarkrAI/KoCommercial-Dataset"

In [18]:
# Load the dataset
dataset = load_dataset(dataset_id)

In [19]:
# Dataset information
dataset

DatasetDict({
    train: Dataset({
        features: ['input', 'instruction', 'output'],
        num_rows: 175454
    })
})

In [20]:
# Dataset example
print(dataset["train"][0]["instruction"])
print(dataset["train"][0]["input"])
print(dataset["train"][0]["output"])

보드 게임 스피너는 $A$, $B$, $C$로 표시된 세 부분으로 나뉩니다. 스피너가 $A$에 떨어질 확률은 $\frac{1}{3}$이고, 스피너가 $B$에 떨어질 확률은 $\frac{5}{12}$입니다.  스피너가 $C$에 착륙할 확률은 얼마입니까? 답을 공통 분수로 표현하세요.

모든 가능한 결과의 확률의 합이 1$이므로, 스피너가 $C$에 착륙할 확률을 구하려면 스피너가 $A$와 $B$에 착륙할 확률을 1$에서 빼야 합니다. 이를 방정식으로 쓸 수 있습니다: $P(C) = 1 - P(A) - P(B)$. P(A) = \frac{1}{3}$, $P(B) = \frac{5}{12}$라는 것을 알고 있으므로 이 값을 방정식에 대입하여 단순화할 수 있습니다. 결과는 다음과 같습니다: P(C) = 1 - \frac{1}{3} - frac{5}{12} = \frac{12}{12} - frac{4}{12} - frac{5}{12} = \frac{3}{12}$. 분자와 분모를 $3$로 나누면 이 분수를 줄일 수 있습니다: P(C) = \frac{1}{4}$입니다.


## Preprocessing

In [21]:
# Train on only a subset of the dataset for demonstration purposes
#dataset["train"] = dataset["train"].shuffle(seed=seed).select(range(100))

In [22]:
# Alpaca dataset format: 
# {"instruction": [str],
#   "input": [str],
#   "output": [str]}

# Korean
def prompt_without_input(example):
    prompt = (
        "다음은 작업을 설명하는 지시사항입니다. 요청을 적절하게 완료하는 응답을 작성하십시오.\n\n"
        
        "### 지시사항:\n"
        f"{example['instruction']}\n\n"
        
        "### 응답:\n"
        f"{example['output']}"
        )
    return prompt
    

def prompt_with_input(example):
    prompt = (    
        "다음은 작업을 설명하는 지시사항과, 함께 쌍을 이루어 제공되는 입력입니다. 요청을 적절하게 완료하는 응답을 작성하십시오.\n\n"
        
        "### 지시사항:\n"
        f"{example['instruction']}\n\n"
        
        "### 입력:\n"
        f"{example['input']}\n\n"
        
        "### 응답:\n"
        f"{example['output']}"
        )
    return prompt

In [23]:
def formatting_func(row):
    # if input is not provided
    if row["input"] == "":
        return prompt_without_input(row)
    # if input is provided
    else:
        return prompt_with_input(row)

In [24]:
# formatting function example
print(formatting_func(dataset["train"][0]))

다음은 작업을 설명하는 지시사항입니다. 요청을 적절하게 완료하는 응답을 작성하십시오.

### 지시사항:
보드 게임 스피너는 $A$, $B$, $C$로 표시된 세 부분으로 나뉩니다. 스피너가 $A$에 떨어질 확률은 $\frac{1}{3}$이고, 스피너가 $B$에 떨어질 확률은 $\frac{5}{12}$입니다.  스피너가 $C$에 착륙할 확률은 얼마입니까? 답을 공통 분수로 표현하세요.

### 응답:
모든 가능한 결과의 확률의 합이 1$이므로, 스피너가 $C$에 착륙할 확률을 구하려면 스피너가 $A$와 $B$에 착륙할 확률을 1$에서 빼야 합니다. 이를 방정식으로 쓸 수 있습니다: $P(C) = 1 - P(A) - P(B)$. P(A) = \frac{1}{3}$, $P(B) = \frac{5}{12}$라는 것을 알고 있으므로 이 값을 방정식에 대입하여 단순화할 수 있습니다. 결과는 다음과 같습니다: P(C) = 1 - \frac{1}{3} - frac{5}{12} = \frac{12}{12} - frac{4}{12} - frac{5}{12} = \frac{3}{12}$. 분자와 분모를 $3$로 나누면 이 분수를 줄일 수 있습니다: P(C) = \frac{1}{4}$입니다.


## Inference before Fine-Tuning

In [25]:
def generate_response(system ,user):
    messages = [
        {"role": "system", "content": system},
        {"role": "user", "content": user}
    ]
    prompt = tokenizer.apply_chat_template(
        messages, 
        tokenize=False, 
        add_generation_prompt=False
    )
    
    input_ids = tokenizer.encode(
        prompt,
        max_length=max_length,
        padding=padding,
        truncation=truncation,
        add_special_tokens=False,
        return_tensors="pt"
    ).to(device)
    
    outputs = model.generate(
        input_ids=input_ids,
        pad_token_id=tokenizer.eos_token_id,
        num_return_sequences=num_return_sequences,
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        repetition_penalty=repetition_penalty
    )
    
    return tokenizer.decode(outputs[0], skip_special_tokens=False)

In [26]:
#system_prompt = "You are a helpful assistant. Respond to the following user prompt."
system_prompt = "다음 지시사항에 대한 응답을 작성해주세요."

In [27]:
#user_prompt = "Write me a poem about Machine Learning."
user_prompt = "머신러닝에 대한 시를 써주세요."

In [28]:
response = generate_response(system_prompt, user_prompt)
print(response)

The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.


<|begin_of_text|><|start_header_id|>system<|end_header_id|>

다음 지시사항에 대한 응답을 작성해주세요.<|eot_id|><|start_header_id|>user<|end_header_id|>

머신러닝에 대한 시를 써주세요.<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Here's a poem about machine learning:

In silicon halls, where data reigns
A secret art is woven in its trains
Machine learning's subtle might
Unfolds the mysteries of day and night

With algorithms keen and bright
It seeks to grasp the hidden light
Of patterns, trends, and causality too
And from them, wisdom anew

Through neural networks' winding paths
It navigates the complexities of facts
By weights and biases, it makes its way
To predict the unknown, come what may

The models grow, like trees so tall
As data flows, and knowledge enthralls
Their accuracy, a beacon's call
Guiding us through life's intricate hall

Yet, amidst this digital sea
Lies a mystery, for you and me
How do we know, when we're astray?
Or if our models are leading us away?

The dance of chance and human han

## Supervised Fine-Tuning (LoRA)

In [29]:
lora_config = LoraConfig(
    task_type=task_type,
    target_modules=target_modules,
    r=r,
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    bias=bias
)

In [30]:
# Print the number of trainable parameters
model = get_peft_model(model, lora_config)
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of trainable parameters (in billions): {num_params / 1e9:.2f}")
print(f"Percentage of trainable parameters: {num_params / model.num_parameters() * 100:.2f}%")

Number of trainable parameters (in billions): 0.02
Percentage of trainable parameters: 0.26%


In [31]:
training_args = TrainingArguments(
    output_dir=output_dir,
    logging_dir=logging_dir,
    save_strategy=save_strategy,
    logging_strategy=logging_strategy,
    logging_steps=logging_steps,
    save_total_limit=save_total_limit,
    report_to=report_to,
    
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    gradient_checkpointing=gradient_checkpointing,
    learning_rate=learning_rate,
    lr_scheduler_type=lr_scheduler_type,
    warmup_ratio=warmup_ratio,
    optim=optim,
    weight_decay=weight_decay,
    seed=seed
)

In [32]:
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    peft_config=lora_config,
    args=training_args,
    train_dataset=dataset["train"],
    formatting_func=formatting_func,
    max_seq_length=max_seq_length,
    packing=packing
)

Generating train split: 0 examples [00:00, ? examples/s]

In [None]:
trainer.train()

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
10,2.1212
20,2.4049
30,2.2869
40,2.1428
50,2.3271
60,2.2926
70,2.2387
80,2.2485
90,2.1493
100,2.1428


Step,Training Loss
10,2.1212
20,2.4049
30,2.2869
40,2.1428
50,2.3271
60,2.2926
70,2.2387
80,2.2485
90,2.1493
100,2.1428


In [None]:
wandb.finish()
trainer.save_model(model_name)

## Inference after Fine-Tuning

In [None]:
#system_prompt = "You are a helpful assistant. Respond to the following user prompt."
system_prompt = "다음 지시사항에 대한 응답을 작성해주세요."

In [None]:
#user_prompt = "Write me a poem about Machine Learning."
user_prompt = "머신러닝에 대한 시를 써주세요."

In [None]:
response = generate_response(system_prompt, user_prompt)
print(response)

## Upload Model

In [None]:
# Flush memory
import gc
gc.collect()

del trainer, model
torch.cuda.empty_cache()

In [None]:
# Reload model in FP16 and merge it with LoRA weights
base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map=device,
    torch_dtype=torch.float16
)
model = PeftModel.from_pretrained(base_model, model_name)
model = model.merge_and_unload()

In [None]:
# Reload tokenizer to save it
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

In [None]:
# Push model and tokenizer to Hugging Face Hub
model.push_to_hub(
    repo_id=repo_id,
    use_temp_dir=False
)
tokenizer.push_to_hub(
    repo_id=repo_id,
    use_temp_dir=False
)