In [1]:
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
import sys
sys.path.append("..")

  from .autonotebook import tqdm as notebook_tqdm


In [15]:
import torch
import torch.nn as nn

class PositionEmbedding(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(PositionEmbedding, self).__init__()
        self.fc_dist = nn.Linear(input_dim-1, output_dim)  # for distance
        self.fc_angle = nn.Linear(1, output_dim)  # for angle

    def forward(self, x):
        # Compute pairwise distance matrix
        x1 = x[:, :2].unsqueeze(1).repeat(1, x.size(0), 1)
        x2 = x[:, :2].unsqueeze(0)
        dist = torch.sqrt(((x1 - x2) ** 2).sum(-1))

        # Compute pairwise angle difference
        angle1 = x[:, 2].unsqueeze(1).repeat(1, x.size(0))
        angle2 = x[:, 2].unsqueeze(0)
        angle_diff = angle1 - angle2

        # Pass through fully connected layers
        # out_dist = self.fc_dist(dist.unsqueeze(-1))
        # out_angle = self.fc_angle(angle_diff.unsqueeze(-1))

        # Concatenate distance and angle embeddings
        # out = torch.cat([out_dist, out_angle], dim=-1)

        return dist, angle_diff


In [16]:
# Test the model
A = 4
input_dim = 3
output_dim = 1
model = PositionEmbedding(input_dim, output_dim)

# Create a tensor of shape (A, 2)
x = torch.randn(A, input_dim)

# Forward pass
out = model(x)

# Print the output
print(x)
print(out)

tensor([[ 1.5500, -0.9624, -1.5492],
        [-0.5705,  0.6033, -0.3426],
        [-0.1369, -1.2059, -1.2736],
        [-1.7726,  0.3708, -0.2174]])
(tensor([[0.0000, 2.6359, 1.7043, 3.5802],
        [2.6359, 0.0000, 1.8605, 1.2245],
        [1.7043, 1.8605, 0.0000, 2.2720],
        [3.5802, 1.2245, 2.2720, 0.0000]]), tensor([[ 0.0000, -1.2066, -0.2756, -1.3318],
        [ 1.2066,  0.0000,  0.9310, -0.1252],
        [ 0.2756, -0.9310,  0.0000, -1.0562],
        [ 1.3318,  0.1252,  1.0562,  0.0000]]))


In [10]:
math.sqrt((1.1384 - 0.7685)**2 + (1.3972-1.2253)**2)

0.40789167679667115

In [31]:
import math
import copy
from typing import Optional, List

import torch
import torch.nn.functional as F
from torch import nn, Tensor


def _scaled_dot_product_attention(q, k, v, attn_mask=None, dropout=0.0):
    # q           (B * nhead, tgt_len, head_dim)    
    # kv          (B * nhead, src_len, head_dim)    
    # attn_mask   (B * nhead, 1 or tgt_len, src_len)
    # out         (B * nhead, tgt_len, head_dim)

    B, Nt, E = q.shape
    q = q / math.sqrt(E)

    # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns)
    attn = torch.bmm(q, k.transpose(-2, -1))

    # attn mask will set -inf to attn positions that must be masked
    # mask is 0 by default so no masking takes place
    if attn_mask is not None:
        attn += attn_mask

    attn = F.softmax(attn, dim=-1)

    if dropout > 0.0:
        attn = F.dropout(attn, p=dropout)

    # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E)
    output = torch.bmm(attn, v)

    return output, attn


def _in_projection_packed(q: torch.Tensor,
                          k: torch.Tensor,
                          v: torch.Tensor,
                          w: torch.Tensor,
                          b: Optional[torch.Tensor] = None):
    r"""
    Performs the in-projection step of the attention operation, using packed weights.
    Output is a triple containing projection tensors for query, key and value.

    Args:
        q, k, v: query, key and value tensors to be projected. For self-attention,
            these are typically the same tensor; for encoder-decoder attention,
            k and v are typically the same tensor. (We take advantage of these
            identities for performance if they are present.) Regardless, q, k and v
            must share a common embedding dimension; otherwise their shapes may vary.
        w: projection weights for q, k and v, packed into a single tensor. Weights
            are packed along dimension 0, in q, k, v order.
        b: optional projection biases for q, k and v, packed into a single tensor
            in q, k, v order.

    Shape:
        Inputs:
        - q: :math:`(..., E)` where E is the embedding dimension
        - k: :math:`(..., E)` where E is the embedding dimension
        - v: :math:`(..., E)` where E is the embedding dimension
        - w: :math:`(E * 3, E)` where E is the embedding dimension
        - b: :math:`E * 3` where E is the embedding dimension

        Output:
        - in output list :math:`[q', k', v']`, each output tensor will have the
            same shape as the corresponding input tensor.
    """
    E = q.shape[-1]
    if k is v:
        if q is k:
            # self-attention
            return F.linear(q, w, b).chunk(3, dim=-1)
            # q:        (B, *, in_features)         -> (..., E)
            # w:        (out_features, in_features) -> (E * 3, E)
            # b:        (out_features)              -> (E * 3)
            # lin_out:  (B, *, out_features)        -> (..., E * 3)
            # chunk_out:                            -> 3 * (..., E)
        else:
            # encoder-decoder attention
            w_q, w_kv = w.split([E, E * 2])
            if b is None:
                b_q = b_kv = None
            else:
                b_q, b_kv = b.split([E, E * 2])
            # will concat q_out with k_out v_out
            #                            |
            #                            V
            return (F.linear(q, w_q, b_q), ) + F.linear(k, w_kv, b_kv).chunk(
                2, dim=-1)
    else:
        w_q, w_k, w_v = w.chunk(3)
        if b is None:
            b_q = b_k = b_v = None
        else:
            b_q, b_k, b_v = b.chunk(3)
        return F.linear(q, w_q, b_q), F.linear(k, w_k,
                                               b_k), F.linear(v, w_v, b_v)


class myMultiheadAttention(nn.Module):

    def __init__(self, d_model, nhead, dropout=0.0, batch_first=False, bias=True):
        super().__init__()
        self.d_model = d_model
        self.nhead = nhead
        self.dropout = dropout
        self.batch_first = batch_first

        self.head_dim = d_model // nhead
        assert (self.head_dim * nhead == d_model), "d_model % nhead != 0"

        self.in_proj_weight = nn.Parameter(torch.empty((3 * d_model, d_model)))
        self.register_parameter('q_proj_weight', None)
        self.register_parameter('k_proj_weight', None)
        self.register_parameter('v_proj_weight', None)
        
        if bias:
            self.in_proj_bias = nn.Parameter(torch.empty(3 * d_model))
        else:
            self.register_parameter("in_proj_bias", None)
        self.out_proj = nn.Linear(d_model, d_model, bias=bias)

        self._reset_parameters()

    def _reset_parameters(self):
        nn.init.xavier_uniform_(self.in_proj_weight)
        if self.in_proj_bias is not None:
            nn.init.constant_(self.in_proj_bias, 0.)
            nn.init.constant_(self.out_proj.bias, 0.)

    def forward(self,
                query: torch.Tensor,
                key: torch.Tensor,
                value: torch.Tensor,
                attn_mask: Optional[Tensor] = None,
                key_padding_mask: Optional[Tensor] = None):
        #                     Enc             Dec tgt         Dec mem
        # query, key, value:  [672, 2, 256]   [100, 2, 256]   [100, 2, 256], [672, 2, 256], [672, 2, 256]
        # attn_mask:          None            None            None
        # key_padding_mask:   [2, 672]        None            [2, 672]
        # output:             [672, 2, 256]   [100, 2, 256]   [100, 2, 256]

        # key_padding_mask: used to mask out padding positions after the end
        #                   of the input sequence. It depends on the longest
        #                   sequence in the batch. Shape (B, src seq length)

        # attn_mask:        used in decoders to prevent attention to future
        #                   positions using a triangle mask.
        #                   2D shape: (tgt seq length, src seq length)
        #                   3D shape: (B*nhead, tgt seq length, src seq length)

        # q:                (tgt seq length, B, C)
        # kv:               (src seq length, B, C)
        # out:
        #   - attn_output           (tgt seq length, B, C)
        #   - attn_output_weights   (B, tgt seq length, C)

        is_batched = query.dim() == 3
        if self.batch_first and is_batched:
            query, key, value = [
                x.transpose(1, 0) for x in (query, key, value)
            ]

        tgt_len, batch_size, embed_dim = query.shape
        src_len, _, _ = key.shape

        assert (embed_dim == self.d_model
                ), f"expected hidden dim = {self.d_model}, but got {embed_dim}"
        assert (
            key.shape == value.shape
        ), f"key shape {key.shape} does not match value shape {value.shape}"

        # compute in-projection
        q, k, v = _in_projection_packed(query, key, value, self.in_proj_weight,
                                        self.in_proj_bias)

        # prep attention mask
        if attn_mask is not None:
            if attn_mask.dtype == torch.uint8:
                attn_mask = attn_mask.to(torch.bool)
            assert attn_mask.is_floating_point(
            ) or attn_mask.dtype == torch.bool, "wrong attn_mask type"

            if attn_mask.dim() == 2:
                assert (tgt_len,
                        src_len) == attn_mask.shape, "wrong attn_mask shape"
                attn_mask = attn_mask.unsqueeze(0)
                # add artificial batch_size=1
            elif attn_mask.dim() == 3:
                assert (batch_size * self.nhead, tgt_len,
                        src_len) == attn_mask.shape, "wrong attn_mask shape"
            else:
                assert False, "wrong attn_mask shape"

        # prep key padding mask
        if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
            key_padding_mask = key_padding_mask.to(torch.bool)

        # reshape q, k, v for multihead attention and make em batch first
        # q:    (tgt_len, B, C)->(tgt_len, B, nhead * head_dim)->
        #       (tgt_len, B * nhead, head_dim)->(B * nhead, tgt_len, head_dim)
        q = q.contiguous().view(tgt_len, batch_size * self.nhead,
                                self.head_dim).transpose(0, 1)

        # kv:   (src_len, B, C)->(src_len, B, nhead * head_dim)->
        #       (src_len, B * nhead, head_dim)->(B * nhead, src_len, head_dim)
        # .view(-1, ...) lets python compute the first dim based on the other dims specified
        k = k.contiguous().view(-1, batch_size * self.nhead,
                                self.head_dim).transpose(0, 1)
        v = v.contiguous().view(-1, batch_size * self.nhead,
                                self.head_dim).transpose(0, 1)

        # update source sequence length after adjustments
        src_len = k.shape[1]

        # merge key padding and attention masks
        if key_padding_mask is not None:
            assert key_padding_mask.shape == (
                batch_size, src_len
            ), f"expecting key_padding_mask shape of {(batch_size, src_len)}, but got {key_padding_mask.shape}"

            key_padding_mask = key_padding_mask.view(batch_size, 1, 1, src_len)
            key_padding_mask = key_padding_mask.expand(-1, self.nhead, -1, -1)
            # -1 means not changing the size of that dimension
            key_padding_mask = key_padding_mask.reshape(
                batch_size * self.nhead, 1, src_len)

            if attn_mask is None:
                attn_mask = key_padding_mask
            elif attn_mask.dtype == torch.bool:
                attn_mask = attn_mask.logical_or(key_padding_mask)
            else:
                attn_mask = attn_mask.masked_fill(key_padding_mask,
                                                  float("-inf"))

        # convert mask to float
        if attn_mask is not None and attn_mask.dtype == torch.bool:
            new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
            new_attn_mask.masked_fill_(attn_mask, float("-inf"))
            attn_mask = new_attn_mask
        
        if not self.training:
            self.dropout = 0.0

        # (deep breath) calculate attention and out projection
        attn_output, attn_output_weights = _scaled_dot_product_attention(
            q, k, v, attn_mask, self.dropout)
        #attn_output            [16, 100, 32]
        #attn_output_weights    [16, 100, 672]

        attn_output = attn_output.transpose(0, 1).contiguous().view(
            tgt_len, batch_size, embed_dim)
        #attn_output [16, 100, 32]->[100, 16, 32]->[100, 2, 256]

        #attn_output            [100, 2, 256]
        #self.out_proj.weight   [256, 256]
        #self.out_proj.bias     [256]
        attn_output = F.linear(attn_output, self.out_proj.weight,
                               self.out_proj.bias)
        #attn_output [100, 2, 256]
        if self.batch_first and is_batched:
            return attn_output.transpose(1, 0), attn_output_weights
        else:
            return attn_output, attn_output_weights

In [33]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy

# 创建一个随机输入
query = torch.randn(10, 32, 128)
key = torch.randn(10, 32, 128)
value = torch.randn(10, 32, 128)

# 创建并运行nn.MultiheadAttention
multihead_attn = nn.MultiheadAttention(128, 8, batch_first=True)
output1, _ = multihead_attn(query, key, value)

# 创建并运行ManualMultiheadAttention
manual_multihead_attn = myMultiheadAttention(128, 8, batch_first=True)

# 将nn.MultiheadAttention的参数复制到ManualMultiheadAttention
manual_multihead_attn.load_state_dict(copy.deepcopy(multihead_attn.state_dict()))

output2, _ = manual_multihead_attn(query, key, value)

# 检查两个输出是否一致
print(torch.allclose(output1, output2))


In [2]:
qlen = 10
klen = 10
context_position = torch.arange(qlen, dtype=torch.long,
                                        )[:, None]
memory_position = torch.arange(klen, dtype=torch.long,
                                      )[None, :]

In [6]:
memory_position - context_position

tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],
        [-1,  0,  1,  2,  3,  4,  5,  6,  7,  8],
        [-2, -1,  0,  1,  2,  3,  4,  5,  6,  7],
        [-3, -2, -1,  0,  1,  2,  3,  4,  5,  6],
        [-4, -3, -2, -1,  0,  1,  2,  3,  4,  5],
        [-5, -4, -3, -2, -1,  0,  1,  2,  3,  4],
        [-6, -5, -4, -3, -2, -1,  0,  1,  2,  3],
        [-7, -6, -5, -4, -3, -2, -1,  0,  1,  2],
        [-8, -7, -6, -5, -4, -3, -2, -1,  0,  1],
        [-9, -8, -7, -6, -5, -4, -3, -2, -1,  0]])

In [11]:
from src.model.layers.relative_position_bias import RelativePositionBias

pos_embed = RelativePositionBias(bidirectional=False,
                                 num_buckets=20,
                                 max_distance=10,
                                 n_heads=8)

In [15]:
pos_embed._relative_position_bucket(memory_position - context_position, bidirectional=True,
                                    num_buckets=19, max_distance=10)

tensor([[ 0, 10, 11, 12, 13, 14, 15, 16, 16, 17],
        [ 1,  0, 10, 11, 12, 13, 14, 15, 16, 16],
        [ 2,  1,  0, 10, 11, 12, 13, 14, 15, 16],
        [ 3,  2,  1,  0, 10, 11, 12, 13, 14, 15],
        [ 4,  3,  2,  1,  0, 10, 11, 12, 13, 14],
        [ 5,  4,  3,  2,  1,  0, 10, 11, 12, 13],
        [ 6,  5,  4,  3,  2,  1,  0, 10, 11, 12],
        [ 7,  6,  5,  4,  3,  2,  1,  0, 10, 11],
        [ 7,  7,  6,  5,  4,  3,  2,  1,  0, 10],
        [ 8,  7,  7,  6,  5,  4,  3,  2,  1,  0]])

In [2]:
import torch

# 创建一个三维张量，例如，大小为[2, 4, 4]
x = torch.randn(2, 4, 4)

# 在第二个维度（channels）上应用torch.max()
values, indices = torch.max(x, dim=1)
print(x)
print("Max values:", values)
print("Indices of max values:", indices)

tensor([[[-0.3324, -1.4082, -0.0784,  0.9798],
         [-1.5731,  2.0689,  0.8974,  1.5193],
         [-0.4568,  0.1283,  0.5386, -2.4981],
         [-1.6869,  0.1335,  0.8823, -0.2645]],

        [[-0.6141,  0.9644, -0.2441, -1.4504],
         [ 0.7161, -0.5827,  0.6977,  1.3610],
         [ 0.8721, -0.4602,  1.0663,  0.3069],
         [-1.1508, -0.7910,  1.0085, -0.4246]]])
Max values: tensor([[-0.3324,  2.0689,  0.8974,  1.5193],
        [ 0.8721,  0.9644,  1.0663,  1.3610]])
Indices of max values: tensor([[0, 1, 1, 1],
        [2, 0, 2, 1]])


In [10]:
## BN implement
bn = nn.BatchNorm2d(num_features=3, eps=0, affine=False, track_running_stats=False)

x = torch.rand(10, 3, 5, 5)* 10000
offical_bn = bn(x)

x_1 = x.permute(1,0,2,3).reshape(3,-1) # [c, n*h*w]
mu_x = x_1.mean(dim=-1).view(1,3,1,1) # [1,c,1,1]
std_x = x_1.std(dim=-1, unbiased=True).view(1,3,1,1)

my_bn = (x-mu_x)/std_x # no epsilon

In [11]:
(offical_bn-my_bn).sum()

tensor(-5.9740e-06)

In [4]:
import torch
from typing import Tuple

def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):
    # 计算词向量元素两两分组之后，每组元素对应的旋转角度
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))

    # 生成 token 序列索引 t = [0, 1,..., seq_len-1]
    t = torch.arange(seq_len, device=freqs.device)
    # freqs.shape = [seq_len, dim // 2] 
    freqs = torch.outer(t, freqs).float()
    # torch.polar的文档, https://pytorch.org/docs/stable/generated/torch.polar.html
    # torch.polar输入参数是abs和angle，abs所有值都一样，abs和angle的shape都一样
    # torch.polar输入参数是abs和angle，则freqs_cis = abs*(cos(angle) + sin(angle)i)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    # xq.shape = [batch_size, seq_len, dim]
    # xq_.shape = [batch_size, seq_len, dim // 2, 2] same as 2维情况
    xq_ = xq.float().reshape(*xq.shape[:-1], -1, 2)
    xk_ = xk.float().reshape(*xk.shape[:-1], -1, 2)
    
    # 转为复数域,  xq_.shape = [batch_size, seq_len, dim // 2]
    xq_ = torch.view_as_complex(xq_)
    xk_ = torch.view_as_complex(xk_)
    # 应用旋转操作，然后将结果转回实数域
    # xq_out.shape = [batch_size, seq_len, dim]
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2) #从dim=2维度开始拍平
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2)

    return xq_out.type_as(xq), xk_out.type_as(xk)

if __name__ == '__main__':
    seq_len,dim=3,4
    freqs_cis = precompute_freqs_cis(dim=dim, seq_len=seq_len, theta=10000.0)
    xq = torch.rand(1, seq_len, dim)
    xk = torch.rand(1, seq_len, dim)
    res = apply_rotary_emb(xq, xk, freqs_cis)
    # res的shape是1, seq_len, dim
    print(res)


(tensor([[[ 0.7552,  0.8442,  0.1087,  0.4682],
         [ 0.2172,  0.7370,  0.9776,  0.8823],
         [-0.7359,  0.6933,  0.3336,  0.8447]]]), tensor([[[ 0.3748,  0.9994,  0.5744,  0.6639],
         [-0.1345,  0.5654,  0.3115,  0.5989],
         [-0.8560, -0.2329,  0.7621,  0.7718]]]))
