In [4]:
import torch
import torch.nn as nn
class CrossAttention(nn.Module):
    def __init__(self, input_dim):
        super(CrossAttention, self).__init__()
        self.input_dim = input_dim
        self.linear = nn.Linear(input_dim, input_dim)
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self, query, key, value):
        # query: (batch_size, seq_len_q, input_dim)
        # key: (batch_size, seq_len_k, input_dim)
        # value: (batch_size, seq_len_v, input_dim)

        # 计算注意力权重
        scores = torch.matmul(query, key.transpose(1,2))    # (batch_size, seq_len_q, seq_len_k)
        attn_weights = self.softmax(scores)                 # (batch_size, seq_len_q, seq_len_k)

        # 使用注意力权重加权求和
        weighted_values = torch.matmul(attn_weights, value) # (batch_size, seq_len_q, input_dim)

        # 应用线性变换
        output = self.linear(weighted_values)               # (batch_size, seq_len_q, input_dim)

        return output

input_dim = 128

cross_attention = CrossAttention(input_dim)
batch_size = 10
seq_len_q = 5
seq_len_k = 7
seq_len_v = 7

query = torch.randn(batch_size, seq_len_q, input_dim)
key = torch.randn(batch_size, seq_len_k, input_dim)
value = torch.randn(batch_size, seq_len_v, input_dim)


# 最后的输出长度大小会和query保持一致
output = cross_attention(query, key, value)
output.shape


torch.Size([10, 5, 128])

In [8]:
import numpy as np
def quantize_to_8int(matrix):
    """"简化的量化函数，将矩阵元素量化为8位整数"""  
    min_val = matrix.min()
    max_val = matrix.max()
    scale = 255 / (max_val - min_val)
    quantized = np.round(scale * (matrix - min_val)).astype(np.uint8)
    return quantized
def matrix_multiply(A, B):
    """矩阵乘法, 使用8位整数量化"""
    A_quantized = quantize_to_8int(A)
    B_quantized = quantize_to_8int(B)
    return np.matmul(A_quantized, B_quantized)

# 测试矩阵乘法
A = np.random.rand(10, 10)
B = np.random.rand(10, 10)
# 执行8位量化矩阵乘法
result_float = np.matmul(A, B)
result = matrix_multiply(A, B)
result, result_float

(array([[ 10, 215, 191, 141, 134, 175, 145,  61, 137, 214],
        [ 15,  58,  20, 233, 209, 239,  83, 108,  40, 134],
        [ 10, 118, 205, 238,  12, 139,  11,  65, 163, 139],
        [216, 104,  48,   6,  19, 252,  35, 229, 225,  12],
        [135,  94,  78, 140,  95,  37,  89,  78,   6,  83],
        [104, 116, 145,  48, 222, 142,  27, 188, 144,  60],
        [ 41, 155, 247,  68,  31, 229, 238, 208, 242, 193],
        [132,  49, 244, 194, 125, 187,  21, 215,  13, 243],
        [ 54, 135,  13, 106, 133, 194, 125,  37, 135, 160],
        [216, 152, 174, 199, 111,  67, 175, 246, 157, 177]], dtype=uint8),
 array([[2.2502748 , 2.74346383, 2.00174007, 2.75396128, 3.60554111,
         2.79526385, 1.66875728, 2.52235823, 2.98987255, 2.39911944],
        [2.27275922, 3.01173176, 2.18492908, 2.24665088, 3.45180009,
         2.89761788, 2.03289828, 3.47495021, 3.06831405, 2.85119874],
        [2.76123377, 3.33120267, 2.23834836, 2.37674244, 3.64615766,
         2.59872596, 2.06849547, 3.465