# Cross Entropy Loss

本notebook演示了交叉熵损失的三种等价计算方式：
1. 使用PyTorch的CrossEntropyLoss
2. 使用LogSoftmax + NLLLoss组合
3. 手动实现交叉熵计算过程

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# 准备示例数据
num_classes = 3
# 模型输出的原始logits
raw_predictions = torch.tensor([[1.2, 0.5, -0.3],  # 预测样本1
                              [-0.1, 2.0, -0.5],   # 预测样本2
                              [0.3, -0.2, 1.5]])   # 预测样本3
# 真实标签
targets = torch.tensor([0, 1, 2])  # 分别对应类别0, 1, 2

## 方法1：直接使用CrossEntropyLoss

In [3]:
criterion = nn.CrossEntropyLoss()
loss1 = criterion(raw_predictions, targets)
print(f"方法1 - CrossEntropyLoss: {loss1.item():.6f}")

方法1 - CrossEntropyLoss: 0.374305


## 方法2：LogSoftmax + NLLLoss组合

In [4]:
log_softmax = nn.LogSoftmax(dim=1)
nll_loss = nn.NLLLoss()

# 计算 log_softmax
log_probs = log_softmax(raw_predictions)
# 计算 negative log likelihood loss
loss2 = nll_loss(log_probs, targets)
print(f"方法2 - LogSoftmax + NLLLoss: {loss2.item():.6f}")

方法2 - LogSoftmax + NLLLoss: 0.374305


## 方法3：手动实现交叉熵计算

In [5]:
# 创建one-hot编码的目标分布
target_dist = torch.eye(num_classes)[targets]

# 手动计算交叉熵：-sum(target_dist * log_probs)
loss3 = torch.mean(torch.sum(-target_dist * log_probs, dim=1))
print(f"方法3 - 手动实现: {loss3.item():.6f}")

方法3 - 手动实现: 0.374305


In [6]:
# 验证三种方法结果相同
assert torch.allclose(loss1, loss2) and torch.allclose(loss2, loss3)

# KL Divergence

KL散度用于衡量两个概率分布之间的差异。
- P: 真实分布
- Q: 预测分布
- KL(P||Q) = Σ P(x) * log(P(x)/Q(x))

In [7]:
kl_loss = nn.KLDivLoss(reduction="batchmean")
# input should be a distribution in the log space
input = F.log_softmax(torch.randn(3, 5, requires_grad=True), dim=1)
# Sample a batch of distributions. Usually this would come from the dataset
target = F.softmax(torch.rand(3, 5), dim=1)
output = kl_loss(input, target)

In [8]:
# 准备示例数据：两个概率分布
P = torch.tensor([[0.9000, 0.1000, 0.0000],
				  [0.2000, 0.7000, 0.1000],
				  [0.1000, 0.2000, 0.7000]])
Q = torch.tensor([[0.8000, 0.1500, 0.0500],
				  [0.2500, 0.6000, 0.1500],
				  [0.1500, 0.2500, 0.6000]])

assert torch.allclose(P.sum(dim=1), torch.ones(3))
assert torch.allclose(Q.sum(dim=1), torch.ones(3))

### 方法1：使用PyTorch的KLDivLoss

In [9]:
kl_div = nn.KLDivLoss(reduction='batchmean')
# 注意：KLDivLoss期望输入是log概率
loss_kl = kl_div(torch.log(Q), P)
print(f"PyTorch KLDivLoss: {loss_kl.item():.6f}")

PyTorch KLDivLoss: 0.036973


### 方法2：手动实现KL散度

In [10]:
# 添加一个很小的值来避免log(0)
epsilon = 1e-15
P = P + epsilon
Q = Q + epsilon

# 重新归一化，确保概率和为1
P = P / P.sum(dim=1, keepdim=True)
Q = Q / Q.sum(dim=1, keepdim=True)

# 手动计算 KL(P||Q) = Σ P(x) * log(P(x)/Q(x))
kl_div_manual = torch.mean(torch.sum(P * (torch.log(P) - torch.log(Q)), dim=1))
print(f"手动实现 KL 散度: {kl_div_manual.item():.6f}")

手动实现 KL 散度: 0.036973


In [11]:
# 验证两种方法结果相同
assert torch.allclose(loss_kl, kl_div_manual, rtol=1e-3)

### KL散度的特性示例

In [12]:
# 演示KL散度的不对称性
kl_pq = torch.mean(torch.sum(P * (torch.log(P) - torch.log(Q)), dim=1))
kl_qp = torch.mean(torch.sum(Q * (torch.log(Q) - torch.log(P)), dim=1))

print(f"KL(P||Q): {kl_pq.item():.6f}")
print(f"KL(Q||P): {kl_qp.item():.6f}")
print("注意：KL(P||Q) ≠ KL(Q||P)，说明KL散度是不对称的")

KL(P||Q): 0.036973
KL(Q||P): 0.530659
注意：KL(P||Q) ≠ KL(Q||P)，说明KL散度是不对称的


# Knowledge Distillation

In [13]:
class DistillationLoss(nn.Module):
    def __init__(self, temperature=2.0, alpha=0.5):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.criterion = nn.CrossEntropyLoss()
        self.kl_div = nn.KLDivLoss(reduction="batchmean")

    def forward(self, student_logits, teacher_logits, targets):
        # Hard Loss: student predictions with ground truth
        hard_loss = self.criterion(student_logits, targets)

        # Soft Loss: distillation with teacher predictions
        soft_student = F.log_softmax(student_logits / self.temperature, dim=1)
        soft_teacher = F.softmax(teacher_logits / self.temperature, dim=1)
        soft_loss = self.kl_div(soft_student, soft_teacher) * (self.temperature**2)

        # Combined loss
        loss = self.alpha * hard_loss + (1 - self.alpha) * soft_loss
        return loss


# 示例使用
# Assuming we have a batch of 3 samples with 10 classes each
batch_size, num_classes = 3, 10

# Randomly generated logits for demonstration purposes
student_logits = torch.randn(batch_size, num_classes)
teacher_logits = torch.randn(batch_size, num_classes)
true_labels = torch.randint(0, num_classes, (batch_size,))

# Initialize the distillation loss with a specified temperature and alpha
criterion = DistillationLoss(temperature=3.0, alpha=0.7)

# Compute the loss
loss = criterion(student_logits, teacher_logits, true_labels)
print(f"Total Distillation Loss: {loss.item()}")

Total Distillation Loss: 1.6067055463790894
