In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal

num  = 1(1-0)
class ImprovedMultiLabelFocalLoss(nn.Module):
    def __init__(self, Maj_list, Min_list, gamma=2, gamma_min=2.0, gamma_maj=1.0, 
                 alpha=None, reduction='mean', eps=1e-8):
        """
        改进的多标签焦点损失
        参数:
            Maj_list: 多数类标签列表 (如 [0, 2, 4])
            Min_list: 少数类标签列表 (如 [1, 3, 5])
            gamma: 原始焦点损失的调制系数
            gamma_min: 少数类的额外权重系数
            gamma_maj: 多数类的额外权重系数
            alpha: 类别不平衡权重 (同原始Focal Loss)
            reduction: 损失聚合方式 ('mean'/'sum'/'none')
            eps: 数值稳定性参数
        """
        super(ImprovedMultiLabelFocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction
        self.eps = eps
        self.gamma_min = gamma_min
        self.gamma_maj = gamma_maj
        
        # 建立类别类型映射（多数类:0，少数类:1）
        self.class_type = {}
        for c in Maj_list:
            self.class_type[c] = 0
        for c in Min_list:
            self.class_type[c] = 1

    def forward(self, logits, targets, features, mu, var):
        """
        前向计算
        参数:
            logits: 模型输出的logits (batch_size, num_classes)
            targets: 真实标签 (batch_size, num_classes)，0/1表示
            features: 样本特征 (batch_size, feature_dim)
            mu: 各类别特征均值 (num_classes, feature_dim)
            var: 各类别特征方差 (num_classes, feature_dim)
        """
        batch_size, num_classes = logits.shape
        feature_dim = features.shape[1]
        
        # 1. 计算原始焦点损失组件
        sigmoid_p = torch.sigmoid(logits)
        p_t = (targets * sigmoid_p) + ((1 - targets) * (1 - sigmoid_p))
        modulating_factor = torch.pow(1 - p_t, self.gamma)
        ce_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
        
        if self.alpha is not None:
            alpha_t = self.alpha * targets + (1 - self.alpha) * (1 - targets)
            focal_loss = modulating_factor * alpha_t * ce_loss
        else:
            focal_loss = modulating_factor * ce_loss  # 形状: (batch_size, num_classes)
        
        # 2. 计算特征分布感知权重（基于正态分布概率密度）
        # 预计算每个样本在每个类别上的对数概率密度
        log_pdf_values = torch.zeros(batch_size, num_classes, device=logits.device)
        for c in range(num_classes):
            for d in range(feature_dim):
                # 每个特征维度的正态分布
                dist = Normal(
                    loc=mu[c, d],
                    scale=torch.sqrt(var[c, d] + self.eps)
                )
                log_pdf_values[:, c] += dist.log_prob(features[:, d])
        df_weight = torch.exp(log_pdf_values)  # 转换为概率密度（指数化）
        
        # 3. 区分多数类/少数类，应用动态权重
        final_loss = torch.zeros_like(focal_loss)
        for c in range(num_classes):
            # 判断类别类型（默认按多数类处理）
            is_minority = self.class_type.get(c, 0) == 1
            class_weight = self.gamma_min if is_minority else self.gamma_maj
            
            # 对正标签（target=1）应用特征分布权重
            pos_mask = (targets[:, c] == 1)
            if is_minority:
                # 少数类正标签：预测正确/错误采用不同权重策略
                correct_mask = pos_mask & (sigmoid_p[:, c] >= 0.5)
                wrong_mask = pos_mask & (sigmoid_p[:, c] < 0.5)
                
                # 预测正确：特征越典型（pdf高）权重越大
                final_loss[correct_mask, c] = focal_loss[correct_mask, c] * class_weight * pdf_weight[correct_mask, c]
                # 预测错误：特征越典型（pdf高）惩罚越重
                final_loss[wrong_mask, c] = focal_loss[wrong_mask, c] * class_weight * (2 - pdf_weight[wrong_mask, c])
            else:
                # 多数类正标签：基础权重
                final_loss[pos_mask, c] = focal_loss[pos_mask, c] * class_weight
            
            # 负标签（target=0）沿用原始焦点损失（可根据需求调整）
            neg_mask = (targets[:, c] == 0)
            final_loss[neg_mask, c] = focal_loss[neg_mask, c]
        
        # 4. 损失聚合
        if self.reduction == 'mean':
            return final_loss.mean()
        elif self.reduction == 'sum':
            return final_loss.sum()
        else:
            return final_loss