In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F




This simplified example defines a custom MSE loss function and compares it with PyTorch's built-in MSE loss to confirm they produce the same result. 

In [None]:
class CustomMSELoss(nn.Module):
    def __init__(self):
        super(CustomMSELoss, self).__init__()

    def forward(self, y_pred, y_true):
        return torch.mean((y_pred - y_true) ** 2)

# Example usage
y_pred = torch.randn(5, 1, requires_grad=True)
y_true = torch.randn(5, 1)

# Custom MSE
custom_mse = CustomMSELoss()(y_pred, y_true)

# PyTorch built-in MSE
builtin_mse = F.mse_loss(y_pred, y_true)

print(f"Custom MSE: {custom_mse.item():.6f}")
print(f"Built-in MSE: {builtin_mse.item():.6f}")

This custom loss function combines binary cross-entropy and mean squared error, with a weighting factor alpha to balance the two. 

In [2]:
class CustomLoss(nn.Module):
    def __init__(self, alpha=0.5):
        super(CustomLoss, self).__init__()
        self.alpha = alpha

    def forward(self, y_pred, y_true):
        bce_loss = F.binary_cross_entropy(y_pred, y_true)
        mse_loss = F.mse_loss(y_pred, y_true)
        return self.alpha * bce_loss + (1 - self.alpha) * mse_loss

# Example usage
y_pred = torch.sigmoid(torch.randn(5, 1, requires_grad=True))
y_true = torch.randint(0, 2, (5, 1)).float()

criterion = CustomLoss(alpha=0.7)
loss = criterion(y_pred, y_true)
print(f"Loss: {loss.item()}")

Loss: 1.0910006761550903
