In [64]:
import pickle as pkl
from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import random
import torch
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import numpy as np, pandas as pd

In [98]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cpu'

## Data Pre-processing

In [107]:
SOS_token = 0
EOS_token = 1
PAD_IDX = 2
UNK_IDX = 3
class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "SOS", 1: "EOS", 2:"UNK", 3:"PAD"}
        self.n_words = 4  # Count SOS and EOS

    def addSentence(self, sentence):
        for word in sentence:
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

In [4]:
def normalizeString(s):
#     s = s.lower().strip()
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"&apos;m", r"am", s)
    s = re.sub(r"&apos;s", r"is", s)
    s = re.sub(r"&apos;re", r"are", s)
    s = re.sub(r"&apos;", r"", s)
    return s

In [None]:
def getPairsCount(arr, n, sum): 
      
    count = 0 # Initialize result 
  
    # Consider all possible pairs 
    # and check their sums 
    for i in range(0, n): 
        for j in range(i + 1, n): 
            if arr[i] + arr[j] == sum: 
                count += 1
      
    return count 

In [5]:
def loadingLangs(sourcelang, targetlang, setname):
    input_ls = []
    output_ls = []
    print('Reading lines...')
    # Read the file 
    with open('iwslt-%s-%s/%s.tok.%s'%(sourcelang, targetlang, setname,sourcelang)) as f:
        for line in f.readlines():
            input_ls.append([normalizeString(word) for word in line.split()])
    with open('iwslt-%s-%s/%s.tok.%s'%(sourcelang, targetlang, setname,targetlang)) as f:
        for line in f.readlines():
            output_ls.append([normalizeString(word) for word in line.split()])
    pairs = list(zip(input_ls, output_ls))
    print('Read %s sentence pairs'%(len(input_ls)))
    input_lang = Lang(sourcelang)
    output_lang = Lang(targetlang)
    print("Counting words...")
    for pair in pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])
    print("Counted words:")
    print(input_lang.name, input_lang.n_words)
    print(output_lang.name, output_lang.n_words)
    return input_lang, output_lang, pairs
    

In [48]:
source_tra, target_tra, pairs_tra = loadingLangs('zh', 'en', 'train')
# source_val, target_val, pairs_val = loadingLangs('zh', 'en', 'dev')
# source_tes, target_tes, pairs_tes = loadingLangs('zh', 'en', 'test')

Reading lines...
Read 213377 sentence pairs
Counting words...
Counted words:
zh 88423
en 60635
Reading lines...
Read 1261 sentence pairs
Counting words...
Counted words:
zh 6130
en 3772


In [7]:
# def index_dataset(sourcelang, targetlang, dataset_pairs):
#     pair_idx= []
    
#     for pair in dataset_pairs:
#         input_index = [sourcelang.word2index[word] for word in pair[0]]
#         input_index.append(EOS_token)
#         output_index = [targetlang.word2index[word] for word in pair[1]]
#         output_index.append(EOS_token)
#         pair_index = input_index.append(output_index)
#         pair_idx.append(pair_index)
#     return pair_idx

## Dataset 

In [54]:
MAX_SENT_LEN = 40

In [100]:
def indexesFromSentence(lang, sentence):
    return [lang.word2index[word] if word in lang.word2index else UNK_IDX for word in sentence]


def tensorFromSentence(lang, sentence):
    indexes = indexesFromSentence(lang, sentence)
    indexes.append(EOS_token)
    return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)


def tensorsFromPair(pair,source,target):
    input_lang = source
    output_lang = target
    input_tensor = tensorFromSentence(input_lang, pair[0]).reshape((-1))
    target_tensor = tensorFromSentence(output_lang, pair[1]).reshape((-1))
    return (input_tensor, input_tensor.shape[0], target_tensor, target_tensor.shape[0])

In [101]:
class NMTDataset(Dataset):
    def __init__(self, source, target, pairs):
        self.source = source
        self.target = target
        self.pairs = pairs
        
    def __len__(self):
        return len(pairs)
    
    def __getitem__(self, key):
        """
        Triggered when you call dataset[i]
        """
        inp_ten, inp_len, tar_ten, tar_len = tensorsFromPair(self.pairs[key], self.source, self.target)
        item = {}
        item['inputtensor'] = inp_ten[:MAX_SENT_LEN]
        item['inputlen'] = min(inp_len, MAX_SENT_LEN)
        item['targettensor'] = tar_ten[:MAX_SENT_LEN]
        item['targetlen'] = min(tar_len, MAX_SENT_LEN)
        return item

In [102]:
train_data = NMTDataset(source_tra, target_tra, pairs_tra)

In [103]:
train_data.__getitem__(234)

{'inputtensor': tensor([  47,  869,   14, 1233,  452, 1110,    4,   82,   83,  320,  396,  308,
            4, 1234, 1235,  733,   55, 1236,  389,  619,  609,  610,  611,   82,
           83,   14, 1237,   16,  883,    4, 1238,    1]),
 'inputlen': 32,
 'targettensor': tensor([ 45,  83,  48,  49, 545, 193,  71,  28,  42,  23,  50,  19, 197, 788,
          19,  20,  21,  52,  71, 193,  48,  49, 872, 867,  41,   1]),
 'targetlen': 26}

## Dataloader

In [104]:
#collate function

def collate_func(batch):
    """
    Customized function for DataLoader that dynamically pads the batch so that all
    data have the same length
    """
    src_data, tar_data, src_len, tar_len = [], [], [], []
    for datum in batch:        
        src_datum = np.pad(np.array(datum['inputtensor']),
                                pad_width=((0,MAX_SENT_LEN-datum['inputlen'])),
                                mode="constant", constant_values=PAD_IDX)
        tar_datum = np.pad(np.array(datum['targettensor']),
                                pad_width=((0,MAX_SENT_LEN-datum['targetlen'])),
                                mode="constant", constant_values=PAD_IDX)
        src_data.append(src_datum)
        tar_data.append(tar_datum)
        src_len.append(datum['inputlen'])
        tar_len.append(datum['targetlen'])
    return [torch.from_numpy(np.array(src_data)),torch.from_numpy(np.array(tar_data)),
               torch.from_numpy(np.array(src_len)),torch.from_numpy(np.array(tar_len))]

In [105]:
BATCH_SIZE = 32
train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=BATCH_SIZE,shuffle=True,collate_fn=collate_func)

In [108]:
# sample data loader
for data in train_loader:
    print('input sentence batch: ')
    print(data[0])
    print('target sentence batch: ')
    print(data[1])
    print('input sentence len: ')
    print(data[2])
    print('target sentence len: ')
    print(data[3])
    break

input sentence batch: 
tensor([[12912,    14,   270,  ...,     2,     2,     2],
        [   14,   238,   379,  ...,     2,     2,     2],
        [ 1027,     1,     2,  ...,     2,     2,     2],
        ...,
        [  328,   956,    47,  ...,     2,     2,     2],
        [  955,   206,   250,  ...,     2,     2,     2],
        [  186,   110,  1801,  ...,     2,     2,     2]])
target sentence batch: 
tensor([[  48,   57,   63,  ...,    2,    2,    2],
        [  48,  290, 8552,  ...,    2,    2,    2],
        [ 745,   52,   41,  ...,    2,    2,    2],
        ...,
        [  28,  173,   61,  ...,   12,  284, 3254],
        [ 102,   49,  214,  ...,    2,    2,    2],
        [ 190,  689,   21,  ...,    2,    2,    2]])
input sentence len: 
tensor([11, 24,  2,  7, 32, 40, 19, 11, 37, 10, 12, 21, 12, 12, 27, 24, 30, 29,
        18, 40, 19,  3,  6,  4, 40, 36, 10, 11, 17, 33,  9, 15])
target sentence len: 
tensor([12, 25,  4,  6, 33, 40, 22, 13, 33, 12, 19, 36, 14, 12, 28, 21, 24, 2

In [115]:
source.word2index

{'深海': 2,
 '海中': 3,
 '的': 4,
 '生命': 5,
 '大卫': 6,
 '盖罗': 7,
 '通过': 8,
 '潜水': 9,
 '潜水艇': 10,
 '拍下': 11,
 '影片': 12,
 '把': 13,
 '我们': 14,
 '带到': 15,
 '了': 16,
 '地球': 17,
 '最': 18,
 '黑暗': 19,
 '险恶': 20,
 '同时': 21,
 '也': 22,
 '最美': 23,
 '美丽': 24,
 '生物': 25,
 '栖息': 26,
 '栖息地': 27,
 '这里': 28,
 '是': 29,
 '海洋': 30,
 '深处': 31,
 '峡谷': 32,
 '和': 33,
 '火山': 34,
 '山脊': 35,
 '怪诞': 36,
 '适应': 37,
 '适应力': 38,
 '应力': 39,
 '强': 40,
 '而且': 41,
 '数量': 42,
 '惊人': 43,
 '这位': 44,
 '比尔': 45,
 '兰格': 46,
 '我': 47,
 '将': 48,
 '用': 49,
 '一些': 50,
 '来讲': 51,
 '讲述': 52,
 '海里': 53,
 '故事': 54,
 '这': 55,
 '有': 56,
 '不少': 57,
 '精彩': 58,
 '泰坦': 59,
 '泰坦尼克': 60,
 '坦尼': 61,
 '尼克': 62,
 '可惜': 63,
 '您': 64,
 '今天': 65,
 '看不到': 66,
 '不到': 67,
 '泰坦尼克号': 68,
 '号': 69,
 '拿': 70,
 '票房': 71,
 '冠军': 72,
 '但': 73,
 '事实': 74,
 '事实上': 75,
 '它': 76,
 '并': 77,
 '不是': 78,
 '关于': 79,
 '于海洋': 80,
 '刺激': 81,
 '原因': 82,
 '在于': 83,
 '一直': 84,
 '没': 85,
 '当回事': 86,
 '回事': 87,
 '回事儿': 88,
 '事儿': 89,
 '大家': 90,
 '想想': 91,
 '占': 92,
 '球面': 93,
 '面积

In [116]:
target.index2word

{0: 'SOS',
 1: 'EOS',
 2: 'life',
 3: 'in',
 4: 'the',
 5: 'deep',
 6: 'oceans',
 7: 'with',
 8: 'vibrant',
 9: 'video',
 10: 'clips',
 11: 'captured',
 12: 'by',
 13: 'submarines',
 14: ',',
 15: 'david',
 16: 'gallo',
 17: 'takes',
 18: 'us',
 19: 'to',
 20: 'some',
 21: 'of',
 22: 'earth',
 23: 'is',
 24: 'darkest',
 25: 'most',
 26: 'violent',
 27: 'toxic',
 28: 'and',
 29: 'beautiful',
 30: 'habitats',
 31: 'valleys',
 32: 'volcanic',
 33: 'ridges',
 34: '',
 35: 'depths',
 36: 'where',
 37: 'bizarre',
 38: 'resilient',
 39: 'shockingly',
 40: 'abundant',
 41: ' .',
 42: 'this',
 43: 'bill',
 44: 'lange',
 45: 'i',
 46: 'am',
 47: 'dave',
 48: 'we',
 49: 'are',
 50: 'going',
 51: 'tell',
 52: 'you',
 53: 'stories',
 54: 'from',
 55: 'sea',
 56: 'here',
 57: 've',
 58: 'got',
 59: 'incredible',
 60: 'titanic',
 61: 'that',
 62: 'ever',
 63: 'been',
 64: 'seen',
 65: 'not',
 66: 'show',
 67: 'any',
 68: 'it',
 69: 'truth',
 70: 'matter',
 71: '--',
 72: 'even',
 73: 'though',
 74: '