Pytorch implementation of Online Label Smoothing (OLS) presented in Delving Deep into Label Smoothing.
As the abstract states, OLS is a strategy to generates soft labels based on the statistics of the model prediction for the target category. The core idea is that instead of using fixed soft labels for every epoch, we go updating them based on the stats of correct predicted samples.
More details and experiment results can be found in the paper.
Usage of OnlineLabelSmoothing is pretty straightforward.
Just use it as you would use PyTorch CrossEntropyLoss.
The only thing that is different is that at the end of the epoch you should call OnlineLabelSmoothing.next_epoch()
. It updates the
OnlineLabelSmoothing.supervise
matrix that will be used in the next epoch for the soft labels.
from ols import OnlineLabelSmoothing
import torch
k = 4 # Number of classes
b = 32 # Batch size
criterion = OnlineLabelSmoothing(alpha=0.5, n_classes=k, smoothing=0.1)
logits = torch.randn(b, k) # Predictions
y = torch.randint(k, (b,)) # Ground truth
loss = criterion(logits, y)
from ols import OnlineLabelSmoothing
criterion = OnlineLabelSmoothing(alpha=..., n_classes=...)
for epoch in range(...): # loop over the dataset multiple times
for i, data in enumerate(...):
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f'Epoch {epoch} finished!')
# Update the soft labels for next epoch
criterion.next_epoch()
Note: criterion.eval()
should be called before eval/testing and criterion.train()
should be
called before training. Otherwise, test/val statistics will be used in step()
.
With PL you can simply call next_epoch()
at the end of the epoch with:
import pytorch_lightning as pl
from ols import OnlineLabelSmoothing
class LitClassification(pl.LightningModule):
def __init__(self):
super().__init__()
self.criterion = OnlineLabelSmoothing(alpha=..., n_classes=...)
def forward(self, x):
pass
def configure_optimizers(self):
pass
def training_step(self, train_batch, batch_idx):
pass
def on_train_epoch_end(self, **kwargs):
self.criterion.next_epoch()
Note: In this case criterion.eval()
and criterion.train()
aren't needed
since PyTorchLightning handles this automatically.
pip install -U git+https://github.com/ankandrew/online-label-smoothing-pt.git
@misc{zhang2020delving,
title={Delving Deep into Label Smoothing},
author={Chang-Bin Zhang and Peng-Tao Jiang and Qibin Hou and Yunchao Wei and Qi Han and Zhen Li and Ming-Ming Cheng},
year={2020},
eprint={2011.12562},
archivePrefix={arXiv},
primaryClass={cs.CV}
}