In [101]:
import sys
sys.path.append('../')
sys.path

['/home/andrecosta/sideprojects/transformer/notebooks',
 '/home/andrecosta/anaconda3/lib/python312.zip',
 '/home/andrecosta/anaconda3/lib/python3.12',
 '/home/andrecosta/anaconda3/lib/python3.12/lib-dynload',
 '',
 '/home/andrecosta/anaconda3/lib/python3.12/site-packages',
 '../',
 '../']

In [213]:
from transformer.mask import subsequent_mask
from transformer.utils import greedy_decode
from transformer.model import make_model
from transformer.batch import Batch

import altair as alt
import pandas as pd
import torch
import math

import seaborn as sns

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [214]:
LS_data = pd.concat(
    [
        pd.DataFrame(
            {
                "Subsequent Mask": subsequent_mask(20)[0][x, y].flatten(),
                "Window": y,
                "Masking": x,
            }
        )
        for y in range(20)
        for x in range(20)
    ]
)
LS_data

Unnamed: 0,Subsequent Mask,Window,Masking
0,True,0,0
0,True,0,1
0,True,0,2
0,True,0,3
0,True,0,4
...,...,...,...
0,False,19,15
0,False,19,16
0,False,19,17
0,False,19,18


In [215]:
(
    alt.Chart(LS_data)
    .mark_rect()
    .properties(height=250, width=250)
    .encode(
        alt.X("Window:O"),
        alt.Y("Masking:O"),
        alt.Color("Subsequent Mask:Q", scale=alt.Scale(scheme="viridis")),
    )
    .interactive()
)

In [216]:
subsequent_mask(5).unsqueeze(1)

tensor([[[[ True, False, False, False, False],
          [ True,  True, False, False, False],
          [ True,  True,  True, False, False],
          [ True,  True,  True,  True, False],
          [ True,  True,  True,  True,  True]]]])

In [217]:
from transformer.mha import MultiHeadedAttention
from transformer.attention import attention

In [218]:
h = 5
d_model = 10
attn = MultiHeadedAttention(h, d_model)

In [219]:
_in = torch.rand(3, 10)
q, k, v = (_in, _in, _in)

In [220]:
nbatches = q.size(0)
nbatches

3

In [221]:
q, k, v = [
    lin(x).view(
        nbatches,
        -1,
        h,
        d_model // h
    ).transpose(1, 2) for lin, x in zip(attn.linears, (q, v, k))
]
q.shape, k.shape, v.shape

(torch.Size([3, 5, 1, 2]), torch.Size([3, 5, 1, 2]), torch.Size([3, 5, 1, 2]))

In [222]:
xline, p_attn = attention(q, k, v)
xline.shape, p_attn.shape

(torch.Size([3, 5, 1, 2]), torch.Size([3, 5, 1, 1]))

In [223]:
# def subsequent_mask(size):
#     "Mask out subsequent positions."
#     attn_shape = (1, size, size)
#     subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(
#         torch.uint8
#     )
#     return subsequent_mask == 0


In [224]:
V = 10
batch_size = 3
pad = 2

In [225]:
data = torch.randint(1, V, size=(batch_size, 10))
data[:, 0] = 1
src = data.requires_grad_(False).clone().detach()
tgt = data.requires_grad_(False).clone().detach()
batch = Batch(src, tgt, 0)

In [226]:
batch.tgt_mask.shape

torch.Size([3, 9, 9])

In [256]:
n_rows = 3
n_cols = 6
_in = torch.rand(n_rows, n_cols)
q, v, k = (_in, _in, _in)

In [257]:
d_k = q.size(-1)
d_k

6

In [258]:
_in.shape

torch.Size([3, 6])

In [259]:
_in.transpose(-2, -1).shape

torch.Size([6, 3])

In [260]:
mask = subsequent_mask(n_rows)
mask.shape

torch.Size([1, 3, 3])

In [261]:
scores = (torch.matmul(_in, _in.transpose(-2, -1))) / math.sqrt(d_k)
scores

tensor([[1.0398, 1.0323, 0.3253],
        [1.0323, 1.5931, 0.5853],
        [0.3253, 0.5853, 0.3364]])

In [262]:
scores = scores.masked_fill(mask == 0, 0.01)
scores

tensor([[[1.0398, 0.0100, 0.0100],
         [1.0323, 1.5931, 0.0100],
         [0.3253, 0.5853, 0.3364]]])

In [263]:
p_attn = scores.softmax(dim=-1)
p_attn

tensor([[[0.5834, 0.2083, 0.2083],
         [0.3213, 0.5630, 0.1156],
         [0.3023, 0.3920, 0.3056]]])

In [264]:
p_attn.shape, v.shape

(torch.Size([1, 3, 3]), torch.Size([3, 6]))

In [271]:
p_attn

tensor([[[0.5834, 0.2083, 0.2083],
         [0.3213, 0.5630, 0.1156],
         [0.3023, 0.3920, 0.3056]]])

In [272]:
v

tensor([[0.5277, 0.6852, 0.1710, 0.0384, 0.9455, 0.9350],
        [0.8526, 0.6190, 0.9002, 0.8291, 0.9773, 0.5825],
        [0.1036, 0.0788, 0.6422, 0.3863, 0.1212, 0.4802]])

In [280]:
p_attn[0][0][0]*v[0][0] + p_attn[0][0][1]*v[1][0] + p_attn[0][0][2]*v[2][0]

tensor(0.5070)

In [273]:
newx = torch.matmul(p_attn, v)
newx

tensor([[[0.5070, 0.5451, 0.4211, 0.2756, 0.7804, 0.7668],
         [0.6616, 0.5778, 0.6360, 0.5238, 0.8681, 0.6839],
         [0.5255, 0.4739, 0.6009, 0.4548, 0.7060, 0.6578]]])

In [269]:
# src = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]])
# max_len = src.shape[1]
# src_mask = torch.ones(1, 1, max_len)
# print(greedy_decode(model, src, src_mask, max_len=max_len, start_symbol=0))
# src_mask