# DDP example with Mistral-7B and medmcqa dataset
In this example a network is trained on multiple GPUs with the help of DDP (Distributed Data Parallel). This approach allows to train networks that fit into the memory of a single GPU on multiple GPUs in parallel in order to speed up the training.

If we want to use multiple GPUs, we need to write the code to a file and submit the job to the SLURM scheduler, because JupyterHub at VSC is configured to have access to only one GPU at maximum.

This example uses 2 nodes with 2 GPUs each.

#### First, we write the python code to a file:

In [None]:
%%writefile mistral7b_train_ddp.py
import torch
from accelerate import PartialState
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
from trl import SFTTrainer
import random
from textwrap import dedent  # Remove leading whitespace from multiline strings

def print_gpu_utilization():
    pynvml.nvmlInit()
    device_count = pynvml.nvmlDeviceGetCount()
    memory_used = []
    for device_index in range(device_count):
        device_handle = pynvml.nvmlDeviceGetHandleByIndex(device_index)
        device_info = pynvml.nvmlDeviceGetMemoryInfo(device_handle)
        memory_used.append(device_info.used/1024**3)
    print('Memory occupied on GPUs: ' + ' + '.join([f'{mem:.1f}' for mem in memory_used]) + ' GB.')

def medmcqa_get_answer(entry):
    # entry['cop'] is an integer in the range 0..3 that
    # denotes the correct option (a, b, c or d).
    options = {0:'opa', 1:'opb', 2:'opc', 3:'opd'}
    correct_option = options[entry['cop']]
    answer = entry[correct_option]
    return answer

def medmcqa_add_prompt(entry, tokenizer, include_answer, shuffle_options=False):
    options = [
        entry["opa"],
        entry["opb"],
        entry["opc"],
        entry["opd"]
    ]
    if shuffle_options:
        random.shuffle(options)
    messages = [
        {'role': 'user', 'content': dedent(f'''\
            You are a medical student taking a multiple-choice exam. Four options are provided for each question. Only one of these options is the correct answer.
            Question: {entry["question"]}
            Options:
            1. {options[0]}
            2. {options[1]}
            3. {options[2]}
            4. {options[3]}
            Solve this multiple-choice exam question and provide the correct answer.''')
        }
    ]
    if include_answer:
        answer = medmcqa_get_answer(entry)
        messages.append(
            {'role': 'assistant', 'content': f'Answer: {answer}'}
        )
    entry['text'] = tokenizer.apply_chat_template(messages, tokenize=False)
    return entry


model_id = 'mistralai/Mistral-7B-Instruct-v0.3'
max_seq_length = 1024

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.pad_token_id = tokenizer.unk_token_id
tokenizer.padding_side = 'right'

data = load_dataset('medmcqa', split='train')
data = data.map(lambda entry:medmcqa_add_prompt(entry, tokenizer, include_answer=True), load_from_cache_file=True)

ps = PartialState()
num_processes = ps.num_processes
process_index = ps.process_index
local_process_index = ps.local_process_index

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=quantization_config,
    attn_implementation='sdpa',  # 'eager', 'sdpa', or "flash_attention_2"
    device_map={'':local_process_index}
)

model.config.pad_token_id = tokenizer.pad_token_id
model.config.use_cache = False
model.config.pretraining_tp = 1  # disable tensor parallelism

peft_config = LoraConfig(
    task_type='CAUSAL_LM',
    r=16,
    lora_alpha=32,  # rule: lora_alpha should be 2*r
    lora_dropout=0.05,
    bias='none',
    target_modules='all-linear',
)

project_name = 'mistral7b-medmcqa'
run_name = '1'

training_arguments = TrainingArguments(
    # When using newer versions of `trl`, use SFTConfig(...) instead of TrainingArguments(...).
    output_dir=f'{project_name}-{run_name}',
    per_device_train_batch_size=8,
    gradient_accumulation_steps=1,
    gradient_checkpointing=True, # Gradient checkpointing improves memory efficiency, but slows down training,
        # e.g. Mistral 7B with PEFT using bitsandbytes:
        # - enabled: 11 GB GPU RAM and 12 samples/second
        # - disabled: 40 GB GPU RAM and 8 samples/second
    gradient_checkpointing_kwargs={'use_reentrant': False},  # Use newer implementation that will become the default.
    ddp_find_unused_parameters=False,  # Set to False when using gradient checkpointing to suppress warning message.
    log_level_replica='error',  # Disable warnings in all but the first process.
    optim='adamw_torch',  # 'paged_adamw_32bit' can save GPU memory
    learning_rate=2e-4,  # QLoRA suggestions: 2e-4 for 7B or 13B, 1e-4 for 33B or 65B
    warmup_steps=200,
    lr_scheduler_type='cosine',
    logging_strategy='steps',  # 'no', 'epoch' or 'steps'
    logging_steps=50,
    save_strategy='no',  # 'no', 'epoch' or 'steps'
    # save_steps=2000,
    # num_train_epochs=5,
    max_steps=20,
    fp16=True,  # mixed precision training: faster, but uses more memory
    # hub_private_repo=True,
    report_to='none',  # disable wandb
)

trainer = SFTTrainer(
    model=model,
    args=training_arguments,
    train_dataset=data,
    peft_config=peft_config,
    tokenizer=tokenizer,
    packing=False,
    # When using newer versions of `trl`, the argument `training_arguments` should be given as
    # an instance of SFTConfig(...) instead of TrainingArguments(...) and the following
    # parameters should be specified there instead of here:
    dataset_text_field='text',
    max_seq_length=max_seq_length,
)

if process_index == 0:  # Only print in first process.
    if hasattr(trainer.model, "print_trainable_parameters"):
        trainer.model.print_trainable_parameters()

result = trainer.train()

# Print statistics in first process only:
if process_index == 0:
    print(f"Run time: {result.metrics['train_runtime']:.2f} seconds")
    print(f"{num_processes} GPUs used.")
    print(f"Training speed: {result.metrics['train_samples_per_second']:.1f} samples/s (={result.metrics['train_samples_per_second'] / num_processes:.1f} samples/s/GPU)")

# Print memory usage once per node:
if local_process_index == 0:
    print_gpu_utilization()

# Save model in first process only:
if process_index == 0:
    trainer.save_model()

#### Next, we write the SLURM script:

In [None]:
%%writefile run_vsc5a100_ddp.slurm
#!/bin/bash

#SBATCH --partition=zen3_0512_a100x2
# #SBATCH --qos=zen3_0512_a100x2
#SBATCH --qos=admin
#SBATCH --gres=gpu:2  # Number of GPUs (1 or 2)
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=1

#SBATCH --time=1:00:00

# Load conda:
module purge
module load miniconda3

# Include commands in output:
set -x

# Print current time and date:
date

# Print host name:
hostname

# List available GPUs:
nvidia-smi

# Set environment variables for communication between nodes:
export MASTER_PORT=24998
export MASTER_ADDR=$(scontrol show hostnames ${SLURM_JOB_NODELIST} | head -n 1)

# Run AI scripts:
# time conda run -n finetuning --no-capture-output torchrun --nproc_per_node 2 mistral7b_train_ddp.py
time srun conda run -n finetuning --no-capture-output torchrun \
    --nnodes=$SLURM_JOB_NUM_NODES \
    --nproc_per_node=$SLURM_GPUS_ON_NODE \
    --rdzv_id=$SLURM_JOB_ID \
    --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \
    --rdzv_backend=c10d \
    mistral7b_train_ddp.py

#### We can now execute the SLURM script and, once the job ran, look at the output:

In [None]:
!sbatch run_vsc5a100_ddp.slurm

In [None]:
!squeue --me

In [None]:
!tail -c +0 slurm-3991728.out

In [None]:
!rm mistral7b_train_ddp.py run_vsc5a100_ddp.slurm

In [None]:
!rm slurm-*.out