In [30]:
# 使用线性插值算法构造相对位置编码
import torch
import torch.nn.functional as F

# 初始化相对位置编码
def initRPE(size_src: tuple, size_tar: tuple, device: torch.device='cpu', gain: float=1.0):
    '''
    size: 特征图尺寸, (h, w)
    index_onehot: 聚类结果(每个像素对应的聚类中心的one-hot索引), [B, num_heads, L, S]
    gain: 增益系数
    '''
    assert type(size_src) == tuple, 'Data type of size in function <initRPE> should be <tuple>!'
    assert size_src.__len__() == 2, 'Length of size should be 2!'
    # 构造基础坐标系, 左上角像素中心坐标为(0.5, 0.5)
    H_max = max(size_src[0], size_tar[0])
    W_max = max(size_src[1], size_tar[1])
    coords_h = torch.arange(H_max) + 0.5                    # 以更大尺寸构造行坐标系，行坐标与例坐标分开，更加鲁棒
    coords_w = torch.arange(W_max) + 0.5                    # 以更大尺寸构造列坐标系
    coords_base = torch.stack(
        torch.meshgrid([coords_h, coords_w])                # 构造基础坐标网格，[2, h, w]
    ).unsqueeze(0).float()                                  # [1, 2, h, w], unsqueeze是为了插值
    # 获取原图各像素点坐标
    if size_src == (H_max, W_max):
        coords_src = coords_base
    else:
        coords_src = F.interpolate(coords_base, size_src, mode='bilinear')
    # 获取目标图各像素点坐标
    if size_tar == (H_max, W_max):
        coords_tar = coords_base
    else:
        coords_tar = F.interpolate(coords_base, size_tar, mode='bilinear')
    # 一维化坐标，便于计算相对位置
    coords_src = coords_src.reshape(2, -1).to(device)       # [2, L]
    coords_tar = coords_tar.reshape(2, -1).to(device)       # [2, S]
    # 构造相对位置矩阵, 第一个矩阵是h方向的相对位置差, 第二个矩阵是w方向的相对位置差
    relative_coords = coords_src[:, :, None] - coords_tar[:, None, :]   # [2, L, S]
    distance = torch.sqrt(                                              # [L, S]
        torch.square(relative_coords[0,:,:]) + torch.square(relative_coords[1,:,:])
    )
    # exp操作用于处理distance中的0, [B, num_heads, L, S]
    distance_exp = torch.exp(distance)
    # 距离越远的token注意力增强越少(加性增强), 最大值为1*gain, 最小值可以接近0, [L, S]
    rpe = (1 / distance_exp) * gain
    return rpe

size_src = (3, 3)
size_tar = (3, 3)

rpe = initRPE(size_src, size_tar)
print(rpe.shape)
print(rpe)

torch.Size([9, 9])
tensor([[1.0000, 0.3679, 0.1353, 0.3679, 0.2431, 0.1069, 0.1353, 0.1069, 0.0591],
        [0.3679, 1.0000, 0.3679, 0.2431, 0.3679, 0.2431, 0.1069, 0.1353, 0.1069],
        [0.1353, 0.3679, 1.0000, 0.1069, 0.2431, 0.3679, 0.0591, 0.1069, 0.1353],
        [0.3679, 0.2431, 0.1069, 1.0000, 0.3679, 0.1353, 0.3679, 0.2431, 0.1069],
        [0.2431, 0.3679, 0.2431, 0.3679, 1.0000, 0.3679, 0.2431, 0.3679, 0.2431],
        [0.1069, 0.2431, 0.3679, 0.1353, 0.3679, 1.0000, 0.1069, 0.2431, 0.3679],
        [0.1353, 0.1069, 0.0591, 0.3679, 0.2431, 0.1069, 1.0000, 0.3679, 0.1353],
        [0.1069, 0.1353, 0.1069, 0.2431, 0.3679, 0.2431, 0.3679, 1.0000, 0.3679],
        [0.0591, 0.1069, 0.1353, 0.1069, 0.2431, 0.3679, 0.1353, 0.3679, 1.0000]])


In [12]:
import torch
from typing import Union

def func(a=None, b=None) -> int:
    assert a is None or b is None, '不能同时指定a和b'
    if not a is None:
        return a
    elif not b is None:
        return b
    else:
        pass
# func(a=1, b=2)

In [35]:
import torch
import torch.nn.functional as F

shape_x = (5, 5)
shape_c = (2, 2)

delta_index_c = torch.arange(0, shape_c[0]*shape_c[1]).reshape(1, 1, *shape_c).float()  # [1, 1, H_c, W_c]
delta_index_x = F.interpolate(delta_index_c, shape_x, mode='nearest').long()            # [1, 1, H_x, W_x]
delta_onehot_x = F.one_hot(delta_index_x, shape_c[0]*shape_c[1]).permute(0, 1, 4, 2, 3).reshape(1, shape_c[0]*shape_c[1], *shape_x).float() # [1, S, H_x, W_x]

delta_onehot_pad_x = F.pad(delta_onehot_x, (0, 2, 0, 1), mode='reflect')
delta_index_pad_x = delta_onehot_pad_x.argmax(dim=1)

# print('delta_index_c: \n', delta_index_c)
print('delta_index_x: \n', delta_index_x)
print('delta_onehot_x: \n', delta_onehot_x.shape)
# print('delta_onehot_x: \n', delta_onehot_x)

print('delta_onehot_pad_x: \n', delta_onehot_pad_x.shape)
print('delta_index_pad_x: \n', delta_index_pad_x)


delta_index_x: 
 tensor([[[[0, 0, 0, 1, 1],
          [0, 0, 0, 1, 1],
          [0, 0, 0, 1, 1],
          [2, 2, 2, 3, 3],
          [2, 2, 2, 3, 3]]]])
delta_onehot_x: 
 torch.Size([1, 4, 5, 5])
delta_onehot_pad_x: 
 torch.Size([1, 4, 6, 7])
delta_index_pad_x: 
 tensor([[[0, 0, 0, 1, 1, 1, 0],
         [0, 0, 0, 1, 1, 1, 0],
         [0, 0, 0, 1, 1, 1, 0],
         [2, 2, 2, 3, 3, 3, 2],
         [2, 2, 2, 3, 3, 3, 2],
         [2, 2, 2, 3, 3, 3, 2]]])


In [40]:
import torch
import torch.nn.functional as F

a = torch.zeros((4, 2)) * 10
b = a.sum(dim=0, keepdim=True) / a.shape[0]
c = a.sum(dim=0, keepdim=True)

print(F.normalize(b))
print(F.normalize(c))

tensor([[0., 0.]])
tensor([[0., 0.]])


In [3]:
import torch

a = torch.tensor([0, 1], dtype=torch.float32)
b = torch.tensor([torch.inf, torch.inf], dtype=torch.float32)

dist = (a - b).square().sum().sqrt()
print(1/dist)


tensor(0.)
