In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# config
random_seed = 0
num_pos = 6 # 64
pad_id = 0
eos_id = 1
heads = 8
vocab_size = 128 # 64000
width = 768
batch_size = 10

# initialization
token_embedding = nn.Embedding(vocab_size, width)
cls_emb = nn.Parameter(torch.empty(width))
nn.init.normal_(cls_emb, std=0.01)

# create random text input
text = torch.randint(0, vocab_size, (batch_size, num_pos),generator=torch.Generator().manual_seed(random_seed))
eos_indices = torch.randint(0, num_pos, (batch_size, 1), generator=torch.Generator().manual_seed(random_seed))
text = text.scatter_(1, eos_indices, eos_id)
for i in range(batch_size): text[i, eos_indices[i, 0]+1:] = pad_id
print(f"text.shape: {text.shape}")
print(f"text: {text}")

text.shape: torch.Size([10, 6])
text: tensor([[ 44,  47,   1,   0,   0,   0],
        [ 67, 103,   9,   1,   0,   0],
        [ 36,  87,  70,  88,  88,   1],
        [  1,   0,   0,   0,   0,   0],
        [ 88,   1,   0,   0,   0,   0],
        [  9,  20, 115,   1,   0,   0],
        [126,   1,   0,   0,   0,   0],
        [ 88,   1,   0,   0,   0,   0],
        [ 14,   1,   0,   0,   0,   0],
        [127,  32,  31,   1,   0,   0]])


In [3]:
def build_causal_mask():
	# lazily create causal attention mask, with full attention between the tokens
	# pytorch uses additive attention mask; fill with -inf
	mask = torch.empty(num_pos+1, num_pos+1)
	mask.fill_(float("-inf"))
	mask.triu_(1)  # zero out the lower diagonal
	return mask

def _expand_token(token, batch_size: int):
    return token.view(1, 1, -1).expand(batch_size, -1, -1)

def build_cls_mask(text, cast_dtype: torch.dtype):
	cls_mask = (text != pad_id).unsqueeze(1) # (batch_size, 1, num_pos)
	cls_mask = F.pad(cls_mask, (0, 1, cls_mask.shape[2], 0), value=True) # (batch_size, num_pos+1, num_pos+1)
	additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device)
	additive_mask.fill_(0)
	additive_mask.masked_fill_(~cls_mask, float("-inf"))
	additive_mask = torch.repeat_interleave(additive_mask, heads, 0)
	return additive_mask

In [4]:
seq_len = text.shape[1] # num_pos
x = token_embedding(text) # (batch_size, num_pos, width)
attn_mask = build_causal_mask() # (num_pos, num_pos)

# appending cls_emb to the txt_emb
seq_len += 1 # num_pos+1
x = torch.cat([x, _expand_token(cls_emb, x.shape[0])], dim=1) # (batch_size, num_pos+1, width)
cls_mask = build_cls_mask(text, x.dtype) # (batch_size, num_pos+1, num_pos+1)

In [11]:
text[1],attn_mask,cls_mask[1],attn_mask+cls_mask[1]

(tensor([ 67, 103,   9,   1,   0,   0]),
 tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf],
         [0., 0., -inf, -inf, -inf, -inf, -inf],
         [0., 0., 0., -inf, -inf, -inf, -inf],
         [0., 0., 0., 0., -inf, -inf, -inf],
         [0., 0., 0., 0., 0., -inf, -inf],
         [0., 0., 0., 0., 0., 0., -inf],
         [0., 0., 0., 0., 0., 0., 0.]]),
 tensor([[0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., -inf, -inf, -inf, 0.]]),
 tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf],
         [0., 0., -inf, -inf, -inf, -inf, -inf],
         [0., 0., 0., -inf, -inf, -inf, -inf],
         [0., 0., 0., 0., -inf, -inf, -inf],
         [0., 0., 0., 0., 0., -inf, -inf],
         [0., 0., 0., 0., 0., 0., -inf],
         [0., 0., 0., -inf, -inf, -inf, 0.]]))

In [7]:
attn_mask[None, :seq_len, :seq_len].shape, cls_mask[:, :seq_len, :seq_len].shape

(torch.Size([1, 7, 7]), torch.Size([80, 7, 7]))

In [8]:
final_attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len] # (batch_size, num_pos+1, num_pos+1)

In [9]:
final_attn_mask.shape

torch.Size([80, 7, 7])

In [10]:
"""
# build_cls_mask
print(f"text.shape = {text.shape}")
print()
cls_mask = (text != pad_id).unsqueeze(1)
print(f"cls_mask.shape = {cls_mask.shape}")
print(f"cls_mask[0] = {cls_mask[0]}")
print()
cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True)
print(f"cls_mask.shape = {cls_mask.shape}")
print(f"cls_mask[0] = {cls_mask[0]}")
print()
additive_mask = torch.empty(cls_mask.shape, device=cls_mask.device)
additive_mask.fill_(0)
additive_mask.masked_fill_(~cls_mask, float("-inf"))
print(f"additive_mask.shape = {additive_mask.shape}")
print(f"additive_mask[0] = {additive_mask[0]}")
print()
additive_mask = torch.repeat_interleave(additive_mask, heads, 0)
print(f"additive_mask.shape = {additive_mask.shape}")
print(f"additive_mask[0] = {additive_mask[0]}")
print()
"""

'\n# build_cls_mask\nprint(f"text.shape = {text.shape}")\nprint()\ncls_mask = (text != pad_id).unsqueeze(1)\nprint(f"cls_mask.shape = {cls_mask.shape}")\nprint(f"cls_mask[0] = {cls_mask[0]}")\nprint()\ncls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True)\nprint(f"cls_mask.shape = {cls_mask.shape}")\nprint(f"cls_mask[0] = {cls_mask[0]}")\nprint()\nadditive_mask = torch.empty(cls_mask.shape, device=cls_mask.device)\nadditive_mask.fill_(0)\nadditive_mask.masked_fill_(~cls_mask, float("-inf"))\nprint(f"additive_mask.shape = {additive_mask.shape}")\nprint(f"additive_mask[0] = {additive_mask[0]}")\nprint()\nadditive_mask = torch.repeat_interleave(additive_mask, heads, 0)\nprint(f"additive_mask.shape = {additive_mask.shape}")\nprint(f"additive_mask[0] = {additive_mask[0]}")\nprint()\n'