In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.optim as optim
import itertools

In [2]:
from lib.utils import split_file_into_two

split_file_into_two(
    'data/he-pron-wiktionary.txt', 
    'data/train-he-pron-wiktionary.txt', 
    'data/test-he-pron-wiktionary.txt', 
    test_size=0.3
)

In [3]:
TRAIN_FILE = 'data/train-he-pron-wiktionary.txt'
TEST_FILE = 'data/test-he-pron-wiktionary.txt'

In [4]:
from lib.utils import Alphabet


he = Alphabet()
en = Alphabet()

In [5]:
from lib.utils import load_pair_dataset

X, Y = load_pair_dataset(TRAIN_FILE, he, en)

In [6]:
from sklearn.model_selection import train_test_split

train_X, val_single_X, train_Y, val_single_Y = train_test_split(X, Y, test_size=0.8, random_state=42)
val_X, single_X, val_Y, single_Y = train_test_split(val_single_X, val_single_Y, test_size=0.85, random_state=42)

In [7]:
from lib.models import SimpleGRUSupervisedSeq2Seq

model = SimpleGRUSupervisedSeq2Seq(he, en, 65, 256)
opt = optim.Adam(model.parameters(), lr=1e-3)

In [8]:
val_src_words = [he.index2letter(x, with_start_end=False) for x in val_X]
val_trg_words = [en.index2letter(y, with_start_end=False) for y in val_Y]

In [9]:
HE_GENERATOR_CHECKPOINTS = './checkpoints/he_generators_checkpoints'

! mkdir -p {HE_GENERATOR_CHECKPOINTS}

In [10]:
from lib.trainer import train_generator

train_generator(
    model, opt, en, 
    train_X, train_Y, 
    val_src_words, val_trg_words, 
    checkpoints_folder=HE_GENERATOR_CHECKPOINTS, 
    metrics_compute_freq=5, n_epochs=30
)

epoch: 0 iter: 4 loss: 1.3346859925436974
epoch: 0 iter: 9 loss: 1.9718228925061139
epoch: 0 iter: 14 loss: 2.297269074638102
epoch: 0 iter: 19 loss: 2.4181559898462184
epoch: 0 iter: 24 loss: 2.4675153013372952
epoch: 0 iter: 29 loss: 2.4286769315309


Widget Javascript not detected.  It may not be installed properly. Did you enable the widgetsnbextension? If not, then run "jupyter nbextension enable --py --sys-prefix widgetsnbextension"
Corpus/Sentence contains 0 counts of 2-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().
Corpus/Sentence contains 0 counts of 4-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().
Corpus/Sentence contains 0 counts of 3-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().



epoch: 0 val_score: 0.6143066873067135 time: 17.012142181396484
epoch: 1 iter: 4 loss: 2.3562851676747596
epoch: 1 iter: 9 loss: 2.266869288866778
epoch: 1 iter: 14 loss: 2.148956287760393
epoch: 1 iter: 19 loss: 2.040578934973147
epoch: 1 iter: 24 loss: 1.9324265433328565
epoch: 1 iter: 29 loss: 1.8245586345128135


Widget Javascript not detected.  It may not be installed properly. Did you enable the widgetsnbextension? If not, then run "jupyter nbextension enable --py --sys-prefix widgetsnbextension"



epoch: 1 val_score: 0.5998235442923345 time: 19.32804846763611
epoch: 2 iter: 4 loss: 1.6907895577647123
epoch: 2 iter: 9 loss: 1.5734793256867277
epoch: 2 iter: 14 loss: 1.4857802590058238
epoch: 2 iter: 19 loss: 1.3500058039039868
epoch: 2 iter: 24 loss: 1.2497571455618104
epoch: 2 iter: 29 loss: 1.1892788952462408


Widget Javascript not detected.  It may not be installed properly. Did you enable the widgetsnbextension? If not, then run "jupyter nbextension enable --py --sys-prefix widgetsnbextension"



epoch: 2 val_score: 0.6429489637176095 time: 20.840585470199585
epoch: 3 iter: 4 loss: 1.082006394977475
epoch: 3 iter: 9 loss: 0.9906909244479422
epoch: 3 iter: 14 loss: 0.9173344862236215
epoch: 3 iter: 19 loss: 0.8682639059548988
epoch: 3 iter: 24 loss: 0.7921729618370519
epoch: 3 iter: 29 loss: 0.7248490907159528


Widget Javascript not detected.  It may not be installed properly. Did you enable the widgetsnbextension? If not, then run "jupyter nbextension enable --py --sys-prefix widgetsnbextension"



epoch: 3 val_score: 0.7390385163099678 time: 17.222690105438232
epoch: 4 iter: 4 loss: 0.6276286919513622
epoch: 4 iter: 9 loss: 0.5868373981435675
epoch: 4 iter: 14 loss: 0.5363700061192611
epoch: 4 iter: 19 loss: 0.5134208062031547
epoch: 4 iter: 24 loss: 0.5143241805921597
epoch: 4 iter: 29 loss: 0.483628120917357


Widget Javascript not detected.  It may not be installed properly. Did you enable the widgetsnbextension? If not, then run "jupyter nbextension enable --py --sys-prefix widgetsnbextension"



epoch: 4 val_score: 0.8169762512458367 time: 20.010468006134033
epoch: 5 iter: 4 loss: 0.4749239948260459
epoch: 5 iter: 9 loss: 0.4307114473287953
epoch: 5 iter: 14 loss: 0.3980397546921679
epoch: 5 iter: 19 loss: 0.36170489944994777
epoch: 5 iter: 24 loss: 0.3335760578376433
epoch: 5 iter: 29 loss: 0.3473637018428036


Widget Javascript not detected.  It may not be installed properly. Did you enable the widgetsnbextension? If not, then run "jupyter nbextension enable --py --sys-prefix widgetsnbextension"



epoch: 5 val_score: 0.8424275443978427 time: 22.116737365722656
epoch: 6 iter: 4 loss: 0.32703710506662637
epoch: 6 iter: 9 loss: 0.3090136063476001
epoch: 6 iter: 14 loss: 0.29759025450935245
epoch: 6 iter: 19 loss: 0.3057330496025056
epoch: 6 iter: 24 loss: 0.29910094730695586
epoch: 6 iter: 29 loss: 0.299316227049155


Widget Javascript not detected.  It may not be installed properly. Did you enable the widgetsnbextension? If not, then run "jupyter nbextension enable --py --sys-prefix widgetsnbextension"



epoch: 6 val_score: 0.8524881977636152 time: 21.40464997291565
epoch: 7 iter: 4 loss: 0.2874946476196805
epoch: 7 iter: 9 loss: 0.27310137205266516
epoch: 7 iter: 14 loss: 0.28144452908700346
epoch: 7 iter: 19 loss: 0.26137145849365867
epoch: 7 iter: 24 loss: 0.28422275992778506
epoch: 7 iter: 29 loss: 0.2649027725308757


Widget Javascript not detected.  It may not be installed properly. Did you enable the widgetsnbextension? If not, then run "jupyter nbextension enable --py --sys-prefix widgetsnbextension"



epoch: 7 val_score: 0.8531493476878835 time: 17.761500120162964
epoch: 8 iter: 4 loss: 0.2343357197779371
epoch: 8 iter: 9 loss: 0.2211774026387571
epoch: 8 iter: 14 loss: 0.22716117297073338
epoch: 8 iter: 19 loss: 0.2152888489005443
epoch: 8 iter: 24 loss: 0.21233996161988453
epoch: 8 iter: 29 loss: 0.20397565435207318


Widget Javascript not detected.  It may not be installed properly. Did you enable the widgetsnbextension? If not, then run "jupyter nbextension enable --py --sys-prefix widgetsnbextension"



epoch: 8 val_score: 0.8619276977671353 time: 18.387447595596313
epoch: 9 iter: 4 loss: 0.18790985181092207
epoch: 9 iter: 9 loss: 0.186331383270875
epoch: 9 iter: 14 loss: 0.17523797258422505
epoch: 9 iter: 19 loss: 0.19546660698292534
epoch: 9 iter: 24 loss: 0.2135245271568358
epoch: 9 iter: 29 loss: 0.18737732624509032


Widget Javascript not detected.  It may not be installed properly. Did you enable the widgetsnbextension? If not, then run "jupyter nbextension enable --py --sys-prefix widgetsnbextension"



epoch: 9 val_score: 0.8521432126592983 time: 17.774501085281372
epoch: 10 iter: 4 loss: 0.20556806874931519
epoch: 10 iter: 9 loss: 0.1854284183677792
epoch: 10 iter: 14 loss: 0.18659030061166204
epoch: 10 iter: 19 loss: 0.20121493838546092
epoch: 10 iter: 24 loss: 0.1951636281848889
epoch: 10 iter: 29 loss: 0.18715009953595238


Widget Javascript not detected.  It may not be installed properly. Did you enable the widgetsnbextension? If not, then run "jupyter nbextension enable --py --sys-prefix widgetsnbextension"



epoch: 10 val_score: 0.8427501854360808 time: 16.98882293701172
epoch: 11 iter: 4 loss: 0.172860702412816
epoch: 11 iter: 9 loss: 0.1712343601056772
epoch: 11 iter: 14 loss: 0.17338004898680964
epoch: 11 iter: 19 loss: 0.17703076431209452
epoch: 11 iter: 24 loss: 0.17245841013892949
epoch: 11 iter: 29 loss: 0.19608140424983217


Widget Javascript not detected.  It may not be installed properly. Did you enable the widgetsnbextension? If not, then run "jupyter nbextension enable --py --sys-prefix widgetsnbextension"



epoch: 11 val_score: 0.8598090298410276 time: 16.53945231437683
epoch: 12 iter: 4 loss: 0.18887970847450758
epoch: 12 iter: 9 loss: 0.16582351361109288
epoch: 12 iter: 14 loss: 0.15423064888743065
epoch: 12 iter: 19 loss: 0.1539745215813342
epoch: 12 iter: 24 loss: 0.15356014465941903
epoch: 12 iter: 29 loss: 0.16172901710848692


Widget Javascript not detected.  It may not be installed properly. Did you enable the widgetsnbextension? If not, then run "jupyter nbextension enable --py --sys-prefix widgetsnbextension"



epoch: 12 val_score: 0.8736158259943589 time: 16.93425440788269
epoch: 13 iter: 4 loss: 0.14205365109771392
epoch: 13 iter: 9 loss: 0.14570094451597854
epoch: 13 iter: 14 loss: 0.1399663047014335
epoch: 13 iter: 19 loss: 0.12409384534438037
epoch: 13 iter: 24 loss: 0.11082517435033061
epoch: 13 iter: 29 loss: 0.12135112489998555


Widget Javascript not detected.  It may not be installed properly. Did you enable the widgetsnbextension? If not, then run "jupyter nbextension enable --py --sys-prefix widgetsnbextension"



epoch: 13 val_score: 0.8684223047176375 time: 17.409035444259644
epoch: 14 iter: 4 loss: 0.11022509427834339
epoch: 14 iter: 9 loss: 0.11644585952678967
epoch: 14 iter: 14 loss: 0.11452644041591982
epoch: 14 iter: 19 loss: 0.10926388416367584
epoch: 14 iter: 24 loss: 0.10838426624460384
epoch: 14 iter: 29 loss: 0.10049567803054454


Widget Javascript not detected.  It may not be installed properly. Did you enable the widgetsnbextension? If not, then run "jupyter nbextension enable --py --sys-prefix widgetsnbextension"



epoch: 14 val_score: 0.8820249969781627 time: 16.86269497871399
epoch: 15 iter: 4 loss: 0.0933624327677148
epoch: 15 iter: 9 loss: 0.10141141702750445
epoch: 15 iter: 14 loss: 0.10427418133453917
epoch: 15 iter: 19 loss: 0.10172892285196565
epoch: 15 iter: 24 loss: 0.09398891441729407
epoch: 15 iter: 29 loss: 0.09591148165190423


Widget Javascript not detected.  It may not be installed properly. Did you enable the widgetsnbextension? If not, then run "jupyter nbextension enable --py --sys-prefix widgetsnbextension"



epoch: 15 val_score: 0.8798469376494075 time: 19.2204487323761
epoch: 16 iter: 4 loss: 0.09303081941628846
epoch: 16 iter: 9 loss: 0.0861036995691437
epoch: 16 iter: 14 loss: 0.08314037301062326
epoch: 16 iter: 19 loss: 0.09323637819793087
epoch: 16 iter: 24 loss: 0.09101960212919201
epoch: 16 iter: 29 loss: 0.08196266031352906


Widget Javascript not detected.  It may not be installed properly. Did you enable the widgetsnbextension? If not, then run "jupyter nbextension enable --py --sys-prefix widgetsnbextension"



epoch: 16 val_score: 0.8697150436626074 time: 23.101334810256958
epoch: 17 iter: 4 loss: 0.07422633539171285
epoch: 17 iter: 9 loss: 0.06546642168198287
epoch: 17 iter: 14 loss: 0.06589640776767126
epoch: 17 iter: 19 loss: 0.07039079970649839
epoch: 17 iter: 24 loss: 0.09179329029329675
epoch: 17 iter: 29 loss: 0.09262089561860305


Widget Javascript not detected.  It may not be installed properly. Did you enable the widgetsnbextension? If not, then run "jupyter nbextension enable --py --sys-prefix widgetsnbextension"



epoch: 17 val_score: 0.8745011249874425 time: 19.566139459609985
epoch: 18 iter: 4 loss: 0.08314920532491947
epoch: 18 iter: 9 loss: 0.07625175190646333
epoch: 18 iter: 14 loss: 0.08241461695643916
epoch: 18 iter: 19 loss: 0.07042305447318267
epoch: 18 iter: 24 loss: 0.07790831851911387
epoch: 18 iter: 29 loss: 0.06805684283266997


Widget Javascript not detected.  It may not be installed properly. Did you enable the widgetsnbextension? If not, then run "jupyter nbextension enable --py --sys-prefix widgetsnbextension"



epoch: 18 val_score: 0.8795945056969183 time: 20.41543960571289
epoch: 19 iter: 4 loss: 0.07023485852885661
epoch: 19 iter: 9 loss: 0.06266780338185265
epoch: 19 iter: 14 loss: 0.056787472592954365
epoch: 19 iter: 19 loss: 0.054045857784114165
epoch: 19 iter: 24 loss: 0.05261883808671238
epoch: 19 iter: 29 loss: 0.06166194390132336


Widget Javascript not detected.  It may not be installed properly. Did you enable the widgetsnbextension? If not, then run "jupyter nbextension enable --py --sys-prefix widgetsnbextension"



epoch: 19 val_score: 0.8864896906239248 time: 21.212124586105347
epoch: 20 iter: 4 loss: 0.0597299218600795
epoch: 20 iter: 9 loss: 0.05445573401297219
epoch: 20 iter: 14 loss: 0.055065412899159456
epoch: 20 iter: 19 loss: 0.05878420115918981
epoch: 20 iter: 24 loss: 0.0632072601514441
epoch: 20 iter: 29 loss: 0.06582375148333273


Widget Javascript not detected.  It may not be installed properly. Did you enable the widgetsnbextension? If not, then run "jupyter nbextension enable --py --sys-prefix widgetsnbextension"



epoch: 20 val_score: 0.8855564108467219 time: 18.592735528945923
epoch: 21 iter: 4 loss: 0.07011103860366699
epoch: 21 iter: 9 loss: 0.06342688280379849
epoch: 21 iter: 14 loss: 0.06074022246293251
epoch: 21 iter: 19 loss: 0.055746990161864074
epoch: 21 iter: 24 loss: 0.05435711913676687
epoch: 21 iter: 29 loss: 0.05613395620153063


Widget Javascript not detected.  It may not be installed properly. Did you enable the widgetsnbextension? If not, then run "jupyter nbextension enable --py --sys-prefix widgetsnbextension"



epoch: 21 val_score: 0.8822883096534071 time: 19.949179649353027
epoch: 22 iter: 4 loss: 0.05300345854635761
epoch: 22 iter: 9 loss: 0.056335493055966
epoch: 22 iter: 14 loss: 0.05594361900824781
epoch: 22 iter: 19 loss: 0.04973048067813541
epoch: 22 iter: 24 loss: 0.04817061152304606
epoch: 22 iter: 29 loss: 0.05297466301685963


Widget Javascript not detected.  It may not be installed properly. Did you enable the widgetsnbextension? If not, then run "jupyter nbextension enable --py --sys-prefix widgetsnbextension"



epoch: 22 val_score: 0.8766622499244584 time: 18.43692684173584
epoch: 23 iter: 4 loss: 0.04782627146814139
epoch: 23 iter: 9 loss: 0.04847353558242379
epoch: 23 iter: 14 loss: 0.05184079630232533
epoch: 23 iter: 19 loss: 0.054589429246492585
epoch: 23 iter: 24 loss: 0.05432916464866892
epoch: 23 iter: 29 loss: 0.05042847690319823


Widget Javascript not detected.  It may not be installed properly. Did you enable the widgetsnbextension? If not, then run "jupyter nbextension enable --py --sys-prefix widgetsnbextension"



epoch: 23 val_score: 0.8745691814057533 time: 19.79112696647644
epoch: 24 iter: 4 loss: 0.05012104033582311
epoch: 24 iter: 9 loss: 0.05983808087402182
epoch: 24 iter: 14 loss: 0.05433323811643371
epoch: 24 iter: 19 loss: 0.06780970312776728
epoch: 24 iter: 24 loss: 0.12647812632246763
epoch: 24 iter: 29 loss: 0.14463591289552272


Widget Javascript not detected.  It may not be installed properly. Did you enable the widgetsnbextension? If not, then run "jupyter nbextension enable --py --sys-prefix widgetsnbextension"



epoch: 24 val_score: 0.8162691775399161 time: 19.16445755958557
epoch: 25 iter: 4 loss: 0.1690364998597018
epoch: 25 iter: 9 loss: 0.18616987110833078
epoch: 25 iter: 14 loss: 0.18699552335393096
epoch: 25 iter: 19 loss: 0.2190072554756436
epoch: 25 iter: 24 loss: 0.20128255904659312
epoch: 25 iter: 29 loss: 0.201387875213222


Widget Javascript not detected.  It may not be installed properly. Did you enable the widgetsnbextension? If not, then run "jupyter nbextension enable --py --sys-prefix widgetsnbextension"



epoch: 25 val_score: 0.8471117881935597 time: 20.44273066520691
epoch: 26 iter: 4 loss: 0.16570608863949374
epoch: 26 iter: 9 loss: 0.13511493790481413
epoch: 26 iter: 14 loss: 0.1408242796631535
epoch: 26 iter: 19 loss: 0.11986437512145913
epoch: 26 iter: 24 loss: 0.11008286300862778
epoch: 26 iter: 29 loss: 0.10484076535410551


Widget Javascript not detected.  It may not be installed properly. Did you enable the widgetsnbextension? If not, then run "jupyter nbextension enable --py --sys-prefix widgetsnbextension"



epoch: 26 val_score: 0.8712336129355943 time: 19.066054105758667
epoch: 27 iter: 4 loss: 0.10666924460948213
epoch: 27 iter: 9 loss: 0.10871441668875402
epoch: 27 iter: 14 loss: 0.11025149882735408
epoch: 27 iter: 19 loss: 0.09757921901563589
epoch: 27 iter: 24 loss: 0.09604290043778928
epoch: 27 iter: 29 loss: 0.09899977781531244


Widget Javascript not detected.  It may not be installed properly. Did you enable the widgetsnbextension? If not, then run "jupyter nbextension enable --py --sys-prefix widgetsnbextension"



epoch: 27 val_score: 0.8571747852851563 time: 18.632996320724487
epoch: 28 iter: 4 loss: 0.08808884535724905
epoch: 28 iter: 9 loss: 0.07684516640105989
epoch: 28 iter: 14 loss: 0.0852334044574398
epoch: 28 iter: 19 loss: 0.08432113997525567
epoch: 28 iter: 24 loss: 0.07431362872173947
epoch: 28 iter: 29 loss: 0.0698215233704071


Widget Javascript not detected.  It may not be installed properly. Did you enable the widgetsnbextension? If not, then run "jupyter nbextension enable --py --sys-prefix widgetsnbextension"



epoch: 28 val_score: 0.8758459973422538 time: 20.12304997444153
epoch: 29 iter: 4 loss: 0.05352125897499866
epoch: 29 iter: 9 loss: 0.057762820821650226
epoch: 29 iter: 14 loss: 0.049612251502050495
epoch: 29 iter: 19 loss: 0.04603245136222388
epoch: 29 iter: 24 loss: 0.04473251560093111
epoch: 29 iter: 29 loss: 0.050305631435042825


Widget Javascript not detected.  It may not be installed properly. Did you enable the widgetsnbextension? If not, then run "jupyter nbextension enable --py --sys-prefix widgetsnbextension"



epoch: 29 val_score: 0.8762013511240251 time: 21.839410305023193


In [12]:
import os

best_score_model = "state_dict_14_0.8820249969781627.pth"
model.load_state_dict(torch.load(os.path.join(HE_GENERATOR_CHECKPOINTS, best_score_model)))

In [13]:
model.translate("אני צריך להזמין מונית ", with_start_end=False)

'ni tsrikh lam nim ton t tim t'

In [None]:
from lib.metrics import compute_accuracy

#compute_accuracy(model, val_src_words, val_trg_words)

In [14]:
src_words = [he.index2letter(x, with_start_end=False) for x in train_X[:8000]]
trg_words = [en.index2letter(y, with_start_end=False) for y in train_Y[:8000]]

In [15]:
from lib.metrics import compute_bleu_score

compute_bleu_score(model, val_src_words, val_trg_words)

Widget Javascript not detected.  It may not be installed properly. Did you enable the widgetsnbextension? If not, then run "jupyter nbextension enable --py --sys-prefix widgetsnbextension"
Corpus/Sentence contains 0 counts of 4-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().
Corpus/Sentence contains 0 counts of 3-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().
Corpus/Sentence contains 0 counts of 2-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().





0.8820249969781627

In [None]:
for ru_word, be_word in zip(src_words, trg_words):
    print(ru_word, model.translate(ru_word, with_start_end=False), be_word)

In [16]:
from lib.models import BiLSTMDiscriminator

disc = BiLSTMDiscriminator(en, 32, 128)
disc_opt = optim.Adam(disc.parameters(), lr=5e-5)

In [24]:
from lib.trainer import train_discriminator


train_discriminator(disc, model, disc_opt, single_X, single_Y, n_epochs=5)

1.326971411705017
1.1243849992752075
1.216160535812378
1.2075817584991455
1.313183307647705
1.3533127307891846
1.2849944829940796
1.150315761566162
1.1133830547332764
1.0920817852020264
1.1260046621595368
1.1998040676116943
1.1257209777832031
0.9887277483940125
1.1098346710205078
1.0364594459533691
1.210181474685669
1.206236481666565
1.0589314699172974
1.3024959564208984
1.0467264652252197
1.170199982722218
1.0443423986434937
1.1883742809295654
1.0689692497253418
1.0560312271118164
1.0201339721679688
1.226714015007019
1.2931761741638184
1.220949649810791
1.1514885425567627
1.1608054637908936
1.1682259033583755
1.2183563709259033
1.0659983158111572
1.116222620010376
1.119957685470581
1.088942050933838
1.1327602863311768
1.0117628574371338
1.1385940313339233
1.0778934955596924
1.0856103897094727
1.2057596987441275
1.235271692276001
1.1090869903564453
1.207736611366272
1.1369129419326782
1.0215330123901367
1.223803162574768
1.1199064254760742
1.3189902305603027
1.0294286012649536
1.219107

In [23]:
for g in disc_opt.param_groups:
    g['lr'] = 1e-4