In [5]:
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np

# Sample data: input and expected output (for a simple memory network)
input_data = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
expected_output = np.array([[1], [4], [7]])  # Outputs for first element of each input row

# Reshape input_data to have the expected shape (samples, timesteps, features)
# Assuming each sample represents a single timestep:
input_data = input_data.reshape(input_data.shape[0], 1, input_data.shape[1])

# Practical Tip: Ensure the input data has clear relationships for the model to learn effectively.
# 2. Define the Memory Network model
def build_memory_network(input_shape):
    inputs = layers.Input(shape=input_shape)

    # Memory component
    memory = layers.LSTM(32, return_sequences=True)(inputs)

    # Attention mechanism
    attention = layers.Attention()([memory, memory])

    # Readout
    output = layers.Dense(1)(attention)

    model = models.Model(inputs, output)
    return model

# Build the model
input_shape = (1, 3)  # 1 time step, 3 features # Adjust the input shape
memory_network = build_memory_network(input_shape)
memory_network.compile(optimizer='adam', loss='mean_squared_error')

# 3. Train the model
memory_network.fit(input_data, expected_output, epochs=500, verbose=1)

# 4. Test the memory network
test_data = np.array([[1, 2, 3]])
# Reshape test_data similar to input_data
predicted_output = memory_network.predict(test_data.reshape(1, 1, 3))
print("Predicted Output:", predicted_output)

Epoch 1/500




[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3s/step - loss: 18.1188
Epoch 2/500
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 49ms/step - loss: 17.8332
Epoch 3/500
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 59ms/step - loss: 17.5494
Epoch 4/500
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - loss: 17.2675
Epoch 5/500
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - loss: 16.9876
Epoch 6/500
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 58ms/step - loss: 16.7097
Epoch 7/500
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step - loss: 16.4341
Epoch 8/500
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 59ms/step - loss: 16.1606
Epoch 9/500
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - loss: 15.8895
Epoch 10/500
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 28ms/step - loss: 15.6208
Epoch 11/500


