In [1]:
# Unmasked attention

In [2]:
import torch
import math
import torch.nn.functional as F

def self_attention(q, k, v):
    # if 3 dim: b t d
    #prod = Q.bmm(K.permute(0, 2, 1))
    # or
    # prod = torch.einsum("btd, bsd -> bts", q, k)
    # if 4 dim: b t h dh
    # q = q.permute(0, 2, 1, 3)
    prod = torch.einsum("bthd, bshd -> bhts", q, k)
    scaled_prod = prod/torch.sqrt(torch.tensor(q.shape[-1]))
    softmaxed_prod = F.softmax(scaled_prod, dim=-1)
    print(softmaxed_prod.shape)
    # print(softmaxed_prod)
    return softmaxed_prod @ v.permute(0, 2, 1, 3)


x = torch.rand([2, 3, 4, 5])
self_attention(x, x, x)
self_attention(x, x, x).shape

  cpu = _conversion_method_template(device=torch.device("cpu"))


torch.Size([2, 4, 3, 3])
torch.Size([2, 4, 3, 3])


torch.Size([2, 4, 3, 5])

In [3]:
from torch import nn

class MHSA(nn.Module):
    def __init__(self, d: int = 512, h: int = 8):
        super().__init__()
        assert d % h == 0
        self.d = d
        self.dh = d // h
        self.h = h
        self.wq = nn.Linear(self.d, self.d)
        self.wk = nn.Linear(self.d, self.d)
        self.wv = nn.Linear(self.d, self.d)
        self.wo = nn.Linear(self.d, self.d)
 
    def forward(self, q, k, v):
        # b, t, d
        b, t, d = q.size()
        wq = self.wq(q)
        wk = self.wk(k)
        wv = self.wv(v)
        wq = wq.view(b, t, self.h, self.dh)
        wk = wk.view(b, t, self.h, self.dh)
        wv = wv.view(b, t, self.h, self.dh)
        # b, t, h, dh
        # if changing from 4 dim -> 3 dim: b*h, t, dh
        # wq = wq.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # wk = wk.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # wv = wv.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # another option 4 dim -> 3 dim
        # wq = wq.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # wk = wk.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # wv = wv.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # changing the number of dims is not necessary as @ supports 4 dims
        attn = self_attention(wq, wk, wv)
        # b * h, t, dh
        # attn = attn.view(b, self.h, t, self.dh).permute(0, 2, 1, 3).reshape(b, t, d)
        attn = attn.view(b, self.h, t, self.dh).transpose(1, 2).contiguous().view(b, t, d)
        wo = self.wo(attn)
        return wo
        # # 1 2 3 4
        # x = F.relu(self.conv1(x))
        # return F.relu(self.conv2(x))

mhsa = MHSA()
x = torch.rand(2, 3, 512)
mhsa(x, x, x).shape

torch.Size([2, 8, 3, 3])


torch.Size([2, 3, 512])

In [4]:
# PE

In [5]:
from torch import nn

class PE(nn.Module):

    def __init__(self, d: int = 512, max_len: int = 1024, dropout: float = 0.1):
        super().__init__()
        self.d = d
        self.dropout = nn.Dropout(p=dropout)

        twoi = torch.arange(0, self.d, 2)
        pow_ = torch.pow(10000, twoi / self.d)
        position = torch.arange(0, max_len).unsqueeze(1)
        sin_p = torch.sin(position / pow_)
        cos_p = torch.cos(position / pow_)
        pe = torch.zeros(max_len, self.d, requires_grad=False) # Explicit, register buffer insures requires grad = False
        pe[:, 0::2] = sin_p
        pe[:, 1::2] = cos_p
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe) 

    def forward(self, x):
        b, t, d = x.size()
        x = x + self.pe[:, : t, :]
        return self.dropout(x)
print(PE(d=4)(torch.arange(24).view(-1, 3, 4)).size()) # torch.Size([2, 3, 4])

torch.Size([2, 3, 4])


In [6]:
from torch import nn

class PEEmbed(nn.Module):

    def __init__(self, d: int = 512, max_len: int = 1024, dropout: float = 0.1):
        super().__init__()
        self.d = d
        self.pe = nn.Embedding(max_len, d)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        b, t, d = x.size()
        pos = self.pe(torch.arange(t))
        x = x + pos
        return self.dropout(x)
print(PEEmbed(d=4)(torch.arange(24).view(-1, 3, 4)).size()) # torch.Size([2, 3, 4])

torch.Size([2, 3, 4])


In [7]:
# Encoder without mask

In [8]:
import torch.nn.functional as F

class EncoderLayerWithoutMask(nn.Module): 

    def __init__(self, d: int = 512, h: int = 8, dropout: float = 0.1):
        super().__init__()
        self.mhsa = MHSA(d, h)
        self.norm1 = nn.LayerNorm(d)
        self.ff1 = nn.Linear(d, d * 4)
        self.ff2 = nn.Linear(d * 4, d)
        self.norm2 = nn.LayerNorm(d)
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)

    def forward(self, x):
        b, t, d = x.size()
        x = x + self.attn_dropout(self.mhsa(x, x, x))
        x = self.norm1(x)
        x = x + self.resid_dropout(self.ff2(F.relu(self.ff1(x))))
        x = self.norm2(x)
        return x


encoder_layer = EncoderLayerWithoutMask()
x = torch.rand(2, 3, 512)
encoder_layer(x).shape

torch.Size([2, 8, 3, 3])


torch.Size([2, 3, 512])

In [9]:
from torch import nn

class EncoderWithoutMask(nn.Module): 

    def __init__(self, vocab_size: int = 2**13, n: int = 6, d: int = 512, h: int = 8):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d)
        self.pe = PE(d=d)
        self.layers = [EncoderLayerWithoutMask(d, h) for _ in range(n)]

    def forward(self, x):
        b, t = x.size()
        x = self.embed(x)
        x = self.pe(x)
        for layer in self.layers:
            x = layer(x)
        return x

encoder = EncoderWithoutMask()
x = torch.randint(0, 2**13, (2, 3))
encoder(x).shape

torch.Size([2, 8, 3, 3])
torch.Size([2, 8, 3, 3])
torch.Size([2, 8, 3, 3])
torch.Size([2, 8, 3, 3])
torch.Size([2, 8, 3, 3])
torch.Size([2, 8, 3, 3])


torch.Size([2, 3, 512])

In [10]:
# With masks
import torch
import math
import torch.nn.functional as F

def self_attention_masked(q, k, v, mask=None):
    # if 3 dim: b t d
    #prod = Q.bmm(K.permute(0, 2, 1))
    # or
    # prod = torch.einsum("btd, bsd -> bts", q, k)
    # if 4 dim: b t h dh:
    prod = torch.einsum("bthd, bshd -> bhts", q, k)
    scaled_prod = prod/torch.sqrt(torch.tensor(q.shape[-1]))
    print(f"scaled_prod.shape: \n {scaled_prod.shape}")
    # mask should be in shape to be broadcastable to bhts and lead to masked keys only (last s dim)
    if mask is not None:
        scaled_prod = scaled_prod.masked_fill(mask == 0, float("-inf"))
    print(f"scaled_prod: \n {scaled_prod}")
    softmaxed_prod = F.softmax(scaled_prod, dim=-1)
    # print(softmaxed_prod.shape)
    print(f"softmaxed_prod: \n {softmaxed_prod}")
    # swap h and t in v
    return softmaxed_prod @ v.permute(0, 2, 1, 3)


In [11]:
# Mask

In [12]:
# play with mask

x = torch.rand([2, 3, 2, 4])
print(x)
# mask 2 batches 3 timeseries
mask = torch.ones([2, 3])
mask[0, 2] = 0
mask[1, 2] = 0
mask[1, 1] = 0
print(f"mask: \n {mask}")
# add head dim to make mask broatcastable to q x k.T prod. mask shape 2, 1, 3
mask = mask.unsqueeze(1)


# mask = mask.permute(0, 2, 1)
# is the mask that I need? keys are ignored?
print(f"wrong mask: \n {mask}")
#  mask = 2 1 3 -> b prepended before broadcasting (1!!!) h (remains since already 2) t (broadcasted from 1) d (remains since already 3) 
print(f"wrong mask broadcast: \n {mask.broadcast_to([2, 2, 3, 3])}") 
a = self_attention_masked(x, x, x, mask=mask)
print(f"wrong a: \n {a}" )
print(f"wrong a.shape: \n {a.shape}")
# leads to wrong attention since the shape of mask is wrong 2 1 3 

# correct mask
# mask 2 batches 3 timeseries
mask = torch.ones([2, 3])
mask[0, 2] = 0
mask[1, 2] = 0
mask[1, 1] = 0
mask = mask.unsqueeze(1).unsqueeze(1)

print(f"mask: \n {mask}")
#  mask = 2 1 1 3 -> b (remains already 2) h (broadcasted from 1) t (broadcasted from 1) d (remains since already 3) 
print(f"mask broadcast: \n {mask.broadcast_to([2, 2, 3, 3])}") 
a = self_attention_masked(x, x, x, mask=mask)
print(f"a: \n {a}" )
print(f"a.shape: \n {a.shape}")


tensor([[[[0.7577, 0.9060, 0.7230, 0.9500],
          [0.6063, 0.0449, 0.7067, 0.5544]],

         [[0.6978, 0.1751, 0.0026, 0.9557],
          [0.9310, 0.8582, 0.4305, 0.2082]],

         [[0.2590, 0.9259, 0.9320, 0.9326],
          [0.4594, 0.2069, 0.3897, 0.9627]]],


        [[[0.1810, 0.9472, 0.1688, 0.4276],
          [0.3366, 0.1658, 0.4613, 0.7097]],

         [[0.1066, 0.2280, 0.1131, 0.0694],
          [0.8893, 0.6127, 0.1256, 0.5475]],

         [[0.8238, 0.5106, 0.8772, 0.0728],
          [0.4804, 0.8187, 0.2267, 0.7607]]]])
mask: 
 tensor([[1., 1., 0.],
        [1., 0., 0.]])
wrong mask: 
 tensor([[[1., 1., 0.]],

        [[1., 0., 0.]]])
wrong mask broadcast: 
 tensor([[[[1., 1., 0.],
          [1., 1., 0.],
          [1., 1., 0.]],

         [[1., 0., 0.],
          [1., 0., 0.],
          [1., 0., 0.]]],


        [[[1., 1., 0.],
          [1., 1., 0.],
          [1., 1., 0.]],

         [[1., 0., 0.],
          [1., 0., 0.],
          [1., 0., 0.]]]])
scaled_prod.shape

In [13]:
# mask is equal to making keys on masked places 0:
# the result in terms of masked symbols is the same
k = x.clone()
k[0, 2, 0, :] = float("-inf")
k[0, 2, 1, :] = float("-inf")
k[1, 2, 0, :] = float("-inf")
k[1, 1, 0, :] = float("-inf")
k[1, 2, 1, :] = float("-inf")
k[1, 1, 1, :] = float("-inf")
print(f"k: \n {k}")
a = self_attention_masked(x, k, x)
print(f"a: \n {a}" )
print(f"a.shape: \n {a.shape}")
# a is the same shape as if mask was applied in q * k:

test = torch.rand([2, 3, 4])
test[0, 2, :] = 0
test[1, 1, :] = 0
test[1, 2, :] = 0

print(f"test: \n {test}")
test_v = test.view(2, 3, 2, 2)
print(f"test_v: \n {test_v}")
test_perm = test_v.permute(0, 2, 1, 3)
print(f"test_perm: \n {test_perm}")

# or like that:
test_q = torch.rand([2, 3, 4])
test_k = test_q.clone()
test_k[0, 2, :] = float("-inf")
test_k[1, 1, :] = float("-inf")
test_k[1, 2, :] = float("-inf")
print(f"test_k: \n {test_k}")

test_q_view = test_q.view(2, 3, 2, 2)
test_k_view = test_k.view(2, 3, 2, 2)
print(f"test_k_view: \n {test_k_view}")
test_q_perm = test_q_view.permute(0, 2, 1, 3)
test_k_perm = test_k_view.permute(0, 2, 1, 3)
print(f"test_k_perm: \n {test_k_perm}")
print(f"q * k: \n {torch.einsum("bhtd, bhsd -> bhts", test_q_perm, test_k_perm)}")

k: 
 tensor([[[[0.7577, 0.9060, 0.7230, 0.9500],
          [0.6063, 0.0449, 0.7067, 0.5544]],

         [[0.6978, 0.1751, 0.0026, 0.9557],
          [0.9310, 0.8582, 0.4305, 0.2082]],

         [[  -inf,   -inf,   -inf,   -inf],
          [  -inf,   -inf,   -inf,   -inf]]],


        [[[0.1810, 0.9472, 0.1688, 0.4276],
          [0.3366, 0.1658, 0.4613, 0.7097]],

         [[  -inf,   -inf,   -inf,   -inf],
          [  -inf,   -inf,   -inf,   -inf]],

         [[  -inf,   -inf,   -inf,   -inf],
          [  -inf,   -inf,   -inf,   -inf]]]])
scaled_prod.shape: 
 torch.Size([2, 2, 3, 3])
scaled_prod: 
 tensor([[[[1.4100, 0.7986,   -inf],
          [0.7986, 0.7155,   -inf],
          [1.2974, 0.6183,   -inf]],

         [[0.5882, 0.5113,   -inf],
          [0.5113, 0.9160,   -inf],
          [0.5485, 0.4867,   -inf]]],


        [[[0.5707,   -inf,   -inf],
          [0.1420,   -inf,   -inf],
          [0.4060,   -inf,   -inf]],

         [[0.4286,   -inf,   -inf],
          [0.4237,   -i

In [14]:
import torch

def build_padding_mask(x, pad_token):
    # x: b t shape
    mask = torch.ones_like(x)
    return mask.masked_fill(x == pad_token, 0)

x = torch.rand(5, 6)
x[0, -3:] = 100
x[1, -2:] = 100
x[2, -1] = 100
x[3, :] = 100
print(x)
print(build_padding_mask(x, 100))

tensor([[5.6423e-02, 2.8367e-01, 2.9734e-01, 1.0000e+02, 1.0000e+02, 1.0000e+02],
        [2.7809e-01, 3.4594e-01, 7.0030e-01, 5.5790e-01, 1.0000e+02, 1.0000e+02],
        [2.8416e-01, 5.6292e-01, 2.7772e-01, 2.1171e-01, 4.1438e-01, 1.0000e+02],
        [1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02],
        [5.8055e-02, 6.6687e-01, 8.2529e-01, 7.1399e-01, 2.7075e-01, 4.5091e-01]])
tensor([[1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [15]:
import torch

def build_causal_mask(x):
    # x: b t shape
    m = torch.ones_like(x)
    return torch.tril(m)
x = torch.rand(5, 6)

print(build_causal_mask(x))

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.]])


In [16]:
import torch

def merge_masks(m1, m2):
    return m1 * m2

x = torch.rand(5, 6)
x[0, -3:] = 100
x[1, -1] = 100
x[2, -4:] = 100
x[3, :] = 100
print(x)
m1 = build_padding_mask(x, 100)
m2 = build_causal_mask(x)
print(merge_masks(m1, m2))

tensor([[  0.8224,   0.1996,   0.5486, 100.0000, 100.0000, 100.0000],
        [  0.7127,   0.1006,   0.3603,   0.2332,   0.1146, 100.0000],
        [  0.3179,   0.7870, 100.0000, 100.0000, 100.0000, 100.0000],
        [100.0000, 100.0000, 100.0000, 100.0000, 100.0000, 100.0000],
        [  0.6792,   0.4480,   0.2321,   0.4819,   0.8868,   0.4640]])
tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0.]])


In [17]:
import torch

def reshape_mask(mask):
    # b t -> b 1 1 t (to be broadcastable to b h t t)
    return mask.unsqueeze(1).unsqueeze(1)

x = torch.rand(2, 3)
print(reshape_mask(build_causal_mask(x)))

tensor([[[[1., 0., 0.]]],


        [[[1., 1., 0.]]]])


In [18]:
from torch import nn

class MHSAMasked(nn.Module):
    def __init__(self, d: int = 512, h: int = 8):
        super().__init__()
        assert d % h == 0
        self.d = d
        self.dh = d // h
        self.h = h
        self.wq = nn.Linear(self.d, self.d)
        self.wk = nn.Linear(self.d, self.d)
        self.wv = nn.Linear(self.d, self.d)
        self.wo = nn.Linear(self.d, self.d)
 
    def forward(self, q, k, v, mask = None):
        # b, t, d
        b, t, d = q.size()
        wq = self.wq(q)
        wk = self.wk(k)
        wv = self.wv(v)
        wq = wq.view(b, t, self.h, self.dh)
        wk = wk.view(b, t, self.h, self.dh)
        wv = wv.view(b, t, self.h, self.dh)
        # b, t, h, dh
        # if changing from 4 dim -> 3 dim: b*h, t, dh
        # wq = wq.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # wk = wk.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # wv = wv.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # another option 4 dim -> 3 dim
        # wq = wq.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # wk = wk.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # wv = wv.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # changing the number of dims is not necessary as @ supports 4 dims
        attn = self_attention_masked(wq, wk, wv, mask=mask)
        # b * h, t, dh
        # attn = attn.view(b, self.h, t, self.dh).permute(0, 2, 1, 3).reshape(b, t, d)
        attn = attn.view(b, self.h, t, self.dh).transpose(1, 2).contiguous().view(b, t, d)
        wo = self.wo(attn)
        return wo
        # # 1 2 3 4
        # x = F.relu(self.conv1(x))
        # return F.relu(self.conv2(x))

mhsa_masked = MHSAMasked(h = 2, d = 6)
x = torch.rand(4, 5)
mask = reshape_mask(build_causal_mask(x))
print(mask)
x = torch.rand(4, 5, 6)
print(mhsa_masked(x, x, x, mask=mask))
print(mhsa_masked(x, x, x, mask=mask).shape)

tensor([[[[1., 0., 0., 0., 0.]]],


        [[[1., 1., 0., 0., 0.]]],


        [[[1., 1., 1., 0., 0.]]],


        [[[1., 1., 1., 1., 0.]]]])
scaled_prod.shape: 
 torch.Size([4, 2, 5, 5])
scaled_prod: 
 tensor([[[[-0.0447,    -inf,    -inf,    -inf,    -inf],
          [-0.0756,    -inf,    -inf,    -inf,    -inf],
          [-0.1176,    -inf,    -inf,    -inf,    -inf],
          [-0.0998,    -inf,    -inf,    -inf,    -inf],
          [-0.0753,    -inf,    -inf,    -inf,    -inf]],

         [[-0.1483,    -inf,    -inf,    -inf,    -inf],
          [-0.0733,    -inf,    -inf,    -inf,    -inf],
          [-0.1115,    -inf,    -inf,    -inf,    -inf],
          [-0.1080,    -inf,    -inf,    -inf,    -inf],
          [-0.0453,    -inf,    -inf,    -inf,    -inf]]],


        [[[ 0.0149, -0.0304,    -inf,    -inf,    -inf],
          [-0.0232, -0.0135,    -inf,    -inf,    -inf],
          [-0.0262, -0.0398,    -inf,    -inf,    -inf],
          [-0.0208, -0.0575,    -inf,    -inf,   

In [19]:
# Transformer implementation 

In [20]:
import torch.nn.functional as F

class EncoderLayer(nn.Module): 

    def __init__(self, d: int = 512, h: int = 8, dropout: float = 0.1):
        super().__init__()
        self.mhsa = MHSAMasked(d, h)
        self.norm1 = nn.LayerNorm(d)
        self.ff1 = nn.Linear(d, d * 4)
        self.ff2 = nn.Linear(d * 4, d)
        self.norm2 = nn.LayerNorm(d)
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)

    def forward(self, x, self_mask=None):
        b, t, d = x.size()
        x = x + self.attn_dropout(self.mhsa(x, x, x, mask=self_mask))
        x = self.norm1(x)
        x = x + self.resid_dropout(self.ff2(F.relu(self.ff1(x))))
        x = self.norm2(x)
        return x


encoder_layer = EncoderLayer()
self_mask = build_padding_mask(torch.tensor([[2, 2, 0], [2, 0, 0]]), pad_token=0)
self_mask = reshape_mask(self_mask)
x = torch.rand(2, 3, 512)

encoder_layer(x, self_mask=self_mask).shape

scaled_prod.shape: 
 torch.Size([2, 8, 3, 3])
scaled_prod: 
 tensor([[[[-0.0138, -0.1131,    -inf],
          [-0.0717, -0.1310,    -inf],
          [-0.1991, -0.2871,    -inf]],

         [[ 0.0512, -0.0554,    -inf],
          [ 0.0345, -0.0757,    -inf],
          [ 0.0630, -0.0499,    -inf]],

         [[-0.0876, -0.0904,    -inf],
          [-0.0636, -0.0532,    -inf],
          [-0.0091, -0.0932,    -inf]],

         [[ 0.1765,  0.1944,    -inf],
          [ 0.0701,  0.1181,    -inf],
          [ 0.1089,  0.1160,    -inf]],

         [[-0.0338,  0.0017,    -inf],
          [ 0.0580,  0.0787,    -inf],
          [ 0.0451,  0.1084,    -inf]],

         [[ 0.0717,  0.0508,    -inf],
          [ 0.0883,  0.0564,    -inf],
          [ 0.0369,  0.0327,    -inf]],

         [[ 0.2495,  0.1865,    -inf],
          [ 0.0967,  0.1040,    -inf],
          [ 0.0930,  0.1004,    -inf]],

         [[ 0.2367,  0.2364,    -inf],
          [ 0.1689,  0.2307,    -inf],
          [ 0.0794,  0.1427,

torch.Size([2, 3, 512])

In [21]:
from torch import nn

class Encoder(nn.Module): 

    def __init__(self, vocab_size: int = 2**13, n: int = 6, d: int = 512, h: int = 8):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d)
        self.pe = PE(d=d)
        self.layers = [EncoderLayer(d, h) for _ in range(n)]

    def forward(self, x, self_mask = None):
        b, t = x.size()
        x = self.embed(x)
        x = self.pe(x)
        for layer in self.layers:
            x = layer(x, self_mask=self_mask)
        return x

encoder = Encoder()
x = torch.randint(0, 2**13, (2, 3))
self_mask = build_padding_mask(torch.tensor([[2, 2, 0], [2, 0, 0]]), pad_token=0)
self_mask = reshape_mask(self_mask)
encoder(x, self_mask).shape

scaled_prod.shape: 
 torch.Size([2, 8, 3, 3])
scaled_prod: 
 tensor([[[[-3.4819e-01, -9.6599e-01,        -inf],
          [-1.0084e+00, -1.0910e+00,        -inf],
          [ 5.0857e-02, -3.7348e-01,        -inf]],

         [[-1.1522e-01, -3.9379e-01,        -inf],
          [-1.7539e-01, -1.0004e+00,        -inf],
          [ 1.8132e-01, -9.6912e-02,        -inf]],

         [[ 1.4321e+00, -3.4864e-01,        -inf],
          [ 1.2077e-01,  3.2830e-01,        -inf],
          [ 5.4350e-01,  4.3532e-01,        -inf]],

         [[-1.0962e+00, -3.3000e-01,        -inf],
          [-6.7391e-01,  5.8167e-02,        -inf],
          [-5.0507e-01,  3.7956e-01,        -inf]],

         [[ 7.8002e-01, -4.6078e-01,        -inf],
          [-8.8403e-01,  3.4761e-01,        -inf],
          [-3.6121e-01, -4.7940e-01,        -inf]],

         [[ 3.8456e-01,  6.8160e-01,        -inf],
          [ 4.2822e-01,  1.1415e+00,        -inf],
          [ 2.4242e-01,  7.8037e-01,        -inf]],

         

torch.Size([2, 3, 512])

In [22]:
import torch.nn.functional as F

class DecoderLayer(nn.Module): 

    def __init__(self, d: int = 512, h: int = 8, dropout: float = 0.1):
        super().__init__()
        self.mhsa = MHSAMasked(d=d, h=h)
        self.attn_norm = nn.LayerNorm(d)
        self.attn_dropout = nn.Dropout(dropout)

        self.mhca = MHSAMasked(d=d, h=h)
        self.cross_attn_norm = nn.LayerNorm(d)
        self.cross_attn_dropout = nn.Dropout(dropout)
        
        self.ff1 = nn.Linear(d, d * 4)
        self.ff2 = nn.Linear(d * 4, d)
        self.resid_dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(d)
        

    def forward(self, dec_x, enc_x, self_mask=None, cross_mask=None):
        # self_mask is merged decoders padding and causal masks
        # cross_mask is equal to endcoders padding mask because we don't want to attend to encoded padded tokens
        b, t, d = dec_x.size()
        x = dec_x + self.attn_dropout(self.mhsa(dec_x, dec_x, dec_x, mask=self_mask))
        x = self.attn_norm(x)

        x = x + self.cross_attn_dropout(self.mhca(x, enc_x, enc_x, mask=cross_mask))
        x = self.cross_attn_norm(x)
        
        x = x + self.resid_dropout(self.ff2(F.relu(self.ff1(x))))
        x = self.norm(x)
        return x


decoder_layer = DecoderLayer(h=2, d=16)
x = torch.rand(3, 3, 16)
y = torch.rand(3, 3, 16)
self_mask1 = build_padding_mask(torch.tensor([[2, 2, 0], [2, 0, 0], [2, 2, 0]]), pad_token=0)
self_mask2 = build_causal_mask(torch.tensor([[2, 2, 0], [2, 0, 0], [2, 2, 0]]))
self_mask = merge_masks(self_mask1, self_mask2)
print(f"self_mask: \n {self_mask}")
self_mask = reshape_mask(self_mask)

cross_mask = build_padding_mask(torch.tensor([[2, 2, 2], [2, 0, 0], [2, 2, 0]]), pad_token=0)
cross_mask = reshape_mask(cross_mask)
print(f"cross_mask: \n {cross_mask}")
decoder_layer(x, y, self_mask=self_mask, cross_mask=cross_mask).shape

self_mask: 
 tensor([[1, 0, 0],
        [1, 0, 0],
        [1, 1, 0]])
cross_mask: 
 tensor([[[[1, 1, 1]]],


        [[[1, 0, 0]]],


        [[[1, 1, 0]]]])
scaled_prod.shape: 
 torch.Size([3, 2, 3, 3])
scaled_prod: 
 tensor([[[[ 0.1762,    -inf,    -inf],
          [ 0.0689,    -inf,    -inf],
          [ 0.1165,    -inf,    -inf]],

         [[ 0.0080,    -inf,    -inf],
          [ 0.0103,    -inf,    -inf],
          [ 0.0217,    -inf,    -inf]]],


        [[[ 0.0650,    -inf,    -inf],
          [ 0.2666,    -inf,    -inf],
          [-0.0464,    -inf,    -inf]],

         [[-0.0322,    -inf,    -inf],
          [-0.1587,    -inf,    -inf],
          [-0.0562,    -inf,    -inf]]],


        [[[ 0.1190,  0.0704,    -inf],
          [ 0.0547,  0.0566,    -inf],
          [ 0.0734,  0.0607,    -inf]],

         [[-0.0423, -0.0187,    -inf],
          [-0.0498, -0.0490,    -inf],
          [-0.0530, -0.0532,    -inf]]]], grad_fn=<MaskedFillBackward0>)
softmaxed_prod: 
 tensor([[[[1

torch.Size([3, 3, 16])

In [23]:
from torch import nn

class Decoder(nn.Module): 

    def __init__(self, vocab_size: int = 2**13, n: int = 6, d: int = 512, h: int = 8):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d)
        self.pe = PE(d=d)
        self.layers = [DecoderLayer(d, h) for _ in range(n)]

    def forward(self, dec_x, enc_x, self_mask=self_mask, cross_mask=cross_mask):
        b, t = dec_x.size()
        x = self.embed(dec_x)
        x = self.pe(x)
        for layer in self.layers:
            x = layer(x, enc_x, self_mask=self_mask, cross_mask=cross_mask)
        return x

    def get_embed_weights(self):
        return self.embed.weight

decoder = Decoder(vocab_size=32, n=2, d=16, h=2)
# x = torch.randint(0, 32, (2, 3))
x = torch.tensor([[15, 7, 0], [10, 0, 0], [1, 3, 0]])
y = torch.rand(3, 3, 16)

self_mask1 = build_padding_mask(x, pad_token=0)
self_mask2 = build_causal_mask(x)
self_mask = merge_masks(self_mask1, self_mask2)
print(f"self_mask: \n {self_mask}")
self_mask = reshape_mask(self_mask)

cross_mask = build_padding_mask(torch.tensor([[2, 2, 2], [2, 0, 0], [2, 2, 0]]), pad_token=0)
cross_mask = reshape_mask(cross_mask)
print(f"cross_mask: \n {cross_mask}")
print(decoder(x, y, self_mask=self_mask, cross_mask=cross_mask).shape)

self_mask: 
 tensor([[1, 0, 0],
        [1, 0, 0],
        [1, 1, 0]])
cross_mask: 
 tensor([[[[1, 1, 1]]],


        [[[1, 0, 0]]],


        [[[1, 1, 0]]]])
scaled_prod.shape: 
 torch.Size([3, 2, 3, 3])
scaled_prod: 
 tensor([[[[-0.5455,    -inf,    -inf],
          [-0.7486,    -inf,    -inf],
          [ 0.4464,    -inf,    -inf]],

         [[-0.2762,    -inf,    -inf],
          [-0.4434,    -inf,    -inf],
          [ 0.6578,    -inf,    -inf]]],


        [[[-1.0599,    -inf,    -inf],
          [-1.4474,    -inf,    -inf],
          [-1.6861,    -inf,    -inf]],

         [[-0.2073,    -inf,    -inf],
          [ 0.5623,    -inf,    -inf],
          [ 0.4858,    -inf,    -inf]]],


        [[[-0.2773,  0.2538,    -inf],
          [-0.3311,  0.3636,    -inf],
          [ 0.2938,  0.4350,    -inf]],

         [[ 0.0601,  0.0843,    -inf],
          [ 0.7677,  0.7050,    -inf],
          [ 0.6450, -0.0937,    -inf]]]], grad_fn=<MaskedFillBackward0>)
softmaxed_prod: 
 tensor([[[[1

In [24]:
from torch import nn

class Output(nn.Module):

    def __init__(self, vocab_size: int = 2**13, d: int = 512, ff_weight = None):
        super().__init__()
        self.ff = nn.Linear(d, vocab_size)
        # weight tying with the decoder embedding
        if ff_weight is not None:
            self.ff.weight = ff_weight

    def forward(self, x):
        return self.ff(x)

In [25]:
from torch import nn

class Transformer(nn.Module):
    
    def __init__(self, vocab_size: int = 2**13, n: int = 6, d: int = 512, h: int = 8, embed_tying=True):
        super().__init__()
        self.encoder = Encoder(vocab_size=vocab_size, n=n, d=d, h=h)
        self.decoder = Decoder(vocab_size=vocab_size, n=n, d=d, h=h)
        if embed_tying:
            self.output = Output(vocab_size=vocab_size, d=d, ff_weight = self.decoder.get_embed_weights())
        else:
            self.output = Output(vocab_size=vocab_size, d=d)

    def forward(self, enc_x, dec_x, enc_mask, dec_mask):
        encoded = self.encoder(enc_x, enc_mask)
        decoded = self.decoder(dec_x=dec_x, enc_x=encoded, self_mask=dec_mask, cross_mask=enc_mask)
        return self.output(decoded)

transformer = Transformer(vocab_size=32, n=2, d=16, h=2)
decoder = Decoder(vocab_size=32, n=2, d=16, h=2)
enc_x = torch.tensor([[15, 7, 3], [10, 10, 0], [1, 0, 0]])
dec_x = torch.tensor([[21, 8, 0], [25, 0, 0], [8, 1, 2]])

enc_mask = build_padding_mask(enc_x, pad_token=0)
print(f"enc_mask: \n {enc_mask}")
enc_mask = reshape_mask(enc_mask)

dec_mask1 = build_padding_mask(dec_x, pad_token=0)
dec_mask2 = build_causal_mask(dec_x)
dec_mask = merge_masks(dec_mask1, dec_mask2)
print(f"dec_mask: \n {dec_mask}")
dec_mask = reshape_mask(dec_mask)

print(transformer(enc_x, dec_x, enc_mask=enc_mask, dec_mask=dec_mask).shape)

enc_mask: 
 tensor([[1, 1, 1],
        [1, 1, 0],
        [1, 0, 0]])
dec_mask: 
 tensor([[1, 0, 0],
        [1, 0, 0],
        [1, 1, 1]])
scaled_prod.shape: 
 torch.Size([3, 2, 3, 3])
scaled_prod: 
 tensor([[[[ 0.3163, -0.4204,  0.2447],
          [-0.2666, -1.2326, -0.5926],
          [-0.1814,  0.0949, -0.4774]],

         [[ 0.0432,  0.1339,  0.0038],
          [ 0.2903, -0.2122,  0.1938],
          [ 0.2361,  0.1003,  0.2385]]],


        [[[ 1.2112,  0.8207,    -inf],
          [ 0.9672,  0.4356,    -inf],
          [-0.0881, -0.5046,    -inf]],

         [[ 0.1077, -0.1028,    -inf],
          [-0.0283, -0.2030,    -inf],
          [ 0.5097,  1.0081,    -inf]]],


        [[[-1.1653,    -inf,    -inf],
          [ 0.6043,    -inf,    -inf],
          [ 0.6063,    -inf,    -inf]],

         [[-0.5251,    -inf,    -inf],
          [-0.7017,    -inf,    -inf],
          [-0.4575,    -inf,    -inf]]]], grad_fn=<MaskedFillBackward0>)
softmaxed_prod: 
 tensor([[[[0.4150, 0.1987, 0.38

In [26]:
# Inference

In [34]:
import tiktoken
import torch
from torch.nn.utils.rnn import pad_sequence


encoding = tiktoken.get_encoding("cl100k_base")
print(encoding.encode("tiktoken is great!"))

sents = ["Hello World", "This is a simple sentence", "Me"]
encoded_sents = [encoding.encode(s) for s in sents]
inputs = pad_sequence([torch.tensor(es) for es in encoded_sents], batch_first=True, padding_value=encoding.eot_token)
print(inputs)

[83, 1609, 5963, 374, 2294, 0]
tensor([[  9906,   4435, 100257, 100257, 100257],
        [  2028,    374,    264,   4382,  11914],
        [  7979, 100257, 100257, 100257, 100257]])


In [28]:
assert False

AssertionError: 

In [None]:
x = torch.rand([1, 2, 3])
mask = torch.ones([1, 2])
mask[0, 1] = 0
mask = mask.unsqueeze(1)
print(mask == 0)
x.masked_fill(mask == 0, float("-inf"))