## LLM Supervised Finetuning

In this notebook, we will efficiently finetune a pretrained large language model (LLM) with an instruction-following dataset. Specifically, we will finetune OLMo-1B using LoRA on the Alpaca dataset, composed of question-answer pairs.

In [None]:
!pip install datasets

In [None]:
import math
import torch
import torch.nn as nn

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

### Part I: Loading and prompting the base model

We will start with loading [OLMo](https://huggingface.co/allenai/OLMo-1B-0724-hf), an open-source English-only pretrained language model.

In [None]:
model_name = "allenai/OLMo-1B-0724-hf"
# For efficiency, we will use a small max length
model_max_length = 128

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

In [None]:
@torch.no_grad
def run_inference(prompt, max_new_tokens=100):
  inputs = tokenizer([prompt], return_tensors='pt', return_token_type_ids=False)
  inputs = {k: v.to('cuda') for k,v in inputs.items()}
  input_len = inputs["input_ids"].shape[1]
  output = model.generate(
      **inputs,
      max_new_tokens=max_new_tokens,
      # We are sampling from the models outputs, so try and rerun the
      # prompts to see the variation in the outputs.
      do_sample=True,
      top_k=50,
      top_p=0.95,
      # This is to remove a warning where during generation
      # we replace the pad_token_id by eos to stop if the
      # model also generates the padding token.
      # pad_token_id=tokenizer.eos_token_id,
  )
  # Remove the first tokens as they are the input
  # output = output[:, input_len:]
  return tokenizer.batch_decode(output, skip_special_tokens=True)[0]

In [None]:
print(run_inference("Language modeling is "))

### Part II: Adding LoRA layers


Now, we will implement [Low Rank Adaptaion (LoRA)](https://arxiv.org/abs/2106.09685). LoRA is an efficient finetuning method method, where low-rank matrices are added in parallel to the model weights to adapt the pretrained model to a specific task.

In [None]:
class LoRALinear(nn.Module):
    def __init__(
        self,
        pretrained: nn.Linear,
        r: int = 4,
        alpha: float = 1.0,
        dropout: float = 0.0,
    ):
        """
        LoRA-enhanced linear layer.

        Args:
            pretrained_linear (nn.Linear):
                The pretrained linear layer to adapt with LoRA.
            r (int):
                The low-rank dimension for the adapter matrices.
            alpha (float):
                The overall scaling factor for the LoRA update.
            dropout (float):
                Optional dropout to apply to the intermediate activations
                in LoRA (default: 0.0).
        """
        super().__init__()
        raise NotImplementedError("Implement me!")


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Compute the base output from the frozen pretrained layer

        # Compute the LoRA adaptation
        # down: (batch_size, in_features) -> (batch_size, r)
        # up:   (batch_size, r) -> (batch_size, out_features)
        # scale by alpha / r

        # Combine frozen linear output + LoRA adaptation
        raise NotImplementedError("Implement me!")

After defining the LoRA layer, we will implement a method to replace linear layers by LoRA layers.

Following the original paper, we will apply LoRA to the matrices involved in attention.


In [None]:
import torch
import torch.nn as nn

def replace_linear_with_lora(
    module: nn.Module,
    suffixes: list,
    *,
    r: int = 4,
    alpha: float = 1.0,
    dropout: float = 0.0,
):
    """
    Recursively traverses `module` and replaces all nn.Linear layers
    whose names end with one of the given `suffixes` with LoRALinear.

    Args:
        module (nn.Module):
            The PyTorch module to modify in-place.
        suffixes (list of str):
            List of suffixes to match against submodule names. If a submodule's
            name ends with any of these suffixes and is an nn.Linear, it will
            be replaced.
        r (int):
            Low-rank dimension for LoRA.
        alpha (float):
            Scaling factor for LoRA.
        dropout (float):
            Dropout probability to apply between LoRA down and up.

    Returns:
        nn.Module:
            The model with LoRA layers replacing the targeted Linear submodules.
    """
    for child_name, child_module in module.named_children():
        # If this is a Linear with a matching suffix, replace it with LoRALinear
        if isinstance(child_module, nn.Linear) and any(child_name.endswith(sfx) for sfx in suffixes):
            lora_module = LoRALinear(
                pretrained=child_module,
                r=r,
                alpha=alpha,
                dropout=dropout,
            )
            setattr(module, child_name, lora_module)
        else:
            # Recurse into children
            replace_linear_with_lora(
                child_module,
                suffixes,
                r=r,
                alpha=alpha,
                dropout=dropout,
            )

    return module

# We apply LoRA to the attention matrices, as recommended by the paper authors.
lora_suffixes = ["q_proj", "k_proj", "v_proj", "o_proj"]
model = replace_linear_with_lora(
    model, lora_suffixes, r=16, alpha=32, dropout=0.05,
)
model

Finally, since we only want to train the LoRA adapters, we will define a method to freeze (setting requires_grad = False) all other parameters.

In [None]:
def freeze_except_lora(model: nn.Module):
    """
    Sets `requires_grad = False` for all parameters in `model`,
    except for those belonging to LoRA adapters (lora_down, lora_up).

    This ensures that only the low-rank LoRA layers are trained.

    Args:
        model (nn.Module):
            A PyTorch model that may contain LoRALinear submodules.
    """
    # First, freeze everything
    for param in model.parameters():
        param.requires_grad = False

    # Then, unfreeze LoRA parameters
    for module in model.modules():
        # If you named your LoRA class differently, change LoRALinear accordingly
        if module.__class__.__name__ == "LoRALinear":
            for param in module.lora_down.parameters():
                param.requires_grad = True
            for param in module.lora_up.parameters():
                param.requires_grad = True

    return model

model = freeze_except_lora(model)
print("Total parameters:", sum(p.numel() for p in model.parameters()))
print("Trainable parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))

### Part III: Prepare the dataset

Now, we will prepare the dataset. We will train on the [alpaca dataset](https://crfm.stanford.edu/2023/03/13/alpaca.html), an open-source instruction tuning dataset.

The dataset has three fields:

* **instruction** - The user instruction to provide the model;
* **input (optional)** - Optional input for tasks such as "Summarize the following paragraph";
* **output** - Answer to finetune the model.

In [None]:
dataset = load_dataset("tatsu-lab/alpaca")
dataset = dataset["train"]

We will now prepare the dataset for instruction tuning, by formating it as:

`Instruction:\n\n{instruction}\n\nInput:{input}\n\n{input}\n\nAnswer:\n\n`

where the Input part only appears if an input is provided.

In [None]:
def create_prompt(instruction, input=None):
  raise NotImplementedError("Implement me!")

def tokenize_text(record):
  # Get the instruction and input and create a prompt

  # Create a text with the prompt and target concatenated and tokenize it

  # Add bos and eos

  # Return the input ids and labels.
  # In this case, no need to shift the labels by one! HF will take care of it.
  raise NotImplementedError("Implement me!")

dataset = dataset.map(tokenize_text)

In order to ensure no record exceeds the model maximum length, we filter the ones which have a higher number of tokens. We could alternatively truncate the records, but we have many records, so its better to only keep the short ones.

In [None]:
# Here, we could alternatively apply truncation and keep the first tokens of
# the text until the model length is filled. However, since we have many records
# we choose to discard the larger ones which will lead to incomplete texts.
dataset = dataset.filter(lambda x: len(x["input_ids"]) <= model_max_length)

We will pad the dataset to the max model length so that we can create batches.

In [None]:
def pad_to_max_length(record):
  # In the labels, we pad with -100 as this indicates to the cross entropy loss
  # these entries should be ignored.
  raise NotImplementedError("Implement me!")

dataset = dataset.map(pad_to_max_length)

In [None]:
dataset = dataset.select_columns(["input_ids", "labels"])
dataset.set_format("torch")

In [None]:
print(dataset)
for record in dataset.select(range(1)):
  print(record["input_ids"])
  print(record["labels"])
  print(tokenizer.batch_decode(record["input_ids"], skip_special_tokens=False))
  print()

## Part IV: Finetuning the model

Now, we will train the model. Besides [Mixed precision training](https://arxiv.org/abs/1710.03740), we will use gradient accumulation. This technique divides batches into smaller micro batches and accumulates (adds) their gradients. This enables training with an effecitve larger batch size, even when the gpu memory does not allow it.

In [None]:
def train_step(
    *,
    model,
    optimizer,
    scaler,
    batch,
    micro_batch_size,
    device,
):
  # Split batch into smaller tensors
  
  # Iterate over batches:
  # 1. Use mixed precision with float16
  # 2. Scale the loss before calling backward
  # 3. Don't call optimizer.step() until all microbatches are processed

  # With gradient accumulation, we only step after we accumulate all gradients.
  raise NotImplementedError("Implement me!")

def train(
    *,
    model,
    dataset,
    max_steps,
    total_batch_size,
    micro_batch_size,
    learning_rate,
    device,
    log_every=1,
):
  model.train()
  model.to(device)
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  scaler = torch.amp.GradScaler()

  dataloader = torch.utils.data.DataLoader(
      dataset, batch_size=total_batch_size, shuffle=True, pin_memory=True,
  )

  def make_dataiter():
    while True:
      for batch in dataloader:
        yield batch

  dataiter = make_dataiter()

  for step in range(1, max_steps + 1):
    batch = next(dataiter)

    loss = train_step(
        model=model,
        optimizer=optimizer,
        scaler=scaler,
        batch=batch,
        micro_batch_size=micro_batch_size,
        device=device,
    )
    if step % log_every == 0:
      epoch_perc = (step / max_steps) * 100
      print(f"[Step {step}/{max_steps} ({epoch_perc:.0f}%)] loss: {loss:.4f}")

  return model

train(
    model=model,
    dataset=dataset,
    total_batch_size=32,
    micro_batch_size=8,
    learning_rate=2e-4,
    max_steps=400,
    device=device,
)

In [None]:
print(run_inference(create_prompt("What are some common tourist places in Portugal?"), max_new_tokens=128))