In [None]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import jieba
import os
import pickle
from tqdm.notebook import tqdm
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.utils import get_tokenizer
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from torch.nn.utils.rnn import pad_sequence
from pathlib import Path

In [None]:
# Load translation dataset
os.environ['HF_DATASETS_OFFLINE'] = '1' # Comment this if you want to download the dataset from huggingface
dataset = load_dataset('wmt19', 'zh-en')
print(dataset)

SRC_LANGUAGE = 'zh'
TGT_LANGUAGE = 'en'

SUBSET_SIZE = 30000

VOCAB_MIN_FREQ = 10

In [None]:
def get_token_transform_en():
    from torchtext.data.utils import get_tokenizer
    return get_tokenizer('basic_english')

def get_token_transform_zh():
    import jieba
    return lambda text: ([x for x in jieba.lcut(text) if x not in {' ', '\t'}])

token_transform = {}
token_transform[TGT_LANGUAGE] = get_token_transform_en()
token_transform[SRC_LANGUAGE] = get_token_transform_zh()

In [None]:
def yield_tokens(data_iter, language):
    for data in data_iter:
        yield token_transform[language](data[language])

UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
SPECIAL_SYMBOLS = ['<UNK>', '<PAD>', '<BOS>', '<EOS>']

VOCAB_PATH = './Data/Vocab.pkl'
if Path(VOCAB_PATH).exists():
    with open(VOCAB_PATH, 'rb') as f:
        vocab_transform = pickle.load(f)
else:
    vocab_transform = {}
    for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
        train_iter = iter(dataset['train'][:SUBSET_SIZE]['translation'])
        vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(train_iter, ln), min_freq=VOCAB_MIN_FREQ, specials=SPECIAL_SYMBOLS, special_first=True)
    for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
        vocab_transform[ln].set_default_index(UNK_IDX)
    with open(VOCAB_PATH, 'wb') as f:
        pickle.dump(vocab_transform, f)

SRC_VOCAB_SIZE = len(vocab_transform[SRC_LANGUAGE])
TGT_VOCAB_SIZE = len(vocab_transform[TGT_LANGUAGE])
print(f'Vocab-{TGT_LANGUAGE} Size: {TGT_VOCAB_SIZE}')
print(f'Vocab-{SRC_LANGUAGE} Size: {SRC_VOCAB_SIZE}')
# print(vocab_transform['en'](['it', 'was', 'later', 'realized', 'that', 'the', 'signal', 'they', 'had', 'detected', 'could', 'be', 'entirely', 'attributed', 'to', 'interstellar', 'dust', '.']))
# print(vocab_transform['zh'](['但', '后来', '他们', '逐渐', '意识', '到', '所', '探测', '到', '的', '信号', '可能', '完全', '来源于', '星际', '尘埃', '。']))

In [None]:
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

def tensor_transform(token_ids):
    return torch.cat((torch.tensor([BOS_IDX]),
                      torch.tensor(token_ids),
                      torch.tensor([EOS_IDX])))

tokenizer = {}

for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    tokenizer[ln] = sequential_transforms(token_transform[ln], vocab_transform[ln], tensor_transform)

def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(tokenizer[SRC_LANGUAGE](src_sample.rstrip("\n")))
        tgt_batch.append(tokenizer[TGT_LANGUAGE](tgt_sample.rstrip("\n")))

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
    return src_batch, tgt_batch

print(tokenizer[TGT_LANGUAGE]("It was later realized that the signal they had detected could be entirely attributed to interstellar dust."))
print(tokenizer[SRC_LANGUAGE]("但后来他们逐渐意识到所探测到的信号可能完全来源于星际尘埃。"))

In [None]:
class WMT19Dataset(Dataset):
    def __init__(self, dataset, subset_size = None):
        self.dataset = dataset
        self.subset_size = subset_size

    def __len__(self):
        if self.subset_size is None:
            return len(self.dataset)
        return self.subset_size

    def __getitem__(self, idx):
        return self.dataset[idx]['translation'][SRC_LANGUAGE], self.dataset[idx]['translation'][TGT_LANGUAGE]

train_dataset = WMT19Dataset(dataset['train'], SUBSET_SIZE)
valid_dataset = WMT19Dataset(dataset['validation'])

print(f'Train dataset size: {len(train_dataset)}')
print(f'Validation dataset size: {len(valid_dataset)}')

BATCH_SIZE = 64
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn)
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn)