In [1]:
# Unmasked attention

In [2]:
import torch
import math
import torch.nn.functional as F

def self_attention(q, k, v):
    # if 3 dim: b t d
    #prod = Q.bmm(K.permute(0, 2, 1))
    # or
    # prod = torch.einsum("btd, bsd -> bts", q, k)
    # if 4 dim: b t h dh
    # q = q.permute(0, 2, 1, 3)
    prod = torch.einsum("bthd, bshd -> bhts", q, k)
    scaled_prod = prod/torch.sqrt(torch.tensor(q.shape[-1]))
    softmaxed_prod = F.softmax(scaled_prod, dim=-1)
    print(softmaxed_prod.shape)
    # print(softmaxed_prod)
    return softmaxed_prod @ v.permute(0, 2, 1, 3)


x = torch.rand([2, 3, 4, 5])
self_attention(x, x, x)
self_attention(x, x, x).shape

torch.Size([2, 4, 3, 3])
torch.Size([2, 4, 3, 3])


torch.Size([2, 4, 3, 5])

In [3]:
from torch import nn

class MHSA(nn.Module):
    def __init__(self, d: int = 512, h: int = 8):
        super().__init__()
        assert d % h == 0
        self.d = d
        self.dh = d // h
        self.h = h
        self.wq = nn.Linear(self.d, self.d)
        self.wk = nn.Linear(self.d, self.d)
        self.wv = nn.Linear(self.d, self.d)
        self.wo = nn.Linear(self.d, self.d)
 
    def forward(self, q, k, v):
        # b, t, d
        b, t, d = q.size()
        wq = self.wq(q)
        wk = self.wk(k)
        wv = self.wv(v)
        wq = wq.view(b, t, self.h, self.dh)
        wk = wk.view(b, t, self.h, self.dh)
        wv = wv.view(b, t, self.h, self.dh)
        # b, t, h, dh
        # if changing from 4 dim -> 3 dim: b*h, t, dh
        # wq = wq.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # wk = wk.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # wv = wv.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # another option 4 dim -> 3 dim
        # wq = wq.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # wk = wk.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # wv = wv.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # changing the number of dims is not necessary as @ supports 4 dims
        attn = self_attention(wq, wk, wv)
        # b * h, t, dh
        # attn = attn.view(b, self.h, t, self.dh).permute(0, 2, 1, 3).reshape(b, t, d)
        attn = attn.view(b, self.h, t, self.dh).transpose(1, 2).contiguous().view(b, t, d)
        wo = self.wo(attn)
        return wo
        # # 1 2 3 4
        # x = F.relu(self.conv1(x))
        # return F.relu(self.conv2(x))

mhsa = MHSA()
x = torch.rand(2, 3, 512)
mhsa(x, x, x).shape

torch.Size([2, 8, 3, 3])


torch.Size([2, 3, 512])

In [4]:
# PE

In [5]:
from torch import nn

class PE(nn.Module):

    def __init__(self, d: int = 512, max_len: int = 1024, dropout: float = 0.1):
        super().__init__()
        self.d = d
        self.dropout = nn.Dropout(p=dropout)

        twoi = torch.arange(0, self.d, 2)
        pow_ = torch.pow(10000, twoi / self.d)
        position = torch.arange(0, max_len).unsqueeze(1)
        sin_p = torch.sin(position / pow_)
        cos_p = torch.cos(position / pow_)
        pe = torch.zeros(max_len, self.d, requires_grad=False) # Explicit, register buffer insures requires grad = False
        pe[:, 0::2] = sin_p
        pe[:, 1::2] = cos_p
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe) 

    def forward(self, x):
        b, t, d = x.size()
        x = x + self.pe[:, : t, :]
        return self.dropout(x)
print(PE(d=4)(torch.arange(24).view(-1, 3, 4)).size()) # torch.Size([2, 3, 4])

torch.Size([2, 3, 4])


In [6]:
from torch import nn

class PEEmbed(nn.Module):

    def __init__(self, d: int = 512, max_len: int = 1024, dropout: float = 0.1):
        super().__init__()
        self.d = d
        self.pe = nn.Embedding(max_len, d)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        b, t, d = x.size()
        pos = self.pe(torch.arange(t))
        x = x + pos
        return self.dropout(x)
print(PEEmbed(d=4)(torch.arange(24).view(-1, 3, 4)).size()) # torch.Size([2, 3, 4])

torch.Size([2, 3, 4])


In [7]:
# Encoder without mask

In [8]:
import torch.nn.functional as F

class EncoderLayerWithoutMask(nn.Module): 

    def __init__(self, d: int = 512, h: int = 8, dropout: float = 0.1):
        super().__init__()
        self.mhsa = MHSA(d, h)
        self.norm1 = nn.LayerNorm(d)
        self.ff1 = nn.Linear(d, d * 4)
        self.ff2 = nn.Linear(d * 4, d)
        self.norm2 = nn.LayerNorm(d)
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)

    def forward(self, x):
        b, t, d = x.size()
        x = x + self.attn_dropout(self.mhsa(x, x, x))
        x = self.norm1(x)
        x = x + self.resid_dropout(self.ff2(F.relu(self.ff1(x))))
        x = self.norm2(x)
        return x


encoder_layer = EncoderLayerWithoutMask()
x = torch.rand(2, 3, 512)
encoder_layer(x).shape

torch.Size([2, 8, 3, 3])


torch.Size([2, 3, 512])

In [9]:
from torch import nn

class EncoderWithoutMask(nn.Module): 

    def __init__(self, vocab_size: int = 2**13, n: int = 6, d: int = 512, h: int = 8):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d)
        self.pe = PE(d=d)
        self.layers = [EncoderLayerWithoutMask(d, h) for _ in range(n)]

    def forward(self, x):
        b, t = x.size()
        x = self.embed(x)
        x = self.pe(x)
        for layer in self.layers:
            x = layer(x)
        return x

encoder = EncoderWithoutMask()
x = torch.randint(0, 2**13, (2, 3))
encoder(x).shape

torch.Size([2, 8, 3, 3])
torch.Size([2, 8, 3, 3])
torch.Size([2, 8, 3, 3])
torch.Size([2, 8, 3, 3])
torch.Size([2, 8, 3, 3])
torch.Size([2, 8, 3, 3])


torch.Size([2, 3, 512])

In [10]:
# With masks
import torch
import math
import torch.nn.functional as F

def self_attention_masked(q, k, v, mask=None):
    # if 3 dim: b t d
    #prod = Q.bmm(K.permute(0, 2, 1))
    # or
    # prod = torch.einsum("btd, bsd -> bts", q, k)
    # if 4 dim: b t h dh:
    prod = torch.einsum("bthd, bshd -> bhts", q, k)
    scaled_prod = prod/torch.sqrt(torch.tensor(q.shape[-1]))
    print(f"scaled_prod.shape: \n {scaled_prod.shape}")
    # mask should be in shape to be broadcastable to bhts and lead to masked keys only (last s dim)
    if mask is not None:
        scaled_prod = scaled_prod.masked_fill(mask == 0, float("-inf"))
    print(f"scaled_prod: \n {scaled_prod}")
    softmaxed_prod = F.softmax(scaled_prod, dim=-1)
    # print(softmaxed_prod.shape)
    print(f"softmaxed_prod: \n {softmaxed_prod}")
    # swap h and t in v
    return softmaxed_prod @ v.permute(0, 2, 1, 3)


In [11]:
# Mask

In [12]:
# play with mask

x = torch.rand([2, 3, 2, 4])
print(x)
# mask 2 batches 3 timeseries
mask = torch.ones([2, 3])
mask[0, 2] = 0
mask[1, 2] = 0
mask[1, 1] = 0
print(f"mask: \n {mask}")
# add head dim to make mask broatcastable to q x k.T prod. mask shape 2, 1, 3
mask = mask.unsqueeze(1)


# mask = mask.permute(0, 2, 1)
# is the mask that I need? keys are ignored?
print(f"wrong mask: \n {mask}")
#  mask = 2 1 3 -> b prepended before broadcasting (1!!!) h (remains since already 2) t (broadcasted from 1) d (remains since already 3) 
print(f"wrong mask broadcast: \n {mask.broadcast_to([2, 2, 3, 3])}") 
a = self_attention_masked(x, x, x, mask=mask)
print(f"wrong a: \n {a}" )
print(f"wrong a.shape: \n {a.shape}")
# leads to wrong attention since the shape of mask is wrong 2 1 3 

# correct mask
# mask 2 batches 3 timeseries
mask = torch.ones([2, 3])
mask[0, 2] = 0
mask[1, 2] = 0
mask[1, 1] = 0
mask = mask.unsqueeze(1).unsqueeze(1)

print(f"mask: \n {mask}")
#  mask = 2 1 1 3 -> b (remains already 2) h (broadcasted from 1) t (broadcasted from 1) d (remains since already 3) 
print(f"mask broadcast: \n {mask.broadcast_to([2, 2, 3, 3])}") 
a = self_attention_masked(x, x, x, mask=mask)
print(f"a: \n {a}" )
print(f"a.shape: \n {a.shape}")


tensor([[[[0.9857, 0.7291, 0.6403, 0.0374],
          [0.7394, 0.3984, 0.7374, 0.5068]],

         [[0.5068, 0.0564, 0.7493, 0.8834],
          [0.8811, 0.7136, 0.0284, 0.0695]],

         [[0.2602, 0.4339, 0.2987, 0.8893],
          [0.4672, 0.3633, 0.7519, 0.5286]]],


        [[[0.4021, 0.0263, 0.4238, 0.1005],
          [0.1994, 0.8508, 0.2777, 0.1168]],

         [[0.4193, 0.7159, 0.8073, 0.5569],
          [0.3141, 0.4790, 0.3625, 0.9635]],

         [[0.2780, 0.7656, 0.6667, 0.0436],
          [0.2631, 0.8770, 0.0948, 0.5818]]]])
mask: 
 tensor([[1., 1., 0.],
        [1., 0., 0.]])
wrong mask: 
 tensor([[[1., 1., 0.]],

        [[1., 0., 0.]]])
wrong mask broadcast: 
 tensor([[[[1., 1., 0.],
          [1., 1., 0.],
          [1., 1., 0.]],

         [[1., 0., 0.],
          [1., 0., 0.],
          [1., 0., 0.]]],


        [[[1., 1., 0.],
          [1., 1., 0.],
          [1., 1., 0.]],

         [[1., 0., 0.],
          [1., 0., 0.],
          [1., 0., 0.]]]])
scaled_prod.shape

In [13]:
# mask is equal to making keys on masked places 0:
# the result in terms of masked symbols is the same
k = x.clone()
k[0, 2, 0, :] = float("-inf")
k[0, 2, 1, :] = float("-inf")
k[1, 2, 0, :] = float("-inf")
k[1, 1, 0, :] = float("-inf")
k[1, 2, 1, :] = float("-inf")
k[1, 1, 1, :] = float("-inf")
print(f"k: \n {k}")
a = self_attention_masked(x, k, x)
print(f"a: \n {a}" )
print(f"a.shape: \n {a.shape}")
# a is the same shape as if mask was applied in q * k:

test = torch.rand([2, 3, 4])
test[0, 2, :] = 0
test[1, 1, :] = 0
test[1, 2, :] = 0

print(f"test: \n {test}")
test_v = test.view(2, 3, 2, 2)
print(f"test_v: \n {test_v}")
test_perm = test_v.permute(0, 2, 1, 3)
print(f"test_perm: \n {test_perm}")

# or like that:
test_q = torch.rand([2, 3, 4])
test_k = test_q.clone()
test_k[0, 2, :] = float("-inf")
test_k[1, 1, :] = float("-inf")
test_k[1, 2, :] = float("-inf")
print(f"test_k: \n {test_k}")

test_q_view = test_q.view(2, 3, 2, 2)
test_k_view = test_k.view(2, 3, 2, 2)
print(f"test_k_view: \n {test_k_view}")
test_q_perm = test_q_view.permute(0, 2, 1, 3)
test_k_perm = test_k_view.permute(0, 2, 1, 3)
print(f"test_k_perm: \n {test_k_perm}")
print(f"q * k: \n {torch.einsum("bhtd, bhsd -> bhts", test_q_perm, test_k_perm)}")

k: 
 tensor([[[[0.9857, 0.7291, 0.6403, 0.0374],
          [0.7394, 0.3984, 0.7374, 0.5068]],

         [[0.5068, 0.0564, 0.7493, 0.8834],
          [0.8811, 0.7136, 0.0284, 0.0695]],

         [[  -inf,   -inf,   -inf,   -inf],
          [  -inf,   -inf,   -inf,   -inf]]],


        [[[0.4021, 0.0263, 0.4238, 0.1005],
          [0.1994, 0.8508, 0.2777, 0.1168]],

         [[  -inf,   -inf,   -inf,   -inf],
          [  -inf,   -inf,   -inf,   -inf]],

         [[  -inf,   -inf,   -inf,   -inf],
          [  -inf,   -inf,   -inf,   -inf]]]])
scaled_prod.shape: 
 torch.Size([2, 2, 3, 3])
scaled_prod: 
 tensor([[[[0.9573, 0.5267,   -inf],
          [0.5267, 0.8009,   -inf],
          [0.3987, 0.5829,   -inf]],

         [[0.7530, 0.4960,   -inf],
          [0.4960, 0.6457,   -inf],
          [0.6563, 0.3645,   -inf]]],


        [[[0.1761,   -inf,   -inf],
          [0.2928,   -inf,   -inf],
          [0.2094,   -inf,   -inf]],

         [[0.4272,   -inf,   -inf],
          [0.3417,   -i

In [14]:
import torch

def build_padding_mask(x, pad_token):
    # x: b t shape
    mask = torch.ones_like(x)
    return mask.masked_fill(x == pad_token, 0)

x = torch.rand(5, 6)
x[0, -3:] = 100
x[1, -2:] = 100
x[2, -1] = 100
x[3, :] = 100
print(x)
print(build_padding_mask(x, 100))

tensor([[3.8186e-01, 6.4062e-01, 1.9850e-01, 1.0000e+02, 1.0000e+02, 1.0000e+02],
        [7.0082e-01, 9.3461e-01, 4.5837e-01, 7.5353e-01, 1.0000e+02, 1.0000e+02],
        [5.2100e-01, 2.6563e-01, 7.7681e-01, 6.7967e-01, 1.0554e-01, 1.0000e+02],
        [1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02],
        [8.5110e-02, 7.2043e-01, 4.4257e-01, 6.6892e-01, 4.1118e-01, 9.9156e-02]])
tensor([[1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [15]:
import torch

def build_causal_mask(x):
    # x: b t shape
    m = torch.ones_like(x)
    return torch.tril(m)
x = torch.rand(5, 6)

print(build_causal_mask(x))

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.]])


In [16]:
import torch

def merge_masks(m1, m2):
    return m1 * m2

x = torch.rand(5, 6)
x[0, -3:] = 100
x[1, -1] = 100
x[2, -4:] = 100
x[3, :] = 100
print(x)
m1 = build_padding_mask(x, 100)
m2 = build_causal_mask(x)
print(merge_masks(m1, m2))

tensor([[6.5627e-01, 1.9587e-01, 3.2327e-02, 1.0000e+02, 1.0000e+02, 1.0000e+02],
        [1.2564e-01, 1.4291e-01, 2.9585e-01, 7.7335e-02, 3.9172e-01, 1.0000e+02],
        [7.5887e-01, 6.5872e-01, 1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02],
        [1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02, 1.0000e+02],
        [8.3057e-01, 1.9838e-01, 4.1828e-01, 7.9825e-01, 3.7982e-01, 3.1589e-01]])
tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0.]])


In [17]:
import torch

def reshape_mask(mask):
    # b t -> b 1 1 t (to be broadcastable to b h t t)
    return mask.unsqueeze(1).unsqueeze(1)

x = torch.rand(2, 3)
print(reshape_mask(build_causal_mask(x)))

tensor([[[[1., 0., 0.]]],


        [[[1., 1., 0.]]]])


In [18]:
from torch import nn

class MHSAMasked(nn.Module):
    def __init__(self, d: int = 512, h: int = 8):
        super().__init__()
        assert d % h == 0
        self.d = d
        self.dh = d // h
        self.h = h
        self.wq = nn.Linear(self.d, self.d)
        self.wk = nn.Linear(self.d, self.d)
        self.wv = nn.Linear(self.d, self.d)
        self.wo = nn.Linear(self.d, self.d)
 
    def forward(self, q, k, v, mask = None):
        # q and k/v might be of different sizes if lengths of decoder and encoders inputs are different
        bq, tq, dq = q.size()
        bk, tk, dk = k.size()
        wq = self.wq(q)
        wk = self.wk(k)
        wv = self.wv(v)
        wq = wq.view(bq, tq, self.h, self.dh)
        wk = wk.view(bk, tk, self.h, self.dh)
        wv = wv.view(bk, tk, self.h, self.dh)
        # b, t, h, dh
        # if changing from 4 dim -> 3 dim: b*h, t, dh
        # wq = wq.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # wk = wk.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # wv = wv.permute(0, 2, 1, 3).reshape(b * self.h, t, self.dh)
        # another option 4 dim -> 3 dim
        # wq = wq.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # wk = wk.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # wv = wv.transpose(1, 2).contiguous().view(b * self.h, t, self.dh)
        # changing the number of dims is not necessary as @ supports 4 dims
        attn = self_attention_masked(wq, wk, wv, mask=mask)
        # b * h, t, dh
        # attn = attn.view(b, self.h, t, self.dh).permute(0, 2, 1, 3).reshape(b, t, d)
        attn = attn.view(bq, self.h, tq, self.dh).transpose(1, 2).contiguous().view(bq, tq, dq)
        wo = self.wo(attn)
        return wo
        # # 1 2 3 4
        # x = F.relu(self.conv1(x))
        # return F.relu(self.conv2(x))

mhsa_masked = MHSAMasked(h = 2, d = 6)
x = torch.rand(4, 5)
mask = reshape_mask(build_causal_mask(x))
print(mask)
x = torch.rand(4, 5, 6)
print(mhsa_masked(x, x, x, mask=mask))
print(mhsa_masked(x, x, x, mask=mask).shape)

tensor([[[[1., 0., 0., 0., 0.]]],


        [[[1., 1., 0., 0., 0.]]],


        [[[1., 1., 1., 0., 0.]]],


        [[[1., 1., 1., 1., 0.]]]])
scaled_prod.shape: 
 torch.Size([4, 2, 5, 5])
scaled_prod: 
 tensor([[[[-0.0823,    -inf,    -inf,    -inf,    -inf],
          [-0.1181,    -inf,    -inf,    -inf,    -inf],
          [-0.0663,    -inf,    -inf,    -inf,    -inf],
          [-0.0437,    -inf,    -inf,    -inf,    -inf],
          [-0.1643,    -inf,    -inf,    -inf,    -inf]],

         [[ 0.0333,    -inf,    -inf,    -inf,    -inf],
          [ 0.0413,    -inf,    -inf,    -inf,    -inf],
          [ 0.0197,    -inf,    -inf,    -inf,    -inf],
          [-0.0091,    -inf,    -inf,    -inf,    -inf],
          [ 0.0641,    -inf,    -inf,    -inf,    -inf]]],


        [[[-0.1610, -0.0980,    -inf,    -inf,    -inf],
          [-0.1552, -0.1273,    -inf,    -inf,    -inf],
          [-0.0323, -0.1289,    -inf,    -inf,    -inf],
          [-0.1566, -0.1981,    -inf,    -inf,   

In [19]:
# Transformer implementation 

In [20]:
import torch.nn.functional as F

class EncoderLayer(nn.Module): 

    def __init__(self, d: int = 512, h: int = 8, dropout: float = 0.1):
        super().__init__()
        self.mhsa = MHSAMasked(d, h)
        self.norm1 = nn.LayerNorm(d)
        self.ff1 = nn.Linear(d, d * 4)
        self.ff2 = nn.Linear(d * 4, d)
        self.norm2 = nn.LayerNorm(d)
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)

    def forward(self, x, self_mask=None):
        b, t, d = x.size()
        x = x + self.attn_dropout(self.mhsa(x, x, x, mask=self_mask))
        x = self.norm1(x)
        x = x + self.resid_dropout(self.ff2(F.relu(self.ff1(x))))
        x = self.norm2(x)
        return x


encoder_layer = EncoderLayer()
self_mask = build_padding_mask(torch.tensor([[2, 2, 0], [2, 0, 0]]), pad_token=0)
self_mask = reshape_mask(self_mask)
x = torch.rand(2, 3, 512)

encoder_layer(x, self_mask=self_mask).shape

scaled_prod.shape: 
 torch.Size([2, 8, 3, 3])
scaled_prod: 
 tensor([[[[ 0.0290,  0.0526,    -inf],
          [-0.1071, -0.0288,    -inf],
          [-0.0313,  0.0395,    -inf]],

         [[ 0.1718,  0.1367,    -inf],
          [ 0.2660,  0.2261,    -inf],
          [ 0.2388,  0.1550,    -inf]],

         [[-0.0152,  0.1000,    -inf],
          [-0.0247, -0.0017,    -inf],
          [-0.0012, -0.0629,    -inf]],

         [[-0.0116, -0.0373,    -inf],
          [-0.0524, -0.0300,    -inf],
          [-0.0198, -0.0917,    -inf]],

         [[-0.0982, -0.0242,    -inf],
          [-0.1178, -0.0027,    -inf],
          [-0.0132,  0.0302,    -inf]],

         [[-0.0034, -0.0034,    -inf],
          [-0.0394, -0.0643,    -inf],
          [-0.1337, -0.0938,    -inf]],

         [[-0.0652, -0.0678,    -inf],
          [-0.1324, -0.1568,    -inf],
          [-0.1193, -0.1518,    -inf]],

         [[-0.1068,  0.0435,    -inf],
          [-0.1516, -0.0182,    -inf],
          [-0.1953, -0.0652,

torch.Size([2, 3, 512])

In [21]:
from torch import nn

class Encoder(nn.Module): 

    def __init__(self, vocab_size: int = 2**13, n: int = 6, d: int = 512, h: int = 8):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d)
        self.pe = PE(d=d)
        self.layers = [EncoderLayer(d, h) for _ in range(n)]

    def forward(self, x, self_mask = None):
        b, t = x.size()
        x = self.embed(x)
        x = self.pe(x)
        for layer in self.layers:
            x = layer(x, self_mask=self_mask)
        return x

encoder = Encoder()
x = torch.randint(0, 2**13, (2, 3))
self_mask = build_padding_mask(torch.tensor([[2, 2, 0], [2, 0, 0]]), pad_token=0)
self_mask = reshape_mask(self_mask)
encoder(x, self_mask).shape

scaled_prod.shape: 
 torch.Size([2, 8, 3, 3])
scaled_prod: 
 tensor([[[[-0.0803, -0.1262,    -inf],
          [-0.5837,  0.9467,    -inf],
          [-0.0425,  0.4973,    -inf]],

         [[-0.2168, -0.1899,    -inf],
          [-0.4836,  0.0629,    -inf],
          [ 0.4864,  0.2861,    -inf]],

         [[-0.5599,  0.0706,    -inf],
          [ 0.0283,  0.1414,    -inf],
          [ 0.2566,  0.1005,    -inf]],

         [[-0.1452, -0.6408,    -inf],
          [ 0.0286, -0.6266,    -inf],
          [ 0.7881,  0.4475,    -inf]],

         [[ 0.2094,  0.5539,    -inf],
          [-0.2821, -0.1401,    -inf],
          [-0.0837, -0.3041,    -inf]],

         [[ 0.2413, -0.4109,    -inf],
          [ 0.1411, -0.4013,    -inf],
          [ 0.7659, -0.4850,    -inf]],

         [[ 0.3142,  0.2689,    -inf],
          [-0.6572,  0.0899,    -inf],
          [-0.4500,  0.6181,    -inf]],

         [[-0.4630,  1.2604,    -inf],
          [-0.0337,  0.1610,    -inf],
          [-0.2285,  0.7066,

torch.Size([2, 3, 512])

In [22]:
import torch.nn.functional as F

class DecoderLayer(nn.Module): 

    def __init__(self, d: int = 512, h: int = 8, dropout: float = 0.1):
        super().__init__()
        self.mhsa = MHSAMasked(d=d, h=h)
        self.attn_norm = nn.LayerNorm(d)
        self.attn_dropout = nn.Dropout(dropout)

        self.mhca = MHSAMasked(d=d, h=h)
        self.cross_attn_norm = nn.LayerNorm(d)
        self.cross_attn_dropout = nn.Dropout(dropout)
        
        self.ff1 = nn.Linear(d, d * 4)
        self.ff2 = nn.Linear(d * 4, d)
        self.resid_dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(d)
        

    def forward(self, dec_x, enc_x, self_mask=None, cross_mask=None):
        # self_mask is merged decoders padding and causal masks
        # cross_mask is equal to endcoders padding mask because we don't want to attend to encoded padded tokens
        b, t, d = dec_x.size()
        x = dec_x + self.attn_dropout(self.mhsa(dec_x, dec_x, dec_x, mask=self_mask))
        x = self.attn_norm(x)

        x = x + self.cross_attn_dropout(self.mhca(x, enc_x, enc_x, mask=cross_mask))
        x = self.cross_attn_norm(x)
        
        x = x + self.resid_dropout(self.ff2(F.relu(self.ff1(x))))
        x = self.norm(x)
        return x


decoder_layer = DecoderLayer(h=2, d=16)
x = torch.rand(3, 3, 16)
y = torch.rand(3, 3, 16)
self_mask1 = build_padding_mask(torch.tensor([[2, 2, 0], [2, 0, 0], [2, 2, 0]]), pad_token=0)
self_mask2 = build_causal_mask(torch.tensor([[2, 2, 0], [2, 0, 0], [2, 2, 0]]))
self_mask = merge_masks(self_mask1, self_mask2)
print(f"self_mask: \n {self_mask}")
self_mask = reshape_mask(self_mask)

cross_mask = build_padding_mask(torch.tensor([[2, 2, 2], [2, 0, 0], [2, 2, 0]]), pad_token=0)
cross_mask = reshape_mask(cross_mask)
print(f"cross_mask: \n {cross_mask}")
decoder_layer(x, y, self_mask=self_mask, cross_mask=cross_mask).shape

self_mask: 
 tensor([[1, 0, 0],
        [1, 0, 0],
        [1, 1, 0]])
cross_mask: 
 tensor([[[[1, 1, 1]]],


        [[[1, 0, 0]]],


        [[[1, 1, 0]]]])
scaled_prod.shape: 
 torch.Size([3, 2, 3, 3])
scaled_prod: 
 tensor([[[[ 2.6476e-03,        -inf,        -inf],
          [ 3.9293e-02,        -inf,        -inf],
          [ 3.0088e-02,        -inf,        -inf]],

         [[-1.8792e-02,        -inf,        -inf],
          [-9.6770e-05,        -inf,        -inf],
          [-7.5240e-03,        -inf,        -inf]]],


        [[[ 1.8202e-02,        -inf,        -inf],
          [ 6.0391e-02,        -inf,        -inf],
          [ 4.1730e-02,        -inf,        -inf]],

         [[ 3.2263e-02,        -inf,        -inf],
          [ 5.2906e-02,        -inf,        -inf],
          [ 2.4623e-02,        -inf,        -inf]]],


        [[[ 4.1940e-02,  3.7628e-02,        -inf],
          [ 7.4598e-02,  8.3658e-02,        -inf],
          [ 2.8784e-02,  1.5011e-02,        -inf]],

 

torch.Size([3, 3, 16])

In [23]:
from torch import nn

class Decoder(nn.Module): 

    def __init__(self, vocab_size: int = 2**13, n: int = 6, d: int = 512, h: int = 8):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d)
        self.pe = PE(d=d)
        self.layers = [DecoderLayer(d, h) for _ in range(n)]

    def forward(self, dec_x, enc_x, self_mask=None, cross_mask=None):
        b, t = dec_x.size()
        x = self.embed(dec_x)
        x = self.pe(x)
        for layer in self.layers:
            x = layer(x, enc_x, self_mask=self_mask, cross_mask=cross_mask)
        return x

    def get_embed_weights(self):
        return self.embed.weight

decoder = Decoder(vocab_size=32, n=2, d=16, h=2)
# x = torch.randint(0, 32, (2, 3))
x = torch.tensor([[15, 7, 0], [10, 0, 0], [1, 3, 0]])
y = torch.rand(3, 3, 16)

self_mask1 = build_padding_mask(x, pad_token=0)
self_mask2 = build_causal_mask(x)
self_mask = merge_masks(self_mask1, self_mask2)
print(f"self_mask: \n {self_mask}")
self_mask = reshape_mask(self_mask)

cross_mask = build_padding_mask(torch.tensor([[2, 2, 2], [2, 0, 0], [2, 2, 0]]), pad_token=0)
cross_mask = reshape_mask(cross_mask)
print(f"cross_mask: \n {cross_mask}")
print(decoder(x, y, self_mask=self_mask, cross_mask=cross_mask).shape)

self_mask: 
 tensor([[1, 0, 0],
        [1, 0, 0],
        [1, 1, 0]])
cross_mask: 
 tensor([[[[1, 1, 1]]],


        [[[1, 0, 0]]],


        [[[1, 1, 0]]]])
scaled_prod.shape: 
 torch.Size([3, 2, 3, 3])
scaled_prod: 
 tensor([[[[-1.0749,    -inf,    -inf],
          [-0.4137,    -inf,    -inf],
          [-0.1183,    -inf,    -inf]],

         [[ 0.1288,    -inf,    -inf],
          [-0.8952,    -inf,    -inf],
          [-0.7050,    -inf,    -inf]]],


        [[[ 0.1889,    -inf,    -inf],
          [ 0.6978,    -inf,    -inf],
          [ 1.0497,    -inf,    -inf]],

         [[-0.5553,    -inf,    -inf],
          [-0.4653,    -inf,    -inf],
          [-0.5273,    -inf,    -inf]]],


        [[[-0.0230, -0.2875,    -inf],
          [ 0.3267, -0.7232,    -inf],
          [ 0.5471, -0.3151,    -inf]],

         [[-0.6063, -0.5643,    -inf],
          [ 0.2344, -0.4607,    -inf],
          [-0.8175,  0.6083,    -inf]]]], grad_fn=<MaskedFillBackward0>)
softmaxed_prod: 
 tensor([[[[1

In [24]:
from torch import nn

class Output(nn.Module):

    def __init__(self, vocab_size: int = 2**13, d: int = 512, ff_weight = None):
        super().__init__()
        self.ff = nn.Linear(d, vocab_size)
        # weight tying with the decoder embedding
        if ff_weight is not None:
            self.ff.weight = ff_weight

    def forward(self, x):
        return self.ff(x)

In [25]:
from torch import nn

class Transformer(nn.Module):
    
    def __init__(self, vocab_size: int = 2**13, n: int = 6, d: int = 512, h: int = 8, embed_tying=True):
        super().__init__()
        self.encoder = Encoder(vocab_size=vocab_size, n=n, d=d, h=h)
        self.decoder = Decoder(vocab_size=vocab_size, n=n, d=d, h=h)
        if embed_tying:
            self.output = Output(vocab_size=vocab_size, d=d, ff_weight = self.decoder.get_embed_weights())
        else:
            self.output = Output(vocab_size=vocab_size, d=d)

    def forward(self, enc_x, dec_x, enc_mask=None, dec_mask=None):
        encoded = self.encoder(enc_x, enc_mask)
        decoded = self.decoder(dec_x=dec_x, enc_x=encoded, self_mask=dec_mask, cross_mask=enc_mask)
        return self.output(decoded)

transformer = Transformer(vocab_size=32, n=2, d=16, h=2, embed_tying=False)
enc_x = torch.tensor([[15, 7, 3], [10, 10, 0], [1, 0, 0]])
dec_x = torch.tensor([[21, 8, 0, 0], [25, 0, 0, 0], [8, 1, 2, 3]])
# dec_x = torch.tensor([[21, 8], [25, 0], [8, 1]])

enc_mask = build_padding_mask(enc_x, pad_token=0)
print(f"enc_mask: \n {enc_mask}")
enc_mask = reshape_mask(enc_mask)

dec_mask1 = build_padding_mask(dec_x, pad_token=0)
dec_mask2 = build_causal_mask(dec_x)
dec_mask = merge_masks(dec_mask1, dec_mask2)
print(f"dec_mask: \n {dec_mask}")
dec_mask = reshape_mask(dec_mask)

print(transformer(enc_x, dec_x, enc_mask=enc_mask, dec_mask=dec_mask).shape)

enc_mask: 
 tensor([[1, 1, 1],
        [1, 1, 0],
        [1, 0, 0]])
dec_mask: 
 tensor([[1, 0, 0, 0],
        [1, 0, 0, 0],
        [1, 1, 1, 0]])
scaled_prod.shape: 
 torch.Size([3, 2, 3, 3])
scaled_prod: 
 tensor([[[[ 1.2182,  1.4764,  0.3400],
          [ 1.2922,  1.9110,  0.7214],
          [ 0.2941,  0.0643, -0.1778]],

         [[-0.0115, -0.0088, -0.1482],
          [-0.4553, -0.3255, -0.4281],
          [-0.2283, -1.7122, -0.9400]]],


        [[[ 0.3308,  0.2568,    -inf],
          [ 0.1895,  0.1410,    -inf],
          [ 0.3867,  0.4825,    -inf]],

         [[-0.2505, -0.2779,    -inf],
          [-0.0279, -0.0726,    -inf],
          [-0.3503, -0.3737,    -inf]]],


        [[[ 0.0320,    -inf,    -inf],
          [ 0.2436,    -inf,    -inf],
          [ 0.1757,    -inf,    -inf]],

         [[-0.3702,    -inf,    -inf],
          [-0.2347,    -inf,    -inf],
          [-0.1736,    -inf,    -inf]]]], grad_fn=<MaskedFillBackward0>)
softmaxed_prod: 
 tensor([[[[0.3690, 0.4

In [26]:
# Inference

In [27]:
import tiktoken
import torch
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F


encoding = tiktoken.get_encoding("cl100k_base")
sents = ["Hello World", "This is a simple sentence", "Me"]
encoded_sents = [encoding.encode(s) for s in sents]
enc_x = pad_sequence([torch.tensor(es) for es in encoded_sents], batch_first=True, padding_value=encoding.eot_token)
print(enc_x)
dec_sents = ["Bonjour", "C'est une phrase", "START"]
dec_encoded_sents = [encoding.encode(s) for s in dec_sents]
dec_x = pad_sequence([torch.tensor(es) for es in dec_encoded_sents], batch_first=True, padding_value=encoding.eot_token)
print(dec_x)

transformer = Transformer(vocab_size=encoding.n_vocab, n=3, d=256, h=4)

enc_mask = build_padding_mask(enc_x, pad_token=100257)
print(f"enc_mask: \n {enc_mask}")
enc_mask = reshape_mask(enc_mask)

dec_mask1 = build_padding_mask(dec_x, pad_token=100257)
dec_mask2 = build_causal_mask(dec_x)
dec_mask = merge_masks(dec_mask1, dec_mask2)
print(f"dec_mask: \n {dec_mask}")
dec_mask = reshape_mask(dec_mask)

output = transformer(enc_x, dec_x, enc_mask=enc_mask, dec_mask=dec_mask)
print(f"output shape: {output.shape}")
softmaxed = F.softmax(output, dim=-1)
print(f"softmaxed[0, 0, :10]: {softmaxed[0, 0, :10]}")
predicted = softmaxed.argmax(dim=-1)
print(f"predicted: \n {predicted}")

predicted_list = predicted.tolist()
predicted_decoded = [encoding.decode(l) for l in predicted_list]
print(f"predicted decoded: \n {predicted_decoded}")

tensor([[  9906,   4435, 100257, 100257, 100257],
        [  2028,    374,    264,   4382,  11914],
        [  7979, 100257, 100257, 100257, 100257]])
tensor([[ 82681, 100257, 100257, 100257],
        [    34,  17771,   6316,  17571],
        [ 23380, 100257, 100257, 100257]])
enc_mask: 
 tensor([[1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1],
        [1, 0, 0, 0, 0]])
dec_mask: 
 tensor([[1, 0, 0, 0],
        [1, 1, 0, 0],
        [1, 0, 0, 0]])
scaled_prod.shape: 
 torch.Size([3, 4, 5, 5])
scaled_prod: 
 tensor([[[[-1.2155e-01,  2.8055e-01,        -inf,        -inf,        -inf],
          [-2.5560e-01,  3.5178e-01,        -inf,        -inf,        -inf],
          [-1.3659e-03,  3.8009e-01,        -inf,        -inf,        -inf],
          [-1.2051e-01,  9.6084e-04,        -inf,        -inf,        -inf],
          [-3.6533e-01,  3.3893e-01,        -inf,        -inf,        -inf]],

         [[ 5.7374e-01, -2.6784e-01,        -inf,        -inf,        -inf],
          [-6.6826e-01, -6.037

In [28]:
# Predicting next words
sent = "This is a simple sentence"
encoded_sent = encoding.encode(sent)
enc_x = torch.tensor(encoded_sent).unsqueeze(0)
dec_x = torch.tensor(encoding.encode("C")).unsqueeze(0)

transformer = Transformer(vocab_size=encoding.n_vocab, n=3, d=256, h=4)

predicted_tokens = []
for _ in range(5):
    output = transformer(enc_x=enc_x, dec_x=dec_x)
    softmaxed = F.softmax(output, dim=-1)
    predicted = softmaxed.argmax(dim=-1)
    predicted_tokens.append(predicted.tolist()[-1][-1]) 
    dec_x = torch.cat((dec_x, predicted), dim=-1)

print(predicted_tokens)
print(f"predicted sentence: \n {encoding.decode(predicted_tokens)}")

scaled_prod.shape: 
 torch.Size([1, 4, 5, 5])
scaled_prod: 
 tensor([[[[-0.0444, -0.4916,  0.0421,  0.0309, -0.6455],
          [-0.5642, -0.2418, -0.6000, -0.3910,  0.1223],
          [-0.7563, -0.6709, -0.6949,  0.4502, -0.5530],
          [-0.9588,  0.1864,  0.0130, -0.3373, -0.6868],
          [-0.1787,  0.8517,  0.0945, -0.6182,  0.1224]],

         [[-0.5480,  0.3953, -0.1576, -0.0710, -0.0051],
          [ 0.1514,  0.4362, -0.5188, -0.0307, -0.4017],
          [ 0.4295,  0.4128,  0.1835,  0.0398,  0.3154],
          [ 0.9346,  1.3585,  0.3010,  0.1890, -0.5747],
          [ 0.8918,  0.7975,  0.6512,  0.3392,  0.2650]],

         [[-0.1668,  0.0657, -0.4285, -0.0558, -0.4906],
          [ 0.6594,  0.2603, -0.4549, -0.5408,  0.8364],
          [ 1.0529,  0.5128, -0.0778, -0.1687,  0.3484],
          [ 0.9814,  0.5878,  0.7433, -0.0950,  0.9049],
          [ 0.3427,  0.8149,  0.1604,  0.2167,  0.5411]],

         [[ 0.6841,  0.1972,  0.2608,  0.8136,  0.6876],
          [ 0.2523,  

In [29]:
# Dataset
import torch
from torch.utils.data import Dataset
from pathlib import Path
import csv
from enum import Enum



class Partition(Enum):
    TRAIN = "train"
    VAL = "val"

class Tokens(Enum):
    START = "START "
    END = "<|endoftext|>"
    PAD = " PAD"
    START_NUM = 23380
    END_NUM = 100257
    PAD_NUM = 62854
    

class EnFrDataset(Dataset):

    def __init__(self, file: Path | str, partition: Partition = Partition.TRAIN, val_ratio: float = 0.1):
        # partition = TRAIN | VAL
        self._partition = partition
        self._val_ratio = val_ratio

        self._data = []
        self._train_map: dict[int, int] = {}
        self._val_map: dict[int, int] = {}
        train_id = 0
        val_id = 0
        with open(file, newline='') as csvfile:
            reader = csv.reader(csvfile)
            # we want data indexes start from 0, but filter out the first header row
            for i, row in enumerate(reader, start=-1):
                if i == -1:
                    continue
                en = row[0]
                fr = Tokens.START.value + row[1]
                self._data.append(tuple([en, fr]))
                if int(i * val_ratio) == int((i - 1) * val_ratio):
                    self._train_map[train_id] = i
                    train_id += 1
                else:
                    self._val_map[val_id] = i
                    val_id += 1

    class Iterator():

        def __init__(self, outer):
            self.cur = 0
            self.outer = outer

        def __next__(self):
            if self.cur == len(self.outer._data):
                raise StopIteration()
            cur = self.outer._data[self.cur]
            self.cur += 1
            return cur

    def __iter__(self):
        return EnFrDataset.Iterator(self)
    
    @property
    def partition(self):
        return self._partition

    @partition.setter
    def partition(self, partition):
        self._partition = partition
    
    def __len__(self):
        return len(self._train_map) if self._partition == Partition.TRAIN else len(self._val_map)

    def __getitem__(self, idx):
        return self._data[self._train_map[idx]] if self._partition == Partition.TRAIN else self._data[self._val_map[idx]]

dataset = EnFrDataset("../data/eng_-french.csv", val_ratio=0.1)
train_sample = dataset[0]
dataset.partition = Partition.VAL
val_sample = dataset[0]
assert train_sample != val_sample
print(train_sample)
print(val_sample)

for i, d in enumerate(dataset):
    if i > 2:
        break
    print(d)

('Hi.', 'START Salut!')
('Stop!', 'START ArrÃªte-toi !')
('Hi.', 'START Salut!')
('Run!', 'START Cours\u202f!')
('Run!', 'START Courez\u202f!')


In [30]:
class TokEnFrDataset(Dataset):

    @staticmethod
    def build_train_sample(en_str: str, dec_str: str):
        en_encoded = encoding.encode(en_str)
        dec_encoded = encoding.encode(dec_str)
        dec_encoded.append(Tokens.END_NUM.value)
        en_sents = []
        dec_sents = []
        target_sents = []
        
        for i in range(1, len(dec_encoded)):
            dec_sents.append(dec_encoded[:i])
            target_sents.append(dec_encoded[1: i + 1])
        en_sents.extend([en_encoded] * len(dec_sents))
        return list(zip(en_sents, dec_sents, target_sents))

    def __init__(self, file: Path | str, partition: Partition = Partition.TRAIN, val_ratio: float = 0.1):
        self._dataset = EnFrDataset(file, partition, val_ratio=0)
        # partition = TRAIN | VAL
        self._partition = partition
        self._val_ratio = val_ratio

        self._data = []
        self._train_map: dict[int, int] = {}
        self._val_map: dict[int, int] = {}
        train_id = 0
        val_id = 0
        i = 0
        for en, fr in self._dataset:
            for sample in self.build_train_sample(en, fr):
                self._data.append(sample)
                if int(i * val_ratio) == int((i - 1) * val_ratio):
                    self._train_map[train_id] = i
                    train_id += 1
                else:
                    self._val_map[val_id] = i
                    val_id += 1
                i += 1

    @property
    def partition(self):
        return self._partition

    @partition.setter
    def partition(self, partition):
        self._partition = partition
    
    def __len__(self):
        return len(self._train_map) if self._partition == Partition.TRAIN else len(self._val_map)

    def __getitem__(self, idx):
        return self._data[self._train_map[idx]] if self._partition == Partition.TRAIN else self._data[self._val_map[idx]]

dataset = TokEnFrDataset("../data/eng_-french.csv", val_ratio=0.1)
train_sample = dataset[0]
dataset.partition = Partition.VAL
val_sample = dataset[0]
assert train_sample != val_sample
print(train_sample)
print(val_sample)

for i, d in enumerate(dataset):
    if i > 2:
        break
    print(d)

([13347, 13], [23380], [8375])
([6869, 0], [23380], [18733])
([6869, 0], [23380], [18733])
([36981, 0], [23380, 64105], [64105, 64])
([12978, 0], [23380], [65381])


In [32]:
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence


def collate(batch):
    print(batch)
    _x, _y, _label = list(zip(*batch))
    x = pad_sequence([torch.tensor(t) for t in _x], batch_first=True, padding_value=Tokens.PAD_NUM.value)
    y = pad_sequence([torch.tensor(t) for t in _y], batch_first=True, padding_value=Tokens.PAD_NUM.value)
    label = pad_sequence([torch.tensor(t) for t in _label], batch_first=True, padding_value=Tokens.PAD_NUM.value)
    enc_mask = build_padding_mask(x, pad_token=Tokens.PAD_NUM.value)
    dec_mask = build_padding_mask(y, pad_token=Tokens.PAD_NUM.value)
    return x, y, label, enc_mask, dec_mask

training_generator = DataLoader(dataset, collate_fn=collate, batch_size=5, num_workers=0)
for batch in training_generator:
    print(batch)
    break

[([6869, 0], [23380], [18733]), ([36981, 0], [23380, 64105], [64105, 64]), ([12978, 0], [23380], [65381]), ([35079, 13], [23380, 16233, 1088], [16233, 1088, 13]), ([10903, 0], [23380], [14549])]
(tensor([[ 6869,     0],
        [36981,     0],
        [12978,     0],
        [35079,    13],
        [10903,     0]]), tensor([[23380, 62854, 62854],
        [23380, 64105, 62854],
        [23380, 62854, 62854],
        [23380, 16233,  1088],
        [23380, 62854, 62854]]), tensor([[18733, 62854, 62854],
        [64105,    64, 62854],
        [65381, 62854, 62854],
        [16233,  1088,    13],
        [14549, 62854, 62854]]), tensor([[1, 1],
        [1, 1],
        [1, 1],
        [1, 1],
        [1, 1]]), tensor([[1, 0, 0],
        [1, 1, 0],
        [1, 0, 0],
        [1, 1, 1],
        [1, 0, 0]]))


In [None]:
# Training
import torch

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# torch.set_default_device("cpu")
print(device)

training_params = {
    'collate_fn': collate,
    'batch_size': 32,
    'shuffle': True,
    'num_workers': 0}
max_epochs = 5

# Generators
dataset = TokEnFrDataset("../data/eng_-french.csv", val_ratio=0.05)
dataloader = DataLoader(dataset, **training_params)


# Loop over epochs
for epoch in range(max_epochs):
    # Training
    for local_batch, local_labels in training_generator:
        # Transfer to GPU
        local_batch, local_labels = local_batch.to(device), local_labels.to(device)

        # Model computations
        [...]

    # Validation
    with torch.set_grad_enabled(False):
        for local_batch, local_labels in validation_generator:
            # Transfer to GPU
            local_batch, local_labels = local_batch.to(device), local_labels.to(device)

            # Model computations
            [...]

In [None]:
import tiktoken

encoding = tiktoken.get_encoding("cl100k_base")

def build_train_sample(en_str: str, dec_str: str):
    en_encoded = encoding.encode(en_str)
    dec_encoded = encoding.encode(dec_str)
    dec_encoded.append(Tokens.END_NUM.value)
    en_sents = []
    dec_sents = []
    target_sents = []
    
    for i in range(1, len(dec_encoded)):
        dec_sents.append(dec_encoded[:i])
        target_sents.append(dec_encoded[1: i + 1])
    en_sents.extend([en_encoded] * len(dec_sents))
    return list(zip(en_sents, dec_sents, target_sents))

build_train_sample('Hi.', 'START Salut!')

In [None]:
from torch.utils.data import DataLoader

# def collate_fn

dataset.partition = Partition.TRAIN
training_generator = DataLoader(dataset, batch_size=2, shuffle=True)
for i, s in enumerate(training_generator):
    print(s)
    if i > 2:
        break
        
    

In [None]:
import torch
print(torch.backends.mps.is_available())

In [None]:
# training

print(torch.backends.mps.is_available())
print(torch.backends.mps.is_built())


In [None]:
class Partition(Enum):
    TRAIN = "train"
    VAL = "val"

In [None]:
F.softmax(torch.tensor([[0.1,0.2,0.3],[0.1, 0.2, 0.3]]), dim=-1)
torch.tensor([[0.1,0.2,0.3],[0.1, 0.2, 0.4]]).argmax(dim=-1)
# torch.max(torch.tensor([[0.1,0.2,0.3],[0.1, 0.2, 0.4]]), dim=-1)

t1 = torch.tensor([[0.1, 0.2]])
t2 = torch.tensor([[0.3]])
torch.cat((t1, t2), dim=1).tolist()[-1][-1]


In [None]:
int(0.1)

In [None]:
assert False

In [None]:
x = torch.rand([1, 2, 3])
mask = torch.ones([1, 2])
mask[0, 1] = 0
mask = mask.unsqueeze(1)
print(mask == 0)
x.masked_fill(mask == 0, float("-inf"))