In [1]:
# imports
from pathlib import Path
import sys  

# Get my_package directory path from Notebook
parent_dir = str(Path().resolve().parents[1])

# Add to sys.path
sys.path.insert(0, parent_dir)

In [2]:
import torch 
import math
from torch import nn
import torch.nn.functional as F


In [3]:
from src.transformers.models.functionals import (
        generate_square_subsequent_mask, 
        mask_fill_combined,    
        create_causal_mask,                                     
        create_cross_attention_mask)

from src.transformers.models.attentions import MultiHeadAttention as MhA
MhA.scaled_dot_product_attention

def clear_nan(tensor: torch.Tensor):
        return torch.where(torch.isnan(tensor), 
                               torch.zeros_like(tensor), 
                               tensor)


In [4]:
embed_dim = 128
num_heads = 8
hidden_dim = 200
max_len = 5000
batch_size = 32
d_ff = hidden_dim
num_head = 8
src_seq_len = 4
tgt_seq_len = 5

In [5]:
src = torch.tensor(
    [[ 1,3,4,2,0,0,0],
    [ 1,3,4,4,2,0,0],
    [ 1,2,0,0,0,0,0]]
)
    
# en_tensor = torch.randint(1, 10**4,[batch_size, seq_len], dtype=int)
tgt = torch.tensor(
    [[ 1],
    [ 1],
    [ 1]]
)

print(src.shape, tgt.shape)

embedding = nn.Embedding(10**4, embedding_dim=embed_dim)
src_embedding = embedding(src)
tgt_embedding = embedding(tgt)

print(src_embedding.shape, tgt_embedding.shape)

batch_size, seq_len_src, embed_dim = src_embedding.shape
_, seq_len_tgt, _ = tgt_embedding.shape

linear = nn.Linear(embed_dim, embed_dim)

Q = linear(tgt_embedding).view(batch_size, seq_len_tgt, num_head, embed_dim//num_head).transpose(1, 2)
K = linear(src_embedding).view(batch_size, seq_len_src, num_head, embed_dim//num_head).transpose(1, 2)
V = linear(src_embedding).view(batch_size, seq_len_src, num_head, embed_dim//num_head).transpose(1, 2)

print(Q.shape, V.shape)

torch.Size([3, 7]) torch.Size([3, 1])
torch.Size([3, 7, 128]) torch.Size([3, 1, 128])
torch.Size([3, 8, 1, 16]) torch.Size([3, 8, 7, 16])


In [6]:
print(generate_square_subsequent_mask(src.shape[-1]))
print(create_cross_attention_mask(tgt, src))

tensor([[ 0.0000e+00, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09,
         -1.0000e+09, -1.0000e+09],
        [ 0.0000e+00,  0.0000e+00, -1.0000e+09, -1.0000e+09, -1.0000e+09,
         -1.0000e+09, -1.0000e+09],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+09, -1.0000e+09,
         -1.0000e+09, -1.0000e+09],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+09,
         -1.0000e+09, -1.0000e+09],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -1.0000e+09, -1.0000e+09],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00, -1.0000e+09],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          0.0000e+00,  0.0000e+00]])
tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -1.0000e+09,
           -1.0000e+09, -1.0000e+09]]],


        [[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           -1.0000e+09, -1.0000e+09

In [7]:
tgt_key_padding_mask = create_cross_attention_mask(tgt, tgt)
src_key_padding_mask = create_cross_attention_mask(src, src)
memory_key_padding_mask = src_key_padding_mask

print(tgt_key_padding_mask.shape)
print(src_key_padding_mask.shape)

torch.Size([3, 1, 1, 1])
torch.Size([3, 1, 7, 7])


In [12]:
scores = K @ K.transpose(-2,-1) / math.sqrt(K.size(-1))
print('ScoresShape:', scores.shape)

src_key_padding_mask = ( src == 0)

mask_fill_combined(
    attention_scores=scores,
    attn_mask=create_causal_mask(seq_len=seq_len_src),
    padding_mask=src_key_padding_mask
)


ScoresShape: torch.Size([3, 8, 7, 7])


tensor([[[[-1.0000e+09, -1.0000e+09, -5.0582e-02,  ..., -1.0000e+09,
           -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
           -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
           -1.0000e+09, -1.0000e+09],
          ...,
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
           -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
           -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
           -1.0000e+09, -1.0000e+09]],

         [[-1.0000e+09, -1.0000e+09,  8.9814e-02,  ..., -1.0000e+09,
           -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
           -1.0000e+09, -1.0000e+09],
          [-1.0000e+09, -1.0000e+09, -1.0000e+09,  ..., -1.0000e+09,
           -1.0000e+09, -1.0000e+09],
          ...,
     

In [9]:
scores = K @ K.transpose(-2,-1) / math.sqrt(K.size(-1))
print('ScoresShape:', scores.shape)

src_key_padding_mask = ( src == 0)

print('AttentionMaskShape:',src_key_padding_mask.unsqueeze(1).unsqueeze(2).shape)
scores = scores.masked_fill(src_key_padding_mask.unsqueeze(1).unsqueeze(2), -1e9)

print(scores.shape)

ScoresShape: torch.Size([3, 8, 7, 7])
AttentionMaskShape: torch.Size([3, 1, 1, 7])
torch.Size([3, 8, 7, 7])


In [10]:
# Self-Attention Encoder

scores = K @ K.transpose(-2,-1) / math.sqrt(K.size(-1))
print('ScoresShape:', scores.shape)


print('AttentionMaskShape:',src_key_padding_mask.shape)

attn_weight = torch.dropout(torch.softmax(scores + src_key_padding_mask, -1), .1, train=True)
print('AttentionWeightShape:', attn_weight.shape)
context = (attn_weight @ V ).transpose(1, 2).contiguous().view(batch_size, seq_len_src, embed_dim)
print('ContextShape:', context.shape)
print(attn_weight[0][0])



ScoresShape: torch.Size([3, 8, 7, 7])
AttentionMaskShape: torch.Size([3, 7])


RuntimeError: The size of tensor a (7) must match the size of tensor b (3) at non-singleton dimension 2

In [None]:
print(
    MhA.scaled_dot_product_attention(
        K, K, V,
        attn_mask=src_key_padding_mask
    )[0]
    .transpose(1, 2)
    .contiguous()
    .view(batch_size, seq_len_src, embed_dim)[0][0]
)

print(
    MhA.scaled_dot_product_attention(
        Q, Q, Q,
        attn_mask=tgt_key_padding_mask
    )[0]
    .transpose(1, 2)
    .contiguous()
    .view(batch_size, seq_len_tgt, embed_dim)[0][0]
)

print(
    MhA.scaled_dot_product_attention(
        Q, K, V,
        attn_mask=memory_key_padding_mask
    )[0]
    .transpose(1, 2)
    .contiguous()
    .view(batch_size, seq_len_tgt, embed_dim)[0][0]
)

tensor([ 0.1347,  0.0133, -0.1113,  0.5884, -0.4262, -0.3468, -0.8244, -0.1808,
         0.6137, -0.1443,  0.0688, -0.1342, -0.5640, -0.5474, -0.0077,  0.0428,
         0.1654,  0.3951, -0.0045, -0.0851, -0.0373,  0.3932, -0.1892, -0.1142,
        -0.5534,  0.2787,  0.5795, -0.1955, -0.6816,  0.1760, -0.3149,  0.2758,
        -0.5144, -0.7355,  0.0139, -0.5996, -0.7749,  0.0522, -0.0306, -0.4824,
         0.7699, -0.0389, -0.3984, -0.0889, -0.3997, -0.4470, -0.4969,  0.0485,
        -0.5229,  0.3915, -0.1164,  0.1595, -0.2984, -0.4077,  0.2486,  0.6055,
        -0.0638,  0.5479,  0.3315, -0.2772,  0.2612, -0.1980,  0.0249, -0.6667,
         0.2648,  0.0666,  0.2216, -0.3258, -0.0243,  0.2828, -0.4516, -0.1699,
         0.3944, -0.2315,  0.2097, -0.1331, -0.1458, -0.0117, -0.2269,  0.0645,
        -0.0752, -0.5169,  0.2199,  0.2410,  0.0887, -0.1910, -0.7426,  0.3091,
        -0.1739,  0.5805, -0.4018,  0.2240, -0.1021, -0.1421, -0.5046, -0.1238,
        -0.4964, -0.1431,  0.6364, -0.30

In [None]:
(attn_weight @ V).shape

torch.Size([3, 8, 7, 16])

In [None]:
# Self-Attention 

scores = Q @ Q.transpose(-2,-1) / math.sqrt(Q.size(-1))
print('ScoresShape:', scores.shape)

attn_mask = generate_square_subsequent_mask(tgt.shape[-1])
print('AttentionMaskShape:',attn_mask.shape)

attn_weight = torch.dropout(torch.softmax(scores + attn_mask, -1), .1, train=True)
print('AttentionWeightShape:', attn_weight.shape)
context = (attn_weight @ Q ).transpose(1, 2).contiguous().view(batch_size, seq_len_tgt, embed_dim)
print('ContextShape:', context.shape)
print(attn_weight[0][0])

ScoresShape: torch.Size([3, 8, 1, 1])
AttentionMaskShape: torch.Size([1, 1])
AttentionWeightShape: torch.Size([3, 8, 1, 1])
ContextShape: torch.Size([3, 1, 128])
tensor([[1.1111]], grad_fn=<SelectBackward0>)


In [None]:
# Cross-Attention
scores = Q @ K.transpose(-2,-1) / math.sqrt(Q.size(-1))
print('ScoresShape:', scores.shape)

attn_mask = F._canonical_mask(
    mask=create_cross_attention_mask(tgt, src),
    mask_name='attn_mask',
    other_type=F._none_or_dtype(attn_mask),
    other_name="mask",
    target_type=torch.float,

)    
print('AttentionMaskShape:',attn_mask.shape)

attn_weight = clear_nan(
    torch.dropout(torch.softmax(scores + attn_mask, -1), .1, train=True)
)
print('AttentionWeightShape:', attn_weight.shape)

context = (
    (attn_weight @ V ).transpose(1, 2).contiguous().view(batch_size, seq_len_tgt, embed_dim)
)
print('ContextShape:', context.shape)
print(attn_weight[0][0])

ScoresShape: torch.Size([3, 8, 1, 7])
AttentionMaskShape: torch.Size([3, 1, 1, 7])
AttentionWeightShape: torch.Size([3, 8, 1, 7])
ContextShape: torch.Size([3, 1, 128])
tensor([[0.5467, 0.2879, 0.0000, 0.1352, 0.0000, 0.0000, 0.0000]],
       grad_fn=<SelectBackward0>)


In [None]:
attn_mask

tensor([[[[0., 0., 0., 0., -inf, -inf, -inf]]],


        [[[0., 0., 0., 0., 0., -inf, -inf]]],


        [[[0., 0., -inf, -inf, -inf, -inf, -inf]]]])