In [2]:
from torch import nn 
import torch 
# target output size of 5x7
m = nn.AdaptiveAvgPool2d((5,7))
input = torch.randn(1, 64, 8, 9)
output = m(input)
print(output.shape)
# target output size of 7x7 (square)
m = nn.AdaptiveAvgPool2d(7)
input = torch.randn(1, 64, 10, 9)
output = m(input)
print(output.shape)
# target output size of 10x7
m = nn.AdaptiveAvgPool2d((None, 7))
input = torch.randn(1, 64, 10, 9)
output = m(input)
print(output.shape)

torch.Size([1, 64, 5, 7])
torch.Size([1, 64, 7, 7])
torch.Size([1, 64, 10, 7])


In [36]:
import torch 
from einops import rearrange
def generate_relative_positions_matrix(length, max_relative_positions,
                                       cache=False):
    """Generate the clipped relative positions matrix
       for a given length and maximum relative positions"""
    if cache:
        distance_mat = torch.arange(-length+1, 1, 1).unsqueeze(0)
    else:
        range_vec = torch.arange(length)
        range_mat = range_vec.unsqueeze(-1).expand(-1, length).transpose(0, 1)
        distance_mat = range_mat - range_mat.transpose(0, 1)
    distance_mat_clipped = torch.clamp(distance_mat,
                                       min=-max_relative_positions,
                                       max=max_relative_positions)
    # Shift values to be >= 0
    final_mat = distance_mat_clipped + max_relative_positions
    return final_mat 


def calc_rel_pos(n):
    pos = torch.meshgrid(torch.arange(n), torch.arange(n))
    print(pos)
    pos = rearrange(torch.stack(pos), 'n i j -> (i j) n')  # [n*n, 2] pos[n] = (i, j)
    rel_pos = pos[None, :] - pos[:, None]                  # [n*n, n*n, 2] rel_pos[n, m] = (rel_i, rel_j)
    rel_pos += n - 1                                       # shift value range from [-n+1, n-1] to [0, 2n-2]
    return rel_pos

print(generate_relative_positions_matrix(1,1))
calc_rel_pos(1)

tensor([[1]])
(tensor([[0]]), tensor([[0]]))


tensor([[[0, 0]]])

In [24]:
class LambdaLayer(nn.Module):
    def __init__(self, dim, dim_k, dim_out, heads):
        super(LambdaLayer, self).__init__()

        assert(dim_out % heads) == 0
        self.heads = heads
        dim_v = dim_out // heads

# (Page 5) BN after Q and V are helpful
        self.get_q = nn.Sequential(
            nn.Conv2d(in_channels=dim, out_channels=dim_k*heads, kernel_size=1, bias=False),
            nn.BatchNorm2d(dim_k*heads),
        )
        self.get_v = nn.Sequential(
            nn.Conv2d(in_channels=dim, out_channels=dim_v, kernel_size=1, bias=False),
            nn.BatchNorm2d(dim_v),
        )
        self.get_k = nn.Conv2d(in_channels=dim, out_channels=dim_k*heads, kernel_size=1, bias=False)

        ## TODO： 
        # self.embedding = ()
        # self.relative_position = generate_relative_positions_matrix(dim,dim)

    def forward(self,x):
        (b, c, im_h, im_w) = x.shape
        Q = self.get_q(x)
        K = self.get_k(x)
        V = self.get_v(x)
        
        Q = rearrange(Q, 'b (h k) im_h im_w -> b h k (im_h im_w)', h=self.heads)
        print(Q.size())
        K = rearrange(K, 'b k im_h im_w -> b k (im_h im_w)')
        print(K.size())
        V = rearrange(V, 'b v im_h im_w -> b v (im_h im_w)')
        print(V.size())
        σ_K = K.softmax(dim=-1)
        λc = einsum('b k m, b v m -> b k v', σ_K, V)
        # λp = einsum('b ') ## TODO: 添加Embedding

        # λn = λc + λp
        return λc

In [27]:
import torch
import torch.nn as nn
from einops import rearrange
from torch import einsum
from torchinfo import summary

# summary(LambdaLayer(dim=3,dim_k=3,dim_out=3,heads=3), (32,3,255,255))

X = torch.rand((32,3,255,255))
LambdaLayer(dim=3,dim_k=3,dim_out=3,heads=3)(X).size()


torch.Size([32, 3, 3, 65025])
torch.Size([32, 9, 65025])
torch.Size([32, 1, 65025])


torch.Size([32, 9, 1])

In [5]:
from einops import rearrange
def calc_rel_pos(n):
    pos = torch.meshgrid(torch.arange(n), torch.arange(n))
    pos = rearrange(torch.stack(pos), 'n i j -> (i j) n')  # [n*n, 2] pos[n] = (i, j)
    rel_pos = pos[None, :] - pos[:, None]                  # [n*n, n*n, 2] rel_pos[n, m] = (rel_i, rel_j)
    rel_pos += n - 1                                       # shift value range from [-n+1, n-1] to [0, 2n-2]
    return rel_pos

In [10]:
import torch 
def rel_pos_indices(size):
    size = (size, size)
    pos = torch.stack(torch.meshgrid(torch.arange(size[0]), torch.arange(size[1]))).flatten(1)
    rel_pos = pos[:, None, :] - pos[:, :, None]
    rel_pos[0] += size[0] - 1
    rel_pos[1] += size[1] - 1
    return rel_pos  # 2, H * W, H * W

print(rel_pos_indices(2))
print(calc_rel_pos(2))

tensor([[[1, 1, 2, 2],
         [1, 1, 2, 2],
         [0, 0, 1, 1],
         [0, 0, 1, 1]],

        [[1, 2, 1, 2],
         [0, 1, 0, 1],
         [1, 2, 1, 2],
         [0, 1, 0, 1]]])
tensor([[[1, 1],
         [1, 2],
         [2, 1],
         [2, 2]],

        [[1, 0],
         [1, 1],
         [2, 0],
         [2, 1]],

        [[0, 1],
         [0, 2],
         [1, 1],
         [1, 2]],

        [[0, 0],
         [0, 1],
         [1, 0],
         [1, 1]]])


In [41]:
import torch 
def rel_pos_indices(size):
    size = (size, size)
    pos = torch.stack(torch.meshgrid(torch.arange(size[0]), torch.arange(size[1]))).flatten(1)
    rel_pos = pos[:, None, :] - pos[:, :, None]
    rel_pos[0] += size[0] - 1
    rel_pos[1] += size[1] - 1
    return rel_pos  # 2, H * W, H * W

  
from einops import rearrange
def get_relative_position_matrix(size,):
  x = torch.repeat_interleave(torch.arange(size),size,dim=0)
  y = torch.arange(size).repeat(size)
  distance_mat = torch.stack((x,y))
  distance_mat = distance_mat[:,None, :] - distance_mat[:, :,None]
  distance_mat = torch.clamp(distance_mat, -size, size)
  distance_mat += size-1
  return distance_mat

def calc_rel_pos(n):
    pos = torch.meshgrid(torch.arange(n), torch.arange(n))
    pos = rearrange(torch.stack(pos), 'n i j -> (i j) n')  # [n*n, 2] pos[n] = (i, j)
    rel_pos = pos[None, :] - pos[:, None]                  # [n*n, n*n, 2] rel_pos[n, m] = (rel_i, rel_j)
    rel_pos += n - 1                                       # shift value range from [-n+1, n-1] to [0, 2n-2]
    return rel_pos

print(get_relative_position_matrix(2))
print(rel_pos_indices(2))
print(calc_rel_pos(2))
torch.equal(get_relative_position_matrix(10), rel_pos_indices(10))

tensor([[[1, 1, 2, 2],
         [1, 1, 2, 2],
         [0, 0, 1, 1],
         [0, 0, 1, 1]],

        [[1, 2, 1, 2],
         [0, 1, 0, 1],
         [1, 2, 1, 2],
         [0, 1, 0, 1]]])
tensor([[[1, 1, 2, 2],
         [1, 1, 2, 2],
         [0, 0, 1, 1],
         [0, 0, 1, 1]],

        [[1, 2, 1, 2],
         [0, 1, 0, 1],
         [1, 2, 1, 2],
         [0, 1, 0, 1]]])
tensor([[[1, 1],
         [1, 2],
         [2, 1],
         [2, 2]],

        [[1, 0],
         [1, 1],
         [2, 0],
         [2, 1]],

        [[0, 1],
         [0, 2],
         [1, 1],
         [1, 2]],

        [[0, 0],
         [0, 1],
         [1, 0],
         [1, 1]]])


True

In [39]:
size=2

x = torch.repeat_interleave(torch.arange(size),size,dim=0)
y = torch.arange(size).repeat(size)
distance_mat = torch.stack((x,y))
distance_mat = distance_mat[:,None, :] - distance_mat[:, :,None]
distance_mat = torch.clamp(distance_mat, -size, size)
# distance_mat = rearrange(distance_mat, 'n i j -> (i j) n')
distance_mat += size-1
distance_mat

tensor([[[1, 1, 2, 2],
         [1, 1, 2, 2],
         [0, 0, 1, 1],
         [0, 0, 1, 1]],

        [[1, 2, 1, 2],
         [0, 1, 0, 1],
         [1, 2, 1, 2],
         [0, 1, 0, 1]]])

In [37]:
import torch
import torch.nn as nn
from einops import rearrange
from torch import einsum

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
    #  np.random.seed(seed)
    #  random.seed(seed)
     torch.backends.cudnn.deterministic = True
# 设置随机数种子
setup_seed(20)

def calc_rel_pos(n):
    pos = torch.meshgrid(torch.arange(n), torch.arange(n))
    pos = rearrange(torch.stack(pos), 'n i j -> (i j) n')  # [n*n, 2] pos[n] = (i, j)
    rel_pos = pos[None, :] - pos[:, None]                  # [n*n, n*n, 2] rel_pos[n, m] = (rel_i, rel_j)
    rel_pos += n - 1                                       # shift value range from [-n+1, n-1] to [0, 2n-2]
    return rel_pos

def calc_embeddings(dim_k, n):
    rel_lengths = 2 * n - 1
    # rel_pos_emb = nn.Parameter(torch.randn(rel_lengths, rel_lengths, dim_k))
    rel_pos_emb = nn.Parameter(torch.arange(rel_lengths*rel_lengths*dim_k,dtype=float).reshape(rel_lengths,rel_lengths,dim_k))
    rel_pos = calc_rel_pos(n)

    n, m = rel_pos.unbind(dim = -1)
    rel_pos_emb = rel_pos_emb[n, m]
    return rel_pos_emb

def get_relative_position_matrix(size,):
  x = torch.repeat_interleave(torch.arange(size),size,dim=0)
  y = torch.arange(size).repeat(size)
  distance_mat = torch.stack((x,y))
  distance_mat = distance_mat[:,None, :] - distance_mat[:, :,None]
  distance_mat = torch.clamp(distance_mat, -size, size)
  distance_mat += size-1
  return distance_mat

def get_embedding(dim_k=1, n=2):
  rel_lengths = 2 * n - 1 # n = im_h = im_w the feature map size
  # rel_pos_emb = nn.Parameter(torch.randn(rel_lengths, rel_lengths, dim_k)) # 2*n-1 2*n-1 k 
  rel_pos_emb = nn.Parameter(torch.arange(rel_lengths*rel_lengths*dim_k,dtype=float).reshape(rel_lengths,rel_lengths,dim_k))
  rel_pos = get_relative_position_matrix(n)
  return rel_pos_emb[rel_pos[0], rel_pos[1]]

# print(calc_embeddings(1,10))
# print(get_embedding(1,10))
# torch.equal(calc_embeddings(1,10), get_embedding(1,10))
print(get_embedding(1,10).shape)

torch.Size([100, 100, 1])


In [9]:
from layers.Lambda_layer import LambdaLayer
from lambda_networks import LambdaLayer

layer1 = LambdaLayer(
    dim = 32,       # channels going in
    dim_out = 32,   # channels out
    n = 64,         # size of the receptive window - max(height, width)
    dim_k = 16,     # key dimension
    heads = 4,      # number of heads, for multi-query
    dim_u = 1       # 'intra-depth' dimension
)

#  dim, dim_k, n, dim_out, heads, r=None
layer2 = LambdaLayer(
    dim = 32,       # channels going in
    dim_k = 16,     # key dimension
    n = 64,         # size of the receptive window - max(height, width)
    dim_out = 32,   # channels out
    heads = 4,      # number of heads, for multi-query
    # r = 23,
)

x = torch.randn(1, 32, 64, 64)


torch.Size([1, 32, 64, 64])
torch.Size([1, 32, 64, 64])


In [6]:
import torch.nn as nn
def get_relative_position_matrix(size,):
    x = torch.repeat_interleave(torch.arange(size),size,dim=0)
    y = torch.arange(size).repeat(size)
    distance_mat = torch.stack((x,y))
    distance_mat = distance_mat[:,None, :] - distance_mat[:, :,None]
    distance_mat = torch.clamp(distance_mat, -size, size)
    distance_mat += size-1
    return distance_mat

def get_embedding(dim_k, n):
    rel_lengths = 2 * n - 1 # n = im_h = im_w the feature map size
    rel_pos_emb = nn.Parameter(torch.randn(rel_lengths, rel_lengths, dim_k)) # 2*n-1 2*n-1 k 
    rel_pos = get_relative_position_matrix(n)
    # return rel_pos
    return rel_pos_emb[rel_pos[0], rel_pos[1]]


n, m = get_embedding(16,64).unbind(dim = -1)

ValueError: too many values to unpack (expected 2)

In [10]:
from einops import rearrange
from torch import einsum
import torch.nn as nn
def get_relative_position_matrix(size,):
    x = torch.repeat_interleave(torch.arange(size),size,dim=0)
    y = torch.arange(size).repeat(size)
    distance_mat = torch.stack((x,y))
    distance_mat = distance_mat[:,None, :] - distance_mat[:, :,None]
    distance_mat = torch.clamp(distance_mat, -size, size)
    distance_mat += size-1
    return distance_mat

def get_embedding(dim_k, n):
    rel_lengths = 2 * n - 1 # n = im_h = im_w the feature map size
    rel_pos_emb = nn.Parameter(torch.randn(rel_lengths, rel_lengths, dim_k)) # 2*n-1 2*n-1 k 
    rel_pos = get_relative_position_matrix(n)
    # return rel_pos
    return rel_pos_emb[rel_pos[0], rel_pos[1]]

V = torch.randn(1, 64*64, 32//4)
embeddings = get_embedding(16, 64) # n m k
λp = einsum('n m k, b m v -> b n k v', embeddings, V)


torch.Size([1, 4096, 16, 8])
