In [None]:
from typing import List, Tuple, Callable, Dict  # 导入类型注解，用于代码类型检查和文档目的。
import tensorflow as tf  # 导入TensorFlow库，用于机器学习和神经网络。

# 导入自定义模块，可能包含日志、工具函数和评估方法。
from Recommender_System.algorithm.common import log, topk
from Recommender_System.utility.evaluation import TopkData
from Recommender_System.utility.decorator import logger

# 函数prepare_ds：准备训练和测试数据集
def prepare_ds(train_data: List[Tuple[int, int, int]], test_data: List[Tuple[int, int, int]],
               batch: int) -> Tuple[tf.data.Dataset, tf.data.Dataset]:
    # 定义函数xy：从原始数据创建用户ID、物品ID和标签的张量
    def xy(data):
        user_ids = tf.constant([d[0] for d in data], dtype=tf.int32)
        item_ids = tf.constant([d[1] for d in data], dtype=tf.int32)
        labels = tf.constant([d[2] for d in data], dtype=tf.keras.backend.floatx())
        return {'user_id': user_ids, 'item_id': item_ids}, labels

    # 创建训练和测试数据集
    train_ds = tf.data.Dataset.from_tensor_slices(xy(train_data)).shuffle(len(train_data)).batch(batch)
    test_ds = tf.data.Dataset.from_tensor_slices(xy(test_data)).batch(batch)

    return train_ds, test_ds

# 函数_evaluate：评估模型的性能
def _evaluate(model, dataset, loss_object, mean_metric=tf.keras.metrics.Mean(), auc_metric=tf.keras.metrics.AUC(),
              precision_metric=tf.keras.metrics.Precision(), recall_metric=tf.keras.metrics.Recall()):
    for metric in [mean_metric, auc_metric, precision_metric, recall_metric]:
        tf.py_function(metric.reset_states, [], [])  # 重置评估指标

    # 定义函数evaluate_batch：评估单个批次
    @tf.function
    def evaluate_batch(ui, label):
        score = tf.squeeze(model(ui))
        loss = loss_object(label, score) + sum(model.losses)
        return score, loss

    # 对数据集中的每个批次进行评估
    for ui, label in dataset:
        score, loss = evaluate_batch(ui, label)

        # 更新评估指标
        mean_metric.update_state(loss)
        auc_metric.update_state(label, score)
        precision_metric.update_state(label, score)
        recall_metric.update_state(label, score)

    # 返回计算的评估指标结果
    return mean_metric.result(), auc_metric.result(), precision_metric.result(), recall_metric.result()

# 函数_train_graph：使用TensorFlow图模式进行模型训练
def _train_graph(model, train_ds, test_ds, topk_data, optimizer, loss_object, epochs):
    score_fn = get_score_fn(model)  # 获取模型分数函数

    # 定义函数train_batch：训练单个批次
    @tf.function
    def train_batch(ui, label):
        with tf.GradientTape() as tape:
            score = tf.squeeze(model(ui, training=True))
            loss = loss_object(label, score) + sum(model.losses)
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    # 进行多个训练周期
    for epoch in range(epochs):
        for ui, label in train_ds:
            train_batch(ui, label)

        # 在训练和测试集上评估模型
        train_loss, train_auc, train_precision, train_recall = _evaluate(model, train_ds, loss_object)
        test_loss, test_auc, test_precision, test_recall = _evaluate(model, test_ds, loss_object)

        # 记录日志并执行topk评估
        log(epoch, train_loss, train_auc, train_precision, train_recall, test_loss, test_auc, test_precision, test_recall)
        topk(topk_data, score_fn)

# 函数_train_eager：使用Eager Execution模式进行模型训练
def _train_eager(model, train_ds, test_ds, topk_data, optimizer, loss_object, epochs):
    model.compile(optimizer=optimizer, loss=loss_object, metrics=['AUC', 'Precision', 'Recall'])
    model.fit(train_ds, epochs=epochs, verbose=0, validation_data=test_ds,
              callbacks=[RsCallback(topk_data, get_score_fn(model))])

# RsCallback类：自定义回调，用于在每个训练时期结束时执行特定操作
class RsCallback(tf.keras.callbacks.Callback):
    def __init__(self, topk_data: TopkData, score_fn: Callable[[Dict[str, List[int]]], List[float]]):
        super(RsCallback, self).__init__()
        self.topk_data = topk_data  # 用于topk评估的数据
        self.score_fn = score_fn  # 评分函数

    # 定义on_epoch_end方法，每个训练时期结束时被调用
    def on_epoch_end(self, epoch, logs=None):
        # 获取日志数据，处理可能缺失的键
        loss = logs.get('loss', 0)
        auc = logs.get('auc', 0)
        precision = logs.get('precision', 0)
        recall = logs.get('recall', 0)
        val_loss = logs.get('val_loss', 0)
        val_auc = logs.get('val_auc', 0)
        val_precision = logs.get('val_precision', 0)
        val_recall = logs.get('val_recall', 0)

        # 记录日志并进行topk评估
        log(epoch, loss, auc, precision, recall, val_loss, val_auc, val_precision, val_recall)
        topk(self.topk_data, self.score_fn)


@logger('开始训练，', ('epochs', 'batch', 'execution'))
def train(model: tf.keras.Model, train_data: List[Tuple[int, int, int]], test_data: List[Tuple[int, int, int]],
          topk_data: TopkData, optimizer=None, loss_object=None, epochs=100, batch=512, execution='eager') -> None:
    """
    通用训练流程。

    :param model: 模型
    :param train_data: 训练集
    :param test_data: 测试集
    :param topk_data: 用于topk评估数据
    :param optimizer: 优化器，默认为Adam
    :param loss_object: 损失函数，默认为BinaryCrossentropy
    :param epochs: 迭代次数
    :param batch: 批数量
    :param execution: 执行模式，为eager或graph。在eager模式下，用model.fit；在graph模式下，用tf.function和GradientTape
    """
    if optimizer is None:
        optimizer = tf.keras.optimizers.Adam()
    if loss_object is None:
        loss_object = tf.keras.losses.BinaryCrossentropy()

    train_ds, test_ds = prepare_ds(train_data, test_data, batch)
    train_fn = _train_eager if execution == 'eager' else _train_graph
    train_fn(model, train_ds, test_ds, topk_data, optimizer, loss_object, epochs)


@logger('开始测试，', ('batch',))
def test(model: tf.keras.Model, train_data: List[Tuple[int, int, int]], test_data: List[Tuple[int, int, int]],
         topk_data: TopkData, loss_object=None, batch=512) -> None:
    """
    通用测试流程。

    :param model: 模型
    :param train_data: 训练集
    :param test_data: 测试集
    :param topk_data: 用于topk评估数据
    :param loss_object: 损失函数，默认为BinaryCrossentropy
    :param batch: 批数量
    """
    if loss_object is None:
        loss_object = tf.keras.losses.BinaryCrossentropy()

    train_ds, test_ds = prepare_ds(train_data, test_data, batch)
    train_loss, train_auc, train_precision, train_recall = _evaluate(model, train_ds, loss_object)
    test_loss, test_auc, test_precision, test_recall = _evaluate(model, test_ds, loss_object)
    log(-1, train_loss, train_auc, train_precision, train_recall, test_loss, test_auc, test_precision, test_recall)
    topk(topk_data, get_score_fn(model))


# get_score_fn函数：获取用于评分的函数
def get_score_fn(model):
    # 定义_fast_model：一个经过优化，能快速评估的模型版本
    @tf.function(experimental_relax_shapes=True)
    def _fast_model(ui):
        return tf.squeeze(model(ui))

    # get_score_fn函数：返回一个用于计算分数的函数
    def score_fn(ui):
        ui = {k: tf.constant(v, dtype=tf.int32) for k, v in ui.items()}
        return _fast_model(ui).numpy()

    return score_fn
