<a href="https://colab.research.google.com/github/Xuliang-Guo/data/blob/main/xlstm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### CasualConv1D

In [None]:
class CausalConv1D(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, dilation=1):
        super(CausalConv1D, self).__init__()
        self.conv = layers.Conv1D(filters, kernel_size,
                                  padding='causal', dilation_rate=dilation)
    def call(self, x):
        return self.conv(x)


### BlockDiagonal

In [None]:
class BlockDiagonal(tf.keras.layers.Layer):
    def __init__(self, in_features, out_features, num_blocks):
        super(BlockDiagonal, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.num_blocks = num_blocks

        assert in_features % num_blocks == 0
        assert out_features % num_blocks == 0

        block_in_features = in_features // num_blocks
        block_out_features = out_features // num_blocks

        self.blocks = [
            tf.keras.layers.Dense(block_out_features, activation='linear')
            for _ in range(num_blocks)
        ]

    def call(self, x):
        x_chunks = tf.split(x, num_or_size_splits=self.num_blocks, axis=-1)
        x_chunks = [block(chunk) for block, chunk in zip(self.blocks, x_chunks)]
        x = tf.concat(x_chunks, axis=-1)

        return x

### sLSTMBlock

In [None]:
class sLSTMBlock(tf.keras.layers.Layer):
    def __init__(self, input_size, hidden_size, num_heads, proj_factor=4/3):
        super(sLSTMBlock, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_size = hidden_size // num_heads
        self.proj_factor = proj_factor

        assert hidden_size % num_heads == 0
        assert proj_factor > 0

        self.layer_norm = layers.LayerNormalization()
        self.causal_conv = CausalConv1D(1, 4)

        self.Wz = BlockDiagonal(input_size, hidden_size, num_heads)
        self.Wi = BlockDiagonal(input_size, hidden_size, num_heads)
        self.Wf = BlockDiagonal(input_size, hidden_size, num_heads)
        self.Wo = BlockDiagonal(input_size, hidden_size, num_heads)

        self.Rz = BlockDiagonal(hidden_size, hidden_size, num_heads)
        self.Ri = BlockDiagonal(hidden_size, hidden_size, num_heads)
        self.Rf = BlockDiagonal(hidden_size, hidden_size, num_heads)
        self.Ro = BlockDiagonal(hidden_size, hidden_size, num_heads)

        self.group_norm = layers.GroupNormalization(groups=num_heads)

        self.up_proj_left = layers.Dense(int(hidden_size * proj_factor), activation='linear')
        self.up_proj_right = layers.Dense(int(hidden_size * proj_factor), activation='linear')
        self.down_proj = layers.Dense(input_size, activation='linear')

    def call(self, x, prev_state):
        h_prev, c_prev, n_prev, m_prev = prev_state
        x_norm = self.layer_norm(x)
        x_conv = tf.nn.silu(tf.squeeze(self.causal_conv(tf.expand_dims(x_norm, axis=-1)), axis=-1))

        z = tf.tanh(self.Wz(x) + self.Rz(h_prev))
        o = tf.sigmoid(self.Wo(x) + self.Ro(h_prev))
        i_tilde = self.Wi(x_conv) + self.Ri(h_prev)
        f_tilde = self.Wf(x_conv) + self.Rf(h_prev)

        m_t = tf.maximum(f_tilde + m_prev, i_tilde)
        i = tf.exp(i_tilde - m_t)
        f = tf.exp(f_tilde + m_prev - m_t)

        c_t = f * c_prev + i * z
        n_t = f * n_prev + i
        h_t = o * c_t / n_t

        output = h_t
        output_norm = self.group_norm(output)
        output_left = self.up_proj_left(output_norm)
        output_right = self.up_proj_right(output_norm)
        output_gated = tf.nn.gelu(output_right)
        output = output_left * output_gated
        output = self.down_proj(output)
        final_output = output + x

        return final_output, (h_t, c_t, n_t, m_t)


### mLSTMBlock

In [None]:
class mLSTMBlock(tf.keras.layers.Layer):
    def __init__(self, input_size, hidden_size, num_heads, proj_factor=2):
        super(mLSTMBlock, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_size = hidden_size // num_heads
        self.proj_factor = proj_factor

        assert hidden_size % num_heads == 0
        assert proj_factor > 0

        self.layer_norm = layers.LayerNormalization()
        self.up_proj_left = layers.Dense(int(input_size * proj_factor), activation='linear')
        self.up_proj_right = layers.Dense(hidden_size, activation='linear')
        self.down_proj = layers.Dense(input_size, activation='linear')

        self.causal_conv = CausalConv1D(1, 4)
        self.skip_connection = layers.Dense(hidden_size, activation='linear')

        self.Wq = BlockDiagonal(int(input_size * proj_factor), hidden_size, num_heads)
        self.Wk = BlockDiagonal(int(input_size * proj_factor), hidden_size, num_heads)
        self.Wv = BlockDiagonal(int(input_size * proj_factor), hidden_size, num_heads)
        self.Wi = layers.Dense(hidden_size, activation='linear')
        self.Wf = layers.Dense(hidden_size, activation='linear')
        self.Wo = layers.Dense(hidden_size, activation='linear')

        self.group_norm = layers.GroupNormalization(groups=num_heads)


    def call(self, x, prev_state):
        h_prev, c_prev, n_prev, m_prev = prev_state
        x_norm = self.layer_norm(x)
        x_up_left = self.up_proj_left(x_norm)
        x_up_right = self.up_proj_right(x_norm)

        x_conv = tf.nn.silu(tf.squeeze(self.causal_conv(tf.expand_dims(x_up_left, axis=-1)), axis=-1))

        x_skip = self.skip_connection(x_conv)

        q = self.Wq(x_conv)
        k = self.Wk(x_conv) / (self.head_size ** 0.5)
        v = self.Wv(x_up_left)

        i_tilde = self.Wi(x_conv)
        f_tilde = self.Wf(x_conv)
        o = tf.sigmoid(self.Wo(x_up_left))

        m_t = tf.math.maximum(f_tilde + m_prev, i_tilde)
        i = tf.exp(i_tilde - m_t)
        f = tf.exp(f_tilde + m_prev - m_t)

        c_t = f * c_prev + i * (v * k) # v @ k.T
        n_t = f * n_prev + i * k
        h_t = o * (c_t * q) / tf.reduce_max(tf.abs(tf.matmul(tf.transpose(n_t), q)), axis=1) # o * (c @ q) / max{|n.T @ q|, 1}

        output = h_t
        output_norm = self.group_norm(output)
        output = output_norm + x_skip
        output = output * tf.nn.silu(x_up_right)
        output = self.down_proj(output)
        final_output = output + x

        return final_output, (h_t, c_t, n_t, m_t)

## XLSTM

1. sLSTM

In [None]:

class sLSTM(tf.keras.layers.Layer):
    def __init__(self, input_size, hidden_size, num_heads, num_layers=1,  proj_factor=4/3):
        super(sLSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.proj_factor_slstm = proj_factor

        # Create a list of sLSTMBlock layers
        self.layers = [sLSTMBlock(input_size, hidden_size, num_heads, proj_factor) for _ in range(num_layers)]

    def call(self, x):
        # Ensure x is of shape (batch, seq_len, input)
        seq_len, batch_size, _ = tf.shape(x)[1], tf.shape(x)[0], tf.shape(x)[2]

        h = [tf.zeros((batch_size, self.hidden_size)) for _ in range(self.num_layers)]
        c = [tf.zeros((batch_size, self.hidden_size)) for _ in range(self.num_layers)]
        n = [tf.zeros((batch_size, self.hidden_size)) for _ in range(self.num_layers)]
        m = [tf.zeros((batch_size, self.hidden_size)) for _ in range(self.num_layers)]

        ta = tf.TensorArray(dtype=tf.float32, size=seq_len)

        for t in tf.range(seq_len):
            xt = x[:, t, :]
            for i in range(self.num_layers):
                xt, [h[i], c[i], n[i], m[i]] = self.layers[i](xt, [h[i], c[i], n[i], m[i]])
            ta = ta.write(t, xt)

        hidden_states = ta.stack()
        hidden_states = tf.transpose(hidden_states, [1, 0, 2])

        out = hidden_states[:, -1, :]

        return out


2. mLSTM

In [None]:
class mLSTM(tf.keras.layers.Layer):
    def __init__(self, input_size, hidden_size, num_heads, num_layers=1,  proj_factor=2):
        super(mLSTM, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.proj_factor_slstm = proj_factor

        # Create a list of sLSTMBlock layers
        self.layers = [mLSTMBlock(input_size, hidden_size, num_heads, proj_factor) for _ in range(num_layers)]

    def call(self, x):
        # Ensure x is of shape (batch, seq_len, input)
        seq_len, batch_size, _ = tf.shape(x)[1], tf.shape(x)[0], tf.shape(x)[2]

        h = [tf.zeros((batch_size, self.hidden_size)) for _ in range(self.num_layers)]
        c = [tf.zeros((batch_size, self.hidden_size)) for _ in range(self.num_layers)]
        n = [tf.zeros((batch_size, self.hidden_size)) for _ in range(self.num_layers)]
        m = [tf.zeros((batch_size, self.hidden_size)) for _ in range(self.num_layers)]

        ta = tf.TensorArray(dtype=tf.float32, size=seq_len)

        for t in tf.range(seq_len):
            xt = x[:, t, :]
            for i in range(self.num_layers):
                xt, [h[i], c[i], n[i], m[i]] = self.layers[i](xt, [h[i], c[i], n[i], m[i]])
            ta = ta.write(t, xt)

        hidden_states = ta.stack()
        hidden_states = tf.transpose(hidden_states, [1, 0, 2])

        out = hidden_states[:, -1, :]

        return out
