In [14]:
import torch
from torch import nn, Tensor
from torch.nn import functional as F
import warnings

# 均方误差(mean square error, MSE / L2Loss)

In [3]:
def mse_loss(input: Tensor, target: Tensor, reduction = 'mean') -> Tensor:
    """mse loss
    input and target shape should be same.

    Args:
        input (Tensor):  predict value
        target (Tensor): target value
        reduction (str, optional): mean' | 'sum' | 'none'. Defaults to 'mean'.

    Returns:
        Tensor: mse result
    """
    if target.size() != input.size():
        warnings.warn(
            "Using a target size ({}) that is different to the input size ({}). "
            "This will likely lead to incorrect results due to broadcasting. "
            "Please ensure they have the same size.".format(target.size(), input.size()),
            stacklevel=2,
        )

    result: Tensor = (input - target) ** 2
    if reduction == "mean":
        return result.mean()
    elif reduction == "sum":
        return result.sum()
    elif reduction == "none":
        return result

## 一维数据

In [22]:
y_true = torch.tensor([1, 0, 1, 0])

In [23]:
y_pred = torch.tensor([0.8, 0.1, 0.7, 0.3])

In [24]:
print(nn.MSELoss(reduction = 'mean')(y_pred, y_true))
print(F.mse_loss(y_pred, y_true, reduction = 'mean'))
print(mse_loss(y_pred, y_true, reduction = 'mean'))

tensor(0.0575)
tensor(0.0575)
tensor(0.0575)


In [25]:
print(nn.MSELoss(reduction = 'sum')(y_pred, y_true))
print(F.mse_loss(y_pred, y_true, reduction = 'sum'))
print(mse_loss(y_pred, y_true, reduction = 'sum'))

tensor(0.2300)
tensor(0.2300)
tensor(0.2300)


In [26]:
print(nn.MSELoss(reduction = 'none')(y_pred, y_true))
print(F.mse_loss(y_pred, y_true, reduction = 'none'))
print(mse_loss(y_pred, y_true, reduction = 'none'))

tensor([0.0400, 0.0100, 0.0900, 0.0900])
tensor([0.0400, 0.0100, 0.0900, 0.0900])
tensor([0.0400, 0.0100, 0.0900, 0.0900])


## 二维数据

In [6]:
y_true = torch.tensor([
    [1, 0, 1, 0],
    [1, 1, 0, 0],
])

In [7]:
y_pred = torch.tensor([
    [0.8, 0.1, 0.7, 0.3],
    [0.9, 0.6, 0.5, 0.3],
])

In [19]:
print(nn.MSELoss(reduction = 'mean')(y_pred, y_true))
print(F.mse_loss(y_pred, y_true, reduction = 'mean'))
print(mse_loss(y_pred, y_true, reduction = 'mean'))

tensor(0.0925)
tensor(0.0925)
tensor(0.0925)


In [20]:
print(nn.MSELoss(reduction = 'sum')(y_pred, y_true))
print(F.mse_loss(y_pred, y_true, reduction = 'sum'))
print(mse_loss(y_pred, y_true, reduction = 'sum'))

tensor(0.7400)
tensor(0.7400)
tensor(0.7400)


In [21]:
print(nn.MSELoss(reduction = 'none')(y_pred, y_true))
print(F.mse_loss(y_pred, y_true, reduction = 'none'))
print(mse_loss(y_pred, y_true, reduction = 'none'))

tensor([[0.0400, 0.0100, 0.0900, 0.0900],
        [0.0100, 0.1600, 0.2500, 0.0900]])
tensor([[0.0400, 0.0100, 0.0900, 0.0900],
        [0.0100, 0.1600, 0.2500, 0.0900]])
tensor([[0.0400, 0.0100, 0.0900, 0.0900],
        [0.0100, 0.1600, 0.2500, 0.0900]])
