In [1]:
import torch
import torch.nn.functional as F


def specialized_loss(P, X, lambda_row=1.0, lambda_col=1.0):
    """
    专门设计的损失函数
    P: 预测的双随机矩阵
    X: 真值01矩阵
    """
    batch_size, n, m = P.shape  # 假设P是三维的，第一维是batch

    # 1. 基本的匹配损失（例如二元交叉熵）
    base_loss = F.binary_cross_entropy(P, X, reduction='mean')

    # 2. X行和为0时的惩罚项
    X_row_sums = torch.sum(X, dim=2)  # [batch_size, n]
    P_row_sums = torch.sum(P, dim=2)  # [batch_size, n]
    row_zero_mask = (X_row_sums == 0).float()  # 标识X中行和为0的位置
    row_penalty = torch.sum(P_row_sums * row_zero_mask) / (torch.sum(row_zero_mask) + 1e-6)

    # 3. X列和为0时的惩罚项
    X_col_sums = torch.sum(X, dim=1)  # [batch_size, m]
    P_col_sums = torch.sum(P, dim=1)  # [batch_size, m]
    col_zero_mask = (X_col_sums == 0).float()  # 标识X中列和为0的位置
    col_penalty = torch.sum(P_col_sums * col_zero_mask) / (torch.sum(col_zero_mask) + 1e-6)

    # 总损失
    total_loss = base_loss + lambda_row * row_penalty + lambda_col * col_penalty

    return total_loss


In [7]:
P = torch.tensor([[[0.1, 0.2, 0.7], [0.3, 0.3, 0.4], [0.8, 0.05, 0.15]]], requires_grad=True).float()
X = torch.tensor([[[0, 0, 1], [0, 0, 0], [1, 0, 0]]]).float()
loss = specialized_loss(P, X)
loss.backward()