## 1、参数 key_padding_mask

In [34]:
import torch
import torch.nn as nn
import torch.nn.functional as F

image1 = torch.rand(3, 3, 2)  # c, h, w
image2 = torch.rand(3, 2, 3)  # c, h, w
max_size = [max(n1, n2) for n1, n2 in zip(image1.shape, image2.shape)]

batch_shape = [2] + max_size
b, c, h, w = batch_shape

batch_image = torch.zeros(batch_shape)
key_padding_mask = torch.full((b, h, w), True, dtype=torch.bool)

batch_image[0, :image1.shape[0], :image1.shape[1], :image1.shape[2]] = image1
key_padding_mask[0, :image1.shape[1], :image1.shape[2]] = False

batch_image[1, :image2.shape[0], :image2.shape[1], :image2.shape[2]] = image2
key_padding_mask[1, :image2.shape[1], :image2.shape[2]] = False
print(key_padding_mask)

batch_image = batch_image.flatten(-2).permute(0, 2, 1)
key_padding_mask = key_padding_mask.flatten(-2)


q = k = v = batch_image
self_attn = nn.MultiheadAttention(embed_dim=3, num_heads=1, batch_first=True)
attn_output, attn_output_weights = self_attn(q, k, v, key_padding_mask=key_padding_mask)

print(attn_output.shape)   # torch.Size([2, 9, 3])
print(attn_output_weights.shape)   # torch.Size([2, 9, 9])

tensor([[[False, False,  True],
         [False, False,  True],
         [False, False,  True]],

        [[False, False, False],
         [False, False, False],
         [ True,  True,  True]]])
torch.Size([2, 9, 3])
torch.Size([2, 9, 9])


## 2、参数 att_mask

In [35]:
import torch
import torch.nn as nn


tgt = torch.rand(2, 4, 3)  # b, len_q, embed_dim
memory = torch.rand(2, 5, 3)  # b, len_k, embed_dim

b, len_q, _ = tgt.shape
_, len_k, _ = memory.shape

attn_mask = torch.full([len_q, len_k], False, dtype=torch.bool)
attn_mask[1:, :2] = True
print(attn_mask)

self_attn = nn.MultiheadAttention(embed_dim=3, num_heads=1, batch_first=True)
attn_output, attn_output_weights = self_attn(tgt, memory, memory, attn_mask=attn_mask)

print(attn_output.shape)
print(attn_output_weights.shape)

tensor([[False, False, False, False, False],
        [ True,  True, False, False, False],
        [ True,  True, False, False, False],
        [ True,  True, False, False, False]])
torch.Size([2, 4, 3])
torch.Size([2, 4, 5])


## 3、key_padding_mask 和 attn_mask 可以等效使用

In [36]:
q = k = v = torch.rand(3, 3)
key_padding_mask = torch.tensor([False, False, True], dtype=torch.bool)
attn_mask = torch.tensor([[False, False, True], [False, False, True], [False, False, True]], dtype=torch.bool)

self_attn = nn.MultiheadAttention(embed_dim=3, num_heads=1, batch_first=True)
attn_output1, attn_output_weights1 = self_attn(q, k, v, key_padding_mask=key_padding_mask)
attn_output2, attn_output_weights2 = self_attn(q, k, v, attn_mask=attn_mask)
print(attn_output1)
print(attn_output2)
print(attn_output_weights1)
print(attn_output_weights2)

tensor([[-0.1327, -0.3230, -0.0430],
        [-0.1277, -0.3282, -0.0476],
        [-0.1281, -0.3278, -0.0473]], grad_fn=<SqueezeBackward1>)
tensor([[-0.1327, -0.3230, -0.0430],
        [-0.1277, -0.3282, -0.0476],
        [-0.1281, -0.3278, -0.0473]], grad_fn=<SqueezeBackward1>)
tensor([[0.5110, 0.4890, 0.0000],
        [0.4933, 0.5067, 0.0000],
        [0.4947, 0.5053, 0.0000]], grad_fn=<SqueezeBackward1>)
tensor([[0.5110, 0.4890, 0.0000],
        [0.4933, 0.5067, 0.0000],
        [0.4947, 0.5053, 0.0000]], grad_fn=<SqueezeBackward1>)


# 4、小实验

In [44]:
import torch
import torch.nn as nn
import torch.nn.functional as F

image1 = torch.rand(3, 3, 2)  # c, h, w
image2 = torch.rand(3, 2, 3)  # c, h, w
max_size = [max(n1, n2) for n1, n2 in zip(image1.shape, image2.shape)]

batch_shape = [2] + max_size
b, c, h, w = batch_shape

batch_image = torch.zeros(batch_shape)
padding_mask = torch.ones((b, h, w))

batch_image[0, :image1.shape[0], :image1.shape[1], :image1.shape[2]] = image1
padding_mask[0, :image1.shape[1], :image1.shape[2]] = False

batch_image[1, :image2.shape[0], :image2.shape[1], :image2.shape[2]] = image2
padding_mask[1, :image2.shape[1], :image2.shape[2]] = False

batch_image = batch_image.flatten(-2).permute(0, 2, 1)
padding_mask = padding_mask.flatten(-2)
q = k = v = batch_image

self_attn = nn.MultiheadAttention(embed_dim=3, num_heads=1, batch_first=True)
attn_output, attn_output_weights = self_attn(q, k, v, key_padding_mask=padding_mask)

print(attn_output)
print(attn_output.shape)   # torch.Size([2, 9, 3])
print(attn_output_weights.shape)   # torch.Size([2, 9, 9])

# --------------------------------------------------------------------------------------

parm_info = [{param_name:param.shape} for param_name, param in self_attn.named_parameters()]
print(parm_info)

input_linear = nn.Linear(3, 3)
input_linear.weight = nn.Parameter(self_attn.in_proj_weight[6:, :])
input_linear.bias = nn.Parameter(self_attn.in_proj_bias[6:])
v = v.contiguous().view(b*h*w, c)
v = input_linear(v).view(b, h*w, c)

o = torch.bmm(attn_output_weights, v)
o = o.contiguous().view(b*h*w, c)

output_linear = nn.Linear(3, 3)
output_linear.weight = self_attn.out_proj.weight
output_linear.bias = self_attn.out_proj.bias
o = output_linear(o)

o = o.view(b, h*w, c)
print(o)

tensor([[[-0.0569, -0.0939, -0.0323],
         [-0.0557, -0.0919, -0.0316],
         [-0.0530, -0.0872, -0.0294],
         [-0.0560, -0.0924, -0.0320],
         [-0.0599, -0.0992, -0.0350],
         [-0.0530, -0.0872, -0.0294],
         [-0.0581, -0.0960, -0.0337],
         [-0.0577, -0.0955, -0.0335],
         [-0.0530, -0.0872, -0.0294]],

        [[-0.0552, -0.0888, -0.0238],
         [-0.0625, -0.1012, -0.0288],
         [-0.0648, -0.1051, -0.0301],
         [-0.0594, -0.0959, -0.0266],
         [-0.0618, -0.1001, -0.0285],
         [-0.0596, -0.0963, -0.0268],
         [-0.0554, -0.0891, -0.0238],
         [-0.0554, -0.0891, -0.0238],
         [-0.0554, -0.0891, -0.0238]]], grad_fn=<TransposeBackward0>)
torch.Size([2, 9, 3])
torch.Size([2, 9, 9])
[{'in_proj_weight': torch.Size([9, 3])}, {'in_proj_bias': torch.Size([9])}, {'out_proj.weight': torch.Size([3, 3])}, {'out_proj.bias': torch.Size([3])}]
tensor([[[-0.0569, -0.0939, -0.0323],
         [-0.0557, -0.0919, -0.0316],
         