In [1]:
# Unmasked attention

In [2]:
import torch
import math
import torch.nn.functional as F

def self_attention(q, k, v):
    # if 3 dim: b t d
    #prod = Q.bmm(K.permute(0, 2, 1))
    # or
    # prod = torch.einsum("btd, bsd -> bts", q, k)
    # if 4 dim: b t h dh
    # q = q.permute(0, 2, 1, 3)
    prod = torch.einsum("bthd, bshd -> bhts", q, k)
    scaled_prod = prod/torch.sqrt(torch.tensor(q.shape[-1]))
    softmaxed_prod = F.softmax(scaled_prod, dim=-1)
    print(softmaxed_prod.shape)
    # print(softmaxed_prod)
    return softmaxed_prod @ v.permute(0, 2, 1, 3)


x = torch.rand([2, 3, 4, 5])
self_attention(x, x, x)
self_attention(x, x, x).shape

  cpu = _conversion_method_template(device=torch.device("cpu"))


torch.Size([2, 4, 3, 3])
torch.Size([2, 4, 3, 3])


torch.Size([2, 4, 3, 5])

In [3]:
from torch import nn

class MHSA(nn.Module):
    def __init__(self, d: int = 512, h: int = 8):
        super().__init__()
        assert d % h == 0
        self.d = d
        self.dh = d // h
        self.h = h
        self.wq = nn.Linear(self.d, self.d)
        self.wk = nn.Linear(self.d, self.d)
        self.wv = nn.Linear(self.d, self.d)
        self.wo = nn.Linear(self.d, self.d)
 
    def forward(self, q, k, v):
        # b, t, d
        b, t, d = q.size()
        wq = self.wq(q)
        wk = self.wk(k)
        wv = self.wv(v)
        wq = wq.view(b, t, self.h, self.dh)
        wk = wk.view(b, t, self.h, self.dh)
        wv = wv.view(b, t, self.h, self.dh)
        # b, t, h, dh
        # if changing from 4 dim -> 3 dim: b*h, t, dh
        # wq = wq.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # wk = wk.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # wv = wv.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # another option 4 dim -> 3 dim
        # wq = wq.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # wk = wk.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # wv = wv.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # changing the number of dims is not necessary as @ supports 4 dims
        attn = self_attention(wq, wk, wv)
        # b * h, t, dh
        # attn = attn.view(b, self.h, t, self.dh).permute(0, 2, 1, 3).reshape(b, t, d)
        attn = attn.view(b, self.h, t, self.dh).transpose(1, 2).contiguous().view(b, t, d)
        wo = self.wo(attn)
        return wo
        # # 1 2 3 4
        # x = F.relu(self.conv1(x))
        # return F.relu(self.conv2(x))

mhsa = MHSA()
x = torch.rand(2, 3, 512)
mhsa(x, x, x).shape

torch.Size([2, 8, 3, 3])


torch.Size([2, 3, 512])

In [4]:
# PE

In [5]:
from torch import nn

class PE(nn.Module):

    def __init__(self, d: int = 512, max_len: int = 1024, dropout: float = 0.1):
        super().__init__()
        self.d = d
        self.dropout = nn.Dropout(p=dropout)

        twoi = torch.arange(0, self.d, 2)
        pow_ = torch.pow(10000, twoi / self.d)
        position = torch.arange(0, max_len).unsqueeze(1)
        sin_p = torch.sin(position / pow_)
        cos_p = torch.cos(position / pow_)
        pe = torch.zeros(max_len, self.d, requires_grad=False) # Explicit, register buffer insures requires grad = False
        pe[:, 0::2] = sin_p
        pe[:, 1::2] = cos_p
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe) 

    def forward(self, x):
        b, t, d = x.size()
        x = x + self.pe[:, : t, :]
        return self.dropout(x)
print(PE(d=4)(torch.arange(24).view(-1, 3, 4)).size()) # torch.Size([2, 3, 4])

torch.Size([2, 3, 4])


In [6]:
from torch import nn

class PEEmbed(nn.Module):

    def __init__(self, d: int = 512, max_len: int = 1024, dropout: float = 0.1):
        super().__init__()
        self.d = d
        self.pe = nn.Embedding(max_len, d)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        b, t, d = x.size()
        pos = self.pe(torch.arange(t))
        x = x + pos
        return self.dropout(x)
print(PEEmbed(d=4)(torch.arange(24).view(-1, 3, 4)).size()) # torch.Size([2, 3, 4])

torch.Size([2, 3, 4])


In [7]:
# Encoder without mask

In [8]:
import torch.nn.functional as F

class EncoderLayerWithoutMask(nn.Module): 

    def __init__(self, d: int = 512, h: int = 8, dropout: float = 0.1):
        super().__init__()
        self.mhsa = MHSA(d, h)
        self.norm1 = nn.LayerNorm(d)
        self.ff1 = nn.Linear(d, d * 4)
        self.ff2 = nn.Linear(d * 4, d)
        self.norm2 = nn.LayerNorm(d)
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)

    def forward(self, x):
        b, t, d = x.size()
        x = x + self.attn_dropout(self.mhsa(x, x, x))
        x = self.norm1(x)
        x = x + self.resid_dropout(self.ff2(F.relu(self.ff1(x))))
        x = self.norm2(x)
        return x


encoder_layer = EncoderLayerWithoutMask()
x = torch.rand(2, 3, 512)
encoder_layer(x).shape

torch.Size([2, 8, 3, 3])


torch.Size([2, 3, 512])

In [9]:
from torch import nn

class EncoderWithoutMask(nn.Module): 

    def __init__(self, vocab_size: int = 2**13, n: int = 6, d: int = 512, h: int = 8):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d)
        self.pe = PE(d=d)
        self.layers = [EncoderLayerWithoutMask(d, h) for _ in range(n)]

    def forward(self, x):
        b, t = x.size()
        x = self.embed(x)
        x = self.pe(x)
        for layer in self.layers:
            x = layer(x)
        return x

encoder = EncoderWithoutMask()
x = torch.randint(0, 2**13, (2, 3))
encoder(x).shape

torch.Size([2, 8, 3, 3])
torch.Size([2, 8, 3, 3])
torch.Size([2, 8, 3, 3])
torch.Size([2, 8, 3, 3])
torch.Size([2, 8, 3, 3])
torch.Size([2, 8, 3, 3])


torch.Size([2, 3, 512])

In [10]:
# With masks
import torch
import math
import torch.nn.functional as F

def self_attention_masked(q, k, v, mask=None):
    # if 3 dim: b t d
    #prod = Q.bmm(K.permute(0, 2, 1))
    # or
    # prod = torch.einsum("btd, bsd -> bts", q, k)
    # if 4 dim: b t h dh:
    prod = torch.einsum("bthd, bshd -> bhts", q, k)
    scaled_prod = prod/torch.sqrt(torch.tensor(q.shape[-1]))
    print(f"scaled_prod.shape: \n {scaled_prod.shape}")
    # mask should be in shape to be broadcastable to bhts and lead to masked keys only (last s dim)
    if mask is not None:
        scaled_prod = scaled_prod.masked_fill(mask == 0, float("-inf"))
    print(f"scaled_prod: \n {scaled_prod}")
    softmaxed_prod = F.softmax(scaled_prod, dim=-1)
    # print(softmaxed_prod.shape)
    print(f"softmaxed_prod: \n {softmaxed_prod}")
    # swap h and t in v
    return softmaxed_prod @ v.permute(0, 2, 1, 3)


In [11]:
# Mask

In [12]:
# play with mask

x = torch.rand([2, 3, 2, 4])
print(x)
# mask 2 batches 3 timeseries
mask = torch.ones([2, 3])
mask[0, 2] = 0
mask[1, 2] = 0
mask[1, 1] = 0
print(f"mask: \n {mask}")
# add head dim to make mask broatcastable to q x k.T prod. mask shape 2, 1, 3
mask = mask.unsqueeze(1)


# mask = mask.permute(0, 2, 1)
# is the mask that I need? keys are ignored?
print(f"wrong mask: \n {mask}")
#  mask = 2 1 3 -> b prepended before broadcasting (1!!!) h (remains since already 2) t (broadcasted from 1) d (remains since already 3) 
print(f"wrong mask broadcast: \n {mask.broadcast_to([2, 2, 3, 3])}") 
a = self_attention_masked(x, x, x, mask=mask)
print(f"wrong a: \n {a}" )
print(f"wrong a.shape: \n {a.shape}")
# leads to wrong attention since the shape of mask is wrong 2 1 3 

# correct mask
# mask 2 batches 3 timeseries
mask = torch.ones([2, 3])
mask[0, 2] = 0
mask[1, 2] = 0
mask[1, 1] = 0
mask = mask.unsqueeze(1).unsqueeze(1)

print(f"mask: \n {mask}")
#  mask = 2 1 1 3 -> b (remains already 2) h (broadcasted from 1) t (broadcasted from 1) d (remains since already 3) 
print(f"mask broadcast: \n {mask.broadcast_to([2, 2, 3, 3])}") 
a = self_attention_masked(x, x, x, mask=mask)
print(f"a: \n {a}" )
print(f"a.shape: \n {a.shape}")


tensor([[[[0.5197, 0.5411, 0.6381, 0.3260],
          [0.7154, 0.1061, 0.3673, 0.7780]],

         [[0.2145, 0.1403, 0.0992, 0.6875],
          [0.0835, 0.2168, 0.8869, 0.8104]],

         [[0.6843, 0.3754, 0.7641, 0.3368],
          [0.0599, 0.4743, 0.4382, 0.2703]]],


        [[[0.5229, 0.5762, 0.5786, 0.8980],
          [0.6177, 0.1171, 0.0077, 0.5496]],

         [[0.8276, 0.2185, 0.5424, 0.0732],
          [0.0893, 0.0777, 0.2497, 0.6546]],

         [[0.8811, 0.8675, 0.5580, 0.9198],
          [0.3234, 0.7859, 0.1576, 0.1278]]]])
mask: 
 tensor([[1., 1., 0.],
        [1., 0., 0.]])
wrong mask: 
 tensor([[[1., 1., 0.]],

        [[1., 0., 0.]]])
wrong mask broadcast: 
 tensor([[[[1., 1., 0.],
          [1., 1., 0.],
          [1., 1., 0.]],

         [[1., 0., 0.],
          [1., 0., 0.],
          [1., 0., 0.]]],


        [[[1., 1., 0.],
          [1., 1., 0.],
          [1., 1., 0.]],

         [[1., 0., 0.],
          [1., 0., 0.],
          [1., 0., 0.]]]])
scaled_prod.shape

In [13]:
# mask is equal to making keys on masked places 0:
# the result in terms of masked symbols is the same
k = x.clone()
k[0, 2, 0, :] = float("-inf")
k[0, 2, 1, :] = float("-inf")
k[1, 2, 0, :] = float("-inf")
k[1, 1, 0, :] = float("-inf")
k[1, 2, 1, :] = float("-inf")
k[1, 1, 1, :] = float("-inf")
print(f"k: \n {k}")
a = self_attention_masked(x, k, x)
print(f"a: \n {a}" )
print(f"a.shape: \n {a.shape}")
# a is the same shape as if mask was applied in q * k:

test = torch.rand([2, 3, 4])
test[0, 2, :] = 0
test[1, 1, :] = 0
test[1, 2, :] = 0

print(f"test: \n {test}")
test_v = test.view(2, 3, 2, 2)
print(f"test_v: \n {test_v}")
test_perm = test_v.permute(0, 2, 1, 3)
print(f"test_perm: \n {test_perm}")

# or like that:
test_q = torch.rand([2, 3, 4])
test_k = test_q.clone()
test_k[0, 2, :] = float("-inf")
test_k[1, 1, :] = float("-inf")
test_k[1, 2, :] = float("-inf")
print(f"test_k: \n {test_k}")

test_q_view = test_q.view(2, 3, 2, 2)
test_k_view = test_k.view(2, 3, 2, 2)
print(f"test_k_view: \n {test_k_view}")
test_q_perm = test_q_view.permute(0, 2, 1, 3)
test_k_perm = test_k_view.permute(0, 2, 1, 3)
print(f"test_k_perm: \n {test_k_perm}")
print(f"q * k: \n {torch.einsum("bhtd, bhsd -> bhts", test_q_perm, test_k_perm)}")

k: 
 tensor([[[[0.5197, 0.5411, 0.6381, 0.3260],
          [0.7154, 0.1061, 0.3673, 0.7780]],

         [[0.2145, 0.1403, 0.0992, 0.6875],
          [0.0835, 0.2168, 0.8869, 0.8104]],

         [[  -inf,   -inf,   -inf,   -inf],
          [  -inf,   -inf,   -inf,   -inf]]],


        [[[0.5229, 0.5762, 0.5786, 0.8980],
          [0.6177, 0.1171, 0.0077, 0.5496]],

         [[  -inf,   -inf,   -inf,   -inf],
          [  -inf,   -inf,   -inf,   -inf]],

         [[  -inf,   -inf,   -inf,   -inf],
          [  -inf,   -inf,   -inf,   -inf]]]])
scaled_prod.shape: 
 torch.Size([2, 2, 3, 3])
scaled_prod: 
 tensor([[[[0.5382, 0.2374,   -inf],
          [0.2374, 0.2741,   -inf],
          [0.5781, 0.2535,   -inf]],

         [[0.6316, 0.5194,   -inf],
          [0.5194, 0.7486,   -inf],
          [0.2322, 0.3578,   -inf]]],


        [[[0.8734,   -inf,   -inf],
          [0.4691,   -inf,   -inf],
          [1.0547,   -inf,   -inf]],

         [[0.3487,   -inf,   -inf],
          [0.2130,   -i

In [14]:
import torch

def build_padding_mask(x, pad_token):
    # x: b t shape
    mask = torch.ones_like(x)
    return mask.masked_fill(x == pad_token, 0)

x = torch.rand(5, 6)
x[0, -3:] = 100
x[1, -2:] = 100
x[2, -1] = 100
x[3, :] = 100
print(x)
print(build_padding_mask(x, 100))

tensor([[6.6317e-01, 2.9040e-01, 4.2508e-01, 1.0000e+02, 1.0000e+02, 1.0000e+02],
        [7.7778e-01, 1.1218e-01, 1.0981e-01, 4.4139e-02, 1.0000e+02, 1.0000e+02],
        [2.3179e-01, 2.6273e-01, 6.1778e-01, 3.1230e-01, 1.3051e-01, 1.0000e+02],
        [1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02],
        [4.9919e-01, 8.8737e-01, 7.5314e-01, 8.6046e-01, 4.2014e-01, 7.3911e-01]])
tensor([[1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [15]:
import torch

def build_causal_mask(x):
    # x: b t shape
    m = torch.ones_like(x)
    return torch.tril(m)
x = torch.rand(5, 6)

print(build_causal_mask(x))

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.]])


In [16]:
import torch

def merge_masks(m1, m2):
    return m1 * m2

x = torch.rand(5, 6)
x[0, -3:] = 100
x[1, -1] = 100
x[2, -4:] = 100
x[3, :] = 100
print(x)
m1 = build_padding_mask(x, 100)
m2 = build_causal_mask(x)
print(merge_masks(m1, m2))

tensor([[2.8126e-01, 7.2617e-01, 9.9193e-01, 1.0000e+02, 1.0000e+02, 1.0000e+02],
        [9.7744e-01, 2.9643e-01, 6.3065e-02, 4.9401e-01, 2.6612e-01, 1.0000e+02],
        [2.2052e-01, 1.2273e-01, 1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02],
        [1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02],
        [9.4130e-01, 7.0709e-01, 9.1821e-02, 9.1652e-01, 2.5230e-01, 4.8051e-01]])
tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0.]])


In [17]:
import torch

def reshape_mask(mask):
    # b t -> b 1 1 t (to be broadcastable to b h t t)
    return mask.unsqueeze(1).unsqueeze(1)

x = torch.rand(2, 3)
print(reshape_mask(build_causal_mask(x)))

tensor([[[[1., 0., 0.]]],


        [[[1., 1., 0.]]]])


In [18]:
from torch import nn

class MHSAMasked(nn.Module):
    def __init__(self, d: int = 512, h: int = 8):
        super().__init__()
        assert d % h == 0
        self.d = d
        self.dh = d // h
        self.h = h
        self.wq = nn.Linear(self.d, self.d)
        self.wk = nn.Linear(self.d, self.d)
        self.wv = nn.Linear(self.d, self.d)
        self.wo = nn.Linear(self.d, self.d)
 
    def forward(self, q, k, v, mask = None):
        # q and k/v might be of different sizes if lengths of decoder and encoders inputs are different
        bq, tq, dq = q.size()
        bk, tk, dk = k.size()
        wq = self.wq(q)
        wk = self.wk(k)
        wv = self.wv(v)
        wq = wq.view(bq, tq, self.h, self.dh)
        wk = wk.view(bk, tk, self.h, self.dh)
        wv = wv.view(bk, tk, self.h, self.dh)
        # b, t, h, dh
        # if changing from 4 dim -> 3 dim: b*h, t, dh
        # wq = wq.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # wk = wk.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # wv = wv.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # another option 4 dim -> 3 dim
        # wq = wq.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # wk = wk.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # wv = wv.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # changing the number of dims is not necessary as @ supports 4 dims
        attn = self_attention_masked(wq, wk, wv, mask=mask)
        # b * h, t, dh
        # attn = attn.view(b, self.h, t, self.dh).permute(0, 2, 1, 3).reshape(b, t, d)
        attn = attn.view(bq, self.h, tq, self.dh).transpose(1, 2).contiguous().view(bq, tq, dq)
        wo = self.wo(attn)
        return wo
        # # 1 2 3 4
        # x = F.relu(self.conv1(x))
        # return F.relu(self.conv2(x))

mhsa_masked = MHSAMasked(h = 2, d = 6)
x = torch.rand(4, 5)
mask = reshape_mask(build_causal_mask(x))
print(mask)
x = torch.rand(4, 5, 6)
print(mhsa_masked(x, x, x, mask=mask))
print(mhsa_masked(x, x, x, mask=mask).shape)

tensor([[[[1., 0., 0., 0., 0.]]],


        [[[1., 1., 0., 0., 0.]]],


        [[[1., 1., 1., 0., 0.]]],


        [[[1., 1., 1., 1., 0.]]]])
scaled_prod.shape: 
 torch.Size([4, 2, 5, 5])
scaled_prod: 
 tensor([[[[-0.0373,    -inf,    -inf,    -inf,    -inf],
          [ 0.0716,    -inf,    -inf,    -inf,    -inf],
          [ 0.0067,    -inf,    -inf,    -inf,    -inf],
          [-0.0635,    -inf,    -inf,    -inf,    -inf],
          [-0.1106,    -inf,    -inf,    -inf,    -inf]],

         [[-0.0179,    -inf,    -inf,    -inf,    -inf],
          [ 0.0067,    -inf,    -inf,    -inf,    -inf],
          [-0.0570,    -inf,    -inf,    -inf,    -inf],
          [-0.0437,    -inf,    -inf,    -inf,    -inf],
          [-0.0307,    -inf,    -inf,    -inf,    -inf]]],


        [[[ 0.0274,  0.0643,    -inf,    -inf,    -inf],
          [ 0.0649,  0.1165,    -inf,    -inf,    -inf],
          [ 0.1164,  0.1189,    -inf,    -inf,    -inf],
          [ 0.1601,  0.1688,    -inf,    -inf,   

In [19]:
# Transformer implementation 

In [20]:
import torch.nn.functional as F

class EncoderLayer(nn.Module): 

    def __init__(self, d: int = 512, h: int = 8, dropout: float = 0.1):
        super().__init__()
        self.mhsa = MHSAMasked(d, h)
        self.norm1 = nn.LayerNorm(d)
        self.ff1 = nn.Linear(d, d * 4)
        self.ff2 = nn.Linear(d * 4, d)
        self.norm2 = nn.LayerNorm(d)
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)

    def forward(self, x, self_mask=None):
        b, t, d = x.size()
        x = x + self.attn_dropout(self.mhsa(x, x, x, mask=self_mask))
        x = self.norm1(x)
        x = x + self.resid_dropout(self.ff2(F.relu(self.ff1(x))))
        x = self.norm2(x)
        return x


encoder_layer = EncoderLayer()
self_mask = build_padding_mask(torch.tensor([[2, 2, 0], [2, 0, 0]]), pad_token=0)
self_mask = reshape_mask(self_mask)
x = torch.rand(2, 3, 512)

encoder_layer(x, self_mask=self_mask).shape

scaled_prod.shape: 
 torch.Size([2, 8, 3, 3])
scaled_prod: 
 tensor([[[[-0.0823, -0.1225,    -inf],
          [-0.0527, -0.0749,    -inf],
          [-0.0348, -0.0261,    -inf]],

         [[-0.0282, -0.1440,    -inf],
          [ 0.0651, -0.1101,    -inf],
          [ 0.0382, -0.1288,    -inf]],

         [[-0.1386, -0.1153,    -inf],
          [-0.0730, -0.0167,    -inf],
          [-0.1262, -0.1576,    -inf]],

         [[-0.0189, -0.0116,    -inf],
          [ 0.0058,  0.0429,    -inf],
          [ 0.0375, -0.0341,    -inf]],

         [[ 0.0412,  0.0224,    -inf],
          [-0.0168, -0.0034,    -inf],
          [-0.0546, -0.0286,    -inf]],

         [[-0.2706, -0.3284,    -inf],
          [-0.0762, -0.0907,    -inf],
          [-0.1119, -0.1862,    -inf]],

         [[ 0.0689,  0.0890,    -inf],
          [ 0.0651,  0.0416,    -inf],
          [ 0.1561,  0.1816,    -inf]],

         [[-0.1237, -0.1848,    -inf],
          [ 0.0351, -0.0222,    -inf],
          [-0.0299, -0.0595,

torch.Size([2, 3, 512])

In [21]:
from torch import nn

class Encoder(nn.Module): 

    def __init__(self, vocab_size: int = 2**13, n: int = 6, d: int = 512, h: int = 8):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d)
        self.pe = PE(d=d)
        self.layers = [EncoderLayer(d, h) for _ in range(n)]

    def forward(self, x, self_mask = None):
        b, t = x.size()
        x = self.embed(x)
        x = self.pe(x)
        for layer in self.layers:
            x = layer(x, self_mask=self_mask)
        return x

encoder = Encoder()
x = torch.randint(0, 2**13, (2, 3))
self_mask = build_padding_mask(torch.tensor([[2, 2, 0], [2, 0, 0]]), pad_token=0)
self_mask = reshape_mask(self_mask)
encoder(x, self_mask).shape

scaled_prod.shape: 
 torch.Size([2, 8, 3, 3])
scaled_prod: 
 tensor([[[[-7.5255e-02,  1.5785e+00,        -inf],
          [-5.1758e-01,  6.9589e-01,        -inf],
          [-2.1072e-01,  6.3987e-02,        -inf]],

         [[-3.6289e-01, -2.6427e-01,        -inf],
          [ 4.2333e-01, -7.2971e-02,        -inf],
          [-4.3477e-01,  2.8622e-01,        -inf]],

         [[ 1.3705e-01, -4.3642e-02,        -inf],
          [-4.6413e-02,  4.1261e-01,        -inf],
          [-3.3340e-01,  8.8788e-01,        -inf]],

         [[ 1.5946e-01, -7.2807e-01,        -inf],
          [ 9.1806e-01,  8.0123e-02,        -inf],
          [ 1.6280e-01, -4.1546e-01,        -inf]],

         [[ 6.8802e-01,  4.9811e-03,        -inf],
          [ 2.0877e-01, -3.8896e-01,        -inf],
          [-1.4905e+00, -9.8413e-01,        -inf]],

         [[ 3.0520e-01,  4.2186e-01,        -inf],
          [ 1.0437e-03,  2.7047e-01,        -inf],
          [ 5.9591e-01,  5.4857e-01,        -inf]],

         

torch.Size([2, 3, 512])

In [22]:
import torch.nn.functional as F

class DecoderLayer(nn.Module): 

    def __init__(self, d: int = 512, h: int = 8, dropout: float = 0.1):
        super().__init__()
        self.mhsa = MHSAMasked(d=d, h=h)
        self.attn_norm = nn.LayerNorm(d)
        self.attn_dropout = nn.Dropout(dropout)

        self.mhca = MHSAMasked(d=d, h=h)
        self.cross_attn_norm = nn.LayerNorm(d)
        self.cross_attn_dropout = nn.Dropout(dropout)
        
        self.ff1 = nn.Linear(d, d * 4)
        self.ff2 = nn.Linear(d * 4, d)
        self.resid_dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(d)
        

    def forward(self, dec_x, enc_x, self_mask=None, cross_mask=None):
        # self_mask is merged decoders padding and causal masks
        # cross_mask is equal to endcoders padding mask because we don't want to attend to encoded padded tokens
        b, t, d = dec_x.size()
        x = dec_x + self.attn_dropout(self.mhsa(dec_x, dec_x, dec_x, mask=self_mask))
        x = self.attn_norm(x)

        x = x + self.cross_attn_dropout(self.mhca(x, enc_x, enc_x, mask=cross_mask))
        x = self.cross_attn_norm(x)
        
        x = x + self.resid_dropout(self.ff2(F.relu(self.ff1(x))))
        x = self.norm(x)
        return x


decoder_layer = DecoderLayer(h=2, d=16)
x = torch.rand(3, 3, 16)
y = torch.rand(3, 3, 16)
self_mask1 = build_padding_mask(torch.tensor([[2, 2, 0], [2, 0, 0], [2, 2, 0]]), pad_token=0)
self_mask2 = build_causal_mask(torch.tensor([[2, 2, 0], [2, 0, 0], [2, 2, 0]]))
self_mask = merge_masks(self_mask1, self_mask2)
print(f"self_mask: \n {self_mask}")
self_mask = reshape_mask(self_mask)

cross_mask = build_padding_mask(torch.tensor([[2, 2, 2], [2, 0, 0], [2, 2, 0]]), pad_token=0)
cross_mask = reshape_mask(cross_mask)
print(f"cross_mask: \n {cross_mask}")
decoder_layer(x, y, self_mask=self_mask, cross_mask=cross_mask).shape

self_mask: 
 tensor([[1, 0, 0],
        [1, 0, 0],
        [1, 1, 0]])
cross_mask: 
 tensor([[[[1, 1, 1]]],


        [[[1, 0, 0]]],


        [[[1, 1, 0]]]])
scaled_prod.shape: 
 torch.Size([3, 2, 3, 3])
scaled_prod: 
 tensor([[[[-0.0805,    -inf,    -inf],
          [ 0.0826,    -inf,    -inf],
          [-0.0313,    -inf,    -inf]],

         [[ 0.1963,    -inf,    -inf],
          [ 0.1676,    -inf,    -inf],
          [ 0.1429,    -inf,    -inf]]],


        [[[-0.0006,    -inf,    -inf],
          [ 0.0163,    -inf,    -inf],
          [-0.0237,    -inf,    -inf]],

         [[ 0.1370,    -inf,    -inf],
          [ 0.0370,    -inf,    -inf],
          [ 0.0983,    -inf,    -inf]]],


        [[[ 0.0224,  0.0335,    -inf],
          [ 0.0327,  0.0037,    -inf],
          [ 0.0672,  0.0623,    -inf]],

         [[ 0.0396,  0.1272,    -inf],
          [-0.0054,  0.1835,    -inf],
          [-0.0210,  0.1624,    -inf]]]], grad_fn=<MaskedFillBackward0>)
softmaxed_prod: 
 tensor([[[[1

torch.Size([3, 3, 16])

In [23]:
from torch import nn

class Decoder(nn.Module): 

    def __init__(self, vocab_size: int = 2**13, n: int = 6, d: int = 512, h: int = 8):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d)
        self.pe = PE(d=d)
        self.layers = [DecoderLayer(d, h) for _ in range(n)]

    def forward(self, dec_x, enc_x, self_mask=None, cross_mask=None):
        b, t = dec_x.size()
        x = self.embed(dec_x)
        x = self.pe(x)
        for layer in self.layers:
            x = layer(x, enc_x, self_mask=self_mask, cross_mask=cross_mask)
        return x

    def get_embed_weights(self):
        return self.embed.weight

decoder = Decoder(vocab_size=32, n=2, d=16, h=2)
# x = torch.randint(0, 32, (2, 3))
x = torch.tensor([[15, 7, 0], [10, 0, 0], [1, 3, 0]])
y = torch.rand(3, 3, 16)

self_mask1 = build_padding_mask(x, pad_token=0)
self_mask2 = build_causal_mask(x)
self_mask = merge_masks(self_mask1, self_mask2)
print(f"self_mask: \n {self_mask}")
self_mask = reshape_mask(self_mask)

cross_mask = build_padding_mask(torch.tensor([[2, 2, 2], [2, 0, 0], [2, 2, 0]]), pad_token=0)
cross_mask = reshape_mask(cross_mask)
print(f"cross_mask: \n {cross_mask}")
print(decoder(x, y, self_mask=self_mask, cross_mask=cross_mask).shape)

self_mask: 
 tensor([[1, 0, 0],
        [1, 0, 0],
        [1, 1, 0]])
cross_mask: 
 tensor([[[[1, 1, 1]]],


        [[[1, 0, 0]]],


        [[[1, 1, 0]]]])
scaled_prod.shape: 
 torch.Size([3, 2, 3, 3])
scaled_prod: 
 tensor([[[[-0.3712,    -inf,    -inf],
          [-0.3560,    -inf,    -inf],
          [-0.3503,    -inf,    -inf]],

         [[ 0.3399,    -inf,    -inf],
          [ 0.0327,    -inf,    -inf],
          [ 0.1566,    -inf,    -inf]]],


        [[[-0.3375,    -inf,    -inf],
          [ 0.3308,    -inf,    -inf],
          [ 0.0699,    -inf,    -inf]],

         [[-0.3200,    -inf,    -inf],
          [ 0.2670,    -inf,    -inf],
          [ 0.0720,    -inf,    -inf]]],


        [[[-0.1735,  0.0916,    -inf],
          [ 0.0646,  0.2738,    -inf],
          [-0.3096, -0.6128,    -inf]],

         [[ 0.1896, -0.1227,    -inf],
          [ 0.2528,  0.3859,    -inf],
          [-0.4951, -0.2353,    -inf]]]], grad_fn=<MaskedFillBackward0>)
softmaxed_prod: 
 tensor([[[[1

In [24]:
from torch import nn

class Output(nn.Module):

    def __init__(self, vocab_size: int = 2**13, d: int = 512, ff_weight = None):
        super().__init__()
        self.ff = nn.Linear(d, vocab_size)
        # weight tying with the decoder embedding
        if ff_weight is not None:
            self.ff.weight = ff_weight

    def forward(self, x):
        return self.ff(x)

In [25]:
from torch import nn

class Transformer(nn.Module):
    
    def __init__(self, vocab_size: int = 2**13, n: int = 6, d: int = 512, h: int = 8, embed_tying=True):
        super().__init__()
        self.encoder = Encoder(vocab_size=vocab_size, n=n, d=d, h=h)
        self.decoder = Decoder(vocab_size=vocab_size, n=n, d=d, h=h)
        if embed_tying:
            self.output = Output(vocab_size=vocab_size, d=d, ff_weight = self.decoder.get_embed_weights())
        else:
            self.output = Output(vocab_size=vocab_size, d=d)

    def forward(self, enc_x, dec_x, enc_mask=None, dec_mask=None):
        encoded = self.encoder(enc_x, enc_mask)
        decoded = self.decoder(dec_x=dec_x, enc_x=encoded, self_mask=dec_mask, cross_mask=enc_mask)
        return self.output(decoded)

transformer = Transformer(vocab_size=32, n=2, d=16, h=2, embed_tying=False)
enc_x = torch.tensor([[15, 7, 3], [10, 10, 0], [1, 0, 0]])
dec_x = torch.tensor([[21, 8, 0, 0], [25, 0, 0, 0], [8, 1, 2, 3]])
# dec_x = torch.tensor([[21, 8], [25, 0], [8, 1]])

enc_mask = build_padding_mask(enc_x, pad_token=0)
print(f"enc_mask: \n {enc_mask}")
enc_mask = reshape_mask(enc_mask)

dec_mask1 = build_padding_mask(dec_x, pad_token=0)
dec_mask2 = build_causal_mask(dec_x)
dec_mask = merge_masks(dec_mask1, dec_mask2)
print(f"dec_mask: \n {dec_mask}")
dec_mask = reshape_mask(dec_mask)

print(transformer(enc_x, dec_x, enc_mask=enc_mask, dec_mask=dec_mask).shape)

enc_mask: 
 tensor([[1, 1, 1],
        [1, 1, 0],
        [1, 0, 0]])
dec_mask: 
 tensor([[1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 1, 1, 0]])
scaled_prod.shape: 
 torch.Size([3, 2, 3, 3])
scaled_prod: 
 tensor([[[[ 1.6742e+00, -9.9728e-04,  1.0976e+00],
          [ 1.5374e+00,  8.7139e-01,  1.7909e-01],
          [ 2.0514e+00,  1.0262e+00,  5.0747e-01]],

         [[ 2.6828e-01,  5.7621e-01,  3.8235e-02],
          [ 1.3697e+00,  1.6423e+00,  2.0728e+00],
          [ 8.7461e-01,  1.0193e-01,  6.4389e-01]]],


        [[[ 4.5975e-01,  3.5872e-01,        -inf],
          [ 7.1827e-01,  5.7135e-01,        -inf],
          [ 4.1149e-01,  4.6960e-01,        -inf]],

         [[ 2.8763e-01,  3.9224e-01,        -inf],
          [ 1.1159e+00,  9.6830e-01,        -inf],
          [-1.2512e-01, -1.8677e-01,        -inf]]],


        [[[ 1.8417e+00,        -inf,        -inf],
          [ 2.9988e-01,        -inf,        -inf],
          [ 4.3545e-01,        -inf,        -inf]],

         [[

In [26]:
# Inference

In [27]:
import tiktoken
import torch
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F


encoding = tiktoken.get_encoding("cl100k_base")
sents = ["Hello World", "This is a simple sentence", "Me"]
encoded_sents = [encoding.encode(s) for s in sents]
enc_x = pad_sequence([torch.tensor(es) for es in encoded_sents], batch_first=True, padding_value=encoding.eot_token)
print(enc_x)
dec_sents = ["Bonjour", "C'est une phrase", "START"]
dec_encoded_sents = [encoding.encode(s) for s in dec_sents]
dec_x = pad_sequence([torch.tensor(es) for es in dec_encoded_sents], batch_first=True, padding_value=encoding.eot_token)
print(dec_x)

transformer = Transformer(vocab_size=encoding.n_vocab, n=3, d=256, h=4)

enc_mask = build_padding_mask(enc_x, pad_token=100257)
print(f"enc_mask: \n {enc_mask}")
enc_mask = reshape_mask(enc_mask)

dec_mask1 = build_padding_mask(dec_x, pad_token=100257)
dec_mask2 = build_causal_mask(dec_x)
dec_mask = merge_masks(dec_mask1, dec_mask2)
print(f"dec_mask: \n {dec_mask}")
dec_mask = reshape_mask(dec_mask)

output = transformer(enc_x, dec_x, enc_mask=enc_mask, dec_mask=dec_mask)
print(f"output shape: {output.shape}")
softmaxed = F.softmax(output, dim=-1)
print(f"softmaxed[0, 0, :10]: {softmaxed[0, 0, :10]}")
predicted = softmaxed.argmax(dim=-1)
print(f"predicted: \n {predicted}")

predicted_list = predicted.tolist()
predicted_decoded = [encoding.decode(l) for l in predicted_list]
print(f"predicted decoded: \n {predicted_decoded}")

tensor([[  9906,   4435, 100257, 100257, 100257],
        [  2028,    374,    264,   4382,  11914],
        [  7979, 100257, 100257, 100257, 100257]])
tensor([[ 82681, 100257, 100257, 100257],
        [    34,  17771,   6316,  17571],
        [ 23380, 100257, 100257, 100257]])
enc_mask: 
 tensor([[1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1],
        [1, 0, 0, 0, 0]])
dec_mask: 
 tensor([[1, 0, 0, 0],
        [1, 1, 0, 0],
        [1, 0, 0, 0]])
scaled_prod.shape: 
 torch.Size([3, 4, 5, 5])
scaled_prod: 
 tensor([[[[-3.2755e-01,  1.3563e-01,        -inf,        -inf,        -inf],
          [ 3.5796e-01,  2.8945e-01,        -inf,        -inf,        -inf],
          [ 2.2562e-01, -9.1705e-02,        -inf,        -inf,        -inf],
          [ 2.8101e-02, -1.2289e-02,        -inf,        -inf,        -inf],
          [ 2.7247e-01,  2.5949e-01,        -inf,        -inf,        -inf]],

         [[ 3.9820e-01,  9.5041e-01,        -inf,        -inf,        -inf],
          [-2.7562e-02,  2.061

In [36]:
F.softmax(torch.tensor([[0.1,0.2,0.3],[0.1, 0.2, 0.3]]), dim=-1)
torch.tensor([[0.1,0.2,0.3],[0.1, 0.2, 0.4]]).argmax(dim=-1)
# torch.max(torch.tensor([[0.1,0.2,0.3],[0.1, 0.2, 0.4]]), dim=-1)

t1 = torch.tensor([[0.1, 0.2]])
t2 = torch.tensor([[0.3]])
torch.cat((t1, t2), dim=1).tolist()[-1][-1]


0.30000001192092896

In [39]:
# Predicting next words
sent = "This is a simple sentence"
encoded_sent = encoding.encode(sent)
enc_x = torch.tensor(encoded_sent).unsqueeze(0)
dec_x = torch.tensor(encoding.encode("C")).unsqueeze(0)

transformer = Transformer(vocab_size=encoding.n_vocab, n=3, d=256, h=4)

predicted_tokens = []
for _ in range(5):
    output = transformer(enc_x=enc_x, dec_x=dec_x)
    softmaxed = F.softmax(output, dim=-1)
    predicted = softmaxed.argmax(dim=-1)
    predicted_tokens.append(predicted.tolist()[-1][-1]) 
    dec_x = torch.cat((dec_x, predicted), dim=-1)

print(predicted_tokens)
print(f"predicted sentence: \n {encoding.decode(predicted_tokens)}")

scaled_prod.shape: 
 torch.Size([1, 4, 5, 5])
scaled_prod: 
 tensor([[[[ 7.9512e-01,  9.0694e-01,  4.5273e-01, -3.0508e-01, -3.6818e-01],
          [ 4.1684e-01,  4.8840e-01,  4.2983e-01, -8.7179e-02, -5.3285e-02],
          [ 9.3448e-01,  3.5573e-04, -1.9414e-02,  6.3703e-01, -4.1700e-01],
          [ 2.8722e-02,  7.6172e-01, -9.0842e-02, -1.7895e-01,  5.8761e-01],
          [-5.0746e-01,  1.9543e-01, -2.0671e-01, -1.4951e-01, -5.1422e-01]],

         [[ 2.5287e-01,  8.4620e-03, -2.4208e-01, -1.3033e-01,  7.7651e-01],
          [-8.5219e-01,  1.9928e-02, -6.2521e-01,  8.4488e-02, -7.2770e-01],
          [ 5.0926e-01,  5.2909e-01,  5.9945e-01,  3.4674e-01,  1.4391e+00],
          [ 6.4368e-01,  4.1816e-01,  4.9041e-01, -2.7250e-01,  4.0249e-01],
          [-2.5678e-01,  2.8006e-01,  5.6157e-01,  9.0215e-01,  3.0465e-01]],

         [[-1.2031e+00, -1.2954e+00, -8.8688e-01,  7.7508e-03, -7.1005e-01],
          [-8.4900e-01, -1.1717e+00, -7.0141e-01, -1.9813e-01, -5.9729e-01],
          [

In [30]:
assert False

AssertionError: 

In [None]:
x = torch.rand([1, 2, 3])
mask = torch.ones([1, 2])
mask[0, 1] = 0
mask = mask.unsqueeze(1)
print(mask == 0)
x.masked_fill(mask == 0, float("-inf"))