/
train_simple_tf2.py
223 lines (172 loc) · 8.46 KB
/
train_simple_tf2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import tensorflow as tf
import numpy as np
import unicodedata
import re
raw_data = (
('What a ridiculous concept!', 'Quel concept ridicule !'),
('Your idea is not entirely crazy.', "Votre idée n'est pas complètement folle."),
("A man's worth lies in what he is.", "La valeur d'un homme réside dans ce qu'il est."),
('What he did is very wrong.', "Ce qu'il a fait est très mal."),
("All three of you need to do that.", "Vous avez besoin de faire cela, tous les trois."),
("Are you giving me another chance?", "Me donnez-vous une autre chance ?"),
("Both Tom and Mary work as models.", "Tom et Mary travaillent tous les deux comme mannequins."),
("Can I have a few minutes, please?", "Puis-je avoir quelques minutes, je vous prie ?"),
("Could you close the door, please?", "Pourriez-vous fermer la porte, s'il vous plaît ?"),
("Did you plant pumpkins this year?", "Cette année, avez-vous planté des citrouilles ?"),
("Do you ever study in the library?", "Est-ce que vous étudiez à la bibliothèque des fois ?"),
("Don't be deceived by appearances.", "Ne vous laissez pas abuser par les apparences."),
("Excuse me. Can you speak English?", "Je vous prie de m'excuser ! Savez-vous parler anglais ?"),
("Few people know the true meaning.", "Peu de gens savent ce que cela veut réellement dire."),
("Germany produced many scientists.", "L'Allemagne a produit beaucoup de scientifiques."),
("Guess whose birthday it is today.", "Devine de qui c'est l'anniversaire, aujourd'hui !"),
("He acted like he owned the place.", "Il s'est comporté comme s'il possédait l'endroit."),
("Honesty will pay in the long run.", "L'honnêteté paye à la longue."),
("How do we know this isn't a trap?", "Comment savez-vous qu'il ne s'agit pas d'un piège ?"),
("I can't believe you're giving up.", "Je n'arrive pas à croire que vous abandonniez."),
)
def unicode_to_ascii(s):
return ''.join(
c for c in unicodedata.normalize('NFD', s)
if unicodedata.category(c) != 'Mn')
def normalize_string(s):
s = unicode_to_ascii(s)
s = re.sub(r'([!.?])', r' \1', s)
s = re.sub(r'[^a-zA-Z.!?]+', r' ', s)
s = re.sub(r'\s+', r' ', s)
return s
raw_data_en, raw_data_fr = list(zip(*raw_data))
raw_data_en, raw_data_fr = list(raw_data_en), list(raw_data_fr)
raw_data_en = [normalize_string(data) for data in raw_data_en]
raw_data_fr_in = ['<start> ' + normalize_string(data) for data in raw_data_fr]
raw_data_fr_out = [normalize_string(data) + ' <end>' for data in raw_data_fr]
en_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')
en_tokenizer.fit_on_texts(raw_data_en)
data_en = en_tokenizer.texts_to_sequences(raw_data_en)
data_en = tf.keras.preprocessing.sequence.pad_sequences(data_en,
padding='post')
fr_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')
fr_tokenizer.fit_on_texts(raw_data_fr_in)
fr_tokenizer.fit_on_texts(raw_data_fr_out)
data_fr_in = fr_tokenizer.texts_to_sequences(raw_data_fr_in)
data_fr_in = tf.keras.preprocessing.sequence.pad_sequences(data_fr_in,
padding='post')
data_fr_out = fr_tokenizer.texts_to_sequences(raw_data_fr_out)
data_fr_out = tf.keras.preprocessing.sequence.pad_sequences(data_fr_out,
padding='post')
BATCH_SIZE = 5
dataset = tf.data.Dataset.from_tensor_slices(
(data_en, data_fr_in, data_fr_out))
dataset = dataset.shuffle(20).batch(BATCH_SIZE)
class Encoder(tf.keras.Model):
def __init__(self, vocab_size, embedding_size, lstm_size):
super(Encoder, self).__init__()
self.lstm_size = lstm_size
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_size)
self.lstm = tf.keras.layers.LSTM(
lstm_size, return_sequences=True, return_state=True)
def call(self, sequence, states):
embed = self.embedding(sequence)
output, state_h, state_c = self.lstm(embed, initial_state=states)
return output, state_h, state_c
def init_states(self, batch_size):
return (tf.zeros([batch_size, self.lstm_size]),
tf.zeros([batch_size, self.lstm_size]))
class Decoder(tf.keras.Model):
def __init__(self, vocab_size, embedding_size, lstm_size):
super(Decoder, self).__init__()
self.lstm_size = lstm_size
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_size)
self.lstm = tf.keras.layers.LSTM(
lstm_size, return_sequences=True, return_state=True)
self.dense = tf.keras.layers.Dense(vocab_size)
def call(self, sequence, state):
embed = self.embedding(sequence)
lstm_out, state_h, state_c = self.lstm(embed, state)
logits = self.dense(lstm_out)
return logits, state_h, state_c
def init_states(self, batch_size):
return (tf.zeros([batch_size, self.lstm_size]),
tf.zeros([batch_size, self.lstm_size]))
en_vocab_size = len(en_tokenizer.word_index) + 1
fr_vocab_size = len(fr_tokenizer.word_index) + 1
EMBEDDING_SIZE = 32
LSTM_SIZE = 64
encoder = Encoder(en_vocab_size, EMBEDDING_SIZE, LSTM_SIZE)
decoder = Decoder(fr_vocab_size, EMBEDDING_SIZE, LSTM_SIZE)
initial_states = encoder.init_states(1)
encoder_outputs = encoder(tf.constant([[1, 2, 3]]), initial_states)
decoder_outputs = decoder(tf.constant([[1, 2, 3]]), encoder_outputs[1:])
def loss_func(targets, logits):
crossentropy = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True)
mask = tf.math.logical_not(tf.math.equal(targets, 0))
mask = tf.cast(mask, dtype=tf.int64)
loss = crossentropy(targets, logits, sample_weight=mask)
return loss
optimizer = tf.keras.optimizers.Adam()
def predict(test_source_text=None):
if test_source_text is None:
test_source_text = raw_data_en[np.random.choice(len(raw_data_en))]
print(test_source_text)
test_source_seq = en_tokenizer.texts_to_sequences([test_source_text])
print(test_source_seq)
en_initial_states = encoder.init_states(1)
en_outputs = encoder(tf.constant(test_source_seq), en_initial_states)
de_input = tf.constant([[fr_tokenizer.word_index['<start>']]])
de_state_h, de_state_c = en_outputs[1:]
out_words = []
while True:
de_output, de_state_h, de_state_c = decoder(
de_input, (de_state_h, de_state_c))
de_input = tf.argmax(de_output, -1)
out_words.append(fr_tokenizer.index_word[de_input.numpy()[0][0]])
if out_words[-1] == '<end>' or len(out_words) >= 20:
break
print(' '.join(out_words))
@tf.function
def train_step(source_seq, target_seq_in, target_seq_out, en_initial_states):
loss = 0
with tf.GradientTape() as tape:
en_outputs = encoder(source_seq, en_initial_states)
en_states = en_outputs[1:]
de_states = en_states
de_outputs = decoder(target_seq_in, de_states)
logits = de_outputs[0]
loss = loss_func(target_seq_out, logits)
variables = encoder.trainable_variables + decoder.trainable_variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
return loss
NUM_EPOCHS = 300
for e in range(NUM_EPOCHS):
en_initial_states = encoder.init_states(BATCH_SIZE)
predict()
for batch, (source_seq, target_seq_in, target_seq_out) in enumerate(dataset.take(-1)):
loss = train_step(source_seq, target_seq_in,
target_seq_out, en_initial_states)
print('Epoch {} Loss {:.4f}'.format(e + 1, loss.numpy()))
test_sents = (
'What a ridiculous concept!',
'Your idea is not entirely crazy.',
"A man's worth lies in what he is.",
'What he did is very wrong.',
"All three of you need to do that.",
"Are you giving me another chance?",
"Both Tom and Mary work as models.",
"Can I have a few minutes, please?",
"Could you close the door, please?",
"Did you plant pumpkins this year?",
"Do you ever study in the library?",
"Don't be deceived by appearances.",
"Excuse me. Can you speak English?",
"Few people know the true meaning.",
"Germany produced many scientists.",
"Guess whose birthday it is today.",
"He acted like he owned the place.",
"Honesty will pay in the long run.",
"How do we know this isn't a trap?",
"I can't believe you're giving up.",
)
for test_sent in test_sents:
test_sequence = normalize_string(test_sent)
predict(test_sequence)