In [6]:
import stanza
from tqdm import tqdm

stanza.download('zh', processors='tokenize')
nlp = stanza.Pipeline('zh', processors='tokenize')

Downloading https://raw.githubusercontent.com/stanfordnlp/stanza-resources/main/resources_1.8.0.json: 379kB [00:00, 50.1MB/s]                    
2024-04-07 16:09:27 INFO: Downloaded file to /Users/seeusim/stanza_resources/resources.json
2024-04-07 16:09:27 INFO: "zh" is an alias for "zh-hans"
2024-04-07 16:09:27 INFO: Downloading these customized packages for language: zh-hans (Simplified_Chinese)...
| Processor | Package |
-----------------------
| tokenize  | gsdsimp |

2024-04-07 16:09:27 INFO: File exists: /Users/seeusim/stanza_resources/zh-hans/tokenize/gsdsimp.pt
2024-04-07 16:09:27 INFO: Finished downloading models and saved to /Users/seeusim/stanza_resources
2024-04-07 16:09:27 INFO: Checking for updates to resources.json in case models have been updated.  Note: this behavior can be turned off with download_method=None or download_method=DownloadMethod.REUSE_RESOURCES
Downloading https://raw.githubusercontent.com/stanfordnlp/stanza-resources/main/resources_1.8.0.json: 379kB [0

In [7]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import stanza

In [None]:
SRC_LANGUAGE = 'en'
TGT_LANGUAGE = 'zh'

DATA_DIR = '../data'

# Define special symbols and indices
UNK_IDX, BOS_IDX, EOS_IDX, PAD_IDX = 0, 1, 2, 3
# Make sure the tokens are in order of their indices to properly insert them in vocab
special_symbols = ['<unk>', '<s>', '</s>', '<pad>']


In [8]:
nlp = stanza.Pipeline(lang='zh', processors='tokenize')

2024-04-07 16:09:27 INFO: Checking for updates to resources.json in case models have been updated.  Note: this behavior can be turned off with download_method=None or download_method=DownloadMethod.REUSE_RESOURCES
Downloading https://raw.githubusercontent.com/stanfordnlp/stanza-resources/main/resources_1.8.0.json: 379kB [00:00, 58.0MB/s]                    
2024-04-07 16:09:27 INFO: Downloaded file to /Users/seeusim/stanza_resources/resources.json
2024-04-07 16:09:27 INFO: "zh" is an alias for "zh-hans"
2024-04-07 16:09:27 INFO: Loading these models for language: zh-hans (Simplified_Chinese):
| Processor | Package |
-----------------------
| tokenize  | gsdsimp |

2024-04-07 16:09:27 INFO: Using device: cpu
2024-04-07 16:09:27 INFO: Loading: tokenize
2024-04-07 16:09:27 INFO: Done loading processors!


In [9]:
# Place-holders
token_transform = {}
vocab_transform = {}

def tokenise_chinese(sent):
    return [word.text for sentence in nlp(sent).sentences for word in sentence.words]

token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_lg')
token_transform[TGT_LANGUAGE] = get_tokenizer(tokenise_chinese)

In [12]:
from tqdm import tqdm


def yield_tokens(data_iter, language: str):
    for data_sample in tqdm(data_iter):
        yield token_transform[language](data_sample)


for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    path = f'{DATA_DIR}/iwslt2017-en-zh-train.{ln}'
    with open(path, 'r') as f:
        # Training data Iterator
        train_iter = f.readlines()
        # Create torchtext's Vocab object
        vocab_transform[ln] = build_vocab_from_iterator(
            yield_tokens(train_iter, ln),
            min_freq=1,
            specials=special_symbols,
            special_first=True
        )

# Set ``UNK_IDX`` as the default index. This index is returned when the token is not found.
# If not set, it throws ``RuntimeError`` when the queried token is not found in the Vocabulary.
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    vocab_transform[ln].set_default_index(UNK_IDX)

100%|██████████| 231266/231266 [00:04<00:00, 56006.36it/s]
100%|██████████| 231266/231266 [08:48<00:00, 437.82it/s]


In [13]:
import torch

for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    torch.save(vocab_transform[ln], f'word-{ln}.vocab')

# Testing

## Encoding

In [26]:
en_vocab = torch.load('./word-en.vocab')

In [27]:
en_ids = en_vocab(token_transform['en'](
    'I want to go to space.'
))

en_ids

[13, 111, 8, 112, 8, 294, 6]

In [33]:
zh_vocab = torch.load('./word-zh.vocab')

zh_ids = zh_vocab(token_transform['zh'](
    '我想去欧洲。'
))

zh_ids

[7, 45, 80, 806, 6]

## Decoding

In [34]:
' '.join(en_vocab.lookup_tokens(en_ids))

'I want to go to space .'

In [35]:
''.join(zh_vocab.lookup_tokens(zh_ids))

'我想去欧洲。'