<a href="https://colab.research.google.com/github/arunvellat/fineTuning/blob/main/Compressing_T5_Via_Low_Rank_Decomposition_Of_Attention_Matrices.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook performs Singular Value Decomposition on Attention Matrices of T5 Base And creates replaces original weight matrices by their Low Rank Approximation. We are able to **reduce parameters by 25-34% (model size)** with allmost no impact on model's performance (Samsum summarization dataset)

In [None]:
!pip install rouge
!pip install -q bitsandbytes datasets accelerate loralib
!pip install -q git+https://github.com/huggingface/transformers.git@main git+https://github.com/huggingface/peft.git
# install additional dependencies needed for training
!pip install rouge-score tensorboard py7zr
!pip install datasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
from torch import nn
from dataclasses import dataclass
from torch.nn import functional as F
import copy
import torch

In [None]:
from transformers import AutoModelForSeq2SeqLM
from transformers import AutoTokenizer

# huggingface hub model id
model_id_t5_base = "google/flan-t5-base"
#model_id_t5_large = "google/flan-t5-large"

# load model from the hub
model_t5_base = AutoModelForSeq2SeqLM.from_pretrained(model_id_t5_base,  device_map="auto")
#model_t5_large = AutoModelForSeq2SeqLM.from_pretrained(model_id_t5_large,  device_map="auto")

#print_trainable_parameters(model)
tokenizer = AutoTokenizer.from_pretrained(model_id_t5_base)

# Create Model With Low Rank Weight Matrices

In [None]:
@dataclass
class LowRankConfig:
    rank:int
    target_modules: list[str]

In [None]:
#low rank decomposition of SelfAttention Key, Query and Value Matrices
config = LowRankConfig(
    rank= 384,
    target_modules=["k", "q", "v"]
)

In [None]:
class LowRankLayer(nn.Module):
    """given a linear layer find low rank decomposition"""
    def __init__(self, rank, full_rank_layer):
        super().__init__()
        self.rank = rank
        U, S, Vh = torch.linalg.svd(full_rank_layer.weight)
        S_diag = torch.diag(S)
        self.U = U[:, :self.rank]
        self.S = S_diag[:self.rank, :self.rank]
        self.Vh = Vh[:self.rank, :]

    def forward(self, x):
        aprox_weight_matrix = self.U @ self.S @ self.Vh
        output = F.linear(x, aprox_weight_matrix)
        return output


Helper Functions

In [None]:
()
#find the module that ends target suffix
def get_submodules(model, key):
    parent = model.get_submodule(".".join(key.split(".")[:-1]))
    target_name = key.split(".")[-1]
    target = model.get_submodule(key)
    return parent, target, target_name

# this function replaces a target layer with low rank layer
def recursive_setattr(obj, attr, value):
    attr = attr.split('.', 1)
    if len(attr) == 1:
        setattr(obj, attr[0], value)
    else:
        recursive_setattr(getattr(obj, attr[0]), attr[1], value)


**create low rank replica of original model (model_t5_base)**

In [None]:
#create a copy of original model
model_t5_base_lr = copy.deepcopy(model_t5_base)


# SVD: Low rank decomposition of SelfAttention Key, Query and Value Matrices

In [None]:
for key, module in model_t5_base.named_modules():
    target_module_found = any(key.endswith("." + target_key) for target_key in config.target_modules)
    if target_module_found:
        low_rank_layer = LowRankLayer(config.rank, module)
        #replace target layer with low rank layer
        recursive_setattr(model_t5_base_lr, key, low_rank_layer)

In [None]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

***Number of Parameters : Original Model vs Low Rank Model ***

In [None]:
print_trainable_parameters(model_t5_base), print_trainable_parameters(model_t5_base_lr)

trainable params: 247577856 || all params: 247577856 || trainable%: 100.0
trainable params: 183876864 || all params: 183876864 || trainable%: 100.0


(None, None)

**Model Size Reduction (945M to 702M) ~ 25%**

In [None]:
model_t5_base.save_pretrained("model_t5_base", from_pt=True)
model_t5_base_lr.save_pretrained("model_t5_base_lr", from_pt=True)


In [None]:
!ls -lh model_t5_base/pytorch_model.bin

-rw-r--r-- 1 root root 945M May 22 22:08 model_t5_base/pytorch_model.bin


In [None]:
!ls -lh model_t5_base_lr/pytorch_model.bin

-rw-r--r-- 1 root root 702M May 22 22:08 model_t5_base_lr/pytorch_model.bin


**Looking into the layers**

In [None]:
model_t5_base_lr

T5ForConditionalGeneration(
  (shared): Embedding(32128, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): LowRankLayer()
              (k): LowRankLayer()
              (v): LowRankLayer()
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=768, out_features=2048, bias=False)
              (wi_1): Linear(in_features=768, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=768, bias=False)
              (dropout): Dropout(p=0.1, inplace=False)
      

In [None]:
model_t5_base_lr.decoder.block[11].layer[0].SelfAttention.q

LowRankLayer()

In [None]:
model_t5_base_lr.decoder.block[11].layer[0].SelfAttention.q.U.shape,   model_t5_base_lr.decoder.block[11].layer[0].SelfAttention.q.S.shape, model_t5_base_lr.decoder.block[11].layer[0].SelfAttention.q.Vh.shape

(torch.Size([768, 384]), torch.Size([384, 384]), torch.Size([384, 768]))

# Projecting Random Vector On Original SelfAttention Matrix vs its Low Rank Aproximation

In [None]:
#low rank approximation of model_t5_base.encoder.block[0].layer[0].SelfAttention.q
# 768 to 384 dim reduction
query_attention_layer = model_t5_base.encoder.block[0].layer[0].SelfAttention.q
low_rank_query_attention_layer = LowRankLayer(384, model_t5_base.encoder.block[0].layer[0].SelfAttention.q)

In [None]:
random_vector = torch.rand(768)
low_rank_projection = low_rank_query_attention_layer(random_vector)
original_projection = query_attention_layer(random_vector)

**Cosine Distance Similarity**

In [None]:
cosine_sim = torch.nn.CosineSimilarity(dim=0)
cosine_sim(low_rank_projection, original_projection)

tensor(0.9663, grad_fn=<SumBackward1>)

**0.9663 Cosine Similarity between random vector projection on original vs low rank approximation of query weight matrix**

# Evaluation

In [None]:
from datasets import load_dataset, load_metric

raw_datasets = load_dataset("samsum")
metric = load_metric("rouge")



  0%|          | 0/3 [00:00<?, ?it/s]

  metric = load_metric("rouge")


In [None]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 14732
    })
    test: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 819
    })
    validation: Dataset({
        features: ['id', 'dialogue', 'summary'],
        num_rows: 818
    })
})

In [None]:
raw_datasets["test"]

Dataset({
    features: ['id', 'dialogue', 'summary'],
    num_rows: 819
})

**Tokenize Test Dataset**

In [None]:
prefix = "Summarize: "
max_input_length = 512
max_target_length = 64

def preprocess_function(examples):
  # encode the documents
  dialogues = examples['dialogue']
  summaries = examples['summary']

  inputs = [prefix + dialogue for dialogue in dialogues]
  model_inputs = tokenizer(inputs, max_length=max_input_length, padding="max_length", truncation=True)

  # encode the summaries
  labels = tokenizer(summaries, max_length=max_target_length, padding="max_length", truncation=True).input_ids

  # important: we need to replace the index of the padding tokens by -100
  # such that they are not taken into account by the CrossEntropyLoss
  labels_with_ignore_index = []
  for labels_example in labels:
    labels_example = [label if label != 0 else -100 for label in labels_example]
    labels_with_ignore_index.append(labels_example)

  model_inputs["labels"] = labels_with_ignore_index

  return model_inputs

In [None]:
tokenized_datasets = raw_datasets["test"].map(preprocess_function, batched=True, remove_columns=["dialogue", "summary", "id"])


Map:   0%|          | 0/819 [00:00<?, ? examples/s]

In [None]:
len(tokenized_datasets)

819

In [None]:
tokenized_datasets

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 819
})

In [None]:
!pip install nltk

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
!pip install evaluate

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import evaluate
import nltk
import numpy as np
from nltk.tokenize import sent_tokenize
nltk.download("punkt")

# Metric
metric = evaluate.load("rouge")

# helper function to postprocess text
def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(sent_tokenize(label)) for label in labels]

    return preds, labels

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    result = {k: round(v * 100, 4) for k, v in result.items()}
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    return result

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


***`Evaluate google/flan-t5-base`***

In [None]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers import DataCollatorForSeq2Seq


In [None]:

# we want to ignore tokenizer pad token in the loss
label_pad_token_id = -100
# Data collator
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model_t5_base,
    label_pad_token_id=label_pad_token_id,
    pad_to_multiple_of=8
)


This is a dummy Trainer object, it will be only used for evaluation

In [None]:

# Define training args
training_args = Seq2SeqTrainingArguments(
    output_dir="dummy",
    predict_with_generate=True,
)
trainer = Seq2SeqTrainer(
    model=model_t5_base,
    args=training_args,
    data_collator=data_collator,
    #train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_datasets,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.evaluate()


{'eval_loss': 1.4673314094543457,
 'eval_rouge1': 46.1409,
 'eval_rouge2': 22.3112,
 'eval_rougeL': 38.5783,
 'eval_rougeLsum': 42.1144,
 'eval_gen_len': 16.631257631257633,
 'eval_runtime': 537.2978,
 'eval_samples_per_second': 1.524,
 'eval_steps_per_second': 0.192}

**Inference**

In [None]:
raw_datasets["test"][0]

{'id': '13862856',
 'dialogue': "Hannah: Hey, do you have Betty's number?\nAmanda: Lemme check\nHannah: <file_gif>\nAmanda: Sorry, can't find it.\nAmanda: Ask Larry\nAmanda: He called her last time we were at the park together\nHannah: I don't know him well\nHannah: <file_gif>\nAmanda: Don't be shy, he's very nice\nHannah: If you say so..\nHannah: I'd rather you texted him\nAmanda: Just text him 🙂\nHannah: Urgh.. Alright\nHannah: Bye\nAmanda: Bye bye",
 'summary': "Hannah needs Betty's number but Amanda doesn't have it. She needs to contact Larry."}

In [None]:
formatted_query = "Summarize: " + raw_datasets["test"][0]["dialogue"]
tokenized_text = tokenizer(formatted_query, truncation=True, return_tensors='pt')

generated_ids = model_t5_base.generate(**tokenized_text)
predictions = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

In [None]:
predictions

"Amanda can't find Betty's number. Amanda will ask Larry if she can"

**Evaluate Compressed flan-t5-base**

In [None]:
label_pad_token_id = -100
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model_t5_base_lr,
    label_pad_token_id=label_pad_token_id,
    pad_to_multiple_of=8
)

# Define training args
training_args = Seq2SeqTrainingArguments(
    output_dir="dummy",
    predict_with_generate=True,
)
trainer = Seq2SeqTrainer(
    model=model_t5_base_lr,
    args=training_args,
    data_collator=data_collator,
    #train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_datasets,
    compute_metrics=compute_metrics,
)
trainer.evaluate()


{'eval_loss': 1.4673314094543457,
 'eval_rouge1': 46.1134,
 'eval_rouge2': 22.3523,
 'eval_rougeL': 38.5734,
 'eval_rougeLsum': 42.0596,
 'eval_gen_len': 16.631257631257633,
 'eval_runtime': 537.9899,
 'eval_samples_per_second': 1.522,
 'eval_steps_per_second': 0.191}

**Inference**

In [None]:
raw_datasets["test"][0]

{'id': '13862856',
 'dialogue': "Hannah: Hey, do you have Betty's number?\nAmanda: Lemme check\nHannah: <file_gif>\nAmanda: Sorry, can't find it.\nAmanda: Ask Larry\nAmanda: He called her last time we were at the park together\nHannah: I don't know him well\nHannah: <file_gif>\nAmanda: Don't be shy, he's very nice\nHannah: If you say so..\nHannah: I'd rather you texted him\nAmanda: Just text him 🙂\nHannah: Urgh.. Alright\nHannah: Bye\nAmanda: Bye bye",
 'summary': "Hannah needs Betty's number but Amanda doesn't have it. She needs to contact Larry."}

In [None]:
formatted_query = "Summarize: " + raw_datasets["test"][0]["dialogue"]
tokenized_text = tokenizer(formatted_query, truncation=True, return_tensors='pt')

generated_ids = model_t5_base_lr.generate(**tokenized_text)
predictions = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

In [None]:
predictions

"Amanda can't find Betty's number. Amanda will ask Larry if she can"

# Compression V2: Low Rank Approximation Decomposition of SelfAttention Key, Query, Value and Output Matrices

In [None]:
#low rank decomposition of SelfAttention Key, Query and Value Matrices
config = LowRankConfig(
    rank= 384,
    target_modules=["k", "q", "v", "o"]
)

In [None]:
#create a copy of original model
model_t5_base_lr_v2 = copy.deepcopy(model_t5_base)


**Create low rank aprox layers**

In [None]:
for key, module in model_t5_base.named_modules():
    target_module_found = any(key.endswith("." + target_key) for target_key in config.target_modules)
    if target_module_found:
        low_rank_layer = LowRankLayer(config.rank, module)
        #replace target layer with low rank layer
        recursive_setattr(model_t5_base_lr_v2, key, low_rank_layer)

In [None]:
print_trainable_parameters(model_t5_base), print_trainable_parameters(model_t5_base_lr_v2)

trainable params: 247577856 || all params: 247577856 || trainable%: 100.0
trainable params: 162643200 || all params: 162643200 || trainable%: 100.0


(None, None)

**34.31% Compression in Model size**

# Evaluating compressed model on Samsum Summarization Dataset

In [None]:
# label_pad_token_id = -100
# data_collator_3 = DataCollatorForSeq2Seq(
#     tokenizer,
#     model=model_t5_base_lr_v2,
#     label_pad_token_id=label_pad_token_id,
#     pad_to_multiple_of=8
# )

# Define training args
training_args_3 = Seq2SeqTrainingArguments(
    output_dir="dummy",
    predict_with_generate=True,
)
trainer_3 = Seq2SeqTrainer(
    model=model_t5_base_lr_v2,
    args=training_args_3,
    # data_collator=data_collator_3,
    #train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_datasets,
    compute_metrics=compute_metrics,
)
trainer_3.evaluate()

{'eval_loss': 1.4673314094543457,
 'eval_rouge1': 46.1134,
 'eval_rouge2': 22.3523,
 'eval_rougeL': 38.5734,
 'eval_rougeLsum': 42.0596,
 'eval_gen_len': 16.631257631257633,
 'eval_runtime': 570.1044,
 'eval_samples_per_second': 1.437,
 'eval_steps_per_second': 0.181}