In [None]:
import tensorflow as tf
from Recommender_System.utility.decorator import logger
from Recommender_System.algorithm.未跑通.RippleNet.layer import Embedding2D

class GatherLayer(tf.keras.layers.Layer):
    def __init__(self, ripple_set, **kwargs):
        super(GatherLayer, self).__init__(**kwargs)
        self.ripple_set = ripple_set

    def call(self, inputs):
        return tf.gather(self.ripple_set, inputs)

class ExpandDimsLayer(tf.keras.layers.Layer):
    def __init__(self, axis, **kwargs):
        super(ExpandDimsLayer, self).__init__(**kwargs)
        self.axis = axis

    def call(self, inputs):
        return tf.expand_dims(inputs, axis=self.axis)

class MatMulLayer(tf.keras.layers.Layer):
    def call(self, inputs):
        return tf.matmul(inputs[0], inputs[1])

class SqueezeLayer(tf.keras.layers.Layer):
    def __init__(self, axis, **kwargs):
        super(SqueezeLayer, self).__init__(**kwargs)
        self.axis = axis

    def call(self, inputs):
        return tf.squeeze(inputs, axis=self.axis)

@logger('初始化RippleNet模型：', ('n_entity', 'n_relation', 'hop_size', 'ripple_size', 'dim', 'kge_weight', 'l2', 'item_update_mode'))
def RippleNet_model(n_entity: int, n_relation: int, ripple_set: list, hop_size=2, ripple_size=32, dim=16,
                    kge_weight=0.01, l2=1e-7, item_update_mode='plus_transform', use_all_hops=True) -> tf.keras.Model:
    l2 = tf.keras.regularizers.l2(l2)
    user_id = tf.keras.Input(shape=(), name='user_id', dtype=tf.int32)
    item_id = tf.keras.Input(shape=(), name='item_id', dtype=tf.int32)

    entity_embedding = tf.keras.layers.Embedding(n_entity, dim, embeddings_initializer='glorot_uniform', embeddings_regularizer=l2)
    relation_embedding = Embedding2D(n_relation, dim, dim, embeddings_initializer='glorot_uniform', embeddings_regularizer=l2)
    transform_matrix = tf.keras.layers.Dense(dim, use_bias=False, kernel_initializer='glorot_uniform', kernel_regularizer=l2)

    i = entity_embedding(item_id)  # batch, dim

    ripple_sets_layer = GatherLayer(ripple_set)
    ripple_sets = ripple_sets_layer(user_id)  # batch, hop_size, hrt, ripple_size

    h, r, t = [], [], []
    expand_dims_layer_2 = ExpandDimsLayer(axis=2)
    matmul_layer = MatMulLayer()
    squeeze_layer = SqueezeLayer(axis=3)
    for hop in range(hop_size):
        h.append(entity_embedding(ripple_sets[:, hop, 0]))  # batch, ripple_size, dim
        r.append(relation_embedding(ripple_sets[:, hop, 1]))  # batch, ripple_size, dim, dim
        t.append(entity_embedding(ripple_sets[:, hop, 2]))  # batch, ripple_size, dim

        h_expanded = expand_dims_layer_2(h[hop])
        h_squeezed = squeeze_layer(h_expanded, axis=2)  # 修正维度
        Rh = matmul_layer([r[hop], h_squeezed])  # 修改后的矩阵乘法
        v = expand_dims_layer_2(i)
        probs = squeeze_layer(matmul_layer([Rh, v]))
        probs_normalized = tf.keras.activations.softmax(probs)
        probs_expanded = expand_dims_layer_2(probs_normalized)
        o = tf.reduce_sum(t[hop] * probs_expanded, axis=1)
        i = update_item(i, o)
        o_list.append(o)

    u = sum(o_list) if use_all_hops else o_list[-1]
    score = tf.keras.layers.Activation('sigmoid', name='score')(tf.reduce_sum(i * u, axis=1))

    kge_loss = 0
    for hop in range(hop_size):
        h_expanded = expand_dims_layer_2(h[hop])
        t_expanded = expand_dims_layer_2(t[hop])
        hRt = squeeze_layer(matmul_layer([h_expanded, matmul_layer([r[hop], t_expanded])]))
        kge_loss += tf.reduce_mean(tf.sigmoid(hRt))

    l2_loss = 0
    for hop in range(hop_size):
        l2_loss += tf.reduce_sum(tf.square(h[hop]))
        l2_loss += tf.reduce_sum(tf.square(r[hop]))
        l2_loss += tf.reduce_sum(tf.square(t[hop]))

    model = tf.keras.Model(inputs=[user_id, item_id], outputs=score)
    model.add_loss(l2.l2 * l2_loss)
    model.add_loss(kge_weight * -kge_loss)
    return model

if __name__ == '__main__':
    pass  # 或者您的测试代码
