In [1]:
# Unmasked attention

In [2]:
import torch
import math
import torch.nn.functional as F

def self_attention(q, k, v, verbose=False):
    # if 3 dim: b t d
    #prod = Q.bmm(K.permute(0, 2, 1))
    # or
    # prod = torch.einsum("btd, bsd -> bts", q, k)
    # if 4 dim: b t h dh
    # q = q.permute(0, 2, 1, 3)
    prod = torch.einsum("bthd, bshd -> bhts", q, k)
    scaled_prod = prod/torch.sqrt(torch.tensor(q.shape[-1]))
    softmaxed_prod = F.softmax(scaled_prod, dim=-1)
    if verbose:
        print(softmaxed_prod.shape)
    return softmaxed_prod @ v.permute(0, 2, 1, 3)


x = torch.rand([2, 3, 4, 5])
self_attention(x, x, x, verbose=True).shape

torch.Size([2, 4, 3, 3])


torch.Size([2, 4, 3, 5])

In [3]:
from torch import nn

class MHSA(nn.Module):
    def __init__(self, d: int = 512, h: int = 8):
        super().__init__()
        assert d % h == 0
        self.d = d
        self.dh = d // h
        self.h = h
        self.wq = nn.Linear(self.d, self.d)
        self.wk = nn.Linear(self.d, self.d)
        self.wv = nn.Linear(self.d, self.d)
        self.wo = nn.Linear(self.d, self.d)
 
    def forward(self, q, k, v):
        # b, t, d
        b, t, d = q.size()
        wq = self.wq(q)
        wk = self.wk(k)
        wv = self.wv(v)
        wq = wq.view(b, t, self.h, self.dh)
        wk = wk.view(b, t, self.h, self.dh)
        wv = wv.view(b, t, self.h, self.dh)
        # b, t, h, dh
        # if changing from 4 dim -> 3 dim: b*h, t, dh
        # wq = wq.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # wk = wk.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # wv = wv.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # another option 4 dim -> 3 dim
        # wq = wq.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # wk = wk.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # wv = wv.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # changing the number of dims is not necessary as @ supports 4 dims
        attn = self_attention(wq, wk, wv)
        # b * h, t, dh
        # attn = attn.view(b, self.h, t, self.dh).permute(0, 2, 1, 3).reshape(b, t, d)
        attn = attn.view(b, self.h, t, self.dh).transpose(1, 2).contiguous().view(b, t, d)
        wo = self.wo(attn)
        return wo
        # # 1 2 3 4
        # x = F.relu(self.conv1(x))
        # return F.relu(self.conv2(x))

mhsa = MHSA()
x = torch.rand(2, 3, 512)
mhsa(x, x, x).shape

torch.Size([2, 3, 512])

In [4]:
# PE

In [5]:
from torch import nn

class PE(nn.Module):

    def __init__(self, d: int = 512, max_len: int = 1024, dropout: float = 0.1):
        super().__init__()
        self.d = d
        self.dropout = nn.Dropout(p=dropout)

        twoi = torch.arange(0, self.d, 2)
        pow_ = torch.pow(10000, twoi / self.d)
        position = torch.arange(0, max_len).unsqueeze(1)
        sin_p = torch.sin(position / pow_)
        cos_p = torch.cos(position / pow_)
        pe = torch.zeros(max_len, self.d, requires_grad=False) # Explicit, register buffer insures requires grad = False
        pe[:, 0::2] = sin_p
        pe[:, 1::2] = cos_p
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe) 

    def forward(self, x):
        b, t, d = x.size()
        x = x + self.pe[:, : t, :]
        return self.dropout(x)
print(PE(d=4)(torch.arange(24).view(-1, 3, 4)).size()) # torch.Size([2, 3, 4])

torch.Size([2, 3, 4])


In [6]:
from torch import nn

class PEEmbed(nn.Module):

    def __init__(self, d: int = 512, max_len: int = 1024, dropout: float = 0.1):
        super().__init__()
        self.d = d
        self.pe = nn.Embedding(max_len, d)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        b, t, d = x.size()
        pos = self.pe(torch.arange(t))
        x = x + pos
        return self.dropout(x)
print(PEEmbed(d=4)(torch.arange(24).view(-1, 3, 4)).size()) # torch.Size([2, 3, 4])

torch.Size([2, 3, 4])


In [7]:
# Encoder without mask

In [8]:
import torch.nn.functional as F

class EncoderLayerWithoutMask(nn.Module): 

    def __init__(self, d: int = 512, h: int = 8, dropout: float = 0.1):
        super().__init__()
        self.mhsa = MHSA(d, h)
        self.norm1 = nn.LayerNorm(d)
        self.ff1 = nn.Linear(d, d * 4)
        self.ff2 = nn.Linear(d * 4, d)
        self.norm2 = nn.LayerNorm(d)
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)

    def forward(self, x):
        b, t, d = x.size()
        x = x + self.attn_dropout(self.mhsa(x, x, x))
        x = self.norm1(x)
        x = x + self.resid_dropout(self.ff2(F.relu(self.ff1(x))))
        x = self.norm2(x)
        return x


encoder_layer = EncoderLayerWithoutMask()
x = torch.rand(2, 3, 512)
encoder_layer(x).shape

torch.Size([2, 3, 512])

In [9]:
from torch import nn

class EncoderWithoutMask(nn.Module): 

    def __init__(self, vocab_size: int = 2**13, n: int = 6, d: int = 512, h: int = 8):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d)
        self.pe = PE(d=d)
        self.layers = [EncoderLayerWithoutMask(d, h) for _ in range(n)]

    def forward(self, x):
        b, t = x.size()
        x = self.embed(x)
        x = self.pe(x)
        for layer in self.layers:
            x = layer(x)
        return x

encoder = EncoderWithoutMask()
x = torch.randint(0, 2**13, (2, 3))
encoder(x).shape

torch.Size([2, 3, 512])

In [10]:
import torch

def reshape_mask(mask):
    # b t -> b 1 1 t (to be broadcastable to b h t t)
    return mask.unsqueeze(1).unsqueeze(1)

x = torch.rand(2, 3)
print(reshape_mask(torch.tensor([[1, 0, 0], [1, 1, 0]])))

tensor([[[[1, 0, 0]]],


        [[[1, 1, 0]]]])


In [11]:
# With masks
import torch
import math
import torch.nn.functional as F

def self_attention_masked(q, k, v, mask=None, verbose=False):
    # if 3 dim: b t d
    #prod = Q.bmm(K.permute(0, 2, 1))
    # or
    # prod = torch.einsum("btd, bsd -> bts", q, k)
    # if 4 dim: b t h dh:
    prod = torch.einsum("bthd, bshd -> bhts", q, k)
    scaled_prod = prod/torch.sqrt(torch.tensor(q.shape[-1]))
    if verbose:
        print(f"scaled_prod.shape: \n {scaled_prod.shape}")
    # mask should be in shape to be broadcastable to bhts and lead to masked keys only (last s dim), e.g. # b t -> b 1 1 t
    if mask is not None:
        mask = mask if scaled_prod.dim() == mask.dim() else reshape_mask(mask)
        scaled_prod = scaled_prod.masked_fill(mask == 0, float("-inf"))
    if verbose:
        print(f"scaled_prod: \n {scaled_prod}")
    softmaxed_prod = F.softmax(scaled_prod, dim=-1)
    if verbose:
        print(f"softmaxed_prod: \n {softmaxed_prod}")
    # swap h and t in v
    return softmaxed_prod @ v.permute(0, 2, 1, 3)


In [12]:
# Mask

In [13]:
# play with mask

x = torch.rand([2, 3, 2, 4])
print(x)
# mask 2 batches 3 timeseries
mask = torch.ones([2, 3])
mask[0, 2] = 0
mask[1, 2] = 0
mask[1, 1] = 0
print(f"mask: \n {mask}")
# add head dim to make mask broatcastable to q x k.T prod. mask shape 2, 1, 3
mask = mask.unsqueeze(1)


# mask = mask.permute(0, 2, 1)
# is the mask that I need? keys are ignored?
print(f"wrong mask: \n {mask}")
#  mask = 2 1 3 -> b prepended before broadcasting (1!!!) h (remains since already 2) t (broadcasted from 1) d (remains since already 3) 
print(f"wrong mask broadcast: \n {mask.broadcast_to([2, 2, 3, 3])}") 
a = self_attention_masked(x, x, x, mask=mask, verbose=True)
print(f"wrong a: \n {a}" )
print(f"wrong a.shape: \n {a.shape}")
# leads to wrong attention since the shape of mask is wrong 2 1 3 

# correct mask
# mask 2 batches 3 timeseries
mask = torch.ones([2, 3])
mask[0, 2] = 0
mask[1, 2] = 0
mask[1, 1] = 0
mask = mask.unsqueeze(1).unsqueeze(1)

print(f"mask: \n {mask}")
#  mask = 2 1 1 3 -> b (remains already 2) h (broadcasted from 1) t (broadcasted from 1) d (remains since already 3) 
print(f"mask broadcast: \n {mask.broadcast_to([2, 2, 3, 3])}") 
a = self_attention_masked(x, x, x, mask=mask, verbose=True)
print(f"a: \n {a}" )
print(f"a.shape: \n {a.shape}")


tensor([[[[0.9161, 0.5284, 0.3874, 0.0296],
          [0.2724, 0.6993, 0.6833, 0.8454]],

         [[0.9751, 0.4556, 0.5572, 0.5613],
          [0.5171, 0.4448, 0.5494, 0.1597]],

         [[0.1415, 0.6000, 0.2259, 0.7091],
          [0.8315, 0.2869, 0.9799, 0.6451]]],


        [[[0.8202, 0.5965, 0.0961, 0.8646],
          [0.5152, 0.0899, 0.2010, 0.3476]],

         [[0.3940, 0.4795, 0.6472, 0.3459],
          [0.3430, 0.9845, 0.6830, 0.2335]],

         [[0.4451, 0.0086, 0.3697, 0.7003],
          [0.1723, 0.1358, 0.8003, 0.7740]]]])
mask: 
 tensor([[1., 1., 0.],
        [1., 0., 0.]])
wrong mask: 
 tensor([[[1., 1., 0.]],

        [[1., 0., 0.]]])
wrong mask broadcast: 
 tensor([[[[1., 1., 0.],
          [1., 1., 0.],
          [1., 1., 0.]],

         [[1., 0., 0.],
          [1., 0., 0.],
          [1., 0., 0.]]],


        [[[1., 1., 0.],
          [1., 1., 0.],
          [1., 1., 0.]],

         [[1., 0., 0.],
          [1., 0., 0.],
          [1., 0., 0.]]]])
scaled_prod.shape

In [14]:
# mask is equal to making keys on masked places 0:
# the result in terms of masked symbols is the same
k = x.clone()
k[0, 2, 0, :] = float("-inf")
k[0, 2, 1, :] = float("-inf")
k[1, 2, 0, :] = float("-inf")
k[1, 1, 0, :] = float("-inf")
k[1, 2, 1, :] = float("-inf")
k[1, 1, 1, :] = float("-inf")
print(f"k: \n {k}")
a = self_attention_masked(x, k, x, verbose=True)
print(f"a: \n {a}" )
print(f"a.shape: \n {a.shape}")
# a is the same shape as if mask was applied in q * k:

test = torch.rand([2, 3, 4])
test[0, 2, :] = 0
test[1, 1, :] = 0
test[1, 2, :] = 0

print(f"test: \n {test}")
test_v = test.view(2, 3, 2, 2)
print(f"test_v: \n {test_v}")
test_perm = test_v.permute(0, 2, 1, 3)
print(f"test_perm: \n {test_perm}")

# or like that:
test_q = torch.rand([2, 3, 4])
test_k = test_q.clone()
test_k[0, 2, :] = float("-inf")
test_k[1, 1, :] = float("-inf")
test_k[1, 2, :] = float("-inf")
print(f"test_k: \n {test_k}")

test_q_view = test_q.view(2, 3, 2, 2)
test_k_view = test_k.view(2, 3, 2, 2)
print(f"test_k_view: \n {test_k_view}")
test_q_perm = test_q_view.permute(0, 2, 1, 3)
test_k_perm = test_k_view.permute(0, 2, 1, 3)
print(f"test_k_perm: \n {test_k_perm}")
print(f"q * k: \n {torch.einsum("bhtd, bhsd -> bhts", test_q_perm, test_k_perm)}")

k: 
 tensor([[[[0.9161, 0.5284, 0.3874, 0.0296],
          [0.2724, 0.6993, 0.6833, 0.8454]],

         [[0.9751, 0.4556, 0.5572, 0.5613],
          [0.5171, 0.4448, 0.5494, 0.1597]],

         [[  -inf,   -inf,   -inf,   -inf],
          [  -inf,   -inf,   -inf,   -inf]]],


        [[[0.8202, 0.5965, 0.0961, 0.8646],
          [0.5152, 0.0899, 0.2010, 0.3476]],

         [[  -inf,   -inf,   -inf,   -inf],
          [  -inf,   -inf,   -inf,   -inf]],

         [[  -inf,   -inf,   -inf,   -inf],
          [  -inf,   -inf,   -inf,   -inf]]]])
scaled_prod.shape: 
 torch.Size([2, 2, 3, 3])
scaled_prod: 
 tensor([[[[0.6347, 0.6833,   -inf],
          [0.6833, 0.8920,   -inf],
          [0.2776, 0.4676,   -inf]],

         [[0.8724, 0.4811,   -inf],
          [0.4811, 0.3963,   -inf],
          [0.8210, 0.5995,   -inf]]],


        [[[0.8927,   -inf,   -inf],
          [0.4852,   -inf,   -inf],
          [0.5056,   -inf,   -inf]],

         [[0.2174,   -inf,   -inf],
          [0.2418,   -i

In [15]:
import torch

def build_padding_mask(x, pad_token):
    # x: b t shape
    mask = torch.ones_like(x)
    return mask.masked_fill(x == pad_token, 0)

x = torch.rand(5, 6)
x[0, -3:] = 100
x[1, -2:] = 100
x[2, -1] = 100
x[3, :] = 100
print(x)
print(build_padding_mask(x, 100))

tensor([[6.2186e-01, 5.9362e-01, 7.6366e-01, 1.0000e+02, 1.0000e+02, 1.0000e+02],
        [9.9286e-02, 4.0288e-01, 9.4348e-01, 6.2622e-01, 1.0000e+02, 1.0000e+02],
        [7.6083e-03, 8.4606e-01, 4.7163e-01, 4.3363e-01, 6.3603e-02, 1.0000e+02],
        [1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02],
        [1.9430e-01, 5.8806e-01, 5.0770e-02, 9.1432e-01, 1.7868e-02, 6.9063e-02]])
tensor([[1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [16]:
import torch

def build_causal_mask(x):
    # x: b t shape
    m = torch.ones_like(x)
    return torch.tril(m)
x = torch.rand(5, 6)

print(build_causal_mask(x))

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.]])


In [17]:
import torch

def merge_masks(m1, m2):
    return m1 * m2

x = torch.rand(5, 6)
x[0, -3:] = 100
x[1, -1] = 100
x[2, -4:] = 100
x[3, :] = 100
print(x)
m1 = build_padding_mask(x, 100)
m2 = build_causal_mask(x)
print(merge_masks(m1, m2))

tensor([[7.1616e-01, 8.2191e-01, 9.6550e-01, 1.0000e+02, 1.0000e+02, 1.0000e+02],
        [5.9104e-01, 7.7429e-01, 8.2443e-02, 8.7411e-01, 9.7018e-03, 1.0000e+02],
        [4.9149e-01, 7.8321e-01, 1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02],
        [1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02],
        [7.4295e-01, 3.3757e-01, 9.9551e-01, 9.0420e-01, 5.5733e-02, 8.4830e-02]])
tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0.]])


In [18]:
from torch import nn

class MHSAMasked(nn.Module):
    def __init__(self, d: int = 512, h: int = 8):
        super().__init__()
        assert d % h == 0
        self.d = d
        self.dh = d // h
        self.h = h
        self.wq = nn.Linear(self.d, self.d)
        self.wk = nn.Linear(self.d, self.d)
        self.wv = nn.Linear(self.d, self.d)
        self.wo = nn.Linear(self.d, self.d)
 
    def forward(self, q, k, v, mask = None):
        # q and k/v might be of different sizes if lengths of decoder and encoders inputs are different
        bq, tq, dq = q.size()
        bk, tk, dk = k.size()
        wq = self.wq(q)
        wk = self.wk(k)
        wv = self.wv(v)
        wq = wq.view(bq, tq, self.h, self.dh)
        wk = wk.view(bk, tk, self.h, self.dh)
        wv = wv.view(bk, tk, self.h, self.dh)
        # b, t, h, dh
        # if changing from 4 dim -> 3 dim: b*h, t, dh
        # wq = wq.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # wk = wk.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # wv = wv.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # another option 4 dim -> 3 dim
        # wq = wq.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # wk = wk.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # wv = wv.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # changing the number of dims is not necessary as @ supports 4 dims
        attn = self_attention_masked(wq, wk, wv, mask=mask)
        # b * h, t, dh
        # attn = attn.view(b, self.h, t, self.dh).permute(0, 2, 1, 3).reshape(b, t, d)
        attn = attn.view(bq, self.h, tq, self.dh).transpose(1, 2).contiguous().view(bq, tq, dq)
        wo = self.wo(attn)
        return wo
        # # 1 2 3 4
        # x = F.relu(self.conv1(x))
        # return F.relu(self.conv2(x))

mhsa_masked = MHSAMasked(h = 2, d = 6)
x = torch.rand(4, 5)
mask = reshape_mask(build_causal_mask(x))
print(mask)
x = torch.rand(4, 5, 6)
print(mhsa_masked(x, x, x, mask=mask))
print(mhsa_masked(x, x, x, mask=mask).shape)

tensor([[[[1., 0., 0., 0., 0.]]],


        [[[1., 1., 0., 0., 0.]]],


        [[[1., 1., 1., 0., 0.]]],


        [[[1., 1., 1., 1., 0.]]]])
tensor([[[ 0.1825,  0.4952, -0.2641,  0.0693, -0.0135, -0.0194],
         [ 0.1825,  0.4952, -0.2641,  0.0693, -0.0135, -0.0194],
         [ 0.1825,  0.4952, -0.2641,  0.0693, -0.0135, -0.0194],
         [ 0.1825,  0.4952, -0.2641,  0.0693, -0.0135, -0.0194],
         [ 0.1825,  0.4952, -0.2641,  0.0693, -0.0135, -0.0194]],

        [[ 0.0600,  0.3445, -0.2016,  0.0677, -0.0371,  0.1043],
         [ 0.0600,  0.3446, -0.2015,  0.0677, -0.0371,  0.1043],
         [ 0.0599,  0.3446, -0.2014,  0.0678, -0.0370,  0.1044],
         [ 0.0599,  0.3446, -0.2014,  0.0678, -0.0370,  0.1044],
         [ 0.0600,  0.3445, -0.2016,  0.0677, -0.0371,  0.1043]],

        [[ 0.1816,  0.3811, -0.2190,  0.1883, -0.1887,  0.1405],
         [ 0.1831,  0.3807, -0.2171,  0.1906, -0.1894,  0.1390],
         [ 0.1822,  0.3804, -0.2175,  0.1902, -0.1886,  0.1432],
        

In [19]:
# Transformer implementation 

In [20]:
import torch.nn.functional as F

class EncoderLayer(nn.Module): 

    def __init__(self, d: int = 512, h: int = 8, dropout: float = 0.1):
        super().__init__()
        self.mhsa = MHSAMasked(d, h)
        self.norm1 = nn.LayerNorm(d)
        self.ff1 = nn.Linear(d, d * 4)
        self.ff2 = nn.Linear(d * 4, d)
        self.norm2 = nn.LayerNorm(d)
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)

    def forward(self, x, self_mask=None):
        b, t, d = x.size()
        x = x + self.attn_dropout(self.mhsa(x, x, x, mask=self_mask))
        x = self.norm1(x)
        x = x + self.resid_dropout(self.ff2(F.relu(self.ff1(x))))
        x = self.norm2(x)
        return x


encoder_layer = EncoderLayer()
self_mask = build_padding_mask(torch.tensor([[2, 2, 0], [2, 0, 0]]), pad_token=0)
self_mask = reshape_mask(self_mask)
x = torch.rand(2, 3, 512)

encoder_layer(x, self_mask=self_mask).shape

torch.Size([2, 3, 512])

In [21]:
from torch import nn

class Encoder(nn.Module): 

    def __init__(self, vocab_size: int = 2**13, n: int = 6, d: int = 512, h: int = 8):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d)
        self.pe = PE(d=d)
        self.layers = nn.ModuleList([EncoderLayer(d, h) for _ in range(n)])

    def forward(self, x, self_mask = None):
        b, t = x.size()
        x = self.embed(x)
        x = self.pe(x)
        for layer in self.layers:
            x = layer(x, self_mask=self_mask)
        return x

encoder = Encoder()
x = torch.randint(0, 2**13, (2, 3))
self_mask = build_padding_mask(torch.tensor([[2, 2, 0], [2, 0, 0]]), pad_token=0)
self_mask = reshape_mask(self_mask)
encoder(x, self_mask).shape

torch.Size([2, 3, 512])

In [22]:
import torch.nn.functional as F

class DecoderLayer(nn.Module): 

    def __init__(self, d: int = 512, h: int = 8, dropout: float = 0.1):
        super().__init__()
        self.mhsa = MHSAMasked(d=d, h=h)
        self.attn_norm = nn.LayerNorm(d)
        self.attn_dropout = nn.Dropout(dropout)

        self.mhca = MHSAMasked(d=d, h=h)
        self.cross_attn_norm = nn.LayerNorm(d)
        self.cross_attn_dropout = nn.Dropout(dropout)
        
        self.ff1 = nn.Linear(d, d * 4)
        self.ff2 = nn.Linear(d * 4, d)
        self.resid_dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(d)
        

    def forward(self, dec_x, enc_x, self_mask=None, cross_mask=None):
        # self_mask is merged decoders padding and causal masks
        # cross_mask is equal to endcoders padding mask because we don't want to attend to encoded padded tokens
        b, t, d = dec_x.size()
        x = dec_x + self.attn_dropout(self.mhsa(dec_x, dec_x, dec_x, mask=self_mask))
        x = self.attn_norm(x)

        x = x + self.cross_attn_dropout(self.mhca(x, enc_x, enc_x, mask=cross_mask))
        x = self.cross_attn_norm(x)
        
        x = x + self.resid_dropout(self.ff2(F.relu(self.ff1(x))))
        x = self.norm(x)
        return x


decoder_layer = DecoderLayer(h=2, d=16)
x = torch.rand(3, 3, 16)
y = torch.rand(3, 3, 16)
self_mask1 = build_padding_mask(torch.tensor([[2, 2, 0], [2, 0, 0], [2, 2, 0]]), pad_token=0)
self_mask2 = build_causal_mask(torch.tensor([[2, 2, 0], [2, 0, 0], [2, 2, 0]]))
self_mask = merge_masks(self_mask1, self_mask2)
print(f"self_mask: \n {self_mask}")
self_mask = reshape_mask(self_mask)

cross_mask = build_padding_mask(torch.tensor([[2, 2, 2], [2, 0, 0], [2, 2, 0]]), pad_token=0)
cross_mask = reshape_mask(cross_mask)
print(f"cross_mask: \n {cross_mask}")
decoder_layer(x, y, self_mask=self_mask, cross_mask=cross_mask).shape

self_mask: 
 tensor([[1, 0, 0],
        [1, 0, 0],
        [1, 1, 0]])
cross_mask: 
 tensor([[[[1, 1, 1]]],


        [[[1, 0, 0]]],


        [[[1, 1, 0]]]])


torch.Size([3, 3, 16])

In [23]:
from torch import nn

class Decoder(nn.Module): 

    def __init__(self, vocab_size: int = 2**13, n: int = 6, d: int = 512, h: int = 8):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d)
        self.pe = PE(d=d)
        self.layers = nn.ModuleList([DecoderLayer(d, h) for _ in range(n)])

    def forward(self, dec_x, enc_x, self_mask=None, cross_mask=None):
        b, t = dec_x.size()
        x = self.embed(dec_x)
        x = self.pe(x)
        for layer in self.layers:
            x = layer(x, enc_x, self_mask=self_mask, cross_mask=cross_mask)
        return x

    def get_embed_weights(self):
        return self.embed.weight

decoder = Decoder(vocab_size=32, n=2, d=16, h=2)
# x = torch.randint(0, 32, (2, 3))
x = torch.tensor([[15, 7, 0], [10, 0, 0], [1, 3, 0]])
y = torch.rand(3, 3, 16)

self_mask1 = build_padding_mask(x, pad_token=0)
self_mask2 = build_causal_mask(x)
self_mask = merge_masks(self_mask1, self_mask2)
print(f"self_mask: \n {self_mask}")
self_mask = reshape_mask(self_mask)

cross_mask = build_padding_mask(torch.tensor([[2, 2, 2], [2, 0, 0], [2, 2, 0]]), pad_token=0)
cross_mask = reshape_mask(cross_mask)
print(f"cross_mask: \n {cross_mask}")
print(decoder(x, y, self_mask=self_mask, cross_mask=cross_mask).shape)

self_mask: 
 tensor([[1, 0, 0],
        [1, 0, 0],
        [1, 1, 0]])
cross_mask: 
 tensor([[[[1, 1, 1]]],


        [[[1, 0, 0]]],


        [[[1, 1, 0]]]])
torch.Size([3, 3, 16])


In [24]:
from torch import nn

class Output(nn.Module):

    def __init__(self, vocab_size: int = 2**13, d: int = 512, ff_weight = None):
        super().__init__()
        self.ff = nn.Linear(d, vocab_size)
        # weight tying with the decoder embedding
        if ff_weight is not None:
            self.ff.weight = ff_weight

    def forward(self, x):
        return self.ff(x)

In [25]:
from torch import nn

class Transformer(nn.Module):
    
    def __init__(self, vocab_size: int = 2**13, n: int = 6, d: int = 512, h: int = 8, embed_tying=True):
        super().__init__()
        self.encoder = Encoder(vocab_size=vocab_size, n=n, d=d, h=h)
        self.decoder = Decoder(vocab_size=vocab_size, n=n, d=d, h=h)
        if embed_tying:
            self.output = Output(vocab_size=vocab_size, d=d, ff_weight = self.decoder.get_embed_weights())
        else:
            self.output = Output(vocab_size=vocab_size, d=d)

    def forward(self, enc_x, dec_x, enc_mask=None, dec_mask=None):
        encoded = self.encoder(enc_x, enc_mask)
        decoded = self.decoder(dec_x=dec_x, enc_x=encoded, self_mask=dec_mask, cross_mask=enc_mask)
        return self.output(decoded)

transformer = Transformer(vocab_size=32, n=2, d=16, h=2, embed_tying=False)
enc_x = torch.tensor([[15, 7, 3], [10, 10, 0], [1, 0, 0]])
dec_x = torch.tensor([[21, 8, 0, 0], [25, 0, 0, 0], [8, 1, 2, 3]])
# dec_x = torch.tensor([[21, 8], [25, 0], [8, 1]])

enc_mask = build_padding_mask(enc_x, pad_token=0)
print(f"enc_mask: \n {enc_mask}")
enc_mask = reshape_mask(enc_mask)

dec_mask1 = build_padding_mask(dec_x, pad_token=0)
dec_mask2 = build_causal_mask(dec_x)
dec_mask = merge_masks(dec_mask1, dec_mask2)
print(f"dec_mask: \n {dec_mask}")
dec_mask = reshape_mask(dec_mask)

print(transformer(enc_x, dec_x, enc_mask=enc_mask, dec_mask=dec_mask).shape)

enc_mask: 
 tensor([[1, 1, 1],
        [1, 1, 0],
        [1, 0, 0]])
dec_mask: 
 tensor([[1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 1, 1, 0]])
torch.Size([3, 4, 32])


In [26]:
# Inference

In [27]:
import tiktoken
import torch
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F


encoding = tiktoken.get_encoding("cl100k_base")
sents = ["Hello World", "This is a simple sentence", "Me"]
encoded_sents = [encoding.encode(s) for s in sents]
enc_x = pad_sequence([torch.tensor(es) for es in encoded_sents], batch_first=True, padding_value=encoding.eot_token)
print(enc_x)
dec_sents = ["Bonjour", "C'est une phrase", "START"]
dec_encoded_sents = [encoding.encode(s) for s in dec_sents]
dec_x = pad_sequence([torch.tensor(es) for es in dec_encoded_sents], batch_first=True, padding_value=encoding.eot_token)
print(dec_x)

transformer = Transformer(vocab_size=encoding.n_vocab, n=3, d=256, h=4)

enc_mask = build_padding_mask(enc_x, pad_token=100257)
print(f"enc_mask: \n {enc_mask}")
enc_mask = reshape_mask(enc_mask)

dec_mask1 = build_padding_mask(dec_x, pad_token=100257)
dec_mask2 = build_causal_mask(dec_x)
dec_mask = merge_masks(dec_mask1, dec_mask2)
print(f"dec_mask: \n {dec_mask}")
dec_mask = reshape_mask(dec_mask)

output = transformer(enc_x, dec_x, enc_mask=enc_mask, dec_mask=dec_mask)
print(f"output shape: {output.shape}")
softmaxed = F.softmax(output, dim=-1)
print(f"softmaxed[0, 0, :10]: {softmaxed[0, 0, :10]}")
predicted = softmaxed.argmax(dim=-1)
print(f"predicted: \n {predicted}")

predicted_list = predicted.tolist()
predicted_decoded = [encoding.decode(l) for l in predicted_list]
print(f"predicted decoded: \n {predicted_decoded}")

tensor([[  9906,   4435, 100257, 100257, 100257],
        [  2028,    374,    264,   4382,  11914],
        [  7979, 100257, 100257, 100257, 100257]])
tensor([[ 82681, 100257, 100257, 100257],
        [    34,  17771,   6316,  17571],
        [ 23380, 100257, 100257, 100257]])
enc_mask: 
 tensor([[1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1],
        [1, 0, 0, 0, 0]])
dec_mask: 
 tensor([[1, 0, 0, 0],
        [1, 1, 0, 0],
        [1, 0, 0, 0]])
output shape: torch.Size([3, 4, 100277])
softmaxed[0, 0, :10]: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], grad_fn=<SliceBackward0>)
predicted: 
 tensor([[ 82681, 100257, 100257, 100257],
        [    34,  17771,   6316,  17571],
        [ 23380, 100257, 100257, 100257]])
predicted decoded: 
 ['Bonjour<|endoftext|><|endoftext|><|endoftext|>', "C'est une phrase", 'START<|endoftext|><|endoftext|><|endoftext|>']


In [28]:
# Predicting next words
sent = "This is a simple sentence"
encoded_sent = encoding.encode(sent)
enc_x = torch.tensor(encoded_sent).unsqueeze(0)
dec_x = torch.tensor(encoding.encode("C")).unsqueeze(0)

transformer = Transformer(vocab_size=encoding.n_vocab, n=3, d=256, h=4)

predicted_tokens = []
for _ in range(5):
    output = transformer(enc_x=enc_x, dec_x=dec_x)
    softmaxed = F.softmax(output, dim=-1)
    predicted = softmaxed.argmax(dim=-1)
    predicted_tokens.append(predicted.tolist()[-1][-1]) 
    dec_x = torch.cat((dec_x, predicted), dim=-1)

print(predicted_tokens)
print(f"predicted sentence: \n {encoding.decode(predicted_tokens)}")

[34, 34, 34, 34, 34]
predicted sentence: 
 CCCCC


In [29]:
# Dataset
import torch
from torch.utils.data import Dataset
from pathlib import Path
import csv
from enum import Enum



class Partition(Enum):
    TRAIN = "train"
    VAL = "val"

class Tokens(Enum):
    START = "START "
    END = "<|endoftext|>"
    PAD = " PAD"
    START_NUM = 23380
    END_NUM = 100257
    PAD_NUM = 62854
    

class EnFrDataset(Dataset):

    def __init__(self, file: Path | str, partition: Partition = Partition.TRAIN, val_ratio: float = 0.1):
        # partition = TRAIN | VAL
        self._partition = partition
        self._val_ratio = val_ratio

        self._data = []
        self._train_map: dict[int, int] = {}
        self._val_map: dict[int, int] = {}
        train_id = 0
        val_id = 0
        with open(file, newline='') as csvfile:
            reader = csv.reader(csvfile)
            # we want data indexes start from 0, but filter out the first header row
            for i, row in enumerate(reader, start=-1):
                if i == -1:
                    continue
                en = row[0]
                fr = Tokens.START.value + row[1]
                self._data.append(tuple([en, fr]))
                if int(i * val_ratio) == int((i - 1) * val_ratio):
                    self._train_map[train_id] = i
                    train_id += 1
                else:
                    self._val_map[val_id] = i
                    val_id += 1

    class Iterator():

        def __init__(self, outer):
            self.cur = 0
            self.outer = outer

        def __next__(self):
            if self.cur == len(self.outer._data):
                raise StopIteration()
            cur = self.outer._data[self.cur]
            self.cur += 1
            return cur

    def __iter__(self):
        return EnFrDataset.Iterator(self)
    
    @property
    def partition(self):
        return self._partition

    @partition.setter
    def partition(self, partition):
        self._partition = partition
    
    def __len__(self):
        return len(self._train_map) if self._partition == Partition.TRAIN else len(self._val_map)

    def __getitem__(self, idx):
        return self._data[self._train_map[idx]] if self._partition == Partition.TRAIN else self._data[self._val_map[idx]]

dataset = EnFrDataset("../data/eng_-french.csv", val_ratio=0.1)
train_sample = dataset[0]
dataset.partition = Partition.VAL
val_sample = dataset[0]
assert train_sample != val_sample
print(train_sample)
print(val_sample)

for i, d in enumerate(dataset):
    if i > 2:
        break
    print(d)

('Hi.', 'START Salut!')
('Stop!', 'START Arrête-toi !')
('Hi.', 'START Salut!')
('Run!', 'START Cours\u202f!')
('Run!', 'START Courez\u202f!')


In [30]:
class TokEnFrDataset(Dataset):

    @staticmethod
    def build_train_sample(en_str: str, dec_str: str):
        en_encoded = encoding.encode(en_str)
        dec_encoded = encoding.encode(dec_str)
        dec_encoded.append(Tokens.END_NUM.value)
        en_sents = []
        dec_sents = []
        target_sents = []
        
        for i in range(1, len(dec_encoded)):
            dec_sents.append(dec_encoded[:i])
            target_sents.append(dec_encoded[1: i + 1])
        en_sents.extend([en_encoded] * len(dec_sents))
        return list(zip(en_sents, dec_sents, target_sents))

    def __init__(self, file: Path | str, partition: Partition = Partition.TRAIN, val_ratio: float = 0.1):
        self._dataset = EnFrDataset(file, partition, val_ratio=0)
        # partition = TRAIN | VAL
        self._partition = partition
        self._val_ratio = val_ratio

        self._data = []
        self._train_map: dict[int, int] = {}
        self._val_map: dict[int, int] = {}
        train_id = 0
        val_id = 0
        i = 0
        for en, fr in self._dataset:
            for sample in self.build_train_sample(en, fr):
                self._data.append(sample)
                if int(i * val_ratio) == int((i - 1) * val_ratio):
                    self._train_map[train_id] = i
                    train_id += 1
                else:
                    self._val_map[val_id] = i
                    val_id += 1
                i += 1

    @property
    def partition(self):
        return self._partition

    @partition.setter
    def partition(self, partition):
        self._partition = partition
    
    def __len__(self):
        return len(self._train_map) if self._partition == Partition.TRAIN else len(self._val_map)

    def __getitem__(self, idx):
        return self._data[self._train_map[idx]] if self._partition == Partition.TRAIN else self._data[self._val_map[idx]]

dataset = TokEnFrDataset("../data/eng_-french.csv", val_ratio=0.1)
train_sample = dataset[0]
dataset.partition = Partition.VAL
val_sample = dataset[0]
assert train_sample != val_sample
print(train_sample)
print(val_sample)

for i, d in enumerate(dataset):
    if i > 2:
        break
    print(d)

([13347, 13], [23380], [8375])
([6869, 0], [23380], [18733])
([6869, 0], [23380], [18733])
([36981, 0], [23380, 64105], [64105, 64])
([12978, 0], [23380], [65381])


In [31]:
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence


def collate(batch):
    # print(batch)
    _x, _y, _label = list(zip(*batch))
    enc_x = pad_sequence([torch.tensor(t) for t in _x], batch_first=True, padding_value=Tokens.PAD_NUM.value)
    dec_x = pad_sequence([torch.tensor(t) for t in _y], batch_first=True, padding_value=Tokens.PAD_NUM.value)
    label = pad_sequence([torch.tensor(t) for t in _label], batch_first=True, padding_value=Tokens.PAD_NUM.value)
    enc_mask = build_padding_mask(enc_x, pad_token=Tokens.PAD_NUM.value)
    dec_mask = build_padding_mask(dec_x, pad_token=Tokens.PAD_NUM.value)
    return enc_x, dec_x, label, enc_mask, dec_mask

training_generator = DataLoader(dataset, collate_fn=collate, batch_size=5, num_workers=0)
for batch in training_generator:
    print(batch)
    break

(tensor([[ 6869,     0],
        [36981,     0],
        [12978,     0],
        [35079,    13],
        [10903,     0]]), tensor([[23380, 62854, 62854],
        [23380, 64105, 62854],
        [23380, 62854, 62854],
        [23380, 16233,  1088],
        [23380, 62854, 62854]]), tensor([[18733, 62854, 62854],
        [64105,    64, 62854],
        [65381, 62854, 62854],
        [16233,  1088,    13],
        [14549, 62854, 62854]]), tensor([[1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1]]), tensor([[1, 0, 0],
        [1, 1, 0],
        [1, 0, 0],
        [1, 1, 1],
        [1, 0, 0]]))


In [None]:
# Training
import torch
import math

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# torch.set_default_device("cpu")
print(f"Device: {device}")

# batch = 16 if memory is not enough
training_params = {
    'collate_fn': collate,
    'batch_size': 16,
    'shuffle': True,
    'num_workers': 0}
max_epochs = 5
model_dir = "../data/"

# Generators
dataset = TokEnFrDataset("../data/eng_-french.csv", val_ratio=0.05)
dataloader = DataLoader(dataset, **training_params)
transformer = Transformer(vocab_size=encoding.n_vocab, n=2, d=128, h=4)
print(f"Number of model's params: {sum(p.numel() for p in transformer.parameters())}")
transformer = transformer.to(device)
loss_fn = F.cross_entropy
optimizer = torch.optim.AdamW(transformer.parameters(), lr=1e-3, weight_decay=1e-4)

def save_model(epoch, i):
    model_path = model_dir + 'model_{}_{}.pt'.format(epoch, i)
    torch.save(transformer.state_dict(), model_path)
    

def train_epoch(epoch):
    running_loss = 0.
    last_loss = 0.
    best_loss = math.inf

    transformer.train(True)
    dataset.partition = Partition.TRAIN

    for i, data in enumerate(dataloader, 1):
        enc_x, dec_x, label, enc_mask, dec_mask = data
        enc_x, dec_x, label, enc_mask, dec_mask = enc_x.to(device), dec_x.to(device), label.to(device), enc_mask.to(device), dec_mask.to(device)
        # Clear grads
        optimizer.zero_grad()

        output = transformer(enc_x, dec_x, enc_mask=enc_mask, dec_mask=dec_mask)
        loss = loss_fn(output.view(-1, encoding.n_vocab), label.view(-1))
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 1000 == 0:
            last_loss = running_loss / 1000
            print('Average batch loss: {}'.format(last_loss))
            running_loss = 0.
            if last_loss < best_loss:
                save_model(epoch, "intermediate")
    print('Average epoch loss: {}'.format(last_loss))
    return last_loss

def validate_epoch():
    running_vloss = 0.0
    transformer.eval()
    dataset.partition = PARTITION.VAL

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(validation_loader):
            enc_x, dec_x, label, enc_mask, dec_mask = vdata
            enc_x, dec_x, label, enc_mask, dec_mask = enc_x.to(device), dec_x.to(device), label.to(device), enc_mask.to(device), dec_mask.to(device)
            output = transformer(enc_x, dec_x, enc_mask=enc_mask, dec_mask=dec_mask)
            vloss = loss_fn(output, label)
            running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('Average valid loss {}'.format(avg_vloss))

    return avg_vloss

def train():
    best_vloss = math.inf
    for epoch in range(max_epochs):
        print('EPOCH {}:'.format(epoch + 1))
        avg_train_loss = train_epoch(epoch)
        
        avg_val_loss = validate_epoch()
        if avg_vloss < best_vloss:
            best_vloss = avg_vloss
            save_model(epoch, "final")

train()

Device: mps
Number of model's params: 26696885
EPOCH 1:
Average batch loss: 0.030220237731933594
Average batch loss: 20.92554894065857
Average batch loss: 7.629583003997802
Average batch loss: 5.8763357303142545
Average batch loss: 5.845679868459701
Average batch loss: 5.8248831729888915
Average batch loss: 6.160519654512405
Average batch loss: 6.20350581240654
Average batch loss: 10.101644933462143
Average batch loss: 7.434640023708344
Average batch loss: 6.75776136803627
Average batch loss: 6.577257626056671
Average batch loss: 5.954030453681946


In [None]:
input = torch.randn((2, 3, 32), requires_grad=True)
print(input.view(-1, 32).shape)
target = torch.empty((2, 3), dtype=torch.long).random_(32)
print(target)
loss = F.cross_entropy(input.view(-1, 32), target.view(-1))
loss

In [None]:
import tiktoken

encoding = tiktoken.get_encoding("cl100k_base")

def build_train_sample(en_str: str, dec_str: str):
    en_encoded = encoding.encode(en_str)
    dec_encoded = encoding.encode(dec_str)
    dec_encoded.append(Tokens.END_NUM.value)
    en_sents = []
    dec_sents = []
    target_sents = []
    
    for i in range(1, len(dec_encoded)):
        dec_sents.append(dec_encoded[:i])
        target_sents.append(dec_encoded[1: i + 1])
    en_sents.extend([en_encoded] * len(dec_sents))
    return list(zip(en_sents, dec_sents, target_sents))

build_train_sample('Hi.', 'START Salut!')

In [None]:
from torch.utils.data import DataLoader

# def collate_fn

dataset.partition = Partition.TRAIN
training_generator = DataLoader(dataset, batch_size=2, shuffle=True)
for i, s in enumerate(training_generator):
    print(s)
    if i > 2:
        break
        
    

In [None]:
torch.empty((2, 3), dtype=torch.long).random_(5)


In [None]:

print(transformer.encoder.layers[0].mhsa.wq.weight.device)
print(next(transformer.parameters()).device)
print(transformer.decoder.layers[0].mhsa.wq.weight.device)
print(next(transformer.parameters()).device)

In [None]:
import torch
print(torch.backends.mps.is_available())

In [None]:
# training

print(torch.backends.mps.is_available())
print(torch.backends.mps.is_built())


In [None]:
class Partition(Enum):
    TRAIN = "train"
    VAL = "val"

In [None]:
F.softmax(torch.tensor([[0.1,0.2,0.3],[0.1, 0.2, 0.3]]), dim=-1)
torch.tensor([[0.1,0.2,0.3],[0.1, 0.2, 0.4]]).argmax(dim=-1)
# torch.max(torch.tensor([[0.1,0.2,0.3],[0.1, 0.2, 0.4]]), dim=-1)

t1 = torch.tensor([[0.1, 0.2]])
t2 = torch.tensor([[0.3]])
torch.cat((t1, t2), dim=1).tolist()[-1][-1]


In [None]:
int(0.1)

In [None]:
assert False

In [None]:
x = torch.rand([1, 2, 3])
mask = torch.ones([1, 2])
mask[0, 1] = 0
mask = mask.unsqueeze(1)
print(mask == 0)
x.masked_fill(mask == 0, float("-inf"))