In [8]:
from transformers import LlamaForCausalLM, LlamaTokenizer
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv, LlamaDecoderLayer, LlamaMLP, LlamaRMSNorm, LlamaModel, LlamaSdpaAttention, LlamaPreTrainedModel
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch.nn import CrossEntropyLoss
import torch.nn as nn

from typing import List, Optional, Tuple, Union
import torch

class LlamaForCausalLMResearch(LlamaForCausalLM):
    def __init__(self, config):
        super(LlamaForCausalLM, self).__init__(config)
        self.model = LlamaModelResearch(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

        self.middleware = {}
        self.fwcall = 0

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        r"""
        Args:
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:

        Example:

        ```python
        >>> from transformers import AutoTokenizer, LlamaForCausalLM

        >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
        >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```"""
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        self.fwcall += 1
        print(f"fwcall: {self.fwcall}, key cache size(one layer): {past_key_values[0][0].shape if past_key_values is not None else None}")
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
        )

        hidden_states = outputs[0]
        if self.config.pretraining_tp > 1:
            lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
            logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
            logits = torch.cat(logits, dim=-1)
        else:
            logits = self.lm_head(hidden_states)
        logits = logits.float()

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

class LlamaModelResearch(LlamaModel):
    def __init__(self, config):
        super(LlamaModel, self).__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList(
            [LlamaDecoderLayerResearch(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.gradient_checkpointing = False

        # Initialize weights and apply final processing
        self.post_init()

        self.middleware = {}

class LlamaDecoderLayerResearch(LlamaDecoderLayer):
    def __init__(self, config, layer_idx):
        super(LlamaDecoderLayer, self).__init__()
        self.hidden_size = config.hidden_size

        self.self_attn = LlamaSdpaAttentionResearch(config, layer_idx)
        # self.self_attn = LlamaSdpaAttention(config, layer_idx)

        self.mlp = LlamaMLP(config)
        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

class LlamaSdpaAttentionResearch(LlamaSdpaAttention):
    """
    Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
    `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
    SDPA API.
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.middleware = {}

    # Adapted from LlamaAttention.forward
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if output_attentions:
            return super().forward(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
            )

        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        self.middleware.update({"query_states": query_states, "key_states": key_states, "value_states": value_states})

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        cos, sin = self.rotary_emb(value_states, position_ids)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        # In case static cache is used, it is an instance attribute.
        past_key_value = getattr(self, "past_key_value", past_key_value)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        causal_mask = attention_mask
        if attention_mask is not None:
            causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

        # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
        # Reference: https://github.com/pytorch/pytorch/issues/112577.
        if query_states.device.type == "cuda" and causal_mask is not None:
            query_states = query_states.contiguous()
            key_states = key_states.contiguous()
            value_states = value_states.contiguous()

        # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather
        # relying on the `is_causal` argument.
        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=causal_mask,
            dropout_p=self.attention_dropout if self.training else 0.0,
            is_causal=causal_mask is None and q_len > 1,
        )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        return attn_output, None, past_key_value

In [9]:
model = LlamaForCausalLMResearch.from_pretrained("llama-2-7b-hf")
tokenizer = LlamaTokenizer.from_pretrained("llama-2-7b-hf")

Loading checkpoint shards: 100%|██████████| 3/3 [00:02<00:00,  1.21it/s]


In [10]:
model.model.layers

ModuleList(
  (0-31): 32 x LlamaDecoderLayerResearch(
    (self_attn): LlamaSdpaAttentionResearch(
      (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (rotary_emb): LlamaRotaryEmbedding()
    )
    (mlp): LlamaMLP(
      (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
      (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
      (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
      (act_fn): SiLU()
    )
    (input_layernorm): LlamaRMSNorm()
    (post_attention_layernorm): LlamaRMSNorm()
  )
)

先用4096个token生成KV Cache，用来训练PQ。32层、每层32个head，一共需要1024个PQ索引。

In [11]:
from transformers import pipeline
import torch

# device = "cpu"
device = "cuda:0"

pipe = pipeline(
    'text-generation',
    model=model,
    torch_dtype=torch.float16,
    device = device,
    tokenizer=tokenizer,
)

sequence = "SJTU SEIEE is"
# sequence =  "In this work, we develop and release Llama 2, a collection of pretrained and fine-tuned large language models (LLMs) ranging in scale from 7 billion to 70 billion parameters. Our fine-tuned LLMs, called Llama 2-Chat, are optimized for dialogue use cases. Our models outperform open-source chat models on most benchmarks we tested, and based on our human evaluations for helpfulness and safety, may be a suitable substitute for closedsource models. We provide a detailed description of our approach to fine-tuning and safety improvements of Llama 2-Chat in order to enable the community to build on our work and contribute to the responsible development of LLMs."
# sequence += "Large Language Models (LLMs) have shown great promise as highly capable AI assistants that excel in complex reasoning tasks requiring expert knowledge across a wide range of fields, including in specialized domains such as programming and creative writing. They enable interaction with humans through intuitive chat interfaces, which has led to rapid and widespread adoption among the general public."
# sequence += "The capabilities of LLMs are remarkable considering the seemingly straightforward nature of the training methodology. Auto-regressive transformers are pretrained on an extensive corpus of self-supervised data, followed by alignment with human preferences via techniques such as Reinforcement Learning with Human Feedback (RLHF). Although the training methodology is simple, high computational requirements have limited the development of LLMs to a few players. There have been public releases of pretrained LLMs (such as BLOOM (Scao et al., 2022), LLaMa-1 (Touvron et al., 2023), and Falcon (Penedo et al., 2023)) that match the performance of closed pretrained competitors like GPT-3 (Brown et al., 2020) and Chinchilla (Hoffmann et al., 2022), but none of these models are suitable substitutes for closed “product” LLMs, such as ChatGPT, BARD, and Claude. These closed product LLMs are heavily fine-tuned to align with human preferences, which greatly enhances their usability and safety. This step can require significant costs in compute and human annotation, and is often not transparent or easily reproducible, limiting progress within the community to advance AI alignment research."
# sequence += "In this work, we develop and release Llama 2, a family of pretrained and fine-tuned LLMs, Llama 2 and Llama 2-Chat, at scales up to 70B parameters. On the series of helpfulness and safety benchmarks we tested, Llama 2-Chat models generally perform better than existing open-source models. They also appear to be on par with some of the closed-source models, at least on the human evaluations we performed (see Figures 1 and 3). We have taken measures to increase the safety of these models, using safety-specific data annotation and tuning, as well as conducting red-teaming and employing iterative evaluations. Additionally, this paper contributes a thorough description of our fine-tuning methodology and approach to improving LLM safety. We hope that this openness will enable the community to reproduce fine-tuned LLMs and continue to improve the safety of those models, paving the way for more responsible development of LLMs. We also share novel observations we made during the development of Llama 2 and Llama 2-Chat, such as the emergence of tool usage and temporal organization of knowledge."
# sequence += "We are releasing the following models to the general public for research and commercial use‡: 1. Llama 2, an updated version of Llama 1, trained on a new mix of publicly available data. We also increased the size of the pretraining corpus by 40%, doubled the context length of the model, and adopted grouped-query attention (Ainslie et al., 2023). We are releasing variants of Llama 2 with 7B, 13B, and 70B parameters. We have also trained 34B variants, which we report on in this paper but are not releasing.§ 2. Llama 2-Chat, a fine-tuned version of Llama 2 that is optimized for dialogue use cases. We release variants of this model with 7B, 13B, and 70B parameters as well."
# sequence += "We believe that the open release of LLMs, when done safely, will be a net benefit to society. Like all LLMs, Llama 2 is a new technology that carries potential risks with use (Bender et al., 2021b; Weidinger et al., 2021; Solaiman et al., 2023). Testing conducted to date has been in English and has not — and could not — cover all scenarios. Therefore, before deploying any applications of Llama 2-Chat, developers should perform safety testing and tuning tailored to their specific applications of the model. We provide a responsible use guide¶ and code examples‖ to facilitate the safe deployment of Llama 2 and Llama 2-Chat. More details of our responsible release strategy can be found in Section 5.3."
# sequence += "To create the new family of Llama 2 models, we began with the pretraining approach described in Touvron et al. (2023), using an optimized auto-regressive transformer, but made several changes to improve performance. Specifically, we performed more robust data cleaning, updated our data mixes, trained on 40% more total tokens, doubled the context length, and used grouped-query attention (GQA) to improve inference scalability for our larger models. Table 1 compares the attributes of the new Llama 2 models with the Llama 1 models."
# sequence += "Our training corpus includes a new mix of data from publicly available sources, which does not include data from Meta’s products or services. We made an effort to remove data from certain sites known to contain a high volume of personal information about private individuals. We trained on 2 trillion tokens of data as this provides a good performance–cost trade-off, up-sampling the most factual sources in an effort to increase knowledge and dampen hallucinations. We performed a variety of pretraining data investigations so that users can better understand the potential capabilities and limitations of our models; results can be found in Section 4.1."
# sequence += "We adopt most of the pretraining setting and model architecture from Llama 1. We use the standard transformer architecture (Vaswani et al., 2017), apply pre-normalization using RMSNorm (Zhang and Sennrich, 2019), use the SwiGLU activation function (Shazeer, 2020), and rotary positional embeddings (RoPE, Su et al. 2022). The primary architectural differences from Llama 1 include increased context length and grouped-query attention (GQA). We detail in Appendix Section A.2.1 each of these differences with ablation experiments to demonstrate their importance."
# sequence += "We trained using the AdamW optimizer (Loshchilov and Hutter, 2017), with β1 = 0.9, β2 = 0.95, eps = 10−5. We use a cosine learning rate schedule, with warmup of 2000 steps, and decay final learning rate down to 10% of the peak learning rate. We use a weight decay of 0.1 and gradient clipping of 1.0. Figure 5 (a) shows the training loss for Llama 2 with these hyperparameters."
# sequence += "Training Hardware. We pretrained our models on Meta’s Research Super Cluster (RSC) (Lee and Sengupta, 2022) as well as internal production clusters. Both clusters use NVIDIA A100s. There are two key differences between the two clusters, with the first being the type of interconnect available: RSC uses NVIDIA Quantum InfiniBand while our production cluster is equipped with a RoCE (RDMA over converged Ethernet) solution based on commodity ethernet Switches. Both of these solutions interconnect 200 Gbps end-points. The second difference is the per-GPU power consumption cap — RSC uses 400W while our production cluster uses 350W. With this two-cluster setup, we were able to compare the suitability of these different types of interconnect for large scale training. RoCE (which is a more affordable, commercial interconnect network)"
# sequence += "Table 2: CO2 emissions during pretraining. Time: total GPU time required for training each model. Power Consumption: peak power capacity per GPU device for the GPUs used adjusted for power usage efficiency. 100% of the emissions are directly offset by Meta’s sustainability program, and because we are openly releasing these models, the pretraining costs do not need to be incurred by others. can scale almost as well as expensive Infiniband up to 2000 GPUs, which makes pretraining even more democratizable."
# sequence += "Carbon Footprint of Pretraining. Following preceding research (Bender et al., 2021a; Patterson et al., 2021; Wu et al., 2022; Dodge et al., 2022) and using power consumption estimates of GPU devices and carbon efficiency, we aim to calculate the carbon emissions resulting from the pretraining of Llama 2 models. The actual power usage of a GPU is dependent on its utilization and is likely to vary from the Thermal Design Power (TDP) that we employ as an estimation for GPU power. It is important to note that our calculations do not account for further power demands, such as those from interconnect or non-GPU server power consumption, nor from datacenter cooling systems. Additionally, the carbon output related to the production of AI hardware, like GPUs, could add to the overall carbon footprint as suggested by Gupta et al. (2022b,a). Table 2 summarizes the carbon emission for pretraining the Llama 2 family of models. A cumulative of 3.3M GPU hours of computation was performed on hardware of type A100-80GB (TDP of 400W or 350W). We estimate the total emissions for training to be 539 tCO2eq, of which 100% were directly offset by Meta’s sustainability program.∗∗ Our open release strategy also means that these pretraining costs will not need to be incurred by other companies, saving more global resources."
# sequence += "In this section, we report the results for the Llama 1 and Llama 2 base models, MosaicML Pretrained Transformer (MPT)†† models, and Falcon (Almazrouei et al., 2023) models on standard academic benchmarks. For all the evaluations, we use our internal evaluations library. We reproduce results for the MPT and Falcon models internally. For these models, we always pick the best score between our evaluation framework and any publicly reported results. In Table 3, we summarize the overall performance across a suite of popular benchmarks. Note that safety benchmarks are shared in Section 4.1. The benchmarks are grouped into the categories listed below. The results for all the individual benchmarks are available in Section A.2.2."
# sequence += "• Code. We report the average pass@1 scores of our models on HumanEval (Chen et al., 2021) and MBPP (Austin et al., 2021). • Commonsense Reasoning. We report the average of PIQA (Bisk et al., 2020), SIQA (Sap et al., 2019), HellaSwag (Zellers et al., 2019a), WinoGrande (Sakaguchi et al., 2021), ARC easy and challenge (Clark et al., 2018), OpenBookQA (Mihaylov et al., 2018), and CommonsenseQA (Talmor et al., 2018). We report 7-shot results for CommonSenseQA and 0-shot results for all other benchmarks. • World Knowledge. We evaluate the 5-shot performance on NaturalQuestions (Kwiatkowski et al., 2019) and TriviaQA (Joshi et al., 2017) and report the average. • Reading Comprehension. For reading comprehension, we report the 0-shot average on SQuAD (Rajpurkar et al., 2018), QuAC (Choi et al., 2018), and BoolQ (Clark et al., 2019). • MATH. We report the average of the GSM8K (8 shot) (Cobbe et al., 2021) and MATH (4 shot) (Hendrycks et al., 2021) benchmarks at top 1."
# sequence += "As shown in Table 3, Llama 2 models outperform Llama 1 models. In particular, Llama 2 70B improves the results on MMLU and BBH by ≈5 and ≈8 points, respectively, compared to Llama 1 65B. Llama 2 7B and 30B models outperform MPT models of the corresponding size on all categories besides code benchmarks. For the Falcon models, Llama 2 7B and 34B outperform Falcon 7B and 40B models on all categories of benchmarks. Additionally, Llama 2 70B model outperforms all open-source models. In addition to open-source models, we also compare Llama 2 70B results to closed-source models. As shown in Table 4, Llama 2 70B is close to GPT-3.5 (OpenAI, 2023) on MMLU and GSM8K, but there is a significant gap on coding benchmarks. Llama 2 70B results are on par or better than PaLM (540B) (Chowdhery et al., 2022) on almost all benchmarks. There is still a large gap in performance between Llama 2 70B and GPT-4 and PaLM-2-L. We also analysed the potential data contamination and share the details in Section A.6."
# sequence += "Llama 2-Chat is the result of several months of research and iterative applications of alignment techniques, including both instruction tuning and RLHF, requiring significant computational and annotation resources. In this section, we report on our experiments and findings using supervised fine-tuning (Section 3.1), as well as initial and iterative reward modeling (Section 3.2.2) and RLHF (Section 3.2.3). We also share a new technique, Ghost Attention (GAtt), which we find helps control dialogue flow over multiple turns (Section 3.3). See Section 4.2 for safety evaluations on fine-tuned models."
# sequence += "Getting Started. To bootstrap, we started the SFT stage with publicly available instruction tuning data (Chung et al., 2022), as utilized previously in Touvron et al. (2023). Quality Is All You Need. Third-party SFT data is available from many different sources, but we found that many of these have insufficient diversity and quality — in particular for aligning LLMs towards dialogue-style instructions. As a result, we focused first on collecting several thousand examples of high-quality SFT data, as illustrated in Table 5. By setting aside millions of examples from third-party datasets and using fewer but higher-quality examples from our own vendor-based annotation efforts, our results notably improved. These findings are similar in spirit to Zhou et al. (2023), which also finds that a limited set of clean instruction-tuning data can be sufficient to reach a high level of quality. We found that SFT annotations in the order of tens of thousands was enough to achieve a high-quality result. We stopped annotating SFT after collecting a total of 27,540 annotations. Note that we do not include any Meta user data. We also observed that different annotation platforms and vendors can result in markedly different downstream model performance, highlighting the importance of data checks even when using vendors to source annotations. To validate our data quality, we carefully examined a set of 180 examples, comparing the annotations provided by humans with the samples generated by the model through manual scrutiny. Surprisingly, we found that the outputs sampled from the resulting SFT model were often competitive with SFT data handwritten by human annotators, suggesting that we could reprioritize and devote more annotation effort to preference-based annotation for RLHF. Fine-Tuning Details. For supervised fine-tuning, we use a cosine learning rate schedule with an initial learning rate of 2 × 10−5, a weight decay of 0.1, a batch size of 64, and a sequence length of 4096 tokens. For the fine-tuning process, each sample consists of a prompt and an answer. To ensure the model sequence length is properly filled, we concatenate all the prompts and answers from the training set. A special token is utilized to separate the prompt and answer segments. We utilize an autoregressive objective and zero-out the loss on tokens from the user prompt, so as a result, we backpropagate only on answer tokens. Finally, we fine-tune the model for 2 epochs."

output = pipe(sequence, max_length=256, do_sample=True, temperature=0.9)
print(output)

The model 'LlamaForCausalLMResearch' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'JambaForCausalLM', 'LlamaForCausalLM', 'MambaForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MistralForCausalLM', 'MixtralForCausalLM', 'MptForCausalLM', 'MusicgenForCausalLM', 'MusicgenMelo

fwcall: 1, key cache size(one layer): None
fwcall: 2, key cache size(one layer): torch.Size([1, 32, 8, 128])
fwcall: 3, key cache size(one layer): torch.Size([1, 32, 9, 128])
fwcall: 4, key cache size(one layer): torch.Size([1, 32, 10, 128])
fwcall: 5, key cache size(one layer): torch.Size([1, 32, 11, 128])
fwcall: 6, key cache size(one layer): torch.Size([1, 32, 12, 128])
fwcall: 7, key cache size(one layer): torch.Size([1, 32, 13, 128])
fwcall: 8, key cache size(one layer): torch.Size([1, 32, 14, 128])
fwcall: 9, key cache size(one layer): torch.Size([1, 32, 15, 128])
fwcall: 10, key cache size(one layer): torch.Size([1, 32, 16, 128])
fwcall: 11, key cache size(one layer): torch.Size([1, 32, 17, 128])
fwcall: 12, key cache size(one layer): torch.Size([1, 32, 18, 128])
fwcall: 13, key cache size(one layer): torch.Size([1, 32, 19, 128])
fwcall: 14, key cache size(one layer): torch.Size([1, 32, 20, 128])
fwcall: 15, key cache size(one layer): torch.Size([1, 32, 21, 128])
fwcall: 16, key

TextGenerationPipeline中使用了自回归。首次运行时使用给定token作为输入，后面每次输入一个token。key cache size的解释为`(num_batch, num_heads, seq_len, head_dim)`。

In [7]:
attn = model.model.layers[0].self_attn

注意力层中的`q_proj`, `k_proj`, `v_proj`分别是生成$Q$、$K$、$V$的线性层。

In [8]:
attn.q_proj?

[0;31mSignature:[0m      [0mattn[0m[0;34m.[0m[0mq_proj[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mType:[0m           Linear
[0;31mString form:[0m    Linear(in_features=4096, out_features=4096, bias=False)
[0;31mFile:[0m           ~/miniconda3/envs/faiss/lib/python3.11/site-packages/torch/nn/modules/linear.py
[0;31mDocstring:[0m     
Applies a linear transformation to the incoming data: :math:`y = xA^T + b`.

This module supports :ref:`TensorFloat32<tf32_on_ampere>`.

On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.

Args:
    in_features: size of each input sample
    out_features: size of each output sample
    bias: If set to ``False``, the layer will not learn an additive bias.
        Default: ``True``

Shape:
    - Input: :math:`(*, H_{in})` where :math:`*` means any number of
      dimensions including none 

在Llama的transformers实现中, kv cache由注意力层的中间变量`past_key_value`表示。对于生成式模型，这是一个`DynamicCache`类，定义如下：
```python
class DynamicCache(Cache):
    """
    A cache that grows dynamically as more tokens are generated. This is the default for generative models.

    It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
    `[batch_size, num_heads, seq_len, head_dim]`.
    """

    def __init__(self) -> None:
        self.key_cache: List[torch.Tensor] = []
        self.value_cache: List[torch.Tensor] = []
        self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen
```
Cache在LlamaModel前向传播之前被创建，并作为forward的参数流经所有层。生命周期为一次前向传播，在整个模型一次推理之后会被返回。  
Pipeline使用了某种自回归手段（TODO: 查清pipeline的源码）使得一次推理的输出（包括缓存）传给了下一次输入。这使得Cache的生命周期扩展到了一次pipeline。
```python
class LlamaForCausalLM(LlamaPreTrainedModel):
    ...

    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        ...

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
```

In [9]:
# sadly past_key_value is not present as a property
getattr(attn, "past_key_value", None)

In [10]:
m = attn.middleware
m['past_key_value']

DynamicCache()

In [11]:
m['past_key_value'].key_cache

[tensor([[[[-0.5439,  0.2430, -0.2144,  ...,  0.3013, -0.2847,  0.3466],
           [-0.2758,  0.0278, -0.0080,  ...,  0.4342, -0.1943,  0.4467],
           [ 0.7921, -0.3864, -0.3918,  ..., -0.3684,  0.3204, -0.3939],
           ...,
           [-0.0603, -0.0925,  0.0043,  ...,  0.4392, -0.2186,  0.4145],
           [ 0.2962, -0.1703, -0.2377,  ..., -0.0684,  0.1394, -0.0229],
           [-0.1052,  0.2297,  0.1839,  ...,  0.4951, -0.2811,  0.4878]],
 
          [[ 0.6405,  0.8949,  0.4003,  ..., -0.8481,  0.2465, -0.6571],
           [ 0.7500,  0.6643, -0.3054,  ..., -0.0711,  0.3391, -0.0927],
           [-1.1822, -0.9251,  0.2281,  ...,  0.5307, -0.5518,  0.5414],
           ...,
           [ 0.7189, -0.5203, -0.2361,  ...,  0.5054, -0.0025,  0.4217],
           [-0.6590,  0.7698,  0.0271,  ..., -0.5157,  0.0932, -0.3848],
           [-0.6171,  0.5436,  0.5091,  ...,  0.2413,  0.1192,  0.1818]],
 
          [[-0.1716,  0.1786,  0.3422,  ...,  1.6794,  1.7912,  1.6603],
           [-

In [12]:
m['past_key_value'].value_cache

[tensor([[[[ 1.1425e-02, -1.3230e-02,  2.6524e-03,  ...,  1.1242e-02,
            -4.3062e-03, -7.1702e-03],
           [ 7.1862e-03, -7.9505e-04,  7.7968e-03,  ...,  8.3490e-04,
             5.7170e-03, -1.1819e-02],
           [ 3.0127e-03, -4.1507e-03, -6.8740e-03,  ...,  1.4032e-03,
             2.2372e-03,  5.5290e-04],
           ...,
           [-3.5807e-03, -8.5695e-03,  9.7189e-03,  ...,  7.8387e-03,
             3.2028e-03, -9.4971e-03],
           [ 1.0221e-03, -3.2174e-03,  8.4431e-03,  ...,  3.1652e-03,
            -3.3442e-03,  9.0078e-03],
           [ 5.3439e-03,  2.7460e-03,  6.0733e-04,  ...,  1.2366e-02,
            -3.0356e-03,  2.4451e-03]],
 
          [[ 6.7095e-03,  2.1722e-03,  6.0971e-03,  ...,  4.5968e-03,
             1.4262e-03, -2.3720e-03],
           [ 8.7663e-03,  4.5856e-03, -4.4579e-03,  ...,  5.2764e-03,
             4.9614e-03, -5.3013e-03],
           [-1.5921e-04,  5.0845e-04, -5.8811e-04,  ...,  1.5965e-03,
             3.6433e-04, -1.0041e-03],


训练PQ并保存到硬盘

In [None]:
from faiss import IndexPQ
import numpy as np

M = 8
nbits = 4
dim = int(model.config.hidden_size / model.config.num_attention_heads)

indices: List[List[IndexPQ]] = [
    [IndexPQ(dim, M, nbits) for _ in range(model.config.num_attention_heads)] for _ in range(model.config.num_hidden_layers)
]

for i in range(model.config.num_hidden_layers):
    for j in range(model.config.num_attention_heads):
        tmp = model.model.layers[i].self_attn.middleware['value_states'][:, j, :4096, :].view(-1, 128).cpu().numpy()
        # print(tmp.shape)
        indices[i][j].train(tmp)

In [None]:
# save indices to disk
from faiss import write_index
for i in range(model.config.num_hidden_layers):
    for j in range(model.config.num_attention_heads):
        index_filename = f"./pq_index/pq_{i}_{j}.index"
        write_index(indices[i][j], index_filename)

# 实现Cache

In [1]:
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers import LlamaForCausalLM, LlamaTokenizer
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv, LlamaDecoderLayer, LlamaMLP, LlamaRMSNorm, LlamaModel, LlamaSdpaAttention, LlamaPreTrainedModel
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch.nn import CrossEntropyLoss
import torch.nn as nn
import torch.nn.functional as F

from typing import List, Optional, Tuple, Union
import torch

from faiss import IndexPQ, IndexFlatIP

class KeyStateTensorMocker:
    def __init__(self, key_states: Optional[torch.Tensor] = None) -> None:
        self._cache = None
        self._shape = [0] * 4

        if key_states is not None:
            bsz, num_heads, seq_len, head_dim = key_states.shape
            # Initialize the cache list with IndexFlatIP instances
            self._cache = [IndexFlatIP(head_dim) for _ in range(num_heads)]
            self.cat(key_states)
            # Store the shape when key_states is provided
            self._shape = list(key_states.shape)
            

    @property
    def shape(self) -> Optional[Tuple[int, int, int, int]]:
        # Return the shape if available
        return tuple(self._shape)
    
    def cat(self, key_state: torch.Tensor):
        bsz, num_heads, seq_len, head_dim = key_state.shape
        for b in range(bsz):
            for i in range(num_heads):
                self._cache[i].add(key_state[b, i, :, :].cpu().numpy())

        self._shape[2] += seq_len
        
    
    def __getitem__(self, idx: int) -> IndexFlatIP:
        # print("idx", idx)
        return self._cache[idx]

class DatabaseCache(DynamicCache):
    def __init__(self) -> None:
        self.key_cache : List[KeyStateTensorMocker] = []
        self.value_cache : List[torch.Tensor] = []
        self._debug_key_cache : List[torch.Tensor] = []
        self._seen_tokens = 0
    
    def reorder_cache(self, beam_idx: torch.LongTensor):
        raise NotImplementedError("Reordering the cache is not currently supported")

    def query(self, query_states, layer_idx):
        '''
        Basically implements SDPA with cache
        '''
        bsz, num_heads, query_len, head_dim = query_states.shape
        seq_len = self._seen_tokens

        assert bsz == 1, "Batch size > 1 is not currently supported"

        attn_score = torch.zeros(bsz, num_heads, query_len, seq_len, device=query_states.device)
        
        # query the cache
        # TODO: parallelize this
        for b in range(bsz):
            for h in range(num_heads):
                D, I = self.key_cache[layer_idx][h].search(query_states[b, h, :, :].cpu().numpy(), seq_len) # TODO: specify k
                # convert D to tensor
                D = torch.tensor(D, device=query_states.device)
                for (idx, cols) in enumerate(I):
                    for (jdx, col) in enumerate(cols):
                        attn_score[b, h, idx, col] = D[idx, jdx]
        
        # scale & softmax
        scaling_factor = torch.sqrt(torch.tensor(head_dim))
        attn_score = torch.softmax(attn_score / scaling_factor, dim=-1)

        # weighted sum
        attn_output = attn_score @ self.value_cache[layer_idx]
        # attn_output = torch.zeros_like(query_states)
        # for b in range(bsz):
        #     for h in range(num_heads):
        #         for i in range(query_len):
        #             # Each output vector is a sum over all value vectors, weighted by the attention scores
        #             for j in range(seq_len):
        #                 attn_output[b, h, i, :] += attn_score[b, h, i, j] * self.value_cache[layer_idx][b, h, j, :]
        
        return attn_output

    def update(self, key_states, value_states, layer_idx, cache_kwargs=None):
        
        # key_states is shaped (bsz, num_heads, seq_len, head_dim)
        if layer_idx == 0:
            self._seen_tokens += key_states.shape[-2]

        bsz, num_heads, seq_len, head_dim = key_states.shape

        # initialize the cache if it doesn't exist
        if len(self.key_cache) <= layer_idx:
            self.key_cache.append(KeyStateTensorMocker(key_states))
            self.value_cache.append(value_states)
            # self._debug_key_cache.append(key_states)
        else:
            # update the cache
            self.key_cache[layer_idx].cat(key_states)
            self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
            # self._debug_key_cache[layer_idx] = torch.cat([self._debug_key_cache[layer_idx], key_states], dim=-2)

       
class LlamaForCausalLMDB(LlamaForCausalLM):
    def __init__(self, config):
        super(LlamaForCausalLM, self).__init__(config)
        self.model = LlamaModelDB(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.post_init()
        self.middleware = {}
        self.fwcall = 0

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        r"""
        Args:
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:

        Example:

        ```python
        >>> from transformers import AutoTokenizer, LlamaForCausalLM

        >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
        >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```"""
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        self.fwcall += 1
        print(f"fwcall: {self.fwcall}, key cache size(one layer): {past_key_values[0][0].shape if past_key_values is not None else None}")
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
        )

        hidden_states = outputs[0]
        if self.config.pretraining_tp > 1:
            lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
            logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
            logits = torch.cat(logits, dim=-1)
        else:
            logits = self.lm_head(hidden_states)
        logits = logits.float()

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
    
class LlamaModelDB(LlamaModel):
    def __init__(self, config):
        super(LlamaModel, self).__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList(
            [LlamaDecoderLayerDB(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.gradient_checkpointing = False

        # Initialize weights and apply final processing
        self.post_init()

        self.middleware = {}
    
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
            )

        if self.gradient_checkpointing and self.training and use_cache:
            use_cache = False

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        past_seen_tokens = 0
        if use_cache:  # kept for BC (cache positions)
            if not isinstance(past_key_values, StaticCache):
                if past_key_values is None:
                    past_key_values = DatabaseCache()
                past_seen_tokens = past_key_values.get_seq_length()

        if cache_position is None:
            if isinstance(past_key_values, StaticCache):
                raise ValueError("cache_position is a required argument when using StaticCache.")
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)

        # embed positions
        hidden_states = inputs_embeds

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = None

        for decoder_layer in self.layers:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    decoder_layer.__call__,
                    hidden_states,
                    causal_mask,
                    position_ids,
                    past_key_values,
                    output_attentions,
                    use_cache,
                    cache_position,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                )

            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache = layer_outputs[2 if output_attentions else 1]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = None
        if use_cache:
            next_cache = next_decoder_cache
        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

class LlamaDecoderLayerDB(LlamaDecoderLayer):
    def __init__(self, config, layer_idx):
        super(LlamaDecoderLayer, self).__init__()
        self.hidden_size = config.hidden_size

        # self.self_attn = LlamaSdpaAttention(config, layer_idx)
        self.self_attn = LlamaSdpaAttentionDB(config, layer_idx)

        self.mlp = LlamaMLP(config)
        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

class LlamaSdpaAttentionDB(LlamaSdpaAttention):
    """
    Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
    `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
    SDPA API.
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.middleware = {}

    # Adapted from LlamaAttention.forward
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if output_attentions:
            return super().forward(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
            )
        # self.middleware.update({"hidden_states_input" : hidden_states.clone()})
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        cos, sin = self.rotary_emb(value_states, position_ids)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)


        if past_key_value is not None:
            past_key_value.update(key_states, value_states, self.layer_idx)
    
        # sdpa is integrated with the cache
        attn_output = past_key_value.query(query_states, self.layer_idx)

        # to_compare = torch.nn.functional.scaled_dot_product_attention(
        #     query_states, past_key_value._debug_key_cache[self.layer_idx], past_key_value.value_cache[self.layer_idx]
        # )

        self.middleware.update({"query_states" : query_states.clone()})
        self.middleware.update({"key_states" : key_states.clone()})
        self.middleware.update({"value_states" : value_states.clone()})
        self.middleware.update({"past_key_value" : past_key_value})
        # assert torch.allclose(attn_output, to_compare, atol=1e-5), f"Mismatch between SDPA and cache query, layer: {self.layer_idx}"

        

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)
        
        return attn_output, None, past_key_value
    

  from .autonotebook import tqdm as notebook_tqdm


## 使用IndexFlatIP验证正确性

In [11]:
key_states = m['key_states']
value_states = m['value_states']
query_states = m['query_states']

sdpa = torch.nn.functional.scaled_dot_product_attention(
    m['query_states'],
    m['key_states'],
    m['value_states']
)

dbcache = DatabaseCache()
dbcache.update(key_states, value_states, 0)

In [12]:
layer_idx = 0

bsz, num_heads, query_len, head_dim = query_states.shape
seq_len = dbcache._seen_tokens

attn_score = torch.zeros(bsz, num_heads, query_len, seq_len)
for b in range(bsz):
    for h in range(num_heads):
        D, I = dbcache.key_cache[layer_idx][h].search(query_states[b, h, :, :].cpu().numpy(), seq_len) # TODO: specify k
        # convert D to tensor
        D = torch.tensor(D, device=query_states.device)
        for (idx, cols) in enumerate(I):
            for (jdx, col) in enumerate(cols):
                attn_score[b, h, idx, col] = D[idx, jdx]

In [13]:
query_states[0, 0, :, :].shape

torch.Size([1, 128])

In [14]:
index = IndexFlatIP(head_dim)

index.add(key_states[0, 0, :, :].cpu().numpy())

In [15]:
D, I = index.search(query_states[0, 0, :, :].cpu().numpy(), 5)

print(query_states[0, 0, :, :] @ key_states[0, 0, I[0][0], :])
print(D[0][0])

tensor([17.8612], device='cuda:0')
17.861155


In [16]:
attn_score.shape

torch.Size([1, 32, 1, 1])

In [17]:
torch.matmul(query_states, key_states.transpose(-1, -2))

tensor([[[[ 17.8612]],

         [[  2.9589]],

         [[  4.6459]],

         [[ 34.8256]],

         [[ 26.3393]],

         [[  4.4898]],

         [[ 16.9595]],

         [[-29.1442]],

         [[ 15.6031]],

         [[-16.4950]],

         [[ 16.2660]],

         [[ 18.6018]],

         [[ 40.9244]],

         [[-12.6539]],

         [[ -3.8843]],

         [[ 12.1046]],

         [[ 12.3422]],

         [[  9.4559]],

         [[ 21.0365]],

         [[ 36.2610]],

         [[ 11.7032]],

         [[ -4.3956]],

         [[ 24.3023]],

         [[ 17.2784]],

         [[  3.4854]],

         [[ 10.2540]],

         [[ 16.6599]],

         [[ -8.6022]],

         [[ 16.4774]],

         [[ 16.5451]],

         [[-14.7055]],

         [[  1.9440]]]], device='cuda:0')

In [18]:
attn_score

tensor([[[[ 17.8612]],

         [[  2.9589]],

         [[  4.6459]],

         [[ 34.8256]],

         [[ 26.3393]],

         [[  4.4898]],

         [[ 16.9595]],

         [[-29.1442]],

         [[ 15.6031]],

         [[-16.4950]],

         [[ 16.2660]],

         [[ 18.6018]],

         [[ 40.9244]],

         [[-12.6539]],

         [[ -3.8843]],

         [[ 12.1046]],

         [[ 12.3422]],

         [[  9.4559]],

         [[ 21.0365]],

         [[ 36.2610]],

         [[ 11.7032]],

         [[ -4.3956]],

         [[ 24.3023]],

         [[ 17.2784]],

         [[  3.4854]],

         [[ 10.2540]],

         [[ 16.6599]],

         [[ -8.6022]],

         [[ 16.4774]],

         [[ 16.5451]],

         [[-14.7055]],

         [[  1.9440]]]])

In [19]:
# scale & softmax
scaling_factor = torch.sqrt(torch.tensor(head_dim))
attn_score = torch.softmax(attn_score / scaling_factor, dim=-1)

# weighted sum
attn_output = torch.zeros_like(query_states)
for b in range(bsz):
    for h in range(num_heads):
        for i in range(query_len):
            # Each output vector is a sum over all value vectors, weighted by the attention scores
            for j in range(seq_len):
                attn_output[b, h, i, :] += attn_score[b, h, i, j] * dbcache.value_cache[layer_idx][b, h, j, :]

In [24]:
attn_output2 = attn_score @ dbcache.value_cache[layer_idx].cpu()

In [20]:
print(attn_output)

tensor([[[[ 7.5211e-03, -1.0375e-02,  3.5426e-03,  ..., -1.8492e-03,
            3.9925e-03, -4.3669e-03]],

         [[ 2.1980e-03,  3.3925e-03, -2.4338e-03,  ...,  2.2310e-03,
            2.4552e-03,  5.0936e-03]],

         [[-4.0987e-03, -3.7784e-03,  3.0119e-03,  ...,  7.2256e-04,
           -4.6644e-05,  9.7175e-03]],

         ...,

         [[-3.7973e-02, -6.9194e-03, -1.7842e-02,  ...,  7.5553e-03,
            3.7649e-03, -1.2968e-02]],

         [[-1.0510e-02,  1.2955e-03,  8.9638e-04,  ...,  7.3972e-04,
            3.3017e-04,  4.8898e-03]],

         [[-7.9546e-03,  8.9661e-03,  1.2278e-03,  ...,  4.6792e-05,
            1.1711e-03, -9.9407e-03]]]], device='cuda:0')


In [21]:
print(sdpa)

tensor([[[[ 7.5211e-03, -1.0375e-02,  3.5426e-03,  ..., -1.8492e-03,
            3.9925e-03, -4.3669e-03]],

         [[ 2.1980e-03,  3.3925e-03, -2.4338e-03,  ...,  2.2310e-03,
            2.4552e-03,  5.0936e-03]],

         [[-4.0987e-03, -3.7784e-03,  3.0119e-03,  ...,  7.2256e-04,
           -4.6644e-05,  9.7175e-03]],

         ...,

         [[-3.7973e-02, -6.9194e-03, -1.7842e-02,  ...,  7.5553e-03,
            3.7649e-03, -1.2968e-02]],

         [[-1.0510e-02,  1.2955e-03,  8.9638e-04,  ...,  7.3972e-04,
            3.3017e-04,  4.8898e-03]],

         [[-7.9546e-03,  8.9661e-03,  1.2278e-03,  ...,  4.6792e-05,
            1.1711e-03, -9.9407e-03]]]], device='cuda:0')


In [25]:
print(attn_output2)

tensor([[[[ 7.5211e-03, -1.0375e-02,  3.5426e-03,  ..., -1.8492e-03,
            3.9925e-03, -4.3669e-03]],

         [[ 2.1980e-03,  3.3925e-03, -2.4338e-03,  ...,  2.2310e-03,
            2.4552e-03,  5.0936e-03]],

         [[-4.0987e-03, -3.7784e-03,  3.0119e-03,  ...,  7.2256e-04,
           -4.6644e-05,  9.7175e-03]],

         ...,

         [[-3.7973e-02, -6.9194e-03, -1.7842e-02,  ...,  7.5553e-03,
            3.7649e-03, -1.2968e-02]],

         [[-1.0510e-02,  1.2955e-03,  8.9638e-04,  ...,  7.3972e-04,
            3.3017e-04,  4.8898e-03]],

         [[-7.9546e-03,  8.9661e-03,  1.2278e-03,  ...,  4.6792e-05,
            1.1711e-03, -9.9407e-03]]]])


## 使用pipeline进行文本生成

In [2]:
modelDB = LlamaForCausalLMDB.from_pretrained("llama-2-7b-hf")
tokenizer = LlamaTokenizer.from_pretrained("llama-2-7b-hf")

Loading checkpoint shards: 100%|██████████| 3/3 [00:02<00:00,  1.28it/s]


In [3]:
from transformers import pipeline
import torch

# device = "cpu"
device = "cuda:0"

pipe = pipeline(
    'text-generation',
    model=modelDB,
    torch_dtype=torch.float16,
    device = device,
    tokenizer=tokenizer,
)

sequence = "SJTU SEIEE is"

output = pipe(sequence, max_length=256, do_sample=True, temperature=0.9)
print(output)

The model 'LlamaForCausalLMDB' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'JambaForCausalLM', 'LlamaForCausalLM', 'MambaForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MistralForCausalLM', 'MixtralForCausalLM', 'MptForCausalLM', 'MusicgenForCausalLM', 'MusicgenMelodyForC

fwcall: 1, key cache size(one layer): None
fwcall: 2, key cache size(one layer): (1, 32, 8, 128)
fwcall: 3, key cache size(one layer): (1, 32, 9, 128)
fwcall: 4, key cache size(one layer): (1, 32, 10, 128)
fwcall: 5, key cache size(one layer): (1, 32, 11, 128)
fwcall: 6, key cache size(one layer): (1, 32, 12, 128)
fwcall: 7, key cache size(one layer): (1, 32, 13, 128)
fwcall: 8, key cache size(one layer): (1, 32, 14, 128)
fwcall: 9, key cache size(one layer): (1, 32, 15, 128)
fwcall: 10, key cache size(one layer): (1, 32, 16, 128)
fwcall: 11, key cache size(one layer): (1, 32, 17, 128)
fwcall: 12, key cache size(one layer): (1, 32, 18, 128)
fwcall: 13, key cache size(one layer): (1, 32, 19, 128)
fwcall: 14, key cache size(one layer): (1, 32, 20, 128)
fwcall: 15, key cache size(one layer): (1, 32, 21, 128)
fwcall: 16, key cache size(one layer): (1, 32, 22, 128)
fwcall: 17, key cache size(one layer): (1, 32, 23, 128)
fwcall: 18, key cache size(one layer): (1, 32, 24, 128)
fwcall: 19, key

In [10]:
m = modelDB.model.layers[0].self_attn.middleware
cache = m['past_key_value']
query_states = m['query_states']
key_states = m['key_states']
value_states = m['value_states']

In [5]:
query_states.shape

torch.Size([1, 32, 8, 128])

In [6]:
cache.key_cache[0].shape

(1, 32, 8, 128)

In [10]:
to_compare = torch.nn.functional.scaled_dot_product_attention(
    query_states, key_states, value_states
)

In [11]:
to_compare.shape

torch.Size([1, 32, 1, 128])

In [14]:
query_result = cache.query(query_states, 0)

In [13]:
print(to_compare)

tensor([[[[ 3.6696e-03, -7.9229e-03,  8.4558e-03,  ..., -4.5834e-03,
           -2.3658e-03, -1.7671e-04]],

         [[ 4.3845e-04, -5.5317e-03,  5.8105e-03,  ...,  8.9691e-03,
           -4.0985e-04, -5.5041e-03]],

         [[-6.8097e-03,  3.5033e-03,  4.1462e-03,  ...,  5.9926e-03,
            2.3143e-04, -2.0504e-03]],

         ...,

         [[-2.8282e-02,  4.3157e-02, -1.2844e-02,  ...,  3.6885e-02,
            7.5483e-03, -1.6709e-03]],

         [[ 9.6602e-05,  6.6361e-03,  8.9406e-03,  ..., -1.1882e-02,
           -5.2853e-03,  1.6323e-02]],

         [[ 3.4548e-03,  4.0896e-03, -1.1464e-02,  ..., -3.3808e-04,
            4.1257e-03,  6.8261e-03]]]], device='cuda:0')


In [16]:
print(query_result)

tensor([[[[ 1.3781e-03, -5.7980e-04,  2.9330e-03,  ...,  7.4228e-04,
           -1.4491e-03, -1.3022e-04]],

         [[-2.8401e-03, -1.2746e-03, -3.7840e-03,  ...,  2.6666e-03,
           -4.3657e-04, -1.0268e-03]],

         [[-3.1803e-04,  1.7788e-03,  2.1568e-03,  ...,  1.2615e-03,
            2.9185e-03,  8.2977e-04]],

         ...,

         [[-1.4362e-02, -2.4401e-03, -1.8101e-02,  ...,  2.2499e-02,
           -1.3045e-02,  7.9051e-03]],

         [[ 4.3792e-03,  4.5733e-03, -9.5786e-04,  ...,  1.0060e-03,
           -8.3106e-05,  2.5171e-03]],

         [[ 1.1982e-04,  3.7192e-03, -1.3928e-03,  ...,  3.4574e-03,
            1.4829e-03,  1.6591e-03]]]], device='cuda:0')


In [None]:
from faiss import IndexPQ
index = IndexPQ(128, 8, 4)

In [None]:
print(m['value_states'].shape)

indices = [IndexPQ(128, 8, 4) for _ in range(32)]

for i, index in enumerate(indices):
    tmp = m['value_states'][:, i, :, :].view(-1, 128).cpu().numpy()
    index.train(tmp)
    index.add(tmp)

In [None]:
m['value_states'][:, i, :, :].view(-1, 128).cpu().numpy().shape

In [None]:
index.add(m['value_states'].cpu().numpy())

In [None]:
sequence = "SEIEE in SJTU is"
output = pipe(sequence, max_length=512, do_sample=True, temperature=0.9)
print(output)

从这里我们可以看到，一个新的文本生成任务会使用一个新的缓存。

In [None]:
print(m['past_key_value']._seen_tokens)

In [None]:
m['hidden_states_input'].shape

In [None]:
model.__call__?

In [None]:
model.forward?

In [None]:
model

# 乘法器优化

In [None]:
def _merge_heads(tensor, num_heads, attn_head_size):
    """
    Merges attn_head_size dim and num_attn_heads dim into hidden_size
    """
    tensor = tensor.permute(0, 2, 1, 3).contiguous()
    new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
    return tensor.view(new_shape)

In [None]:
print("k.shape", m['past_key_value'].key_cache[0].shape)
print("v.shape", m['past_key_value'].value_cache[0].shape)

In [None]:
# make a copy of q, k
q = m['past_key_value'].key_cache[0].clone()
k = m['past_key_value'].value_cache[0].clone()

# merge heads
num_heads = 32
attn_head_size = 128
q = _merge_heads(q, num_heads, attn_head_size)
k = _merge_heads(k, num_heads, attn_head_size)

print("q.shape", q.shape)
print("k.shape", k.shape)

In [None]:
# Assuming q and k are defined as torch.Tensor with the given shapes
# q.shape == torch.Size([1, 357, 4096])
# k.shape == torch.Size([1, 357, 4096])

# Expand q and k to prepare for element-wise multiplication
# Expand q to [1, 357, 1, 4096] and k to [1, 1, 357, 4096]
q_expanded = q.unsqueeze(2)  # Adding a singleton dimension for broadcasting
k_expanded = k.unsqueeze(1)  # Adding a different singleton dimension for broadcasting

# Element-wise multiplication
# Result shape will be [1, 357, 357, 4096], capturing each multiplication
multiplication_results = q_expanded * k_expanded

# To verify, let's examine the shape
print("Shape of multiplication_results:", multiplication_results.shape)

In [None]:
multiplication_results.reshape(-1)

In [None]:
len(multiplication_results.reshape(-1))

In [None]:
import matplotlib.pyplot as plt

# plt.hist(multiplication_results.reshape(-1).cpu().numpy(), bins=100, log=True)
# cdf of abs(multiplication_results)

plt.hist(multiplication_results.reshape(-1).abs().cpu().numpy(), bins=100, cumulative=True, density=True, histtype='step')
plt.xscale("log")

plt.title(f"Q*K^T, {q.shape[1]} * {q.shape[2]} * {k.shape[1]}")

In [None]:
num_heads = 32
attn_head_size = 128

# store multiplication results in a list
multiplication_results = []


# for i, layer in enumerate(model.model.layers):
for i in [0,15,31]:
    layer = model.model.layers[i]
    

    m = layer.self_attn.middleware
    
    q = m['past_key_value'].key_cache[0].clone()
    k = m['past_key_value'].value_cache[0].clone()

    q = _merge_heads(q, num_heads, attn_head_size)
    k = _merge_heads(k, num_heads, attn_head_size)

    q_expanded = q.unsqueeze(2)
    k_expanded = k.unsqueeze(1)
    tmp = q_expanded * k_expanded

    # offload to cpu
    tmp.cpu()

    # free cuda memory
    del q, k, q_expanded, k_expanded

    multiplication_results.append(tmp)

# flatten the list
multiplication_results = torch.cat([tmp.reshape(-1) for tmp in multiplication_results])

In [None]:
import matplotlib.pyplot as plt

# plt.hist(multiplication_results.reshape(-1).cpu().numpy(), bins=100, log=True)
# cdf of abs(multiplication_results)

plt.hist(multiplication_results.reshape(-1).abs().cpu().numpy(), bins=100, cumulative=True, density=True, histtype='step')
plt.xscale("log")

plt.title(f"Q*K^T, {q.shape[1]} * {q.shape[2]} * {k.shape[1]}")

# Evaluation

In [None]:
from perplexity import perplexity

device = 'cuda:0'
root = '~/'
dataset = 'PTB'

stride = model.config.max_position_embeddings # 4096
ppl_baseline = perplexity(model, tokenizer, dataset, device, verbose=True, stride=stride, root=root)
print(ppl_baseline)
