In [16]:
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from multi_head import MultiHead

In [8]:
input_layer = tf.keras.layers.Input(shape=(100,))

In [9]:
embeddings_layer = Embedding(1000, 50)(input_layer)

In [10]:
embeddings_layer

<tf.Tensor 'embedding_1/Identity:0' shape=(None, 100, 50) dtype=float32>

In [11]:
splits = tf.keras.layers.Lambda(lambda x: tf.split(x, num_or_size_splits=2, axis=2))(embeddings_layer)

In [12]:
splits

[<tf.Tensor 'lambda_1/Identity:0' shape=(None, 100, 25) dtype=float32>,
 <tf.Tensor 'lambda_1/Identity_1:0' shape=(None, 100, 25) dtype=float32>]

In [13]:
x1 = Bidirectional(LSTM(128, return_sequences=True))(splits[0])
x2 = Bidirectional(LSTM(128, return_sequences=True))(splits[1])

In [17]:
multi_head_x1 = MultiHead(Bidirectional(LSTM(128, return_sequences=True)), 3)(splits[0])

In [19]:
dense_layer = Dense(1)(multi_head_x1)

In [22]:
dense_layer

<tf.Tensor 'dense/Identity:0' shape=(None, 100, 256, 1) dtype=float32>

In [23]:
tf.squeeze(dense_layer,axis=3)

<tf.Tensor 'Squeeze_1:0' shape=(None, 100, 256) dtype=float32>

In [27]:
conct_x = concatenate([x1,x2])

In [28]:
conct_layer = Bidirectional(GRU(128, return_sequences=True))(conct_x)

In [30]:
max_pool_layer = GlobalMaxPooling1D()(conct_layer)

In [2]:
def build_model(max_seq_len, n_words=1000, embed_size=50):
    input_layer = tf.keras.layers.Input(shape=(max_seq_len,))
    embeddings_layer = Embedding(n_words, embed_size)(input_layer)
    splits = tf.keras.layers.Lambda(lambda x: tf.split(x, num_or_size_splits=2, axis=2))(embeddings_layer)
    
    x1 = Bidirectional(LSTM(128, return_sequences=True))(splits[0])
    x2 = Bidirectional(LSTM(128, return_sequences=True))(splits[1])
    
    conct_x = concatenate([x1,x2])
    conct_layer = Bidirectional(GRU(128, return_sequences=True))(conct_x)
    max_pool_layer = GlobalMaxPooling1D()(conct_layer)
    
    return Model(input_layer, max_pool_layer)
    
    

In [26]:
def build_model_multihead(max_seq_len, n_words=1000, embed_size=50, n_head=3):
    input_layer = tf.keras.layers.Input(shape=(max_seq_len,))
    embeddings_layer = Embedding(n_words, embed_size)(input_layer)
    splits = tf.keras.layers.Lambda(lambda x: tf.split(x, num_or_size_splits=2, axis=2))(embeddings_layer)
    
    x1 = Bidirectional(LSTM(128, return_sequences=True))#(splits[0])
    multi_head_x1 = MultiHead(x1, n_head)(splits[0])
    dense_layer_x1 = Dense(1)(multi_head_x1)
    multi_head_x1_squeeze = tf.squeeze(dense_layer_x1,axis=3)
    
    x2 = Bidirectional(LSTM(128, return_sequences=True))#(splits[1])
    multi_head_x2 = MultiHead(x2, n_head)(splits[1])
    dense_layer_x2 = Dense(1)(multi_head_x2)
    multi_head_x2_squeeze = tf.squeeze(dense_layer_x2,axis=3)
    
    conct_x = concatenate([multi_head_x1_squeeze,multi_head_x2_squeeze])
    conct_layer = Bidirectional(GRU(128, return_sequences=True))(conct_x)
    max_pool_layer = GlobalMaxPooling1D()(conct_layer)
    
    return Model(input_layer, max_pool_layer)
    
    

In [3]:
model = build_model(120)

In [4]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 120)]        0                                            
__________________________________________________________________________________________________
embedding (Embedding)           (None, 120, 50)      50000       input_1[0][0]                    
__________________________________________________________________________________________________
lambda (Lambda)                 [(None, 120, 25), (N 0           embedding[0][0]                  
__________________________________________________________________________________________________
bidirectional (Bidirectional)   (None, 120, 256)     157696      lambda[0][0]                     
______________________________________________________________________________________________

In [27]:
model = build_model_multihead(120)
model.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_4 (InputLayer)            [(None, 120)]        0                                            
__________________________________________________________________________________________________
embedding_3 (Embedding)         (None, 120, 50)      50000       input_4[0][0]                    
__________________________________________________________________________________________________
lambda_3 (Lambda)               [(None, 120, 25), (N 0           embedding_3[0][0]                
__________________________________________________________________________________________________
multi_head_2 (MultiHead)        (None, 120, 256, 3)  473088      lambda_3[0][0]                   
____________________________________________________________________________________________

In [30]:
1489240/858448

1.7348051367118333