# PEFT LLaMA-2 + Quantization + FastAttention
_________________
## Installing required libraries

In [None]:
!apt update -y
!pip install --upgrade pip
!pip install "transformers==4.31.0" "datasets==2.13.0" "peft==0.4.0" "accelerate==0.21.0" "bitsandbytes==0.40.2" "trl==0.4.7" "safetensors>=0.3.1" "huggingface_hub>=0.16.4" "python-dotenv==1.0.0" "openai>=0.27.8" "langchain[llm]" "git-lfs" --upgrade

## Installing Fast Attention Optimization
If you are runnig from a newly created environment where the last version of fast-attn (V2) pytorch library is not compiled, you MUST uncomment the following lines and compile the library. This would take around 45-60 minutes

In [None]:
# !pip install ninja packaging
# # start time 2:21 pm - end time 3 11
# !MAX_JOBS=4 pip install flash-attn==2.0.4 --no-build-isolation

If you want to check that the fast-attn v2 library is already compiled and working, uncomment the following lines of code and check that the `FlashAttnVarlenQKVPackedFunc` function exist with the `flash_attn_interface.py` script.

In [None]:
# Command to check that the FlashAttnVarlenQKVPackedFunc is already compiled for posterior usage
# !cat /opt/conda/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py

## Login into HuggingFace Hub
-------------
Provide your HuggingFace User Token in order to upload/download models, tokenizers and dataset from the hub.

In [None]:
# !huggingface-cli login --token {}
!huggingface-cli login --token <YOUR HUGGING FACE TOKEN>

## Defining training parameters

In [None]:
########################################
# Flash Attention parameters
########################################
use_flash_attention = True

########################################
# TrainingArguments parameters
########################################
model_name = 'meta-llama/Llama-2-7b-hf'
dataset_name = "mlabonne/guanaco-llama2-1k"
output_dir = "./LLaMA2-QLoRA"
output_model_name = "Llama-2-7b-hf_guanaco-llama2-1k_3-epochs_FastAtt-custom"

# Training iterations. Could be done by epochs OR by steps
num_train_epochs = 3
max_steps = -1

# Batch size settings
per_device_train_batch_size = 6 if use_flash_attention else 4
per_device_eval_batch_size = 4
group_by_length = True

# Learning rate settings
learning_rate = 2e-4
lr_scheduler_type = "constant"

# Weigths initialization and regularization
weight_decay = 0.001
warmup_ratio = 0.03

# Optimizer and gradient settings
optim = "paged_adamw_32bit"
gradient_accumulation_steps = 2
max_grad_norm = 0.3
gradient_checkpointing = True

# Mixed precision settings
fp16 = False
bf16 = True
tf32 = True

# Save and logging settings.
# The checkpoint save strategy to adopt during training. Possible values are:
# - "no": No save is done during training. 
# - "epoch": Save is done at the end of each epoch.
# - "steps": Save is done every save_steps.
save_strategy = "epoch"
save_steps = 10
logging_strategy = "steps"
logging_steps = 10

########################################
# QLoRA parameters
########################################
lora_alpha = 16
lora_dropout = 0.1
lora_r = 64

########################################
# bitsandbytes parameters
########################################
use_4bit = True
use_nested_quant = True 
bnb_4bit_compute_dtype = "bfloat16" 
bnb_4bit_quant_type = "nf4"

########################################
# SFT parameters
########################################
max_seq_length = 2048
merge_weights = True
push_weights = False
packing = False

# PEFT Command 
peft_command = f"python llama2_peft_training.py \
--model_name {model_name} \
--dataset_name {dataset_name} \
--output_dir {output_dir} \
--output_model_name {output_model_name} \
--num_train_epochs {num_train_epochs} \
--max_steps {max_steps} \
--per_device_train_batch_size {per_device_train_batch_size} \
--per_device_eval_batch_size {per_device_eval_batch_size} \
--group_by_length {group_by_length} \
--learning_rate {learning_rate} \
--lr_scheduler_type {lr_scheduler_type} \
--weight_decay {weight_decay} \
--warmup_ratio {warmup_ratio} \
--optim {optim} \
--gradient_accumulation_steps {gradient_accumulation_steps} \
--max_grad_norm {max_grad_norm} \
--gradient_checkpointing {gradient_checkpointing} \
--fp16 {fp16} \
--bf16 {bf16} \
--tf32 {tf32} \
--lora_alpha {lora_alpha} \
--lora_dropout {lora_dropout} \
--lora_r {lora_r} \
--use_4bit {use_4bit} \
--use_nested_quant {use_nested_quant} \
--bnb_4bit_compute_dtype {bnb_4bit_compute_dtype} \
--bnb_4bit_quant_type {bnb_4bit_quant_type} \
--max_seq_length {max_seq_length} \
--merge_weights {merge_weights} \
--push_weights {push_weights} \
--use_flash_attention {use_flash_attention}"

# For inspecting if peft command is well-formed
print(peft_command)

## Launching PEFT 

In [None]:
import os

# Executing peft command
os.system(peft_command)