In [1]:
import os
import sys
import re
import tarfile
import time
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from typing import Optional

In [5]:
import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import torch.nn as nn
from torch.nn import Embedding, TransformerDecoderLayer, TransformerDecoder, TransformerEncoderLayer, TransformerEncoder
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
# from torchtext.data import Field, BucketIterator, TabularDataset
from torch.nn.init import xavier_uniform_

In [16]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [63]:
path = './data/raw.tar.gz'
savedir = './data/'
with tarfile.open(path, 'r:*') as tar:
        tar.extractall(savedir)

filename = 'raw'
with open(os.path.join(savedir + 'raw/', filename), 'r') as f:
    raw_data = f.readlines()

raw_list = [re.sub('\n', '', s).split('\t') for s in raw_data]
raw_df = pd.DataFrame(raw_list,
                  columns=['英語', '日本語'])
raw_df = raw_df.iloc[:10000]

In [64]:
raw_df

Unnamed: 0,英語,日本語
0,"you are back, aren't you, harold?",あなたは戻ったのね ハロルド?
1,my opponent is shark.,俺の相手は シャークだ。
2,this is one thing in exchange for another.,引き換えだ ある事とある物の
3,"yeah, i'm fine.",もういいよ ごちそうさま ううん
4,don't come to the office anymore. don't call m...,もう会社には来ないでくれ 電話もするな
...,...,...
9995,"so, i'm going to start asking those questions",次にお話しするdna構造で この問題に 取り組んでみたいと思います
9996,ak: it's about a boy who falls in love with a ...,少年とその愛馬の物語です
9997,got to take the perks where i can get 'em.,特権は利用しないと
9998,i do not fall.,大事な仲間に そんなことはできない。


In [65]:
import en_core_web_sm
import ja_core_news_sm
nlp_ja = ja_core_news_sm.load()
nlp_en = en_core_web_sm.load()
def tokenizer_ja(sentence):
    return [tok.text for tok in nlp_ja.tokenizer(sentence)]
def tokenizer_en(sentence):
    return [tok.text for tok in nlp_en.tokenizer(sentence)]

In [66]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import multi30k, Multi30k
from typing import Iterable, List


# We need to modify the URLs for the dataset since the links to the original dataset are broken
# Refer to https://github.com/pytorch/text/issues/1756#issuecomment-1163664163 for more info
# multi30k.URL["train"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz"
# multi30k.URL["valid"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz"

TGT_LANGUAGE = '英語'
SRC_LANGUAGE = '日本語'

In [70]:
# Place-holders
token_transform = {}
vocab_transform = {}


# Create source and target language tokenizer. Make sure to install the dependencies.
# pip install -U torchdata
# pip install -U spacy
# python -m spacy download en_core_web_sm
# python -m spacy download de_core_news_sm
token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='ja_core_news_sm')
token_transform[TGT_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')


# # helper function to yield list of tokens
def yield_tokens(data_iter: Iterable, language: str) -> List[str]:
    for data_sample in data_iter[language]:
        yield token_transform[language](data_sample)

# # Define special symbols and indices
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
# # Make sure the tokens are in order of their indices to properly insert them in vocab
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    # Training data Iterator
    # Create torchtext's Vocab object
    vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(raw_df, ln),
                                                    min_freq=1,
                                                    specials=special_symbols,
                                                    special_first=True)

# # Set UNK_IDX as the default index. This index is returned when the token is not found.
# # If not set, it throws RuntimeError when the queried token is not found in the Vocabulary.
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    vocab_transform[ln].set_default_index(UNK_IDX)

In [77]:
vocab_transform[TGT_LANGUAGE](list(yield_tokens(raw_df, TGT_LANGUAGE))[1])

[32, 1480, 16, 1537, 4]

In [79]:
list(yield_tokens(raw_df, TGT_LANGUAGE))[2]

['this', 'is', 'one', 'thing', 'in', 'exchange', 'for', 'another', '.']

In [80]:
list(yield_tokens(raw_df, SRC_LANGUAGE))[2]

['引き換え', 'だ', 'ある', '事', 'と', 'ある', '物', 'の']

In [84]:
torch.manual_seed(0)

SRC_VOCAB_SIZE = len(vocab_transform[SRC_LANGUAGE])
TGT_VOCAB_SIZE = len(vocab_transform[TGT_LANGUAGE])
EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512
BATCH_SIZE = 128
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3

In [81]:
from torch.nn.utils.rnn import pad_sequence

# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids: List[int]):
    return torch.cat((torch.tensor([BOS_IDX]),
                      torch.tensor(token_ids),
                      torch.tensor([EOS_IDX])))

# src and tgt language text transforms to convert raw strings into tensors indices
text_transform = {}
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    text_transform[ln] = sequential_transforms(token_transform[ln], #Tokenization
                                               vocab_transform[ln], #Numericalization
                                               tensor_transform) # Add BOS/EOS and create tensor

In [103]:
text_transform[SRC_LANGUAGE](raw_df[SRC_LANGUAGE].tolist()[0])

tensor([   2,   55,    6,  290,    8,    4,   35, 1659,   12,    3])

In [83]:
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform[SRC_LANGUAGE](src_sample.rstrip("\n")))
        tgt_batch.append(text_transform[TGT_LANGUAGE](tgt_sample.rstrip("\n")))

    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
    return src_batch, tgt_batch

In [85]:
from torch.utils.data import DataLoader

In [82]:
text_transform[TGT_LANGUAGE]

<function __main__.sequential_transforms.<locals>.func(txt_input)>