# Research

In [1]:
from transformers import LlamaForCausalLM, LlamaTokenizer
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv, LlamaDecoderLayer, LlamaMLP, LlamaRMSNorm, LlamaModel, LlamaSdpaAttention, LlamaPreTrainedModel
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch.nn import CrossEntropyLoss
import torch.nn as nn

from typing import List, Optional, Tuple, Union
import torch

class LlamaForCausalLMResearch(LlamaForCausalLM):
    def __init__(self, config):
        super(LlamaForCausalLM, self).__init__(config)
        self.model = LlamaModelResearch(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

        self.middleware = {}
        self.fwcall = 0

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        r"""
        Args:
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:

        Example:

        ```python
        >>> from transformers import AutoTokenizer, LlamaForCausalLM

        >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
        >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```"""
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        self.fwcall += 1
        print(f"fwcall: {self.fwcall}, key cache size(one layer): {past_key_values[0][0].shape if past_key_values is not None else None}")
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
        )

        hidden_states = outputs[0]
        if self.config.pretraining_tp > 1:
            lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
            logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
            logits = torch.cat(logits, dim=-1)
        else:
            logits = self.lm_head(hidden_states)
        logits = logits.float()

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

class LlamaModelResearch(LlamaModel):
    def __init__(self, config):
        super(LlamaModel, self).__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList(
            [LlamaDecoderLayerResearch(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.gradient_checkpointing = False

        # Initialize weights and apply final processing
        self.post_init()

        self.middleware = {}

class LlamaDecoderLayerResearch(LlamaDecoderLayer):
    def __init__(self, config, layer_idx):
        super(LlamaDecoderLayer, self).__init__()
        self.hidden_size = config.hidden_size

        self.self_attn = LlamaSdpaAttentionResearch(config, layer_idx)
        # self.self_attn = LlamaSdpaAttention(config, layer_idx)

        self.mlp = LlamaMLP(config)
        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

class LlamaSdpaAttentionResearch(LlamaSdpaAttention):
    """
    Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
    `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
    SDPA API.
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.middleware = {}

    # Adapted from LlamaAttention.forward
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if output_attentions:
            return super().forward(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
            )

        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        self.middleware.update({"query_states": query_states, "key_states": key_states, "value_states": value_states})

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        cos, sin = self.rotary_emb(value_states, position_ids)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        # In case static cache is used, it is an instance attribute.
        # past_key_value = getattr(self, "past_key_value", past_key_value)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)

        # key_states = repeat_kv(key_states, self.num_key_value_groups)
        # value_states = repeat_kv(value_states, self.num_key_value_groups)

        # causal_mask = attention_mask
        # if attention_mask is not None:
        #     causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]

        # # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
        # # Reference: https://github.com/pytorch/pytorch/issues/112577.
        # if query_states.device.type == "cuda" and causal_mask is not None:
        #     query_states = query_states.contiguous()
        #     key_states = key_states.contiguous()
        #     value_states = value_states.contiguous()
        
        causal_mask = None

        # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather
        # relying on the `is_causal` argument.
        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=causal_mask,
            dropout_p=self.attention_dropout if self.training else 0.0,
            is_causal=causal_mask is None and q_len > 1,
        )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        self.middleware.update({
            "past_key_value" : past_key_value,
            "key_states": key_states
        })

        return attn_output, None, past_key_value

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from transformers import AutoModelForCausalLM
issubclass(LlamaForCausalLM, AutoModelForCausalLM)

False

In [2]:
model = LlamaForCausalLMResearch.from_pretrained("llama-2-7b-hf")
tokenizer = LlamaTokenizer.from_pretrained("llama-2-7b-hf")

Loading checkpoint shards: 100%|██████████| 3/3 [00:02<00:00,  1.46it/s]


In [3]:
model.model.layers

ModuleList(
  (0-31): 32 x LlamaDecoderLayerResearch(
    (self_attn): LlamaSdpaAttentionResearch(
      (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
      (rotary_emb): LlamaRotaryEmbedding()
    )
    (mlp): LlamaMLP(
      (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
      (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
      (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
      (act_fn): SiLU()
    )
    (input_layernorm): LlamaRMSNorm()
    (post_attention_layernorm): LlamaRMSNorm()
  )
)

先用4096个token生成KV Cache，用来训练PQ。32层、每层32个head，一共需要1024个PQ索引。

In [14]:
from transformers import pipeline
import torch

device = "cpu"
# device = "cuda:0"

pipe = pipeline(
    'text-generation',
    model=model,
    torch_dtype=torch.float16,
    device = device,
    tokenizer=tokenizer,
)

# sequence = "SJTU SEIEE is"
sequence =  "In this work, we develop and release Llama 2, a collection of pretrained and fine-tuned large language models (LLMs) ranging in scale from 7 billion to 70 billion parameters. Our fine-tuned LLMs, called Llama 2-Chat, are optimized for dialogue use cases. Our models outperform open-source chat models on most benchmarks we tested, and based on our human evaluations for helpfulness and safety, may be a suitable substitute for closedsource models. We provide a detailed description of our approach to fine-tuning and safety improvements of Llama 2-Chat in order to enable the community to build on our work and contribute to the responsible development of LLMs."
sequence += "Large Language Models (LLMs) have shown great promise as highly capable AI assistants that excel in complex reasoning tasks requiring expert knowledge across a wide range of fields, including in specialized domains such as programming and creative writing. They enable interaction with humans through intuitive chat interfaces, which has led to rapid and widespread adoption among the general public."
sequence += "The capabilities of LLMs are remarkable considering the seemingly straightforward nature of the training methodology. Auto-regressive transformers are pretrained on an extensive corpus of self-supervised data, followed by alignment with human preferences via techniques such as Reinforcement Learning with Human Feedback (RLHF). Although the training methodology is simple, high computational requirements have limited the development of LLMs to a few players. There have been public releases of pretrained LLMs (such as BLOOM (Scao et al., 2022), LLaMa-1 (Touvron et al., 2023), and Falcon (Penedo et al., 2023)) that match the performance of closed pretrained competitors like GPT-3 (Brown et al., 2020) and Chinchilla (Hoffmann et al., 2022), but none of these models are suitable substitutes for closed “product” LLMs, such as ChatGPT, BARD, and Claude. These closed product LLMs are heavily fine-tuned to align with human preferences, which greatly enhances their usability and safety. This step can require significant costs in compute and human annotation, and is often not transparent or easily reproducible, limiting progress within the community to advance AI alignment research."
sequence += "In this work, we develop and release Llama 2, a family of pretrained and fine-tuned LLMs, Llama 2 and Llama 2-Chat, at scales up to 70B parameters. On the series of helpfulness and safety benchmarks we tested, Llama 2-Chat models generally perform better than existing open-source models. They also appear to be on par with some of the closed-source models, at least on the human evaluations we performed (see Figures 1 and 3). We have taken measures to increase the safety of these models, using safety-specific data annotation and tuning, as well as conducting red-teaming and employing iterative evaluations. Additionally, this paper contributes a thorough description of our fine-tuning methodology and approach to improving LLM safety. We hope that this openness will enable the community to reproduce fine-tuned LLMs and continue to improve the safety of those models, paving the way for more responsible development of LLMs. We also share novel observations we made during the development of Llama 2 and Llama 2-Chat, such as the emergence of tool usage and temporal organization of knowledge."
sequence += "We are releasing the following models to the general public for research and commercial use‡: 1. Llama 2, an updated version of Llama 1, trained on a new mix of publicly available data. We also increased the size of the pretraining corpus by 40%, doubled the context length of the model, and adopted grouped-query attention (Ainslie et al., 2023). We are releasing variants of Llama 2 with 7B, 13B, and 70B parameters. We have also trained 34B variants, which we report on in this paper but are not releasing.§ 2. Llama 2-Chat, a fine-tuned version of Llama 2 that is optimized for dialogue use cases. We release variants of this model with 7B, 13B, and 70B parameters as well."
sequence += "We believe that the open release of LLMs, when done safely, will be a net benefit to society. Like all LLMs, Llama 2 is a new technology that carries potential risks with use (Bender et al., 2021b; Weidinger et al., 2021; Solaiman et al., 2023). Testing conducted to date has been in English and has not — and could not — cover all scenarios. Therefore, before deploying any applications of Llama 2-Chat, developers should perform safety testing and tuning tailored to their specific applications of the model. We provide a responsible use guide¶ and code examples‖ to facilitate the safe deployment of Llama 2 and Llama 2-Chat. More details of our responsible release strategy can be found in Section 5.3."
sequence += "To create the new family of Llama 2 models, we began with the pretraining approach described in Touvron et al. (2023), using an optimized auto-regressive transformer, but made several changes to improve performance. Specifically, we performed more robust data cleaning, updated our data mixes, trained on 40% more total tokens, doubled the context length, and used grouped-query attention (GQA) to improve inference scalability for our larger models. Table 1 compares the attributes of the new Llama 2 models with the Llama 1 models."
sequence += "Our training corpus includes a new mix of data from publicly available sources, which does not include data from Meta’s products or services. We made an effort to remove data from certain sites known to contain a high volume of personal information about private individuals. We trained on 2 trillion tokens of data as this provides a good performance–cost trade-off, up-sampling the most factual sources in an effort to increase knowledge and dampen hallucinations. We performed a variety of pretraining data investigations so that users can better understand the potential capabilities and limitations of our models; results can be found in Section 4.1."
sequence += "We adopt most of the pretraining setting and model architecture from Llama 1. We use the standard transformer architecture (Vaswani et al., 2017), apply pre-normalization using RMSNorm (Zhang and Sennrich, 2019), use the SwiGLU activation function (Shazeer, 2020), and rotary positional embeddings (RoPE, Su et al. 2022). The primary architectural differences from Llama 1 include increased context length and grouped-query attention (GQA). We detail in Appendix Section A.2.1 each of these differences with ablation experiments to demonstrate their importance."
sequence += "We trained using the AdamW optimizer (Loshchilov and Hutter, 2017), with β1 = 0.9, β2 = 0.95, eps = 10−5. We use a cosine learning rate schedule, with warmup of 2000 steps, and decay final learning rate down to 10% of the peak learning rate. We use a weight decay of 0.1 and gradient clipping of 1.0. Figure 5 (a) shows the training loss for Llama 2 with these hyperparameters."
sequence += "Training Hardware. We pretrained our models on Meta’s Research Super Cluster (RSC) (Lee and Sengupta, 2022) as well as internal production clusters. Both clusters use NVIDIA A100s. There are two key differences between the two clusters, with the first being the type of interconnect available: RSC uses NVIDIA Quantum InfiniBand while our production cluster is equipped with a RoCE (RDMA over converged Ethernet) solution based on commodity ethernet Switches. Both of these solutions interconnect 200 Gbps end-points. The second difference is the per-GPU power consumption cap — RSC uses 400W while our production cluster uses 350W. With this two-cluster setup, we were able to compare the suitability of these different types of interconnect for large scale training. RoCE (which is a more affordable, commercial interconnect network)"
sequence += "Table 2: CO2 emissions during pretraining. Time: total GPU time required for training each model. Power Consumption: peak power capacity per GPU device for the GPUs used adjusted for power usage efficiency. 100% of the emissions are directly offset by Meta’s sustainability program, and because we are openly releasing these models, the pretraining costs do not need to be incurred by others. can scale almost as well as expensive Infiniband up to 2000 GPUs, which makes pretraining even more democratizable."
sequence += "Carbon Footprint of Pretraining. Following preceding research (Bender et al., 2021a; Patterson et al., 2021; Wu et al., 2022; Dodge et al., 2022) and using power consumption estimates of GPU devices and carbon efficiency, we aim to calculate the carbon emissions resulting from the pretraining of Llama 2 models. The actual power usage of a GPU is dependent on its utilization and is likely to vary from the Thermal Design Power (TDP) that we employ as an estimation for GPU power. It is important to note that our calculations do not account for further power demands, such as those from interconnect or non-GPU server power consumption, nor from datacenter cooling systems. Additionally, the carbon output related to the production of AI hardware, like GPUs, could add to the overall carbon footprint as suggested by Gupta et al. (2022b,a). Table 2 summarizes the carbon emission for pretraining the Llama 2 family of models. A cumulative of 3.3M GPU hours of computation was performed on hardware of type A100-80GB (TDP of 400W or 350W). We estimate the total emissions for training to be 539 tCO2eq, of which 100% were directly offset by Meta’s sustainability program.∗∗ Our open release strategy also means that these pretraining costs will not need to be incurred by other companies, saving more global resources."
sequence += "In this section, we report the results for the Llama 1 and Llama 2 base models, MosaicML Pretrained Transformer (MPT)†† models, and Falcon (Almazrouei et al., 2023) models on standard academic benchmarks. For all the evaluations, we use our internal evaluations library. We reproduce results for the MPT and Falcon models internally. For these models, we always pick the best score between our evaluation framework and any publicly reported results. In Table 3, we summarize the overall performance across a suite of popular benchmarks. Note that safety benchmarks are shared in Section 4.1. The benchmarks are grouped into the categories listed below. The results for all the individual benchmarks are available in Section A.2.2."
sequence += "• Code. We report the average pass@1 scores of our models on HumanEval (Chen et al., 2021) and MBPP (Austin et al., 2021). • Commonsense Reasoning. We report the average of PIQA (Bisk et al., 2020), SIQA (Sap et al., 2019), HellaSwag (Zellers et al., 2019a), WinoGrande (Sakaguchi et al., 2021), ARC easy and challenge (Clark et al., 2018), OpenBookQA (Mihaylov et al., 2018), and CommonsenseQA (Talmor et al., 2018). We report 7-shot results for CommonSenseQA and 0-shot results for all other benchmarks. • World Knowledge. We evaluate the 5-shot performance on NaturalQuestions (Kwiatkowski et al., 2019) and TriviaQA (Joshi et al., 2017) and report the average. • Reading Comprehension. For reading comprehension, we report the 0-shot average on SQuAD (Rajpurkar et al., 2018), QuAC (Choi et al., 2018), and BoolQ (Clark et al., 2019). • MATH. We report the average of the GSM8K (8 shot) (Cobbe et al., 2021) and MATH (4 shot) (Hendrycks et al., 2021) benchmarks at top 1."
sequence += "As shown in Table 3, Llama 2 models outperform Llama 1 models. In particular, Llama 2 70B improves the results on MMLU and BBH by ≈5 and ≈8 points, respectively, compared to Llama 1 65B. Llama 2 7B and 30B models outperform MPT models of the corresponding size on all categories besides code benchmarks. For the Falcon models, Llama 2 7B and 34B outperform Falcon 7B and 40B models on all categories of benchmarks. Additionally, Llama 2 70B model outperforms all open-source models. In addition to open-source models, we also compare Llama 2 70B results to closed-source models. As shown in Table 4, Llama 2 70B is close to GPT-3.5 (OpenAI, 2023) on MMLU and GSM8K, but there is a significant gap on coding benchmarks. Llama 2 70B results are on par or better than PaLM (540B) (Chowdhery et al., 2022) on almost all benchmarks. There is still a large gap in performance between Llama 2 70B and GPT-4 and PaLM-2-L. We also analysed the potential data contamination and share the details in Section A.6."
sequence += "Llama 2-Chat is the result of several months of research and iterative applications of alignment techniques, including both instruction tuning and RLHF, requiring significant computational and annotation resources. In this section, we report on our experiments and findings using supervised fine-tuning (Section 3.1), as well as initial and iterative reward modeling (Section 3.2.2) and RLHF (Section 3.2.3). We also share a new technique, Ghost Attention (GAtt), which we find helps control dialogue flow over multiple turns (Section 3.3). See Section 4.2 for safety evaluations on fine-tuned models."
sequence += "Getting Started. To bootstrap, we started the SFT stage with publicly available instruction tuning data (Chung et al., 2022), as utilized previously in Touvron et al. (2023). Quality Is All You Need. Third-party SFT data is available from many different sources, but we found that many of these have insufficient diversity and quality — in particular for aligning LLMs towards dialogue-style instructions. As a result, we focused first on collecting several thousand examples of high-quality SFT data, as illustrated in Table 5. By setting aside millions of examples from third-party datasets and using fewer but higher-quality examples from our own vendor-based annotation efforts, our results notably improved. These findings are similar in spirit to Zhou et al. (2023), which also finds that a limited set of clean instruction-tuning data can be sufficient to reach a high level of quality. We found that SFT annotations in the order of tens of thousands was enough to achieve a high-quality result. We stopped annotating SFT after collecting a total of 27,540 annotations. Note that we do not include any Meta user data. We also observed that different annotation platforms and vendors can result in markedly different downstream model performance, highlighting the importance of data checks even when using vendors to source annotations. To validate our data quality, we carefully examined a set of 180 examples, comparing the annotations provided by humans with the samples generated by the model through manual scrutiny. Surprisingly, we found that the outputs sampled from the resulting SFT model were often competitive with SFT data handwritten by human annotators, suggesting that we could reprioritize and devote more annotation effort to preference-based annotation for RLHF. Fine-Tuning Details. For supervised fine-tuning, we use a cosine learning rate schedule with an initial learning rate of 2 × 10−5, a weight decay of 0.1, a batch size of 64, and a sequence length of 4096 tokens. For the fine-tuning process, each sample consists of a prompt and an answer. To ensure the model sequence length is properly filled, we concatenate all the prompts and answers from the training set. A special token is utilized to separate the prompt and answer segments. We utilize an autoregressive objective and zero-out the loss on tokens from the user prompt, so as a result, we backpropagate only on answer tokens. Finally, we fine-tune the model for 2 epochs."
sequence += "Transformer-based large language models (LLMs) have achieved great success with the growing model size. LLMs’ size grows by 240× every two years, which outpaces the hardware progress and makes model inference increasingly costly. Model quantization is a promising approach to mitigate the widening gap between LLM size and hardware capacity. However, the existence of outliers, values with significant magnitudes, in LLMs makes existing quantization methods less effective. Prior outlier-aware quantization schemes adopt sparsity encoding techniques to separate outliers from normal values where the process requires global coordination (e.g., a global sparsity coordination list). This incurs complex encoding/decoding hardware logics and an extra orchestration controller for the computation between outlier and normal values. As such, it is not hardware-efficient and hence only achieves sub-optimal quantization benefits."
sequence += "We propose OliVe, an algorithm/architecture co-designed solution that adopts an outlier-victim pair (OVP) quantization and handles outlier values locally with low hardware overheads and high performance gains. The key insight of OliVe is that outliers are important while the normal values next to them are not. Thus those normal values (called victims) can be sacrificed to accommodate outliers. This enables a memory-aligned OVP encoding scheme, which can be efficiently integrated to the existing hardware accelerators like systolic array and tensor core. As a result, OliVe-based accelerator surpasses the existing outlier-aware accelerator, GOBO, by 4.5× speedup and 4.0× energy reduction, respectively, with a superior model accuracy."
sequence += "Large language models (LLMs) have shown great promise as highly capable AI assistants that excel in complex reasoning tasks requiring expert knowledge across a wide range of fields, including in specialized domains such as programming and creative writing. They enable interaction with humans through intuitive chat interfaces, which has led to rapid and widespread adoption among the general public."
sequence += "The capabilities of LLMs are remarkable considering the seemingly straightforward nature of the training methodology. Auto-regressive transformers are pretrained on an extensive corpus of self-supervised data, followed by alignment with human preferences"
sequence += "The aforementioned outlier-aware architectures separate normal values from outliers in a global way. For instance, GOBO [85] involves a global sparse coordinate list in the quantization and computation, leading to a large hardware overhead and low performance benefits. In this work, we aim to design an architecture to handle outliers in a localized way with high hardware efficiency. To achieve that, we group two consecutive fixed-size values in a tensor and analyze their impact to model accuracy. There can be three kinds of pairs: i) a normal pair with two normal values, ii) one-outlier pair with one normal value and one outlier value, iii) two-outlier pair with two outlier values. We observe that the third two-outlier pair almost never shows up in well-trained LLMs. For the second one-outlier pair, we find that only keeping its outlier value while pruning its normal value (i.e., treating it as zero) is sufficient to maintain the model accuracy. Based on the above observations, we propose a novel outlieraware quantization architecture, called OliVe, based on the outliervictim pair (OVP) encoding. The salient feature of OliVe is memoryaligned and therefore hardware-friendly. As illustrated in Fig. 1b, OliVe first prunes normal values that are adjacent to the outliers as zero. These pruned normal values are called victims, which sacrifice themselves and make space for outliers. Then, we exploit the extra space provided by victims and embed the outliers into the low-precision matrix."
sequence += "OliVe is able to maintain a high accuracy for large Transformer models with a low hardware overhead due to the following reasons. First, OliVe incorporates victims to tackle outliers in LLMs. The effects of victims resemble model pruning [36]. Although clipping a few (0.1%) outliers will lead to a disastrous accuracy drop [18, 82], pruning the same amount of “normal” values will only impact model accuracy slightly (< 0.1% drop). Therefore, OliVe sacrifices (“prunes”) those insignificant values as victims for the outliers, allowing a more aggressive encoding scheme to accommodate extremely significant values. Second, the OVP encoding follows a specific outlier-victim (or victim-outlier) pattern to achieve memory alignment with little hardware overheads. Each victim is adjacent to an outlier, and the outlier-victim pair must align the memory access pattern. For example, in Fig. 1b, right outlier −98 in the OV pair needs a left victim, and left outliers 17.6 and 30.7 require the right victims. That can align 8-bit (1-byte) memory accesses with high efficiency. This design enables a completely localized outlier decoding/encoding process."
sequence += "The secure container that hosts a single container in a micro virtual machine (VM) is now used in serverless computing, as the containers are isolated through the microVMs. There are high demands on the high-density container deployment and high-concurrency container startup to improve both the resource utilization and user experience, as user functions are fine-grained in serverless platforms. Our investigation shows that the entire software stacks, containing the cgroups in the host operating system, the guest operating system, and the container rootfs for the function workload, together result in low deployment density and slow startup performance at high-concurrency. We propose and implement a lightweight secure container runtime, named RunD, to resolve the above problems through a holistic guest-tohost solution. With RunD, over 200 secure containers can be started in a second, and over 2,500 secure containers can be deployed on a node with 384GB of memory. RunD is adopted as Alibaba serverless container runtime to support high-density deployment and high-concurrency startup."
sequence += "Based on different levels of security/isolation requirements, there are generally two categories of secure containers in the production environments. Figure 2(a) shows the multi-container-per-VM secure container model that only isolates functions. In the model, a virtual machine (VM) hosts the containers for the invocations of the same function. The containers in the same VM share the guest operating system of the VM. In this case, the invocations to different functions are isolated, but the invocations to the same function are not isolated. Since the number of required containers for each function varies, this model results in memory fragmentations [34]. Though the memory fragmentations can be reclaimed at runtime, it may significantly affect the function performance, and even crash the VM when the memory hot-unplug fails. Figure 2(b) shows the single-container-per-VM secure container model that isolates each function invocation. Current serverless computing providers [1, 20] mainly use this secure container model. In this model, each invocation is served with a container in a microVM. This model does not introduce memory fragmentations, but the microVMs themselves show heavy memory overhead. It is obvious that each microVM needs to run its exclusive guest operating system, multiplying the memory footprints."
sequence += "Requirement on high-concurrency container startup. In serverless platforms, each function invocation is short, and a large number of function invocations may arrive in a short time. For example, in Alibaba serverless platform, more than 200 container-launch requests arrive nearly simultaneously on a node. The latency until all containers have entered main() can swell super-proportionally due to resource contention among the simultaneously launching VMs. Meanwhile, emerging internet services often show a diurnal load pattern and have bursty loads [18]. A large number of containers are required to be created when the load bursts. Some techniques, such as prewarming containers [31, 42, 49], are able to alleviate container cold startups. However, bursty loads are inevitable can easily exhaust the limited prewarmed containers. The ability to startup containers at high-concurrency is crucial for serverless platforms. Requirement on high-density container deployment. The small container specification in a serverless computing platform brings the requirement to deploy containers densely on a node. For instance, 47% of lambda functions run with the minimum memory specification of 128MB in AWS [5]. The actual memory usage of a container may also be smaller than its specifications. As Azure reports [49], about 90% of the applications never consume more than 400MB of memory. A node with 256GB of memory can host 8 × 256 = 2048 containers if there is no other overhead. In Alibaba serverless platform, over 2,500 secure containers that 128MB-sized can be deployed on a node with 384GB memory. Without proactive customizations, secure containers incur extra memory overhead, reducing deployment density in serverless computing. Increasing deployment density greatly improves resource utilization and multi-tenant serving efficiency with the same infrastructure."
sequence += "With the default configuration (cache enabled), virtio-blk performs best at random/sequential writing. However, the device-mapper who prepares the block device in the host cannot meet the high-concurrency requirement [59]. According to our measurement, it takes as high as 10 seconds to prepare a rootfs when 200 containers are started concurrently, while it only takes about 30 milliseconds for a single container startup. In this case, the operation of preparing rootfs timeouts, resulting in the container breakdown. Moreover, virtio-blk inherently does not support the page cache sharing between host and guest operating systems. When virtio-blk backend reads rootfs files into the host page cache, the mapped content reproduces the same page cache in the guest. The issue of duplicated page cache brings a high memory footprint overhead."
sequence += "Except for the memory used by the user function, the memory footprint of other components in the secure container is the memory overhead. The 5MB memory overhead reported in FireCracker [52] is the overhead of the FireCracker VMM itself. In the microVM of a secure container, the guest operating system, the struct page for memory management, and other components (e.g., baseOS, shimv2, agent) also consume additional memory space [52]. Figure 5 shows the per-container memory overhead of secure containers with different memory specifications and at different deployment densities. In the figure, Kata-qemu is the secure container that uses qemu as the hypervisor, and Kata-FireCracker uses FireCracker as the hypervisor. As observed in Figure 5(a), the memory overheads of a 128MB container are 94MB and 168MB with Kata-FireCracker and Kata-qemu, respectively. The overhead increases with the memory specification of the container. The average memory footprint of a single microVM can be reduced by sharing the text/rodata segment among multiple microVMs. Mainstream MicroVMs achieve it by mapping the kernel file to the guest memory directly using mmap. As shown in Figure 5(b), the per-microVM memory overhead of kata-qemu and kata-FireCracker reduce to 145MB and 71MB when 1,000 VMs are deployed on a node. However, the overhead is still too large for a serverless container with only 128MB memory specification."
sequence += "However, the template technique is not as efficient as we thought, due to the self-modifying codes in the operating system kernel [24, 25]. The self-modifying code technique alters the instructions on-demand as it runs, and the Linux kernel relies heavily on self-modification code to improve performance on boot and during runtime. We start a clean microVM with CentOS 4.19 guest kernel from a template to investigate the impact of self-modifying codes. The investigation shows that 10,012KB of the code and the read-only data is accessed in the memory, but 7,928KB of them were modified during boot. This case in point reveals that the selfmodifying codes degrade the efficiency when using mmap for less memory consumption of kernel image files."
sequence += "Cgroup is designed for resource control and abstraction of processes. In serverless computing, the frequency of function invocations shows high variation. In this case, the corresponding secure containers are frequently created and recycled. For instance, in our serverless platform, at most 200 containers would be created and recycled on a physical node concurrently in a second. The frequent creating and recycling challenge the cgroup mechanism on the host. We measure the performance of cgroup operations when creating 2, 000 containers concurrently. In the experiment, we use different numbers of threads to perform cgroup operations. Figure 6(a) shows the cumulative distribution of container creating latencies. Counter-intuitively, the latency increases when more threads are used, even if each thread needs to create fewer containers. The reason behind the above fact is that the Linux kernel introduces several global locks (e.g., cgroup_mutex, css_set_lock, freezer_mutex) to serialize cgroup operations. The global locks are used to coordinate more than 10 resource subsystems (aka. the cgroup subsys) involved in cgroup. Figure 6(b) shows the flame graph of creating 2, 000 cgroups using 10 threads concurrently. In the figure, the red parts show the case that “mutex locks” are active. When the cgroup mutex uses the optimistic spinning by default, the spinner cgroups experience the optimistic spinning if they fail to acquire the lock. It will lead to heavy CPU consumption and belated exiting of the critical section in the multi-"
sequence += "Figure 7 shows the RunD design and summarizes our methodologies. RunD runtime makes a read/write splitting by providing the read-only layer to virtio-fs, using the builtin storage file to create a volatile writeable layer to virtioblk, and mounting the former and latter as the final container rootfs using overlayfs. RunD leverages the microVM template that integrates the condensed kernel and adopts the prepatched image to create a new microVM, further amortizing the overhead across different microVMs. RunD renames and attaches a lightweight cgroup from the cgroup pool for management when a secure container is created. Based on the above optimizations, a secure container (referred to as a “sandbox”) is started in the following steps, when RunD is used as the secure container runtime. • In the first step, once containerd receives a user invocation, it forwards the request to RunD runtime. • Second, RunD prepares the runc-container rootfs for the virtual machine hypervisor. The rootfs is separated into read-only layer and writable layer. (Section 4.2). • Third, the hypervisor uses the microVM template to create the required sandbox (Section 4.3), and mount the runc-container rootfs into the sandbox by overlayfs. • Lastly, a lightweight cgroup is attached to the sandbox (Section 4.4), to manage the resource allocation for this sandbox in the host."
sequence += "We investigate the data in a sandbox in the serverless computing scenario, and find that user-provided code/data is read-only for the operating system, and the systemprovided runtime files are also read-only for user functions. Meanwhile, the data in the local memory or storage generated in a sandbox will not be used by subsequent function invocations, due to the stateless feature of serverless computing. The temporary and intermediate data generated during the function execution is not required to be persisted. Based on the above finding, it is possible to split the rootfs into a read-only layer and a writable layer, and then handle them in different ways [32]. The sandboxes can share the read-only layer on the same node, and the writable layer has to be prepared separately for each sandbox. Figure 8 shows the way to split rootfs into a read-only layer and a volatile writable layer. According to the investigation in Section 3.1, virtio-fs is used to handle the read-only layer, and virtio-blk is used to handle the volatile writable layer for better performance. The read-only layer is stored in the host and can be prepared in negligible time when using the overlay snapshotter provided by the container runtime. However, it is challenging to handle the volatile writable layer efficiently. By default, the host operating system needs to prepare a logic storage volume for the sandbox. This operation is time-consuming and is one of the most important reasons that result in the long latency of preparing rootfs."
sequence += "Following the abstraction premise in current serverless platforms, the guest environment management for serverless containers is offloaded to the cloud provider. Meanwhile, RunD depends on the security model of hardware virtualization and VMM, explicitly treating the guest kernel as untrusted through syscall inspections. Based on this fact, there is an opportunity to condense the guest kernel for the lightweight characteristic of serverless functions. Considering that several features in the guest kernel are redundant and memory intensive in the serverless context, RunD condenses these features at compile-time. When customizing the condensed guest kernel, the principles behind it are as follows: - Minimize kernel memory footprint and image size. - Retain features required in the serverless context. - Without runtime performance degradation. Following the above principles, we build the condensed kernel for the guest operating system based on Linux kernel, by disabling features: - Do not pre-create loop device (2.2MB Mem reduced). - Disable acpi and ftrace (2MB and 6MB Mem reduced). - Disable graphics-related items (2MB Mem reduced). - Disable i2c and ceph (3MB Mem reduced, and 4MB reduced of kernel image size). - Kernel files (560K Mem and 571K image size reduced). Validating all features at compile-time case by case, RunD effectively reduces the memory footprint of a CentOS 4.19 Linux kernel by about 16MB and condenses the kernel image by about 4MB. Based on this condensed guest kernel, we"
sequence += "As mentioned before, cloud providers manage and maintain the underlying hardware and execution runtimes in serverless context, standing for that all microVMs on the same node generally use the same guest kernel. In this scenario, the sandboxes on the same node generate the same patched kernel code, even if they execute the self-modification patch logic. This is because the self-modifying code of kernel text segments only occurs at the startup phase, after which the kernel code area becomes “read-only after initialization”. In this case, sandboxes experience the same initialization phase and generate predicable self-modifying code segments. Based on the above observation, there is an opportunity to generate a pre-patch guest kernel image file already patched with self-modified code segments. The MicroVM template technique discussed in Section 3.2 may work efficiently without self-modifying code. Adapting to this optimization, we also resolve the potential kernel panic issues when loading the pre-patched kernel image for higher stability. RunD tries to share as many kernel files as possible across different secure containers. With a pre-patched microVM template, RunD not only reduces the memory footprint of a single container for higherdensity deployment, but also allows to quickly fork multiple instances [29, 52]."
sequence += "The cgroup pool with renaming mechanism eliminates the time-consuming cgroup creation and initialization. RunD pre-creates corresponding lightweight cgroups and maintains them in a cgroup pool based on the pre-defined node capacity. These cgroups are marked idle when initialized, and are protected in a linked list. For each created container, RunD simply allocates an idle cgroup, updates the state to busy, performs the cgroup rename operation, and then attaches the container to this renamed cgroup when a container is started. If a container triggers recycling, RunD will take the cgroup back to the pool, kill the corresponding instance process, and then update the returned cgroup state to idle for subsequent allocating and renaming. Adopting the above optimizations in kernel mode, we replay the evaluation in Section 3.3. The cgroups creation only consumes 0.09s (1 thread), 0.1s (50 threads), and 0.14s (200 threads), respectively. Compared with the default mechanism, the lightweight cgroup and the rename-based cgroup pool reduce 94% of the cgroups creation time."
sequence += "Baselines: we compare RunD with the state-of-theart secure container, Kata Containers [19]. Specifically, we use three popular configurations of Kata containers: Kata-qemu, Kata-template, and Kata-FC. Kata-qemu uses QEMU [15, 23] as the microVM hypervisor, Kata-template uses QEMU while integrating container template, Kata-FC uses lightweight FireCracker [20] as the microVM hypervisor. Kata-qemu and kata-template use an old version of Kata Containers, as the new version has some bugs that result in poor performance. Table 1 shows the detailed setups. Testbed: we run the experiments on a node with 104 virtual cores, 384GB memory, and two SSD drives of 100GB and 500GB. Such specification is widely-used in production clouds. The 100GB drive is used as the root filesystem of the host operating system, and the 500GB drive is used by the secure containers. We use Alibaba Cloud Linux 2 for RunD and Alpine Linux [3] for others, as the guest operating systems in the microVM for a low memory footprint. Measurement: in the CRI specification [6], a pod sandbox refers to a microVM with a lightweight pause container [12]. In all the tests, we only create the pod sandboxes without other containers inside, through the crictl command. In the following evaluations, the memory specification of a container denotes the size of memory that can be used by itself. The actual memory usage of a container is collected using the smem command. As RunD is proposed to maximize the supported container startup concurrency and deployment density, in the experiment, we start empty secure containers without user codes or data considering that it is a common practice in FaaS to start empty containers concurrently for prewarming. The inproduction results show the performance of RunD for actual workloads with all the steps involved."
sequence += "As shown in the figure, RunD uses the shortest time to start a large number of sandboxes for all concurrency levels. When 200 containers are created concurrently (we already observe such high-concurrency in Alibaba serverless platform), Kata-FC, kata-qemu, kata-template, and RunD needs 47.6s, 6.85s and 2.98s and 1s to create them. Kata-FC requires a much longer time to startup the sandboxes when the concurrency is high. This is because Kata-FC uses virtioblk to create rootfs, and the performance is poor at highconcurrency, as we measured in Section 3. There is no such bottleneck in Kata-template and Kata-qemu. Kata-template simply uses template to reduce the overhead of guest kernel and rootfs loading, but the inefficient rootfs mapping, code self-modification and high host-side overhead of the cgroup operations still exists. As a result, it performs worse than RunD at high startup concurrency. The overall optimizations suggest that RunD provides the performance improvement of about 40% over its nearest baseline, Kata-template, at highconcurrency (e.g., 400-way) startup. As for the second metric, Figure 10(b) shows the latency distribution of starting each sandbox, when 200 sandboxes are started concurrently. RunD and Kata-template are able to start sandboxes in a stable short time, but the latencies of starting sandboxes with others are out of expected. Users can have identical good experiences with RunD. As for the CPU overhead, Figure 10(c) shows the CPU time needed on the host to startup sandboxes. When the concurrency is high, RunD greatly reduces the CPU overhead. For instance, when 200 sandboxes are started concurrently, RunD reduces 89.3%, 74.5% and 62.1% CPU overhead compared with Kata-qemu, Kata-template, and KataFC, respectively. In addition, the CPU overhead of RunD only increases slightly, when the concurrency increases. This is due to the read/write split policy and the reduction of compute-intensive operations in cgroups. Therefore, RunD"
sequence += "The memory overhead of RunD is 5MB, which is the overhead of the RunD runtime itself. The memory overhead of the microVM is 94MB for a 128MB container, and 168MB for a 256MB container. The memory overhead increases with the memory specification of the container. The average memory overhead of a single microVM can be reduced by sharing the text/rodata segment among multiple microVMs. As shown in Figure 5(b), the per-microVM memory overhead of kata-qemu and kata-FireCracker reduce to 145MB and 71MB when 1,000 VMs are deployed on a node. However, the overhead is still too large for a serverless container with only 128MB memory specification."
sequence += "Currently, Alibaba serverless computing platform has adopted RunD. The platform serves almost 4 billion invocations from more than 1 million different functions per day. Figure 14 reports the sandbox startup concurrency and the corresponding startup latency from six nodes. The specification of each node is the same as our experimental setup in Table 1. The data is collected between 08:00 and 18:00 of Jan 10th, 2022. There are about 800 active sandboxes on each node, when the concurrency data is collected. The inproduction startup latency of sandboxes at high-concurrency is consistent with that reported in Section 5.4. As observed from the figure, the startup concurrency bursts at the beginning of each hour. At most 191 sandboxes are started concurrently around 10:00. RunD starts the 191 sandboxes in 1.6 seconds. We look into the function invocation logs, and find that the periodic burst is caused by the an-hour time trigger and cluster-level load balancing. The periodical burst is pervasive, as the Azure serverless platform traces [14] show the same pattern. In the figure, the sandbox startup latency occasionally increases when the concurrency is low. The long time results from the operation in loading large-scale workloads from the tenants. Although the startup"
sequence += "The most closely related work to RunD is FireCracker [20], which proposes a lightweight VMM for serverless runtime. It provides fast startup within 125ms, allowing 150 VMs to start concurrently per second per node, with less than 5MB footprint per VM. However, FireCracker only serves as the hypervisor stack in the Security Container model, without other complex related processes, e.g., rootfs [52]. By contrast, RunD investigates the guest-to-host solution through all stacks and provides higher concurrency and density. Higher-density deployment. Regarding serverless computing, in the space of higher function deployment density of Secure Containers and VMs [57], the key is designing a more lightweight container runtime both in guest and host. Unikernel [36, 37, 43, 47] runs as a built-in GuestOS without necessary add-ons, demonstrating great potential for deploying containers with less overhead. Kuo [33] Explores lightweight guest kernel configurations for use in Unikernel environments, which has similarity to the approach towards reducing guest kernel size. However, Unikernel is hard to be changed once after compilation with the application. Its compile-time invariance results in poor flexibility in practice. SAND [21] adopts the multi-container-per-VM model to amortize the memory footprint of sandboxing. However, they do not further investigate the utilization impact of memory fragmentations in a real-system with high-density deployment. Gsight [61] observes that fine-grained functionlevel profiling can expose more predictability system-level features in the partial interference. With a more accurate interference predicting [27, 44], the function density can get improved with QoS guaranteed. The above studies make sense in improving the effective density with less interference for serverless. They are orthogonal to our work, because RunD is motivated to improve the maximum deployment density on a signe node. Higher-concurrency startup. In the space of higher function startup concurrency, recent approaches leverage the container prewarm pool [9, 40, 49, 58]. The state-of-the-art on container prewarming, SOCK [42], uses a benefit-to-cost model to select packages pre-installed in zygotes, and builds a tree cache to ensure that the forked zygote container does not import any additional packages other than the private ones the handler specifies. The C/R (Checkpoint/Restore) [7, 31, 39] supporting the VM snapshotting [10, 28, 29, 41, 54] captures the state of a running instance as a checkpoint, and then restores it once cold startup. Observing that most functions only access a small fraction of the files and mem-"


output = pipe(sequence, max_new_tokens=2, do_sample=True, temperature=0.9)
print(output)

The model 'LlamaForCausalLMResearch' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'JambaForCausalLM', 'LlamaForCausalLM', 'MambaForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MistralForCausalLM', 'MixtralForCausalLM', 'MptForCausalLM', 'MusicgenForCausalLM', 'MusicgenMelo

fwcall: 3, key cache size(one layer): None


This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (4096). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.


fwcall: 4, key cache size(one layer): torch.Size([1, 32, 10869, 128])
[{'generated_text': 'In this work, we develop and release Llama 2, a collection of pretrained and fine-tuned large language models (LLMs) ranging in scale from 7 billion to 70 billion parameters. Our fine-tuned LLMs, called Llama 2-Chat, are optimized for dialogue use cases. Our models outperform open-source chat models on most benchmarks we tested, and based on our human evaluations for helpfulness and safety, may be a suitable substitute for closedsource models. We provide a detailed description of our approach to fine-tuning and safety improvements of Llama 2-Chat in order to enable the community to build on our work and contribute to the responsible development of LLMs.Large Language Models (LLMs) have shown great promise as highly capable AI assistants that excel in complex reasoning tasks requiring expert knowledge across a wide range of fields, including in specialized domains such as programming and creative 

TextGenerationPipeline中使用了自回归。首次运行时使用给定token作为输入，后面每次输入一个token。key cache size的解释为`(num_batch, num_heads, seq_len, head_dim)`。

In [15]:
attn = model.model.layers[0].self_attn

注意力层中的`q_proj`, `k_proj`, `v_proj`分别是生成$Q$、$K$、$V$的线性层。

In [16]:
attn.q_proj?

[0;31mSignature:[0m      [0mattn[0m[0;34m.[0m[0mq_proj[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mType:[0m           Linear
[0;31mString form:[0m    Linear(in_features=4096, out_features=4096, bias=False)
[0;31mFile:[0m           ~/miniconda3/envs/faiss/lib/python3.11/site-packages/torch/nn/modules/linear.py
[0;31mDocstring:[0m     
Applies a linear transformation to the incoming data: :math:`y = xA^T + b`.

This module supports :ref:`TensorFloat32<tf32_on_ampere>`.

On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.

Args:
    in_features: size of each input sample
    out_features: size of each output sample
    bias: If set to ``False``, the layer will not learn an additive bias.
        Default: ``True``

Shape:
    - Input: :math:`(*, H_{in})` where :math:`*` means any number of
      dimensions including none 

在Llama的transformers实现中, kv cache由注意力层的中间变量`past_key_value`表示。对于生成式模型，这是一个`DynamicCache`类，定义如下：
```python
class DynamicCache(Cache):
    """
    A cache that grows dynamically as more tokens are generated. This is the default for generative models.

    It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
    `[batch_size, num_heads, seq_len, head_dim]`.
    """

    def __init__(self) -> None:
        self.key_cache: List[torch.Tensor] = []
        self.value_cache: List[torch.Tensor] = []
        self._seen_tokens = 0  # Used in `generate` to keep tally of how many tokens the cache has seen
```
Cache在LlamaModel前向传播之前被创建，并作为forward的参数流经所有层。生命周期为一次前向传播，在整个模型一次推理之后会被返回。  
Pipeline使用了某种自回归手段（TODO: 查清pipeline的源码）使得一次推理的输出（包括缓存）传给了下一次输入。这使得Cache的生命周期扩展到了一次pipeline。
```python
class LlamaForCausalLM(LlamaPreTrainedModel):
    ...

    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        ...

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
```

In [17]:
# sadly past_key_value is not present as a property
getattr(attn, "past_key_value", None)

In [18]:
m = attn.middleware
m['past_key_value']

DynamicCache()

In [19]:
m['past_key_value'].key_cache

[tensor([[[[-0.3384,  0.0469,  0.0529,  ...,  0.2462, -0.1481,  0.2926],
           [-0.0552, -0.1907,  0.0224,  ..., -0.0031,  0.0241,  0.0113],
           [ 0.1668, -0.2733, -0.3675,  ...,  0.0711,  0.1794,  0.1150],
           ...,
           [-0.1707,  0.5088,  0.7048,  ..., -0.0091,  0.0510, -0.0450],
           [-0.5672, -0.1113,  0.3188,  ..., -0.3986, -0.5565, -0.1698],
           [ 0.1198, -0.0046, -0.0380,  ...,  0.4629,  0.6372,  0.3343]],
 
          [[ 0.6409,  0.8173,  0.1992,  ..., -0.5353,  0.1221, -0.4024],
           [-0.2616,  0.2207,  0.2423,  ..., -0.8128,  0.2036, -0.6366],
           [-0.2250,  0.2389,  0.1041,  ..., -0.2533,  0.4262, -0.2529],
           ...,
           [-0.5541, -0.4107, -0.2449,  ..., -0.2505, -0.4789, -0.1813],
           [ 0.7244,  1.0885,  0.3598,  ..., -0.2154, -0.0510, -0.4382],
           [ 0.0048,  0.2987, -0.1986,  ..., -0.0262, -0.3108,  0.2186]],
 
          [[-0.4030,  0.0613,  0.3968,  ...,  1.6628,  1.7025,  1.5833],
           [ 

In [20]:
m['past_key_value'].value_cache

[tensor([[[[-3.0323e-03,  2.9014e-03, -2.5995e-03,  ...,  3.7256e-04,
            -6.8101e-03,  2.4478e-03],
           [ 4.9791e-03,  1.8254e-03,  4.7753e-04,  ..., -1.6144e-04,
             6.5963e-04,  1.7619e-03],
           [ 3.3913e-03, -4.6228e-03, -1.4544e-02,  ...,  4.5442e-03,
             6.1396e-03, -7.7402e-03],
           ...,
           [ 5.6766e-03,  1.4015e-04,  1.0991e-03,  ..., -3.7779e-03,
            -5.6904e-04,  1.0857e-02],
           [-1.0003e-03, -7.0993e-03,  3.6075e-03,  ..., -5.5194e-04,
             1.9539e-03,  4.3085e-03],
           [-2.4467e-03,  1.1709e-02,  8.7955e-03,  ..., -1.6282e-03,
            -8.2406e-03, -1.0132e-02]],
 
          [[ 3.1156e-03,  3.2496e-04,  1.1357e-02,  ...,  1.8874e-03,
             7.7895e-03, -5.8534e-03],
           [-1.5936e-03, -3.9710e-03, -5.3878e-03,  ..., -7.5102e-04,
            -3.9153e-03, -1.0004e-03],
           [-3.3756e-03, -4.5582e-03, -1.3133e-02,  ...,  1.0132e-02,
            -2.0331e-03, -7.9861e-03],


训练PQ并保存到硬盘

In [21]:
from faiss import IndexPQ, METRIC_INNER_PRODUCT
import numpy as np

M = 16
nbits = 8
dim = int(model.config.hidden_size / model.config.num_attention_heads)

train_data = np

indices: List[List[IndexPQ]] = [
    [IndexPQ(dim, M, nbits, METRIC_INNER_PRODUCT) for _ in range(model.config.num_attention_heads)] for _ in range(model.config.num_hidden_layers)
]

for i in range(model.config.num_hidden_layers):
    for j in range(model.config.num_attention_heads):
        tmp = model.model.layers[i].self_attn.middleware['key_states'][:, j, :10000, :].view(-1, 128).cpu().numpy()
        # print(tmp.shape)
        indices[i][j].train(tmp)

In [22]:
# save indices to disk
from faiss import write_index
for i in range(model.config.num_hidden_layers):
    for j in range(model.config.num_attention_heads):
        index_filename = f"./pq_index/pq_{i}_{j}.index"
        write_index(indices[i][j], index_filename)

# 实现Cache

In [23]:
from faiss import read_index

def restore_index(layer_idx: int, head_idx: int):
    """
    Restore a FAISS IndexPQ from disk.
    
    Args:
    layer_idx (int): The index of the layer in the model.
    head_idx (int): The index of the attention head within the layer.

    Returns:
    faiss.IndexPQ: The restored FAISS index.
    """
    index_filename = f"./pq_index/pq_{layer_idx}_{head_idx}.index"
    return read_index(index_filename)



In [24]:
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers import LlamaForCausalLM, LlamaTokenizer
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv, LlamaDecoderLayer, LlamaMLP, LlamaRMSNorm, LlamaModel, LlamaSdpaAttention, LlamaPreTrainedModel
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch.nn import CrossEntropyLoss
import torch.nn as nn
import torch.nn.functional as F

from typing import List, Optional, Tuple, Union
import torch
import math

from faiss import IndexPQ, IndexFlatIP

class KeyStateTensorMocker:
    def __init__(self, key_states: Union[torch.Tensor, 'KeyStateTensorMocker'], layer_idx: int) -> None:
        self._cache = None
        self._shape = None
        self._debug_cache = None
        
        self.layer_idx = layer_idx
        if key_states is not None:
            if isinstance(key_states, KeyStateTensorMocker):
                # This is used when from_legacy_cache is called
                self._cache = key_states._cache
                self._shape = key_states._shape
                self._debug_cache = key_states._debug_cache
                return

            bsz, num_heads, seq_len, head_dim = key_states.shape
            # self._cache = [IndexFlatIP(head_dim) for _ in range(num_heads)]
            self._cache = [restore_index(layer_idx, i) for i in range(num_heads)]
            self._shape = [bsz, num_heads, 0, head_dim]

            self.cat(key_states)          

    @property
    def shape(self) -> Optional[Tuple[int, int, int, int]]:
        # Return the shape if available
        return tuple(self._shape)
    
    def cat(self, key_states: torch.Tensor) -> None:
        '''
        Update or init the cache with key_states.
        '''
        bsz, num_heads, seq_len, head_dim = key_states.shape
        assert head_dim == self._shape[-1], "The head dimension of the key_states does not match the cache's head dimension"
        assert num_heads == self._shape[1], "The number of heads of the key_states does not match the cache's number of heads"

        for b in range(bsz):
            for i in range(num_heads):
                self._cache[i].add(key_states[b, i, :, :].cpu().numpy())

        self._shape[2] += seq_len

        if self._debug_cache is None:
            self._debug_cache = key_states
        else:
            self._debug_cache = torch.cat([self._debug_cache, key_states], dim=-2)
 
        pass
    def __getitem__(self, idx: int) -> IndexFlatIP:
        # print("idx", idx)
        return self._cache[idx]

    def reconstruct(self) -> torch.Tensor:
        '''
        Reconstruct the mocker to a torch.Tensor
        Inefficient, only for debugging
        '''
        bsz, num_heads, seq_len, head_dim = self._shape
        key_states = torch.zeros(bsz, num_heads, seq_len, head_dim, device='cuda:0')

        for b in range(bsz):
            for i in range(num_heads):
                key_states[b, i, :, :] = torch.tensor(self._cache[i].reconstruct_n(0, seq_len), device=key_states.device)
        
        return key_states

class DatabaseCache(DynamicCache):
    def __init__(self) -> None:
        self.key_cache : List[KeyStateTensorMocker] = [] # indexed by layer_idx
        self.value_cache : List[torch.Tensor] = []
        self._debug_key_cache : List[torch.Tensor] = []
        self._seen_tokens = 0
    
    def reorder_cache(self, beam_idx: torch.LongTensor):
        raise NotImplementedError("Reordering the cache is not currently supported")

    def query(self, query_states, layer_idx, *, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
        '''
        Basically implements SDPA with cache
        '''
        bsz, num_heads, query_len, head_dim = query_states.shape
        seq_len = self._seen_tokens

        assert bsz == 1, "Batch size > 1 is not currently supported"
        
        # scale
        scaling_factor = 1 / math.sqrt(head_dim) if scale is None else scale
        # scaling_factor = 1

        # bias
        attn_bias = torch.zeros(query_len, seq_len, device=query_states.device, dtype=query_states.dtype)
        if is_causal:
            assert attn_mask is None, "is_causal and attn_mask cannot be used together"
            temp_mask = torch.ones(query_len, seq_len, device=query_states.device, dtype=torch.bool).tril(diagonal=0)
            attn_bias.masked_fill_(temp_mask.logical_not(), float('-inf'))
            attn_bias.to(query_states.dtype)
        if attn_mask is not None:
            if attn_mask.dtype == torch.bool:
                attn_bias.masked_fill(attn_mask.logical_not(), float('-inf'))
            else:
                attn_bias += attn_mask

        # score
        if seq_len > 0:
            attn_score = torch.zeros(bsz, num_heads, query_len, seq_len, device=query_states.device)
            top_k = int(seq_len * 1)
            # top_k = seq_len
            for b in range(bsz): # TODO: parallelize this
                for h in range(num_heads):
                    D, I = self.key_cache[layer_idx][h].search(query_states[b, h, :, :].cpu().numpy(), top_k) # TODO: specify k
                    # convert D to tensor
                    D = torch.tensor(D, device=query_states.device)
                    for (idx, cols) in enumerate(I):
                        for (jdx, col) in enumerate(cols):
                            attn_score[b, h, idx, col] = D[idx, jdx]

        else:
            attn_score = query_states @ self._debug_key_cache[layer_idx].transpose(-1, -2)


        attn_score[torch.abs(attn_score) < 1e-5] = -1e10

        attn_score = attn_score * scaling_factor + attn_bias

        # softmax
        attn_score = torch.softmax(attn_score, dim=-1)

        # dropout
        attn_score = torch.dropout(attn_score, dropout_p, train=True)

        # weighted sum
        return attn_score @ self.value_cache[layer_idx]

    def update(self, key_states, value_states, layer_idx, cache_kwargs=None) -> None:
        '''
        Broken change: returns None instead of updated states
        '''
        # key_states is shaped (bsz, num_heads, seq_len, head_dim)
        if layer_idx == 0:
            self._seen_tokens += key_states.shape[-2]

        bsz, num_heads, seq_len, head_dim = key_states.shape

        # initialize the cache if it doesn't exist
        if len(self.key_cache) <= layer_idx:
            self.key_cache.append(KeyStateTensorMocker(key_states, layer_idx))
            self.value_cache.append(value_states)

            if isinstance(key_states, KeyStateTensorMocker):
                # key_states = key_states.reconstruct()
                key_states = key_states._debug_cache
            self._debug_key_cache.append(key_states)
        else:
            # update the cache
            self.key_cache[layer_idx].cat(key_states)
            self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
            self._debug_key_cache[layer_idx] = torch.cat([self._debug_key_cache[layer_idx], key_states], dim=-2)
     
class LlamaForCausalLMDB(LlamaForCausalLM):
    def __init__(self, config):
        super(LlamaForCausalLM, self).__init__(config)
        self.model = LlamaModelDB(config)
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.post_init()
        self.middleware = {}
        self.fwcall = 0

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        r"""
        Args:
            labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
                config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
                (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Returns:

        Example:

        ```python
        >>> from transformers import AutoTokenizer, LlamaForCausalLM

        >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
        >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

        >>> prompt = "Hey, are you conscious? Can you talk to me?"
        >>> inputs = tokenizer(prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
        >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
        ```"""
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        self.fwcall += 1
        print(f"fwcall: {self.fwcall}, key cache size(one layer): {past_key_values[0][0].shape if past_key_values is not None else None}")
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
        )

        hidden_states = outputs[0]
        if self.config.pretraining_tp > 1:
            lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
            logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
            logits = torch.cat(logits, dim=-1)
        else:
            logits = self.lm_head(hidden_states)
        logits = logits.float()

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
    
class LlamaModelDB(LlamaModel):
    def __init__(self, config):
        super(LlamaModel, self).__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList(
            [LlamaDecoderLayerDB(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.gradient_checkpointing = False

        # Initialize weights and apply final processing
        self.post_init()

        self.middleware = {}
    
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        '''
        Only difference is that we use DatabaseCache instead of DynamicCache
        '''

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
            )

        if self.gradient_checkpointing and self.training and use_cache:
            use_cache = False

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        past_seen_tokens = 0
        if use_cache:  # kept for BC (cache positions)
            if not isinstance(past_key_values, StaticCache):
                past_key_values = DatabaseCache.from_legacy_cache(past_key_values)
                past_seen_tokens = past_key_values.get_seq_length()

        if cache_position is None:
            if isinstance(past_key_values, StaticCache):
                raise ValueError("cache_position is a required argument when using StaticCache.")
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens)

        # embed positions
        hidden_states = inputs_embeds

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = None

        for decoder_layer in self.layers:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    decoder_layer.__call__,
                    hidden_states,
                    causal_mask,
                    position_ids,
                    past_key_values,
                    output_attentions,
                    use_cache,
                    cache_position,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                )

            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache = layer_outputs[2 if output_attentions else 1]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = None
        if use_cache:
            next_cache = (
                next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
            )
        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

class LlamaDecoderLayerDB(LlamaDecoderLayer):
    def __init__(self, config, layer_idx):
        super(LlamaDecoderLayer, self).__init__()
        self.hidden_size = config.hidden_size

        # self.self_attn = LlamaSdpaAttention(config, layer_idx)
        self.self_attn = LlamaSdpaAttentionDB(config, layer_idx)

        self.mlp = LlamaMLP(config)
        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

class LlamaSdpaAttentionDB(LlamaSdpaAttention):
    """
    Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
    `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
    SDPA API.
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.middleware = {}

    # Adapted from LlamaAttention.forward
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        if output_attentions:
            return super().forward(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
            )

        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        cos, sin = self.rotary_emb(value_states, position_ids)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        causal_mask = attention_mask
        if attention_mask is not None:
            causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
            
        attn_output = past_key_value.query(
            query_states, 
            self.layer_idx,
            attn_mask=causal_mask,
            dropout_p=self.attention_dropout if self.training else 0.0,
            is_causal=causal_mask is None and q_len > 1,
        )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        self.middleware.update({
            "query_states" : query_states,
            "key_states" : key_states,
            "value_states" : value_states,
            "past_key_value" : past_key_value
        })

        return attn_output, None, past_key_value
    

## 使用IndexFlatIP验证正确性

In [11]:
key_states = m['key_states']
value_states = m['value_states']
query_states = m['query_states']

sdpa = torch.nn.functional.scaled_dot_product_attention(
    m['query_states'],
    m['key_states'],
    m['value_states']
)

dbcache = DatabaseCache()
dbcache.update(key_states, value_states, 0)

In [12]:
layer_idx = 0

bsz, num_heads, query_len, head_dim = query_states.shape
seq_len = dbcache._seen_tokens

attn_score = torch.zeros(bsz, num_heads, query_len, seq_len)
for b in range(bsz):
    for h in range(num_heads):
        D, I = dbcache.key_cache[layer_idx][h].search(query_states[b, h, :, :].cpu().numpy(), seq_len) # TODO: specify k
        # convert D to tensor
        D = torch.tensor(D, device=query_states.device)
        for (idx, cols) in enumerate(I):
            for (jdx, col) in enumerate(cols):
                attn_score[b, h, idx, col] = D[idx, jdx]

In [13]:
query_states[0, 0, :, :].shape

torch.Size([1, 128])

In [14]:
index = IndexFlatIP(head_dim)

index.add(key_states[0, 0, :, :].cpu().numpy())

In [15]:
D, I = index.search(query_states[0, 0, :, :].cpu().numpy(), 5)

print(query_states[0, 0, :, :] @ key_states[0, 0, I[0][0], :])
print(D[0][0])

tensor([17.8612], device='cuda:0')
17.861155


In [16]:
attn_score.shape

torch.Size([1, 32, 1, 1])

In [17]:
torch.matmul(query_states, key_states.transpose(-1, -2))

tensor([[[[ 17.8612]],

         [[  2.9589]],

         [[  4.6459]],

         [[ 34.8256]],

         [[ 26.3393]],

         [[  4.4898]],

         [[ 16.9595]],

         [[-29.1442]],

         [[ 15.6031]],

         [[-16.4950]],

         [[ 16.2660]],

         [[ 18.6018]],

         [[ 40.9244]],

         [[-12.6539]],

         [[ -3.8843]],

         [[ 12.1046]],

         [[ 12.3422]],

         [[  9.4559]],

         [[ 21.0365]],

         [[ 36.2610]],

         [[ 11.7032]],

         [[ -4.3956]],

         [[ 24.3023]],

         [[ 17.2784]],

         [[  3.4854]],

         [[ 10.2540]],

         [[ 16.6599]],

         [[ -8.6022]],

         [[ 16.4774]],

         [[ 16.5451]],

         [[-14.7055]],

         [[  1.9440]]]], device='cuda:0')

In [18]:
attn_score

tensor([[[[ 17.8612]],

         [[  2.9589]],

         [[  4.6459]],

         [[ 34.8256]],

         [[ 26.3393]],

         [[  4.4898]],

         [[ 16.9595]],

         [[-29.1442]],

         [[ 15.6031]],

         [[-16.4950]],

         [[ 16.2660]],

         [[ 18.6018]],

         [[ 40.9244]],

         [[-12.6539]],

         [[ -3.8843]],

         [[ 12.1046]],

         [[ 12.3422]],

         [[  9.4559]],

         [[ 21.0365]],

         [[ 36.2610]],

         [[ 11.7032]],

         [[ -4.3956]],

         [[ 24.3023]],

         [[ 17.2784]],

         [[  3.4854]],

         [[ 10.2540]],

         [[ 16.6599]],

         [[ -8.6022]],

         [[ 16.4774]],

         [[ 16.5451]],

         [[-14.7055]],

         [[  1.9440]]]])

In [19]:
# scale & softmax
scaling_factor = torch.sqrt(torch.tensor(head_dim))
attn_score = torch.softmax(attn_score / scaling_factor, dim=-1)

# weighted sum
attn_output = torch.zeros_like(query_states)
for b in range(bsz):
    for h in range(num_heads):
        for i in range(query_len):
            # Each output vector is a sum over all value vectors, weighted by the attention scores
            for j in range(seq_len):
                attn_output[b, h, i, :] += attn_score[b, h, i, j] * dbcache.value_cache[layer_idx][b, h, j, :]

In [24]:
attn_output2 = attn_score @ dbcache.value_cache[layer_idx].cpu()

In [20]:
print(attn_output)

tensor([[[[ 7.5211e-03, -1.0375e-02,  3.5426e-03,  ..., -1.8492e-03,
            3.9925e-03, -4.3669e-03]],

         [[ 2.1980e-03,  3.3925e-03, -2.4338e-03,  ...,  2.2310e-03,
            2.4552e-03,  5.0936e-03]],

         [[-4.0987e-03, -3.7784e-03,  3.0119e-03,  ...,  7.2256e-04,
           -4.6644e-05,  9.7175e-03]],

         ...,

         [[-3.7973e-02, -6.9194e-03, -1.7842e-02,  ...,  7.5553e-03,
            3.7649e-03, -1.2968e-02]],

         [[-1.0510e-02,  1.2955e-03,  8.9638e-04,  ...,  7.3972e-04,
            3.3017e-04,  4.8898e-03]],

         [[-7.9546e-03,  8.9661e-03,  1.2278e-03,  ...,  4.6792e-05,
            1.1711e-03, -9.9407e-03]]]], device='cuda:0')


In [21]:
print(sdpa)

tensor([[[[ 7.5211e-03, -1.0375e-02,  3.5426e-03,  ..., -1.8492e-03,
            3.9925e-03, -4.3669e-03]],

         [[ 2.1980e-03,  3.3925e-03, -2.4338e-03,  ...,  2.2310e-03,
            2.4552e-03,  5.0936e-03]],

         [[-4.0987e-03, -3.7784e-03,  3.0119e-03,  ...,  7.2256e-04,
           -4.6644e-05,  9.7175e-03]],

         ...,

         [[-3.7973e-02, -6.9194e-03, -1.7842e-02,  ...,  7.5553e-03,
            3.7649e-03, -1.2968e-02]],

         [[-1.0510e-02,  1.2955e-03,  8.9638e-04,  ...,  7.3972e-04,
            3.3017e-04,  4.8898e-03]],

         [[-7.9546e-03,  8.9661e-03,  1.2278e-03,  ...,  4.6792e-05,
            1.1711e-03, -9.9407e-03]]]], device='cuda:0')


In [25]:
print(attn_output2)

tensor([[[[ 7.5211e-03, -1.0375e-02,  3.5426e-03,  ..., -1.8492e-03,
            3.9925e-03, -4.3669e-03]],

         [[ 2.1980e-03,  3.3925e-03, -2.4338e-03,  ...,  2.2310e-03,
            2.4552e-03,  5.0936e-03]],

         [[-4.0987e-03, -3.7784e-03,  3.0119e-03,  ...,  7.2256e-04,
           -4.6644e-05,  9.7175e-03]],

         ...,

         [[-3.7973e-02, -6.9194e-03, -1.7842e-02,  ...,  7.5553e-03,
            3.7649e-03, -1.2968e-02]],

         [[-1.0510e-02,  1.2955e-03,  8.9638e-04,  ...,  7.3972e-04,
            3.3017e-04,  4.8898e-03]],

         [[-7.9546e-03,  8.9661e-03,  1.2278e-03,  ...,  4.6792e-05,
            1.1711e-03, -9.9407e-03]]]])


## 使用pipeline进行文本生成

In [25]:
modelDB = LlamaForCausalLMDB.from_pretrained("llama-2-7b-hf")
tokenizer = LlamaTokenizer.from_pretrained("llama-2-7b-hf")

Loading checkpoint shards: 100%|██████████| 3/3 [00:02<00:00,  1.17it/s]


In [27]:
from transformers import pipeline
import torch

# device = "cpu"
device = "cuda:0"

pipe = pipeline(
    'text-generation',
    model=modelDB,
    torch_dtype=torch.float16,
    device = device,
    tokenizer=tokenizer,
)

sequence = "In this work, we develop"

output = pipe(sequence, max_length=128, do_sample=True, temperature=0.9)
print(output)

The model 'LlamaForCausalLMDB' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'JambaForCausalLM', 'LlamaForCausalLM', 'MambaForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MistralForCausalLM', 'MixtralForCausalLM', 'MptForCausalLM', 'MusicgenForCausalLM', 'MusicgenMelodyForC

fwcall: 15, key cache size(one layer): None
fwcall: 16, key cache size(one layer): (1, 32, 6, 128)
fwcall: 17, key cache size(one layer): (1, 32, 7, 128)
fwcall: 18, key cache size(one layer): (1, 32, 8, 128)
fwcall: 19, key cache size(one layer): (1, 32, 9, 128)
fwcall: 20, key cache size(one layer): (1, 32, 10, 128)
fwcall: 21, key cache size(one layer): (1, 32, 11, 128)
fwcall: 22, key cache size(one layer): (1, 32, 12, 128)
fwcall: 23, key cache size(one layer): (1, 32, 13, 128)
fwcall: 24, key cache size(one layer): (1, 32, 14, 128)
fwcall: 25, key cache size(one layer): (1, 32, 15, 128)
fwcall: 26, key cache size(one layer): (1, 32, 16, 128)
fwcall: 27, key cache size(one layer): (1, 32, 17, 128)
fwcall: 28, key cache size(one layer): (1, 32, 18, 128)
fwcall: 29, key cache size(one layer): (1, 32, 19, 128)
fwcall: 30, key cache size(one layer): (1, 32, 20, 128)
fwcall: 31, key cache size(one layer): (1, 32, 21, 128)
fwcall: 32, key cache size(one layer): (1, 32, 22, 128)
fwcall: 

In [21]:
m = modelDB.model.layers[0].self_attn.middleware
cache = m['past_key_value']
query_states = m['query_states']
key_states = m['key_states']
value_states = m['value_states']

In [5]:
query_states.shape

torch.Size([1, 32, 8, 128])

In [6]:
cache.key_cache[0].shape

(1, 32, 8, 128)

In [10]:
to_compare = torch.nn.functional.scaled_dot_product_attention(
    query_states, key_states, value_states
)

In [11]:
to_compare.shape

torch.Size([1, 32, 1, 128])

In [14]:
query_result = cache.query(query_states, 0)

In [13]:
print(to_compare)

tensor([[[[ 3.6696e-03, -7.9229e-03,  8.4558e-03,  ..., -4.5834e-03,
           -2.3658e-03, -1.7671e-04]],

         [[ 4.3845e-04, -5.5317e-03,  5.8105e-03,  ...,  8.9691e-03,
           -4.0985e-04, -5.5041e-03]],

         [[-6.8097e-03,  3.5033e-03,  4.1462e-03,  ...,  5.9926e-03,
            2.3143e-04, -2.0504e-03]],

         ...,

         [[-2.8282e-02,  4.3157e-02, -1.2844e-02,  ...,  3.6885e-02,
            7.5483e-03, -1.6709e-03]],

         [[ 9.6602e-05,  6.6361e-03,  8.9406e-03,  ..., -1.1882e-02,
           -5.2853e-03,  1.6323e-02]],

         [[ 3.4548e-03,  4.0896e-03, -1.1464e-02,  ..., -3.3808e-04,
            4.1257e-03,  6.8261e-03]]]], device='cuda:0')


In [16]:
print(query_result)

tensor([[[[ 1.3781e-03, -5.7980e-04,  2.9330e-03,  ...,  7.4228e-04,
           -1.4491e-03, -1.3022e-04]],

         [[-2.8401e-03, -1.2746e-03, -3.7840e-03,  ...,  2.6666e-03,
           -4.3657e-04, -1.0268e-03]],

         [[-3.1803e-04,  1.7788e-03,  2.1568e-03,  ...,  1.2615e-03,
            2.9185e-03,  8.2977e-04]],

         ...,

         [[-1.4362e-02, -2.4401e-03, -1.8101e-02,  ...,  2.2499e-02,
           -1.3045e-02,  7.9051e-03]],

         [[ 4.3792e-03,  4.5733e-03, -9.5786e-04,  ...,  1.0060e-03,
           -8.3106e-05,  2.5171e-03]],

         [[ 1.1982e-04,  3.7192e-03, -1.3928e-03,  ...,  3.4574e-03,
            1.4829e-03,  1.6591e-03]]]], device='cuda:0')


In [None]:
from faiss import IndexPQ
index = IndexPQ(128, 8, 4)

In [None]:
print(m['value_states'].shape)

indices = [IndexPQ(128, 8, 4) for _ in range(32)]

for i, index in enumerate(indices):
    tmp = m['value_states'][:, i, :, :].view(-1, 128).cpu().numpy()
    index.train(tmp)
    index.add(tmp)

In [None]:
m['value_states'][:, i, :, :].view(-1, 128).cpu().numpy().shape

In [None]:
index.add(m['value_states'].cpu().numpy())

In [None]:
sequence = "SEIEE in SJTU is"
output = pipe(sequence, max_length=512, do_sample=True, temperature=0.9)
print(output)

从这里我们可以看到，一个新的文本生成任务会使用一个新的缓存。

In [None]:
print(m['past_key_value']._seen_tokens)

In [None]:
m['hidden_states_input'].shape

In [None]:
model.__call__?

In [None]:
model.forward?

In [None]:
model

# 乘法器优化

In [None]:
def _merge_heads(tensor, num_heads, attn_head_size):
    """
    Merges attn_head_size dim and num_attn_heads dim into hidden_size
    """
    tensor = tensor.permute(0, 2, 1, 3).contiguous()
    new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
    return tensor.view(new_shape)

In [None]:
print("k.shape", m['past_key_value'].key_cache[0].shape)
print("v.shape", m['past_key_value'].value_cache[0].shape)

In [None]:
# make a copy of q, k
q = m['past_key_value'].key_cache[0].clone()
k = m['past_key_value'].value_cache[0].clone()

# merge heads
num_heads = 32
attn_head_size = 128
q = _merge_heads(q, num_heads, attn_head_size)
k = _merge_heads(k, num_heads, attn_head_size)

print("q.shape", q.shape)
print("k.shape", k.shape)

In [None]:
# Assuming q and k are defined as torch.Tensor with the given shapes
# q.shape == torch.Size([1, 357, 4096])
# k.shape == torch.Size([1, 357, 4096])

# Expand q and k to prepare for element-wise multiplication
# Expand q to [1, 357, 1, 4096] and k to [1, 1, 357, 4096]
q_expanded = q.unsqueeze(2)  # Adding a singleton dimension for broadcasting
k_expanded = k.unsqueeze(1)  # Adding a different singleton dimension for broadcasting

# Element-wise multiplication
# Result shape will be [1, 357, 357, 4096], capturing each multiplication
multiplication_results = q_expanded * k_expanded

# To verify, let's examine the shape
print("Shape of multiplication_results:", multiplication_results.shape)

In [None]:
multiplication_results.reshape(-1)

In [None]:
len(multiplication_results.reshape(-1))

In [None]:
import matplotlib.pyplot as plt

# plt.hist(multiplication_results.reshape(-1).cpu().numpy(), bins=100, log=True)
# cdf of abs(multiplication_results)

plt.hist(multiplication_results.reshape(-1).abs().cpu().numpy(), bins=100, cumulative=True, density=True, histtype='step')
plt.xscale("log")

plt.title(f"Q*K^T, {q.shape[1]} * {q.shape[2]} * {k.shape[1]}")

In [None]:
num_heads = 32
attn_head_size = 128

# store multiplication results in a list
multiplication_results = []


# for i, layer in enumerate(model.model.layers):
for i in [0,15,31]:
    layer = model.model.layers[i]
    

    m = layer.self_attn.middleware
    
    q = m['past_key_value'].key_cache[0].clone()
    k = m['past_key_value'].value_cache[0].clone()

    q = _merge_heads(q, num_heads, attn_head_size)
    k = _merge_heads(k, num_heads, attn_head_size)

    q_expanded = q.unsqueeze(2)
    k_expanded = k.unsqueeze(1)
    tmp = q_expanded * k_expanded

    # offload to cpu
    tmp.cpu()

    # free cuda memory
    del q, k, q_expanded, k_expanded

    multiplication_results.append(tmp)

# flatten the list
multiplication_results = torch.cat([tmp.reshape(-1) for tmp in multiplication_results])

In [None]:
import matplotlib.pyplot as plt

# plt.hist(multiplication_results.reshape(-1).cpu().numpy(), bins=100, log=True)
# cdf of abs(multiplication_results)

plt.hist(multiplication_results.reshape(-1).abs().cpu().numpy(), bins=100, cumulative=True, density=True, histtype='step')
plt.xscale("log")

plt.title(f"Q*K^T, {q.shape[1]} * {q.shape[2]} * {k.shape[1]}")

# Evaluation

In [None]:
from perplexity import perplexity

device = 'cuda:0'
root = '~/'
dataset = 'PTB'

stride = model.config.max_position_embeddings # 4096
ppl_baseline = perplexity(model, tokenizer, dataset, device, verbose=True, stride=stride, root=root)
print(ppl_baseline)


In [3]:
import os

# os.environ['http_proxy'] = '127.0.0.1:7897'
# os.environ['https_proxy'] = '127.0.0.1:7897'

model = LlamaForCausalLM.from_pretrained("llama-2-7b-hf")
tokenizer = LlamaTokenizer.from_pretrained("llama-2-7b-hf")


from deepeval.benchmarks import MMLU
from deepeval.benchmarks.tasks import MMLUTask

# Define benchmark with specific tasks and shots
benchmark = MMLU(
    tasks=[MMLUTask.HIGH_SCHOOL_COMPUTER_SCIENCE, MMLUTask.ASTRONOMY],
    n_shots=3
)

benchmark.evaluate(model=model)

# unset http_proxy and https_proxy
# os.environ.pop('http_proxy')
# os.environ.pop('https_proxy')

# Replace 'mistral_7b' with your own custom model
print(benchmark.overall_score)

Loading checkpoint shards: 100%|██████████| 3/3 [00:02<00:00,  1.30it/s]
Using the latest cached version of the module from /home/xupeng/.cache/huggingface/modules/datasets_modules/datasets/lukaemon--mmlu/5407247256b75097c6ed96d65e9673eaf8cb7522ab67e1ea65e7bb85b44be036 (last modified on Thu Apr 25 10:08:46 2024) since it couldn't be found locally at lukaemon/mmlu, or remotely on the Hugging Face Hub.
Processing high_school_computer_science:   0%|          | 0/100 [00:00<?, ?it/s]The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Processing high_school_computer_science:   0%|          | 0/100 [00:00<?, ?it/s]


AttributeError: 'str' object has no attribute 'shape'