In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.optim as optim
import itertools

In [2]:
from lib.utils import split_file_into_two

split_file_into_two(
    'data/he-pron-wiktionary.txt', 
    'data/train-he-pron-wiktionary.txt', 
    'data/test-he-pron-wiktionary.txt', 
    test_size=0.3
)

In [2]:
TRAIN_FILE = 'data/train-he-pron-wiktionary.txt'
TEST_FILE = 'data/test-he-pron-wiktionary.txt'

In [3]:
from lib.utils import Alphabet


he = Alphabet()
en = Alphabet()

In [4]:
from lib.utils import load_pair_dataset

X, Y = load_pair_dataset(TRAIN_FILE, he, en)

In [5]:
from sklearn.model_selection import train_test_split

train_X, val_single_X, train_Y, val_single_Y = train_test_split(X, Y, test_size=0.8, random_state=42)
val_X, single_X, val_Y, single_Y = train_test_split(val_single_X, val_single_Y, test_size=0.85, random_state=42)

In [6]:
from lib.models import SimpleGRUSupervisedSeq2Seq

he_gen_model = SimpleGRUSupervisedSeq2Seq(he, en, 65, 256)
en_gen_model = SimpleGRUSupervisedSeq2Seq(en, he, 65, 256)
he_gen_opt = optim.Adam(he_gen_model.parameters(), lr=1e-3)
en_gen_opt = optim.Adam(en_gen_model.parameters(), lr=1e-3)

In [47]:
val_src_words = [he.index2letter(x, with_start_end=False) for x in val_X]
val_trg_words = [en.index2letter(y, with_start_end=False) for y in val_Y]

In [7]:
HE_GENERATOR_CHECKPOINTS = './checkpoints/he_generators_checkpoints'
EN_GENERATOR_CHECKPOINTS = './checkpoints/en_generators_checkpoints'

! mkdir -p {HE_GENERATOR_CHECKPOINTS} {EN_GENERATOR_CHECKPOINTS}

In [49]:
%load_ext autoreload
%autoreload 2

from lib.trainer import train_generator

train_generator(
    he_gen_model, he_gen_opt, en, 
    train_X, train_Y, 
    val_src_words, val_trg_words, 
    checkpoints_folder=HE_GENERATOR_CHECKPOINTS, 
    metrics_compute_freq=5, n_epochs=30
)

KeyboardInterrupt: 

In [12]:
train_generator(
    en_gen_model, en_gen_opt, he, 
    train_Y, train_X, 
    val_trg_words, val_src_words, 
    checkpoints_folder=EN_GENERATOR_CHECKPOINTS, 
    metrics_compute_freq=5, n_epochs=30
)

epoch: 0 iter: 4 loss: 1.0221201269555094
epoch: 0 iter: 9 loss: 1.542006659432749
epoch: 0 iter: 14 loss: 1.7994971765463932
epoch: 0 iter: 19 loss: 1.9018823672941463
epoch: 0 iter: 24 loss: 1.9349514036341193
epoch: 0 iter: 29 loss: 1.935029066656149


Corpus/Sentence contains 0 counts of 4-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().
Corpus/Sentence contains 0 counts of 2-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().
Corpus/Sentence contains 0 counts of 3-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().



epoch: 0 val_score: 0.5316933730302896 time: 23.15217638015747
epoch: 1 iter: 4 loss: 1.8257912569714356
epoch: 1 iter: 9 loss: 1.7232718833424587
epoch: 1 iter: 14 loss: 1.698344552062926
epoch: 1 iter: 19 loss: 1.6016687208383211
epoch: 1 iter: 24 loss: 1.4975072693713618
epoch: 1 iter: 29 loss: 1.4532978820646223

epoch: 1 val_score: 0.5386251118060478 time: 24.224265813827515
epoch: 2 iter: 4 loss: 1.3686532765515973
epoch: 2 iter: 9 loss: 1.2992207131861806
epoch: 2 iter: 14 loss: 1.1759098039516336
epoch: 2 iter: 19 loss: 1.146051263419111
epoch: 2 iter: 24 loss: 1.1303204772960804
epoch: 2 iter: 29 loss: 1.0650707889434048

epoch: 2 val_score: 0.5714936906355055 time: 23.54473304748535
epoch: 3 iter: 4 loss: 0.981791930525283
epoch: 3 iter: 9 loss: 0.9548921381439499
epoch: 3 iter: 14 loss: 0.921052612331801
epoch: 3 iter: 19 loss: 0.9021065417489034
epoch: 3 iter: 24 loss: 0.8973082271564714
epoch: 3 iter: 29 loss: 0.8874920415220401

epoch: 3 val_score: 0.565961733246865 time

In [8]:
import os

best_score_he_model = "state_dict_11_0.8800413397572596.pth"
best_score_en_model = "state_dict_28_0.643515051192157.pth"

he_gen_model.load_state_dict(torch.load(os.path.join(HE_GENERATOR_CHECKPOINTS, best_score_he_model)))
en_gen_model.load_state_dict(torch.load(os.path.join(EN_GENERATOR_CHECKPOINTS, best_score_en_model)))

In [51]:
he_gen_model.translate("אני צריך להזמין מונית ", with_start_end=False)

'ni tsari kham nim yamim'

In [52]:
en_gen_model.translate("oved", with_start_end=False)

'עוֹבֵד'

In [None]:
from lib.metrics import compute_accuracy

#compute_accuracy(model, val_src_words, val_trg_words)

In [32]:
from lib.metrics import compute_bleu_score

print("he->en bleu:", compute_bleu_score(he_gen_model, val_src_words, val_trg_words))
print("en->he bleu:", compute_bleu_score(en_gen_model, val_trg_words, val_src_words))

Corpus/Sentence contains 0 counts of 3-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().
Corpus/Sentence contains 0 counts of 4-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().
Corpus/Sentence contains 0 counts of 2-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().



he->en bleu: 0.756867651006124

en->he bleu: 0.643515051192157


In [11]:
from lib.models import BiLSTMDiscriminator

he_en_disc = BiLSTMDiscriminator(en, 32, 128)
he_en_disc_opt = optim.Adam(he_en_disc.parameters(), lr=1e-4)

In [11]:
from lib.utils import batch_iterator

x, y = next(batch_iterator(single_X, single_Y))

In [70]:
a

Variable containing:
-0.6931
[torch.FloatTensor of size 1]

In [17]:
from lib.trainer import train_discriminator


train_discriminator(he_en_disc, he_gen_model, he_en_disc_opt, single_X, single_Y, n_epochs=5)

1.2384521961212158
1.3718072175979614
1.3718334436416626
1.3148930072784424
1.2624657154083252
1.2915847301483154
1.2727792263031006
1.1567726135253906
1.3493173122406006
1.2651690244674683
1.2297322319317114
1.3737601041793823
1.2953300476074219
1.2669183015823364
1.2240954637527466
1.212420105934143
1.289250135421753
1.1618685722351074
1.2122358083724976
1.342937707901001
1.3290832042694092
1.234034298357258
1.18357515335083
1.2539525032043457
1.2623798847198486
1.2066926956176758
1.382946252822876
1.2576000690460205
1.3047951459884644
1.3303977251052856
1.1666460037231445
1.1743628978729248
1.2420799208252453
1.226090908050537
1.3146021366119385
1.3053488731384277
1.2869923114776611
1.1714167594909668
1.2247607707977295
1.2718067169189453
1.2853238582611084
1.2806005477905273
1.2753772735595703
1.2550788223872757
1.132434368133545
1.2214456796646118
1.2251617908477783
1.3123784065246582
1.2538001537322998
1.1897834539413452
1.3106296062469482
1.1576265096664429
1.2624542713165283
1.

In [17]:
for g in en_he_disc_opt.param_groups:
    g['lr'] = 1e-3

In [36]:
len(single_X)
len(single_Y)

3336

In [14]:
#torch.save(he_en_disc.state_dict(), 'checkpoints/he_en_disc.pth')

In [15]:
en_he_disc = BiLSTMDiscriminator(he, 32, 128)
en_he_disc_opt = optim.Adam(en_he_disc.parameters(), lr=1e-4)

In [21]:
%load_ext autoreload
%autoreload 2

from lib.trainer import train_discriminator


train_discriminator(en_he_disc, en_gen_model, en_he_disc_opt, single_Y, single_X, n_epochs=5)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
0.2804229259490967
0.3591843247413635
0.5119056701660156
0.38617441058158875
0.4110684096813202
0.4386439323425293
0.3661508560180664
0.41029417514801025
0.19577345252037048
0.35182929039001465
0.43271239670753664
0.5297037959098816
0.5909342765808105
0.47149455547332764
0.4378165006637573
0.42431938648223877
0.32680660486221313
0.39451146125793457
0.23801937699317932
0.5421937108039856
0.19427376985549927
0.36238883628009205
0.23387694358825684
0.440614253282547
0.5886380672454834
0.41957834362983704
0.27277401089668274
0.5400180220603943
0.5062800645828247
0.5183104872703552
0.4915462136268616
0.3166813850402832
0.3975979069490849
0.3233836889266968
0.19026336073875427
0.19132040441036224
0.2645258903503418
0.3994094133377075
0.5878828763961792
0.20943179726600647
0.21955306828022003
0.3330232799053192
0.6479833126068115
0.40701333767769265
0.29019033908843994
0.697598934173584
0.6357625722885132
