## AIM
1. Along the head dimension do KNN and order the heads based on the number of groups, that is the most closes cluster is together
    1. Getting the neighbours based on similarity and group number
    2. Take it to the GPU, by building the index there
2. Integrate the logic in the Grouped Query Attention module

## SIMILARITY BASED GQA

## STEPS
1. Go through the architecture and based on encoder, decoder and EncDecoder Attention (cross-attention) get the queries, keys and values (have the attention layer name as attribute)
2. Apply KNN and arrange the key, queries and values. Do it in GPU
3. Return the model with the shuffled K, Q and V

In [8]:
from transformers import T5ForConditionalGeneration
from transformers.models.t5.modeling_t5 import T5Attention, T5Config, T5Block
from copy import deepcopy
from typing import List
t5: T5ForConditionalGeneration = T5ForConditionalGeneration.from_pretrained(
        "t5-small"
    )

tf_attention_list = []
transfer_to_gqa: List[str] = ["encoder","decoder","EncDecAttention"]
def convert_t5_to_gqa(module, kv_heads: int,similarity_flag:bool=False,inplace: bool = False):
    """Get the list of attention modules based on the flag about encoder, decoder or cross-attention

    Args:
        module: Transformer module/unit
        kv_heads (int): Number of key-value heads
        similarity_flag (bool, optional): Similarity GQA flag. Defaults to False.
        inplace (bool, optional): inplace replace the model with GQA. Defaults to False.

    Returns:
        _type_: _description_
    """
    if isinstance(module, T5Attention) and similarity_flag:
        # for name, child in module.named_children():
        #     print(name)
        # print('Module:',module)
        # 
        tf_attention_list.append(module)
        # return SimT5GQA.from_t5_attention(module, kv_heads=kv_heads)

    out = module if inplace else deepcopy(module)
    for name, child in out.named_children():
        # if name in ["encoder","decoder","EncDecAttention"]:
            # print(name,child)
        #     print("-"*100)
        if name in transfer_to_gqa:
            # print(name,child)
            similarity_flag = True
        out._modules[name] = convert_t5_to_gqa(child, kv_heads=kv_heads,similarity_flag=similarity_flag, inplace=True)
    return out

## INDEX ONE OF THE MODULE TO DO SIMILARITY-BASED GROUPING

In [9]:
out = convert_t5_to_gqa(t5,8)

In [10]:
first_attn = tf_attention_list[0]

In [13]:
# first_attn.q.weight.data
first_attn.v.weight.data.shape

torch.Size([512, 512])

In [7]:
first_attn.n_heads

8

In [8]:
512//8

64

## SPLITING THE PROJECTION HEADS INTO RESPECTIVE HEADS

In [33]:
# import torch
# num_heads = first_attn.n_heads
# query_heads = torch.tensor_split(first_attn.q.weight,num_heads,dim=1)
# key_heads = torch.tensor_split(first_attn.k.weight.data,num_heads,dim=1)
# value_heads = torch.tensor_split(first_attn.v.weight.data,num_heads,dim=1)

In [43]:

# query
# (n_seq x d_model) @ (d_model x d_model//num_heads) = (n_seq x d_model//num_head)

#key
# (n_seq x d_model) @ (d_model x d_model//num_heads)  = (n_seq x d_model//num_head)

#value
# (n_seq x d_model) @ (d_model x d_model//num_heads) =  (n_seq x d_model//num_head)




### Cosine Similarity

In [157]:
import torch
import torch.nn.functional as F

def cosine_similarity(query_heads,key_heads,value_heads):

    num_heads = 8
    query_heads = torch.tensor_split(query_heads,num_heads,dim=1)
    key_heads = torch.tensor_split(key_heads,num_heads,dim=1)
    value_heads = torch.tensor_split(value_heads,num_heads,dim=1)

    # num_heads = len(query_heads)
    flattened_vectors = [head.reshape(-1) for head in query_heads]  # Flatten each matrix
    pair_similarities = []

    # Calculate cosine similarity for all pairs
    for i in range(num_heads):
        for j in range(i + 1, num_heads):
            vec1 = F.normalize(flattened_vectors[i], p=2, dim=0)
            vec2 = F.normalize(flattened_vectors[j], p=2, dim=0)
            similarity = torch.dot(vec1, vec2).item()
            pair_similarities.append((similarity, i, j))

    # Sort pairs by similarity (highest first)
    pair_similarities.sort(reverse=True, key=lambda x: x[0])

    # Group heads into pairs based on highest similarity
    grouped_pairs = []
    used_heads = set()
    for _, head1, head2 in pair_similarities:
        if head1 not in used_heads and head2 not in used_heads:
            grouped_pairs.append(head1)
            grouped_pairs.append(head2)
            used_heads.update([head1, head2])
    # print(grouped_pairs)
    query_heads_grouped = torch.cat([query_heads[i] for i in grouped_pairs],dim=1)
    key_heads_grouped = torch.cat([key_heads[i] for i in grouped_pairs],dim=1)
    value_heads_grouped = torch.cat([value_heads[i] for i in grouped_pairs],dim=1)

    return query_heads_grouped,key_heads_grouped,value_heads_grouped,grouped_pairs

# q_grp, k_grp, v_grp  = cosine_similarity(query_heads,key_heads,value_heads)

In [183]:
tf_attention_list = []
transfer_to_gqa: List[str] = ["decoder"]
def convert_t5_to_gqa(module, kv_heads: int,similarity_flag:bool=True,inplace: bool = False):
    """Get the list of attention modules based on the flag about encoder, decoder or cross-attention

    Args:
        module: Transformer module/unit
        kv_heads (int): Number of key-value heads
        similarity_flag (bool, optional): Similarity GQA flag. Defaults to False.
        inplace (bool, optional): inplace replace the model with GQA. Defaults to False.

    Returns:
        _type_: _description_
    """
    if not similarity_flag:
        return module if inplace else deepcopy(module)
    
    out = module if inplace else deepcopy(module)
    
    num_heads = 8

    for component_name in transfer_to_gqa:
        component = getattr(out,component_name)
        for layer in component.block:
            if component_name == 'encoder':
                pass
            elif component_name == 'decoder':
                decoder_self_attention_block = layer.layer[0].SelfAttention
                decoder_cross_attention_block = layer.layer[1].EncDecAttention 

                tf_attention_list.extend([decoder_self_attention_block,decoder_cross_attention_block])
                
                # Get the query, key, and value tensors for self-attention
                q_lin, k_lin, v_lin = decoder_self_attention_block.q, decoder_self_attention_block.k, decoder_self_attention_block.v
                query_heads,q_bias,key_heads,k_bias,value_heads,v_bias = q_lin.weight,q_lin.bias,k_lin.weight,k_lin.bias,v_lin.weight,v_lin.bias

                # Reorder them based on cosine similarity
                query_heads,key_heads,value_heads,grouped_pairs = cosine_similarity(query_heads, key_heads, value_heads)
                # print(grouped_pairs)
                # Replace original tensors with reordered ones
                decoder_self_attention_block.q.weight = torch.nn.Parameter(query_heads) 
                decoder_self_attention_block.k.weight = torch.nn.Parameter(key_heads)
                decoder_self_attention_block.v.weight = torch.nn.Parameter(value_heads)

                #this condition needs to be checked if bias is being used, not sure if this works
                if q_bias is not None:
                    q_bias = q_bias[grouped_pairs]
                    k_bias = k_bias[grouped_pairs]
                    v_bias = v_bias[grouped_pairs]

                    decoder_self_attention_block.q.bias = torch.nn.Parameter(q_bias)
                    decoder_self_attention_block.k.bias = torch.nn.Parameter(k_bias)
                    decoder_self_attention_block.v.bias = torch.nn.Parameter(v_bias)

                # Get the query, key, and value tensors for cross-attention
                q_cross_lin, k_cross_lin, v_cross_lin = decoder_cross_attention_block.q, decoder_cross_attention_block.k, decoder_cross_attention_block.v
                query_heads,q_bias,key_heads,k_bias,value_heads,v_bias = q_cross_lin.weight,q_cross_lin.bias,k_cross_lin.weight,k_cross_lin.bias,v_cross_lin.weight,v_cross_lin.bias

                # Reorder them based on cosine similarity
                query_heads,key_heads,value_heads,grouped_pairs = cosine_similarity(query_heads, key_heads, value_heads)
                
                # Replace original tensors with reordered ones
                decoder_cross_attention_block.q.weight = torch.nn.Parameter(query_heads) 
                decoder_cross_attention_block.k.weight = torch.nn.Parameter(key_heads)
                decoder_cross_attention_block.v.weight = torch.nn.Parameter(value_heads)

                #this condition needs to be checked if bias is being used, not sure if this works
                if q_bias is not None:
                    q_bias = q_bias[grouped_pairs]
                    k_bias = k_bias[grouped_pairs]
                    v_bias = v_bias[grouped_pairs]

                    decoder_cross_attention_block.q.bias = torch.nn.Parameter(q_bias)
                    decoder_cross_attention_block.k.bias = torch.nn.Parameter(k_bias)
                    decoder_cross_attention_block.v.bias = torch.nn.Parameter(v_bias)


    return out

In [184]:
grouped_model = convert_t5_to_gqa(t5,8)

## FAISS KNN WITH GPU

In [1]:
import torch

b = 32
h = 8
n = 256
d = 512

vals = torch.randn(d,d, device="cuda", dtype=torch.float16)
split_vals = torch.tensor_split(vals,h,dim=1)

In [2]:
type(split_vals)

tuple

In [3]:
type(split_vals[0])

torch.Tensor

In [4]:
from faiss_knn_wrapper import FaissKNNClassifier

fknn = FaissKNNClassifier(3,device="cuda",n_cells=4,algorithm="voronoi")

In [5]:
def tuple_of_tensors_to_tensor(tuple_of_tensors):
    return  torch.stack(list(tuple_of_tensors), dim=0)

In [6]:
fit_vals = tuple_of_tensors_to_tensor(split_vals)

In [7]:
fknn.fit(fit_vals)

TypeError: Wrong number or type of arguments for overloaded function 'new_GpuIndexIVFFlat'.
  Possible C/C++ prototypes are:
    faiss::gpu::GpuIndexIVFFlat::GpuIndexIVFFlat(faiss::gpu::GpuResourcesProvider *,faiss::IndexIVFFlat const *,faiss::gpu::GpuIndexIVFFlatConfig)
    faiss::gpu::GpuIndexIVFFlat::GpuIndexIVFFlat(faiss::gpu::GpuResourcesProvider *,faiss::IndexIVFFlat const *)
    faiss::gpu::GpuIndexIVFFlat::GpuIndexIVFFlat(faiss::gpu::GpuResourcesProvider *,int,int,faiss::MetricType,faiss::gpu::GpuIndexIVFFlatConfig)
    faiss::gpu::GpuIndexIVFFlat::GpuIndexIVFFlat(faiss::gpu::GpuResourcesProvider *,int,int,faiss::MetricType)


In [25]:
len(fit_vals.shape)

3