In [1]:
# Unmasked attention

In [2]:
import torch
import math
import torch.nn.functional as F

def self_attention(q, k, v):
    # if 3 dim: b t d
    #prod = Q.bmm(K.permute(0, 2, 1))
    # or
    # prod = torch.einsum("btd, bsd -> bts", q, k)
    # if 4 dim: b t h dh
    # q = q.permute(0, 2, 1, 3)
    prod = torch.einsum("bthd, bshd -> bhts", q, k)
    scaled_prod = prod/torch.sqrt(torch.tensor(q.shape[-1]))
    softmaxed_prod = F.softmax(scaled_prod, dim=-1)
    print(softmaxed_prod.shape)
    # print(softmaxed_prod)
    return softmaxed_prod @ v.permute(0, 2, 1, 3)


x = torch.rand([2, 3, 4, 5])
self_attention(x, x, x)
self_attention(x, x, x).shape

  cpu = _conversion_method_template(device=torch.device("cpu"))


torch.Size([2, 4, 3, 3])
torch.Size([2, 4, 3, 3])


torch.Size([2, 4, 3, 5])

In [3]:
from torch import nn

class MHSA(nn.Module):
    def __init__(self, d: int = 512, h: int = 8):
        super().__init__()
        assert d % h == 0
        self.d = d
        self.dh = d // h
        self.h = h
        self.wq = nn.Linear(self.d, self.d)
        self.wk = nn.Linear(self.d, self.d)
        self.wv = nn.Linear(self.d, self.d)
        self.wo = nn.Linear(self.d, self.d)
 
    def forward(self, q, k, v):
        # b, t, d
        b, t, d = q.size()
        wq = self.wq(q)
        wk = self.wk(k)
        wv = self.wv(v)
        wq = wq.view(b, t, self.h, self.dh)
        wk = wk.view(b, t, self.h, self.dh)
        wv = wv.view(b, t, self.h, self.dh)
        # b, t, h, dh
        # if changing from 4 dim -> 3 dim: b*h, t, dh
        # wq = wq.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # wk = wk.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # wv = wv.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # another option 4 dim -> 3 dim
        # wq = wq.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # wk = wk.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # wv = wv.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # changing the number of dims is not necessary as @ supports 4 dims
        attn = self_attention(wq, wk, wv)
        # b * h, t, dh
        # attn = attn.view(b, self.h, t, self.dh).permute(0, 2, 1, 3).reshape(b, t, d)
        attn = attn.view(b, self.h, t, self.dh).transpose(1, 2).contiguous().view(b, t, d)
        wo = self.wo(attn)
        return wo
        # # 1 2 3 4
        # x = F.relu(self.conv1(x))
        # return F.relu(self.conv2(x))

mhsa = MHSA()
x = torch.rand(2, 3, 512)
mhsa(x, x, x).shape

torch.Size([2, 8, 3, 3])


torch.Size([2, 3, 512])

In [4]:
# PE

In [5]:
from torch import nn

class PE(nn.Module):

    def __init__(self, d: int = 512, max_len: int = 1024, dropout: float = 0.1):
        super().__init__()
        self.d = d
        self.dropout = nn.Dropout(p=dropout)

        twoi = torch.arange(0, self.d, 2)
        pow_ = torch.pow(10000, twoi / self.d)
        position = torch.arange(0, max_len).unsqueeze(1)
        sin_p = torch.sin(position / pow_)
        cos_p = torch.cos(position / pow_)
        pe = torch.zeros(max_len, self.d, requires_grad=False) # Explicit, register buffer insures requires grad = False
        pe[:, 0::2] = sin_p
        pe[:, 1::2] = cos_p
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe) 

    def forward(self, x):
        b, t, d = x.size()
        x = x + self.pe[:, : t, :]
        return self.dropout(x)
print(PE(d=4)(torch.arange(24).view(-1, 3, 4)).size()) # torch.Size([2, 3, 4])

torch.Size([2, 3, 4])


In [6]:
from torch import nn

class PEEmbed(nn.Module):

    def __init__(self, d: int = 512, max_len: int = 1024, dropout: float = 0.1):
        super().__init__()
        self.d = d
        self.pe = nn.Embedding(max_len, d)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        b, t, d = x.size()
        pos = self.pe(torch.arange(t))
        x = x + pos
        return self.dropout(x)
print(PEEmbed(d=4)(torch.arange(24).view(-1, 3, 4)).size()) # torch.Size([2, 3, 4])

torch.Size([2, 3, 4])


In [7]:
# Encoder without mask

In [8]:
import torch.nn.functional as F

class EncoderLayerWithoutMask(nn.Module): 

    def __init__(self, d: int = 512, h: int = 8, dropout: float = 0.1):
        super().__init__()
        self.mhsa = MHSA(d, h)
        self.norm1 = nn.LayerNorm(d)
        self.ff1 = nn.Linear(d, d * 4)
        self.ff2 = nn.Linear(d * 4, d)
        self.norm2 = nn.LayerNorm(d)
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)

    def forward(self, x):
        b, t, d = x.size()
        x = x + self.attn_dropout(self.mhsa(x, x, x))
        x = self.norm1(x)
        x = x + self.resid_dropout(self.ff2(F.relu(self.ff1(x))))
        x = self.norm2(x)
        return x


encoder_layer = EncoderLayerWithoutMask()
x = torch.rand(2, 3, 512)
encoder_layer(x).shape

torch.Size([2, 8, 3, 3])


torch.Size([2, 3, 512])

In [9]:
from torch import nn

class EncoderWithoutMask(nn.Module): 

    def __init__(self, vocab_size: int = 2**13, n: int = 6, d: int = 512, h: int = 8):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d)
        self.pe = PE(d=d)
        self.layers = [EncoderLayerWithoutMask(d, h) for _ in range(n)]

    def forward(self, x):
        b, t = x.size()
        x = self.embed(x)
        x = self.pe(x)
        for layer in self.layers:
            x = layer(x)
        return x

encoder = EncoderWithoutMask()
x = torch.randint(0, 2**13, (2, 3))
encoder(x).shape

torch.Size([2, 8, 3, 3])
torch.Size([2, 8, 3, 3])
torch.Size([2, 8, 3, 3])
torch.Size([2, 8, 3, 3])
torch.Size([2, 8, 3, 3])
torch.Size([2, 8, 3, 3])


torch.Size([2, 3, 512])

In [10]:
# With masks
import torch
import math
import torch.nn.functional as F

def self_attention_masked(q, k, v, mask=None):
    # if 3 dim: b t d
    #prod = Q.bmm(K.permute(0, 2, 1))
    # or
    # prod = torch.einsum("btd, bsd -> bts", q, k)
    # if 4 dim: b t h dh:
    prod = torch.einsum("bthd, bshd -> bhts", q, k)
    scaled_prod = prod/torch.sqrt(torch.tensor(q.shape[-1]))
    print(f"scaled_prod.shape: \n {scaled_prod.shape}")
    # mask should be in shape to be broadcastable to bhts and lead to masked keys only (last s dim)
    if mask is not None:
        scaled_prod = scaled_prod.masked_fill(mask == 0, float("-inf"))
    print(f"scaled_prod: \n {scaled_prod}")
    softmaxed_prod = F.softmax(scaled_prod, dim=-1)
    # print(softmaxed_prod.shape)
    print(f"softmaxed_prod: \n {softmaxed_prod}")
    # swap h and t in v
    return softmaxed_prod @ v.permute(0, 2, 1, 3)


In [11]:
# Mask

In [12]:
# play with mask

x = torch.rand([2, 3, 2, 4])
print(x)
# mask 2 batches 3 timeseries
mask = torch.ones([2, 3])
mask[0, 2] = 0
mask[1, 2] = 0
mask[1, 1] = 0
print(f"mask: \n {mask}")
# add head dim to make mask broatcastable to q x k.T prod. mask shape 2, 1, 3
mask = mask.unsqueeze(1)


# mask = mask.permute(0, 2, 1)
# is the mask that I need? keys are ignored?
print(f"wrong mask: \n {mask}")
#  mask = 2 1 3 -> b prepended before broadcasting (1!!!) h (remains since already 2) t (broadcasted from 1) d (remains since already 3) 
print(f"wrong mask broadcast: \n {mask.broadcast_to([2, 2, 3, 3])}") 
a = self_attention_masked(x, x, x, mask=mask)
print(f"wrong a: \n {a}" )
print(f"wrong a.shape: \n {a.shape}")
# leads to wrong attention since the shape of mask is wrong 2 1 3 

# correct mask
# mask 2 batches 3 timeseries
mask = torch.ones([2, 3])
mask[0, 2] = 0
mask[1, 2] = 0
mask[1, 1] = 0
mask = mask.unsqueeze(1).unsqueeze(1)

print(f"mask: \n {mask}")
#  mask = 2 1 1 3 -> b (remains already 2) h (broadcasted from 1) t (broadcasted from 1) d (remains since already 3) 
print(f"mask broadcast: \n {mask.broadcast_to([2, 2, 3, 3])}") 
a = self_attention_masked(x, x, x, mask=mask)
print(f"a: \n {a}" )
print(f"a.shape: \n {a.shape}")


tensor([[[[0.6164, 0.1778, 0.9353, 0.4388],
          [0.6152, 0.1659, 0.4717, 0.4772]],

         [[0.3524, 0.0751, 0.3100, 0.9328],
          [0.0744, 0.9578, 0.5731, 0.7566]],

         [[0.1704, 0.1385, 0.5422, 0.5292],
          [0.7538, 0.3801, 0.4732, 0.1390]]],


        [[[0.6504, 0.3352, 0.7541, 0.7502],
          [0.6255, 0.1920, 0.5531, 0.2186]],

         [[0.5849, 0.3492, 0.8659, 0.2505],
          [0.1802, 0.2976, 0.4366, 0.5256]],

         [[0.1058, 0.1050, 0.5552, 0.5796],
          [0.2178, 0.4177, 0.6545, 0.5954]]]])
mask: 
 tensor([[1., 1., 0.],
        [1., 0., 0.]])
wrong mask: 
 tensor([[[1., 1., 0.]],

        [[1., 0., 0.]]])
wrong mask broadcast: 
 tensor([[[[1., 1., 0.],
          [1., 1., 0.],
          [1., 1., 0.]],

         [[1., 0., 0.],
          [1., 0., 0.],
          [1., 0., 0.]]],


        [[[1., 1., 0.],
          [1., 1., 0.],
          [1., 1., 0.]],

         [[1., 0., 0.],
          [1., 0., 0.],
          [1., 0., 0.]]]])
scaled_prod.shape

In [13]:
# mask is equal to making keys on masked places 0:
# the result in terms of masked symbols is the same
k = x.clone()
k[0, 2, 0, :] = float("-inf")
k[0, 2, 1, :] = float("-inf")
k[1, 2, 0, :] = float("-inf")
k[1, 1, 0, :] = float("-inf")
k[1, 2, 1, :] = float("-inf")
k[1, 1, 1, :] = float("-inf")
print(f"k: \n {k}")
a = self_attention_masked(x, k, x)
print(f"a: \n {a}" )
print(f"a.shape: \n {a.shape}")
# a is the same shape as if mask was applied in q * k:

test = torch.rand([2, 3, 4])
test[0, 2, :] = 0
test[1, 1, :] = 0
test[1, 2, :] = 0

print(f"test: \n {test}")
test_v = test.view(2, 3, 2, 2)
print(f"test_v: \n {test_v}")
test_perm = test_v.permute(0, 2, 1, 3)
print(f"test_perm: \n {test_perm}")

# or like that:
test_q = torch.rand([2, 3, 4])
test_k = test_q.clone()
test_k[0, 2, :] = float("-inf")
test_k[1, 1, :] = float("-inf")
test_k[1, 2, :] = float("-inf")
print(f"test_k: \n {test_k}")

test_q_view = test_q.view(2, 3, 2, 2)
test_k_view = test_k.view(2, 3, 2, 2)
print(f"test_k_view: \n {test_k_view}")
test_q_perm = test_q_view.permute(0, 2, 1, 3)
test_k_perm = test_k_view.permute(0, 2, 1, 3)
print(f"test_k_perm: \n {test_k_perm}")
print(f"q * k: \n {torch.einsum("bhtd, bhsd -> bhts", test_q_perm, test_k_perm)}")

k: 
 tensor([[[[0.6164, 0.1778, 0.9353, 0.4388],
          [0.6152, 0.1659, 0.4717, 0.4772]],

         [[0.3524, 0.0751, 0.3100, 0.9328],
          [0.0744, 0.9578, 0.5731, 0.7566]],

         [[  -inf,   -inf,   -inf,   -inf],
          [  -inf,   -inf,   -inf,   -inf]]],


        [[[0.6504, 0.3352, 0.7541, 0.7502],
          [0.6255, 0.1920, 0.5531, 0.2186]],

         [[  -inf,   -inf,   -inf,   -inf],
          [  -inf,   -inf,   -inf,   -inf]],

         [[  -inf,   -inf,   -inf,   -inf],
          [  -inf,   -inf,   -inf,   -inf]]]])
scaled_prod.shape: 
 torch.Size([2, 2, 3, 3])
scaled_prod: 
 tensor([[[[0.7395, 0.4649,   -inf],
          [0.4649, 0.5480,   -inf],
          [0.4345, 0.3661,   -inf]],

         [[0.4281, 0.4180,   -inf],
          [0.4180, 0.9119,   -inf],
          [0.4082, 0.3983,   -inf]]],


        [[[0.8334,   -inf,   -inf],
          [0.6692,   -inf,   -inf],
          [0.4787,   -inf,   -inf]],

         [[0.3909,   -inf,   -inf],
          [0.2631,   -i

In [14]:
import torch

def build_padding_mask(x, pad_token):
    # x: b t shape
    mask = torch.ones_like(x)
    return mask.masked_fill(x == pad_token, 0)

x = torch.rand(5, 6)
x[0, -3:] = 100
x[1, -2:] = 100
x[2, -1] = 100
x[3, :] = 100
print(x)
print(build_padding_mask(x, 100))

tensor([[7.4148e-01, 2.2639e-02, 1.2796e-01, 1.0000e+02, 1.0000e+02, 1.0000e+02],
        [3.5107e-01, 3.3189e-02, 6.5396e-01, 6.8129e-01, 1.0000e+02, 1.0000e+02],
        [7.0389e-01, 6.7112e-01, 2.1144e-01, 4.4746e-02, 1.7409e-01, 1.0000e+02],
        [1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02],
        [6.0043e-01, 8.4909e-01, 9.0186e-02, 8.3443e-01, 5.7014e-01, 1.9058e-01]])
tensor([[1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [15]:
import torch

def build_causal_mask(x):
    # x: b t shape
    m = torch.ones_like(x)
    return torch.tril(m)
x = torch.rand(5, 6)

print(build_causal_mask(x))

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.]])


In [16]:
import torch

def merge_masks(m1, m2):
    return m1 * m2

x = torch.rand(5, 6)
x[0, -3:] = 100
x[1, -1] = 100
x[2, -4:] = 100
x[3, :] = 100
print(x)
m1 = build_padding_mask(x, 100)
m2 = build_causal_mask(x)
print(merge_masks(m1, m2))

tensor([[  0.1822,   0.9773,   0.3937, 100.0000, 100.0000, 100.0000],
        [  0.6297,   0.4831,   0.5250,   0.1028,   0.3756, 100.0000],
        [  0.3837,   0.6903, 100.0000, 100.0000, 100.0000, 100.0000],
        [100.0000, 100.0000, 100.0000, 100.0000, 100.0000, 100.0000],
        [  0.4342,   0.5863,   0.4891,   0.4150,   0.5842,   0.6702]])
tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0.]])


In [17]:
import torch

def reshape_mask(mask):
    # b t -> b 1 1 t (to be broadcastable to b h t t)
    return mask.unsqueeze(1).unsqueeze(1)

x = torch.rand(2, 3)
print(reshape_mask(build_causal_mask(x)))

tensor([[[[1., 0., 0.]]],


        [[[1., 1., 0.]]]])


In [18]:
from torch import nn

class MHSAMasked(nn.Module):
    def __init__(self, d: int = 512, h: int = 8):
        super().__init__()
        assert d % h == 0
        self.d = d
        self.dh = d // h
        self.h = h
        self.wq = nn.Linear(self.d, self.d)
        self.wk = nn.Linear(self.d, self.d)
        self.wv = nn.Linear(self.d, self.d)
        self.wo = nn.Linear(self.d, self.d)
 
    def forward(self, q, k, v, mask = None):
        # q and k/v might be of different sizes if lengths of decoder and encoders inputs are different
        bq, tq, dq = q.size()
        bk, tk, dk = k.size()
        wq = self.wq(q)
        wk = self.wk(k)
        wv = self.wv(v)
        wq = wq.view(bq, tq, self.h, self.dh)
        wk = wk.view(bk, tk, self.h, self.dh)
        wv = wv.view(bk, tk, self.h, self.dh)
        # b, t, h, dh
        # if changing from 4 dim -> 3 dim: b*h, t, dh
        # wq = wq.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # wk = wk.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # wv = wv.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # another option 4 dim -> 3 dim
        # wq = wq.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # wk = wk.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # wv = wv.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # changing the number of dims is not necessary as @ supports 4 dims
        attn = self_attention_masked(wq, wk, wv, mask=mask)
        # b * h, t, dh
        # attn = attn.view(b, self.h, t, self.dh).permute(0, 2, 1, 3).reshape(b, t, d)
        attn = attn.view(bq, self.h, tq, self.dh).transpose(1, 2).contiguous().view(bq, tq, dq)
        wo = self.wo(attn)
        return wo
        # # 1 2 3 4
        # x = F.relu(self.conv1(x))
        # return F.relu(self.conv2(x))

mhsa_masked = MHSAMasked(h = 2, d = 6)
x = torch.rand(4, 5)
mask = reshape_mask(build_causal_mask(x))
print(mask)
x = torch.rand(4, 5, 6)
print(mhsa_masked(x, x, x, mask=mask))
print(mhsa_masked(x, x, x, mask=mask).shape)

tensor([[[[1., 0., 0., 0., 0.]]],


        [[[1., 1., 0., 0., 0.]]],


        [[[1., 1., 1., 0., 0.]]],


        [[[1., 1., 1., 1., 0.]]]])
scaled_prod.shape: 
 torch.Size([4, 2, 5, 5])
scaled_prod: 
 tensor([[[[ 0.0537,    -inf,    -inf,    -inf,    -inf],
          [ 0.1387,    -inf,    -inf,    -inf,    -inf],
          [ 0.1328,    -inf,    -inf,    -inf,    -inf],
          [ 0.0958,    -inf,    -inf,    -inf,    -inf],
          [ 0.2381,    -inf,    -inf,    -inf,    -inf]],

         [[-0.0215,    -inf,    -inf,    -inf,    -inf],
          [-0.0674,    -inf,    -inf,    -inf,    -inf],
          [-0.2111,    -inf,    -inf,    -inf,    -inf],
          [-0.1404,    -inf,    -inf,    -inf,    -inf],
          [-0.2935,    -inf,    -inf,    -inf,    -inf]]],


        [[[ 0.0078,  0.0492,    -inf,    -inf,    -inf],
          [ 0.0057,  0.0968,    -inf,    -inf,    -inf],
          [-0.0754,  0.0486,    -inf,    -inf,    -inf],
          [-0.0710,  0.0382,    -inf,    -inf,   

In [19]:
# Transformer implementation 

In [20]:
import torch.nn.functional as F

class EncoderLayer(nn.Module): 

    def __init__(self, d: int = 512, h: int = 8, dropout: float = 0.1):
        super().__init__()
        self.mhsa = MHSAMasked(d, h)
        self.norm1 = nn.LayerNorm(d)
        self.ff1 = nn.Linear(d, d * 4)
        self.ff2 = nn.Linear(d * 4, d)
        self.norm2 = nn.LayerNorm(d)
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)

    def forward(self, x, self_mask=None):
        b, t, d = x.size()
        x = x + self.attn_dropout(self.mhsa(x, x, x, mask=self_mask))
        x = self.norm1(x)
        x = x + self.resid_dropout(self.ff2(F.relu(self.ff1(x))))
        x = self.norm2(x)
        return x


encoder_layer = EncoderLayer()
self_mask = build_padding_mask(torch.tensor([[2, 2, 0], [2, 0, 0]]), pad_token=0)
self_mask = reshape_mask(self_mask)
x = torch.rand(2, 3, 512)

encoder_layer(x, self_mask=self_mask).shape

scaled_prod.shape: 
 torch.Size([2, 8, 3, 3])
scaled_prod: 
 tensor([[[[-0.0754, -0.1314,    -inf],
          [-0.0970, -0.1964,    -inf],
          [ 0.0464, -0.0945,    -inf]],

         [[ 0.0843,  0.0848,    -inf],
          [-0.0227, -0.0228,    -inf],
          [ 0.0956,  0.0830,    -inf]],

         [[-0.0293, -0.0763,    -inf],
          [-0.0226, -0.0955,    -inf],
          [-0.0309, -0.0531,    -inf]],

         [[-0.0876,  0.0441,    -inf],
          [-0.0258,  0.0927,    -inf],
          [-0.0357,  0.0175,    -inf]],

         [[ 0.0636,  0.1356,    -inf],
          [ 0.1351,  0.1470,    -inf],
          [-0.0188,  0.0551,    -inf]],

         [[-0.1990,  0.0022,    -inf],
          [-0.1659,  0.0083,    -inf],
          [-0.2187, -0.0992,    -inf]],

         [[ 0.0008, -0.0380,    -inf],
          [ 0.0258, -0.0592,    -inf],
          [-0.0053, -0.0550,    -inf]],

         [[ 0.0972,  0.2530,    -inf],
          [ 0.1754,  0.2615,    -inf],
          [ 0.0913,  0.2150,

torch.Size([2, 3, 512])

In [21]:
from torch import nn

class Encoder(nn.Module): 

    def __init__(self, vocab_size: int = 2**13, n: int = 6, d: int = 512, h: int = 8):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d)
        self.pe = PE(d=d)
        self.layers = [EncoderLayer(d, h) for _ in range(n)]

    def forward(self, x, self_mask = None):
        b, t = x.size()
        x = self.embed(x)
        x = self.pe(x)
        for layer in self.layers:
            x = layer(x, self_mask=self_mask)
        return x

encoder = Encoder()
x = torch.randint(0, 2**13, (2, 3))
self_mask = build_padding_mask(torch.tensor([[2, 2, 0], [2, 0, 0]]), pad_token=0)
self_mask = reshape_mask(self_mask)
encoder(x, self_mask).shape

scaled_prod.shape: 
 torch.Size([2, 8, 3, 3])
scaled_prod: 
 tensor([[[[-0.9491, -0.3482,    -inf],
          [-0.0322,  0.4981,    -inf],
          [-0.5286, -0.2884,    -inf]],

         [[-0.0396,  0.2085,    -inf],
          [ 0.4922,  0.1172,    -inf],
          [ 0.9767, -0.5614,    -inf]],

         [[-0.7110,  0.8074,    -inf],
          [ 0.8500,  0.3379,    -inf],
          [-0.1594,  0.0954,    -inf]],

         [[ 0.5741,  0.1972,    -inf],
          [ 0.0464,  1.7148,    -inf],
          [ 0.4335,  0.0948,    -inf]],

         [[-0.0058,  0.3594,    -inf],
          [ 0.7504,  0.3286,    -inf],
          [ 0.9123,  0.7980,    -inf]],

         [[-0.3001,  0.3517,    -inf],
          [ 0.6769,  0.3782,    -inf],
          [-0.0225, -0.4206,    -inf]],

         [[-0.2361,  0.3447,    -inf],
          [ 0.8986,  0.4357,    -inf],
          [ 0.5661, -0.0671,    -inf]],

         [[-0.1737, -0.1760,    -inf],
          [ 0.3013,  0.0356,    -inf],
          [ 0.1016, -0.2116,

torch.Size([2, 3, 512])

In [22]:
import torch.nn.functional as F

class DecoderLayer(nn.Module): 

    def __init__(self, d: int = 512, h: int = 8, dropout: float = 0.1):
        super().__init__()
        self.mhsa = MHSAMasked(d=d, h=h)
        self.attn_norm = nn.LayerNorm(d)
        self.attn_dropout = nn.Dropout(dropout)

        self.mhca = MHSAMasked(d=d, h=h)
        self.cross_attn_norm = nn.LayerNorm(d)
        self.cross_attn_dropout = nn.Dropout(dropout)
        
        self.ff1 = nn.Linear(d, d * 4)
        self.ff2 = nn.Linear(d * 4, d)
        self.resid_dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(d)
        

    def forward(self, dec_x, enc_x, self_mask=None, cross_mask=None):
        # self_mask is merged decoders padding and causal masks
        # cross_mask is equal to endcoders padding mask because we don't want to attend to encoded padded tokens
        b, t, d = dec_x.size()
        x = dec_x + self.attn_dropout(self.mhsa(dec_x, dec_x, dec_x, mask=self_mask))
        x = self.attn_norm(x)

        x = x + self.cross_attn_dropout(self.mhca(x, enc_x, enc_x, mask=cross_mask))
        x = self.cross_attn_norm(x)
        
        x = x + self.resid_dropout(self.ff2(F.relu(self.ff1(x))))
        x = self.norm(x)
        return x


decoder_layer = DecoderLayer(h=2, d=16)
x = torch.rand(3, 3, 16)
y = torch.rand(3, 3, 16)
self_mask1 = build_padding_mask(torch.tensor([[2, 2, 0], [2, 0, 0], [2, 2, 0]]), pad_token=0)
self_mask2 = build_causal_mask(torch.tensor([[2, 2, 0], [2, 0, 0], [2, 2, 0]]))
self_mask = merge_masks(self_mask1, self_mask2)
print(f"self_mask: \n {self_mask}")
self_mask = reshape_mask(self_mask)

cross_mask = build_padding_mask(torch.tensor([[2, 2, 2], [2, 0, 0], [2, 2, 0]]), pad_token=0)
cross_mask = reshape_mask(cross_mask)
print(f"cross_mask: \n {cross_mask}")
decoder_layer(x, y, self_mask=self_mask, cross_mask=cross_mask).shape

self_mask: 
 tensor([[1, 0, 0],
        [1, 0, 0],
        [1, 1, 0]])
cross_mask: 
 tensor([[[[1, 1, 1]]],


        [[[1, 0, 0]]],


        [[[1, 1, 0]]]])
scaled_prod.shape: 
 torch.Size([3, 2, 3, 3])
scaled_prod: 
 tensor([[[[ 0.0310,    -inf,    -inf],
          [-0.1026,    -inf,    -inf],
          [-0.0142,    -inf,    -inf]],

         [[ 0.1056,    -inf,    -inf],
          [ 0.0850,    -inf,    -inf],
          [ 0.0925,    -inf,    -inf]]],


        [[[-0.0730,    -inf,    -inf],
          [-0.0086,    -inf,    -inf],
          [-0.0795,    -inf,    -inf]],

         [[ 0.0661,    -inf,    -inf],
          [-0.0332,    -inf,    -inf],
          [ 0.0974,    -inf,    -inf]]],


        [[[-0.0357, -0.0447,    -inf],
          [-0.0492, -0.0687,    -inf],
          [ 0.0827,  0.0655,    -inf]],

         [[ 0.0153,  0.0570,    -inf],
          [-0.0090, -0.0014,    -inf],
          [-0.1045,  0.0419,    -inf]]]], grad_fn=<MaskedFillBackward0>)
softmaxed_prod: 
 tensor([[[[1

torch.Size([3, 3, 16])

In [23]:
from torch import nn

class Decoder(nn.Module): 

    def __init__(self, vocab_size: int = 2**13, n: int = 6, d: int = 512, h: int = 8):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d)
        self.pe = PE(d=d)
        self.layers = [DecoderLayer(d, h) for _ in range(n)]

    def forward(self, dec_x, enc_x, self_mask=None, cross_mask=None):
        b, t = dec_x.size()
        x = self.embed(dec_x)
        x = self.pe(x)
        for layer in self.layers:
            x = layer(x, enc_x, self_mask=self_mask, cross_mask=cross_mask)
        return x

    def get_embed_weights(self):
        return self.embed.weight

decoder = Decoder(vocab_size=32, n=2, d=16, h=2)
# x = torch.randint(0, 32, (2, 3))
x = torch.tensor([[15, 7, 0], [10, 0, 0], [1, 3, 0]])
y = torch.rand(3, 3, 16)

self_mask1 = build_padding_mask(x, pad_token=0)
self_mask2 = build_causal_mask(x)
self_mask = merge_masks(self_mask1, self_mask2)
print(f"self_mask: \n {self_mask}")
self_mask = reshape_mask(self_mask)

cross_mask = build_padding_mask(torch.tensor([[2, 2, 2], [2, 0, 0], [2, 2, 0]]), pad_token=0)
cross_mask = reshape_mask(cross_mask)
print(f"cross_mask: \n {cross_mask}")
print(decoder(x, y, self_mask=self_mask, cross_mask=cross_mask).shape)

self_mask: 
 tensor([[1, 0, 0],
        [1, 0, 0],
        [1, 1, 0]])
cross_mask: 
 tensor([[[[1, 1, 1]]],


        [[[1, 0, 0]]],


        [[[1, 1, 0]]]])
scaled_prod.shape: 
 torch.Size([3, 2, 3, 3])
scaled_prod: 
 tensor([[[[-0.3470,    -inf,    -inf],
          [ 0.2680,    -inf,    -inf],
          [ 0.0472,    -inf,    -inf]],

         [[-0.4876,    -inf,    -inf],
          [-0.5241,    -inf,    -inf],
          [-0.0243,    -inf,    -inf]]],


        [[[-0.3846,    -inf,    -inf],
          [ 0.0702,    -inf,    -inf],
          [ 0.0368,    -inf,    -inf]],

         [[ 1.5445,    -inf,    -inf],
          [ 0.2934,    -inf,    -inf],
          [ 0.2163,    -inf,    -inf]]],


        [[[-1.3250, -0.1522,    -inf],
          [-1.0179, -0.2387,    -inf],
          [-0.2871,  0.0590,    -inf]],

         [[-0.2381,  0.2220,    -inf],
          [-0.9814,  0.1479,    -inf],
          [ 0.0563,  0.3812,    -inf]]]], grad_fn=<MaskedFillBackward0>)
softmaxed_prod: 
 tensor([[[[1

In [24]:
from torch import nn

class Output(nn.Module):

    def __init__(self, vocab_size: int = 2**13, d: int = 512, ff_weight = None):
        super().__init__()
        self.ff = nn.Linear(d, vocab_size)
        # weight tying with the decoder embedding
        if ff_weight is not None:
            self.ff.weight = ff_weight

    def forward(self, x):
        return self.ff(x)

In [25]:
from torch import nn

class Transformer(nn.Module):
    
    def __init__(self, vocab_size: int = 2**13, n: int = 6, d: int = 512, h: int = 8, embed_tying=True):
        super().__init__()
        self.encoder = Encoder(vocab_size=vocab_size, n=n, d=d, h=h)
        self.decoder = Decoder(vocab_size=vocab_size, n=n, d=d, h=h)
        if embed_tying:
            self.output = Output(vocab_size=vocab_size, d=d, ff_weight = self.decoder.get_embed_weights())
        else:
            self.output = Output(vocab_size=vocab_size, d=d)

    def forward(self, enc_x, dec_x, enc_mask=None, dec_mask=None):
        encoded = self.encoder(enc_x, enc_mask)
        decoded = self.decoder(dec_x=dec_x, enc_x=encoded, self_mask=dec_mask, cross_mask=enc_mask)
        return self.output(decoded)

transformer = Transformer(vocab_size=32, n=2, d=16, h=2, embed_tying=False)
enc_x = torch.tensor([[15, 7, 3], [10, 10, 0], [1, 0, 0]])
dec_x = torch.tensor([[21, 8, 0, 0], [25, 0, 0, 0], [8, 1, 2, 3]])
# dec_x = torch.tensor([[21, 8], [25, 0], [8, 1]])

enc_mask = build_padding_mask(enc_x, pad_token=0)
print(f"enc_mask: \n {enc_mask}")
enc_mask = reshape_mask(enc_mask)

dec_mask1 = build_padding_mask(dec_x, pad_token=0)
dec_mask2 = build_causal_mask(dec_x)
dec_mask = merge_masks(dec_mask1, dec_mask2)
print(f"dec_mask: \n {dec_mask}")
dec_mask = reshape_mask(dec_mask)

print(transformer(enc_x, dec_x, enc_mask=enc_mask, dec_mask=dec_mask).shape)

enc_mask: 
 tensor([[1, 1, 1],
        [1, 1, 0],
        [1, 0, 0]])
dec_mask: 
 tensor([[1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 1, 1, 0]])
scaled_prod.shape: 
 torch.Size([3, 2, 3, 3])
scaled_prod: 
 tensor([[[[ 5.9970e-01, -1.1968e-01,  4.6869e-01],
          [-9.8441e-01,  2.4761e-01, -2.0759e-01],
          [-2.5555e-01,  3.8735e-01, -3.5294e-01]],

         [[-3.0122e-02, -1.0659e-04,  1.3670e-01],
          [-4.6867e-01,  8.3376e-02, -1.7783e-01],
          [-5.8456e-01,  1.0852e+00, -1.1018e-01]]],


        [[[ 4.5371e-01,  4.1200e-01,        -inf],
          [ 6.5218e-01,  6.4440e-01,        -inf],
          [ 2.1543e-01,  4.5420e-01,        -inf]],

         [[ 4.3294e-01,  8.5197e-02,        -inf],
          [ 3.1179e-01,  2.4474e-01,        -inf],
          [ 1.6698e-01, -3.2434e-02,        -inf]]],


        [[[ 1.3633e+00,        -inf,        -inf],
          [-2.3976e-01,        -inf,        -inf],
          [-9.8775e-02,        -inf,        -inf]],

         [[

In [26]:
# Inference

In [27]:
import tiktoken
import torch
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F


encoding = tiktoken.get_encoding("cl100k_base")
sents = ["Hello World", "This is a simple sentence", "Me"]
encoded_sents = [encoding.encode(s) for s in sents]
enc_x = pad_sequence([torch.tensor(es) for es in encoded_sents], batch_first=True, padding_value=encoding.eot_token)
print(enc_x)
dec_sents = ["Bonjour", "C'est une phrase", "START"]
dec_encoded_sents = [encoding.encode(s) for s in dec_sents]
dec_x = pad_sequence([torch.tensor(es) for es in dec_encoded_sents], batch_first=True, padding_value=encoding.eot_token)
print(dec_x)

transformer = Transformer(vocab_size=encoding.n_vocab, n=3, d=256, h=4)

enc_mask = build_padding_mask(enc_x, pad_token=100257)
print(f"enc_mask: \n {enc_mask}")
enc_mask = reshape_mask(enc_mask)

dec_mask1 = build_padding_mask(dec_x, pad_token=100257)
dec_mask2 = build_causal_mask(dec_x)
dec_mask = merge_masks(dec_mask1, dec_mask2)
print(f"dec_mask: \n {dec_mask}")
dec_mask = reshape_mask(dec_mask)

output = transformer(enc_x, dec_x, enc_mask=enc_mask, dec_mask=dec_mask)
print(f"output shape: {output.shape}")
softmaxed = F.softmax(output, dim=-1)
print(f"softmaxed[0, 0, :10]: {softmaxed[0, 0, :10]}")
predicted = softmaxed.argmax(dim=-1)
print(f"predicted: \n {predicted}")

predicted_list = predicted.tolist()
predicted_decoded = [encoding.decode(l) for l in predicted_list]
print(f"predicted decoded: \n {predicted_decoded}")

tensor([[  9906,   4435, 100257, 100257, 100257],
        [  2028,    374,    264,   4382,  11914],
        [  7979, 100257, 100257, 100257, 100257]])
tensor([[ 82681, 100257, 100257, 100257],
        [    34,  17771,   6316,  17571],
        [ 23380, 100257, 100257, 100257]])
enc_mask: 
 tensor([[1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1],
        [1, 0, 0, 0, 0]])
dec_mask: 
 tensor([[1, 0, 0, 0],
        [1, 1, 0, 0],
        [1, 0, 0, 0]])
scaled_prod.shape: 
 torch.Size([3, 4, 5, 5])
scaled_prod: 
 tensor([[[[-3.7867e-02, -8.2838e-01,        -inf,        -inf,        -inf],
          [-2.8830e-01,  8.3117e-02,        -inf,        -inf,        -inf],
          [-1.2055e-01, -1.0235e+00,        -inf,        -inf,        -inf],
          [-3.5413e-01, -9.0736e-01,        -inf,        -inf,        -inf],
          [-4.2189e-01, -8.9929e-01,        -inf,        -inf,        -inf]],

         [[-3.6490e-02, -2.1649e-01,        -inf,        -inf,        -inf],
          [-1.2192e-01, -1.553

In [28]:
# Predicting next words
sent = "This is a simple sentence"
encoded_sent = encoding.encode(sent)
enc_x = torch.tensor(encoded_sent).unsqueeze(0)
dec_x = torch.tensor(encoding.encode("C")).unsqueeze(0)

transformer = Transformer(vocab_size=encoding.n_vocab, n=3, d=256, h=4)

predicted_tokens = []
for _ in range(5):
    output = transformer(enc_x=enc_x, dec_x=dec_x)
    softmaxed = F.softmax(output, dim=-1)
    predicted = softmaxed.argmax(dim=-1)
    predicted_tokens.append(predicted.tolist()[-1][-1]) 
    dec_x = torch.cat((dec_x, predicted), dim=-1)

print(predicted_tokens)
print(f"predicted sentence: \n {encoding.decode(predicted_tokens)}")

scaled_prod.shape: 
 torch.Size([1, 4, 5, 5])
scaled_prod: 
 tensor([[[[ 0.4056, -0.0592,  0.3210,  0.7641,  0.2315],
          [ 0.1923,  0.2160,  0.6901, -0.3807,  0.4211],
          [ 0.1536, -1.1140,  0.2217, -0.5141, -0.7577],
          [-0.2799,  0.7452, -0.6044, -0.7963,  0.2692],
          [ 0.4924, -0.5034, -0.1884, -0.9747,  0.0758]],

         [[ 0.2996,  0.9493,  0.4883,  0.8463,  0.5434],
          [ 0.9331,  1.1882, -0.0688, -1.1243, -0.4603],
          [ 0.6546,  0.1881,  1.0822, -0.6032, -0.0829],
          [ 0.0524,  0.4807, -0.3223, -0.4373,  0.4061],
          [ 1.3109,  0.8021,  1.0551, -0.3777,  0.3677]],

         [[-0.0879, -0.0493, -0.0073, -0.2372,  0.5121],
          [-0.6886, -0.5000, -0.7529, -0.1972,  0.2821],
          [-0.1928,  0.6222, -0.1332, -0.7815, -0.3502],
          [ 0.4074,  0.9669,  0.2027,  0.9032,  1.0341],
          [ 0.7860,  0.3283,  0.2983,  0.1211,  0.0942]],

         [[ 0.5253,  0.6203,  0.1955,  0.1431,  0.6613],
          [-0.6797,  

In [72]:
# Dataset
import torch
from torch.utils.data import Dataset
from pathlib import Path
import csv
from enum import Enum



class Partition(Enum):
    TRAIN = "train"
    VAL = "val"

class Tokens(Enum):
    START = "START "
    END = "<|endoftext|>"
    PAD = " PAD"
    START_NUM = 23380
    END_NUM = 100257
    PAD_NUM = 62854
    

class EnFrDataset(Dataset):

    def __init__(self, file: Path | str, partition: Partition = Partition.TRAIN, val_ratio: float = 0.1):
        # partition = TRAIN | VAL
        self._partition = partition
        self._val_ratio = val_ratio

        self._data = []
        self._train_map: dict[int, int] = {}
        self._val_map: dict[int, int] = {}
        train_id = 0
        val_id = 0
        with open(file, newline='') as csvfile:
            reader = csv.reader(csvfile)
            # we want data indexes start from 0, but filter out the first header row
            for i, row in enumerate(reader, start=-1):
                if i == -1:
                    continue
                en = row[0]
                fr = Tokens.START.value + row[1] + Tokens.END.value
                self._data.append(tuple([en, fr]))
                if int(i * val_ratio) == int((i - 1) * val_ratio):
                    self._train_map[train_id] = i
                    train_id += 1
                else:
                    self._val_map[val_id] = i
                    val_id += 1

    class Iterator():

        def __init__(self, outer):
            self.cur = 0
            self.outer = outer

        def __next__(self):
            if self.cur == len(self.outer._data):
                raise StopIteration()
            cur = self.outer._data[self.cur]
            self.cur += 1
            return cur

    def __iter__(self):
        return EnFrDataset.Iterator(self)
    
    @property
    def partition(self):
        return self._partition

    @partition.setter
    def partition(self, partition):
        self._partition = partition
    
    def __len__(self):
        return len(self._train_map) if self._partition == Partition.TRAIN else len(self._val_map)

    def __getitem__(self, idx):
        return self._data[self._train_map[idx]] if self._partition == Partition.TRAIN else self._data[self._val_map[idx]]

dataset = EnFrDataset("../data/eng_-french.csv")
train_sample = dataset[0]
dataset.partition = Partition.VAL
val_sample = dataset[0]
assert train_sample != val_sample
print(train_sample)
print(val_sample)

for i, d in enumerate(dataset):
    if i > 2:
        break
    print(d)

('Hi.', 'START Salut!<|endoftext|>')
('Stop!', 'START Arrête-toi !<|endoftext|>')
('Hi.', 'START Salut!<|endoftext|>')
('Run!', 'START Cours\u202f!<|endoftext|>')
('Run!', 'START Courez\u202f!<|endoftext|>')
('Hi.', 'START Salut!<|endoftext|>')
('Run!', 'START Cours\u202f!<|endoftext|>')
('Run!', 'START Courez\u202f!<|endoftext|>')


In [None]:
import tiktoken

encoding = tiktoken.get_encoding("cl100k_base")

def build_train_sample(en_str: str, dec_str: str):
    en_encoded = encoding.encode(en_str)
    dec_encoded = encoding.encode(dec_str)
    en_sents = []
    dec_sents = []
    target_sents = []
    for 
    

In [65]:
from torch.utils.data import DataLoader

# def collate_fn

dataset.partition = Partition.TRAIN
training_generator = DataLoader(dataset, batch_size=2, shuffle=True)
for i, s in enumerate(training_generator):
    print(s)
    if i > 2:
        break
        
    

[('They are five in all.', 'You have to live with it.'), ('Elles sont cinq en tout.', 'Il vous faut vivre avec.')]
[("It's fairly warm today.", 'We never speak French at home.'), ("Il fait assez chaud, aujourd'hui.", 'Nous ne parlons jamais français chez nous.')]
[("This rope is strong, isn't it?", "Tom and I don't usually agree."), ("Cette corde est résistante, n'est-ce pas\u202f?", "Tom et moi ne sommes habituellement pas d'accord.")]
[('I noticed a change.', 'It was too difficult for me.'), ("J'ai remarqué un changement.", "C'était trop difficile pour moi.")]


In [None]:
# training

print(torch.backends.mps.is_available())
print(torch.backends.mps.is_built())


In [30]:
class Partition(Enum):
    TRAIN = "train"
    VAL = "val"

In [None]:
F.softmax(torch.tensor([[0.1,0.2,0.3],[0.1, 0.2, 0.3]]), dim=-1)
torch.tensor([[0.1,0.2,0.3],[0.1, 0.2, 0.4]]).argmax(dim=-1)
# torch.max(torch.tensor([[0.1,0.2,0.3],[0.1, 0.2, 0.4]]), dim=-1)

t1 = torch.tensor([[0.1, 0.2]])
t2 = torch.tensor([[0.3]])
torch.cat((t1, t2), dim=1).tolist()[-1][-1]


In [55]:
int(0.1)

0

In [None]:
assert False

In [None]:
x = torch.rand([1, 2, 3])
mask = torch.ones([1, 2])
mask[0, 1] = 0
mask = mask.unsqueeze(1)
print(mask == 0)
x.masked_fill(mask == 0, float("-inf"))