In [1]:
from transformers import AutoModelForSeq2SeqLM,AutoTokenizer , T5ForConditionalGeneration,Seq2SeqTrainer,Seq2SeqTrainingArguments
from transformers.models.t5.modeling_t5 import T5LayerSelfAttention , T5Attention

In [2]:
from TALib import TALib

In [3]:
tokenizer = AutoTokenizer.from_pretrained(TALib.TK_ckpt)
model = AutoModelForSeq2SeqLM.from_pretrained(TALib.CHECKPOINT)

In [4]:
model

T5ForConditionalGeneration(
  (shared): Embedding(32128, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 8)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Drop

In [5]:
import torch

import torch.nn.utils.prune as prune

In [6]:
from transformers.models.t5.modeling_t5 import T5LayerSelfAttention, T5LayerCrossAttention, T5LayerFF

In [7]:
parameters_to_prune = {"self_attention" : [] , "cross_attention":[] , "ffn":[] , "lm_head":[]}
for name, module in model.named_modules():
    # print(name , type(module))
    if isinstance(module ,T5LayerSelfAttention ):
        # print("SelfAttention " , module)
        
        for name_2 , item in module.named_modules():
            if isinstance(item , torch.nn.Linear):
                parameters_to_prune["self_attention"].append((item , "weight"))
                
    if isinstance(module ,T5LayerCrossAttention ):
        # print("CrossAttention " , module)
        
        for name_2 , item in module.named_modules():
            if isinstance(item , torch.nn.Linear):
                parameters_to_prune["cross_attention"].append((item , "weight"))
                
    if isinstance(module ,T5LayerFF ):
        # print("FFN " , module)
        
        for name_2 , item in module.named_modules():
            if isinstance(item , torch.nn.Linear):
                parameters_to_prune["ffn"].append((item , "weight"))
                
    if isinstance(module ,torch.nn.Linear ) and name == "lm_head":
        parameters_to_prune["lm_head"].append((module , "weight"))
        
    # if isinstance(module, torch.nn.Linear):
    #     parameters_to_prune.append((module, "weight"))


In [8]:
parameters_to_prune

{'self_attention': [(Linear(in_features=512, out_features=512, bias=False),
   'weight'),
  (Linear(in_features=512, out_features=512, bias=False), 'weight'),
  (Linear(in_features=512, out_features=512, bias=False), 'weight'),
  (Linear(in_features=512, out_features=512, bias=False), 'weight'),
  (Linear(in_features=512, out_features=512, bias=False), 'weight'),
  (Linear(in_features=512, out_features=512, bias=False), 'weight'),
  (Linear(in_features=512, out_features=512, bias=False), 'weight'),
  (Linear(in_features=512, out_features=512, bias=False), 'weight'),
  (Linear(in_features=512, out_features=512, bias=False), 'weight'),
  (Linear(in_features=512, out_features=512, bias=False), 'weight'),
  (Linear(in_features=512, out_features=512, bias=False), 'weight'),
  (Linear(in_features=512, out_features=512, bias=False), 'weight'),
  (Linear(in_features=512, out_features=512, bias=False), 'weight'),
  (Linear(in_features=512, out_features=512, bias=False), 'weight'),
  (Linear(in_

In [9]:
self_attention_prune_amount = 0.7
cross_attention_prune_amount = 0.7
ff_prune_amount = 0.7
lm_head_amount = 0.7

In [10]:
def prune_model():
    prune.global_unstructured(
        parameters_to_prune["self_attention"],
        pruning_method=prune.L1Unstructured,
        amount=self_attention_prune_amount,
    )
    prune.global_unstructured(
        parameters_to_prune["cross_attention"],
        pruning_method=prune.L1Unstructured,
        amount=cross_attention_prune_amount,
    )
    prune.global_unstructured(
        parameters_to_prune["ffn"],
        pruning_method=prune.L1Unstructured,
        amount=ff_prune_amount,
    )
    prune.global_unstructured(
        parameters_to_prune["lm_head"],
        pruning_method=prune.L1Unstructured,
        amount=lm_head_amount,
    )
    return 

In [11]:
prune_model()

In [12]:
model

T5ForConditionalGeneration(
  (shared): Embedding(32128, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 8)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Drop

In [13]:
TALib.show_param_ratio(model)

0.8544285297393799