# FSDP example with Mistral-7B and Guanaco dataset
In this example a network is trained on multiple GPUs with the help of FSDP (Fully Sharded Data Parallel). This approach allows to train networks that are too large to fit into the memory of a single GPU.

If we want to use multiple GPUs, we need to write the code to a file and submit the job to the SLURM scheduler, because JupyterHub at VSC is configured to have access to only one GPU at maximum. This example uses two GPUs on one node, but could be extended to use multiple nodes simply by adjusting the number of nodes in the line
```
#SBATCH --nodes=1
```
in the SLURM script.

#### First, we write the python code to a file:

In [1]:
%%writefile mistral7b_train.py
import torch
from accelerate import PartialState
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
from trl import SFTTrainer
import pynvml


def print_gpu_utilization():
    pynvml.nvmlInit()
    device_count = pynvml.nvmlDeviceGetCount()
    memory_used = []
    for device_index in range(device_count):
        device_handle = pynvml.nvmlDeviceGetHandleByIndex(device_index)
        device_info = pynvml.nvmlDeviceGetMemoryInfo(device_handle)
        memory_used.append(device_info.used/1024**3)
    print('Memory occupied on GPUs: ' + ' + '.join([f'{mem:.1f}' for mem in memory_used]) + ' GB.')


model_id = 'mistralai/Mistral-7B-Instruct-v0.3'

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.pad_token_id = tokenizer.unk_token_id
tokenizer.padding_side = 'right'

data = load_dataset('timdettmers/openassistant-guanaco', split='train')

ps = PartialState()
num_processes = ps.num_processes
process_index = ps.process_index
local_process_index = ps.local_process_index

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_storage=torch.float16,
)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=quantization_config,
    attn_implementation='sdpa',  # 'eager', 'sdpa', or "flash_attention_2"
    torch_dtype=torch.float16,
)

model.config.pad_token_id = tokenizer.pad_token_id
model.config.use_cache = False
model.config.pretraining_tp = 1  # disable tensor parallelism

peft_config = LoraConfig(
    task_type='CAUSAL_LM',
    r=16,
    lora_alpha=32,  # rule: lora_alpha should be 2*r
    lora_dropout=0.05,
    bias='none',
    target_modules='all-linear',
)

project_name = 'mistral7b-guanaco'
run_name = '1'

training_arguments = TrainingArguments(
    output_dir=f'{project_name}-{run_name}',
    per_device_train_batch_size=8,
    gradient_accumulation_steps=1,
    gradient_checkpointing=True, # Gradient checkpointing improves memory efficiency, but slows down training,
        # e.g. Mistral 7B with PEFT using bitsandbytes:
        # - enabled: 11 GB GPU RAM and 12 samples/second
        # - disabled: 40 GB GPU RAM and 8 samples/second
    gradient_checkpointing_kwargs={'use_reentrant': False},  # Use newer implementation that will become the default.
    optim='adamw_torch',  # 'paged_adamw_32bit' can save GPU memory
    learning_rate=2e-4,  # QLoRA suggestions: 2e-4 for 7B or 13B, 1e-4 for 33B or 65B
    warmup_steps=200,
    lr_scheduler_type='cosine',
    logging_strategy='steps',  # 'no', 'epoch' or 'steps'
    logging_steps=50,
    save_strategy='no',  # 'no', 'epoch' or 'steps'
    max_steps=10,
    fp16=True,  # mixed precision training
    report_to='none',  # disable wandb
)

trainer = SFTTrainer(
    model=model,
    args=training_arguments,
    train_dataset=data,
    peft_config=peft_config,
    tokenizer=tokenizer,
    packing=False,
    dataset_text_field='text',
    max_seq_length=1024,
)

if process_index == 0:  # Only print in first process.
    if hasattr(trainer.model, "print_trainable_parameters"):
        trainer.model.print_trainable_parameters()

result = trainer.train()

# Print statistics in first process only:
if process_index == 0:
    print(f"Run time: {result.metrics['train_runtime']:.2f} seconds")
    print(f"{num_processes} GPUs used.")
    print(f"Training speed: {result.metrics['train_samples_per_second']:.1f} samples/s (={result.metrics['train_samples_per_second'] / num_processes:.1f} samples/s/GPU)")

# Print memory usage once per node:
if local_process_index == 0:
    print_gpu_utilization()

# Save model in first process only:
if process_index == 0:
    if trainer.is_fsdp_enabled:
        trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
    trainer.save_model()

Writing mistral7b_train.py


#### Next, we write a file with the configuration for FSDP:

In [2]:
%%writefile fsdp_config.yml
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: true
  fsdp_forward_prefetch: false
  fsdp_offload_params: false
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: false
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 1
rdzv_backend: c10d
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Writing fsdp_config.yml


#### Finally, we write the SLURM script:

In [3]:
%%writefile run_vsc5a100_fsdp.slurm
#!/bin/bash

#SBATCH --partition=zen3_0512_a100x2
# #SBATCH --qos=zen3_0512_a100x2
#SBATCH --qos=admin

## Specify resources:
#SBATCH --nodes=1
#SBATCH --gres=gpu:2  # up to 2 on VSC5/A100
#SBATCH --ntasks-per-node=1
## No need to specify RAM on VSC5, as it will be automatically
## allocated depending on the number of GPUs requested.

#SBATCH --time=0:30:00

# Load conda:
module purge
module load miniconda3

# Include commands in output:
set -x

# Print current time and date:
date

# Print host name:
hostname

# List available GPUs:
nvidia-smi

# Set environment variables for communication between nodes:
export MASTER_PORT=24998
export MASTER_ADDR=$(scontrol show hostnames ${SLURM_JOB_NODELIST} | head -n 1)

# Print statistics:
echo "Using $((SLURM_NNODES * SLURM_GPUS_ON_NODE)) GPUs on $SLURM_NNODES nodes."

# Run AI scripts:
srun bash -c "conda run -n finetuning --no-capture-output accelerate launch \
    --num_machines $SLURM_NNODES \
    --num_processes $((SLURM_NNODES * SLURM_GPUS_ON_NODE)) \
    --num_cpu_threads_per_process 8 \
    --main_process_ip $MASTER_ADDR \
    --main_process_port $MASTER_PORT \
    --machine_rank \$SLURM_PROCID \
    --config_file \"fsdp_config.yml\" \
    mistral7b_train.py"

Writing run_vsc5a100_fsdp.slurm


#### We can now execute the SLURM script and, once the job ran, look at the output:

In [None]:
!sbatch run_vsc5a100_fsdp.slurm

In [None]:
!squeue --me

In [None]:
!tail -c +0 slurm-3990535.out

In [None]:
!rm fsdp_config.yml mistral7b_train.py run_vsc5a100_fsdp.slurm

In [None]:
!rm slurm-*.out

### Output of the same script executed on Leonardo:

```
Unloading profile/base
  ERROR: Module evaluation aborted
+ date
Wed Sep 25 19:27:04 CEST 2024
+ hostname
lrdn3361.leonardo.local
+ nvidia-smi
Wed Sep 25 19:27:04 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02              Driver Version: 530.30.02    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100-SXM-64GB            On | 00000000:8F:00.0 Off |                    0 |
| N/A   43C    P0               61W / 455W|      0MiB / 65536MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM-64GB            On | 00000000:C8:00.0 Off |                    0 |
| N/A   42C    P0               61W / 458W|      0MiB / 65536MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+
+ export MASTER_PORT=24998
+ MASTER_PORT=24998
++ scontrol show hostnames lrdn3361
++ head -n 1
+ export MASTER_ADDR=lrdn3361
+ MASTER_ADDR=lrdn3361
+ echo 'Using 2 GPUs on 1 nodes.'
Using 2 GPUs on 1 nodes.
+ srun bash -c 'conda run -n finetuning --no-capture-output accelerate launch     --num_machines 1     --num_processes 2     --num_cpu_threads_per_process 8     --main_process_ip lrdn3361     --main_process_port 24998     --machine_rank $SLURM_PROCID     --config_file "fsdp_config.yml"     mistral7b_train.py'
Repo card metadata block was not found. Setting CardData to empty.
Repo card metadata block was not found. Setting CardData to empty.
Loading checkpoint shards: 100%|__________| 3/3 [02:56<00:00, 58.79s/it]
Loading checkpoint shards: 100%|__________| 3/3 [02:55<00:00, 58.48s/it]
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
max_steps is given, it will override any value given in num_train_epochs
max_steps is given, it will override any value given in num_train_epochs
trainable params: 41,943,040 || all params: 7,289,966,592 || trainable%: 0.5754
{'train_runtime': 30.6162, 'train_samples_per_second': 5.226, 'train_steps_per_second': 0.327, 'train_loss': 1.3052967071533204, 'epoch': 0.02}
100%|__________| 10/10 [00:30<00:00,  3.06s/it]
Run time: 30.62 seconds
2 GPUs used.
Training speed: 5.2 samples/s (=2.6 samples/s/GPU)
Memory occupied on GPUs: 18.1 + 14.4 GB.
```