In [13]:
from keras.api.layers import Input, Dense, Flatten, MultiHeadAttention, Concatenate, Add, Permute, Dropout, LayerNormalization, Reshape, Layer
from keras.api.models import Model, Sequential
from keras.api.optimizers import Adam
from keras.api.losses import MeanSquaredError
from keras.api.metrics import MeanAbsoluteError
from keras.api.callbacks import EarlyStopping, ModelCheckpoint
from mne.channels import make_standard_montage, get_builtin_montages
import tensorflow as tf
import numpy as np

In [44]:
class SAB(Layer):
    # TODO: embed dim for temporal is the number of channels
    # TODO: embed dim for spatial is the number of encoded channels feature map, i.e. d_model
    def __init__(self, embed_dim, num_heads, mlp_dim, spatial_or_temporal="spatial", dropout_rate=0.1):
        super(SAB, self).__init__()

        if spatial_or_temporal not in ["spatial", "temporal"]:
            raise ValueError("spatial_or_temporal must be either 'spatial' or 'temporal'")
        
        self.spatial_or_temporal = spatial_or_temporal
        
        self.norm1 = LayerNormalization(epsilon=1e-6)
        self.attn = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim, dropout=dropout_rate)
        self.dropout1 = Dropout(dropout_rate)
        self.norm2 = LayerNormalization(epsilon=1e-6)
        self.mlp = Sequential([
            Dense(mlp_dim, activation='gelu'),
            Dropout(dropout_rate),
            Dense(embed_dim),
            Dropout(dropout_rate)               
        ])

    def generate(self, inp, L=1, training=False):
        # x shape: (batch_size, channels, time_steps)
        print("\t\tSAB called")
        out = inp
        is_temporal = self.spatial_or_temporal == "temporal"

        if is_temporal:
            out = Permute((2, 1))(out) # Change to (batch_size, time_steps, channels)
            
        for _ in range(L):
            out_norm = self.norm1(out)
            attn_output = self.attn(out_norm, out_norm)
            out1 = out + self.dropout1(attn_output, training=training)

            out1_norm = self.norm2(out1)
            mlp_output = self.mlp(out1_norm, training=training)
            out2 = out1 + mlp_output

            out += out2
            
        if is_temporal:
            out = Permute((2, 1))(out) # Change back to (batch_size, channels, time_steps)

        return out

In [40]:
class CAB(Layer):
    def __init__(self, embed_dim, d_model, num_heads, mlp_dim, dropout_rate=0.1):
        super(CAB, self).__init__()
        self.tsab1 = SAB(embed_dim, num_heads, mlp_dim, spatial_or_temporal="temporal", dropout_rate=dropout_rate)
        self.ssab = SAB(d_model, num_heads, mlp_dim, spatial_or_temporal="spatial", dropout_rate=dropout_rate)
        self.tsab2 = SAB(embed_dim, num_heads, mlp_dim, spatial_or_temporal="temporal", dropout_rate=dropout_rate)

    def generate(self, inp, L=1, training=False):
        # inp shape: (batch_size, channels, time_steps)
        out = inp
        for _ in range(L):
            out1 = self.tsab1.generate(out, training=training)
            out1 = out + out1

            out2 = self.ssab.generate(out1, training=training)
            out2 = out1 + out2

            out3 = self.tsab2.generate(out2, training=training)
            out3 = out2 + out3
            out += out3
            print("[CAB] TSAB1 SSAB TSAB2")

        return out

In [22]:
def generate_3d_positional_encoding(channel_names, d_model, builtin_montage=None, positions=[]):
    num_channels = len(channel_names)
    builtin_montages = get_builtin_montages()

    if num_channels == 0:
        raise ValueError("The number of channels must be greater than 0.")
    
    if builtin_montage and positions:
        raise ValueError("You can only use either builtin_montage or positions, not both.")
    
    if not builtin_montage and positions and len(positions) != num_channels:
        raise ValueError("The number of positions must match the number of channels.")
    
    if not builtin_montage and positions and len(positions) == num_channels:
        positions = np.array(positions)

    if builtin_montage and not positions and builtin_montage not in builtin_montages:
        raise ValueError(f"Montage '{builtin_montage}' is not available. Please choose from {builtin_montages}.")
    
    if builtin_montage and not positions and builtin_montage in builtin_montages:
        builtin_montage = make_standard_montage(builtin_montage)
        pos_dict = builtin_montage.get_positions()['ch_pos']
        positions = np.array([pos_dict[ch] for ch in channel_names])  # shape: (num_channels, 3)

    assert d_model % 3 == 0, "d_model must be divisible by 3."
    d_model_per_axis = d_model // 3

    pos_encoding = []

    for axis in range(3):
        pos = positions[:, axis]
        pe = np.zeros((num_channels, d_model_per_axis))
        for i in range(d_model_per_axis):
            div_term = np.power(10000.0, (2 * i) / d_model_per_axis)
            pe[:, i] = np.where(i % 2 == 0, np.sin(pos / div_term), np.cos(pos / div_term))
    
        pos_encoding.append(pe)

    pos_encoding = np.concatenate(pos_encoding, axis=-1)    # shape: (num_channels, d_model)
    pos_encoding = np.expand_dims(pos_encoding, axis=0)     # shape: (1, num_channels, d_model)
    return tf.constant(pos_encoding, dtype=tf.float32)      # shape: (1, num_channels, d_model)

In [41]:
class SIM(Layer):
    def __init__(self, embed_dim, d_model, num_heads, mlp_dim, low_res_ch_names, high_res_ch_names, dropout_rate=0.1, builtin_montage=None, positions=[]):
        super(SIM, self).__init__()
        self.dense1 = Dense(d_model, activation='gelu')
        self.low_res_3d_pos_encoding = generate_3d_positional_encoding(low_res_ch_names, d_model, builtin_montage=builtin_montage, positions=positions)
        self.cab1 = CAB(embed_dim, d_model, num_heads, mlp_dim, dropout_rate=dropout_rate)
        self.norm1 = LayerNormalization(epsilon=1e-6)

        self.mask_token = tf.Variable(initial_value=tf.zeros([1, d_model]), trainable=True)
        
        self.dense2 = Dense(d_model, activation='gelu')
        self.high_res_3d_pos_encoding = generate_3d_positional_encoding(high_res_ch_names, d_model, builtin_montage=builtin_montage, positions=positions)
        self.cab2 = CAB(embed_dim, d_model, num_heads, mlp_dim, dropout_rate=dropout_rate)
        self.norm2 = LayerNormalization(epsilon=1e-6)
        
    def generate(self, inp, L=1, training=False):
        # inp shape: (batch_size, channels, time_steps)
        print("[SIM]")
        out = inp                                       # channel embedding
        out = self.dense1(out)
        out = out + self.low_res_3d_pos_encoding
        print("\tCAB1")
        out = self.cab1.generate(out, L=L, training=training)
        out = self.norm1(out)                            # feature projection
        
        out = Concatenate([out, self.mask_token])
        
        out = self.dense2(out)
        out = out + self.high_res_3d_pos_encoding
        print("\tCAB2")
        out = self.cab2.generate(out, L=L, training=training)
        out = self.norm2(out)

        return out                                      # time projection

In [18]:
def generate_1d_positional_encoding(time_steps, d_model):
    assert d_model % 2 == 0, "d_model must be even for sin/cos encoding."
    
    pos_encoding = np.zeros((time_steps, d_model))  # Shape: (time_steps, d_model)
    
    for pos in range(time_steps):
        for i in range(d_model // 2):
            div_term = np.power(10000.0, (2 * i) / d_model)
            pos_encoding[pos, 2 * i] = np.sin(pos / div_term)
            pos_encoding[pos, 2 * i + 1] = np.cos(pos / div_term)

    pos_encoding = np.expand_dims(pos_encoding, axis=0)  # Shape: (1, time_steps, d_model)
    return tf.constant(pos_encoding, dtype=tf.float32)

In [42]:
class TRM(Layer):
    def __init__(self, embed_dim, d_model, num_heads, mlp_dim, dropout_rate=0.1):
        super(TRM, self).__init__()
        self.dense1 = Dense(d_model, activation='gelu')
        self._1d_pos_encoding = generate_1d_positional_encoding()
        self.tsab1 = SAB(embed_dim, num_heads, mlp_dim, spatial_or_temporal="temporal", dropout_rate=dropout_rate)
        self.norm1 = LayerNormalization(epsilon=1e-6)

        self.dense2 = Dense(d_model, activation='gelu')
        self.tsab2 = SAB(embed_dim, num_heads, mlp_dim, spatial_or_temporal="temporal", dropout_rate=dropout_rate)
        self.norm2 = LayerNormalization(epsilon=1e-6)
        
        
    def call(self, inp, L=1, training=False):
        # inp shape: (batch_size, channels, time_steps)
        out = Permute((2, 1))(inp)                      # time embedding (batch_size, time_steps, channels)
        out = self.dense1(out)
        out = out + self._1d_pos_encoding
        out = self.tsab1(out, L=L, training=training)
        out = self.norm1(out)                            # feature projection

        out = self.dense2(out)
        out = out + self._1d_pos_encoding
        out = self.tsab2(out, L=L, training=training)
        out = self.norm2(out)                            # feature projection
        out = Permute((2, 1))(out)                      # reshape to (batch_size, channels, time_steps)
        return out                                      # channel projection

In [45]:
high_res_ch_names = ['Fp1', 'Fz', 'F3', 'F7', 'FT9', 'FC5', 'FC1', 'C3', 'T7', 'TP9', 'CP5', 'CP1', 'Pz', 'P3', 'P7', 'O1', 'Oz', 'O2', 'P4', 'P8', 'TP10', 'CP6', 'CP2', 'C4', 'T8', 'FT10', 'FC6', 'FC2', 'F4', 'F8', 'Fp2', 'AF7', 'AF3', 'AFz', 'F1', 'F5', 'FT7', 'FC3', 'C1', 'C5', 'TP7', 'CP3', 'P1', 'P5', 'PO7', 'PO3', 'POz', 'PO4', 'PO8', 'P6', 'P2', 'CPz', 'CP4', 'TP8', 'C6', 'C2', 'FC4', 'FT8', 'F6', 'AF8', 'AF4', 'F2', 'FCz']
low_res_ch_names = ['AF3', 'F7', 'F3', 'FC5', 'T7', 'P7', 'O1', 'O2', 'P8', 'T8', 'FC6', 'F4', 'F8', 'AF4']
builtin_montage = 'standard_1020'
time_steps = 57
embed_dim = 14
d_model = 60
num_heads = 4
mlp_dim = 128
dropout_rate = 0.1

input_epoch = Input(shape=(len(low_res_ch_names), time_steps)) # (batch_size, channels, time_steps)
sim = SIM(embed_dim=embed_dim, d_model=d_model, num_heads=num_heads, mlp_dim=mlp_dim, dropout_rate=dropout_rate, low_res_ch_names=low_res_ch_names, high_res_ch_names=high_res_ch_names, builtin_montage=builtin_montage).generate(input_epoch, L=1, training=True)
# trm = TRM(embed_dim=embed_dim, d_model=d_model, num_heads=num_heads, mlp_dim=mlp_dim, dropout_rate=dropout_rate)(sim, L=1, training=True)
# model = sim + trm

# model.summary()

[SIM]
	CAB1
		SAB called
		SAB called
		SAB called
[CAB] TSAB1 SSAB TSAB2


ValueError: Only input tensors may be passed as positional arguments. The following argument value should be passed as a keyword argument: <Concatenate name=concatenate_9, built=False> (of type <class 'keras.src.layers.merging.concatenate.Concatenate'>)