In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Dense, LayerNormalization, Add, Activation

class GatedResidualNetwork(tf.keras.layers.Layer):
    def __init__(self, units, output_size=None, dropout_rate=0.1):
        super(GatedResidualNetwork, self).__init__()
        self.units = units
        self.output_size = output_size if output_size else units
        self.dropout_rate = dropout_rate

    def build(self, input_shape):
        self.dense1 = Dense(self.units)
        self.dense2 = Dense(self.output_size)
        self.dropout = tf.keras.layers.Dropout(self.dropout_rate)
        self.layer_norm = LayerNormalization()
        self.gate = Dense(self.output_size, activation='sigmoid')
        self.activation = Activation('elu')

    def call(self, inputs, **kwargs):
        x = self.dense1(inputs)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.dense2(x)
        gate = self.gate(inputs)
        return self.layer_norm(inputs + gate * x)
class VariableSelectionNetwork(tf.keras.layers.Layer):
    def __init__(self, input_size, units, dropout_rate=0.1):
        super(VariableSelectionNetwork, self).__init__()
        self.input_size = input_size
        self.units = units
        self.dropout_rate = dropout_rate

    def build(self, input_shape):
        self.grn = GatedResidualNetwork(self.units, dropout_rate=self.dropout_rate)
        self.dense = Dense(self.input_size, activation='softmax')

    def call(self, inputs, **kwargs):
        flattened_inputs = tf.concat([tf.expand_dims(i, axis=-1) for i in inputs], axis=-1)
        grn_outputs = self.grn(flattened_inputs)
        weights = self.dense(grn_outputs)
        weighted_sum = tf.reduce_sum(grn_outputs * weights, axis=-1)
        return weighted_sum
class TemporalAttention(tf.keras.layers.Layer):
    def __init__(self, units, num_heads, dropout_rate=0.1):
        super(TemporalAttention, self).__init__()
        self.units = units
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate

    def build(self, input_shape):
        self.multi_head_attention = tf.keras.layers.MultiHeadAttention(num_heads=self.num_heads, key_dim=self.units)
        self.dropout = tf.keras.layers.Dropout(self.dropout_rate)
        self.layer_norm = LayerNormalization()

    def call(self, inputs, **kwargs):
        attn_output = self.multi_head_attention(inputs, inputs)
        attn_output = self.dropout(attn_output)
        return self.layer_norm(inputs + attn_output)


In [2]:
class TemporalFusionTransformer(tf.keras.Model):
    def __init__(self, input_size, units, num_heads, output_size, dropout_rate=0.1):
        super(TemporalFusionTransformer, self).__init__()
        self.units = units
        self.dropout_rate = dropout_rate
        self.vsn = VariableSelectionNetwork(input_size, units, dropout_rate)
        self.temporal_attention = TemporalAttention(units, num_heads, dropout_rate)
        self.grn = GatedResidualNetwork(units, output_size, dropout_rate)
        self.output_layer = Dense(output_size)

    def call(self, inputs, **kwargs):
        selected_features = self.vsn(inputs)
        attn_output = self.temporal_attention(selected_features)
        grn_output = self.grn(attn_output)
        return self.output_layer(grn_output)

# Example usage
input_size = 10
units = 64
num_heads = 4
output_size = 1
dropout_rate = 0.1

model = TemporalFusionTransformer(input_size, units, num_heads, output_size, dropout_rate)
inputs = tf.random.normal([32, 10, input_size])  # Batch size 32, sequence length 10, input size 10
outputs = model(inputs)
print(outputs.shape)  # Expected output shape: (32, 10, output_size)


Metal device set to: Apple M1

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB



InvalidArgumentError: Exception encountered when calling layer 'gated_residual_network_1' (type GatedResidualNetwork).

{{function_node __wrapped__AddV2_device_/job:localhost/replica:0/task:0/device:GPU:0}} Incompatible shapes: [10,10,32] vs. [10,10,64] [Op:AddV2]

Call arguments received by layer 'gated_residual_network_1' (type GatedResidualNetwork):
  • inputs=tf.Tensor(shape=(10, 10, 32), dtype=float32)
  • kwargs={'training': 'None'}