In [1]:
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
import sys
sys.path.append("..")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
qlen = 10
klen = 10
context_position = torch.arange(qlen, dtype=torch.long,
                                        )[:, None]
memory_position = torch.arange(klen, dtype=torch.long,
                                      )[None, :]

In [6]:
memory_position - context_position

tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],
        [-1,  0,  1,  2,  3,  4,  5,  6,  7,  8],
        [-2, -1,  0,  1,  2,  3,  4,  5,  6,  7],
        [-3, -2, -1,  0,  1,  2,  3,  4,  5,  6],
        [-4, -3, -2, -1,  0,  1,  2,  3,  4,  5],
        [-5, -4, -3, -2, -1,  0,  1,  2,  3,  4],
        [-6, -5, -4, -3, -2, -1,  0,  1,  2,  3],
        [-7, -6, -5, -4, -3, -2, -1,  0,  1,  2],
        [-8, -7, -6, -5, -4, -3, -2, -1,  0,  1],
        [-9, -8, -7, -6, -5, -4, -3, -2, -1,  0]])

In [11]:
from src.model.layers.relative_position_bias import RelativePositionBias

pos_embed = RelativePositionBias(bidirectional=False,
                                 num_buckets=20,
                                 max_distance=10,
                                 n_heads=8)

In [15]:
pos_embed._relative_position_bucket(memory_position - context_position, bidirectional=True,
                                    num_buckets=19, max_distance=10)

tensor([[ 0, 10, 11, 12, 13, 14, 15, 16, 16, 17],
        [ 1,  0, 10, 11, 12, 13, 14, 15, 16, 16],
        [ 2,  1,  0, 10, 11, 12, 13, 14, 15, 16],
        [ 3,  2,  1,  0, 10, 11, 12, 13, 14, 15],
        [ 4,  3,  2,  1,  0, 10, 11, 12, 13, 14],
        [ 5,  4,  3,  2,  1,  0, 10, 11, 12, 13],
        [ 6,  5,  4,  3,  2,  1,  0, 10, 11, 12],
        [ 7,  6,  5,  4,  3,  2,  1,  0, 10, 11],
        [ 7,  7,  6,  5,  4,  3,  2,  1,  0, 10],
        [ 8,  7,  7,  6,  5,  4,  3,  2,  1,  0]])

In [2]:
import torch

# 创建一个三维张量，例如，大小为[2, 4, 4]
x = torch.randn(2, 4, 4)

# 在第二个维度（channels）上应用torch.max()
values, indices = torch.max(x, dim=1)
print(x)
print("Max values:", values)
print("Indices of max values:", indices)

tensor([[[-0.3324, -1.4082, -0.0784,  0.9798],
         [-1.5731,  2.0689,  0.8974,  1.5193],
         [-0.4568,  0.1283,  0.5386, -2.4981],
         [-1.6869,  0.1335,  0.8823, -0.2645]],

        [[-0.6141,  0.9644, -0.2441, -1.4504],
         [ 0.7161, -0.5827,  0.6977,  1.3610],
         [ 0.8721, -0.4602,  1.0663,  0.3069],
         [-1.1508, -0.7910,  1.0085, -0.4246]]])
Max values: tensor([[-0.3324,  2.0689,  0.8974,  1.5193],
        [ 0.8721,  0.9644,  1.0663,  1.3610]])
Indices of max values: tensor([[0, 1, 1, 1],
        [2, 0, 2, 1]])


In [10]:
## BN implement
bn = nn.BatchNorm2d(num_features=3, eps=0, affine=False, track_running_stats=False)

x = torch.rand(10, 3, 5, 5)* 10000
offical_bn = bn(x)

x_1 = x.permute(1,0,2,3).reshape(3,-1) # [c, n*h*w]
mu_x = x_1.mean(dim=-1).view(1,3,1,1) # [1,c,1,1]
std_x = x_1.std(dim=-1, unbiased=True).view(1,3,1,1)

my_bn = (x-mu_x)/std_x # no epsilon

In [11]:
(offical_bn-my_bn).sum()

tensor(-5.9740e-06)

In [4]:
import torch
from typing import Tuple

def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):
    # 计算词向量元素两两分组之后，每组元素对应的旋转角度
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))

    # 生成 token 序列索引 t = [0, 1,..., seq_len-1]
    t = torch.arange(seq_len, device=freqs.device)
    # freqs.shape = [seq_len, dim // 2] 
    freqs = torch.outer(t, freqs).float()
    # torch.polar的文档, https://pytorch.org/docs/stable/generated/torch.polar.html
    # torch.polar输入参数是abs和angle，abs所有值都一样，abs和angle的shape都一样
    # torch.polar输入参数是abs和angle，则freqs_cis = abs*(cos(angle) + sin(angle)i)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    # xq.shape = [batch_size, seq_len, dim]
    # xq_.shape = [batch_size, seq_len, dim // 2, 2] same as 2维情况
    xq_ = xq.float().reshape(*xq.shape[:-1], -1, 2)
    xk_ = xk.float().reshape(*xk.shape[:-1], -1, 2)
    
    # 转为复数域,  xq_.shape = [batch_size, seq_len, dim // 2]
    xq_ = torch.view_as_complex(xq_)
    xk_ = torch.view_as_complex(xk_)
    # 应用旋转操作，然后将结果转回实数域
    # xq_out.shape = [batch_size, seq_len, dim]
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2) #从dim=2维度开始拍平
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2)

    return xq_out.type_as(xq), xk_out.type_as(xk)

if __name__ == '__main__':
    seq_len,dim=3,4
    freqs_cis = precompute_freqs_cis(dim=dim, seq_len=seq_len, theta=10000.0)
    xq = torch.rand(1, seq_len, dim)
    xk = torch.rand(1, seq_len, dim)
    res = apply_rotary_emb(xq, xk, freqs_cis)
    # res的shape是1, seq_len, dim
    print(res)


(tensor([[[ 0.7552,  0.8442,  0.1087,  0.4682],
         [ 0.2172,  0.7370,  0.9776,  0.8823],
         [-0.7359,  0.6933,  0.3336,  0.8447]]]), tensor([[[ 0.3748,  0.9994,  0.5744,  0.6639],
         [-0.1345,  0.5654,  0.3115,  0.5989],
         [-0.8560, -0.2329,  0.7621,  0.7718]]]))
