In [None]:
from typing import List
import tensorflow as tf
from tensorflow.keras.layers import Layer
from Recommender_System.algorithm.未跑通.KGCN.layer import SumAggregator, ConcatAggregator, NeighborAggregator
from Recommender_System.utility.decorator import logger

class ExpandDimsLayer(Layer):
    def call(self, inputs):
        return tf.expand_dims(inputs, axis=1)

class GatherLayer(Layer):
    def call(self, inputs):
        indices, params = inputs
        # 确保所有索引都是有效的
        mask = tf.greater_equal(indices, 0)
        indices = tf.where(mask, indices, tf.zeros_like(indices))
        return tf.gather(params, indices)

@logger('初始化KGCN模型：', ('n_user', 'n_entity', 'n_relation', 'neighbor_size', 'iter_size', 'dim', 'l2', 'aggregator'))
def KGCN_model(n_user: int, n_entity: int, n_relation: int, adj_entity: List[List[int]], adj_relation: List[List[int]],
               neighbor_size: int, iter_size=2, dim=16, l2=1e-7, aggregator='sum') -> tf.keras.Model:
    assert neighbor_size == len(adj_entity[0]) == len(adj_relation[0])
    l2 = tf.keras.regularizers.l2(l2)

    user_id = tf.keras.Input(shape=(), name='user_id', dtype=tf.int32)
    item_id = tf.keras.Input(shape=(), name='item_id', dtype=tf.int32)

    user_embedding = tf.keras.layers.Embedding(n_user, dim, embeddings_initializer='glorot_uniform', embeddings_regularizer=l2)(user_id)
    entity_embedding = tf.keras.layers.Embedding(n_entity, dim, embeddings_initializer='glorot_uniform', embeddings_regularizer=l2)
    relation_embedding = tf.keras.layers.Embedding(n_relation, dim, embeddings_initializer='glorot_uniform', embeddings_regularizer=l2)

    flatten = tf.keras.layers.Flatten()
    expand_dims_layer = ExpandDimsLayer()
    gather_layer = GatherLayer()

    entities = [expand_dims_layer(item_id)]
    relations = []
    for _ in range(iter_size):
        neighbor_entities = flatten(gather_layer(inputs=(entities[-1], adj_entity)))
        neighbor_relations = flatten(gather_layer(inputs=(entities[-1], adj_relation)))
        entities.append(neighbor_entities)
        relations.append(neighbor_relations)

    # 根据aggregator选择对应的聚合类
    if aggregator == 'sum':
        aggregator_class = SumAggregator
    elif aggregator == 'concat':
        aggregator_class = ConcatAggregator
    elif aggregator == 'neighbor':
        aggregator_class = NeighborAggregator
    else:
        raise Exception("Unknown aggregator: " + aggregator)

    entity_vectors = [entity_embedding(entity) for entity in entities]
    relation_vectors = [relation_embedding(relation) for relation in relations]
    for it in range(iter_size):
        aggregator = aggregator_class(activation='relu' if it < iter_size - 1 else 'tanh', kernel_regularizer=l2)
        entities_next = []
        for hop in range(iter_size - it):
            inputs = (entity_vectors[hop], entity_vectors[hop + 1], relation_vectors[hop])
            vector = aggregator(inputs, neighbor_size=neighbor_size)
            entities_next.append(vector)
        entity_vectors = entities_next
    assert len(entity_vectors) == 1
    i = tf.reshape(entity_vectors[0], shape=(-1, dim))  # batch, dim
    score = tf.sigmoid(tf.reduce_sum(user_embedding * i, axis=1))

    return tf.keras.Model(inputs=[user_id, item_id], outputs=score)

if __name__ == '__main__':
    adj = [[1, 2], [0, 2], [0, 1]]
    model = KGCN_model(3, 3, 3, adj, adj, neighbor_size=2)
