<a href="https://colab.research.google.com/github/SEOYUNJE/Lung-Image-Analysis/blob/main/Untitled9_ipynb%EC%9D%98_%EC%82%AC%EB%B3%B8.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf

In [None]:
class ExponentialGate(tf.keras.layers.Layer):
    def __init__(self, units):
        super(ExponentialGate, self).__init__()
        self.units = units
        self.dense = tf.keras.layers.Dense(units)

    def call(self, inputs):
        return tf.exp(self.dense(inputs))

In [None]:
class mLSTMCell(tf.keras.layers.Layer):
    def __init__(self, units):
        super(mLSTMCell, self).__init__()
        self.units = units

        self.key_transform = tf.keras.layers.Dense(units)
        self.value_transform = tf.keras.layers.Dense(units)
        self.query_transform = tf.keras.layers.Dense(units)

        self.input_gate = ExponentialGate(units)
        self.forget_gate = ExponentialGate(units)
        self.output_gate = tf.keras.layers.Dense(units)

    def call(self, inputs, states):
        h_prev, c_prev = states
        k_t = self.key_transform(inputs)
        v_t = self.value_transform(inputs)
        q_t = self.query_transform(inputs)

        i_t = self.input_gate(inputs)
        f_t = tf.sigmoid(self.forget_gate(inputs))
        o_t = tf.sigmoid(self.output_gate(inputs))

        c_t = tf.matmul(tf.expand_dims(v_t, axis=2), tf.expand_dims(k_t, axis=1))
        h_t = tf.matmul(c_t, tf.expand_dims(q_t, axis=2))
        h_t = tf.squeeze(h_t, axis=-1)

        return h_t, [h_t, c_t]

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.units)

In [None]:
class sLSTMCell(tf.keras.layers.Layer):
    def __init__(self, units):
        super(sLSTMCell, self).__init__()
        self.units = units

        self.input_gate = ExponentialGate(units)
        self.forget_gate = ExponentialGate(units)
        self.output_gate = tf.keras.layers.Dense(units)

        self.input_transform = tf.keras.layers.Dense(units)
        self.hidden_transform = tf.keras.layers.Dense(units)

    def call(self, inputs, states):
        h_prev, c_prev = states
        i_t = self.input_gate(inputs)
        f_t = tf.sigmoid(self.forget_gate(inputs))
        o_t = tf.sigmoid(self.output_gate(inputs))

        c_t = f_t * c_prev + i_t * tf.tanh(self.input_transform(inputs))
        h_t = o_t * tf.tanh(c_t)
        return h_t, [h_t, c_t]

In [None]:
class xLSTMBlock(tf.keras.layers.Layer):
    def __init__(self, units, block_type='sLSTM'):
        super(xLSTMBlock, self).__init__()
        self.block_type = block_type

        if block_type == 'sLSTM':
            self.cell = sLSTMCell(units)
        elif block_type == 'mLSTM':
            self.cell = mLSTMCell(units)
        else:
            raise ValueError("block_type must be 'sLSTM' or 'mLSTM'")

    def call(self, inputs, states):
        return self.cell(inputs, states)

In [None]:
class xLSTMLayer(tf.keras.layers.Layer):
    def __init__(self, hidden_dim, num_layers, block_types):
        super(xLSTMLayer, self).__init__(name='lstm_layer')
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        self.blocks = [xLSTMBlock(hidden_dim, block_types[i]) for i in range(num_layers)]

    def call(self, inputs, training=False):
        batch_size = tf.shape(inputs)[0]
        seq_len = tf.shape(inputs)[1]

        h = [tf.zeros((batch_size, self.hidden_dim)) for _ in range(self.num_layers)]
        c = [tf.zeros((batch_size, self.hidden_dim)) for _ in range(self.num_layers)]

        ta = tf.TensorArray(dtype=tf.float32, size=seq_len)

        for t in tf.range(seq_len):
            xt = inputs[:, t, :]
            for i in range(self.num_layers):
                h[i], [h[i], c[i]] = self.blocks[i](xt, [h[i], c[i]])
                xt = h[i]
            ta = ta.write(t, h[-1])

        hidden_states = ta.stack()
        hidden_states = tf.transpose(hidden_states, [1, 0, 2])

        out = hidden_states[:, -1, :]
        return out
