In [None]:
import torch
from torch import nn
import numpy as np
import random
from matplotlib import pyplot as plt
from typing import Any, Dict, List, Optional, Tuple
from transformers.cache_utils import Cache

In [None]:
def distance_fun(value):
    #初步先定为sigmoid函数
    return torch.sigmoid(value)
    
    
def adaptive_position( query:torch.Tensor, key:torch.Tensor,position_modified_ids=None)->torch.Tensor:
    '''输入的维度应该是 
    key: [batch, num_heads, seq_length, hidden_dim]
    
    in prefilling stage:
    query: [batch, num_heads, seq_length, hidden_dim]
    
    in inference stage:
    query: [batch, num_heads, 1, hidden_dim]'''
    batch,num_heads = query.shape[0], query.shape[1]
    if position_modified_ids == None:
        #此时应该是预填充阶段
        #注意要错位相乘, 因此query要去第一个, key要去最后一个
        query = query[:,:,1:,:].unsqueeze(-2)
        key = key[:,:,:-1,:].unsqueeze(-2)
        distance = distance_fun(torch.matmul(query,key.transpose(-2,-1)).squeeze(-1)).transpose(-1,-2)
        
        #这个distance相比于真正的距离差着一个一号位, 我们都强制命令一号位的位置为0
        distance_0 = torch.zeros((batch,num_heads,1,1),dtype=distance.dtype)
        distance = torch.cat((distance_0,distance),dim=-1) #在一号位添加一个0
        
        position_modified_ids = distance.cumsum(dim=-1) #此时的输出维度是[batch,num_head, 1, seq_length]
    
    else:
        #这个时候的输入一般是"一个"query, 我们的任务是计算新的query和倒数第二个key之间的距离.
        
        new_distance = distance_fun(torch.matmul(query, key[:,:,-2,:].unsqueeze(-2).transpose(-2,-1)))
        position_modified_ids = new_distance + position_modified_ids[:,:,:,-1].unsqueeze(-2)
    return position_modified_ids #返回的结果: 若是在pre-filling阶段, 则是[batch, num_head, 1, seq_length], 若是在推理阶段, 则是[batch, num_head,1,1]
        
def query_select(query_status, block_ends, key_cache, block_lst:list):#从给定的blocks中挑选出来最有用的那个
    #query: [1,hidden_dim]
    #block_ends: list[int] 
    #key_cache : [cache_length,hidden_dim]
    block_num = len(block_ends)
    scores = []
    for i in range(1,block_num):
        start = block_ends[i-1]
        end = block_ends[i]
        key_status = key_cache[start:end,:]
        score = torch.mean((query_status@key_status.T))
        scores.append(score)
    max_id = torch.argmax(torch.tensor(scores))+1
    end = block_ends[max_id]
    return block_lst.index(end)

def get_block_score(query_status,key_cache,obj_block, end_block):
    obj_key = key_cache[obj_block[0]:obj_block[1],:] #[obj_length, hidden_dim]
    end_query = query_status[end_block[0]:end_block[1],:] #[end_length, hidden_dim]
    score = torch.mean(end_query@obj_key.T)
    return score


def block_select(query_status,key_cache,block_lst,neighbor_block_num):
    #query_status:[seq_length, hidden_dim]
    #key_cache : [cache_length, hidden_dim]
    #block_lst : list[]
    #输出为: block_list中最后一个block在neighbor_block中最关注的block
    last_block_end = block_lst[-1]
    last_block_start = block_lst[-2]
    scores = []
    if len(block_lst) <= neighbor_block_num: #说明前面的所有block都需要检查
        block_lst = [0]+block_lst
        for i in range(len(block_lst)-1):
            
            obj_block = []
            score = get_block_score()
    elif len(block_lst) > neighbor_block_num:


def block_trace(block_dependency,
                select_block_id, 
                left_length,block_lst):#这个block_dependecy是针对某个batch和head_idx的
    #key_cache:[seq_length,hidden_dim]
    #输入的left_length 是去除掉[skip_block]和[sink_block]以后的长度
    select_keys_position = []
    while block_dependency[select_block_id] != None and left_length>0: #此时说明select_block_id不是0
        select_block_id = block_dependency[select_block_id]
        end = block_lst[select_block_id]
        start = block_lst[select_block_id-1]
        select_keys_position = [[start,end]] + select_keys_position
        left_length -= (end-start)
    return select_keys_position
    

def query_aware(query:torch.Tensor, #[batch,num_heads, seq_length,hidden_dim]
                key_cache, #[batch, num_heads,cache_length, hidden_dim]
                position_ids,
                block_status, 
                block_dependency,
                max_length,
                skip_num=5, #其实这个skip_block更应该叫做neighbor_block, 我懒得改了
                find_num = 5,
                sink_num = 1):
    #输入的query维度是[batch, num_head, seq_length, hidden_dim]
    #输入的block是{batch1:{head1:[...],head2:[...],...},batch2:{head1:[...],head2:[...],...},...}
    #输入的position_ids维度是[1,seq_length]
    #输出是字典: {batch1:{head1:[[a1,a2,...],[b1,b2,...]]}}字典嵌套, 每个head的value是列表, 列表中每个元素是query挑选出来的token_ids, 列表的长度是seq_length
    batch_size, head_num = query.shape[:2]
    position_min = position_ids[0,0]
    seq_length = position_ids.shape[-1]
    keys_to_be_used = {}
    for batch in range(batch_size):
        keys_to_be_used[batch] = {} #理想状态的key_cache的组合应该是: [sink_block] + [select_block] + [skip_block]
        for head_idx in range(head_num):
            
            i = 0
            block_lst = block_status[batch][head_idx]
            element_num = len(block_lst)
            for delta in range(seq_length):
                query_position = position_min+delta
                while block_lst[i] < query_position and i < element_num: #lst中第i-1个元素就是我们要找的那个, 即共有0...i-1 共计i个元素在query前面
                    i += 1
                if i <= skip_num: #说明我这个query前面的所有key_cache都有用, 此时没有select_block, 所有的sink_block和ship_block重合
                    keys_to_be_used[batch][head_idx] = key_cache[batch,head_idx,:query_position,:]

                else: #在这种状态下, key_cache的组合应该是: [sink_block] + [select_block] + [skip_block]
                    end = i-skip_num 
                    if end > find_num: 
                        start = end-find_num-1
                        #接下来的block_ends只是用来在其中根据query挑选一个最有用的出来, 使用函数query_select
                        block_ends = block_lst[start:end] #注意, 此时len(block_ends) = end-start = find_num+1, 第一个元素是上个block的结尾
                    else: #此时满足: 跳过了一些block, 然后剩下的block不够了
                        block_ends = [0] + block_lst[:end]
                    query_status = query[batch,head_idx,query_position,:]
                    selected_block_id = query_select(query_status,block_ends, key_cache[batch,head_idx,:,:],block_lst) #注意: 返回的是block_lst中被选中的block的id
                    

In [None]:
class Block_Cache():
    def __init__(self,block_size=256) -> None:
        self.position_cache:list[torch.Tensor] = []#列表中每个元素是三维的tensor[batch, num_heads, seq_length]
        self.block_cache = [] #列表中每个元素是一个字典嵌套字典的结构{batch:{head_idx:{[78,155,283,...]}}}
        self.block_size = block_size
        self.block_dependency = []#每一个元素都是一个dict{batch: {head_idx:[None, 78, 78,...]}} 其中列表的长度应该和block_cache中列表的长度相同 

        
    def get_block_token_ids(self,query_ids:torch.Tensor, block_num = 5): #输入的query形状为[batch, num_heads, 1,query_length]
        #对应的输出维度应该是[batch, num_heads, query_length, block_num]
        batch_size, num_heads = query_ids.shape[0:2]
        seq_length =query_ids.shape[-1]
        for batch in range(batch_size):
            for head_idx in range(num_heads):
                query_positions = query_ids[batch,head_idx]#[1,seq_length]

    def update(self, 
               position_modified_ids: torch.Tensor, 
               layer_idx: int,):
        position_modified_ids = position_modified_ids % self.block_size
        if len(self.position_cache) <= layer_idx:
            self.position_cache.append(position_modified_ids)
        else:
            self.position_cache[layer_idx] = torch.cat([self.position_cache[layer_idx], position_modified_ids], dim = -1)
        
        
        batch_size, num_heads = position_modified_ids.shape[0:2]
        seq_length = self.position_cache[layer_idx].shape[-1]

        if len(self.block_cache) <= layer_idx:
            block_record = {}
            block_dependency = {}
            for batch in range(batch_size):
                block_record[batch] = {}
                for head in range(num_heads):
                    block_record[batch][head] = []
                    block_dependency[batch][head] = []
                    for i in range(seq_length-1):
                        if position_modified_ids[batch,head,0,i] < position_modified_ids[batch,head,0,i-1]: #记录每个block的最后一个token后一个的位置, 这是为了在之后切片的时候统一计算
                            block_record[batch][head].append(i)
            self.block_cache.append(block_record)
        
        else:
            block_record = self.block_cache[layer_idx]
            assert seq_length == 1 , f'在推理阶段假设一次只新增一个token, 而现在新增了{seq_length}个'
            
            for batch in range(batch_size):
                for head_idx in range(num_heads):
                    position_ids = self.position_cache[layer_idx]
                    if position_ids[batch,head_idx,0,-1] <  position_ids[batch,head_idx,0,-2]:
                        block_record[batch][head_idx].append(seq_length-1)
        return self.position_cache[layer_idx], self.block_cache[layer_idx]
    

    def get_status(self,layer_idx):
        if len(self.block_cache) <= layer_idx:
            return None
        else: 
            return self.position_cache[layer_idx], self.block_cache[layer_idx]