## LLM Supervised Finetuning

In this notebook, we will efficiently finetune a pretrained large language model (LLM) with an instruction-following dataset. Specifically, we will finetune OLMo-1B using LoRA on the Alpaca dataset, composed of question-answer pairs.

In [None]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m29.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl 

In [None]:
import math
import torch
import torch.nn as nn

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

### Part I: Loading and prompting the base model

We will start with loading [OLMo](https://huggingface.co/allenai/OLMo-1B-0724-hf), an open-source English-only pretrained language model.

In [None]:
model_name = "allenai/OLMo-1B-0724-hf"
# For efficiency, we will use a small max length
model_max_length = 128

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/5.37k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.12M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/65.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/609 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/9.25k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.71G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/412M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/115 [00:00<?, ?B/s]

OlmoForCausalLM(
  (model): OlmoModel(
    (embed_tokens): Embedding(50304, 2048, padding_idx=1)
    (layers): ModuleList(
      (0-15): 16 x OlmoDecoderLayer(
        (self_attn): OlmoSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): OlmoRotaryEmbedding()
        )
        (mlp): OlmoMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): OlmoLayerNorm()
        (post_attention_layernorm): OlmoLayerNorm()
      )
    )
    (norm): OlmoLayerNorm()
  )
  (

In [None]:
@torch.no_grad
def run_inference(prompt, max_new_tokens=100):
  inputs = tokenizer([prompt], return_tensors='pt', return_token_type_ids=False)
  inputs = {k: v.to('cuda') for k,v in inputs.items()}
  input_len = inputs["input_ids"].shape[1]
  output = model.generate(
      **inputs,
      max_new_tokens=max_new_tokens,
      # We are sampling from the models outputs, so try and rerun the
      # prompts to see the variation in the outputs.
      do_sample=True,
      top_k=50,
      top_p=0.95,
      # This is to remove a warning where during generation
      # we replace the pad_token_id by eos to stop if the
      # model also generates the padding token.
      # pad_token_id=tokenizer.eos_token_id,
  )
  # Remove the first tokens as they are the input
  # output = output[:, input_len:]
  return tokenizer.batch_decode(output, skip_special_tokens=True)[0]

In [None]:
print(run_inference("Language modeling is "))

Language modeling is 
the problem of defining similarity relations between sentences in a document.  
Many methods have been proposed to model the similarity between sentences \cite{Gleison1999}.
Although sentence similarity measurement has attracted much attention, it has its own limitations.
One such limitation is that sentence similarity depends only on its content, but not the surrounding context 
or the style of the sentences. 
Some sentence characteristics, such as co-occurrence in a document, can indicate more about the 


### Part II: Adding LoRA layers


Now, we will implement [Low Rank Adaptaion (LoRA)](https://arxiv.org/abs/2106.09685). LoRA is an efficient finetuning method method, where low-rank matrices are added in parallel to the model weights to adapt the pretrained model to a specific task.

In [None]:
class LoRALinear(nn.Module):
    def __init__(
        self,
        pretrained: nn.Linear,
        r: int = 4,
        alpha: float = 1.0,
        dropout: float = 0.0,
    ):
        """
        LoRA-enhanced linear layer.

        Args:
            pretrained_linear (nn.Linear):
                The pretrained linear layer to adapt with LoRA.
            r (int):
                The low-rank dimension for the adapter matrices.
            alpha (float):
                The overall scaling factor for the LoRA update.
            dropout (float):
                Optional dropout to apply to the intermediate activations
                in LoRA (default: 0.0).
        """
        super().__init__()

        self.pretrained = pretrained

        # Define LoRA down and up projection matrices
        self.lora_down = nn.Linear(pretrained.in_features, r, bias=False)
        self.lora_up = nn.Linear(r, pretrained.out_features, bias=False)

        # Initialize LoRA layers
        # Taken from the microsoft github:
        # https://github.com/microsoft/LoRA/blob/c4593f060e6a368d7bb5af5273b8e42810cdef90/loralib/layers.py#L124
        nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
        nn.init.zeros_(self.lora_up.weight)

        # Compute scaling factor: alpha / r
        # alpha is a hyperparameter controlling the overall scaling
        # r is the low-rank dimension
        self.scaling = alpha / r

        # Optional dropout
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Compute the base output from the frozen pretrained layer
        base_out = self.pretrained(x)

        # Compute the LoRA adaptation
        # down: (batch_size, in_features) -> (batch_size, r)
        # up:   (batch_size, r) -> (batch_size, out_features)
        # scale by alpha / r
        lora_out = self.lora_up(self.dropout(self.lora_down(x))) * self.scaling

        # Combine frozen linear output + LoRA adaptation
        return base_out + lora_out

After defining the LoRA layer, we will implement a method to replace linear layers by LoRA layers.

Following the original paper, we will apply LoRA to the matrices involved in attention.


In [None]:
import torch
import torch.nn as nn

def replace_linear_with_lora(
    module: nn.Module,
    suffixes: list,
    *,
    r: int = 4,
    alpha: float = 1.0,
    dropout: float = 0.0,
):
    """
    Recursively traverses `module` and replaces all nn.Linear layers
    whose names end with one of the given `suffixes` with LoRALinear.

    Args:
        module (nn.Module):
            The PyTorch module to modify in-place.
        suffixes (list of str):
            List of suffixes to match against submodule names. If a submodule's
            name ends with any of these suffixes and is an nn.Linear, it will
            be replaced.
        r (int):
            Low-rank dimension for LoRA.
        alpha (float):
            Scaling factor for LoRA.
        dropout (float):
            Dropout probability to apply between LoRA down and up.

    Returns:
        nn.Module:
            The model with LoRA layers replacing the targeted Linear submodules.
    """
    for child_name, child_module in module.named_children():
        # If this is a Linear with a matching suffix, replace it with LoRALinear
        if isinstance(child_module, nn.Linear) and any(child_name.endswith(sfx) for sfx in suffixes):
            lora_module = LoRALinear(
                pretrained=child_module,
                r=r,
                alpha=alpha,
                dropout=dropout,
            )
            setattr(module, child_name, lora_module)
        else:
            # Recurse into children
            replace_linear_with_lora(
                child_module,
                suffixes,
                r=r,
                alpha=alpha,
                dropout=dropout,
            )

    return module

# We apply LoRA to the attention matrices, as recommended by the paper authors.
lora_suffixes = ["q_proj", "k_proj", "v_proj", "o_proj"]
model = replace_linear_with_lora(
    model, lora_suffixes, r=16, alpha=32, dropout=0.05,
)
model

OlmoForCausalLM(
  (model): OlmoModel(
    (embed_tokens): Embedding(50304, 2048, padding_idx=1)
    (layers): ModuleList(
      (0-15): 16 x OlmoDecoderLayer(
        (self_attn): OlmoSdpaAttention(
          (q_proj): LoRALinear(
            (pretrained): Linear(in_features=2048, out_features=2048, bias=False)
            (lora_down): Linear(in_features=2048, out_features=16, bias=False)
            (lora_up): Linear(in_features=16, out_features=2048, bias=False)
            (dropout): Dropout(p=0.05, inplace=False)
          )
          (k_proj): LoRALinear(
            (pretrained): Linear(in_features=2048, out_features=2048, bias=False)
            (lora_down): Linear(in_features=2048, out_features=16, bias=False)
            (lora_up): Linear(in_features=16, out_features=2048, bias=False)
            (dropout): Dropout(p=0.05, inplace=False)
          )
          (v_proj): LoRALinear(
            (pretrained): Linear(in_features=2048, out_features=2048, bias=False)
            (l

Finally, since we only want to train the LoRA adapters, we will define a method to freeze (setting requires_grad = False) all other parameters.

In [None]:
def freeze_except_lora(model: nn.Module):
    """
    Sets `requires_grad = False` for all parameters in `model`,
    except for those belonging to LoRA adapters (lora_down, lora_up).

    This ensures that only the low-rank LoRA layers are trained.

    Args:
        model (nn.Module):
            A PyTorch model that may contain LoRALinear submodules.
    """
    # First, freeze everything
    for param in model.parameters():
        param.requires_grad = False

    # Then, unfreeze LoRA parameters
    for module in model.modules():
        # If you named your LoRA class differently, change LoRALinear accordingly
        if module.__class__.__name__ == "LoRALinear":
            for param in module.lora_down.parameters():
                param.requires_grad = True
            for param in module.lora_up.parameters():
                param.requires_grad = True

    return model

model = freeze_except_lora(model)
print("Total parameters:", sum(p.numel() for p in model.parameters()))
print("Trainable parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))

Total parameters: 1283981312
Trainable parameters: 4194304


### Part III: Prepare the dataset

Now, we will prepare the dataset. We will train on the [alpaca dataset](https://crfm.stanford.edu/2023/03/13/alpaca.html), an open-source instruction tuning dataset.

The dataset has three fields:

* **instruction** - The user instruction to provide the model;
* **input (optional)** - Optional input for tasks such as "Summarize the following paragraph";
* **output** - Answer to finetune the model.

In [None]:
dataset = load_dataset("tatsu-lab/alpaca")
dataset = dataset["train"]

README.md:   0%|          | 0.00/7.47k [00:00<?, ?B/s]

(…)-00000-of-00001-a09b74b3ef9c3b56.parquet:   0%|          | 0.00/24.2M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/52002 [00:00<?, ? examples/s]

We will now prepare the dataset for instruction tuning, by formating it as:

`Instruction:\n\n{instruction}\n\nInput:{input}\n\n{input}\n\nAnswer:\n\n`

where the Input part only appears if an input is provided.

In [None]:
def create_prompt(instruction, input=None):
  if input is None or len(input) == 0:
    return f"Instruction:\n\n{instruction}\n\nAnswer:\n\n"
  return f"Instruction:\n\n{instruction}\n\nInput:{input}\n\n{input}\n\nAnswer:\n\n"

def tokenize_text(record):
  # Get the instruction and input and create a prompt
  instruction = record["instruction"].strip()
  input = record["input"].strip()
  prompt = create_prompt(instruction, input)
  target = record["output"].strip()
  # Create a text with the prompt and target concatenated and tokenize it
  text = f"{prompt}{target}"
  input_ids = tokenizer(text)["input_ids"]
  # Add bos and eos
  input_ids = input_ids + [tokenizer.eos_token_id]

  labels = [t for t in input_ids]
  # Return the input ids and labels.
  # In this case, no need to shift the labels by one! HF will take care of it.
  return {"input_ids": input_ids, "labels": labels}

dataset = dataset.map(tokenize_text)

Map:   0%|          | 0/52002 [00:00<?, ? examples/s]

In order to ensure no record exceeds the model maximum length, we filter the ones which have a higher number of tokens. We could alternatively truncate the records, but we have many records, so its better to only keep the short ones.

In [None]:
# Here, we could alternatively apply truncation and keep the first tokens of
# the text until the model length is filled. However, since we have many records
# we choose to discard the larger ones which will lead to incomplete texts.
dataset = dataset.filter(lambda x: len(x["input_ids"]) <= model_max_length)

Filter:   0%|          | 0/52002 [00:00<?, ? examples/s]

We will pad the dataset to the max model length so that we can create batches.

In [None]:
def pad_to_max_length(record):
  pad_len = model_max_length - len(record["input_ids"])
  record["input_ids"] = record["input_ids"] + [tokenizer.pad_token_id] * pad_len
  # In the labels, we pad with -100 as this indicates to the cross entropy loss
  # these entries should be ignored.
  record["labels"] = record["labels"] + [-100] * pad_len
  assert len(record["input_ids"]) == model_max_length
  return record

dataset = dataset.map(pad_to_max_length)

Map:   0%|          | 0/41015 [00:00<?, ? examples/s]

In [None]:
dataset = dataset.select_columns(["input_ids", "labels"])
dataset.set_format("torch")

In [None]:
print(dataset)
for record in dataset.select(range(1)):
  print(record["input_ids"])
  print(record["labels"])
  print(tokenizer.batch_decode(record["input_ids"], skip_special_tokens=False))
  print()

Dataset({
    features: ['input_ids', 'labels'],
    num_rows: 41015
})
tensor([10548,  2705,    27,   187,   187, 19735,  1264, 12192,   323, 14596,
         5875,    15,   187,   187, 32869,    27,   187,   187,    18,    15,
           38,   255,   247, 16645,  6196,   285,  1056,  2119,   281,  2486,
         9828,   273, 18098,   285, 15737,    15,   209,   187,    19,    15,
        40626, 11719,   281,  1978,   634,  2133,  3939,   285,  2266,    15,
          209,   187,    20,    15,  5057,  2217,  4600,   285,  6558,   247,
         5185,  4600, 10130,    15, 50279,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,

## Part IV: Finetuning the model

Now, we will train the model. Besides [Mixed precision training](https://arxiv.org/abs/1710.03740), we will use gradient accumulation. This technique divides batches into smaller micro batches and accumulates (adds) their gradients. This enables training with an effecitve larger batch size, even when the gpu memory does not allow it.

In [None]:
def train_step(
    *,
    model,
    optimizer,
    scaler,
    batch,
    micro_batch_size,
    device,
):
  input_ids = batch["input_ids"]
  labels = batch["labels"]

  batch_size = input_ids.shape[0]
  assert batch_size % micro_batch_size == 0, "Batch size must be divisible by micro batch size"
  num_micro_batches = batch_size // micro_batch_size

  # Split batch into smaller tensors
  input_ids_chunks = torch.split(input_ids, micro_batch_size)
  labels_chunks = torch.split(labels, micro_batch_size)

  # Iterate over batches:
  # 1. Use mixed precision with float16
  # 2. Scale the loss before calling backward
  # 3. Don't call optimizer.step() until all microbatches are processed
  for i in range(num_micro_batches):
    # Perform computations with float16
    with torch.autocast(device_type=device, dtype=torch.float16):
      outputs = model(
          input_ids=input_ids_chunks[i].to(device),
          labels=labels_chunks[i].to(device),
      )
    loss = outputs.loss
    # Call backward multiple times to accumulate gradients over microbatches.
    scaler.scale(loss).backward()

  # With gradient accumulation, we only step after we accumulate all gradients.
  scaler.step(optimizer)
  scaler.update()
  optimizer.zero_grad()
  return loss.item()

def train(
    *,
    model,
    dataset,
    max_steps,
    total_batch_size,
    micro_batch_size,
    learning_rate,
    device,
    log_every=1,
):
  model.train()
  model.to(device)
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  scaler = torch.amp.GradScaler()

  dataloader = torch.utils.data.DataLoader(
      dataset, batch_size=total_batch_size, shuffle=True, pin_memory=True,
  )

  def make_dataiter():
    while True:
      for batch in dataloader:
        yield batch

  dataiter = make_dataiter()

  for step in range(1, max_steps + 1):
    batch = next(dataiter)

    loss = train_step(
        model=model,
        optimizer=optimizer,
        scaler=scaler,
        batch=batch,
        micro_batch_size=micro_batch_size,
        device=device,
    )
    if step % log_every == 0:
      epoch_perc = (step / max_steps) * 100
      print(f"[Step {step}/{max_steps} ({epoch_perc:.0f}%)] loss: {loss:.4f}")

  return model

train(
    model=model,
    dataset=dataset,
    total_batch_size=32,
    micro_batch_size=8,
    learning_rate=2e-4,
    max_steps=400,
    device=device,
)



OlmoForCausalLM(
  (model): OlmoModel(
    (embed_tokens): Embedding(50304, 2048, padding_idx=1)
    (layers): ModuleList(
      (0-15): 16 x OlmoDecoderLayer(
        (self_attn): OlmoSdpaAttention(
          (q_proj): LoRALinear(
            (pretrained): Linear(in_features=2048, out_features=2048, bias=False)
            (lora_down): Linear(in_features=2048, out_features=16, bias=False)
            (lora_up): Linear(in_features=16, out_features=2048, bias=False)
            (dropout): Dropout(p=0.05, inplace=False)
          )
          (k_proj): LoRALinear(
            (pretrained): Linear(in_features=2048, out_features=2048, bias=False)
            (lora_down): Linear(in_features=2048, out_features=16, bias=False)
            (lora_up): Linear(in_features=16, out_features=2048, bias=False)
            (dropout): Dropout(p=0.05, inplace=False)
          )
          (v_proj): LoRALinear(
            (pretrained): Linear(in_features=2048, out_features=2048, bias=False)
            (l

In [None]:
print(run_inference(create_prompt("What are some common tourist places in Portugal?"), max_new_tokens=128))

Instruction:

What are some common tourist places in Portugal?

Answer:


Some common tourist places in Portugal include Lisbon, Porto, Sintra, and Algarve. Lisbon is known for its colonial architecture, Porto for its port wine, Sintra for its gorgeous mountain scenery, and the Algarve for its long beach.
