接下来，我们定义PredictionRecorder类，用于在训练过程中记录每次训练的logits和真实值。

可以通过Summary方法打印本轮的训练信息，也可以通过distribution和average_score方法向TrainAnimator类传送训练数据

In [None]:
import numpy as np
import pandas as pd
import torch
from sklearn.metrics import  confusion_matrix

class Recorder:
    """
    记录和分析预测结果的类。
    """

    def __init__(self, is_logits = True):
        self.records = pd.DataFrame(columns=[
            'pred_neg', 'pred_abstain', 'pred_pos',
            'logit_neg', 'logit_abstain', 'logit_pos',
            'real', 'predicted_class'
        ])
        self.is_logits = is_logits

    def add(self, predict: torch.Tensor, real: torch.Tensor):
        if predict.shape[0] != real.shape[0]:
            raise ValueError("预测张量和真实值张量的batch_size必须相同。")
        if predict.dim() != 2 or predict.shape[1] != 3:
            raise ValueError("预测张量的形状必须是 (batch_size, 3)。")
        if real.dim() != 2 or real.shape[1] != 1:
            raise ValueError("真实值张量的形状必须是 (batch_size, 1)。")

        if self.is_logits :
            prob = torch.softmax(predict, dim = 1).cpu().detach().numpy()
            logits = predict.cpu().detach().numpy()
        else:
            prob = predict.cpu().detach().numpy()
            logits = torch.log(predict + 1e-9).cpu().detach().numpy()
        
        predicted_class = torch.argmax(predict, dim=1).cpu().detach().numpy()

        new_records_df = pd.DataFrame({
            'pred_neg': prob[:, 0],
            'pred_abstain': prob[:, 1],
            'pred_pos': prob[:, 2],
            'logit_neg': logits[:, 0],
            'logit_abstain': logits[:, 1],
            'logit_pos': logits[:, 2],
            'real': real.squeeze().cpu().detach().numpy(),
            'predicted_class': predicted_class,
        })
        self.records = pd.concat([self.records, new_records_df], ignore_index=True)

    def clear(self):
        self.__init__()

    def summary(self, threshold: float = 0.0) -> pd.DataFrame:
        """
        Generates and prints a detailed classification performance summary DataFrame.
        """
        if self.records.empty:
            print("记录为空，无法生成摘要。")
            return pd.DataFrame()

        # 1. Classify 'real' values
        def classify_real(value):
            if value < -abs(threshold): return 0
            elif value > abs(threshold): return 2
            else: return 1

        y_true = self.records['real'].apply(classify_real)
        y_pred = self.records['predicted_class']


        y_true = y_true.astype(int)
        y_pred = y_pred.astype(int)
    
        cm = confusion_matrix(y_true, y_pred, labels=[0, 1, 2])
        

        results = []
        for i in range(3):
            tp = cm[i, i]
            predicted_count = cm[:, i].sum()
            true_count = cm[i, :].sum()
            precision = tp / predicted_count if predicted_count > 0 else 0
            recall = tp / true_count if true_count > 0 else 0
            severe_error = 0
            if i == 0:
                severe_error = cm[2, 0] / predicted_count if predicted_count > 0 else 0
            elif i == 2:
                severe_error = cm[0, 2] / predicted_count if predicted_count > 0 else 0
            results.append({
                '预测为该分类的个数': predicted_count,
                'Precision (精确率)': precision,
                '真实为该分类的个数': true_count,
                'Accuracy (召回率)': recall,
                'Severe (严重错误率)': severe_error
            })

        total_samples = cm.sum()
        total_correct = np.trace(cm)
        total_severe_errors = cm[2, 0] + cm[0, 2]
        overall_accuracy = total_correct / total_samples if total_samples > 0 else 0
        overall_severe_rate = total_severe_errors / total_samples if total_samples > 0 else 0
        results.append({
            '预测为该分类的个数': total_samples,
            'Precision (精确率)': overall_accuracy,
            '真实为该分类的个数': total_samples,
            'Accuracy (召回率)': overall_accuracy,
            'Severe (严重错误率)': overall_severe_rate
        })
        summary_df = pd.DataFrame(results, index=['分类 0 (负)', '分类 1 (放弃)', '分类 2 (正)', '总计'])
        # print(f"--- 基于阈值 {threshold} 的分类性能摘要 ---")
        # print(summary_df.to_string(float_format="%.4f"))
        return summary_df

    def distribution(self) -> tuple[float, float, float]:
        if self.records.empty:
            return (0.0, 0.0, 0.0)
        props = self.records['predicted_class'].value_counts(normalize=True).reindex([0, 1, 2]).fillna(0)
        return (props[0], props[1], props[2])

    def average_score(self) -> tuple[float, float, float]:
        """
        计算三个分类的 logits 的全局平均值。
        tuple[float, float, float]: 分别代表 logit_neg, logit_abstain, logit_pos 的平均值。
        """
        if self.records.empty:
            return (0.0, 0.0, 0.0)

        # 选取 logits 相关的列
        logit_columns = ['logit_neg', 'logit_abstain', 'logit_pos']
        
        # 使用 .mean() 计算每列的平均值
        avg_logits = self.records[logit_columns].mean()

        return (avg_logits['logit_neg'], avg_logits['logit_abstain'], avg_logits['logit_pos'])

PredictionRecorder.distribution() 和 PredictionRecorder.average_score()方法用于向Animator传递prob 和 logits

In [None]:
import matplotlib.pyplot as plt
from IPython import display
import numpy as np

class TrainAnimator:
    """
    在动画中绘制数据，用于在模型训练中动态监控损失、预测概率、logits的变化。
    """

    def __init__(self, figsize=(12, 6)):
        self.num_subplots = 6
        self.reset()
        self.fig, self.axes = plt.subplots(2, 3, figsize=figsize)
        self.axes = self.axes.flatten()
        titles = ['train loss', 'train classes prob', 'train classes logits', 'test loss', 'test classes prob', 'test classes logits']
        for i, ax in enumerate(self.axes):
            ax.set_title(titles[i])
            ax.grid()
        self.fig.tight_layout()

    def add(self, x, y, subplot_idx=0):
        """
        向指定的子图添加数据点。
        参数:
            x : 当前epoch
            y : 记录的值，对于prob和logits，传入元组
            subplot_idx (int): 子图的编号
        """
        if subplot_idx < 0 or subplot_idx >= self.num_subplots:
            raise ValueError(f"subplot_idx must be between 0 and {self.num_subplots - 1}.")
            
        target_plot = self.data[subplot_idx]
        
        # 确保y是列表
        if not hasattr(y, "__len__"):
            y = [y]
        n = len(y)
        # 确保x是列表
        if not hasattr(x, "__len__"):
            x = [x] * n
            
        # 第一次添加数据时需要初始化
        if not target_plot['X']:
            target_plot['X'] = [[] for _ in range(n)]
            target_plot['Y'] = [[] for _ in range(n)]

        for i, (a, b) in enumerate(zip(x, y)):
            if a is not None and b is not None:
                target_plot['X'][i].append(a)
                target_plot['Y'][i].append(b)

        self.draw()

    def draw(self):
        """绘制子图"""
        display.clear_output(wait=True)
        for i, ax in enumerate(self.axes):
            ax.cla()
            plot_data = self.data[i]
            if plot_data['X']:
                fmts = ('-', 'm--', 'g-.', 'r:')
                for j in range(len(plot_data['X'])):
                    ax.plot(plot_data['X'][j], plot_data['Y'][j], fmts[j % len(fmts)])
            ax.legend()
        self.fig.tight_layout()
        display.display(self.fig)

    def reset(self):
        """清空数据"""
        self.data = [{'X': [], 'Y': []} for _ in range(self.num_subplots)]
        print("Animator data has been reset.")