Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
1 contributor

Users who have contributed to this file

executable file 390 lines (302 sloc) 13.9 KB
# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np
import unicodedata
import re
import matplotlib.pyplot as plt
import os
import imageio
from zipfile import ZipFile
# Mode can be either 'train' or 'infer'
# Set to 'infer' will skip the training
MODE = 'train'
URL = 'http://www.manythings.org/anki/fra-eng.zip'
FILENAME = 'fra-eng.zip'
BATCH_SIZE = 64
EMBEDDING_SIZE = 256
RNN_SIZE = 512
NUM_EPOCHS = 15
# Set the score function to compute alignment vectors
# Can choose between 'dot', 'general' or 'concat'
ATTENTION_FUNC = 'concat'
def maybe_download_and_read_file(url, filename):
if not os.path.exists(filename):
session = requests.Session()
response = session.get(url, stream=True)
CHUNK_SIZE = 32768
with open(filename, "wb") as f:
for chunk in response.iter_content(CHUNK_SIZE):
if chunk:
f.write(chunk)
zipf = ZipFile(filename)
filename = zipf.namelist()
with zipf.open('fra.txt') as f:
lines = f.read()
return lines
lines = maybe_download_and_read_file(URL, FILENAME)
lines = lines.decode('utf-8')
raw_data = []
for line in lines.split('\n'):
raw_data.append(line.split('\t'))
print(raw_data[-5:])
# The last element is empty, so omit it
raw_data = raw_data[:-1]
def unicode_to_ascii(s):
return ''.join(
c for c in unicodedata.normalize('NFD', s)
if unicodedata.category(c) != 'Mn'
)
def normalize_string(s):
s = unicode_to_ascii(s)
s = re.sub(r'([!.?])', r' \1', s)
s = re.sub(r'[^a-zA-Z.!?]+', r' ', s)
s = re.sub(r'\s+', r' ', s)
return s
raw_data_en, raw_data_fr = list(zip(*raw_data))
raw_data_en = [normalize_string(data) for data in raw_data_en]
raw_data_fr_in = ['<start> ' + normalize_string(data) for data in raw_data_fr]
raw_data_fr_out = [normalize_string(data) + ' <end>' for data in raw_data_fr]
en_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')
en_tokenizer.fit_on_texts(raw_data_en)
data_en = en_tokenizer.texts_to_sequences(raw_data_en)
data_en = tf.keras.preprocessing.sequence.pad_sequences(data_en,
padding='post')
print('English sequences')
print(data_en[:2])
fr_tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')
fr_tokenizer.fit_on_texts(raw_data_fr_in)
fr_tokenizer.fit_on_texts(raw_data_fr_out)
data_fr_in = fr_tokenizer.texts_to_sequences(raw_data_fr_in)
data_fr_in = tf.keras.preprocessing.sequence.pad_sequences(data_fr_in,
padding='post')
print('French input sequences')
print(data_fr_in[:2])
data_fr_out = fr_tokenizer.texts_to_sequences(raw_data_fr_out)
data_fr_out = tf.keras.preprocessing.sequence.pad_sequences(data_fr_out,
padding='post')
print('French output sequences')
print(data_fr_out[:2])
dataset = tf.data.Dataset.from_tensor_slices(
(data_en, data_fr_in, data_fr_out))
dataset = dataset.shuffle(len(raw_data_en)).batch(
BATCH_SIZE, drop_remainder=True)
class Encoder(tf.keras.Model):
def __init__(self, vocab_size, embedding_size, rnn_size):
super(Encoder, self).__init__()
self.rnn_size = rnn_size
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_size)
self.lstm = tf.keras.layers.LSTM(
rnn_size, return_sequences=True, return_state=True)
def call(self, sequence, states):
embed = self.embedding(sequence)
output, state_h, state_c = self.lstm(embed, initial_state=states)
return output, state_h, state_c
def init_states(self, batch_size):
return (tf.zeros([batch_size, self.rnn_size]),
tf.zeros([batch_size, self.rnn_size]))
en_vocab_size = len(en_tokenizer.word_index) + 1
encoder = Encoder(en_vocab_size, EMBEDDING_SIZE, RNN_SIZE)
class LuongAttention(tf.keras.Model):
def __init__(self, rnn_size, attention_func):
super(LuongAttention, self).__init__()
self.attention_func = attention_func
if attention_func not in ['dot', 'general', 'concat']:
raise ValueError(
'Unknown attention score function! Must be either dot, general or concat.')
if attention_func == 'general':
# General score function
self.wa = tf.keras.layers.Dense(rnn_size)
elif attention_func == 'concat':
# Concat score function
self.wa = tf.keras.layers.Dense(rnn_size, activation='tanh')
self.va = tf.keras.layers.Dense(1)
def call(self, decoder_output, encoder_output):
if self.attention_func == 'dot':
# Dot score function: decoder_output (dot) encoder_output
# decoder_output has shape: (batch_size, 1, rnn_size)
# encoder_output has shape: (batch_size, max_len, rnn_size)
# => score has shape: (batch_size, 1, max_len)
score = tf.matmul(decoder_output, encoder_output, transpose_b=True)
elif self.attention_func == 'general':
# General score function: decoder_output (dot) (Wa (dot) encoder_output)
# decoder_output has shape: (batch_size, 1, rnn_size)
# encoder_output has shape: (batch_size, max_len, rnn_size)
# => score has shape: (batch_size, 1, max_len)
score = tf.matmul(decoder_output, self.wa(
encoder_output), transpose_b=True)
elif self.attention_func == 'concat':
# Concat score function: va (dot) tanh(Wa (dot) concat(decoder_output + encoder_output))
# Decoder output must be broadcasted to encoder output's shape first
decoder_output = tf.tile(
decoder_output, [1, encoder_output.shape[1], 1])
# Concat => Wa => va
# (batch_size, max_len, 2 * rnn_size) => (batch_size, max_len, rnn_size) => (batch_size, max_len, 1)
score = self.va(
self.wa(tf.concat((decoder_output, encoder_output), axis=-1)))
# Transpose score vector to have the same shape as other two above
# (batch_size, max_len, 1) => (batch_size, 1, max_len)
score = tf.transpose(score, [0, 2, 1])
# alignment a_t = softmax(score)
alignment = tf.nn.softmax(score, axis=2)
# context vector c_t is the weighted average sum of encoder output
context = tf.matmul(alignment, encoder_output)
return context, alignment
class Decoder(tf.keras.Model):
def __init__(self, vocab_size, embedding_size, rnn_size, attention_func):
super(Decoder, self).__init__()
self.attention = LuongAttention(rnn_size, attention_func)
self.rnn_size = rnn_size
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_size)
self.lstm = tf.keras.layers.LSTM(
rnn_size, return_sequences=True, return_state=True)
self.wc = tf.keras.layers.Dense(rnn_size, activation='tanh')
self.ws = tf.keras.layers.Dense(vocab_size)
def call(self, sequence, state, encoder_output):
# Remember that the input to the decoder
# is now a batch of one-word sequences,
# which means that its shape is (batch_size, 1)
embed = self.embedding(sequence)
# Therefore, the lstm_out has shape (batch_size, 1, rnn_size)
lstm_out, state_h, state_c = self.lstm(embed, initial_state=state)
# Use self.attention to compute the context and alignment vectors
# context vector's shape: (batch_size, 1, rnn_size)
# alignment vector's shape: (batch_size, 1, source_length)
context, alignment = self.attention(lstm_out, encoder_output)
# Combine the context vector and the LSTM output
# Before combined, both have shape of (batch_size, 1, rnn_size),
# so let's squeeze the axis 1 first
# After combined, it will have shape of (batch_size, 2 * rnn_size)
lstm_out = tf.concat(
[tf.squeeze(context, 1), tf.squeeze(lstm_out, 1)], 1)
# lstm_out now has shape (batch_size, rnn_size)
lstm_out = self.wc(lstm_out)
# Finally, it is converted back to vocabulary space: (batch_size, vocab_size)
logits = self.ws(lstm_out)
return logits, state_h, state_c, alignment
fr_vocab_size = len(fr_tokenizer.word_index) + 1
decoder = Decoder(fr_vocab_size, EMBEDDING_SIZE, RNN_SIZE, ATTENTION_FUNC)
# These lines can be used for debugging purpose
# Or can be seen as a way to build the models
initial_state = encoder.init_states(1)
encoder_outputs = encoder(tf.constant([[1]]), initial_state)
decoder_outputs = decoder(tf.constant(
[[1]]), encoder_outputs[1:], encoder_outputs[0])
def loss_func(targets, logits):
crossentropy = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True)
mask = tf.math.logical_not(tf.math.equal(targets, 0))
mask = tf.cast(mask, dtype=tf.int64)
loss = crossentropy(targets, logits, sample_weight=mask)
return loss
optimizer = tf.keras.optimizers.Adam(clipnorm=5.0)
def predict(test_source_text=None):
if test_source_text is None:
test_source_text = raw_data_en[np.random.choice(len(raw_data_en))]
print(test_source_text)
test_source_seq = en_tokenizer.texts_to_sequences([test_source_text])
print(test_source_seq)
en_initial_states = encoder.init_states(1)
en_outputs = encoder(tf.constant(test_source_seq), en_initial_states)
de_input = tf.constant([[fr_tokenizer.word_index['<start>']]])
de_state_h, de_state_c = en_outputs[1:]
out_words = []
alignments = []
while True:
de_output, de_state_h, de_state_c, alignment = decoder(
de_input, (de_state_h, de_state_c), en_outputs[0])
de_input = tf.expand_dims(tf.argmax(de_output, -1), 0)
out_words.append(fr_tokenizer.index_word[de_input.numpy()[0][0]])
alignments.append(alignment.numpy())
if out_words[-1] == '<end>' or len(out_words) >= 20:
break
print(' '.join(out_words))
return np.array(alignments), test_source_text.split(' '), out_words
@tf.function
def train_step(source_seq, target_seq_in, target_seq_out, en_initial_states):
loss = 0
with tf.GradientTape() as tape:
en_outputs = encoder(source_seq, en_initial_states)
en_states = en_outputs[1:]
de_state_h, de_state_c = en_states
# We need to create a loop to iterate through the target sequences
for i in range(target_seq_out.shape[1]):
# Input to the decoder must have shape of (batch_size, length)
# so we need to expand one dimension
decoder_in = tf.expand_dims(target_seq_in[:, i], 1)
logit, de_state_h, de_state_c, _ = decoder(
decoder_in, (de_state_h, de_state_c), en_outputs[0])
# The loss is now accumulated through the whole batch
loss += loss_func(target_seq_out[:, i], logit)
variables = encoder.trainable_variables + decoder.trainable_variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
return loss / target_seq_out.shape[1]
if not os.path.exists('checkpoints_luong/encoder'):
os.makedirs('checkpoints_luong/encoder')
if not os.path.exists('checkpoints_luong/decoder'):
os.makedirs('checkpoints_luong/decoder')
# Uncomment these lines for inference mode
encoder_checkpoint = tf.train.latest_checkpoint('checkpoints_luong/encoder')
decoder_checkpoint = tf.train.latest_checkpoint('checkpoints_luong/decoder')
if encoder_checkpoint is not None and decoder_checkpoint is not None:
encoder.load_weights(encoder_checkpoint)
decoder.load_weights(decoder_checkpoint)
if MODE == 'train':
for e in range(NUM_EPOCHS):
en_initial_states = encoder.init_states(BATCH_SIZE)
encoder.save_weights(
'checkpoints_luong/encoder/encoder_{}.h5'.format(e + 1))
decoder.save_weights(
'checkpoints_luong/decoder/decoder_{}.h5'.format(e + 1))
for batch, (source_seq, target_seq_in, target_seq_out) in enumerate(dataset.take(-1)):
loss = train_step(source_seq, target_seq_in,
target_seq_out, en_initial_states)
if batch % 100 == 0:
print('Epoch {} Batch {} Loss {:.4f}'.format(
e + 1, batch, loss.numpy()))
try:
predict()
predict("How are you today ?")
except Exception:
continue
if not os.path.exists('heatmap'):
os.makedirs('heatmap')
test_sents = (
'What a ridiculous concept!',
'Your idea is not entirely crazy.',
"A man's worth lies in what he is.",
'What he did is very wrong.',
"All three of you need to do that.",
"Are you giving me another chance?",
"Both Tom and Mary work as models.",
"Can I have a few minutes, please?",
"Could you close the door, please?",
"Did you plant pumpkins this year?",
"Do you ever study in the library?",
"Don't be deceived by appearances.",
"Excuse me. Can you speak English?",
"Few people know the true meaning.",
"Germany produced many scientists.",
"Guess whose birthday it is today.",
"He acted like he owned the place.",
"Honesty will pay in the long run.",
"How do we know this isn't a trap?",
"I can't believe you're giving up.",
)
filenames = []
for i, test_sent in enumerate(test_sents):
test_sequence = normalize_string(test_sent)
alignments, source, prediction = predict(test_sequence)
attention = np.squeeze(alignments, (1, 2))
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(1, 1, 1)
ax.matshow(attention, cmap='jet')
ax.set_xticklabels([''] + source, rotation=90)
ax.set_yticklabels([''] + prediction)
filenames.append('heatmap/test_{}.png'.format(i))
plt.savefig('heatmap/test_{}.png'.format(i))
plt.close()
with imageio.get_writer('translation_heatmaps.gif', mode='I', duration=2) as writer:
for filename in filenames:
image = imageio.imread(filename)
writer.append_data(image)
You can’t perform that action at this time.