This notebook will contain the final version of all code written to train the transformer.

In [1]:
import torch
import torch.nn.functional as F

import numpy as np

In [2]:
class MLP(torch.nn.Module):
    def __init__(self, de, in_d, debug=True):
        super().__init__()
        self.W1 = torch.nn.Linear(de, in_d)
        self.W2 = torch.nn.Linear(in_d, de)
        self.layernorm = torch.nn.LayerNorm(de)
        self.dropout = torch.nn.Dropout(p=0.1)
        self.debug = debug
        
    def forward(self, x):
        if self.debug:
            print('MLP')
            print('Before W')
            print('x', x.shape)
            
        x = x + self.W2(torch.nn.functional.relu(self.W1(x)))
        
        if self.debug:
            print('After W')
            print('x', x.shape)
            
        x = self.layernorm(x)
        x = self.dropout(x)
        return x

In [3]:
class MultiHeadedAttention(torch.nn.Module):
    def __init__(self, h, de, dq, dk, dv, debug=True):
        super().__init__()
        self.h, self.de, self.dq, self.dk, self.dv, self.debug = h, de, dq, dk, dv, debug
        self.W_q = torch.nn.Linear(de, dq*h)
        self.W_k = torch.nn.Linear(de, dk*h)
        self.W_v = torch.nn.Linear(de, dv*h)
        self.W_o = torch.nn.Linear(h*dv, de) 
        self.layernorm = torch.nn.LayerNorm(de)
        self.dropout = torch.nn.Dropout(p=0.1)
        
    def scaled_dot_prod_attention(self, q, k, v, masks=None):
        qkt = torch.einsum('bhqd, bhkd -> bhqk', q, k)
        scale = qkt/(q.shape[-1]**-2)
        
        if self.debug:
            print('qkt/scale', scale.shape)
        
        if masks is not None:
            if self.debug:
                print('masks', masks[:, None, None, :].shape)
            
            scale = scale.masked_fill(masks[:, None, None, :], -torch.inf)
        
        soft = torch.nn.Softmax(dim=-1)(scale)
        return torch.einsum('bhij, bhjk -> bhik', soft, v)
        
    def forward(self, x, encoder_output=None, masks=None):
        if self.debug:
            print('x', x.shape)
        
        if encoder_output is not None:
            # I don't know if the encoder_output should be linearly transformed
            # So I need to test this later on 
            q, k, v = self.W_q(x), self.W_k(encoder_output), self.W_v(encoder_output)
        else:
            q, k, v = self.W_q(x), self.W_k(x), self.W_v(x)
        
        if self.debug:
            print('Before splitting')
            print('q', q.shape)
            print('k', k.shape)
            print('v', v.shape)
        
        # Split into different heads
        bs = q.shape[0]
        ln_q = q.shape[1]
        ln_k = k.shape[1]
        ln_v = v.shape[1]

        q = q.view(bs, ln_q, self.h, self.dq)
        k = k.view(bs, ln_k, self.h, self.dk)
        v = v.view(bs, ln_v, self.h, self.dv)
        
        if self.debug:
            print('After splitting')
            print('q', q.shape)
            print('k', k.shape)
            print('v', v.shape)
        
        # Transpose ln and h for scaled dot product attention
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)
        
        if self.debug:
            print('After Transpose')
            print('q', q.shape)
            print('k', k.shape)
            print('v', v.shape)
        
        sdpa = self.scaled_dot_prod_attention(q, k, v, masks)
        
        if self.debug:
            print('After scaled dot product')
            print('sdpa', sdpa.shape)
        
        # Linear projection of attention 
        sdpa = self.W_o(sdpa.view(bs, ln_q, -1))
        
        if self.debug:
            print('After linear projection')
            print('sdpa', sdpa.shape)
        
        # Add & Norm
        x = (sdpa + x).view(bs, ln_q, -1)
        x = self.layernorm(x)
        
        if self.debug:
            print('After add & norm')
            print('x', x.shape)
        
        x = self.dropout(x.contiguous().view(bs, ln_q, self.de))
        
        if self.debug:
            print('After concat')
            print('x', x.shape)

        return x

In [4]:
class EncoderBlock(torch.nn.Module):
    def __init__(self, in_d, h, de, dq, dk, dv, debug=True):
        super().__init__()
        self.mla = MultiHeadedAttention(h, de, dq, dk, dv, debug)
        self.mlp = MLP(de, in_d, debug)
        
    def forward(self, x):
        return self.mlp(self.mla(x))

In [5]:
class DecoderBlock(torch.nn.Module):
    def __init__(self, in_d, h, de, dq, dk, dv, debug=True):
        super().__init__()
        self.mla1 = MultiHeadedAttention(h, de, dq, dk, dv, debug)
        self.mla2 = MultiHeadedAttention(h, de, dq, dk, dv, debug)
        self.mlp = MLP(de, in_d, debug)
        
    def forward(self, x, encoder_output, masks):
        x = self.mla1(x, masks=masks)
        x = self.mla2(x, encoder_output=encoder_output)
        x = self.mlp(x)
        return x

In [6]:
class Transformer(torch.nn.Module):
    def __init__(self, n_vocab_inp, dmodel=512, dmiddle=2048, dq=64, dk=64, dv=64, h=8, n_encoders=6, n_decoders=6, debug=True):
        super().__init__()
        
        self.dmodel = dmodel
        self.input_embeddings = torch.nn.Embedding(num_embeddings=n_vocab_inp, embedding_dim=dmodel)
        self.encoder_blocks = torch.nn.ModuleList([EncoderBlock(dmiddle, h, dmodel, dq, dk, dv, debug) for _ in range(n_encoders)])
        self.decoder_blocks = torch.nn.ModuleList([DecoderBlock(dmiddle, h, dmodel, dq, dk, dv, debug) for _ in range(n_decoders)])
        
    def forward(self, x, y, masks):
        x = self.input_embeddings(x)
        x = self.add_positional_encodings(x)
        for encoder in self.encoder_blocks:
            x = encoder(x)
        
        
        y = self.input_embeddings(y)
        y = self.add_positional_encodings(y)
        for decoder in self.decoder_blocks:
            y = decoder(y, x, masks)
            
        final = self.final_linear(y)
        
        return final
    
    def add_positional_encodings(self, x):
        pos, i = torch.where(x[0])
        sin_enc = torch.sin(pos/(10000**(2*i/self.dmodel)))
        cos_enc = torch.cos(pos/(10000**(2*i/self.dmodel)))
        pos_enc = torch.where(i % 2 == 0, sin_enc, cos_enc).view(x.shape[1:])
        return x + pos_enc
    
    def final_linear(self, y):
        return torch.einsum('bmd, dn -> bmn', y, self.input_embeddings.weight.T)
    

In [7]:
from torch.utils.data import Dataset, DataLoader

class TransformerDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer
        

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):            
        tokenized_input = tokenizer.encode(self.data[idx]['translation']['en'])
        tokenized_output = tokenizer.encode(self.data[idx]['translation']['ar'])

        encoder_input_ids = torch.tensor([tokenized_input.ids for _ in range(len(tokenized_output.ids)-1)])[:, 1:-1]
        
        tokenized_output_ids = torch.tensor([tokenized_output.ids for _ in range(len(tokenized_output.ids)-1)])
        
        decoder_input_ids = tokenized_output_ids[:, :-1]
        decoder_input_masks = self.get_decoder_masks(decoder_input_ids)
        
        
        target_ids = tokenized_output_ids[:, 1:]
        target_masks = self.get_target_masks(target_ids)
        
        return encoder_input_ids, decoder_input_ids, target_ids, decoder_input_masks, target_masks
        

    def get_decoder_masks(self, ids):
        masks = torch.ones_like(ids)
        return torch.triu(masks, diagonal=1)
    
    def get_target_masks(self, ids):
        masks = torch.ones_like(ids)
        return torch.tril(masks)

In [8]:
from datasets import load_dataset

In [9]:
dataset = load_dataset("news_commentary", "ar-en", split=["train"], cache_dir="data")

Reusing dataset news_commentary (data/news_commentary/ar-en/11.0.0/cfab724ce975dc2da51cdae45302389860badc88b74db8570d561ced6004f8b4)


  0%|          | 0/1 [00:00<?, ?it/s]

In [10]:
from tokenizers import Tokenizer
from tokenizers.models import BPE

tokenizer = Tokenizer(BPE(unk_token="[UNK]"))

In [11]:
from tokenizers import normalizers

tokenizer.normalizer = normalizers.Sequence(
    [
        normalizers.NFD(), 
        normalizers.Lowercase(), 
        normalizers.StripAccents()
    ]
)

In [12]:
from tokenizers.trainers import BpeTrainer

trainer = BpeTrainer(special_tokens=["[PAD]", "[SOS]", "[EOS]", "[UNK]"])

In [13]:
from tokenizers.pre_tokenizers import Whitespace

tokenizer.pre_tokenizer = Whitespace()

In [14]:
from itertools import chain

list(chain.from_iterable(map(lambda x: x.values(), dataset[0][0:10]['translation'])))

['الذهب بعشرة آلاف دولار؟',
 '$10,000 Gold?',
 'سان فرانسيسكو ـ لم يكن من السهل قط أن ينخرط المرء في محادثة عقلانية حول قيمة الذهب. ومؤخراً، مع ارتفاع أسعار الذهب بما يزيد على 300% في غضون الأعوام العشرة الماضية، فقد أصبح الأمر أصعب من أي وقت مضى. ففي شهر ديسمبر/كانون الأول الماضي، نشر كل من مارتن فيلدشتاين ونورييل روبيني ـ وهما من أبرز خبراء الاقتصاد ـ مقالاً تحريرياً، حيث شكك كل منهما في مقاله بشجاعة في مشاعر السوق الصاعدة، مشيراً بحكمة إلى المجازفات والمخاطر المحيطة بالذهب.',
 'Lately, with gold prices up more than 300% over the last decade, it is harder than ever. Just last December, fellow economists Martin Feldstein and Nouriel Roubini each penned op-eds bravely questioning bullish market sentiment, sensibly pointing out gold’s risks.',
 'ولكن من المؤسف أن أسعار الذهب واصلت ارتفاعها منذ نشر المقالين. حتى أن أسعار الذهب سجلت رقماً قياسياً بلغ 1300 دولار مؤخرا. وفي ديسمبر/كانون الأول الماضي كان العديد من أنصار الذهب يزعمون أن الأسعار تتجه حتماً نحو 2000 دولار. والآن، وبتشجيع من الا

In [15]:
batch_size = 1000
def batch_iterator():
    for i in range(len(dataset)):
        for j in range(0, len(dataset[i]), batch_size):
            yield list(
                chain.from_iterable(
                    map(lambda x: x.values(),
                        dataset[0][0:batch_size]['translation'])
                )
            )

In [16]:
# next(batch_iterator())

In [17]:
tokenizer.train_from_iterator(batch_iterator(), trainer=trainer)






In [18]:
from tokenizers.processors import TemplateProcessing

In [19]:
tokenizer.post_processor = TemplateProcessing(
    single="[SOS] $A [EOS]",
    special_tokens=[
        ("[SOS]", tokenizer.token_to_id("[SOS]")),
        ("[EOS]", tokenizer.token_to_id("[EOS]")),
    ],
)

In [20]:
tokenizer.save("data/tokenizer-news_commentary.json")

In [21]:
tokenizer = tokenizer.from_file("data/tokenizer-news_commentary.json")

#### Add number of tokens per translation to each row

In [22]:
def count_tokens(row):
    row['translation']['count'] = len(tokenizer.encode(row['translation']['ar']).ids) + 1 # Add one because we use [SOS] and [EOS]
    return row

In [23]:
new_dataset = dataset[0].map(count_tokens)



  0%|          | 0/83187 [00:00<?, ?ex/s]

We can now get an array containing the number of tokens in the arabic translation per row, and we can use it to create a new id system.

In [24]:
counts = np.array(new_dataset.flatten()['translation.count'])
counts[:10]

array([  8,  92,  73,  66, 131,  72, 138,  62,  62,  60])

In [25]:
cumsum = np.cumsum(np.pad(counts, (1, 0), 'constant'))
cumsum[:10]

array([  0,   8, 100, 173, 239, 370, 442, 580, 642, 704])

In [26]:
idx_map = {}
for i, (x1, x2) in enumerate(zip(cumsum[:-1], cumsum[1:])):
    for n in np.arange(x1, x2):
        idx_map[n] = i

In [27]:
idx_map

{0: 0,
 1: 0,
 2: 0,
 3: 0,
 4: 0,
 5: 0,
 6: 0,
 7: 0,
 8: 1,
 9: 1,
 10: 1,
 11: 1,
 12: 1,
 13: 1,
 14: 1,
 15: 1,
 16: 1,
 17: 1,
 18: 1,
 19: 1,
 20: 1,
 21: 1,
 22: 1,
 23: 1,
 24: 1,
 25: 1,
 26: 1,
 27: 1,
 28: 1,
 29: 1,
 30: 1,
 31: 1,
 32: 1,
 33: 1,
 34: 1,
 35: 1,
 36: 1,
 37: 1,
 38: 1,
 39: 1,
 40: 1,
 41: 1,
 42: 1,
 43: 1,
 44: 1,
 45: 1,
 46: 1,
 47: 1,
 48: 1,
 49: 1,
 50: 1,
 51: 1,
 52: 1,
 53: 1,
 54: 1,
 55: 1,
 56: 1,
 57: 1,
 58: 1,
 59: 1,
 60: 1,
 61: 1,
 62: 1,
 63: 1,
 64: 1,
 65: 1,
 66: 1,
 67: 1,
 68: 1,
 69: 1,
 70: 1,
 71: 1,
 72: 1,
 73: 1,
 74: 1,
 75: 1,
 76: 1,
 77: 1,
 78: 1,
 79: 1,
 80: 1,
 81: 1,
 82: 1,
 83: 1,
 84: 1,
 85: 1,
 86: 1,
 87: 1,
 88: 1,
 89: 1,
 90: 1,
 91: 1,
 92: 1,
 93: 1,
 94: 1,
 95: 1,
 96: 1,
 97: 1,
 98: 1,
 99: 1,
 100: 2,
 101: 2,
 102: 2,
 103: 2,
 104: 2,
 105: 2,
 106: 2,
 107: 2,
 108: 2,
 109: 2,
 110: 2,
 111: 2,
 112: 2,
 113: 2,
 114: 2,
 115: 2,
 116: 2,
 117: 2,
 118: 2,
 119: 2,
 120: 2,
 121: 2,
 122: 2,
 12

Now we have a dict that maps the new id to the original dataset id, but we need a way to figure out which is the cutoff token to process in the dataset.

In [28]:
token_map = {}
for x1, x2 in zip(cumsum[:-1], cumsum[1:]):
    for i, n in enumerate(np.arange(x1, x2)):
        token_map[n] = i

In [29]:
token_map

{0: 0,
 1: 1,
 2: 2,
 3: 3,
 4: 4,
 5: 5,
 6: 6,
 7: 7,
 8: 0,
 9: 1,
 10: 2,
 11: 3,
 12: 4,
 13: 5,
 14: 6,
 15: 7,
 16: 8,
 17: 9,
 18: 10,
 19: 11,
 20: 12,
 21: 13,
 22: 14,
 23: 15,
 24: 16,
 25: 17,
 26: 18,
 27: 19,
 28: 20,
 29: 21,
 30: 22,
 31: 23,
 32: 24,
 33: 25,
 34: 26,
 35: 27,
 36: 28,
 37: 29,
 38: 30,
 39: 31,
 40: 32,
 41: 33,
 42: 34,
 43: 35,
 44: 36,
 45: 37,
 46: 38,
 47: 39,
 48: 40,
 49: 41,
 50: 42,
 51: 43,
 52: 44,
 53: 45,
 54: 46,
 55: 47,
 56: 48,
 57: 49,
 58: 50,
 59: 51,
 60: 52,
 61: 53,
 62: 54,
 63: 55,
 64: 56,
 65: 57,
 66: 58,
 67: 59,
 68: 60,
 69: 61,
 70: 62,
 71: 63,
 72: 64,
 73: 65,
 74: 66,
 75: 67,
 76: 68,
 77: 69,
 78: 70,
 79: 71,
 80: 72,
 81: 73,
 82: 74,
 83: 75,
 84: 76,
 85: 77,
 86: 78,
 87: 79,
 88: 80,
 89: 81,
 90: 82,
 91: 83,
 92: 84,
 93: 85,
 94: 86,
 95: 87,
 96: 88,
 97: 89,
 98: 90,
 99: 91,
 100: 0,
 101: 1,
 102: 2,
 103: 3,
 104: 4,
 105: 5,
 106: 6,
 107: 7,
 108: 8,
 109: 9,
 110: 10,
 111: 11,
 112: 12,
 113: 13

Now I have two dicts ready at my disposal to enable the pytorch dataset to process the sequences from the dataset on a certain token instead of generating the whole sequence data.

I'll try this pipeline outside of the class to make sure it can work.

In [30]:
idx = 8

In [31]:
new_idx = idx_map[idx]

In [32]:
tokenized_input = tokenizer.encode(new_dataset[new_idx]['translation']['en'])
tokenized_output = tokenizer.encode(new_dataset[new_idx]['translation']['ar'])

In [33]:
encoder_input_ids = torch.tensor([tokenized_input.ids[1:-1]])
encoder_input_ids

tensor([[24500,    14,   328,  1997,  1932,   613,   407,   545,  3796,     8,
           643,   154,  1542,  2632,    14,   175,   176, 10639,   545,  3361,
            16,   915,  1542,  6587,    14, 21073,  3477, 13587, 26576,   183,
         28628, 27685,  2659, 27087,   372,    15, 18718, 27588, 24816, 24408,
          1941, 27048,    14, 25632,  8795,   464,  1997,   143,    51,  4292,
            16]])

In [34]:
tokenized_output_ids = torch.tensor(tokenized_output.ids)
decoder_input_ids = tokenized_output_ids[:-1]
target_ids = tokenized_output_ids[1:]
decoder_input_ids, target_ids

(tensor([    1,  1621, 15566,    86,   378,  1906,   163,  9773,  2196,   151,
         18638,  3732,   160, 21807,  8931,  1736,  2329,  1552,    16, 15342,
            62,   281,  2380,  1148,  1552,   893,  3719,   189,  3796,     8,
           160,  5808,  2630,  6413,  3737,    62,   636,  1426,   629,  9950,
           163,   201,  1195,  5762,    16,  1259,  3589,  6629,    17,  3921,
           920,  1236,    62,  3223,   245,   163, 14133, 26173, 29258, 28797,
            86,  8671,   163,  6603,  2673,   426,    86, 10338, 25380,    62,
           738, 13855,   245, 14403,   160, 22567, 23057,   160,  4694,  2834,
         13002,    62, 23037, 13369,   210, 22386,  8402,  8808,  9191,    16]),
 tensor([ 1621, 15566,    86,   378,  1906,   163,  9773,  2196,   151, 18638,
          3732,   160, 21807,  8931,  1736,  2329,  1552,    16, 15342,    62,
           281,  2380,  1148,  1552,   893,  3719,   189,  3796,     8,   160,
          5808,  2630,  6413,  3737,    62,   636,

In [35]:
token = token_map[idx]

In [36]:
token

0

In [37]:
decoder_input_masks = torch.ones_like(decoder_input_ids)
decoder_input_masks[:token+1] = 0
decoder_input_masks

tensor([0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [38]:
random = torch.randn(decoder_input_masks.shape)
random.masked_fill(decoder_input_masks, 0)

tensor([1.8135, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000])

In [39]:
target_masks = torch.zeros_like(target_ids)
target_masks[:token+1] = 1
target_masks

tensor([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [40]:
target_ids * target_masks

tensor([1621,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0])

So I have tested it with multiple idx and it seems to work pretty well. I'll modify the dataset class to use the new pipeline.

In [84]:
class TransformerDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer
        self.init_maps()

    def __len__(self):
        return len(self.idx_map.keys())

    def __getitem__(self, idx):     
        new_idx = self.idx_map[idx]
        token = self.token_map[idx]
        
        tokenized_input = tokenizer.encode(self.data[new_idx]['translation']['en'])
        tokenized_output = tokenizer.encode(self.data[new_idx]['translation']['ar'])

        encoder_input_ids = torch.tensor([tokenized_input.ids[1:-1]])
        
        tokenized_output_ids = torch.tensor([tokenized_output.ids])

        decoder_input_ids = tokenized_output_ids[:, :-1]
        decoder_input_masks = self.get_decoder_masks(decoder_input_ids, token)
        
        
        target_ids = tokenized_output_ids[:, 1:]
        target_masks = self.get_target_masks(target_ids, token)
        
        return encoder_input_ids, decoder_input_ids, target_ids, decoder_input_masks, target_masks
        

    def get_decoder_masks(self, ids, token):
        masks = torch.ones_like(ids)
        masks[:, :token+1] = 0
        return masks
    
    def get_target_masks(self, ids, token):
        masks = torch.zeros_like(ids)
        masks[:, :token+1] = 1
        return masks
    
    def count_tokens(self, row):
        row['translation']['count'] = len(self.tokenizer.encode(row['translation']['ar']).ids) + 1 # Add one because we use [SOS] and [EOS]
        return row
    
    def init_maps(self):
        counts = np.array(self.data.flatten()['translation.count'])
        cumsum = np.cumsum(np.pad(counts, (1, 0), 'constant'))
        
        self.idx_map = {}
        for i, (x1, x2) in enumerate(zip(cumsum[:-1], cumsum[1:])):
            for n in np.arange(x1, x2):
                self.idx_map[n] = i
                
        self.token_map = {}
        for x1, x2 in zip(cumsum[:-1], cumsum[1:]):
            for i, n in enumerate(np.arange(x1, x2)):
                self.token_map[n] = i
    
        

In [85]:
dataset = TransformerDataset(new_dataset, tokenizer)
len(dataset)

6891959

In [86]:
dataset[0]

(tensor([[   7, 2261,   14, 2427, 1997,   30]]),
 tensor([[    1,  1552, 20144,  1005,  1316,    64]]),
 tensor([[ 1552, 20144,  1005,  1316,    64,     2]]),
 tensor([[0, 1, 1, 1, 1, 1]]),
 tensor([[1, 0, 0, 0, 0, 0]]))

In [87]:
dataset[1]

(tensor([[   7, 2261,   14, 2427, 1997,   30]]),
 tensor([[    1,  1552, 20144,  1005,  1316,    64]]),
 tensor([[ 1552, 20144,  1005,  1316,    64,     2]]),
 tensor([[0, 0, 1, 1, 1, 1]]),
 tensor([[1, 1, 0, 0, 0, 0]]))

In [88]:
def pad_row(r, max_seq_len):
    "Pads each row from Dataset __getitem__ to maximum sequence length in the current batch"
    return [F.pad(r[i], (0, max_seq_len[i] - r[i].shape[1]), "constant", 0) for i in range(len(r))]

In [89]:
def collate_fn(batch):
    "Collates multiples tuples into one tensor by padding their sequence length and concatenating them"
    max_seq_len = [max([t.shape[1] for t in ts]) for ts in zip(*batch)]
    padded_batch = [pad_row(r, max_seq_len) for r in batch]
    return [torch.concat(ts, dim=0) for ts in zip(*padded_batch)]

In [90]:
batch = [dataset[i] for i in range(6, 9)]

In [91]:
[[t.shape for t in r] for r in batch]

[[torch.Size([1, 6]),
  torch.Size([1, 6]),
  torch.Size([1, 6]),
  torch.Size([1, 6]),
  torch.Size([1, 6])],
 [torch.Size([1, 6]),
  torch.Size([1, 6]),
  torch.Size([1, 6]),
  torch.Size([1, 6]),
  torch.Size([1, 6])],
 [torch.Size([1, 51]),
  torch.Size([1, 90]),
  torch.Size([1, 90]),
  torch.Size([1, 90]),
  torch.Size([1, 90])]]

In [92]:
[t.shape for t in collate_fn(batch)]

[torch.Size([3, 51]),
 torch.Size([3, 90]),
 torch.Size([3, 90]),
 torch.Size([3, 90]),
 torch.Size([3, 90])]

In [93]:
import gc
import wandb

def train_one_epoch(epoch_index, model, optimizer, loss_fn, training_loader, validation_loader, save_path):
    running_loss = 0.
    last_loss = 0.
    epoch_loss = 0.
    
    for i, data in enumerate(training_loader):
        encoder_input_ids, decoder_input_ids, target_ids, decoder_input_masks, target_masks = data
        
        optimizer.zero_grad()
        outputs = model(encoder_input_ids.to('cuda'), decoder_input_ids.to('cuda'), decoder_input_masks.to('cuda').bool())
        
        outputs = outputs.view(-1, tokenizer.get_vocab_size())

        loss = loss_fn(outputs[target_masks.flatten().to('cuda').bool()],
                       target_ids.flatten().to('cuda')[target_masks.flatten().to('cuda').bool()])
        wandb.log({"loss": loss.item()})

        loss.backward()

        optimizer.step()
        
        running_loss += loss.item()
        epoch_loss += loss.item()
        
        n = 1000
        
        if i % n == n - 1:    # print every n mini-batches
            print('  batch {} loss: {}'.format(i + 1, running_loss/n))
            
            wandb.log({"running_loss": running_loss/n})
            
            if running_loss < last_loss:
                torch.save({
                    'epoch': epoch_index,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': running_loss,
                }, save_path)
                
            last_loss = running_loss
            running_loss = 0.0
            
        
        del encoder_input_ids
        del decoder_input_ids
        del target_ids
        del decoder_input_masks
        del target_masks
        
        torch.cuda.empty_cache()
        gc.collect()
        
    wandb.log({"train_loss": epoch_loss/len(training_loader)})
    
    
    model.eval()
    
    with torch.no_grad():
        
        val_loss = 0.
    
        for i, data in enumerate(validation_loader):
            encoder_input_ids, decoder_input_ids, target_ids, decoder_input_masks, target_masks = data

            outputs = model(encoder_input_ids.to('cuda'), decoder_input_ids.to('cuda'), decoder_input_masks.to('cuda').bool())

            outputs = outputs.view(-1, tokenizer.get_vocab_size())

            loss = loss_fn(outputs[target_masks.flatten().to('cuda').bool()],
                           target_ids.flatten().to('cuda')[target_masks.flatten().to('cuda').bool()])
            
            val_loss += loss.item()
            
            del encoder_input_ids
            del decoder_input_ids
            del target_ids
            del decoder_input_masks
            del target_masks

            torch.cuda.empty_cache()
            gc.collect()

        wandb.log({"val_loss": val_loss/len(validation_loader)})

In [63]:
import wandb

In [57]:
wandb.init(project="attention-is-all-you-need", entity="ahmedsamirio")

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
loss,▇█▇▅▅▄▄▄▃▃▂▃▂▃▃▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
loss,8.14026


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [94]:
new_dataset.shuffle()

Dataset({
    features: ['id', 'translation'],
    num_rows: 83187
})

In [98]:
new_dataset.select([0, 1])

Dataset({
    features: ['id', 'translation'],
    num_rows: 2
})

In [102]:
train_len = int(len(new_dataset)*0.8)
train_ids = range(0, train_len)
valid_ids = range(train_len, len(new_dataset))

In [103]:
TrainDataset = TransformerDataset(new_dataset.select(train_ids), tokenizer)
ValidDataset = TransformerDataset(new_dataset.select(valid_ids), tokenizer)

In [104]:
transformer = Transformer(tokenizer.get_vocab_size(), debug=False).to('cuda')
optimizer = torch.optim.Adam(transformer.parameters(), lr=0.01)
loss_fn = torch.nn.CrossEntropyLoss(reduction='mean')
training_loader = DataLoader(TrainDataset, batch_size=64, collate_fn=collate_fn, shuffle=False)
validation_loader = DataLoader(ValidDataset, batch_size=64, collate_fn=collate_fn, shuffle=False)

In [None]:
for i in range(100):
    train_one_epoch(i, transformer, optimizer, loss_fn, training_loader, validation_loader, 'model.pt')

  batch 1000 loss: 12.10929033613205
  batch 2000 loss: 8.59061136817932
  batch 3000 loss: 8.168880687713623
  batch 4000 loss: 7.952878517627716
  batch 5000 loss: 7.993534821033478
  batch 6000 loss: 8.048997934818267
  batch 7000 loss: 7.8977878718376155
  batch 8000 loss: 7.84813807964325
  batch 9000 loss: 7.89777515411377
  batch 10000 loss: 7.861690134525299
  batch 11000 loss: 7.831814856529236
  batch 12000 loss: 7.81308338546753
  batch 13000 loss: 7.821840600967407
  batch 14000 loss: 7.821160840988159
  batch 15000 loss: 7.759237428188324
  batch 16000 loss: 7.706367583751678
  batch 17000 loss: 7.715282855033874
  batch 18000 loss: 7.748741548538208
  batch 19000 loss: 7.849751022338867
  batch 20000 loss: 7.809839839935303
  batch 21000 loss: 7.815505774021148
  batch 22000 loss: 7.832446142673493
  batch 23000 loss: 7.793722000598907
  batch 24000 loss: 7.763119186401367
  batch 25000 loss: 7.776265001773834
  batch 26000 loss: 7.802763346195221
  batch 27000 loss: 7.79