In [1]:
import torch
import torch.nn.functional as F

def smooth_labels(labels, classes, smoothing=0.1):
    """
    Apply label smoothing to one-hot encoded labels.
    
    labels: Tensor of shape (batch_size,)
    classes: Total number of classes
    smoothing: Smoothing factor (0 means no smoothing)
    """
    confidence = 1.0 - smoothing
    smooth_value = smoothing / (classes - 1)

    one_hot = torch.full((labels.size(0), classes), smooth_value).to(labels.device)
    one_hot.scatter_(1, labels.unsqueeze(1), confidence)
    
    return one_hot

class LabelSmoothingCrossEntropy(torch.nn.Module):
    def __init__(self, label_smoothing=0.1):
        super(LabelSmoothingCrossEntropy, self).__init__()
        self.smoothing = label_smoothing

    def forward(self, preds, target):
        classes = preds.size(1)  # Number of classes
        smoothed_targets = smooth_labels(target, classes, self.smoothing)
        log_probs = F.log_softmax(preds, dim=-1)
        return torch.mean(torch.sum(-smoothed_targets * log_probs, dim=-1))

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
crit1 = LabelSmoothingCrossEntropy(label_smoothing=0)
crit2 = torch.nn.CrossEntropyLoss()

x, labels = torch.randn(4, 10), torch.randint(0, 10, (4,))
assert torch.allclose(crit1(x, labels), crit2(x, labels))