评估器 metric 的基类。
用法:
m = SomeMetric() for prediction, label in ...: m.update(prediction, label) m.accumulate()
compute 接口的进阶用法:
在 compute 中可以使用 PaddlePaddle 内置的算子进行评估器的状态,而不是通过 Python/NumPy,这样可以加速计算。`update` 接口将 compute 的输出作为 输入,内部采用 Python/NumPy 计算。
Metric 计算流程如下 (在{}中的表示模型和评估器的计算):
inputs & labels || ------------------ | || {model} || | ||
- outputs & labels ||
|| tensor data- {Metric.compute} ||
||
- metric states(tensor) ||
||{fetch as numpy} || ------------------ | ||
- metric states(numpy) || numpy data
||{Metric.update} / ------------------
以 计算正确率的 Accuracy 为例,该评估器的输入为 pred 和 label,可以在 compute 中通过 pred 和 label`先计算正确预测的矩阵。 例如,预测结果包含 10 类,`pred 的 shape 是[N, 10],`label` 的 shape 是[N, 1],N 是 batch size,我们需要计算 top-1 和 top-5 的准确率, 可以在 compute 中计算每个样本的 top-5 得分,正确预测的矩阵的 shape 是[N, 5]。
COPY-FROM: paddle.metric.Metric:code-compute-example
在 compute 中的计算,使用内置的算子(可以跑在 GPU 上,使得速度更快)。作为 update 的输入,该接口计算如下:
COPY-FROM: paddle.metric.Metric:code-update-example
清空状态和计算结果。
返回
无。
update(*args) '''''''''
更新状态。如果定义了 compute , update 的输入是 compute 的输出。如果没有定义,则输入是网络的输出*output和标签label*, 如:`update(output1, output2, ..., label1, label2,...)` 。
也可以参考 update 。
accumulate() '''''''''
累积的统计指标,计算和返回评估结果。
返回
评估结果,一般是 一个标量 或 多个标量。
返回 Metric 的名字,一般通过__init__构造函数传入。
返回
评估的名字,string 类型。
此接口可以通过 PaddlePaddle 内置的算子计算 metric 的状态,可以加速 metric 的计算,为可选的高阶接口。
- 如果这个接口定义了,输入是网络的输出 outputs 和 标签 labels,定义如:`compute(output1, output2, ..., label1, label2,...)` 。
- 如果这个接口没有定义,默认的行为是直接将输入参数返回给 update,则其定义如:`update(output1, output2, ..., label1, label2,...)` 。
也可以参考 compute 。