In [1]:
# imports
from pathlib import Path
import sys  

# Get my_package directory path from Notebook
parent_dir = str(Path().resolve().parents[1])

# Add to sys.path
sys.path.insert(0, parent_dir)

In [2]:
import torch 
import math
from torch import nn
import torch.nn.functional as F


In [3]:
from models.components.functional import (generate_square_subsequent_mask, 
                               create_causal_mask,
                               create_cross_attention_mask)

def clear_nan(tensor: torch.Tensor):
        return torch.where(torch.isnan(tensor), 
                               torch.zeros_like(tensor), 
                               tensor)


In [4]:
embed_dim = 128
num_heads = 8
hidden_dim = 200
max_len = 5000
batch_size = 32
d_ff = hidden_dim
num_head = 8
src_seq_len = 4
tgt_seq_len = 5

In [10]:
src = torch.tensor(
    [[ 1,3,4,2,0,0,0],
    [ 1,3,4,4,2,0,0],
    [ 1,2,0,0,0,0,0]]
)
    
# en_tensor = torch.randint(1, 10**4,[batch_size, seq_len], dtype=int)
tgt = torch.tensor(
    [[ 1],
    [ 1],
    [ 1]]
)

print(src.shape, tgt.shape)

embedding = nn.Embedding(10**4, embedding_dim=embed_dim)
src_embedding = embedding(src)
tgt_embedding = embedding(tgt)

print(src_embedding.shape, tgt_embedding.shape)

batch_size, seq_len_src, embed_dim = src_embedding.shape
_, seq_len_tgt, _ = tgt_embedding.shape

linear = nn.Linear(embed_dim, embed_dim)

Q = linear(tgt_embedding).view(batch_size, seq_len_tgt, num_head, embed_dim//num_head).transpose(1, 2)
K = linear(src_embedding).view(batch_size, seq_len_src, num_head, embed_dim//num_head).transpose(1, 2)
V = linear(src_embedding).view(batch_size, seq_len_src, num_head, embed_dim//num_head).transpose(1, 2)

print(Q.shape, V.shape)

torch.Size([3, 7]) torch.Size([3, 1])
torch.Size([3, 7, 128]) torch.Size([3, 1, 128])
torch.Size([3, 8, 1, 16]) torch.Size([3, 8, 7, 16])


In [11]:
print(generate_square_subsequent_mask(src.shape[-1]))
print(create_causal_mask(src.shape[-1]))
print(create_cross_attention_mask(tgt, src))

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0.]])
tensor([[1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1.]])
tensor([[[[False, False, False, False,  True,  True,  True]]],


        [[[False, False, False, False, False,  True,  True]]],


        [[[False, False,  True,  True,  True,  True,  True]]]])


In [12]:
# Self-Attention 

scores = Q @ Q.transpose(-2,-1) / math.sqrt(Q.size(-1))
print('ScoresShape:', scores.shape)

attn_mask = generate_square_subsequent_mask(tgt.shape[-1])
print('AttentionMaskShape:',attn_mask.shape)

attn_weight = torch.dropout(torch.softmax(scores + attn_mask, -1), .1, train=True)
print('AttentionWeightShape:', attn_weight.shape)
context = (attn_weight @ Q ).transpose(1, 2).contiguous().view(batch_size, seq_len_tgt, embed_dim)
print('ContextShape:', context.shape)
print(attn_weight[0][0])

ScoresShape: torch.Size([3, 8, 1, 1])
AttentionMaskShape: torch.Size([1, 1])
AttentionWeightShape: torch.Size([3, 8, 1, 1])
ContextShape: torch.Size([3, 1, 128])
tensor([[1.1111]], grad_fn=<SelectBackward0>)


In [13]:
mask = None

F._canonical_mask(
                mask=mask,
                mask_name='attn_mask',
                other_type=F._none_or_dtype(mask),
                other_name="mask",
                target_type=torch.float,  
            )


In [14]:
# Cross-Attention
scores = Q @ K.transpose(-2,-1) / math.sqrt(Q.size(-1))
print('ScoresShape:', scores.shape)

attn_mask = F._canonical_mask(
    mask=create_cross_attention_mask(tgt, src),
    mask_name='attn_mask',
    other_type=F._none_or_dtype(attn_mask),
    other_name="mask",
    target_type=torch.float,

)    
print('AttentionMaskShape:',attn_mask.shape)

attn_weight = clear_nan(
    torch.dropout(torch.softmax(scores + attn_mask, -1), .1, train=True)
)
print('AttentionWeightShape:', attn_weight.shape)

context = (
    (attn_weight @ V ).transpose(1, 2).contiguous().view(batch_size, seq_len_tgt, embed_dim)
)
print('ContextShape:', context.shape)
print(attn_weight[0][0])

ScoresShape: torch.Size([3, 8, 1, 7])
AttentionMaskShape: torch.Size([3, 1, 1, 7])
AttentionWeightShape: torch.Size([3, 8, 1, 7])
ContextShape: torch.Size([3, 1, 128])
tensor([[0.7817, 0.0392, 0.1920, 0.0983, 0.0000, 0.0000, 0.0000]],
       grad_fn=<SelectBackward0>)
