In [1]:
# imports
from pathlib import Path
import sys  

# Get my_package directory path from Notebook
parent_dir = str(Path().resolve().parents[1])

# Add to sys.path
sys.path.insert(0, parent_dir)

In [2]:
import torch 
import math
from torch import nn
import torch.nn.functional as F


In [None]:
from models.components.functional import (generate_square_subsequent_mask, 
                               create_causal_mask,
                               create_cross_attention_mask)

def clear_nan(tensor: torch.Tensor):
        return torch.where(torch.isnan(tensor), 
                               torch.zeros_like(tensor), 
                               tensor)


In [4]:
embed_dim = 128
num_heads = 8
hidden_dim = 200
max_len = 5000
batch_size = 32
d_ff = hidden_dim
num_head = 8
src_seq_len = 4
tgt_seq_len = 5

In [5]:
src = torch.tensor(
    [[ 1,3,4,2,0,0,0],
    [ 1,3,4,4,2,0,0],
    [ 1,2,0,0,0,0,0]]
)
    
# en_tensor = torch.randint(1, 10**4,[batch_size, seq_len], dtype=int)
tgt = torch.tensor(
    [[ 1,3,2,0,0,0,0,0],
    [ 1,3,3,3,2,0,0,0],
    [ 1,3,2,0,0,0,0,0]]
)

print(src.shape, tgt.shape)

embedding = nn.Embedding(10**4, embedding_dim=embed_dim)
src_embedding = embedding(src)
tgt_embedding = embedding(tgt)

print(src_embedding.shape, tgt_embedding.shape)

batch_size, seq_len_src, embed_dim = src_embedding.shape
_, seq_len_tgt, _ = tgt_embedding.shape

linear = nn.Linear(embed_dim, embed_dim)

Q = linear(tgt_embedding).view(batch_size, seq_len_tgt, num_head, embed_dim//num_head).transpose(1, 2)
K = linear(src_embedding).view(batch_size, seq_len_src, num_head, embed_dim//num_head).transpose(1, 2)
V = linear(src_embedding).view(batch_size, seq_len_src, num_head, embed_dim//num_head).transpose(1, 2)

print(Q.shape, V.shape)

torch.Size([3, 7]) torch.Size([3, 8])
torch.Size([3, 7, 128]) torch.Size([3, 8, 128])
torch.Size([3, 8, 8, 16]) torch.Size([3, 8, 7, 16])


In [6]:
print(generate_square_subsequent_mask(src.shape[-1]))
print(create_causal_mask(src.shape[-1]))
print(create_cross_attention_mask(tgt, src))

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0.]])
tensor([[1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1.]])
tensor([[[[False, False, False, False,  True,  True,  True],
          [False, False, False, False,  True,  True,  True],
          [False, False, False, False,  True,  True,  True],
          [ True,  True,  True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True,  True,  True

In [7]:
# Self-Attention 

scores = Q @ Q.transpose(-2,-1) / math.sqrt(Q.size(-1))
print('ScoresShape:', scores.shape)

attn_mask = generate_square_subsequent_mask(tgt.shape[-1])
print('AttentionMaskShape:',attn_mask.shape)

attn_weight = torch.dropout(torch.softmax(scores + attn_mask, -1), .1, train=True)
print('AttentionWeightShape:', attn_weight.shape)
context = (attn_weight @ Q ).transpose(1, 2).contiguous().view(batch_size, seq_len_tgt, embed_dim)
print('ContextShape:', context.shape)
print(attn_weight[0][0])

ScoresShape: torch.Size([3, 8, 8, 8])
AttentionMaskShape: torch.Size([8, 8])
AttentionWeightShape: torch.Size([3, 8, 8, 8])
ContextShape: torch.Size([3, 8, 128])
tensor([[1.1111, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2879, 0.8232, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2624, 0.1422, 0.7065, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2580, 0.2004, 0.2273, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1866, 0.1449, 0.1644, 0.3076, 0.3076, 0.0000, 0.0000, 0.0000],
        [0.1461, 0.1135, 0.1287, 0.2409, 0.2409, 0.0000, 0.0000, 0.0000],
        [0.1201, 0.0933, 0.1058, 0.1980, 0.1980, 0.1980, 0.1980, 0.0000],
        [0.1019, 0.0792, 0.0898, 0.1680, 0.1680, 0.1680, 0.1680, 0.1680]],
       grad_fn=<SelectBackward0>)


In [8]:
mask = None

F._canonical_mask(
                mask=mask,
                mask_name='attn_mask',
                other_type=F._none_or_dtype(mask),
                other_name="mask",
                target_type=torch.float,  
            )


In [None]:
# Cross-Attention
scores = Q @ K.transpose(-2,-1) / math.sqrt(Q.size(-1))
print('ScoresShape:', scores.shape)

attn_mask = F._canonical_mask(
    mask=create_cross_attention_mask(tgt, src),
    mask_name='attn_mask',
    other_type=F._none_or_dtype(attn_mask),
    other_name="mask",
    target_type=torch.float,

)    
print('AttentionMaskShape:',attn_mask.shape)

attn_weight = clear_nan(
    torch.dropout(torch.softmax(scores + attn_mask, -1), .1, train=True)
)
print('AttentionWeightShape:', attn_weight.shape)

context = (
    (attn_weight @ V ).transpose(1, 2).contiguous().view(batch_size, seq_len_tgt, embed_dim)
)
print('ContextShape:', context.shape)
print(attn_weight[0][0])

ScoresShape: torch.Size([3, 8, 8, 7])
AttentionMaskShape: torch.Size([3, 1, 8, 7])
AttentionWeightShape: torch.Size([3, 8, 8, 7])
ContextShape: torch.Size([3, 8, 128])
tensor([-0.2351,  0.0437,  0.3506,  0.1932,  0.2827, -0.0188, -0.5101,  0.0163,
         0.7627,  0.2617,  0.2117, -0.3608,  0.0725, -0.0302,  0.0799,  0.3650,
         0.2956, -0.4211, -0.0770, -0.2030, -0.0968,  0.1160,  0.2713, -0.5688,
         0.6390,  0.5940, -0.4506, -0.1929, -0.4062,  0.5581, -0.0551,  0.3531,
         0.5743,  0.6006,  0.5474,  0.0894, -0.0363,  0.0387,  0.8682, -0.2338,
         0.6029,  0.5158, -0.0798,  0.3372, -0.0107, -0.5761,  0.4835, -0.2120,
         0.2287,  0.2663, -0.3287, -0.2782, -0.4811,  0.2382,  0.0768,  0.3470,
        -0.3056,  0.2837, -0.0549, -0.0228, -0.0104,  0.1383, -0.3286,  0.4399,
        -0.6087,  0.0828, -0.3126, -0.0319,  0.4788, -0.9044,  0.5025,  0.3697,
        -0.7963, -0.4134,  0.0427, -0.4408, -0.1149,  0.3260, -0.6283, -0.4193,
         0.0424, -0.3924,  0.015

