# Attention Mechanism Demo on Keras: Machine Translation Example (Many-to-Many, encoder-decoder)

In this demo, we will show you how to create a machine translator using Keras. This demo is inspired by Andrew Ng's deeplearning.ai course on sequence models. (Programming Assignment: Neural Machine Translation with Attention)    In this demo, we create a machine translator to translate dates in various formats  into dates in an ISO format. 

In [None]:
%matplotlib inline

try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass

import tensorflow as tf
print(tf.__version__)

from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.layers import Bidirectional, Concatenate, Permute, Dot, Input, LSTM, Multiply
from tensorflow.keras.layers import RepeatVector, Dense, Activation, Lambda
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import load_model, Model
import tensorflow.keras.backend as K
import numpy as np

import random


Colab only includes TensorFlow 2.x; %tensorflow_version has no effect.
2.11.0


## Generate Dataset
We generate a toy dataset using datetime library.  A target output only comes in one format (iso format), while there are three different date format for an input.

In [None]:
#Generating a toy dataset
import datetime
base = datetime.datetime.today()
base = datetime.date(base.year, base.month, base.day)
date_list = [base - datetime.timedelta(days=x) for x in range(0, 15000)]

In [None]:
target_date_list = [date.isoformat() for date in date_list] 
print(target_date_list[0])

2023-03-30


In [None]:
from random import randint
random.seed(42)
input_date_list = list()
for date in date_list:
    random_num = randint(0, 2)
    if random_num == 0:
        input_date_list.append(date.strftime("%d/%m/%y"))#"11/03/02"
    elif random_num == 1:
        input_date_list.append(date.strftime("%A %d %B %Y")) #"Monday 11 March 2002"
    elif random_num == 2: 
        input_date_list.append(date.strftime("%d %B %Y")) #"11 March 2002"

In [None]:
for input_sample, target_sample in zip(input_date_list[0:10],target_date_list[0:10]):
    print(input_sample,target_sample)

30 March 2023 2023-03-30
29/03/23 2023-03-29
28/03/23 2023-03-28
27 March 2023 2023-03-27
Sunday 26 March 2023 2023-03-26
25/03/23 2023-03-25
24/03/23 2023-03-24
23/03/23 2023-03-23
22 March 2023 2023-03-22
21/03/23 2023-03-21


In [None]:
#Preprocessing
input_chars = list(set(''.join(input_date_list)))
output_chars = list(set(''.join(target_date_list)))

# +1 for padding
data_size, vocab_size = len(input_date_list), len(input_chars)+1 
output_vocab_size = len(output_chars)+1

print('There are %d lines and %d unique characters in your input data.' % (data_size, vocab_size))
maxlen = len( max(input_date_list, key=len)) #max input length

There are 15000 lines and 42 unique characters in your input data.


In [None]:
print("Max input length:", maxlen)

Max input length: 27


In [None]:
sorted_chars= sorted(input_chars)
sorted_output_chars= sorted(output_chars)
sorted_chars.insert(0,"<PAD>") #PADDING for input
sorted_output_chars.insert(0,"<PAD>") #PADDING for output
#Input
char_to_ix = { ch:i for i,ch in enumerate(sorted_chars) }
ix_to_char = { i:ch for i,ch in enumerate(sorted_chars) } #reverse dictionary
#Output
output_char_to_ix = { ch:i for i,ch in enumerate(sorted_output_chars) }
ix_to_output_char = { i:ch for i,ch in enumerate(sorted_output_chars) } #reverse dictionary

print(ix_to_char)
print(ix_to_output_char)

{0: '<PAD>', 1: ' ', 2: '/', 3: '0', 4: '1', 5: '2', 6: '3', 7: '4', 8: '5', 9: '6', 10: '7', 11: '8', 12: '9', 13: 'A', 14: 'D', 15: 'F', 16: 'J', 17: 'M', 18: 'N', 19: 'O', 20: 'S', 21: 'T', 22: 'W', 23: 'a', 24: 'b', 25: 'c', 26: 'd', 27: 'e', 28: 'g', 29: 'h', 30: 'i', 31: 'l', 32: 'm', 33: 'n', 34: 'o', 35: 'p', 36: 'r', 37: 's', 38: 't', 39: 'u', 40: 'v', 41: 'y'}
{0: '<PAD>', 1: '-', 2: '0', 3: '1', 4: '2', 5: '3', 6: '4', 7: '5', 8: '6', 9: '7', 10: '8', 11: '9'}


In [None]:
m=15000  # #sample
Tx=maxlen # size of input = 27
Ty=10 # size of output = 10

In [None]:
X = []
for line in input_date_list:
    temp=[]
    for char in line:
        temp.append(char_to_ix[char])
    X.append(temp)
Y = []
for line in target_date_list:
    temp=[]
    for char in line:
        temp.append(output_char_to_ix[char])
    Y.append(temp)    

X = pad_sequences(X,maxlen=maxlen)
# Y = pad_sequences(Y,maxlen=10)

X= to_categorical(X,vocab_size)
X=X.reshape(data_size,maxlen ,vocab_size)

Y= to_categorical(Y,output_vocab_size)
Y=Y.reshape(data_size,10 ,output_vocab_size)
print(X.shape,Y.shape)

(15000, 27, 42) (15000, 10, 12)


# Attention Mechanism
--> https://drive.google.com/file/d/1xY2_yGARtR8MDw231j7OmH-zCB4XiOkl/view?usp=share_link 

In [None]:
from tensorflow.keras.activations import softmax
def softMaxAxis1(x):
    return softmax(x,axis=1)

In [None]:
#These are global variables (shared layers)
repeator = RepeatVector(Tx)
concatenator = Concatenate(axis=-1)
#Attention function###
fattn_1 = Dense(10, activation = "tanh")
fattn_2 = Dense(1, activation = "relu")
###
activator = Activation(softMaxAxis1, name='attention_scores') 
dotor = Dot(axes = 1)

In [None]:
def one_step_attention(a, s_prev):

    # Repeat the decoder hidden state to concat with encoder hidden states
    s_prev = repeator(s_prev)
    concat = concatenator([a,s_prev])
    # attention function
    e = fattn_1(concat)
    energies =fattn_2(e)
    # calculate attention_scores (softmax)
    attention_scores = activator(energies)
    #calculate a context vector
    context = dotor([attention_scores,a])

    return context

# The model
--> https://drive.google.com/file/d/1dcBMZG_fxfawQChmM6b8OsWtX7jR6cI9/view?usp=share_link

In [None]:
n_h = 32 #hidden dimensions for encoder 
n_s = 64 #hidden dimensions for decoder
encoder_LSTM =  Bidirectional(LSTM(n_h, return_sequences=True),input_shape=(-1, Tx, n_h*2))
decoder_LSTM_cell = LSTM(n_s, return_state = True) #decoder_LSTM_cell
output_layer = Dense(output_vocab_size, activation="softmax") #softmax output layer

In [None]:
def model(Tx, Ty, n_h, n_s, vocab_size, machine_vocab_size):
    """
    Arguments:
    Tx -- length of the input sequence
    Ty -- length of the output sequence
    n_h -- hidden state size of the Bi-LSTM
    n_s -- hidden state size of the post-attention LSTM
    vocab_size -- size of the input vocab
    output_vocab_size -- size of the output vocab

    Returns:
    model -- Keras model instance
    """
    
    # Define the input of your model
    X = Input(shape=(Tx, vocab_size))
    # Define hidden state and cell state for decoder_LSTM_Cell
    s0 = Input(shape=(n_s,), name='s0')
    c0 = Input(shape=(n_s,), name='c0')
    s = s0
    c = c0
    
    # Initialize empty list of outputs
    outputs = list()

    #Encoder Bi-LSTM
    # h = Bidirectional(LSTM(n_h, return_sequences=True),input_shape=(-1, Tx, n_h*2))(X)
    h = encoder_LSTM(X)
    #Iterate for Ty steps (Decoding)
    for t in range(Ty):
    
        #Perform one step of the attention mechanism to calculate the context vector at timestep t
        context = one_step_attention(h, s)
       
        # Feed the context vector to the decoder LSTM cell
        s, _, c = decoder_LSTM_cell(context,initial_state=[s,c])
           
        # Pass the decoder hidden output to the output layer (softmax)
        out = output_layer(s)
        
        # Append an output list with the current output
        outputs.append(out)
    
    #Create model instance
    model = Model(inputs=[X,s0,c0],outputs=outputs)
    
    return model

In [None]:
model = model(Tx, Ty, n_h, n_s, vocab_size, output_vocab_size)

In [None]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 27, 42)]     0           []                               
                                                                                                  
 s0 (InputLayer)                [(None, 64)]         0           []                               
                                                                                                  
 bidirectional (Bidirectional)  (None, 27, 64)       19200       ['input_1[0][0]']                
                                                                                                  
 repeat_vector (RepeatVector)   (None, 27, 64)       0           ['s0[0][0]',                     
                                                                  'lstm_1[0][0]',             

In [None]:
opt = Adam(lr= 0.01, clipvalue=0.5)
model.compile(loss='categorical_crossentropy',optimizer=opt,metrics=['accuracy'])



In [None]:
s0 = np.zeros((m, n_s))
c0 = np.zeros((m, n_s))
outputs = list(Y.swapaxes(0,1))

In [None]:
model.fit([X, s0, c0], outputs, epochs=20, batch_size=120)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.callbacks.History at 0x7f8e50040520>

# Let's do some "translation"

In [None]:
def prep_input(input_list):
    X = []
    for line in input_list:
        temp=[]
        for char in line:
            temp.append(char_to_ix[char])
        X.append(temp)
    X = pad_sequences(X,maxlen=maxlen)
    X= to_categorical(X,vocab_size)
    X=X.reshape(len(input_list),maxlen ,vocab_size)
    
    return X

EXAMPLES = ['3 May 1999', '05 October 2009', '30 August 2016', '11 July 2000', 'Saturday 19 May 2018', '3 March 2001', '1 March 2001']
s0 = np.zeros((len(EXAMPLES), n_s))
c0 = np.zeros((len(EXAMPLES), n_s))
EXAMPLES = prep_input(EXAMPLES)

prediction = model.predict([EXAMPLES , s0, c0])
prediction = np.swapaxes(prediction,0,1)
prediction = np.argmax(prediction, axis = -1)

for j in range(len(prediction)):
    output = "".join([ix_to_output_char[int(i)] for i in prediction[j]])
    print(output)

1999-05-23
2009-10-06
2016-08-33
2000-07-11
2018-05-19
2001-03-33
2001-03-11
