# Llama3 8b training using PEFT model with torchtune training loop

Hyper-parameters and data have been matched.

torchtune uses torchao with 8bit and nf4 (?)

bitsandbytes v0.43.3

## imports

In [1]:
import os
import time
from types import SimpleNamespace
from functools import partial

In [2]:
import torch
from torch import nn
from tqdm.notebook import trange, tqdm
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
from torchtune import utils
from torchtune.modules import get_cosine_schedule_with_warmup
from torchtune.datasets import InstructDataset
from torchtune.datasets._instruct import _get_component_from_path
from torchtune.utils import padded_collate, get_memory_stats
from torchtune.utils.metric_logging import DiskLogger
from torchtune.models.llama3 import llama3_tokenizer
from torchtune.modules.common_utils import reparametrize_as_dtype_state_dict_post_hook

## params

In [3]:
dtype = torch.bfloat16
device = 0
rank = 8
alpha = 16
target_modules = ['q_proj', 'v_proj', 'k_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']
weight_decay = 0.01
dropout = 0.05
lr = 0.0003
shuffle = True
batch_size = 2
# dataset: torchtune.datasets.alpaca_cleaned_dataset
gradient_accumulation_steps = 4
total_epochs = 1
max_seq_len = 512
template = _get_component_from_path('torchtune.data.AlpacaInstructTemplate')
log_every_n_steps = 1
num_warmup_steps = 100
max_steps_per_epoch = None

## setup

In [4]:
torch.cuda.manual_seed(0)

In [5]:
base_path = os.path.expanduser("~/work/clones/torchtune/recipes/configs")

In [6]:
tokenizer = llama3_tokenizer(os.path.join(base_path, "Meta-Llama-3-8B-Instruct/original/tokenizer.model"))

In [7]:
ds = InstructDataset(
    tokenizer, 'yahma/alpaca-cleaned', train_on_input=True, max_seq_len=max_seq_len, split="train", template=template,
)

In [8]:
sampler = torch.utils.data.DistributedSampler(
    ds,
    num_replicas=1,
    rank=0,
    shuffle=shuffle,
    seed=0,
)

In [10]:
packed = False
dataloader = torch.utils.data.DataLoader(
    dataset=ds,
    sampler=sampler,
    batch_size=batch_size,
    collate_fn=(
        partial(
            padded_collate,
            padding_idx=128004,
            ignore_idx=-100,
        )
        if not packed
        else None
    ),
)

In [11]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

In [12]:
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3-8B-Instruct",
    torch_dtype=dtype,
    device_map=device,
    quantization_config=bnb_config,
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [13]:
lora_config = LoraConfig(
    r=rank,
    lora_alpha=alpha,
    target_modules=target_modules,
    lora_dropout=dropout,
)

In [14]:
model = get_peft_model(model, lora_config, autocast_adapter_dtype=False) # torchtune uses bf16
model.print_trainable_parameters()

trainable params: 20,971,520 || all params: 8,051,232,768 || trainable%: 0.2605


In [15]:
# same as in torchtune
model._register_state_dict_hook(
    partial(reparametrize_as_dtype_state_dict_post_hook, offload_to_cpu=True)
)

<torch.utils.hooks.RemovableHandle at 0x793c05f0ab50>

In [16]:
optimizer = torch.optim.AdamW(model.parameters(), weight_decay=weight_decay, lr=lr)

In [17]:
criterion = nn.CrossEntropyLoss()

In [18]:
steps_per_epoch = len(dataloader) // gradient_accumulation_steps
num_training_steps = total_epochs * steps_per_epoch
lr_scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)
print(f"{total_epochs=}, {steps_per_epoch=}, {num_training_steps=}")

total_epochs=1, steps_per_epoch=6470, num_training_steps=6470


In [19]:
metric_logger = DiskLogger("/tmp/peft/llama3-8b-qlora-4bit")

Writing logs to /tmp/peft/llama3-8b-qlora-4bit/log_1724160813.txt


## training

### emulate "self" object

In [20]:
self = SimpleNamespace(
    epochs_run=0,
    global_step=0,
    total_epochs=total_epochs,
    _sampler=sampler,
    _steps_per_epoch=steps_per_epoch,
    _dataloader=dataloader,
    _gradient_accumulation_steps=gradient_accumulation_steps,
    max_steps_per_epoch=max_steps_per_epoch,
    _device=torch.device(0),
    _model=model,
    _loss_fn=criterion,
    _optimizer=optimizer,
    _log_every_n_steps=log_every_n_steps,
    _log_peak_memory_stats=True,
    _lr_scheduler=lr_scheduler,
    _metric_logger=metric_logger,
)

### train loop

Copied from:

https://github.com/pytorch/torchtune/blob/bc6b7e9132542e2f6d47d28fab338d42f9b2242d/recipes/lora_dpo_single_device.py#L479

Interrupted early, as not much movement anymore.

In [21]:
t0 = time.perf_counter()
running_loss = 0
num_tokens = 0

for curr_epoch in range(self.epochs_run, self.total_epochs):
    # Update the sampler to ensure data is correctly shuffled across epochs
    # in case shuffle is True
    self._sampler.set_epoch(curr_epoch)

    pbar = tqdm(total=self._steps_per_epoch)
    for idx, batch in enumerate(self._dataloader):
        if (
            self.max_steps_per_epoch is not None
            and (idx // self._gradient_accumulation_steps)
            == self.max_steps_per_epoch
        ):
            break

        # Both are shape [b, s]
        tokens, labels = batch["tokens"], batch["labels"]
        # Get the attention mask and position ids from the dataset if they
        # exist. Currently, only sample packing in PackedDataset returns these
        mask = batch.get("mask", None)  # shape [b, s, s]
        input_pos = batch.get("input_pos", None)  # shape [b, s]

        tokens = tokens.to(self._device)
        num_tokens += tokens.numel()
        labels = labels.to(self._device)
        mask = mask.to(self._device) if mask is not None else None
        input_pos = (
            input_pos.to(self._device) if input_pos is not None else None
        )

        # uncomment to use transformers
        loss = self._model(tokens, attention_mask=mask, labels=labels).loss

        loss = loss / self._gradient_accumulation_steps
        running_loss += loss
        loss.backward()

        # Step with optimizer
        if (idx + 1) % self._gradient_accumulation_steps == 0:
            self._optimizer.step()
            self._optimizer.zero_grad(set_to_none=True)
            self._lr_scheduler.step()
            # Update the number of steps when the weights are updated
            self.global_step += 1

            loss_to_log = running_loss.item()
            pbar.update(1)
            pbar.set_description(
                f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
            )

            # Log per-step metrics
            if self.global_step % self._log_every_n_steps == 0:
                time_per_step = time.perf_counter() - t0
                log_dict = {
                    "loss": loss_to_log,
                    "lr": self._optimizer.param_groups[0]["lr"],
                    "tokens_per_second_per_gpu": num_tokens / time_per_step,
                }
                if (
                    self._device.type == "cuda"
                    and self._log_peak_memory_stats
                ):
                    log_dict.update(
                        utils.get_memory_stats(device=self._device)
                    )
                self._metric_logger.log_dict(
                    log_dict,
                    step=self.global_step,
                )

            # Reset running stats for the next step
            running_loss = 0
            num_tokens = 0
            t0 = time.perf_counter()

        # Step the profiler
        # Note we are stepping each batch, which might not include optimizer step in the trace
        # if the schedule cycle doesn't align with gradient accumulation.
        #prof.step()

    self.epochs_run += 1
    #self.save_checkpoint(epoch=curr_epoch)

  0%|          | 0/6470 [00:00<?, ?it/s]

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)
Exception ignored in: <function tqdm.__del__ at 0x793d585d3f60>
Traceback (most recent call last):
  File "/home/vinh/anaconda3/envs/peft/lib/python3.11/site-packages/tqdm/std.py", line 1147, in __del__
    def __del__(self):

KeyboardInterrupt: 

KeyboardInterrupt

