In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Layer
from Recommender_System.utility.decorator import logger


class FM_Layer(Layer):
    def call(self, inputs):
        user_embedding, item_embedding, user_bias, item_bias = inputs
        fm = tf.reduce_sum(user_embedding * item_embedding, axis=1, keepdims=True)
        return fm + user_bias + item_bias


@logger('初始化FM模型：', ('n_user', 'n_item', 'dim', 'l2'))
def FM_model(n_user: int, n_item: int, dim=8, l2=1e-6) -> tf.keras.Model:
    l2 = tf.keras.regularizers.l2(l2)

    user_id = tf.keras.Input(shape=(), name='user_id', dtype=tf.int32)
    user_embedding = tf.keras.layers.Embedding(n_user, dim, embeddings_regularizer=l2)(user_id)
    user_bias = tf.keras.layers.Embedding(n_user, 1, embeddings_initializer='zeros')(user_id)

    item_id = tf.keras.Input(shape=(), name='item_id', dtype=tf.int32)
    item_embedding = tf.keras.layers.Embedding(n_item, dim, embeddings_regularizer=l2)(item_id)
    item_bias = tf.keras.layers.Embedding(n_item, 1, embeddings_initializer='zeros')(item_id)

    fm_layer = FM_Layer()
    x = fm_layer([user_embedding, item_embedding, user_bias, item_bias])

    out = tf.keras.activations.sigmoid(x)
    return tf.keras.Model(inputs=[user_id, item_id], outputs=out)


if __name__ == '__main__':
    tf.keras.utils.plot_model(FM_model(1, 1), 'graph.png', show_shapes=True)
