# Related module  docs
* https://github.com/huggingface/tokenizers

# Note
* add ".to(device)" when move to transformer

# hyparameter

In [1]:
bos_id = 0
eos_id = 1
pad_id = 2
gpu_num = 1

# Download dataset

In [2]:
# from datasets import load_dataset
# dataset = load_dataset("wmt14", 'de-en', split='train')

In [3]:
from datasets import load_from_disk
dataset = load_from_disk('dataset')

  from .autonotebook import tqdm as notebook_tqdm


# write raw sectences into txt

In [4]:
# with open("en.txt",'w') as f:
#     for i in range(len(dataset)):
#         f.write(dataset[i]['translation']['en']+'\n')
        
# with open("de.txt",'w') as f:
#     for i in range(len(dataset)):
#         f.write(dataset[i]['translation']['de']+'\n')

In [5]:
# dataset['translation']

In [6]:
dataset[4508784]['translation']['en']

'Somehow Zuma must find a way to honor his own generation’s commitment to racial justice and national liberation, while empowering the masses who daily suffer the sting of class differences and yearn for material gain.'

In [7]:
len(dataset)

4508785

# train a tokenizer

In [8]:
# ! pip install tokenizers

### model

In [9]:
from tokenizers import Tokenizer
from tokenizers.models import BPE

tokenizer = Tokenizer(BPE())

### Normalization

### pre_tokenizer

In [10]:
from tokenizers.pre_tokenizers import ByteLevel

tokenizer.pre_tokenizer = ByteLevel()

### post_processor

In [11]:
from tokenizers.processors import TemplateProcessing
tokenizer.post_processor = TemplateProcessing(
    single="<bos> $A <eos>",
    special_tokens=[("<bos>", 0), ("<eos>", 1)],
)

### decoder

In [12]:
from tokenizers.decoders import ByteLevel
tokenizer.decoder = ByteLevel()

# other setting

In [13]:
tokenizer.enable_padding(pad_id=2,pad_token='<pad>')
# tokenizer.enable_truncation(maxlen=513)

# Train by raw file

In [14]:
from tokenizers.trainers import BpeTrainer

trainer = BpeTrainer(vocab_size=37000, show_progress=True, special_tokens=["<bos>", "<eos>", "<pad>"])
tokenizer.train(files=["de.txt", "en.txt"], trainer=trainer)

tokenizer.save("tokenizer.json")






# train from memory

In [15]:
# def batch_iterator(batch_size=65536):
#     for i in range(0, len(dataset), batch_size):
#         yield dataset[i : i + batch_size]['translation']['en']
#         yield dataset[i : i + batch_size]['translation']['de']
        
        
# from tokenizers.trainers import BpeTrainer

# trainer = BpeTrainer(vocab_size=37000, show_progress=True, special_tokens=["<bos>", "<eos>", "<pad>"])
# tokenizer.train_from_iterator(batch_iterator(),trainer=trainer,length=len(dataset))

# tokenizer.save("tokenizer.json")

# use tokenizer

### add data to memory

In [16]:
de_en_pairs = []
for i in range(len(dataset)):
    de_en_pairs.append((dataset[i]['translation']['de'],dataset[i]['translation']['en']))

In [17]:
de_en_pairs = sorted(de_en_pairs,key=lambda x:len(x[0])+len(x[1]))

In [18]:
en_sents=[]
de_sents=[]
for pairs in de_en_pairs:
    en_sents.append(pairs[1])
    de_sents.append(pairs[0])

In [19]:
# en_sents[-1],de_sents[-1]

# Train DataLoader

### in paper there are 25000 source tokens and 25000 target tokens,here 4096 sourece tokens and 4096 target tokens for 1 gpu
* for safe, 3125 for 1gpu and use gradient accumulation to update for 8 batch
* for safe, 1024 for 1gpu for 24 batch

In [20]:
import torch
import numpy as np
from tokenizers import Tokenizer
tokenizer = Tokenizer.from_file("tokenizer.json")
def batch_generator(dataset,gpu_num=1,max_len=3125):
    en_cnt = 0
    de_cnt = 0
    en_batch = []
    de_batch = []
    batch_size = 0
    for pairs in dataset:
        
        en_batch.append(pairs[1])
        de_batch.append(pairs[0])
        en_cnt += len(pairs[1])
        de_cnt += len(pairs[0])
        batch_size += 1
        
        if batch_size%gpu_num == 0:          
            if en_cnt + de_cnt > max_len*gpu_num:

                en_output = tokenizer.encode_batch(en_batch[:-gpu_num])
                de_output = tokenizer.encode_batch(de_batch[:-gpu_num])
                
                
                en_ids = [] 
                de_ids = []
                target_en_ids = []
                target_de_ids = []
                en_padding_mask = []
                de_padding_mask = []

                for en in en_output:
                    en_ids.append(en.ids)
                    target_en_ids.append(en.ids[1:]+[pad_id])
                    en_padding_mask.append(en.attention_mask)
                    
                for de in de_output:
                    de_ids.append(de.ids)
                    target_de_ids.append(de.ids[1:]+[pad_id])
                    de_padding_mask.append(de.attention_mask)              

                yield torch.LongTensor(en_ids).t().contiguous(),\
                        torch.LongTensor(de_ids).t().contiguous(),\
                        torch.LongTensor(target_en_ids).t().contiguous(),\
                        torch.LongTensor(target_de_ids).t().contiguous(),\
                        torch.BoolTensor(1-np.array(en_padding_mask)),\
                        torch.BoolTensor(1-np.array(de_padding_mask))
            

                en_cnt = 0
                de_cnt = 0            
                en_batch = en_batch[-gpu_num:]
                de_batch = de_batch[-gpu_num:]            

    if en_ids:
        yield torch.LongTensor(en_ids).t().contiguous(),\
                torch.LongTensor(de_ids).t().contiguous(),\
                torch.LongTensor(target_en_ids).t().contiguous(),\
                torch.LongTensor(target_de_ids).t().contiguous(),\
                torch.BoolTensor(1-np.array(en_padding_mask)),\
                torch.BoolTensor(1-np.array(de_padding_mask))
    

In [21]:
#

### check (padding_mask in pytorch require pos of \<pad> is True)
* the shape of mask is (B,S), yes, it's batch first ,different from the input (S,B,E)

In [22]:

for en_ids,de_ids,target_en_ids,target_de_ids,\
    en_padding_mask,de_padding_mask in batch_generator(dataset=de_en_pairs,gpu_num=1):
    print(en_ids.shape)
    print(en_ids)
    
    print(en_padding_mask.shape)
    print(en_padding_mask)    
    
    print(target_en_ids.shape)
    print(target_en_ids)

    print("*"*30)
    print(de_ids.shape)
    print(de_ids)
    
    print(de_padding_mask.shape)
    print(de_padding_mask)  
    
    print(target_de_ids.shape)
    print(target_de_ids)   
    
    en_sent = en_ids.t().tolist()
    de_sent = de_ids.t().tolist()
    print(tokenizer.decode(en_sent[-1]))
    print(tokenizer.decode(de_sent[-1]))
    print(tokenizer.decode(en_sent[-1],skip_special_tokens = False))
    print(tokenizer.decode(de_sent[-1],skip_special_tokens = False))
    if (en_ids.shape[0]>=7):
        break

torch.Size([5, 1115])
tensor([[   0,    0,    0,  ...,    0,    0,    0],
        [ 828,  828,  828,  ..., 2045,  418,  933],
        [   1,    1,    1,  ...,   16,   16,   16],
        [   2,    2,    2,  ...,   19,   23,   19],
        [   2,    2,    2,  ...,    1,    1,    1]])
torch.Size([1115, 5])
tensor([[False, False, False,  True,  True],
        [False, False, False,  True,  True],
        [False, False, False,  True,  True],
        ...,
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False]])
torch.Size([5, 1115])
tensor([[ 828,  828,  828,  ..., 2045,  418,  933],
        [   1,    1,    1,  ...,   16,   16,   16],
        [   2,    2,    2,  ...,   19,   23,   19],
        [   2,    2,    2,  ...,    1,    1,    1],
        [   2,    2,    2,  ...,    2,    2,    2]])
******************************
torch.Size([5, 1115])
tensor([[   0,    0,    0,  ...,    0,    0,    0],
        [ 828,  828,  8

# Vaild dataloader

In [23]:
from datasets import load_dataset
valid_dataset = load_dataset("wmt14", 'de-en', split='validation')

valid_de_en_pairs = []
for i in range(len(valid_dataset)):
    valid_de_en_pairs.append((valid_dataset[i]['translation']['de'],valid_dataset[i]['translation']['en']))

Found cached dataset wmt14 (/data2/zrs/.cache/huggingface/datasets/wmt14/de-en/1.0.0/2de185b074515e97618524d69f5e27ee7545dcbed4aa9bc1a4235710ffca33f4)


### check

In [24]:

for en_ids,de_ids,target_en_ids,target_de_ids,\
    en_padding_mask,de_padding_mask in batch_generator(dataset=valid_de_en_pairs,gpu_num=1):
    print(en_ids.shape)
    print(en_ids)
    
    print(en_padding_mask.shape)
    print(en_padding_mask)    
    
    print(target_en_ids.shape)
    print(target_en_ids)

    print("*"*30)
    print(de_ids.shape)
    print(de_ids)
    
    print(de_padding_mask.shape)
    print(de_padding_mask)  
    
    print(target_de_ids.shape)
    print(target_de_ids)   
    
    en_sent = en_ids.t().tolist()
    de_sent = de_ids.t().tolist()
    print(tokenizer.decode(en_sent[-1]))
    print(tokenizer.decode(de_sent[-1]))
    print(tokenizer.decode(en_sent[-1],skip_special_tokens = False))
    print(tokenizer.decode(de_sent[-1],skip_special_tokens = False))
    if (en_ids.shape[0]>=7):
        break

torch.Size([38, 12])
tensor([[    0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0],
        [  264,  5864,  2218,  7993,  3804,   368, 23837,   711,   791,  1089,
           401,  6729],
        [ 5864,   228,    14,    14,  4776,   438,   235,   278, 16021,   213,
           836,    14],
        [  228,  5395,   229,  5864,   278,  5409, 12283,   235, 14078,  2351,
         10165,   981],
        [ 3979, 14585, 12961,   228,  2011,    14,    14,   438, 28485,    14,
          7805,  7805],
        [  262,   762,   228, 23392,    28,   229,   229,  6831,  3246, 17776,
          1870,   617],
        [ 7203,  1413,  6718, 13655,   981,  2322,  4830,   344,   229, 36027,
         20431,  5889],
        [  229,   471, 16279,   914,   836,   491,  1111,   213,  4229, 34801,
           262,  4838],
        [  305,   229,   438,  5975,  6393, 23112,   389,  4810,  2686,   229,
          2902,  7216],
        [   15,   898,   213,  4490,   491, 20479,  

# other info

In [25]:
print(tokenizer.get_vocab_size())
print(tokenizer.token_to_id('<bos>'))
print(tokenizer.token_to_id('<eos>'))
print(tokenizer.token_to_id('<pad>'))

37000
0
1
2
