In [2]:
import torch
from torch import nn
import numpy as np
import random
from matplotlib import pyplot as plt
from typing import Any, Dict, List, Optional, Tuple
from transformers.cache_utils import Cache

In [3]:
def distance_fun(value):
    #初步先定为sigmoid函数
    return torch.sigmoid(value)
    
    
def adaptive_position( query:torch.Tensor, key:torch.Tensor,position_modified_ids=None)->torch.Tensor:
    '''输入的维度应该是 
    key: [batch, num_heads, seq_length, hidden_dim]
    
    in prefilling stage:
    query: [batch, num_heads, seq_length, hidden_dim]
    
    in inference stage:
    query: [batch, num_heads, 1, hidden_dim]'''
    batch,num_heads = query.shape[0], query.shape[1]
    if position_modified_ids == None:
        #此时应该是预填充阶段
        #注意要错位相乘, 因此query要去第一个, key要去最后一个
        query = query[:,:,1:,:].unsqueeze(-2)
        key = key[:,:,:-1,:].unsqueeze(-2)
        distance = distance_fun(torch.matmul(query,key.transpose(-2,-1)).squeeze(-1)).transpose(-1,-2)
        
        #这个distance相比于真正的距离差着一个一号位, 我们都强制命令一号位的位置为0
        distance_0 = torch.zeros((batch,num_heads,1,1),dtype=distance.dtype)
        distance = torch.cat((distance_0,distance),dim=-1) #在一号位添加一个0
        
        position_modified_ids = distance.cumsum(dim=-1) #此时的输出维度是[batch,num_head, 1, seq_length]
    
    else:
        #这个时候的输入一般是"一个"query, 我们的任务是计算新的query和倒数第二个key之间的距离.
        
        new_distance = distance_fun(torch.matmul(query, key[:,:,-2,:].unsqueeze(-2).transpose(-2,-1)))
        position_modified_ids = new_distance + position_modified_ids[:,:,:,-1].unsqueeze(-2)
    return position_modified_ids #返回的结果: 若是在pre-filling阶段, 则是[batch, num_head, 1, seq_length], 若是在推理阶段, 则是[batch, num_head,1,1]


'''下面几个函数记得区分query_status和query_cache'''

def query_select(query_status, block_ends, key_cache, block_lst:list):#从给定的blocks中挑选出来最有用的那个
    #query: [1,hidden_dim]
    #block_ends: list[int] 
    #key_cache : [cache_length,hidden_dim]
    block_num = len(block_ends)
    scores = []
    for i in range(1,block_num):
        start = block_ends[i-1]
        end = block_ends[i]
        key_status = key_cache[start:end,:]
        score = torch.mean((query_status@key_status.T))
        scores.append(score)
    max_id = torch.argmax(torch.tensor(scores))+1
    end = block_ends[max_id]
    return block_lst.index(end)

def block_trace(block_dependency,#这个block_dependecy是针对某个batch和head_idx的
                select_block_id, 
                left_length,
                block_lst,
                position_method = 'origin',
                block_size = 64): 
    #key_cache:[seq_length,hidden_dim]
    #输入的left_length 是去除掉[skip_block]和[sink_block]以后的长度
    select_keys_position = [[0,0]]

    while block_dependency[select_block_id] != None and left_length>0: #此时说明select_block_id不是0
        
        end = block_lst[select_block_id]
        start = block_lst[select_block_id-1]
        select_keys_position = [[start,end]] + select_keys_position
        if position_method == 'origin':
            left_length -= (end-start)
        elif position_method == 'modified':
            left_length -= block_size
        select_block_id = block_dependency[select_block_id]
        #注意, 此函数只会追踪到倒数第二个block, 不会把第一个block的信息放到select_keys_position, 这个是给sink_block留下空间
    return select_keys_position 

def query_aware(query_status:torch.Tensor, #[batch,num_heads, seq_length,hidden_dim]
                key_cache, #[batch, num_heads,cache_length, hidden_dim]
                value_cache,#[batch, num_heads,cache_length, hidden_dim]
                position_ids, #这个很重要, 这个函数是对position_ids中的每个query计算其inference所用的keys #[1, seq_length]
                #请注意: 这里的position_ids不是position_modified_ids, 这里的全部是整数, 和每个query一一对应
                #并且它的长度是seq_length, 而不是cache_length
                #layer_idx,
                position_modified_ids, #[batch, num_heads, 1, cache_length]
                block_status, 
                block_dependency, #block_dependency和block_status是不带layer_idx的, 因为block_cache.update的输出会有
                max_length,
                skip_num=5, #其实这个skip_block更应该叫做neighbor_block, 我懒得改了
                find_num = 5,
                position_method = 'origin',
                block_size = 64):
    #输入的query维度是[batch, num_head, seq_length, hidden_dim]
    #输入的block是{batch1:{head1:[...],head2:[...],...},batch2:{head1:[...],head2:[...],...},...}
    #输入的position_ids维度是[1,seq_length]
    #输出是字典: {batch1:{head1:[[a1,a2,...],[b1,b2,...]]}}字典嵌套, 每个head的value是列表, 列表中每个元素是query挑选出来的token_ids, 列表的长度是seq_length
    batch_size, head_num = query_status.shape[:2]
    position_min = position_ids[0,0]
    seq_length = position_ids.shape[-1]
    keys_to_be_used = {}
    values_to_be_used = {}
    position_to_be_used = {}
    for batch in range(batch_size):
        keys_to_be_used[batch] = {} #理想状态的key_cache的组合应该是: [sink_block] + [select_block] + [skip_block]
        values_to_be_used[batch] = {}
        position_to_be_used[batch] = {}
        for head_idx in range(head_num):
            key_cache_to_be_used = key_cache[batch,head_idx,:,:]
            values_cache_to_be_used = value_cache[batch,head_idx,:,:]
            keys_to_be_used[batch][head_idx] = []
            values_to_be_used[batch][head_idx] = []
            position_to_be_used[batch][head_idx] = []

            i = 0
            #block_lst = block_status[layer_idx][batch][head_idx]
            block_lst = block_status[batch][head_idx]
            element_num = len(block_lst)
            for delta in range(seq_length):
                query_position = position_min+delta #query_position是query在seq 中的绝对位置, delta是在query_status中的绝对位置, 是在整个seq中的相对位置
                while i < element_num and block_lst[i] < query_position: #lst中第i-1个元素就是我们要找的那个, 即共有0...i-1 共计i个元素在query前面
                    i += 1
                if i <= skip_num: #说明我这个query前面的所有key_cache都有用, 此时select_block, 所有的sink_block和skip_block, left_keys重合
                    keys_to_be_used[batch][head_idx].append(key_cache_to_be_used[:query_position+1,:])
                    values_to_be_used[batch][head_idx].append(values_cache_to_be_used[:query_position+1,:])
                    if position_method == 'modified':
                        position_to_be_used[batch][head_idx].append(position_modified_ids[batch,head_idx,0,:query_position+1].unsqueeze(0))
                    else:
                        position_to_be_used[batch][head_idx].append(torch.arange(query_position+1).long().unsqueeze(0))
                else: #在这种状态下, key_cache的组合应该是: [sink_block] + [select_block] + [skip_block] + left_keys
                    end = i-skip_num 
                    if end > find_num+1: 
                        start = end-find_num-1 #start >= 1
                        #接下来的block_ends只是用来在其中根据query挑选一个最有用的出来, 使用函数query_select
                        block_ends = block_lst[start:end] #注意, 此时len(block_ends) = end-start = find_num+1, 第一个元素是上个block的结尾
                    else: #此时满足: 跳过了一些block, 然后剩下的block不够找出来所有的find_number了, select_block和sink_block有可能重合
                        block_ends = [0] + block_lst[:end]
                    selected_block_idx = query_select(
                            query_status = query_status[batch,head_idx,delta,:],
                            block_ends = block_ends, 
                            key_cache = key_cache[batch,head_idx,:,:],
                            block_lst=block_lst) #注意: 返回的是block_lst中被选中的block的index, 若想知道真正的block的位置, 需要查询block_lst[idx]
                    #selected_block_idx是query最关注的那个block的idx in block_lst
                    if position_method == 'origin':
                        skip_block_length = block_lst[i-1] - block_lst[end-1] 
                        sink_block_length = block_lst[0]
                        left_keys_length = query_position-block_lst[i-1]
                        left_length = max_length - (skip_block_length+sink_block_length+left_keys_length)
                    elif position_method == 'modified':
                        skip_block_length = block_size * skip_num
                        sink_block_length = block_size
                        
                        query_modified_position = position_modified_ids[batch][head_idx][0][query_position]
                        left_keys_length = query_modified_position-position_modified_ids[batch][head_idx][0][block_lst[i-1]]
                        left_length = max_length - (skip_block_length+sink_block_length+left_keys_length)
                
                    '''
                    left_length的计算可以修改, 比如使用position_modified_ids可以使推理时使用更多的tokens, 类似于插值的方法, 但是还需要进一步实验
                    '''
                    
                    select_keys_position = block_trace(
                        #block_dependency = block_dependency[layer_idx][batch][head_idx],
                        block_dependency = block_dependency[batch][head_idx],
                        select_block_id = selected_block_idx,
                        left_length = left_length,
                        block_lst= block_lst,
                        position_method = position_method,
                        block_size=block_size
                    )#当selected_block_idx == 0时, 他的输出是[[0,0]], 若不是0, 则是类似于[[79,123],[261,337],...]
                    

                    skip_end = block_lst[i-1]
                    skip_start = block_lst[i-1-skip_num]

                    sink_block_key = key_cache_to_be_used[:block_lst[0],:] #这个东西永远是第一个block的信息
                    sink_block_value = values_cache_to_be_used[:block_lst[0],:]
                    select_blocks_key = torch.cat([key_cache_to_be_used[block[0]:block[1],:] for block in select_keys_position]
                                                ,dim=-2)
                    select_blocks_value = torch.cat([values_cache_to_be_used[block[0]:block[1],:] for block in select_keys_position]
                                                ,dim=-2)
                    skip_blocks_key = key_cache_to_be_used[skip_start:skip_end,:]
                    skip_blocks_value = values_cache_to_be_used[skip_start:skip_end,:]
                    left_keys = key_cache_to_be_used[block_lst[i-1]:query_position+1,:]#这是query前面, 但是却没有在任何一个block中的keys
                    left_values = values_cache_to_be_used[block_lst[i-1]:query_position+1,:]
                    keys_to_be_used[batch][head_idx].append(torch.cat([sink_block_key,select_blocks_key,skip_blocks_key,left_keys],dim=-2))
                    v = torch.cat([sink_block_value,select_blocks_value,
                                                                    skip_blocks_value,left_values],dim=-2)
                    values_to_be_used[batch][head_idx].append(v)
                    
                    
                    if position_method == 'modified':
                        position_cache = position_modified_ids[batch,head_idx,0,:]
                        sink_block_position = position_cache[:block_lst[0]]
                        select_blocks_position = sink_block_position
                        for block in select_keys_position:
                            select = position_cache[block[0]:block[1]] - position_cache[block[0]]+select_blocks_position[-1]
                            select_blocks_position = torch.cat([select_blocks_position,select], dim = 0)
                        skip_blocks_position = position_cache[skip_start:skip_end] - position_cache[skip_start] + select_blocks_position[-1]
                        left_postion = position_cache[block_lst[i-1]:query_position+1] - position_cache[block_lst[i-1]] + skip_blocks_position[-1]
                        position_to_be_used[batch][head_idx].append(torch.cat([select_blocks_position,skip_blocks_position,left_postion],dim=0).unsqueeze(0))
                        
                    elif position_method == 'origin':
                        position_to_be_used[batch][head_idx].append(
                            torch.arange(len(v)).long().unsqueeze(0)
                            )
    return keys_to_be_used, values_to_be_used, position_to_be_used


def get_key_value()
            

In [6]:
def get_block_score(query_cache,key_cache,obj_block, end_block,norm = 0): #把最后一个block当作end_block, 计算end_block和obj_block之间的分数
    obj_key = key_cache[obj_block[0]:obj_block[1],:] #[obj_block_length, hidden_dim]
    end_query = query_cache[end_block[0]-norm:end_block[1]-norm,:] #[end_block_length, hidden_dim]
    score = torch.mean(end_query@obj_key.T)
    return score


def block_select(query_cache,key_cache,block_lst:List,neighbor_block_num,norm = 0):
    #query_status:[seq_length, hidden_dim]
    #key_cache : [cache_length, hidden_dim]
    #block_lst : list[]
    #输出为: block_list中最后一个block在neighbor_block中最关注的block的index
    last_block_end = block_lst[-1]
    last_block_start = block_lst[-2]
    max_score = -1e10
    
    if len(block_lst)-1 <= neighbor_block_num: #说明前面的所有block都需要检查
        block_lst = [0]+block_lst
        for i in range(1,len(block_lst)-1):
            obj_block = [block_lst[i-1],block_lst[i]]
            score = get_block_score(query_cache, key_cache,obj_block,end_block=[last_block_start,last_block_end],norm = norm)
            if score >= max_score:
                max_score = score
                max_id = i

    elif len(block_lst)-1 > neighbor_block_num:
        for i in range(-2, -neighbor_block_num-2,-1):
            obj_block = [block_lst[i-1],block_lst[i]]
            score = get_block_score(query_cache, key_cache,obj_block,[last_block_start,last_block_end], norm = norm)
            if score >= max_score:
                max_score = score
                max_id = len(block_lst) + i
    #len(socres) = neighbor_block_num
    #max_id_reverse = int(torch.argmax(torch.tensor(scores)))-len(scores)
    
    return max_id



def dependency_update(query_cache,key_cache,block_lst,block_dependency,neighbor_block_num,norm = 0):
    if block_dependency == []:
        block_dependency.append(None)
        for i in range(1,len(block_lst)):
            depend_idx = block_select(query_cache,key_cache,block_lst[:i+1],neighbor_block_num,norm = norm)
            block_dependency.append(depend_idx)
    else:
        depend_idx = block_select(query_cache,key_cache,block_lst,neighbor_block_num,norm = norm)
        block_dependency.append(depend_idx)
    return block_dependency

In [7]:
class Block_Cache():
    def __init__(self,block_size=64) -> None:
        self.modified_position_cache:list[torch.Tensor] = []#列表中每个元素是三维的tensor[batch, num_heads, cache_length]
        self.block_cache = [] #列表中每个元素是一个字典嵌套字典的结构{batch:{head_idx:[78,155,283,...]}}
        self.block_size = block_size
        self.query_cache = [] #{batch:{head_idx:[query_cache_length,hidden_dim]}}
        self.dependency_cache = []#每一个元素都是一个dict{batch: {head_idx:[None, 78, 78,...]}} 其中列表的长度应该和block_cache中列表的长度相同


    def update(self, 
               layer_idx: int,
               query_status,#[batch, head_idx, seq_length,hidden_dim]
               key_cache,#[batch, head_idx, cache_length,hidden_dim]
               neighbor_block_num):
        
        if len(self.modified_position_cache) <= layer_idx:
            position_modified_ids = adaptive_position(query_status,key_cache) #[batch, head_idx, seq_length,hidden_dim]
            self.modified_position_cache.append(position_modified_ids)
        else:
            position_modified_ids = adaptive_position(query_status, key_cache, self.modified_position_cache[layer_idx])
            self.modified_position_cache[layer_idx] = torch.cat([self.modified_position_cache[layer_idx], position_modified_ids], dim = -1)
        
        position_modified_ids = position_modified_ids % self.block_size
        batch_size, num_heads = position_modified_ids.shape[0:2]
        cache_length = self.modified_position_cache[layer_idx].shape[-1]

        if len(self.block_cache) <= layer_idx:
            block_layer = {}
            dependency_layer = {}
            query_cache_layer = {}
            for batch in range(batch_size):
                block_layer[batch] = {}
                dependency_layer[batch] = {}
                query_cache_layer[batch] = {}
                for head_idx in range(num_heads):
                    block_layer[batch][head_idx] = [] #这个是block_lst
                    dependency_layer[batch][head_idx] = [] #这个是block_dependency, 他的长度和block_lst相同
                    # query_cache_layer[batch][head_idx]这个的长度是动态的
                    for i in range(1,cache_length):
                        if position_modified_ids[batch,head_idx,0,i] < position_modified_ids[batch,head_idx,0,i-1]: #记录每个block的最后一个token后一个的位置, 这是为了在之后切片的时候统一计算
                            block_layer[batch][head_idx].append(i)
                    block_lst = block_layer[batch][head_idx]
                    block_dependency = dependency_layer[batch][head_idx]
                    block_dependency = dependency_update(query_status[batch,head_idx,:,:],
                                                         key_cache[batch,head_idx,:,:],
                                                         block_lst,block_dependency,neighbor_block_num)
                    if block_lst == []:
                        query_cache_layer[batch][head_idx] = query_status[batch,head_idx,:,:]
                    else:
                        query_cache_layer[batch][head_idx] = query_status[batch,head_idx,block_lst[-1]:,:]

            self.block_cache.append(block_layer)
            self.dependency_cache.append(dependency_layer)
            self.query_cache.append(query_cache_layer)

        else:
            block_layer = self.block_cache[layer_idx]
            dependency_layer = self.dependency_cache[layer_idx]
            query_cache_layer = self.query_cache[layer_idx]
            assert position_modified_ids.shape[2] == 1 , f'在推理阶段假设一次只新增一个token, 而现在新增了{position_modified_ids.shape[2]}个'
            
            for batch in range(batch_size):
                for head_idx in range(num_heads):
                    position_ids = self.modified_position_cache[layer_idx] #在这一步的时候, self.modified_position_cache已经更新过了
                    self.query_cache[layer_idx][batch][head_idx] = torch.cat([self.query_cache[layer_idx][batch][head_idx],query_status[batch,layer_idx,:,:]],dim=0)
                    if position_ids[batch,head_idx,0,-1] <  position_ids[batch,head_idx,0,-2]:
                        block_layer[batch][head_idx].append(cache_length)
                        dependency_layer[batch][head_idx] = dependency_update(query_cache = self.query_cache[layer_idx][batch][head_idx],
                                                                              key_cache = key_cache[batch,head_idx,:,:],
                                                                              block_lst=block_layer[batch][head_idx],
                                                                              block_dependency=dependency_layer[batch][head_idx],
                                                                              neighbor_block_num=neighbor_block_num,
                                                                              norm = block_layer[batch][head_idx][-2])
        return self.modified_position_cache[layer_idx], self.block_cache[layer_idx], self.dependency_cache[layer_idx]
    

    def get_status(self,layer_idx):
        if len(self.block_cache) <= layer_idx:
            return None
        else: 
            return self.modified_position_cache[layer_idx], self.block_cache[layer_idx]
        



In [13]:
block_cache = Block_Cache(3)
layers = 3
query_status = torch.randn((2,3,30,3))
key_cache = torch.randn((2,3,30,3))
value_cache = torch.randn((2,3,20,3))
position_modified_ids = adaptive_position(query_status,key_cache)
for layer_idx in range(layers):
    a = block_cache.update(layer_idx,query_status,key_cache,2)
query_aware(
    query_status=query_status,
    key_cache=key_cache,
    position_ids=torch.arange(30).reshape(1,-1).long(),
    position_modified_ids= position_modified_ids,
    value_cache=value_cache,
    #layer_idx=0,
    block_status = block_cache.block_cache[layer_idx],
    block_dependency = block_cache.dependency_cache[layer_idx],
    max_length = 100,
    skip_num = 2,
    find_num = 5,
    position_method='origin'
    
)[1][0][0]

[tensor([[ 0.0759,  0.1761, -0.5493]]),
 tensor([[ 0.0759,  0.1761, -0.5493],
         [-0.6124,  0.5323, -1.2284]]),
 tensor([[ 0.0759,  0.1761, -0.5493],
         [-0.6124,  0.5323, -1.2284],
         [ 0.0125,  0.4428, -0.8751]]),
 tensor([[ 0.0759,  0.1761, -0.5493],
         [-0.6124,  0.5323, -1.2284],
         [ 0.0125,  0.4428, -0.8751],
         [-0.5447,  0.2373,  1.7358]]),
 tensor([[ 0.0759,  0.1761, -0.5493],
         [-0.6124,  0.5323, -1.2284],
         [ 0.0125,  0.4428, -0.8751],
         [-0.5447,  0.2373,  1.7358],
         [-1.4266,  0.4551, -0.0601]]),
 tensor([[ 0.0759,  0.1761, -0.5493],
         [-0.6124,  0.5323, -1.2284],
         [ 0.0125,  0.4428, -0.8751],
         [-0.5447,  0.2373,  1.7358],
         [-1.4266,  0.4551, -0.0601],
         [-0.0516, -0.3702, -0.5804]]),
 tensor([[ 0.0759,  0.1761, -0.5493],
         [-0.6124,  0.5323, -1.2284],
         [ 0.0125,  0.4428, -0.8751],
         [-0.5447,  0.2373,  1.7358],
         [-1.4266,  0.4551, -0.0601],


In [15]:
block_cache.modified_position_cache

[tensor([[[[ 0.0000,  0.7484,  1.0772,  1.4092,  1.6217,  1.8003,  1.9236,
             2.1574,  2.5313,  3.5135,  3.6902,  4.4828,  4.9626,  5.9117,
             6.8460,  6.9550,  7.7628,  7.9340,  8.2262,  8.9921,  9.4382,
             9.9781, 10.7287, 11.5100, 11.8646, 12.0024, 12.7541, 12.8277,
            13.0167, 13.1253]],
 
          [[ 0.0000,  0.2425,  1.0857,  1.5335,  2.2735,  3.0268,  3.6878,
             3.8046,  4.6013,  5.2623,  5.7119,  5.9063,  6.2766,  6.4702,
             6.8770,  7.0581,  7.1913,  7.2035,  7.7763,  8.7611,  8.9160,
             9.7278,  9.9662, 10.5574, 10.8194, 11.3474, 11.4863, 11.5088,
            12.3580, 12.9850]],
 
          [[ 0.0000,  0.4564,  1.4409,  2.0489,  2.9918,  3.3788,  4.0007,
             4.5170,  4.6228,  5.3165,  5.8892,  6.7288,  7.1764,  7.7302,
             8.5494,  9.2882,  9.9641, 10.3002, 11.1738, 11.6385, 12.1458,
            12.4396, 12.7708, 13.6429, 13.6493, 13.9483, 14.4054, 15.0909,
            15.4665, 16.1269]]],

In [None]:
import math
import warnings
from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutputWithPast,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_flash_attn_2_available,
    is_flash_attn_greater_or_equal_2_10,
    logging,
    replace_return_docstrings,
)
from .configuration_llama import LlamaConfig

In [None]:
def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None, #输入的维度是[1, seseq_lengthgth]
        past_key_value: Optional[Cache] = None,
        block_cache: Optional[Block_Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_method = 'origin',
        neighbor_block_num = 5,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        batch_size, seq_length, hidden_dim = hidden_states.size()
        

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)
        query_states = query_states.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(batch_size, seq_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(batch_size, seq_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)


        block_cache = getattr(self,"block_cache",block_cache)
        past_key_value = getattr(self,"past_key_value", past_key_value)
        

        if past_key_value is not None:
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
        
        modified_position_ids, block_states,block_dependency = block_cache.update(self.layer_idx,
                                                                                  query_status = query_states,
                                                                                  key_cache = key_states,
                                                                                  neighbor_block_num = neighbor_block_num)
        

        
        keys_to_be_used, values_to_be_used,position_to_be_used = query_aware(
            query_status = query_states,
            key_cache = key_states,
            value_cache = value_states,
            position_ids = position_ids,
            position_modified_ids = modified_position_ids,
            block_status = block_states,
            block_dependency = block_dependency,
            max_length = self.max_position_embeddings,
            skip_num = neighbor_block_num,
            find_num = neighbor_block_num,
            position_method = position_method,
            block_size = block_cache.block_size
        )

        cos, sin = self.rotary_emb(value_states, position_ids)
        
        #下一步要做的是针对每个query挑出来重要的blocks


       
        value_states = value_states.view(batch_size, seq_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        past_key_value = getattr(self, "past_key_value", past_key_value)
        cos, sin = self.rotary_emb(value_states, position_ids)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; position_ids needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask
            if cache_position is not None:
                causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (batch_size, self.num_heads, seq_length, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(batch_size, self.num_heads, seq_length, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()

        attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size)

        if self.config.pretraining_tp > 1:
            attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
            o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
            attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
        else:
            attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value