In [1]:
from transformers import LlamaForCausalLM, LlamaTokenizer
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv, LlamaDecoderLayer, LlamaMLP, LlamaRMSNorm, LlamaModel, LlamaSdpaAttention, LlamaPreTrainedModel
from transformers.cache_utils import Cache
import torch.nn as nn

from typing import List, Optional, Tuple, Union
import torch

class LlamaForCausalLMResearch(LlamaForCausalLM):
    def __init__(self, config):
        super(LlamaForCausalLM, self).__init__(config)
        self.model = LlamaModelResearch(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

        self.middleware = {}

class LlamaModelResearch(LlamaModel):
    def __init__(self, config):
        super(LlamaModel, self).__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList(
            [LlamaDecoderLayerResearch(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.gradient_checkpointing = False

        # Initialize weights and apply final processing
        self.post_init()

        self.middleware = {}

class LlamaDecoderLayerResearch(LlamaDecoderLayer):
    def __init__(self, config, layer_idx):
        super(LlamaDecoderLayer, self).__init__()
        self.hidden_size = config.hidden_size

        self.self_attn = LlamaSdpaAttentionResearch(config, layer_idx)

        self.mlp = LlamaMLP(config)
        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

class LlamaSdpaAttentionResearch(LlamaSdpaAttention):
    """
    Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
    `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
    SDPA API.
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.middleware = {}

    # Adapted from LlamaAttention.forward
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if output_attentions:
            return super().forward(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
            )

        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        cos, sin = self.rotary_emb(value_states, position_ids)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        # In case static cache is used, it is an instance attribute.
        past_key_value = getattr(self, "past_key_value", past_key_value)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        causal_mask = attention_mask
        # if attention_mask is not None and cache_position is not None:
        if attention_mask is not None:
            causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

        # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
        # Reference: https://github.com/pytorch/pytorch/issues/112577.
        if query_states.device.type == "cuda" and causal_mask is not None:
            query_states = query_states.contiguous()
            key_states = key_states.contiguous()
            value_states = value_states.contiguous()

        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=causal_mask,
            dropout_p=self.attention_dropout if self.training else 0.0,
        )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)
        
        self.middleware.update({"past_key_value": past_key_value})
        return attn_output, None, past_key_value

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = LlamaForCausalLMResearch.from_pretrained("llama-2-7b-hf")
tokenizer = LlamaTokenizer.from_pretrained("llama-2-7b-hf")

Loading checkpoint shards: 100%|██████████| 3/3 [00:02<00:00,  1.24it/s]


In [3]:
model.model.layers

ModuleList(
  (0-31): 32 x LlamaDecoderLayerResearch(
    (self_attn): LlamaSdpaAttentionResearch(
      (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (rotary_emb): LlamaRotaryEmbedding()
    )
    (mlp): LlamaMLP(
      (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
      (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
      (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
      (act_fn): SiLU()
    )
    (input_layernorm): LlamaRMSNorm()
    (post_attention_layernorm): LlamaRMSNorm()
  )
)

In [4]:
from transformers import pipeline
import torch

device = "cuda:0"

pipeline = pipeline(
    'text-generation',
    model=model,
    torch_dtype=torch.float16,
    device = device,
    tokenizer=tokenizer,
)

sequence = "The temperature in transformers.pipeline means"
output = pipeline(sequence, max_length=512, do_sample=True, temperature=0.9)
print(output)

The model 'LlamaForCausalLMResearch' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'LlamaForCausalLM', 'MambaForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MistralForCausalLM', 'MixtralForCausalLM', 'MptForCausalLM', 'MusicgenForCausalLM', 'MusicgenMelodyForCausalLM', 'MvpForCausalLM', 'Open

[{'generated_text': 'The temperature in transformers.pipeline means, to 110°F (43°C) on the coldest day of the year. The warmest temperature in transformers.pipeline falls to 55°F (12°C) on the warmest days, with a median temperature of 73°F (23°C). Weeks with chilly weather are listed below for the city. Weather is more likely to be rain. Snow is rare in this location. It rarely snows in transformers.pipeline. If you are planning to visit transformers.pipeline, you can assume that it will be somewhat on the chilly side and somewhat rainy.\nWhat is the rainiest month in transformers.pipeline?\nThe wettest months in transformers.pipeline are June and July, during which time transformers.pipeline regularly aggregates up to 4" (102mm) of precipitation.\nIs it cold in transformers.pipeline in November?\nIn transformers.pipeline, the average high-temperature in November is essentially the same as in October - a still moderate 69.8°F (21°C). The average low-temperature, in transformers.pipel

In [5]:
attn = model.model.layers[0].self_attn

在Llama的transformers实现中, kv cache由注意力层的中间变量`past_key_value`表示。对于生成式模型，这是一个`DynamicCache`类，定义如下：
```python
class DynamicCache(Cache):
    """
    A cache that grows dynamically as more tokens are generated. This is the default for generative models.

    It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
    `[batch_size, num_heads, seq_len, head_dim]`.
    """

    def __init__(self) -> None:
        self.key_cache: List[torch.Tensor] = []
        self.value_cache: List[torch.Tensor] = []
        self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen
```

In [23]:
getattr(attn, "past_key_value", None)

In [7]:
m = attn.middleware
m['past_key_value']

DynamicCache()

In [15]:
m['past_key_value'].key_cache

[tensor([[[[-0.5439,  0.2430, -0.2144,  ...,  0.3013, -0.2847,  0.3466],
           [-0.2758,  0.0278, -0.0080,  ...,  0.4342, -0.1943,  0.4467],
           [ 0.7921, -0.3864, -0.3918,  ..., -0.3684,  0.3204, -0.3939],
           ...,
           [-0.0125, -0.4789, -0.0823,  ...,  0.1941,  0.1527,  0.2773],
           [ 0.0556, -0.8214,  0.4442,  ..., -0.4149,  0.2700, -0.3972],
           [ 0.0697, -0.1749,  0.4911,  ...,  0.2468,  0.0599,  0.2941]],
 
          [[ 0.6405,  0.8949,  0.4003,  ..., -0.8481,  0.2465, -0.6571],
           [ 0.7500,  0.6643, -0.3054,  ..., -0.0711,  0.3391, -0.0927],
           [-1.1822, -0.9251,  0.2281,  ...,  0.5307, -0.5518,  0.5414],
           ...,
           [ 0.2070,  0.5067,  0.3514,  ..., -0.3669,  0.4428, -0.3299],
           [ 0.4137, -0.0047, -0.2174,  ...,  0.5443, -0.5142,  0.5229],
           [ 0.7394,  0.4232, -0.4000,  ..., -0.3109,  0.3955, -0.2921]],
 
          [[-0.1716,  0.1786,  0.3422,  ...,  1.6794,  1.7912,  1.6603],
           [-

In [17]:
m['past_key_value'].value_cache

[tensor([[[[ 1.1425e-02, -1.3230e-02,  2.6524e-03,  ...,  1.1242e-02,
            -4.3062e-03, -7.1702e-03],
           [ 7.1862e-03, -7.9505e-04,  7.7968e-03,  ...,  8.3490e-04,
             5.7170e-03, -1.1819e-02],
           [ 3.0127e-03, -4.1507e-03, -6.8740e-03,  ...,  1.4032e-03,
             2.2372e-03,  5.5290e-04],
           ...,
           [ 8.7727e-03,  9.9741e-03,  1.5922e-04,  ...,  3.1478e-03,
            -4.5694e-03, -7.9014e-03],
           [ 3.0127e-03, -4.1507e-03, -6.8740e-03,  ...,  1.4032e-03,
             2.2372e-03,  5.5290e-04],
           [ 1.5638e-02,  4.9682e-03,  2.6889e-03,  ...,  1.0554e-02,
             1.4971e-02, -9.4085e-03]],
 
          [[ 6.7095e-03,  2.1722e-03,  6.0971e-03,  ...,  4.5968e-03,
             1.4262e-03, -2.3720e-03],
           [ 8.7663e-03,  4.5856e-03, -4.4579e-03,  ...,  5.2764e-03,
             4.9614e-03, -5.3013e-03],
           [-1.5921e-04,  5.0845e-04, -5.8811e-04,  ...,  1.5965e-03,
             3.6433e-04, -1.0041e-03],


In [24]:
print(m['past_key_value']._seen_tokens)

511


In [29]:
sequence = "SEIEE in SJTU is"
output = pipeline(sequence, max_length=512, do_sample=True, temperature=0.9)
print(output)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


[{'generated_text': 'SEIEE in SJTU is the only international program that meets the needs of the development of this discipline and the demand of international talents and is a prerequisite for the Chinese government to become a member of the global innovation network.\nJia Xingcheng, professor and dean of the department said, “the internationalization of Chinese engineering education is a new topic that everyone needs to pay attention to. According to the trend of engineering talent internationalization, the internationalization of engineering education is a new direction of the development of professional engineering. The internationalization of Chinese engineering education should not only focus on international mobility and transnational collaboration, but also include the establishment of international education systems, the development and improvement of international curriculum and teaching modes, the establishment of international engineering talents and competency evaluation s

In [30]:
print(m['past_key_value']._seen_tokens)

458


# Evaluation

In [11]:
from perplexity import perplexity

device = 'cuda:0'
root = '~/'
dataset = 'PTB'

stride = model.config.max_position_embeddings # 4096
ppl_baseline = perplexity(model, tokenizer, dataset, device, verbose=True, stride=stride, root=root)
print(ppl_baseline)


[117.388s]     Loading dataset PTB
[117.398s]     Encoding dataset


 96%|█████████▋| 27/28 [01:35<00:03,  3.55s/it]


28.380826950073242
