In [1]:
import torch

# 假设张量已经定义
k_states_reloaded = torch.rand(1, 4, 1620, 128)
k_activation_part1 = torch.rand(1, 4, 14, 128)
k_activation_part2 = torch.rand(1, 4, 202, 128)

# 按指定顺序在 dim=2 上拼接
result = torch.cat((k_activation_part1, k_states_reloaded, k_activation_part2), dim=2)

# 检查拼接后的形状
print(result.shape)


torch.Size([1, 4, 1836, 128])


In [8]:
import torch
import time

# 测试参数
batch_size, seq_len, embed_dim = 16, 2512, 4096  # 修改序列长度和维度测试
device = "cuda" if torch.cuda.is_available() else "cpu"

query = torch.randn(batch_size, seq_len, embed_dim, device=device)
key = torch.randn(batch_size, seq_len, embed_dim, device=device)
value = torch.randn(batch_size, seq_len, embed_dim, device=device)
attn_mask = torch.ones(seq_len, seq_len, dtype=torch.bool, device=device).tril()

# 手动实现
def manual_scaled_dot_product_attention(query, key, value, attn_mask=None):
    scale_factor = 1 / query.size(-1) ** 0.5
    attn_weights = query @ key.transpose(-2, -1) * scale_factor
    if attn_mask is not None:
        attn_weights = attn_weights.masked_fill(attn_mask.logical_not(), float("-inf"))
    attn_weights = torch.softmax(attn_weights, dim=-1)
    return attn_weights @ value

# 测试时间
def benchmark(func, *args):
    torch.cuda.synchronize()
    start = time.time()
    func(*args)
    torch.cuda.synchronize()
    return time.time() - start

# 手动实现时间
manual_time = benchmark(manual_scaled_dot_product_attention, query, key, value, attn_mask)
print(f"Manual time: {manual_time:.4f} seconds")

# PyTorch SPA 时间
spa_time = benchmark(
    torch.nn.functional.scaled_dot_product_attention,
    query, key, value, attn_mask
)
print(f"SPA time: {spa_time:.4f} seconds")

# 计算加速比
print(f"Speedup: {manual_time / spa_time:.2f}×")


Manual time: 0.1141 seconds
SPA time: 0.1138 seconds
Speedup: 1.00×


In [25]:
def scaled_dot_product_attention_for_reload_step2(attn_logits, chunk_infos, top_k=3):
    # 将每个视觉token对于文本token的attention进行求和
    print(f'attn_logits: {attn_logits}')
    print(f'attn_logits.shape : {attn_logits.shape}')
    attn_logits_sum = attn_logits.sum(dim=-2) # (batch_size, nheads, seqlen)
    print(f'attn_logits_sum: {attn_logits_sum}')
    print(f'attn_logits_sum.shape : {attn_logits_sum.shape}')
    attn_logits_avg = attn_logits_sum.mean(dim=1) # 将每个 head 求平均 (batch_size, seqlen)
    print(f'attn_logits_avg: {attn_logits_avg}')
    print(f'attn_logits_avg.shape : {attn_logits_avg.shape}')

    batch_chunk_avg = []

    # 对 batch 维度逐一处理
    for batch in attn_logits_avg:
        # 保存当前 batch 的 chunk 平均值
        chunk_avg = []
        print(f'batch.shape: {batch.shape}')
        for chunk_idx, start, end in chunk_infos:
            print(f'start: {start}, end: {end}')
            # 对当前 chunk 的元素取均值
            chunk_score = batch[start: end].mean()  
            chunk_avg.append(chunk_score)
        batch_chunk_avg.append(chunk_avg)

    # 转换为张量，形状为 (batch_size, n_chunks)
    batch_chunk_avg = torch.tensor(batch_chunk_avg)
    print(f'batch_chunk_avg: {batch_chunk_avg}')
    print(f'batch_chunk_avg.shape: {batch_chunk_avg.shape}')

    # chunk_infos 的数量 小于 k，全部纳入
    if batch_chunk_avg.shape[-1] < top_k:
        topk_values = batch_chunk_avg
        topk_indices = torch.arange(batch_chunk_avg.shape[-1])
        return topk_values, topk_indices

    topk_values, topk_indices = torch.topk(batch_chunk_avg, k=top_k, dim=1)
    return topk_values, topk_indices[0]

import torch

attn_logits = torch.rand(1,4,8,256)
chunk_infos = [(1, 16,32),(2, 64,128), (3, 129,139),(4,200, 210)]

topk_values, topk_indices = scaled_dot_product_attention_for_reload_step2(attn_logits, chunk_infos)

attn_logits: tensor([[[[4.6599e-01, 9.3160e-01, 6.0031e-01,  ..., 8.3951e-01,
           3.9051e-01, 3.6991e-01],
          [9.1105e-01, 9.6870e-01, 1.9888e-01,  ..., 4.7841e-01,
           7.0689e-01, 8.2734e-01],
          [9.7203e-01, 9.3085e-03, 9.1803e-01,  ..., 2.8146e-01,
           9.7307e-01, 3.3683e-01],
          ...,
          [5.2449e-01, 1.8395e-01, 5.4980e-01,  ..., 8.5767e-01,
           6.9943e-01, 6.4617e-01],
          [9.8120e-01, 1.6080e-01, 2.5963e-01,  ..., 9.2691e-01,
           6.9573e-01, 3.1241e-01],
          [6.1523e-01, 6.5969e-01, 1.0094e-01,  ..., 7.5538e-01,
           2.7946e-01, 3.8473e-01]],

         [[7.6536e-01, 7.7824e-01, 9.7660e-01,  ..., 3.1278e-01,
           8.5102e-01, 6.7716e-01],
          [2.7048e-01, 9.2083e-01, 7.0344e-01,  ..., 4.6432e-01,
           9.5548e-01, 3.5643e-01],
          [1.9052e-01, 5.7030e-04, 1.8468e-01,  ..., 4.0759e-01,
           2.8387e-01, 8.3967e-01],
          ...,
          [2.6715e-01, 6.0837e-01, 1.7380e-01,

In [26]:
for indice in topk_indices:
    chunk_idx_selected = indice 
    print(f'随机选中的 chunk: {chunk_idx_selected}')
    chunk_info = chunk_infos[chunk_idx_selected]
    print(chunk_info)

随机选中的 chunk: 1
(2, 64, 128)
随机选中的 chunk: 0
(1, 16, 32)
随机选中的 chunk: 3
(4, 200, 210)


In [24]:
topk_indices

tensor([[1, 3, 0]])

In [11]:
import torch
tgt_size = 4
mem_size = 2
batch_size = 1
src_size = 6
causal_mask = torch.full((tgt_size, tgt_size), False, dtype=torch.bool)

mask_cond = torch.arange(causal_mask.size(-1))
causal_mask.masked_fill_(mask_cond < (mask_cond + 1).view(causal_mask.size(-1), -1), True)

causal_mask = torch.cat([torch.ones(tgt_size, mem_size, dtype=torch.bool), causal_mask], dim=-1)

causal_mask = causal_mask[None, None, ...].expand(batch_size, 1, tgt_size, src_size)
causal_mask

tensor([[[[ True,  True,  True, False, False, False],
          [ True,  True,  True,  True, False, False],
          [ True,  True,  True,  True,  True, False],
          [ True,  True,  True,  True,  True,  True]]]])

In [12]:
from loguru import logger as eval_logger
import sys

# 清空默认的 logger 配置
eval_logger.remove()

# 添加一个新的 handler，指定输出到控制台，级别为 INFO
eval_logger.add(sys.stderr, level="INFO")

# 测试日志
eval_logger.debug("This is a DEBUG message")  # 不会显示
eval_logger.info("This is an INFO message")   # 会显示
eval_logger.error("This is an ERROR message")  # 会显示


[32m2024-11-28 10:30:59.981[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m12[0m - [1mThis is an INFO message[0m
[32m2024-11-28 10:30:59.982[0m | [31m[1mERROR   [0m | [36m__main__[0m:[36m<module>[0m:[36m13[0m - [31m[1mThis is an ERROR message[0m


In [11]:
eval_logger

<loguru.logger handlers=[(id=5, level=20, sink='stderr')]>

In [7]:



# L = 4
# S = 8
# res = torch.zeros(1,1,L, S, dtype=torch.float32)

tensor([[[[True, True, True, True],
          [True, True, True, True],
          [True, True, True, True],
          [True, True, True, True]],

         [[True, True, True, True],
          [True, True, True, True],
          [True, True, True, True],
          [True, True, True, True]],

         [[True, True, True, True],
          [True, True, True, True],
          [True, True, True, True],
          [True, True, True, True]]]])

In [None]:
import torch


attn_mask = torch.zeros((1,3,2,4),dtype=torch.bool)
attn_bias = torch.zeros(attn_mask.shape, dtype=torch.float32)
attn_bias = attn_bias.to(attn_mask.device)


if attn_mask is not None:
    if attn_mask.dtype == torch.bool:
        attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
    else:
        attn_bias += attn_mask
# attn_weight = query @ key.transpose(-2, -1) * scale_factor
    
# attn_weight += attn_bias.to(query.device)
# attn_weight = torch.softmax(attn_weight, dim=-1)
# attn_logits = attn_weight
attn_bias

In [1]:
import torch

test = [-14464.6035]

torch.tensor(test, dtype=torch.float16)

tensor([-14464.], dtype=torch.float16)

In [5]:
import random

effective_chunk_infos = [1,2,3]

bsz = 2

range(len(effective_chunk_infos))

topk_indices = random.sample(range(len(effective_chunk_infos)), 2)
topk_indices

[0, 2]

In [7]:
str(list(range(0,28)))

'[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]'

In [None]:
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]

if layer_idx not in [0, 1, 2, 14, 15, 16, 25, 26, 27]:
    continue
