In [1]:
from keras.api.layers import Input, Dense, Flatten, MultiHeadAttention, Concatenate, Add, Permute, Dropout, LayerNormalization, Reshape, Layer
from keras.api.models import Model, Sequential
from keras.api.optimizers import Adam
from keras.api.losses import MeanSquaredError
from keras.api.metrics import MeanAbsoluteError
from keras.api.callbacks import EarlyStopping, ModelCheckpoint
from mne.channels import make_standard_montage, get_builtin_montages
import tensorflow as tf
import numpy as np

In [43]:
class SAB(Layer):
    # TODO: embed dim for temporal is the number of channels
    # TODO: embed dim for spatial is the number of encoded channels feature map, i.e. d_model
    def __init__(self, embed_dim, num_heads, mlp_dim, spatial_or_temporal="spatial", dropout_rate=0.1):
        super(SAB, self).__init__()

        if spatial_or_temporal not in ["spatial", "temporal"]:
            raise ValueError("spatial_or_temporal must be either 'spatial' or 'temporal'")
        
        self.spatial_or_temporal = spatial_or_temporal
        
        self.norm1 = LayerNormalization(epsilon=1e-6)
        self.attn = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim, dropout=dropout_rate)
        self.dropout1 = Dropout(dropout_rate)
        self.norm2 = LayerNormalization(epsilon=1e-6)
        self.mlp = Sequential([
            Dense(mlp_dim, activation='gelu'),  # (None, 60, 128)
            Dropout(dropout_rate),
            Dense(embed_dim),                   # (None, 60, 14)
            Dropout(dropout_rate)               
        ])

    def call(self, inp, L=1, training=False):
        # x shape: (batch_size, channels, time_steps)
        # (None, 14, 60)
        out = inp
        is_temporal = self.spatial_or_temporal == "temporal"

        if is_temporal:
            out = Permute((2, 1))(out) # Change to (batch_size, time_steps, channels) (None, 60, 14)
            
        for _ in range(L):
            out_norm = self.norm1(out)
            attn_output = self.attn(out_norm, out_norm)
            out1 = out + self.dropout1(attn_output, training=training)

            out1_norm = self.norm2(out1) # (None, 60, 14)
            mlp_output = self.mlp(out1_norm, training=training)
            out2 = out1 + mlp_output

            out += out2
            
        if is_temporal:
            out = Permute((2, 1))(out) # Change back to (batch_size, channels, time_steps)

        return out

In [44]:
class CAB(Layer):
    def __init__(self, embed_dim, d_model, num_heads, mlp_dim, dropout_rate=0.1):
        super(CAB, self).__init__()
        self.tsab1 = SAB(embed_dim, num_heads, mlp_dim, spatial_or_temporal="temporal", dropout_rate=dropout_rate)
        self.ssab = SAB(d_model, num_heads, mlp_dim, spatial_or_temporal="spatial", dropout_rate=dropout_rate)
        self.tsab2 = SAB(embed_dim, num_heads, mlp_dim, spatial_or_temporal="temporal", dropout_rate=dropout_rate)

    def call(self, inp, L=1, training=False):
        # inp shape: (batch_size, channels, time_steps)
        # (None, 14, 60)
        out = inp
        for _ in range(L):
            print("[CAB]\n")

            out1 = self.tsab1(out, training=training)
            out1 = out + out1
            print("TSAB1")

            out2 = self.ssab(out1, training=training)
            out2 = out1 + out2
            print("SSAB")

            out3 = self.tsab2(out2, training=training)
            out3 = out2 + out3
            out += out3
            print("TSAB2\n")

        return out

In [47]:
def generate_3d_positional_encoding(channel_names, d_model, builtin_montage=None, positions=[]):
    num_channels = len(channel_names)
    builtin_montages = get_builtin_montages()

    if num_channels == 0:
        raise ValueError("The number of channels must be greater than 0.")
    
    if builtin_montage and positions:
        raise ValueError("You can only use either builtin_montage or positions, not both.")
    
    if not builtin_montage and positions and len(positions) != num_channels:
        raise ValueError("The number of positions must match the number of channels.")
    
    if not builtin_montage and positions and len(positions) == num_channels:
        positions = np.array(positions)

    if builtin_montage and not positions and builtin_montage not in builtin_montages:
        raise ValueError(f"Montage '{builtin_montage}' is not available. Please choose from {builtin_montages}.")
    
    if builtin_montage and not positions and builtin_montage in builtin_montages:
        builtin_montage = make_standard_montage(builtin_montage)
        pos_dict = builtin_montage.get_positions()['ch_pos']
        positions = np.array([pos_dict[ch] for ch in channel_names])  # shape: (num_channels, 3)

    assert d_model % 3 == 0, "d_model must be divisible by 3."
    d_model_per_axis = d_model // 3
    print(type(d_model_per_axis))

    pos_encoding = []

    for axis in range(3):
        pos = positions[:, axis]
        pe = np.zeros((num_channels, d_model_per_axis))
        for i in range(d_model_per_axis):
            div_term = np.power(10000.0, (2 * i) / d_model_per_axis)
            pe[:, i] = np.where(i % 2 == 0, np.sin(pos / div_term), np.cos(pos / div_term))
    
        pos_encoding.append(pe)

    pos_encoding = np.concatenate(pos_encoding, axis=-1)    # shape: (num_channels, d_model)
    pos_encoding = np.expand_dims(pos_encoding, axis=0)     # shape: (1, num_channels, d_model)
    return tf.constant(pos_encoding, dtype=tf.float32)      # shape: (1, num_channels, d_model)

In [46]:
class SIM(Layer):
    def __init__(self, embed_dim, d_model, num_heads, mlp_dim, low_res_ch_names, high_res_ch_names, dropout_rate=0.1, builtin_montage=None, positions=[]):
        super(SIM, self).__init__()
        self.cab = CAB(embed_dim, d_model, num_heads, mlp_dim, dropout_rate=dropout_rate)
        self.norm = LayerNormalization(epsilon=1e-6)
        self.dense = Dense(d_model, activation='gelu')
        self.low_res_3d_pos_encoding = generate_3d_positional_encoding(low_res_ch_names, d_model, builtin_montage=builtin_montage, positions=positions)
        self.high_res_3d_pos_encoding = generate_3d_positional_encoding(high_res_ch_names, d_model, builtin_montage=builtin_montage, positions=positions)
        self.mask_token = tf.Variable(initial_value=tf.zeros([1, d_model]), trainable=True)
        
    def call(self, inp, L=1, training=False):
        # inp shape: (batch_size, channels, time_steps)
        # (None, 14, 57)
        print("[SIM]\n")
        out = inp                                       # channel embedding
        out = self.dense(out)                           # (None, 14, 60)
        out = out + self.low_res_3d_pos_encoding        # (None, 14, 60)
        print("CAB1")
        out = self.cab(out, L=L, training=training)
        out = self.norm(out)                            # feature projection
        
        out = Concatenate([out, self.mask_token])
        out = self.dense(out)
        out = out + self.high_res_3d_pos_encoding
        print("CAB2\n")
        out = self.cab(out, L=L, training=training)
        out = self.norm(out)

        return out                                      # time projection

In [48]:
def generate_1d_positional_encoding(time_steps, d_model):
    assert d_model % 2 == 0, "d_model must be even for sin/cos encoding."
    
    pos_encoding = np.zeros((time_steps, d_model))  # Shape: (time_steps, d_model)
    
    for pos in range(time_steps):
        for i in range(d_model // 2):
            div_term = np.power(10000.0, (2 * i) / d_model)
            pos_encoding[pos, 2 * i] = np.sin(pos / div_term)
            pos_encoding[pos, 2 * i + 1] = np.cos(pos / div_term)

    pos_encoding = np.expand_dims(pos_encoding, axis=0)  # Shape: (1, time_steps, d_model)
    return tf.constant(pos_encoding, dtype=tf.float32)

In [49]:
class TRM(Layer):
    def __init__(self, embed_dim, d_model, num_heads, mlp_dim, dropout_rate=0.1):
        super(TRM, self).__init__()
        self.tsab = SAB(embed_dim, num_heads, mlp_dim, spatial_or_temporal="temporal", dropout_rate=dropout_rate)
        self.norm = LayerNormalization(epsilon=1e-6)
        self.dense = Dense(d_model, activation='gelu')
        self._1d_pos_encoding = generate_1d_positional_encoding()
        
    def call(self, inp, L=1, training=False):
        # inp shape: (batch_size, channels, time_steps)
        out = Permute((2, 1))(inp)                      # time embedding (batch_size, time_steps, channels)
        out = self.dense(out)
        out = out + self._1d_pos_encoding
        out = self.tsab(out, L=L, training=training)
        out = self.norm(out)                            # feature projection

        out = self.dense(out)
        out = out + self._1d_pos_encoding
        out = self.tsab(out, L=L, training=training)
        out = self.norm(out)                            # feature projection
        out = Permute((2, 1))(out)                      # reshape to (batch_size, channels, time_steps)
        return out                                      # channel projection

In [50]:
high_res_ch_names = ['Fp1', 'Fz', 'F3', 'F7', 'FT9', 'FC5', 'FC1', 'C3', 'T7', 'TP9', 'CP5', 'CP1', 'Pz', 'P3', 'P7', 'O1', 'Oz', 'O2', 'P4', 'P8', 'TP10', 'CP6', 'CP2', 'C4', 'T8', 'FT10', 'FC6', 'FC2', 'F4', 'F8', 'Fp2', 'AF7', 'AF3', 'AFz', 'F1', 'F5', 'FT7', 'FC3', 'C1', 'C5', 'TP7', 'CP3', 'P1', 'P5', 'PO7', 'PO3', 'POz', 'PO4', 'PO8', 'P6', 'P2', 'CPz', 'CP4', 'TP8', 'C6', 'C2', 'FC4', 'FT8', 'F6', 'AF8', 'AF4', 'F2', 'FCz']
low_res_ch_names = ['AF3', 'F7', 'F3', 'FC5', 'T7', 'P7', 'O1', 'O2', 'P8', 'T8', 'FC6', 'F4', 'F8', 'AF4']
builtin_montage = 'standard_1020'
time_steps = 57
embed_dim = 14
d_model = 60
num_heads = 4
mlp_dim = 128
dropout_rate = 0.1

input_epoch = Input(shape=(len(low_res_ch_names), time_steps)) # (batch_size, channels, time_steps)
print("SIM")
sim = SIM(embed_dim=embed_dim, d_model=d_model, num_heads=num_heads, mlp_dim=mlp_dim, dropout_rate=dropout_rate, low_res_ch_names=low_res_ch_names, high_res_ch_names=high_res_ch_names, builtin_montage=builtin_montage)(input_epoch, L=1, training=True)
trm = TRM(embed_dim=embed_dim, d_model=d_model, num_heads=num_heads, mlp_dim=mlp_dim, dropout_rate=dropout_rate)(sim, L=1, training=True)
model = sim + trm

model.summary()

SIM
<class 'int'>
<class 'int'>
[SIM]

CAB1
[CAB]

TSAB1
SSAB
TSAB2

[CAB]

TSAB1
SSAB
TSAB2

[SIM]

CAB1
[CAB]

TSAB1


1. The `call()` method of your layer may be crashing. Try to `__call__()` the layer eagerly on some test input first to see if it works. E.g. `x = np.random.random((3, 4)); y = layer(x)`
2. If the `call()` method is correct, then you may need to implement the `def build(self, input_shape)` method on your layer. It should create all variables used by the layer (e.g. by calling `layer.build()` on all its children layers).
Exception encountered: ''Only input tensors may be passed as positional arguments. The following argument value should be passed as a keyword argument: <Concatenate name=concatenate_4, built=False> (of type <class 'keras.src.layers.merging.concatenate.Concatenate'>)''


SSAB
TSAB2



ValueError: Exception encountered when calling SIM.call().

[1mCould not automatically infer the output shape / dtype of 'sim_18' (of type SIM). Either the `SIM.call()` method is incorrect, or you need to implement the `SIM.compute_output_spec() / compute_output_shape()` method. Error encountered:

Only input tensors may be passed as positional arguments. The following argument value should be passed as a keyword argument: <Concatenate name=concatenate_5, built=False> (of type <class 'keras.src.layers.merging.concatenate.Concatenate'>)[0m

Arguments received by SIM.call():
  • args=('<KerasTensor shape=(None, 14, 57), dtype=float32, sparse=False, ragged=False, name=keras_tensor_168>',)
  • kwargs={'L': '1', 'training': 'True'}