In [335]:
import tensorflow as tf
import numpy as np
generator = tf.random.Generator.from_seed(1)
data = generator.normal(shape=[40, 3])

In [515]:
SAMPLES = 200
TIME = 6
FEATURES = 5
X = generator.normal(shape=[SAMPLES, TIME, FEATURES])

y = generator.normal(shape=[SAMPLES, 1, 1])

In [516]:
class Attention(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()
        
        self.weighter = tf.keras.layers.Softmax(axis=2)

    def build(self, input_shape):
        if len(input_shape) != 3:
            raise Exception("Wrong dimensions")
        
        self.batch_size, self.time_steps, self.number_of_features = input_shape
        
        self.weight = self.add_weight(
            name="att_weight",
            shape=(self.number_of_features, self.number_of_features),
            initializer=tf.keras.initializers.RandomUniform(-0.1, 0.1),
            trainable=True,
        )

        self.bias = self.add_weight(
            name="att_bias",
            shape=(self.number_of_features, ),
            initializer='zeros',
            trainable=True,
        )

        self.score_calculator = lambda x: x @ self.weight + self.bias
        self.averager = lambda x: tf.math.reduce_mean(x, axis=1, keepdims=True)
        self.repeater = lambda x: tf.repeat(x, repeats=self.time_steps, axis=1)

        super().build(input_shape)

    def call(self, inputs):
        attention_score = self.score_calculator(inputs)
        attention_weights = self.weighter(attention_score)
        averaged_attention_weight = self.averager(attention_weights)
        averaged_attention_weights = self.repeater(averaged_attention_weight)
        feature_representation = inputs * averaged_attention_weights
        
        return feature_representation, averaged_attention_weight

In [517]:
class STLayer(tf.keras.layers.Layer):
    def __init__(self, freq, model="add", smoothing_param=None, **kwargs):
        super(STLayer, self).__init__(**kwargs)
        self.freq = freq
        self.model = model
        self.smoothing_param = smoothing_param

    def call(self, inputs):
        batch_size = inputs.shape[0]
        num_timesteps = inputs.shape[1]

        trend = np.zeros((batch_size, num_timesteps))
        seasonal = np.zeros((batch_size, num_timesteps))
        residual = np.zeros((batch_size, num_timesteps))

        for i in range(batch_size):
            # Decompose the time series using statsmodels
            stl = sm.tsa.seasonal_decompose(inputs[i,:], model=self.model, period=self.freq, filt=self.smoothing_param)

            trend[i, :] = stl.trend
            seasonal[i, :] = stl.seasonal
            residual[i, :] = stl.resid

        return tf.convert_to_tensor([trend, seasonal, residual])

# # Example usage
# inputs = tf.keras.layers.Input(shape=(timesteps,))
# stl = STLayer(freq=12)(inputs)

# # Use the output of the STLayer to build a neural network for time series forecasting
# # ...

# model = tf.keras.Model(inputs=inputs, outputs=stl)
# model.compile(optimizer="adam", loss="mse")

In [548]:
class IterativeFilterCallback(tf.keras.callbacks.Callback):
    def __init__(self, X, y, model_class, iterations=2, iteration=0, threshold=0.2):
        self.X = X
        self.y = y
        self.iteration = iteration
        self.iterations = iterations
        self.model_class = model_class
        self.threshold = threshold
        
    def on_train_begin(self, logs={}):
        self.attention_scores = []
        self.filtered_indices = []

    # def on_batch_end(self, batch, logs={}):
    #     model = self.model
    #     self.attention_scores.append(tf.math.reduce_mean(model.attention_weights, axis=0))
        
    # def on_epoch_end(self, epoch, logs={}):
    #     model = self.model
    #     pred = model.predict(self.X)
    #     print(pred.shape)

    def on_train_end(self, logs={}):
        if self.iteration >= self.iterations:
            return
        model = self.model
        attention_weights = model.get_attention_weights(self.X)
        averaged_attention_weights = tf.math.reduce_mean(tf.math.reduce_mean(attention_weights, axis=0), axis=0)

        self.mask = averaged_attention_weights >= self.threshold
        self.indices = tf.where(self.mask)
        self.filtered_indices.append(self.indices)

        # Perform the iterative filtering and training process
        filtered_X = tf.squeeze(tf.gather(self.X, self.indices, axis=-1))
        filtered_model = self.model_class()

        filtered_model.compile(
            loss='mse', 
            metrics='mean_squared_error'
            )
        filtered_model.build(filtered_X.shape)

        for orig_layer, filtered_layer in zip(self.model.layers, filtered_model.layers):
            try:
                filtered_layer.set_weights(orig_layer.get_weights())
                print('it worked' + str(filtered_layer))
            except Exception as e:
                print(orig_layer, filtered_layer)
                print(e)
            

        filter_callback = IterativeFilterCallback(filtered_X, y, Model, iterations=self.iterations, iteration=self.iteration+1, threshold=0.1)
        filtered_model.fit(filtered_X, self.y, batch_size=16, epochs=10, callbacks=[filter_callback])
        self.filtered_indices.extend(filter_callback.filtered_indices)


In [549]:
class Model(tf.keras.Model):
    def __init__(self):
        super().__init__()
        
        self.attention = Attention()
        self.lstm = tf.keras.layers.LSTM(32, activation='softmax', return_sequences=False)
        self.dense = tf.keras.layers.Dense(1)
    
    def call(self, inputs):
        # print("inputs", inputs.shape)
        
        x, attention_weights = self.attention(inputs)
        # print("attention", x.shape)
        # print(attention_weights)

        x = self.lstm(x)
        # print("h", x.shape)
        x = self.dense(x)
        # print("ans", x.shape)
        return x
    
    def get_attention_weights(self, X):
        return self.attention(X)[1]


In [550]:
# Initialize the callback
filter_callback = IterativeFilterCallback(X, y, Model, iterations=4)

In [551]:
model = Model()
model.compile(
    loss='mse', 
    metrics='mean_squared_error',
    run_eagerly=False,
    )

model.fit(X, y, batch_size=32, epochs=10, callbacks=[filter_callback])

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
<__main__.Attention object at 0x00000194BC6A6590> <__main__.Attention object at 0x00000194C8C16110>
Layer attention_221 weight shape (2, 2) is not compatible with provided weight shape (5, 5).
<keras.layers.rnn.lstm.LSTM object at 0x00000194C8994CA0> <keras.layers.rnn.lstm.LSTM object at 0x00000194C8F36BC0>
Layer lstm_221 weight shape (2, 128) is not compatible with provided weight shape (5, 128).
it worked<keras.layers.core.dense.Dense object at 0x00000194C8F375E0>
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
it worked<__main__.Attention object at 0x00000194CB09FCA0>
it worked<keras.layers.rnn.lstm.LSTM object at 0x00000194BEBDEBF0>
it worked<keras.layers.core.dense.Dense object at 0x00000194BEBDF460>
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x194c89b9330>

In [552]:
model.summary()

Model: "model_213"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 attention_220 (Attention)   multiple                  30        
                                                                 
 lstm_220 (LSTM)             multiple                  4864      
                                                                 
 dense_220 (Dense)           multiple                  33        
                                                                 
Total params: 4,927
Trainable params: 4,927
Non-trainable params: 0
_________________________________________________________________


In [553]:
filter_callback.filtered_indices

[<tf.Tensor: shape=(2, 1), dtype=int64, numpy=
 array([[1],
        [4]], dtype=int64)>,
 <tf.Tensor: shape=(2, 1), dtype=int64, numpy=
 array([[0],
        [1]], dtype=int64)>,
 <tf.Tensor: shape=(2, 1), dtype=int64, numpy=
 array([[0],
        [1]], dtype=int64)>,
 <tf.Tensor: shape=(2, 1), dtype=int64, numpy=
 array([[0],
        [1]], dtype=int64)>]