<a href="https://colab.research.google.com/github/DeekshithaDPrakash/LLM_Notebooks/blob/main/My_finds/Multi_GPU_DPO_Training_with_FSDP_and_QLoRA_for_Qwen2_5_72B_Instruct.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

*More details in this article: [Multi-GPU DPO Training with FSDP: Full Training, LoRA and QLoRA](https://kaitchup.substack.com/p/multi-gpu-dpo-training-with-fsdp)*


This notebook shows how to train a 70B LLM, e.g., Qwen2.5 72B, using multiple GPUs. It exploits FSDP for multi-gpus training and QLoRA for parameter-efficient fine-tuning.

This code runs on four 24 GB GPUs and requires at least 170 GB of CPU RAM.

For supervised fine-tuning, the step before DPO training, check this article: [Multi-GPU Fine-tuning for Llama 3.1 70B with FSDP and QLoRA](https://kaitchup.substack.com/p/multi-gpus-fine-tuning-for-llama)


*Note: This code was not tested with a Jupyter notebook. You may copy the training code into Python file and run this Python file with Accelerate.*


First, we need to install:

*Note: You need Transformers 4.46.3 (or more recent)*

In [1]:
!pip install --upgrade bitsandbytes transformers peft accelerate datasets trl flash_attn

Collecting bitsandbytes
  Downloading bitsandbytes-0.45.3-py3-none-manylinux_2_24_x86_64.whl.metadata (5.0 kB)
Collecting transformers
  Downloading transformers-4.49.0-py3-none-any.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
Collecting accelerate
  Downloading accelerate-1.5.2-py3-none-any.whl.metadata (19 kB)
Collecting datasets
  Downloading datasets-3.4.1-py3-none-any.whl.metadata (19 kB)
Collecting trl
  Downloading trl-0.15.2-py3-none-any.whl.metadata (11 kB)
Collecting flash_attn
  Downloading flash_attn-2.7.4.post1.tar.gz (6.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.0/6.0 MB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-ma

Then, configure Accelerate:

In [3]:
!accelerate config

----------------------------------------------------------------------------------------------------In which compute environment are you running?
Please input a choice index (starting from 0), and press enter
 ➔  [32mThis machine[0m
    AWS (Amazon SageMaker)
[2A[?25l0
[32mThis machine[0m
----------------------------------------------------------------------------------------------------Which type of machine are you using?
Please input a choice index (starting from 0), and press enter
 ➔  [32mNo distributed training[0m
    multi-CPU
    multi-XPU
    multi-GPU
    multi-NPU
    multi-MLU
    multi-SDAA
    multi-MUSA
    TPU
[9A[?25l[?25hTraceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/accelerate/commands/menu/cursor.py", line 63, in hide
    yield
  File "/usr/local/lib/python3.11/dist-packages/accelerate/commands/menu/selection_menu.py", line 133, in run
    choice = int(builtins.input())
                 ^^^^^^^^^^^^^^^^
KeyboardInterru

or use the following configuration file that you may copy into a file named "config_fsdp.yaml"

In [4]:
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: true
  fsdp_forward_prefetch: false
  fsdp_offload_params: true
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: false
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

SyntaxError: invalid syntax (<ipython-input-4-7fcab895fb5a>, line 5)

The training code below must be run with accelerate. Copy it into a file, e.g., "fsdp+QLoRA.py" and then run


```
accelerate launch --config_file config_fsdp.yaml fsdp+QLoRA.py
```



In [None]:

import torch, os, multiprocessing
from datasets import load_dataset
from peft import PeftModel
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    set_seed
)
from peft.utils.other import fsdp_auto_wrap_policy
from accelerate import Accelerator
from trl import DPOTrainer, DPOConfig
accelerator = Accelerator()
set_seed(1234)




model_name = "Qwen/Qwen2.5-72B-Instruct"
sft_adapter = "./SFT_LoRA/" #a LoRA adapter fine-tuned with SFT

compute_dtype = torch.bfloat16

#If you have troubles with FlashAttention, use 'sdpa' instead
attn_implementation = 'flash_attention_2'

#Modify the following 3 training arguments if you run out of memory
bs = 1 #Batch size per device (training and validation)
gas = 16 #Gradient accumulation steps
mseqlen = 512 #Maximum sequence length


lr = 1e-5 #Learning rate
QLoRA = True #Quantize the base model. I don't recommend it if you have enough memory to run LoRA
lora_alpha = 16
lora_dropout = 0.0
lora_r = 16

output_dir = "/workspace/DPO_LoRA"

#Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = "<|image_pad|>"
tokenizer.pad_token_id = 151655
tokenizer.padding_side = 'right' #right or left doesn't seem to matter for Qwen2.5 (which is not the case for Llama 3.1 which is better with right-padding for some reasons)

#A dataset to test DPO training
ds = load_dataset("mlabonne/orpo-dpo-mix-40k", split="train").train_test_split(test_size=0.01)
ds_train = ds['train']
ds_test = ds['test']

#Add the EOS token
def process(row):
    #The first message is the prompt
    prompt_messages = tokenizer.apply_chat_template([row["chosen"][0]], tokenize=False)
    chosen_messages = tokenizer.apply_chat_template(row["chosen"][1:], tokenize=False)+tokenizer.eos_token
    rejected_messages = tokenizer.apply_chat_template(row["rejected"][1:], tokenize=False)+tokenizer.eos_token
    row["prompt"] = prompt_messages
    row["chosen"] = chosen_messages
    row["rejected"] = rejected_messages
    return row

ds_train = ds_train.map(
    process,
    num_proc= multiprocessing.cpu_count(),
    load_from_cache_file=False,
)

ds_test = ds_test.map(
    process,
    num_proc= multiprocessing.cpu_count(),
    load_from_cache_file=False,
)


if QLoRA:
    bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=compute_dtype,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_storage=compute_dtype,
    )


    model = AutoModelForCausalLM.from_pretrained(
              model_name, quantization_config=bnb_config, torch_dtype=torch.bfloat16, attn_implementation=attn_implementation
    )
    for name, param in model.named_parameters():
    # freeze base model's layers
        param.requires_grad = False
    def make_inputs_require_grad(module, input, output):
        output.requires_grad_(True)

    model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
else:
    model = AutoModelForCausalLM.from_pretrained(
              model_name, torch_dtype=torch.bfloat16, attn_implementation=attn_implementation
    )
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant':True})

model = PeftModel.from_pretrained(model, sft_adapter, is_trainable=True, adapter_name="DPO")
model.load_adapter(sft_adapter, adapter_name="reference")

training_arguments = DPOConfig(
        output_dir=output_dir,
        eval_strategy="steps",
        do_eval=True,
        optim="adamw_torch",
        per_device_train_batch_size=bs,
        gradient_accumulation_steps=gas,
        per_device_eval_batch_size=bs,
        log_level="debug",
        save_strategy="steps",
        save_steps=5,
        logging_steps=2,
        learning_rate=lr,
        bf16 = True,
        beta = 0.1,
        eval_steps=2,
        max_steps=10,
        warmup_ratio=0.1,
        lr_scheduler_type="linear",
        max_length=mseqlen,
        max_prompt_length=512,
        dataset_num_proc=multiprocessing.cpu_count(),
        model_adapter_name="DPO",
        ref_adapter_name="reference",
)


trainer = DPOTrainer(
    model,
    args=training_arguments,
    train_dataset=ds_train,
    eval_dataset=ds_test,
    processing_class=tokenizer,
)


# LoRA's parameters are float32, we must downcast them to bfloat16
# Necessary to flatten the tensors during model preparation by FSDP
for param in model.parameters():
     if (param.dtype == torch.float32):
         param.data = param.data.to(torch.bfloat16)

if trainer.ref_model is not None:
    fsdp_plugin = trainer.accelerator.state.fsdp_plugin
    fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(trainer.ref_model)
    trainer.ref_model = trainer.accelerator.prepare_model(trainer.ref_model)

fsdp_plugin = trainer.accelerator.state.fsdp_plugin
fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(trainer.model)

prepared_model = trainer._wrap_model(
    trainer.model, training=True, dataloader=None
)

(
    prepared_model,
    trainer.optimizer,
    trainer.lr_scheduler,
) = trainer.accelerator.prepare(
    prepared_model, trainer.optimizer, trainer.lr_scheduler
)
trainer.model_wrapped = prepared_model
if trainer.is_fsdp_enabled:
    trainer.model = prepared_model


trainer.accelerator.prepare_model = lambda model, *args, **kwargs: model

trainer.train()

if trainer.is_fsdp_enabled:
    trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")

trainer.save_model(output_dir)