In [5]:
!pip install huggingface

Collecting huggingface
  Downloading huggingface-0.0.1-py3-none-any.whl.metadata (2.9 kB)
Downloading huggingface-0.0.1-py3-none-any.whl (2.5 kB)
Installing collected packages: huggingface
Successfully installed huggingface-0.0.1


In [6]:
from typing import Any
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace

from pathlib import Path

In [7]:
def get_sentences(dataset, lang):
    for data in dataset: 
        yield item['translation'][lang]

In [8]:
 def get_or_build_tokenizer(config, ds, lang): 
        tokenizer_path = Path(config['tokenizer_file'].format(lang))
        if not Path.exists(tokenizer_path):
            tokenizer = Tokenizer(WordLevel(unk_token='[UNK]'))
            tokenizer.pre_tokenizer = Whitespace()
            trainer = WordLevelTrainer(special_tokens=['[UNK]', '[PAD]', '[SOS]', '[EOS]'], min_frequency=2)
            tokenizer.train_from_iterator(get_sentences(ds, lang))
            tokenizer.save(str(tokenizer_path))
        else: 
            tokenizer = Tokenzier.from_file(str(tokenizer_path))
        return tokenizer

In [14]:
def get_dataset(config): 
    raw = load_dataset('opus_books', f'{config["lang_src"]}-{config["lang_tgt"]}', split='train')
    
    tokenizer_src = get_or_build_tokenizer(config, raw, config["lang_src"])
    tokenizer_tgt = get_or_build_tokenizer(config, raw, config["lang_tgt"])
    
    train_ds = len(0.9 * len(raw))
    val_ds = len(raw) - train_ds
    train_raw, val_raw = random_split(raw, [train_ds, val_ds])
    
    train_ds = BilingualDataset(train_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
    val_ds = BilingualDataset(val_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
    
    max_src = 0 
    max_tgt = 0
    for item in raw: 
        src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
        tgt_ids = tokenizer_src.encode(item['translation'][config['lang_tgt']]).ids
        max_src = max(max_src, len(src_ids))
        max_tgt = max(max_tgt, len(tgt_ids))
        
    print(f'Max length of source: {max_src}')
    print(f'Max length of target: {max_tgt}')
    
    train_loader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=1, shuffle=True)
    
    return train_loader, val_loader, tokenizer_src, tokenizer_tgt

In [15]:
def get_model(config, vocab_src_len, vocab_tgt_len): 
    model = build_transformer(vocab_src_len, vocab_tgt_len, config['seq_len'], config['seq_len'], config['d_model'])
    return model

In [10]:
def causal_mask(size): 
    mask = torch.triu(torch.ones(1, size, size), diagonal=1).type(torch.int)
    return mask == 0

In [11]:
class BilingualDataset(Dataset): 
    def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len) -> None: 
        super().__init__()
        self.ds = ds
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang
        
        self.eos_token = torch.Tensor([tokenizer_src.token_to_id(['[EOS]'])], dtype=torch.int64)
        self.sos_token = torch.Tensor([tokenizer_src.token_to_id(['[SOS]'])], dtype=torch.int64)
        self.pad_token = torch.Tensor([tokenizer_src.token_to_id(['[PAD]'])], dtype=torch.int64)
    
    def __len__(self): 
        return len(self.ds)
    
    def __getitem__(self, index: Any) -> Any: 
        src_pair = self.ds[index]
        src_text = src_pair['translation'][src_lang]
        tgt_text = src_pair['translation'][tgt_lang]
        
        enc_tokens = self.tokenizer_src.encode(src_text).ids
        dec_tokens = self.tokenizer_src.decode(tgt_text).ids
        
        enc_padding = self.seq_len - len(enc_tokens) - 2
        dec_padding = self.seq_len - len(dec_tokens) - 1
        
        if enc_padding < 0 or dec_padding < 0: 
            raise ValueError('Too long')
        
        enc_input = torch.cat(
        [
            self.sos_token, 
            torch.tensor(enc_tokens, dtype=torch.int64), 
            self.eos_token, 
            torch.tensor([self.pad_token] * enc_padding, type=torch.int64)
        ])
        
        dec_input = torch.cat(
        [
            self.sos_token, 
            torch.tensor(dec_input, dtype=torch.int64), 
            torch.tensor([self.pad_token] * dec_padding, type=torch.int64)
        ])
        
        label = torch.cat(
        [
            torch.tensor(dec_input, type=torch.int64), 
            self.eos_token, 
            torch.tensor([self.pad_token] * dec_padding, type=torch.int64)
        ])
        
        assert enc_input.size(0) == self.seq_len
        assert dec_input.size(0) == self.seq_len
        assert label.size(0) == self.seq_len
        
        return {
            "encoder_input": enc_input, 
            "decoder_input": dec_input, 
            "encoder_mask": (enc_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), 
            "decoder_mask": (dec_input != self.pad_token).unsqueeze(0).unsqueeze(0).int() & causal_mask(dec_input.size(0)),
            "label": label,
            "src_text": src_text, 
            "tgt_text": tgt_text
        }