In [1]:
from transformers import T5ForConditionalGeneration, T5Tokenizer,T5Config
from transformers.models.t5.modeling_t5 import T5Attention, T5Config, T5Block
from copy import deepcopy
from typing import List, Optional
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.t5.modeling_t5 import T5Attention, T5LayerSelfAttention, T5LayerCrossAttention

t5: T5ForConditionalGeneration = T5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small", legacy=False)

In [37]:
from config import *
from collections import defaultdict

def get_tf_attention_dict(module,kv_heads:int=4):
    transfer_to_gqa: List[str] = ["decoder","EncDecAttention"]
    tf_attention_dict = defaultdict(list)
    def convert_t5_to_gqa(module, kv_heads: int,similarity_flag:bool=False,inplace: bool = False,curr_name:str=''):
        """Get the list of attention modules based on the flag about encoder, decoder or cross-attention

        Args:
            module: Transformer module/unit
            kv_heads (int): Number of key-value heads
            similarity_flag (bool, optional): Similarity GQA flag. Defaults to False.
            inplace (bool, optional): inplace replace the model with GQA. Defaults to False.

        Returns:
            _type_: _description_
        """
        if isinstance(module, T5Attention) and similarity_flag:
            print(curr_name)
            tf_attention_dict[curr_name].append(module)

        out = module if inplace else deepcopy(module)
        for name, child in out.named_children():
            if name in transfer_to_gqa:
                curr_name = name
                similarity_flag = True
            out._modules[name] = convert_t5_to_gqa(child, kv_heads=kv_heads,similarity_flag=similarity_flag, inplace=True,curr_name=curr_name)
        return out

    convert_t5_to_gqa(module,kv_heads=kv_heads)
    return tf_attention_dict

In [38]:
tf_attn_dict = get_tf_attention_dict(t5)

decoder
EncDecAttention
decoder
EncDecAttention
decoder
EncDecAttention
decoder
EncDecAttention
decoder
EncDecAttention
decoder
EncDecAttention


In [5]:
tf_attn_dict.keys()

dict_keys(['decoder', 'EncDecAttention'])

In [7]:
len(tf_attn_dict['EncDecAttention'])

6

In [8]:
first_qkv = tf_attn_dict['decoder'][0]
first_qkv

T5Attention(
  (q): Linear(in_features=512, out_features=512, bias=False)
  (k): Linear(in_features=512, out_features=512, bias=False)
  (v): Linear(in_features=512, out_features=512, bias=False)
  (o): Linear(in_features=512, out_features=512, bias=False)
  (relative_attention_bias): Embedding(32, 8)
)

In [14]:
q = first_qkv.q.weight.data.T
k = first_qkv.k.weight.data.T
v = first_qkv.v.weight.data.T

In [15]:
v.shape,q.shape,k.shape

(torch.Size([512, 512]), torch.Size([512, 512]), torch.Size([512, 512]))

In [17]:
num_heads = 8
#define 8 weights one for each key head and 8 weights one for each value head

## DECODER 

In [18]:
import torch

# Assuming your original matrix is named 'input_matrix' with dimensions (512, 512)
input_matrix = torch.rand((512, 512))

# Define the weight scalars w1, w2, ..., w8
w_values = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]

# Reshape the input matrix to (512, 8, 64) to represent the 8 heads
input_matrix_reshaped = input_matrix.view(512, 8, 64)

# Multiply each head with its corresponding weight scalar
output_heads = torch.stack([input_matrix_reshaped[:, i, :] * w_values[i] for i in range(8)], dim=1)

# Sum the weighted heads along the second dimension to combine them
output_matrix = output_heads.sum(dim=1)

# Verify the shape of the resulting matrix (512, 512)
print(output_matrix.shape)


torch.Size([512, 64])


In [35]:
w = torch.ones((4,4))
print(w)
w = w.view(4,2,2)
mul_val = torch.tensor([[0.2],[0.3]])
# print(mul_val.shape)
# w_1 = torch.reshape(w,())
final_val = torch.multiply(w,mul_val)
print(final_val.view(4,4))

tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])
tensor([[0.2000, 0.2000, 0.3000, 0.3000],
        [0.2000, 0.2000, 0.3000, 0.3000],
        [0.2000, 0.2000, 0.3000, 0.3000],
        [0.2000, 0.2000, 0.3000, 0.3000]])


In [33]:
tensor_0_5 = torch.full((8, 1), 0.5)

# Print the resulting tensor
print(tensor_0_5.shape)

torch.Size([8, 1])


In [36]:
w = torch.randn((8,1))
w.shape

torch.Size([8, 1])

In [None]:
for idx,attn_module in tf_attn_dict['decoder']:
    q = attn_module.q.weight.data.T
    k = attn_module.k.weight.data.T
    v = attn_module.v.weight.data.T

    params = nn.ParameterDict({
        f"key_{idx}": nn.Parameter(torch.full((num_heads,1),0.5)),
        f"value_{idx}": nn.Parameter(torch.full((num_heads,1),0.5)),
    })

    k = k.view(512,num_heads,512//num_heads)
    v = v.view(512,num_heads,512//num_heads)
    k_mod = torch.multiply(k,params[f"key_{idx}"])
    v_mod = torch.multiply(v,params[f"value_{idx}"])

    

In [None]:
#b: batch_size
# @n: seq_length
# d: d_model
# h:heads
# k: same as d_model
                #16,
Q= tf .einsum("bnd,hdk->bhnk" , X, P_q) 
K= tf .einsum("bmd,dk->bmk" , M, P_k) 
V= tf .einsum("bmd,dv->bmv" , M, P_v) 
logits = tf .einsum("bhnk,bmk->bhnm", Q, K) 
weights= tf .softmax(logits +mask) 
O= tf .einsum("bhnm,bmv->bhnv", weights , V)
Y= tf .einsum("bhnv,hd->bnd" , O, P_o) 
# return Y