In [None]:
!pip install stanza torch torchtext spacy[cuda]
!python -m spacy download en_core_web_lg

In [None]:
import stanza
import torch
from torchtext.data.utils import get_tokenizer

In [None]:
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

stanza.download('zh', processors='tokenize')
nlp = stanza.Pipeline(lang='zh', processors='tokenize')

# Setting up the Vocab and Transforms

In [None]:
SRC_LANGUAGE = 'en'
TGT_LANGUAGE = 'zh'

MAX_SEQ_LEN = 288

# Define special symbols and indices
UNK_IDX, BOS_IDX, EOS_IDX, PAD_IDX = 0, 1, 2, 3
# Make sure the tokens are in order of their indices to properly insert them in vocab
special_symbols = ['<unk>', '<s>', '</s>', '<pad>']

In [None]:
# Place-holders
token_transform = {}
vocab_transform = {}

def tokenise_chinese(sent):
    return [word.text for sentence in nlp(sent).sentences for word in sentence.words]

token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_lg')
token_transform[TGT_LANGUAGE] = get_tokenizer(tokenise_chinese)

for lang in [SRC_LANGUAGE, TGT_LANGUAGE]:
    vocab_transform[lang] = torch.load(f'./word-{lang}.vocab')

# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

def tensor_transform_src(token_ids: list[int]):
    return torch.cat((
        torch.tensor(token_ids[:MAX_SEQ_LEN]),              
        torch.tensor([PAD_IDX] * (MAX_SEQ_LEN - len(token_ids)))
    )).to(DEVICE)

def tensor_transform_tgt_inp(token_ids: list[int]):
    return torch.cat((
        torch.tensor([BOS_IDX]),
        torch.tensor(token_ids[:MAX_SEQ_LEN - 1]),
        torch.tensor([PAD_IDX] * (MAX_SEQ_LEN - len(token_ids) - 1)),
    )).to(DEVICE)

def tensor_transform_tgt_out(token_ids: list[int]):
    return torch.cat((
        torch.tensor(token_ids[:MAX_SEQ_LEN - 1]),
        torch.tensor([EOS_IDX]),
        torch.tensor([PAD_IDX] * (MAX_SEQ_LEN - len(token_ids) - 1)),
    )).to(DEVICE)


# ``src`` and ``tgt`` language text transforms to convert raw strings into tensors indices
text_transform = {}
text_transform[SRC_LANGUAGE] = sequential_transforms(
    token_transform[SRC_LANGUAGE], #Tokenization
    vocab_transform[SRC_LANGUAGE], #Numericalization
    tensor_transform_src # Add BOS/EOS and create tensor
)

text_transform[TGT_LANGUAGE] = (
    sequential_transforms(
        token_transform[SRC_LANGUAGE], #Tokenization
        vocab_transform[SRC_LANGUAGE], #Numericalization
        tensor_transform_tgt_inp, # Add BOS/EOS and create tensor
    ),
    sequential_transforms(
        token_transform[SRC_LANGUAGE], #Tokenization
        vocab_transform[SRC_LANGUAGE], #Numericalization
        tensor_transform_tgt_out, # Add BOS/EOS and create tensor
    )
)


# function to collate data samples into batch tensors
def collate_fn(batch):
    src_batch, tgt_inp_batch, tgt_out_batch = [], [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform[SRC_LANGUAGE](src_sample.strip()))
        tgt_inp_batch.append(text_transform[TGT_LANGUAGE][0](tgt_sample.strip()))
        tgt_out_batch.append(text_transform[TGT_LANGUAGE][1](tgt_sample.strip()))

    return src_batch, tgt_inp_batch, tgt_out_batch

# Example Loading of data for batching during training

In [None]:
from torch.utils.data import DataLoader

BATCH_SIZE = 80
DATA_DIR = '../data/iwslt2017-en-zh-{file}.{lang}'

en_train = DATA_DIR.format(file='train', lang='en')
zh_train = DATA_DIR.format(file='train', lang='zh')
en_validation = DATA_DIR.format(file='validation', lang='en')
zh_validation = DATA_DIR.format(file='validation', lang='zh')

with open(en_train, 'r') as f1, open(zh_train, 'r') as f2:
    src = f1.readlines()
    tgt = f2.readlines()
    train_iter = zip(src, tgt)
with open(en_validation, 'r') as f1, open(zh_validation, 'r') as f2:
    src = f1.readlines()
    tgt = f2.readlines()
    val_iter = zip(src, tgt)


train_loader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn, pin_memory=False)
val_loader = DataLoader(val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn, pin_memory=False)