In [1]:
import sys
sys.path.append('../')
sys.path

['/home/andrecosta/sideprojects/transformer/notebooks',
 '/home/andrecosta/anaconda3/lib/python312.zip',
 '/home/andrecosta/anaconda3/lib/python3.12',
 '/home/andrecosta/anaconda3/lib/python3.12/lib-dynload',
 '',
 '/home/andrecosta/anaconda3/lib/python3.12/site-packages',
 '../']

In [2]:
from transformer.mask import subsequent_mask
from transformer.translator import greedy_decode
from transformer.model import make_model
from transformer.batch import Batch

import altair as alt
import pandas as pd
import torch
import math

import seaborn as sns

%load_ext autoreload
%autoreload 2

In [3]:
LS_data = pd.concat(
    [
        pd.DataFrame(
            {
                "Subsequent Mask": subsequent_mask(20)[0][x, y].flatten(),
                "Window": y,
                "Masking": x,
            }
        )
        for y in range(20)
        for x in range(20)
    ]
)
LS_data

Unnamed: 0,Subsequent Mask,Window,Masking
0,True,0,0
0,True,0,1
0,True,0,2
0,True,0,3
0,True,0,4
...,...,...,...
0,False,19,15
0,False,19,16
0,False,19,17
0,False,19,18


In [4]:
(
    alt.Chart(LS_data)
    .mark_rect()
    .properties(height=250, width=250)
    .encode(
        alt.X("Window:O"),
        alt.Y("Masking:O"),
        alt.Color("Subsequent Mask:Q", scale=alt.Scale(scheme="viridis")),
    )
    .interactive()
)

In [5]:
subsequent_mask(5).unsqueeze(1)

tensor([[[[ True, False, False, False, False],
          [ True,  True, False, False, False],
          [ True,  True,  True, False, False],
          [ True,  True,  True,  True, False],
          [ True,  True,  True,  True,  True]]]])

In [6]:
from transformer.mha import MultiHeadedAttention
from transformer.attention import attention

In [7]:
h = 5
d_model = 10
attn = MultiHeadedAttention(h, d_model)

In [8]:
_in = torch.rand(3, 10)
q, k, v = (_in, _in, _in)

In [9]:
nbatches = q.size(0)
nbatches

3

In [10]:
q, k, v = [
    lin(x).view(
        nbatches,
        -1,
        h,
        d_model // h
    ).transpose(1, 2) for lin, x in zip(attn.linears, (q, v, k))
]
q.shape, k.shape, v.shape

(torch.Size([3, 5, 1, 2]), torch.Size([3, 5, 1, 2]), torch.Size([3, 5, 1, 2]))

In [11]:
xline, p_attn = attention(q, k, v)
xline.shape, p_attn.shape

(torch.Size([3, 5, 1, 2]), torch.Size([3, 5, 1, 1]))

In [12]:
# def subsequent_mask(size):
#     "Mask out subsequent positions."
#     attn_shape = (1, size, size)
#     subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(
#         torch.uint8
#     )
#     return subsequent_mask == 0


In [13]:
V = 10
batch_size = 3
pad = 2

In [14]:
data = torch.randint(1, V, size=(batch_size, 10))
data[:, 0] = 1
src = data.requires_grad_(False).clone().detach()
tgt = data.requires_grad_(False).clone().detach()
batch = Batch(src, tgt, 0)

In [15]:
batch.tgt_mask.shape

torch.Size([3, 9, 9])

In [16]:
n_rows = 3
n_cols = 6
_in = torch.rand(n_rows, n_cols)
q, v, k = (_in, _in, _in)

In [17]:
d_k = q.size(-1)
d_k

6

In [18]:
_in.shape

torch.Size([3, 6])

In [19]:
_in.transpose(-2, -1).shape

torch.Size([6, 3])

In [20]:
mask = subsequent_mask(n_rows)
mask.shape

torch.Size([1, 3, 3])

In [21]:
scores = (torch.matmul(_in, _in.transpose(-2, -1))) / math.sqrt(d_k)
scores

tensor([[0.7070, 0.6221, 0.4751],
        [0.6221, 0.9932, 0.3157],
        [0.4751, 0.3157, 0.6994]])

In [22]:
scores = scores.masked_fill(mask == 0, 0.01)
scores

tensor([[[0.7070, 0.0100, 0.0100],
         [0.6221, 0.9932, 0.0100],
         [0.4751, 0.3157, 0.6994]]])

In [23]:
p_attn = scores.softmax(dim=-1)
p_attn

tensor([[[0.5010, 0.2495, 0.2495],
         [0.3343, 0.4845, 0.1812],
         [0.3221, 0.2747, 0.4032]]])

In [24]:
p_attn.shape, v.shape

(torch.Size([1, 3, 3]), torch.Size([3, 6]))

In [25]:
p_attn

tensor([[[0.5010, 0.2495, 0.2495],
         [0.3343, 0.4845, 0.1812],
         [0.3221, 0.2747, 0.4032]]])

In [26]:
v

tensor([[0.1646, 0.4215, 0.2770, 0.6286, 0.4390, 0.9288],
        [0.6200, 0.7655, 0.9439, 0.1105, 0.3253, 0.6734],
        [0.2386, 0.0098, 0.1722, 0.9513, 0.8419, 0.1131]])

In [27]:
p_attn[0][0][0]*v[0][0] + p_attn[0][0][1]*v[1][0] + p_attn[0][0][2]*v[2][0]

tensor(0.2967)

In [28]:
newx = torch.matmul(p_attn, v)
newx

tensor([[[0.2967, 0.4046, 0.4172, 0.5798, 0.5112, 0.6615],
         [0.3987, 0.5135, 0.5811, 0.4361, 0.4570, 0.6572],
         [0.3195, 0.3500, 0.4179, 0.6164, 0.5702, 0.5298]]])

In [29]:
# src = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
# max_len = src.shape[1]
# src_mask = torch.ones(1, 1, max_len)
# print(greedy_decode(model, src, src_mask, max_len=max_len, start_symbol=0))
# src_mask