In [1]:
from transformers import T5ForConditionalGeneration, T5Tokenizer,T5Config
from transformers.models.t5.modeling_t5 import T5Attention, T5Config, T5Block
from copy import deepcopy
from typing import List, Optional
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.t5.modeling_t5 import T5Attention, T5LayerSelfAttention, T5LayerCrossAttention

t5: T5ForConditionalGeneration = T5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small", legacy=False)

In [2]:
from config import *
from collections import defaultdict

def get_tf_attention_dict(module,kv_heads:int=4):
    transfer_to_gqa: List[str] = ["decoder","EncDecAttention"]
    tf_attention_dict = defaultdict(list)
    def convert_t5_to_gqa(module, kv_heads: int,similarity_flag:bool=False,inplace: bool = False,curr_name:str=''):
        """Get the list of attention modules based on the flag about encoder, decoder or cross-attention

        Args:
            module: Transformer module/unit
            kv_heads (int): Number of key-value heads
            similarity_flag (bool, optional): Similarity GQA flag. Defaults to False.
            inplace (bool, optional): inplace replace the model with GQA. Defaults to False.

        Returns:
            _type_: _description_
        """
        if isinstance(module, T5Attention) and similarity_flag:
            print(curr_name)
            tf_attention_dict[curr_name].append(module)

        out = module if inplace else deepcopy(module)
        for name, child in out.named_children():
            if name in transfer_to_gqa:
                curr_name = name
                similarity_flag = True
            out._modules[name] = convert_t5_to_gqa(child, kv_heads=kv_heads,similarity_flag=similarity_flag, inplace=True,curr_name=curr_name)
        return out

    convert_t5_to_gqa(module,kv_heads=kv_heads)
    return tf_attention_dict

In [3]:
tf_attn_dict = get_tf_attention_dict(t5)

decoder
EncDecAttention
decoder
EncDecAttention
decoder
EncDecAttention
decoder
EncDecAttention
decoder
EncDecAttention
decoder
EncDecAttention


In [4]:
tf_attn_dict.keys()

dict_keys(['decoder', 'EncDecAttention'])

In [5]:
len(tf_attn_dict['EncDecAttention'])

6

In [8]:
first_qkv = tf_attn_dict['decoder'][0]
first_qkv

T5Attention(
  (q): Linear(in_features=512, out_features=512, bias=False)
  (k): Linear(in_features=512, out_features=512, bias=False)
  (v): Linear(in_features=512, out_features=512, bias=False)
  (o): Linear(in_features=512, out_features=512, bias=False)
  (relative_attention_bias): Embedding(32, 8)
)

In [14]:
q = first_qkv.q.weight.data.T
k = first_qkv.k.weight.data.T
v = first_qkv.v.weight.data.T

In [15]:
v.shape,q.shape,k.shape

(torch.Size([512, 512]), torch.Size([512, 512]), torch.Size([512, 512]))

In [17]:
num_heads = 8
#define 8 weights one for each key head and 8 weights one for each value head

## DECODER 

In [18]:
import torch

# Assuming your original matrix is named 'input_matrix' with dimensions (512, 512)
input_matrix = torch.rand((512, 512))

# Define the weight scalars w1, w2, ..., w8
w_values = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]

# Reshape the input matrix to (512, 8, 64) to represent the 8 heads
input_matrix_reshaped = input_matrix.view(512, 8, 64)

# Multiply each head with its corresponding weight scalar
output_heads = torch.stack([input_matrix_reshaped[:, i, :] * w_values[i] for i in range(8)], dim=1)

# Sum the weighted heads along the second dimension to combine them
output_matrix = output_heads.sum(dim=1)

# Verify the shape of the resulting matrix (512, 512)
print(output_matrix.shape)


torch.Size([512, 64])


In [12]:
w = torch.ones((512,512))
print(w)
w = w.view(512,8,64)
#(8x1) and (512,8,64)
mul_val = torch.tensor([[0.1],[0.2],[0.4],[0.3],[0.5],[0.6],[0.7],[0.8]])
# print(mul_val.shape)
# w_1 = torch.reshape(w,())
final_val = torch.multiply(w,mul_val)
fv = final_val.view(512,512)
fv[:,:65]

tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.]])


torch.Size([512, 65])

In [17]:
import torch

w = torch.ones((512, 512))
print(w)

w = w.view(512, 8, 64)
# (8x1) and (512,8,64)
mul_val = torch.tensor([[0.1], [0.2], [0.4], [0.3], [0.5], [0.6], [0.7], [0.8]])

# Ensure mul_val has the same size as the second dimension of w
mul_val = mul_val.view(1, 8, 1)

# Perform matrix multiplication
final_val = torch.multiply(w, mul_val)

# Reshape the result to the original shape (512, 512)
fv = final_val.view(512, 512)

print(fv)


tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.]])
tensor([[0.1000, 0.1000, 0.1000,  ..., 0.8000, 0.8000, 0.8000],
        [0.1000, 0.1000, 0.1000,  ..., 0.8000, 0.8000, 0.8000],
        [0.1000, 0.1000, 0.1000,  ..., 0.8000, 0.8000, 0.8000],
        ...,
        [0.1000, 0.1000, 0.1000,  ..., 0.8000, 0.8000, 0.8000],
        [0.1000, 0.1000, 0.1000,  ..., 0.8000, 0.8000, 0.8000],
        [0.1000, 0.1000, 0.1000,  ..., 0.8000, 0.8000, 0.8000]])


In [15]:
print(mul_val.T.shape)

torch.Size([1, 8, 1])


  print(mul_val.T.shape)


In [33]:
tensor_0_5 = torch.full((8, 1), 0.5)

# Print the resulting tensor
print(tensor_0_5.shape)

torch.Size([8, 1])


In [36]:
w = torch.randn((8,1))
w.shape

torch.Size([8, 1])

In [None]:
for idx,attn_module in tf_attn_dict['decoder']:
    q = attn_module.q.weight.data.T
    k = attn_module.k.weight.data.T
    v = attn_module.v.weight.data.T

    params = nn.ParameterDict({
        f"key_{idx}": nn.Parameter(torch.full((num_heads,1),0.5)),
        f"value_{idx}": nn.Parameter(torch.full((num_heads,1),0.5)),
    })

    k = k.view(512,num_heads,512//num_heads)
    v = v.view(512,num_heads,512//num_heads)
    k_mod = torch.multiply(k,params[f"key_{idx}"])
    v_mod = torch.multiply(v,params[f"value_{idx}"])

    

In [None]:
#b: batch_size
# @n: seq_length
# d: d_model
# h:heads
# k: same as d_model
                #16,
Q= tf .einsum("bnd,hdk->bhnk" , X, P_q) 
K= tf .einsum("bmd,dk->bmk" , M, P_k) 
V= tf .einsum("bmd,dv->bmv" , M, P_v) 
logits = tf .einsum("bhnk,bmk->bhnm", Q, K) 
weights= tf .softmax(logits +mask) 
O= tf .einsum("bhnm,bmv->bhnv", weights , V)
Y= tf .einsum("bhnv,hd->bnd" , O, P_o) 
# return Y

In [3]:
c = {"there":1,"jackass":2}
{"hello_"+k:v for k,v in c.items()}

{'hello_there': 1, 'hello_jackass': 2}

In [7]:
x = torch.randn(2, 3)
x
torch.transpose(x, 0, 1).shape

torch.Size([3, 2])

In [18]:
"SHORT_sum".upper()

'SHORT_SUM'

## OBJECTIVE
1. Take a GQA model, repeat interleave to expand it and then save it
2. AFter that do again similarity based re-ordering and build the GQA model

In [3]:
from t5_SGQA import convert_t5_to_gqa
from config import * 
from transformers import T5ForConditionalGeneration

t5: T5ForConditionalGeneration = T5ForConditionalGeneration.from_pretrained(
        MODEL_NAME
    )

t5_gqa = convert_t5_to_gqa(t5,kv_heads=4)

In [4]:
t5_gqa

T5ForConditionalGeneration(
  (shared): Embedding(32128, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 8)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Drop

In [None]:
t5.decoder.block.d

In [52]:
from torch import nn

d_model = 512
kv_heads = 4
n_heads = 8
for layer in t5_gqa.decoder.block:
    # print(layer.layer[0].SelfAttention.q)
    curr_self_attention_layer = layer.layer[0].SelfAttention
    k_weight_data = curr_self_attention_layer.k.weight.data
    k_weight_data = k_weight_data.view(d_model//n_heads,kv_heads,d_model)
    k_weight_data = torch.repeat_interleave(k_weight_data,2,dim=1).view(-1,d_model)
    
    v_weight_data = curr_self_attention_layer.v.weight.data
    v_weight_data = v_weight_data.view(d_model//n_heads,kv_heads,d_model)
    v_weight_data = torch.repeat_interleave(v_weight_data,2,dim=1).view(-1,d_model)
    
    curr_self_attention_layer.k = nn.Linear(in_features=512,out_features=512,bias=False)
    curr_self_attention_layer.v = nn.Linear(in_features=512,out_features=512,bias=False)
    
    curr_self_attention_layer.k.weight.data = k_weight_data
    curr_self_attention_layer.v.weight.data = v_weight_data

    curr_cross_attention_layer = layer.layer[1].EncDecAttention
    k_weight_data = curr_cross_attention_layer.k.weight.data
    k_weight_data = k_weight_data.view(d_model//n_heads,kv_heads,d_model)
    k_weight_data = torch.repeat_interleave(k_weight_data,2,dim=1).view(-1,d_model)
    
    v_weight_data = curr_cross_attention_layer.v.weight.data
    v_weight_data = torch.repeat_interleave(v_weight_data,2,dim=1).view(-1,d_model)
    
    curr_cross_attention_layer.k = nn.Linear(in_features=512,out_features=512,bias=False)
    curr_cross_attention_layer.v = nn.Linear(in_features=512,out_features=512,bias=False)
    
    curr_cross_attention_layer.k.weight.data = k_weight_data
    curr_cross_attention_layer.v.weight.data = v_weight_data

In [53]:
t5_gqa

T5ForConditionalGeneration(
  (shared): Embedding(32128, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 8)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Drop

In [50]:
import torch

#               (d_model,kv_heads,d_model//n_heads)
#(512,256) --> (512,4,64) --> repeat interleave along dimension one --> (512, 8, 64)

d_model = 8 #512
n_heads = 4 #8
key_value_dim = d_model//n_heads #64
kv_heads = 2 #4

a = torch.ones((4,8))
print(f"a = {a}")
print(f"Shape = {a.shape}")

a[0,:] = 0.1
a[1,:] = 0.2
a[2,:] = 0.3
a[3,:] = 0.4
a = a.view(key_value_dim,kv_heads,d_model)
a = torch.repeat_interleave(a,2,dim=1).view(d_model,-1)

print("Expanded a = ",a)
print("Shape: ",a.shape)

a = tensor([[1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])
Shape = torch.Size([4, 8])
Expanded a =  tensor([[0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
        [0.3000, 0.3000, 0.3000, 0.3000, 0.3000, 0.3000, 0.3000, 0.3000],
        [0.3000, 0.3000, 0.3000, 0.3000, 0.3000, 0.3000, 0.3000, 0.3000],
        [0.4000, 0.4000, 0.4000, 0.4000, 0.4000, 0.4000, 0.4000, 0.4000],
        [0.4000, 0.4000, 0.4000, 0.4000, 0.4000, 0.4000, 0.4000, 0.4000]])
Shape:  torch.Size([8, 8])


In [51]:
a[:2,:]

tensor([[0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000],
        [0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000, 0.1000]])

In [26]:
mul_val.shape

torch.Size([4, 1])

In [27]:
b.shape

torch.Size([8, 2, 2])

In [29]:
b = b.view(8,-1)

In [31]:
b.shape

torch.Size([8, 4])

In [32]:
b[:,0] = 0.1
b[:,1] = 0.2
b[:,2] = 0.3
b[:,3] = 0.4

b

tensor([[0.1000, 0.2000, 0.3000, 0.4000],
        [0.1000, 0.2000, 0.3000, 0.4000],
        [0.1000, 0.2000, 0.3000, 0.4000],
        [0.1000, 0.2000, 0.3000, 0.4000],
        [0.1000, 0.2000, 0.3000, 0.4000],
        [0.1000, 0.2000, 0.3000, 0.4000],
        [0.1000, 0.2000, 0.3000, 0.4000],
        [0.1000, 0.2000, 0.3000, 0.4000]])