<h1 align="center" style="color:green;font-size: 3em;">
Implementing Fine-tuning Techniques</h1>


Implementing various fine-tuning methods as described in different papers, specifically LoRA and IA3.

In this notebook, we will explore IA3 implementations:
- by modifying activations
- by manipulating weights and biases

### Install Dependencies

In [1]:
%pip install datasets transformers -q

Note: you may need to restart the kernel to use updated packages.


### Import Libraries

In [1]:

# importing required libraries
import torch
import torch.nn as nn
import collections
import random
import numpy as np
import math
import matplotlib.pyplot as plt
import warnings

from torch.optim import AdamW
from typing import List
from torch.nn import functional as F
from tqdm import tqdm
from datasets import load_dataset, Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    DataCollatorForLanguageModeling,
    TrainingArguments,
)
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, T5Tokenizer, T5ForSequenceClassification
from torch.utils.data import DataLoader

warnings.simplefilter("ignore")
print(torch.__version__)

2.6.0+cu124


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

### IA3 Memory Analysis

In this section, we will perform memory analysis on a well-known model from Hugging Face called Gemma. Gemma is a state-of-the-art model used for various natural language processing tasks.

However, it is a "gated model", so will require access.

To read more about how to get access to gated models: [link](https://huggingface.co/docs/hub/en/models-gated).

In [2]:
%pip install huggingface_hub -q

Note: you may need to restart the kernel to use updated packages.


In [3]:
from huggingface_hub import login
login()

In [4]:
## Load the gemma model (NEED AUTHENTICATIOIN)
model_name = "google/gemma-2b"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

tokenizer_config.json:   0%|          | 0.00/33.6k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/627 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

In [5]:
## Check the model architecture
model

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
          (act_fn): GELUActivation()
        )
        (input_layernorm): GemmaRMSNorm((2048,), eps=1e-06)
        (post_attention_layernorm): GemmaRMSNorm((2048,), eps=1e-06)
      )
    )
    (norm): GemmaRMSNorm((2048,), 

### Implement IA3 with activations

Next, similar to how we integrate LoRA adapters into our OPT model, we aim to incorporate IA3 into the Gemma model.

More about IA3 [here](https://arxiv.org/pdf/2205.05638).

IA3 has two distinct implementations: one that modifies activations (as traditional PEFT frameworks do) and another that adjusts weights and biases to align with the CLAM framework.

Modifying activations means that during the forward pass, the outputs (also known as activations) of certain layers are adjusted using additional parameters.

Manipulating weights and biases, on the other hand, means changing the way we get the weights such that it will still satisfy the fine-tuning technique. Below is an implementation focusing on manipulating activations. Review the code carefully to understand the approach.

In [6]:
class IA3Adapter1(nn.Module):
    def __init__(self, existing_layer, in_features, out_features, ia3_lr, is_feedforward=False):
        nn.Module.__init__(self)
        self.existing_layer = existing_layer.to("cuda")
        self.is_feedforward = is_feedforward
        self.in_features = in_features
        self.out_features = out_features

        # The trainable weights
        self.ia3_lw = (
            nn.Parameter(
                torch.ones((1, out_features), device="cuda")
            )
            if not is_feedforward
            else nn.Parameter(
                torch.ones((1, in_features), device="cuda")
            )
        )
        nn.init.ones_(self.ia3_lw)

        self.ia3_lr = ia3_lr

    def forward(self, x: torch.Tensor):
        if not self.is_feedforward:
            # We first get the output of the current layer
            result = self.existing_layer(x)
            result = torch.mul(result, self.ia3_lw)
            return result
        else:
            result = torch.mul(x, self.ia3_lw)
            result = self.existing_layer(result)
            return result

In [7]:
def ia3_params(model: nn.Module) -> None:
  total_params = 0
  trainable_params = 0
  # Freeze all parameters in the model
  for param in model.parameters():
    total_params += param.numel()

  # Enable gradients only for ia3 parameters
  for name, param in model.named_parameters():
    if "ia3_lw" in name:
      trainable_params += param.numel()
  return total_params, trainable_params

In [8]:
def match_submodules(model: nn.Module, key:str) -> List[str]:
  matching_layers = []
  for name, module in model.named_modules():
    if key in name:
      matching_layers.append(name)
  return matching_layers


def get_submodule(model: nn.Module, module_name:str):
    return model.get_submodule(module_name)


def replace_submodule(model: nn.Module, module_path: str, new_module):
  modules = module_path.split('.')
  parent_module = model
  for sub in modules[:-1]:
    parent_module = getattr(parent_module, sub)
  setattr(parent_module, modules[-1], new_module)


def inject_adapter(model: nn.Module, match_on: List[str], adapter_fn):
  processed_modules = set()
  for key in match_on:
    matching_layers = match_submodules(model, key)
    for module_path in matching_layers:
      if module_path in processed_modules:
        continue
      current_module = get_submodule(model, module_path)
      new_module = adapter_fn(current_module) # New IA3 module
      new_module = new_module.to(device)  # Move to gpu
      replace_submodule(model, module_path, new_module) # Replace
      processed_modules.add(module_path)

In [9]:
inject_adapter(model, ["k_proj","v_proj","down_proj"], lambda x: IA3Adapter1(x, in_features=x.in_features, out_features=x.out_features, ia3_lr=1e-3, is_feedforward = ["False", "False", "True"]))
total_params, trainable_params = ia3_params(model)

print(f"Total Parameters: {total_params}")
print(f"Trainable Parameters: {trainable_params}")

Total Parameters: 2506541056
Trainable Parameters: 368640


In [10]:
# Check architecture after IA3 interjection
model

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): IA3Adapter1(
            (existing_layer): Linear(in_features=2048, out_features=256, bias=False)
          )
          (v_proj): IA3Adapter1(
            (existing_layer): Linear(in_features=2048, out_features=256, bias=False)
          )
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): IA3Adapter1(
            (existing_layer): Linear(in_features=16384, out_features=2048, bias=False)
          )
          (act_fn): GELUActivation()
        )
        (inp

### Memory analysis


With the activations method, we will now calculate the memory requirement (return result in GB) of each forward pass when we try to inject IA3 into the **key and value matrices** of **each transformer block** in gemma and the **down_proj** layer of the mlp layer.


First calculate how many GB of memory per block (separate KV calculations from down_proj calculations)

The parameter weights of gemma is stored in Brain Float 16 which is 2 bytes (or 16 bits). Gradients are stored in 4 bytes (or 32 bits). We can disregard the calculations related to the optimizer. We also need to take care of the calculations for the additional trainable parameters

Normal IA3 (IA3 Adapter 1)

`For K,V:`

x (batched): `[32,216,2048]` (batch size, sequence length, hidden dim) (not counted towards memory)

self.existing_layer(x)’s shape: `[32,216,256]`

torch.mul(result, self.ia3_lw) result’s shape: `[32,216,256]`

`Total` = 2 matrices KV * 2 computations * 32 (first value of the shape of the previous operation)* 216(second value of the shape of the previous operation) * 256 (third value of the shape of the previous operation) * 16 bits per number / 8 bits per byte = `14,155,776 bytes`


`For down_proj:`
x (batched): `[32,216,16384]` (not counted towards memory)

torch.mul(x, self.ia3_lw): `[32, 216, 16384]`

self.existing_layer(result): `[32, 216, 2048]`

`Total` = (32 * 216 * 16384 + 32 * 216 * 2048)  * 16 bits per number / 8 bits per byte =  `254,803,968 bytes`


`Total activations size per block` = 14,155,776 bytes + 254,803,968 bytes = 268959744 bytes = `0.25 GB`


`Total IA3 activation size for model`  = 0.25 GB * 18 blocks = `4.5087890625 GB`

We also introduce more trainable parameters, whose weight values, gradients, and gradient moments need to be tracked for the optimizer.  Assume we are using the Adam optimizer, which requires 2 gradient moments to be stored per trainable parameter.

Total trainable parameters: 368640

Param weights (bf16): ) 368640 params * 2 bytes = 737280 bytes

Gradients (32 bit) = 368640 params * 4 bytes = 1474560 bytes

Gradient moments (32 bit) = n params * 2 moments * 4 bytes = 2949120 bytes

`Total additional memory for params: 0.00480651855 GB`


### Implement IA3 in CLAM

In this section, the implementation modifies the weights and biases directly instead of manipulating activations.

In [11]:
## Part of the CLAM abstraction (should not compile)
class IA3Adapter2(nn.Module):
    def __init__(self, existing_layer, in_features, out_features, ia3_lr, is_feedforward=False):
        nn.Module.__init__(self)
        self.existing_layer = existing_layer
        self.is_feedforward = is_feedforward
        self.in_features = in_features
        self.out_features = out_features
        self.ia3_lw = (
            nn.Parameter(
                torch.ones((1, out_features), device="cuda")
            )
            if not is_feedforward
            else nn.Parameter(
                torch.ones((1, in_features), device="cuda")
            )
        )
        nn.init.ones_(self.ia3_lw)

        self.ia3_lr = ia3_lr


    ## Everytime we run the forward method, we will get the corresponding weights and biases
    def forward(self, x: torch.Tensor):
        return F.linear(
                x,
                self.get_equivalent_weight(),
                self.get_equivalent_bias()
        )

    def get_equivalent_weight(self):
        """
        Converts IA3 layer to equivalent nn.Linear weight tensor
        """
        mat = self.get_weight()
        ret_weight = None
        if not self.is_feedforward:
            ret_weight = torch.diag(self.ia3_lw.view(-1)) @ mat
        else:
            ret_weight = mat @ torch.diag(self.ia3_lw.view(-1))

        return ret_weight


    def get_equivalent_bias(self):
        """
        Gets equivalent nn.Linear bias data
        """
        mat = self.get_bias()
        if mat is None:
            return None
        ret_bias = None

        if not self.is_feedforward:
            ret_bias = torch.mul(mat, self.ia3_lw.squeeze())
        else:
            ret_bias = mat

        return ret_bias

In [12]:
inject_adapter(model, ["k_proj","v_proj","down_proj"], lambda x: IA3Adapter2(x, in_features=x.in_features, out_features=x.out_features, ia3_lr=1e-3, is_feedforward = ["False", "False", "True"]))
total_params, trainable_params = ia3_params(model)

print(f"Total Parameters: {total_params}")
print(f"Trainable Parameters: {trainable_params}")

Total Parameters: 2507278336
Trainable Parameters: 1105920


In [13]:
model

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): IA3Adapter2(
            (existing_layer): IA3Adapter2(
              (existing_layer): IA3Adapter1(
                (existing_layer): Linear(in_features=2048, out_features=256, bias=False)
              )
            )
          )
          (v_proj): IA3Adapter2(
            (existing_layer): IA3Adapter2(
              (existing_layer): IA3Adapter1(
                (existing_layer): Linear(in_features=2048, out_features=256, bias=False)
              )
            )
          )
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(