In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.optim as optim
import itertools

In [2]:
TRAIN_FILE = 'data/ru-be-train.txt'
TEST_FILE = 'data/ru-be-test.txt'

In [3]:
from lib.utils import Alphabet


ru = Alphabet()
be = Alphabet()

In [4]:
from lib.utils import load_pair_dataset

X, Y = load_pair_dataset(TRAIN_FILE, ru, be)

In [5]:
from sklearn.model_selection import train_test_split

train_X, val_X, train_Y, val_Y = train_test_split(X, Y, test_size=0.1, random_state=42)

In [6]:
from lib.models import SimpleGRUSupervisedSeq2Seq

model = SimpleGRUSupervisedSeq2Seq(ru, be, 65, 256)
opt = optim.Adam(model.parameters(), lr=1e-3)

In [7]:
val_src_words = [ru.index2letter(x, with_start_end=False) for x in val_X]
val_trg_words = [be.index2letter(y, with_start_end=False) for y in val_Y]

In [8]:
RU_GENERATOR_CHECKPOINTS = './checkpoints/ru_generators_checkpoints'

! mkdir -p {RU_GENERATOR_CHECKPOINTS}

In [11]:
from lib.trainer import train_generator

train_generator(
    model, opt, be, 
    train_X, train_Y, 
    val_src_words, val_trg_words, 
    checkpoints_folder=RU_GENERATOR_CHECKPOINTS, 
    metrics_compute_freq=50, n_epochs=10
)

epoch: 0 iter: 49 loss: 0.18749146784438972
epoch: 0 iter: 99 loss: 0.1925157384710806
epoch: 0 iter: 149 loss: 0.1971139955831735
epoch: 0 iter: 199 loss: 0.19273636193160368
epoch: 0 iter: 249 loss: 0.1800166141711831
epoch: 0 iter: 299 loss: 0.17534948138655812
epoch: 0 iter: 349 loss: 0.16808016918203736
epoch: 0 iter: 399 loss: 0.19870731300857897
epoch: 0 iter: 449 loss: 0.18669938045001172
epoch: 0 iter: 499 loss: 0.17915205913272433
epoch: 0 iter: 549 loss: 0.17623469553709434
epoch: 0 iter: 599 loss: 0.17824407318836993
epoch: 0 iter: 649 loss: 0.19722229769145933
epoch: 0 iter: 699 loss: 0.19358745286883056
epoch: 0 iter: 749 loss: 0.21682424548138
epoch: 0 iter: 799 loss: 0.2037420646300935
epoch: 0 iter: 849 loss: 0.20755405683640238
epoch: 0 iter: 899 loss: 0.21673287047922946
epoch: 0 iter: 949 loss: 0.21825829383429313
epoch: 0 iter: 999 loss: 0.19722665209180223
epoch: 0 iter: 1049 loss: 0.2088876228990186
epoch: 0 iter: 1099 loss: 0.21210584177204816


Corpus/Sentence contains 0 counts of 3-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().
Corpus/Sentence contains 0 counts of 4-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().
Corpus/Sentence contains 0 counts of 2-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().



epoch: 0 val_score: 0.7993138246330833 time: 444.489666223526
epoch: 1 iter: 49 loss: 0.14538353563933032
epoch: 1 iter: 99 loss: 0.15838343080189057
epoch: 1 iter: 149 loss: 0.16282482438438467
epoch: 1 iter: 199 loss: 0.15753384125834352
epoch: 1 iter: 249 loss: 0.1541912020099127
epoch: 1 iter: 299 loss: 0.14959332736390032
epoch: 1 iter: 349 loss: 0.1613433498939563
epoch: 1 iter: 399 loss: 0.1618052241266536
epoch: 1 iter: 449 loss: 0.17411149292950762
epoch: 1 iter: 499 loss: 0.1830906354064941
epoch: 1 iter: 549 loss: 0.15706067688288924
epoch: 1 iter: 599 loss: 0.18614956338251967
epoch: 1 iter: 649 loss: 0.19174604884268717
epoch: 1 iter: 699 loss: 0.19625491330518802
epoch: 1 iter: 749 loss: 0.21506548780049592
epoch: 1 iter: 799 loss: 0.20193613857077772
epoch: 1 iter: 849 loss: 0.20712392192739468
epoch: 1 iter: 899 loss: 0.20319327611458343
epoch: 1 iter: 949 loss: 0.2047374716869786
epoch: 1 iter: 999 loss: 0.23052755541578493
epoch: 1 iter: 1049 loss: 0.2147749537298360

KeyboardInterrupt: 

In [16]:
import os

best_score_model = "state_dict_0_0.7993138246330833.pth"
model.load_state_dict(torch.load(os.path.join(RU_GENERATOR_CHECKPOINTS, best_score_model)))

In [20]:
model.translate("федеральным", with_start_end=False)

'федэральным'

In [11]:
from lib.metrics import compute_accuracy

#compute_accuracy(model, val_src_words, val_trg_words)

In [26]:
src_words = [ru.index2letter(x, with_start_end=False) for x in train_X[:8000]]
trg_words = [be.index2letter(y, with_start_end=False) for y in train_Y[:8000]]

In [21]:
from lib.metrics import compute_bleu_score

compute_bleu_score(model, val_src_words, val_trg_words)

Corpus/Sentence contains 0 counts of 3-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().
Corpus/Sentence contains 0 counts of 4-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().
Corpus/Sentence contains 0 counts of 2-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().





0.7993138246330833

In [9]:
for ru_word, be_word in zip(src_words, trg_words):
    print(ru_word, model.translate(ru_word, with_start_end=False), be_word)

NameError: name 'src_words' is not defined

In [22]:
from lib.models import BiLSTMDiscriminator

disc = BiLSTMDiscriminator(be, 32, 128)
disc_opt = optim.Adam(disc.parameters(), lr=5e-5)

In [23]:
from lib.trainer import train_discriminator


train_discriminator(disc, model, disc_opt, train_X, train_Y, n_epochs=5)

1.3864200115203857
1.3856408596038818
1.3859609365463257
1.3855679035186768
1.3860799074172974
1.3875696659088135
1.3844671249389648
1.3863332271575928
1.385830044746399
1.3852546215057373
1.386328935623169
1.3852815628051758
1.3844716548919678
1.385976791381836
1.3853299617767334
1.3839329481124878
1.3843560218811035
1.3858451843261719
1.3852136135101318
1.3830749988555908
1.382631778717041
1.3860793113708496
1.3807296752929688
1.3858046531677246
1.3817212581634521
1.3851286172866821
1.3851951360702515
1.3835530281066895
1.3865728378295898
1.3857605457305908
1.3793292045593262
1.3826868534088135
1.3832042217254639
1.3856315612792969
1.385282039642334
1.3836541175842285
1.3857414722442627
1.3847401142120361
1.3891725540161133
1.385475754737854
1.3864716291427612
1.38339364528656
1.3872170448303223
1.3846826553344727
1.3861864805221558
1.381763219833374
1.382333755493164
1.3850464820861816
1.386406660079956
1.3868886232376099
1.3869279623031616
1.3828113079071045
1.384963035583496
1.382