In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaModel, LlamaConfig

import torch
from torch import nn
from torch.functional import F

from transformers.models.llama.modeling_llama import LlamaConfig, LlamaModel, LlamaRotaryEmbedding, LlamaLinearScalingRotaryEmbedding, \
LlamaDynamicNTKScalingRotaryEmbedding, LlamaAttention, LlamaRMSNorm, LlamaMLP

from transformers.models.llama.configuration_llama import LlamaConfig

from typing import Optional, Tuple


from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", token="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",)

model.safetensors.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]

In [7]:
lc = LlamaConfig()
# lm = LlamaModel(config=lc)

In [4]:
custom_llama_config = LlamaConfig(hidden_size=2048)
CustomLLamaAttn = LlamaAttention(config=custom_llama_config)


class CustomLLamaMLP(nn.Module):
    def __init__(self, original_mlp):
        super(CustomLLamaMLP, self).__init__()
        # Initialize the layers using the existing MLP layers
        self.gate_proj = original_mlp.gate_proj
        self.up_proj = original_mlp.up_proj
        # Copying down_proj from the original MLP
        self.down_proj = original_mlp.down_proj
        # Inserting your new layer here, scaling down to 2048
        self.downscale_proj = nn.Linear(4096, 2048)
        # Keeping the original activation function
        self.act_fn = original_mlp.act_fn

    def forward(self, x):
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        downscale_proj = self.act_fn(self.downscale_proj(down_proj))
        return downscale_proj
    

class CustomLlamaDecoderLayer(nn.Module):
    def __init__(self, config: LlamaConfig, layer_idx: int):
        super().__init__()
        self.hidden_size = config.hidden_size

        self.self_attn = LlamaAttention(config=lc)

        self.mlp = CustomLLamaMLP(LlamaMLP(config))
        self.residual_mlp = CustomLLamaMLP(LlamaMLP(config)) # new MLP to compress the residuals
        self.input_layernorm = LlamaRMSNorm(hidden_size=4096, eps=config.rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(hidden_size=4096, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`, *optional*):
                attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
                query_sequence_length, key_sequence_length)` if default attention is used.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
        """
        if "padding_mask" in kwargs:
            pass
            # warnings.warn(
            #     "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
            # )

        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            **kwargs,
        )

        hidden_states = residual + hidden_states
        
        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)

        residual = self.residual_mlp(residual) # scale down the residuals to the shape of the hidden_states
        
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights,)

        if use_cache:
            outputs += (present_key_value,)

        return outputs


# custom_llama_config = LlamaConfig(hidden_size=2048)
# CustomLLamaAttn = LlamaAttention(config=custom_llama_config)

Instantiating LlamaAttention without passing a `layer_idx` is not recommended and will lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` when creating this class.


In [22]:
# model.model.layers[0] = CustomLlamaDecoderLayer(config=lc, layer_idx=0)

In [5]:
model.model

LlamaModel(
  (embed_tokens): Embedding(32000, 4096)
  (layers): ModuleList(
    (0-31): 32 x LlamaDecoderLayer(
      (self_attn): LlamaAttention(
        (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (rotary_emb): LlamaRotaryEmbedding()
      )
      (mlp): LlamaMLP(
        (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
        (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
        (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
        (act_fn): SiLU()
      )
      (input_layernorm): LlamaRMSNorm()
      (post_attention_layernorm): LlamaRMSNorm()
    )
  )
  (norm): LlamaRMSNorm()
)

In [8]:
# get layer name in string

from transformers.models.llama.modeling_llama import LlamaMLP, LlamaDecoderLayer
# from from transformers.models.llama.modeling_llama.LlamaMLP import down

# add a linear layer to the end of every MLP, before the last activation named act_fn

for i, layer in enumerate(model.model.layers):
    if isinstance(layer, LlamaDecoderLayer):
        if i != 0:
            model.model.layers[i] = LlamaDecoderLayer(config=custom_llama_config, layer_idx=i)
            layer.self_attn = CustomLLamaAttn
            layer.mlp = CustomLLamaMLP(layer.mlp)
        if i == 0:
            print("Here")
            model.model.layers[0] = CustomLlamaDecoderLayer(config=lc, layer_idx=i)

# upscale to 4069 from 2048
model.lm_head = nn.Sequential(
    nn.Linear(2048, 4096, bias=False),
    nn.SiLU(),
    model.lm_head
)

model.norm = LlamaRMSNorm(hidden_size=2048, eps=1e-6)

Here


In [18]:
# model.model.layers[0] = CustomLlamaDecoderLayer(config=lc, layer_idx=i)
model.model.norm = LlamaRMSNorm(hidden_size=2048, eps=1e-6)

In [13]:
model.model


LlamaModel(
  (embed_tokens): Embedding(32000, 4096)
  (layers): ModuleList(
    (0): CustomLlamaDecoderLayer(
      (self_attn): LlamaAttention(
        (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (rotary_emb): LlamaRotaryEmbedding()
      )
      (mlp): CustomLLamaMLP(
        (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
        (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
        (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
        (downscale_proj): Linear(in_features=4096, out_features=2048, bias=True)
        (act_fn): SiLU()
      )
      (residual_mlp): CustomLLamaMLP(
        (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
        (up_pr

In [10]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")

In [21]:
tokens = tokenizer("Hello, my dog is cute", return_tensors="pt")
out = model.generate(**tokens, max_length=10, use_cache=False)

print(tokenizer.decode(out[0]))

<s> Hello, my dog is cuteios commercial


In [None]:
sum(p.numel() for p in model.parameters())

In [102]:
model.model.layers[0].self_attn

LlamaAttention(
  (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
  (rotary_emb): LlamaRotaryEmbedding()
)