Skip to content

Latest commit

 

History

History
112 lines (70 loc) · 3.32 KB

Metric_cn.rst

File metadata and controls

112 lines (70 loc) · 3.32 KB

Metric

评估器 metric 的基类。

用法:

m = SomeMetric()
for prediction, label in ...:
    m.update(prediction, label)
m.accumulate()

compute 接口的进阶用法:

compute 中可以使用 PaddlePaddle 内置的算子进行评估器的状态,而不是通过 Python/NumPy,这样可以加速计算。`update` 接口将 compute 的输出作为 输入,内部采用 Python/NumPy 计算。

Metric 计算流程如下 (在{}中的表示模型和评估器的计算):

inputs & labels              || ------------------
      |                      ||
   {model}                   ||
      |                      ||
outputs & labels ||
                     || tensor data
{Metric.compute} ||
                     ||
metric states(tensor) ||
                     ||

{fetch as numpy} || ------------------ | ||

metric states(numpy) || numpy data
                     ||

{Metric.update} / ------------------

代码示例 1

以 计算正确率的 Accuracy 为例,该评估器的输入为 predlabel,可以在 compute 中通过 predlabel`先计算正确预测的矩阵。 例如,预测结果包含 10 类,`pred 的 shape 是[N, 10],`label` 的 shape 是[N, 1],N 是 batch size,我们需要计算 top-1 和 top-5 的准确率, 可以在 compute 中计算每个样本的 top-5 得分,正确预测的矩阵的 shape 是[N, 5]。

COPY-FROM: paddle.metric.Metric:code-compute-example

代码示例 2

compute 中的计算,使用内置的算子(可以跑在 GPU 上,使得速度更快)。作为 update 的输入,该接口计算如下:

COPY-FROM: paddle.metric.Metric:code-update-example

方法

reset()

清空状态和计算结果。

返回

无。

update(*args) '''''''''

更新状态。如果定义了 computeupdate 的输入是 compute 的输出。如果没有定义,则输入是网络的输出*output和标签label*, 如:`update(output1, output2, ..., label1, label2,...)` 。

也可以参考 update

accumulate() '''''''''

累积的统计指标,计算和返回评估结果。

返回

评估结果,一般是 一个标量 或 多个标量。

name()

返回 Metric 的名字,一般通过__init__构造函数传入。

返回

评估的名字,string 类型。

compute()

此接口可以通过 PaddlePaddle 内置的算子计算 metric 的状态,可以加速 metric 的计算,为可选的高阶接口。

  • 如果这个接口定义了,输入是网络的输出 outputs 和 标签 labels,定义如:`compute(output1, output2, ..., label1, label2,...)` 。
  • 如果这个接口没有定义,默认的行为是直接将输入参数返回给 update,则其定义如:`update(output1, output2, ..., label1, label2,...)` 。

也可以参考 compute