In [5]:
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Bidirectional, LSTM, Dense, TimeDistributed, Dropout, concatenate
from tensorflow.keras.models import Sequential
import numpy as np


In [10]:
n_channels = 61
n_dipoles = 1284
n_time = 100
n_dense_units = 300
n_lstm_units = 100
activation_function = 'relu'
dropout = 0.1

In [11]:
inputs = tf.keras.Input(shape=(None, n_channels), name='Input')


# FC1
fc1 = TimeDistributed(Dense(n_dense_units, 
    activation=activation_function), 
    name='FC1')(inputs)
fc1 = Dropout(dropout, name='Dropout1')(fc1)


# LSTM path
lstm1 = Bidirectional(LSTM(n_lstm_units, return_sequences=True, 
    input_shape=(None, n_dense_units), dropout=dropout), name='LSTM')(fc1)

# concatenate
concat = concatenate([lstm1, fc1], name='Concat')

# FC2
fc2 = TimeDistributed(Dense(n_dense_units,
    activation=activation_function), 
    name='FC2')(concat)
fc2 = Dropout(dropout, name='Dropout2')(fc2)


output = TimeDistributed(Dense(n_dipoles,
    activation='linear'), 
    name='output')(fc2)

model = tf.keras.Model(inputs=inputs, outputs=output, 
    name='context_net')

model.summary()


Model: "context_net"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
Input (InputLayer)              [(None, None, 61)]   0                                            
__________________________________________________________________________________________________
FC1 (TimeDistributed)           (None, None, 300)    18600       Input[0][0]                      
__________________________________________________________________________________________________
Dropout1 (Dropout)              (None, None, 300)    0           FC1[0][0]                        
__________________________________________________________________________________________________
LSTM (Bidirectional)            (None, None, 200)    320800      Dropout1[0][0]                   
________________________________________________________________________________________