# Load the model

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import BitsAndBytesConfig

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_8bit=True, 
    bnb_8bit_use_double_quant=True, 
    bnb_8bit_quant_type="nf4" 
)

In [None]:
checkpoint = "meta-llama/Llama-3.2-1B-Instruct"
device = "cuda"

In [None]:

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    checkpoint,
    quantization_config=bnb_config,
    device_map="auto")

In [None]:
# model = AutoModelForCausalLM.from_pretrained(checkpoint,device_map='auto')

# Load the data

In [None]:
import pickle as pkl

In [None]:
train_data = pkl.load(open("Dataset/GSM8k_train.pkl","rb"))
val_data = pkl.load(open("Dataset/GSM8k_val.pkl","rb"))

train_data, val_data

# Setup training infra

In [None]:
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig

In [None]:
rank_dimension = 6
lora_alpha = 8
lora_dropout = 0.05

peft_config = LoraConfig(
    r=rank_dimension,  # Rank dimension - typically between 4-32
    lora_alpha=lora_alpha,  # LoRA scaling factor - typically 2x rank
    lora_dropout=lora_dropout,  # Dropout probability for LoRA layers
    bias="none",  # Bias type for LoRA. the corresponding biases will be updated during training.
    target_modules=['q_proj', 'k_proj', 'v_proj'],  # Which modules to apply LoRA to
    task_type="CAUSAL_LM",  # Task type for model architecture,
)

In [None]:
sft_config = SFTConfig(
    output_dir="./TrainingCheckpoints",
    max_steps=1000, 
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    gradient_checkpointing=True,
    learning_rate=5e-5,
    logging_steps=50,
    save_steps=60,
    eval_strategy="steps",
    eval_steps=5,
    use_mps_device=(
        True if device == "mps" else False
    ),  # Use MPS for mixed precision training
    report_to = 'none',
)

In [None]:
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = train_data,
    eval_dataset = val_data,
    args = sft_config,
    peft_config=peft_config, 
)

# Train the model

In [None]:
trainer.train()
# trainer.train(resume_from_checkpoint=True)

In [None]:
trainer.state

In [None]:
import torch
torch.cuda.empty_cache()

# Save the model

In [None]:
trainer.save_model("1000_steps")