In [11]:
#!pip install trl

In [12]:
import os
os.environ["HF_TOKEN"] = "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"

In [7]:
import torch
from trl import SFTTrainer
from transformers import TrainingArguments, AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
from datasets import load_dataset, Dataset
import warnings

In [8]:
warnings.filterwarnings("ignore")

In [13]:
#max_seq_length = 5020
model_name = "meta-llama/Llama-3.2-3B-Instruct"
max_seq_length = 8192  
device_map = "auto"


In [14]:
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    use_fast=True,
    padding_side="right",
    token=os.getenv("HF_TOKEN")
)

In [15]:
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map=device_map,
    attn_implementation="flash_attention_2",
    token=os.getenv("HF_TOKEN")
)

Downloading shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [03:50<00:00, 115.17s/it]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.45it/s]


In [20]:
peft_config = LoraConfig(
    r=16,
    lora_alpha=16,
    target_modules=[
        "q_proj", "k_proj", "v_proj",
        "gate_proj", "up_proj", "down_proj", "o_proj"
    ],
    lora_dropout=0.05,
    #use_rslora=True ?
    bias="none",
    task_type="CAUSAL_LM",
    use_rslora = False
)

In [21]:
model = get_peft_model(model, peft_config)
model.enable_input_require_grads() # do wee need it?


In [None]:
#data

from datasets import Dataset
data = Dataset.load_from_disk("/root/data/dataset1")
train_dataset = data["train"]

def formatting_prompt(examples):
    
    questions = examples["question"]
    answers = examples["answer"]  
    input_ids_list = []
    labels_list = []
    
    for question, answer in zip(questions, answers):
        
        prompt = question
        
        full_text = prompt + tokenizer.bos_token + answer + EOS_TOKEN
        
        tokenized_full = tokenizer(full_text, truncation=True, max_length=max_seq_length)
        tokenized_prompt = tokenizer(prompt, truncation=True, max_length=max_seq_length)
        prompt_length = len(tokenized_prompt["input_ids"])
        
        labels = tokenized_full["input_ids"].copy()
        labels[:prompt_length] = [-100] * prompt_length
        
        input_ids_list.append(tokenized_full["input_ids"])
        labels_list.append(labels)
    
    return {"input_ids": input_ids_list, "labels": labels_list}

training_data = train_dataset.map(formatting_prompt, batched=True)

In [None]:
training_args = TrainingArguments(
    output_dir="./llama-3.2-lora-checkpoints",
    num_train_epochs=40,
    per_device_train_batch_size=16, #8?
    gradient_accumulation_steps=4,
    learning_rate=3e-4,
    weight_decay=0.01,
    lr_scheduler_type="cosine", #linear?
    warmup_steps=100,
    logging_steps=10,
    save_steps=500,
    bf16=True,
    tf32=True,
    gradient_checkpointing=True,
    optim="adamw_torch",
    report_to="wandb",
    max_grad_norm=0.3,
    ddp_find_unused_parameters=False
)

In [None]:
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=training_data,
    max_seq_length=max_seq_length,
    dataset_text_field="text",
    packing=True,
    dataset_num_proc=4,
    neftune_noise_alpha=5 #noise, do we need it?
)