# PyTorch的交叉熵、信息熵、二分类交叉熵、负对数似然、KL散度、余弦相似度的原理与代码讲解

来自b站up主deep_thoughts 合集【PyTorch源码教程与前沿人工智能算法复现讲解】

P_55_PyTorch的交叉熵、信息熵、二分类交叉熵、负对数似然、KL散度、余弦相似度的原理与代码讲解：

https://www.bilibili.com/video/BV1Sv4y1A7dz/?spm_id_from=333.788&vd_source=18e91d849da09d846f771c89a366ed40

损失函数官方文档：https://pytorch.org/docs/stable/nn.html#distance-functions

## 1. 交叉熵损失 Cross Entropy Loss (CE loss)

官方文档：https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html#torch.nn.CrossEntropyLoss

In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# logits shape: [BS, NC]
batchsize = 2
num_class = 4

logits = torch.randn(batchsize, num_class)
target_indices = torch.randint(num_class, size=(batchsize,))  # delta 目标分布
target_logits = torch.randn(batchsize, num_class)  # 非 delta 目标分布

## 1. 调用 Cross Entropy loss

### method 1 for CE loss
ce_loss_fun = torch.nn.CrossEntropyLoss()
ce_loss = ce_loss_fun(logits, target_indices)
print(f"cross entropy loss1: {ce_loss}")

### method 2 for CE loss
ce_loss = ce_loss_fun(logits, torch.softmax(target_logits, -1))
print(f"cross entropy loss2: {ce_loss}")


cross entropy loss1: 2.4557747840881348
cross entropy loss2: 2.339287519454956


## 2.负对数似然损失 Negative Log Likelihood Loss (NLL loss)

官方文档：https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html#torch.nn.NLLLoss

In [23]:
nll_fn = torch.nn.NLLLoss()
nll_loss = nll_fn(torch.log(torch.softmax(logits, dim=-1)), target_indices)  # 与 CE loss结果相同
print(f"negative log-likelihood loss: {nll_loss}")
### cross entropy value = NLL value

negative log-likelihood loss: 2.4557747840881348


## 3. 调用 Kullback-Leibler divergence loss（KL loss）

官方文档：https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html#torch.nn.KLDivLoss

In [24]:
kld_loss_fn = torch.nn.KLDivLoss()
kld_loss = kld_loss_fn(torch.log(torch.softmax(logits, dim=-1)), torch.softmax(target_logits, dim=-1))
print(f'Kullback-Leibler divergence loss:{kld_loss}')

Kullback-Leibler divergence loss:0.3004379868507385


  "reduction: 'mean' divides the total loss by both the batch size and the support size."


## 4.验证 CE = IE + KLD

In [26]:
ce_loss_fn_sample = torch.nn.CrossEntropyLoss(reduction="none")
ce_loss_sample = ce_loss_fn_sample(logits, torch.softmax(target_logits, dim=-1))
print(f"cross entropy loss sample: {ce_loss_sample}")

kld_loss_fn_sample = torch.nn.KLDivLoss(reduction="none")
kld_loss_sample = kld_loss_fn_sample(torch.log(torch.softmax(logits, dim=-1)), torch.softmax(target_logits, dim=-1)).sum(-1)
print(f'Kullback-Leibler divergence loss sample:{kld_loss_sample}')

target_information_entropy = torch.distributions.Categorical(probs=torch.softmax(target_logits, dim=-1)).entropy()
print(f'information entropy sample:{target_information_entropy}')  # IE为常数，如果目标分布是delta分布，IE=0

print(torch.allclose(ce_loss_sample, kld_loss_sample+target_information_entropy))

cross entropy loss sample: tensor([1.9413, 2.7372])
Kullback-Leibler divergence loss sample:tensor([0.6749, 1.7286])
information entropy sample:tensor([1.2665, 1.0086])
True


## 5.Binary Cross Entropy loss（BCE loss）

官方文档：https://pytorch.org/docs/stable/generated/torch.nn.BCELoss.html#torch.nn.BCELoss

In [30]:
bce_loss_fn = torch.nn.BCELoss()
logits = torch.randn(batchsize)
prob_1 = torch.sigmoid(logits)
target = torch.randint(2, size=(batchsize, ))
bce_loss = bce_loss_fn(prob_1, target.float())
print(f"binart cross entropy loss: {bce_loss}")

### 用NLL loss代替BCE loss做二分类
prob_0 = 1-prob_1.unsqueeze(-1)
prob = torch.cat([prob_0, prob_1.unsqueeze(-1)], dim=-1)
nll_loss_binary = nll_fn(torch.log(prob), target)
print(f"negative likelihood loss binary: {nll_loss_binary}")
print(torch.allclose(bce_loss, nll_loss_binary))

binart cross entropy loss: 0.6156336665153503
negative likelihood loss binary: 0.6156336665153503
True


## 6.调用 Cosine Similarity loss

官方文档：https://pytorch.org/docs/stable/generated/torch.nn.CosineEmbeddingLoss.html#torch.nn.CosineEmbeddingLoss

In [31]:
cosine_loss_fn = torch.nn.CosineEmbeddingLoss()
v1 = torch.randn(batchsize, 512)
v2 = torch.randn(batchsize, 512)
target = torch.randint(2, size=(batchsize, ))*2-1
cosine_loss = cosine_loss_fn(v1, v2, target)
print(f"cosine similarity loss: {cosine_loss}")

cosine similarity loss: 0.01581306755542755
