In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchtext
from torchtext.legacy.datasets import TranslationDataset, Multi30k
from torchtext.legacy.data import Field, BucketIterator

import spacy

import random
import math
import time

import matplotlib
matplotlib.rcParams.update({'figure.figsize': (16, 12), 'font.size': 14})
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import clear_output

from nltk.tokenize import WordPunctTokenizer
from subword_nmt.learn_bpe import learn_bpe
from subword_nmt.apply_bpe import BPE

In [10]:
path_to_data = '../../datasets/Machine_translation_EN_RU/data.txt'
from data_preprocessing import get_dataset

In [11]:
data, vocab = get_dataset(path_to_data)

In [12]:
train_data, valid_data, test_data = data

In [20]:
print(test_data[4].src)
print(test_data[4].trg)

['апартаменты', 'shenzhen', 'dameisha', 'sea', 'world', 'holiday', 'расположены', 'в', 'городе', 'шэньчжэнь', ',', 'в', '5', 'минутах', 'ходьбы', 'от', 'приморского', 'парка', 'дамейша', '.']
['shenzhen', 'dameisha', 'sea', 'world', 'holiday', 'apartment', 'offers', 'accommodation', 'in', 'shenzhen', '.', 'it', 'is', 'located', 'a', '5', '-', 'minute', 'walk', 'from', 'dameisha', 'seaside', 'park', '.']


In [21]:
print(f"Number of training examples: {len(train_data.examples)}")
print(f"Number of validation examples: {len(valid_data.examples)}")
print(f"Number of testing examples: {len(test_data.examples)}")

Number of training examples: 40000
Number of validation examples: 2500
Number of testing examples: 7500


In [22]:
src_vocab, trg_vocab = vocab

In [24]:
len(src_vocab), len(trg_vocab)

(9273, 6704)

In [34]:
def _len_sort_key(x):
    return len(x.src)

BATCH_SIZE = 512

train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data), 
    batch_size = BATCH_SIZE, 
    device = device,
    sort_key=_len_sort_key
)

In [31]:
PAD_IDX = trg_vocab.stoi['<pad>']
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss(ignore_index = PAD_IDX)

In [33]:
from train_model import train

In [38]:
from base_line_model import get_base_line_model

In [39]:
base_line_model = get_base_line_model(len(src_vocab), len(trg_vocab))

In [40]:
train(base_line_model, train_iterator, valid_iterator, optimizer, criterion)

KeyboardInterrupt: 