In [1]:
import torch
import torch.nn as nn
from transformers import LlamaForCausalLM, AutoTokenizer
from transformers.models.llama.modeling_llama import rotate_half, repeat_kv
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import math

  return torch._C._cuda_getDeviceCount() if nvml_count < 0 else nvml_count


目前需要完成的事情有两件: 

1. 验证对于两个token1 & token2, query(token1)$\sim$ key(token2), 则 key(token1) $\sim$ key(token2). 进一步我希望可以有: 对任意query, 若query(token1) $\sim$ key(token2), 且 query $\sim$ key(token1), 那么有 query(token2), 并且这个二元关系可以远程传递不衰减.


2. block 分块的超参数大小

In [None]:
llama_dir = '/mntcephfs/data/ruoyusun/liziniu/.cache/huggingface/hub/models--meta-llama--Llama-2-7b-hf/snapshots/8a0442e81540efaeb1a0fe3e95477b5e0edfd423'
llama = LlamaForCausalLM.from_pretrained(llama_dir,attn_implementation="eager")
max_positions = 4096
attn_bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
                1, 1, max_positions, max_positions
            )
def attention_score_wo_rotary(layer_idx, hidden_states, num_heads=32, head_dim = 128):
    bsz, q_len, _ = hidden_states.size()
    
    attn_model = llama.model.layers[layer_idx]
    
    layer_norm = attn_model.input_layernorm
    
    hidden_states = layer_norm(hidden_states)
    
    query_states = attn_model.self_attn.q_proj(hidden_states)
    key_states = attn_model.self_attn.k_proj(hidden_states)
    value_states = attn_model.self_attn.v_proj(hidden_states)
    
    query_states = query_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
    key_states = key_states.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
    value_states = value_states.view(bsz, q_len, num_heads,head_dim).transpose(1, 2)
    
    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim)
    attn_shape = attn_weights.shape
    
    query_length, key_length = attn_shape[-2],attn_shape[-1]
    causal_mask = attn_bias[:, :, key_length - query_length : key_length, :key_length]
    mask_value = torch.finfo(attn_weights.dtype).min
    mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
    attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
    return attn_weights[0]

tokenizer = AutoTokenizer.from_pretrained(llama_dir)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

#### 在推理阶段, query位置增加的速度要改变, 就像cope中keys距离增加的速度改变那样.
比如在序列[a,b,c,d]中, 他们的位置是[0,1,2,3], 这个时候新来一个e, 则他的位置是5.


现在对这个形式进行改变, 还是[a,b,c,d]这个序列, 我们在计算position的时候, 考虑一个映射$f:\mathbb{R}\to (0,1)$, 用pos(a,b)表示token a,b之间的距离, 则定义:
$$
pos(a,b) = f(q_b^T\cdot k_a)
$$,
则他们之间的距离可以用序列$$[pos(a,b), pos(b,c), pos(c,d)]$$来表示. 那么此时计算position变为:

$$
[0,0+pos(a,b), pos(a,b)+pos(b,c), pos(a,b)+pos(b,c)+pos(c,d)]
$$
这个数字可以是任何数, 比如$[0.3,0.7,0.5]\to [0,0.3, 1.0, 1.5]$. 这样在我们新添加一个新的token e, 每次只需要计算一个额外的$q^T k$

In [3]:
class LlamaRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
        super().__init__()
        self.scaling_factor = scaling_factor
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        # For BC we register cos and sin cached
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
        t = t / self.scaling_factor
        freqs = torch.outer(t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
        self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)

    @torch.no_grad()
    def forward(self, x, position_ids):

        # x: [bs, num_attention_heads, seq_len, head_size]
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        position_ids_expanded = position_ids[:, None, :].float()
        # Force float32 since bfloat16 loses precision on long contexts
        # See https://github.com/huggingface/transformers/pull/29285
        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos()
            sin = emb.sin()
        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
        
def position_update(self, query, key):
    