In [1]:
# Authenticates with the Hugging Face Hub using provided API key

import os
from huggingface_hub import login

api_key = os.getenv('HUGGINGFACE_API_KEY')

login(api_key)

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to /home/drossini/.cache/huggingface/token
Login successful


In [2]:
# Loading the model to test the answers with the dataset

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, pipeline

model_name = "microsoft/Phi-3-mini-128k-instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name)
mha_model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [3]:
mha_model

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 1024)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2SdpaAttention(
          (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (o_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=1024, out_features=2816, bias=False)
          (up_proj): Linear(in_features=1024, out_features=2816, bias=False)
          (down_proj): Linear(in_features=2816, out_features=1024, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm()
        (post_attention_layernorm): Qwen2RMSNorm()
      )
    )
    (norm): Qwen2RMSNorm()
  )
  (lm_head): Line

**Try**

In [4]:
attentions_wts = mha_model.state_dict().copy()
num_heads = 16
gqa_groups = num_heads // 4

In [5]:
def split_attention_to_heads(input_tensor, num_splits):
    # Get the shape of the input tensor
    rows, cols = input_tensor.shape

    # Check if the number of rows is divisible by the number of splits
    if rows % num_splits != 0:
        raise ValueError("Number of rows is not divisible by the number of splits")

    # Calculate the number of rows in each split

    # Use chunk to split the tensor along the rows
    split_tensors = input_tensor.chunk(num_splits, dim=0)

    return split_tensors

In [6]:
def average_heads(tensor_tuple, group_size, dtype):
    # Initialize an empty list to store the averaged tensors
    averaged_tensors = []

    # Iterate through the tuple and average consecutive groups
    for i in range(0, len(tensor_tuple), group_size):
        # Take a group of tensors
        tensor_group = tensor_tuple[i:i + group_size]

        # Calculate the mean along dimension 0
        averaged_tensor = torch.mean(torch.stack(tensor_group), dim=0, dtype=dtype)

        # Append the averaged tensor to the list
        averaged_tensors.append(averaged_tensor)

    # Convert the list of averaged tensors to a tuple
    averaged_tensors_tuple = tuple(averaged_tensors)

    return averaged_tensors_tuple

In [7]:
# # Process the weights for no Phi models

# for name_wts in list(attentions_wts.keys()):
#     if len(attentions_wts[name_wts].shape) >= 2:
#         tensor_to_process = attentions_wts[name_wts].clone()
#         torch_dtype = tensor_to_process.dtype
        
#         # Process k_proj weights
#         if "k_proj" in name_wts:
#             attn_heads = split_attention_to_heads(tensor_to_process, num_splits=num_heads)
#             gqa_tensors_grouped = average_heads(attn_heads, gqa_groups, dtype=torch_dtype)
#             new_key = torch.cat(gqa_tensors_grouped)
#             attentions_wts[name_wts] = new_key
        
#         # Process v_proj weights
#         elif "v_proj" in name_wts:
#             attn_heads = split_attention_to_heads(tensor_to_process, num_splits=num_heads)
#             gqa_tensors_grouped = average_heads(attn_heads, gqa_groups, dtype=torch_dtype)
#             new_value = torch.cat(gqa_tensors_grouped)
#             attentions_wts[name_wts] = new_value

# # Process the biases
# for name_bias in list(attentions_wts.keys()):
#     if "bias" in name_bias:
#         bias_tensor_to_process = attentions_wts[name_bias].clone()
#         torch_dtype = bias_tensor_to_process.dtype
        
#         # Process k_proj biases
#         if "k_proj" in name_bias:
#             # Assumes biases can be split similarly, typically biases are 1D
#             attn_heads = split_attention_to_heads(bias_tensor_to_process.unsqueeze(1), num_splits=num_heads)
#             gqa_tensors_grouped = average_heads(attn_heads, gqa_groups, dtype=torch_dtype)
#             new_key_bias = torch.cat(gqa_tensors_grouped).squeeze(1)  # Remove the added dimension
#             attentions_wts[name_bias] = new_key_bias
        
#         # Process v_proj biases
#         elif "v_proj" in name_bias:
#             attn_heads = split_attention_to_heads(bias_tensor_to_process.unsqueeze(1), num_splits=num_heads)
#             gqa_tensors_grouped = average_heads(attn_heads, gqa_groups, dtype=torch_dtype)
#             new_value_bias = torch.cat(gqa_tensors_grouped).squeeze(1)  # Remove the added dimension
#             attentions_wts[name_bias] = new_value_bias

In [8]:
# Process the weights for Phi models

for name_wts in list(attentions_wts.keys()):
    if ("qkv_proj" in name_wts):
        qkv_tensor = attentions_wts[name_wts].clone()
        query_mha = qkv_tensor[0:3072, :]
        key_mha = qkv_tensor[3072:6144, :]
        value_mha = qkv_tensor[6144:9216, :]
        torch_dtype = qkv_tensor.dtype

        new_key = None
        new_value = None
        new_qkv_proj = None
        for idx, tensor_to_convert in enumerate([key_mha, value_mha]):
            attn_heads = split_attention_to_heads(tensor_to_convert, num_splits=num_heads)
            gqa_tensors_grouped = average_heads(attn_heads, gqa_groups, dtype=torch_dtype)
            if idx == 0:
                new_key = torch.cat(gqa_tensors_grouped)
            else:
                new_value = torch.cat(gqa_tensors_grouped)
        new_qkv_proj = torch.cat((query_mha, new_key, new_value), dim=0)
        attentions_wts[name_wts] = new_qkv_proj

In [9]:
# Loading the model to test the answers with the dataset

from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig

model_name = "microsoft/Phi-3-mini-128k-instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name)

gqa_config = AutoConfig.from_pretrained(model_name, num_key_value_heads = 4)
gqa_model = AutoModelForCausalLM.from_config(gqa_config)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [10]:
gqa_model

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 1024)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2SdpaAttention(
          (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (k_proj): Linear(in_features=1024, out_features=256, bias=True)
          (v_proj): Linear(in_features=1024, out_features=256, bias=True)
          (o_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=1024, out_features=2816, bias=False)
          (up_proj): Linear(in_features=1024, out_features=2816, bias=False)
          (down_proj): Linear(in_features=2816, out_features=1024, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm()
        (post_attention_layernorm): Qwen2RMSNorm()
      )
    )
    (norm): Qwen2RMSNorm()
  )
  (lm_head): Linear

In [12]:
# gqa phi3 original wiegths but qkv_proj mean pooled

gqa_model.load_state_dict(attentions_wts)

<All keys matched successfully>

In [13]:
gqa_model.eval() 

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 1024)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2SdpaAttention(
          (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
          (k_proj): Linear(in_features=1024, out_features=256, bias=True)
          (v_proj): Linear(in_features=1024, out_features=256, bias=True)
          (o_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (rotary_emb): Qwen2RotaryEmbedding()
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=1024, out_features=2816, bias=False)
          (up_proj): Linear(in_features=1024, out_features=2816, bias=False)
          (down_proj): Linear(in_features=2816, out_features=1024, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm()
        (post_attention_layernorm): Qwen2RMSNorm()
      )
    )
    (norm): Qwen2RMSNorm()
  )
  (lm_head): Linear

In [15]:
gqa_model.save_pretrained("../home/drossini/models/phi3_mean_pooled/")