In [14]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, LSTM, Flatten, Lambda, Concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import MeanSquaredError
import numpy as np
from numpy.random import randint


然后，我们定义一个名为 “AttentionLayer” 的新层，该层将实现注重注意力机制。

In [15]:
class AttentionLayer(tf.keras.layers.Layer):
    def __init__(self):
        super(AttentionLayer, self).__init__()

    def build(self, input_shape):
        self.w1 = self.add_weight(
            shape=(input_shape[-1], 1), initializer="random_normal", trainable=True
        )
        self.b1 = self.add_weight(
            shape=(input_shape[1], 1), initializer="zeros", trainable=True
        )

    def call(self, inputs):
        e = tf.matmul(inputs, self.w1) + self.b1
        alpha = tf.nn.softmax(e, axis=1)
        output = inputs * alpha
        return tf.reduce_sum(output, axis=1)

In [18]:
from tensorflow.keras.losses import CategoricalCrossentropy

inputs = Input(shape=(10, 8))

lstm_out = LSTM(50, return_sequences=True, dropout=0.1)(inputs)
attention_out = AttentionLayer()(lstm_out)

concatenated_out = Concatenate()([attention_out, Flatten()(lstm_out[:, -1, :])])
dense_out = Dense(10, activation="relu")(concatenated_out)
outputs = Dense(8, activation="softmax")(dense_out)

model = Model(inputs, outputs)
model.compile(optimizer=Adam(learning_rate=0.01), loss=CategoricalCrossentropy())

X = np.random.uniform(size=(1000,8))
y = X[:,:]

model.fit(X, y, epochs=10, batch_size=32)

Epoch 1/10


ValueError: ignored

接下来，我们定义模型的输入层，输入大小为(10,8)。然后，我们定义一个基于LSTM的双向RNN，以捕获输入序列中的时间依赖关系。随后，我们添加一层AttentionLayer（上面定义的自定义层），它将加强我们模型的注意力机制。最后，我们将添加一些全连接层，以将模型的输出映射到所需维数。

In [9]:
inputs = Input(shape=(10, 8))

lstm_out = LSTM(50, return_sequences=True, dropout=0.1)(inputs)
attention_out = AttentionLayer()(lstm_out)

concatenated_out = Concatenate()([attention_out, Flatten()(lstm_out[:, -1, :])])
dense_out = Dense(10, activation="relu")(concatenated_out)
outputs = Dense(8, activation="softmax")(dense_out)

model = Model(inputs, outputs)


这个模型的训练是基于一些模拟数据的，我们先生成了一组大小为 (1000,10,8)的随机数据用来训练我们的模型。接下来，反转了这些数据中的输入序列，然后用反转后的序列作为模型期望的输出。最后，我们使用Adam优化器和均方误差损失函数编译该模型，并通过模型的 fit() 方法对其进行训练。

In [None]:
model.compile(optimizer=Adam(learning_rate=0.01), loss=CategoricalCrossentropy())

X = np.random.uniform(size=(1000, 10, 8))
y = np.reshape(X[:, ::-1, :], (-1, 8))
X = np.repeat(X, repeats=10, axis=0)

model.fit(X, y, epochs=10, batch_size=32)


In [3]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Input, Dense, LSTM, Flatten, Lambda, Concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import CategoricalCrossentropy

class AttentionLayer(tf.keras.layers.Layer):

    def __init__(self):
        super(AttentionLayer, self).__init__()

    def build(self, input_shape):
        self.w1 = self.add_weight(
            shape=(input_shape[-1], 1), initializer="random_normal", trainable=True
        )
        self.b1 = self.add_weight(
            shape=(input_shape[1], 1), initializer="zeros", trainable=True
        )

    def call(self, inputs):
        e = tf.matmul(inputs, self.w1) + self.b1
        alpha = tf.nn.softmax(e, axis=1)
        output = inputs * alpha
        return tf.reduce_sum(output, axis=1)

inputs = Input(shape=(10, 8))

lstm_out = LSTM(50, return_sequences=True, dropout=0.1)(inputs)
attention_out = AttentionLayer()(lstm_out)

concatenated_out = Concatenate()([attention_out, Flatten()(lstm_out[:, -1, :])])
dense_out = Dense(10, activation="relu")(concatenated_out)
outputs = Dense(8, activation="softmax")(dense_out)

model = Model(inputs, outputs)
model.compile(optimizer=Adam(learning_rate=0.01), loss=CategoricalCrossentropy())

X = np.random.uniform(size=(1000, 10, 8))
y = np.reshape(X[:, ::-1, :], (-1, 8))
X = np.repeat(X, repeats=10, axis=0)

model.fit(X, y, epochs=10, batch_size=32)


Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x7fca4a18f670>