In [6]:
from torchtext.data import Field, BucketIterator
from torchtext.datasets import Multi30k
SRC = Field(tokenize = tokenize_de, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True)

TRG = Field(tokenize = tokenize_en, 
            init_token = '<sos>', 
            eos_token = '<eos>', 
            lower = True)

train_data, valid_data, test_data = Multi30k.splits(exts = ('.de', '.en'),
                                                    fields = (SRC, TRG))

print(vars(train_data.examples[0]))


<class 'torchtext.datasets.translation.Multi30k'>
{'src': ['.', 'büsche', 'vieler', 'nähe', 'der', 'in', 'freien', 'im', 'sind', 'männer', 'weiße', 'junge', 'zwei'], 'trg': ['two', 'young', ',', 'white', 'males', 'are', 'outside', 'near', 'many', 'bushes', '.']}


In [7]:
print(type(train_data))

<class 'torchtext.datasets.translation.Multi30k'>


In [7]:
"""
将一串英文转换成torchtext.datasets的格式
"""
import torch as t
from torchtext.data import Field, Example, TabularDataset
from torchtext.data import BucketIterator
from torchtext import data
import spacy
import jieba

# 构建分词标准
def tokenize_zh(text):
    return list(jieba.cut(text))

spacy_en = spacy.load('en')
def tokenize_en(text):
    # [::-1]  的含义就是：从下标-1开始，然后每次递增（步长是-1），直到最后
    return [tok.text for tok in spacy_en.tokenizer(text)][::-1]

# 构建Field
TEXT = Field(sequential=True, tokenize=tokenize_en,
                     use_vocab=True, batch_first=True,
                     fix_length=50,
                     eos_token=None, init_token=None,
                     include_lengths=True, pad_token=0)

LABEL = Field(sequential=False,
              tokenize=tokenize_zh,
              use_vocab=True,
              batch_first=True)


fields = [("label", LABEL), ("text", TEXT)]
train, valid = TabularDataset.splits(
    path=".",
    train="/home/lawson/program/wheels/seq2seq/data/test.tsv",
    validation="/home/lawson/program/wheels/seq2seq/data/test.tsv",
    format='tsv',
    skip_header=False,
    fields=fields)

# 根据数据构建字典
TEXT.build_vocab(train)
LABEL.build_vocab(train)

train_iter, val_iter = BucketIterator.splits((train, valid),
                                             batch_sizes=(5,5),
                                             device = t.device("cpu"),
                                             sort_key=lambda x: len(x.text), # field sorted by len
                                             sort_within_batch=True,
                                             repeat=False)
print(type(train_iter))
print(len(train_iter))
for x in train_iter:
    print(x)

<class 'torchtext.data.iterator.BucketIterator'>
1

[torchtext.data.batch.Batch of size 1]
	[.label]:[torch.LongTensor of size 1]
	[.text]:('[torch.LongTensor of size 1x50]', '[torch.LongTensor of size 1]')


<class 'torchtext.data.dataset.TabularDataset'>
{'label': 'I love you', 'text': ['我爱你']}


In [31]:
"""
将一串英文转换成torchtext.datasets的格式
"""
import torch as t
from torchtext.data import Field, Example, TabularDataset
from torchtext.data import BucketIterator
from torchtext import data
import spacy
import jieba

spacy_en = spacy.load('en') # 英文分词

# 构建分词标准
def tokenize_zh(text):
    return list(jieba.cut(text))

def tokenize_en(text):
    # [::-1]  的含义就是：从下标-1开始，然后每次递增（步长是-1），直到最后
    return [tok.text for tok in spacy_en.tokenizer(text)]

# 构建Field
SRC = Field(sequential=True, tokenize=tokenize_en,
                     use_vocab=True, batch_first=True,
                     fix_length=50,
                     eos_token=None, init_token=None,
                     include_lengths=True, pad_token=0)

TRG = Field(sequential=False,
              tokenize=tokenize_zh,
              use_vocab=True,
              batch_first=True)

fields = [("label", SRC), ("text", TRG)]

# 根据数据构建字典
SRC.build_vocab(train)
TRG.build_vocab(train)

src_file = ['I love ShangHai.','What\'s your name?'] # source
trg_file = ["我爱上海","你叫什么？"] # target

examples=[]
for src_line, trg_line in zip(src_file, trg_file):
    src_line, trg_line = src_line.strip(), trg_line.strip()
    if src_line != '' and trg_line != '':
        temp = data.Example.fromlist([src_line, trg_line], fields)
        examples.append(temp)

print(type(train))
print(vars(train.examples[0]))

<class 'torchtext.data.dataset.TabularDataset'>
{'label': 'I love you', 'text': ['我爱你']}
