# Supplementary Material - Towards Understanding Why Label Smoothing Degrades Selective Classification and How to Fix It

This is the code was included in the supplementary material of our ICLR submission **487**. Namely, the code:
- **trains two ResNet-20**, one with label smoothing and one without, on CIFAR-10
- **plots the Risk-Coverage (RC) curves** of the two models and computes the areas under these curves
- **shows the effect of logit normalization** on the Risk-Coverage curves

You will need to be on a machine with at least a small GPU to run the code properly. The whole notebook will run in approximately **15 minutes** with a consumer-grade GPU including training. If you have no time, you can set TRAIN_MODELS to False and use the networks that we trained in advance to compute the RC curves. This code was designed for linux/MacOS, but small changes (e.g. in the requirements to install the GPU version of PyTorch) should make it work on Windows.

The rest of the codes (ImageNet & segmentation trainings and models) will be made available online after the anonimity period.

### Create the virtual environment

We start by creating the virtual environment and downloading the necessary packages (torch, torchvision, & ipykernel). If you already have an environment (conda or venv) including PyTorch and Torchvision, you can skip these steps. Otherwise, create a venv and install the packages with:

```bash
python3 -m venv .reproduce_env
source .reproduce_env/bin/activate
pip install -r requirements.txt
```

The code was tested with Python 3.10. Please tell us if you encounter any issues.

Then, select .reproduce_env to run your notebook. Now, you can just run all the notebook to reproduce the results. You may delete the environment with `rm -rf .reproduce_env` when you are done to save disk space.

In [None]:
# You can choose whether to train the models or used pre-trained ones, using the TRAIN_MODELS boolean variable
TRAIN_MODELS = True
num_epochs = 75
device = "cuda:0"

## Training the models

We start by defining the ResNet-20 architecture and the training loop. We will train new models depending on the value of the TRAIN_MODELS variable.

In [None]:
from collections.abc import Callable
from typing import Literal

from torch import Tensor, nn
from torch.nn.functional import relu


class _BasicBlock(nn.Module):
    expansion = 1

    def __init__(
        self,
        in_planes: int,
        planes: int,
        stride: int,
        dropout_rate: float,
        groups: int,
        activation_fn: Callable,
        normalization_layer: nn.Module,
        conv_bias: bool,
    ) -> None:
        super().__init__()
        self.activation_fn = activation_fn

        self.conv1 = nn.Conv2d(
            in_planes,
            planes,
            kernel_size=3,
            stride=stride,
            padding=1,
            groups=groups,
            bias=conv_bias,
        )
        self.bn1 = normalization_layer(planes)

        # As in timm
        self.dropout = nn.Dropout2d(p=dropout_rate)
        self.conv2 = nn.Conv2d(
            planes,
            planes,
            kernel_size=3,
            stride=1,
            padding=1,
            groups=groups,
            bias=conv_bias,
        )
        self.bn2 = normalization_layer(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_planes,
                    self.expansion * planes,
                    kernel_size=1,
                    stride=stride,
                    groups=groups,
                    bias=conv_bias,
                ),
                normalization_layer(self.expansion * planes),
            )

    def forward(self, x: Tensor) -> Tensor:
        out = self.activation_fn(self.dropout(self.bn1(self.conv1(x))))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return self.activation_fn(out)


class _ResNet(nn.Module):
    def __init__(
        self,
        block,
        num_blocks,
        in_channels: int,
        num_classes: int,
        conv_bias: bool,
        dropout_rate: float,
        groups: int,
        style: Literal["imagenet", "cifar"] = "imagenet",
        in_planes: int = 64,
        activation_fn: Callable = relu,
        normalization_layer: nn.Module = nn.BatchNorm2d,
    ) -> None:
        """ResNet from `Deep Residual Learning for Image Recognition`."""
        super().__init__()

        self.in_planes = in_planes
        block_planes = in_planes
        self.dropout_rate = dropout_rate
        self.activation_fn = activation_fn

        if style == "imagenet":
            self.conv1 = nn.Conv2d(
                in_channels,
                block_planes,
                kernel_size=7,
                stride=2,
                padding=3,
                groups=1,  # No groups in the first layer
                bias=conv_bias,
            )
        elif style == "cifar":
            self.conv1 = nn.Conv2d(
                in_channels,
                block_planes,
                kernel_size=3,
                stride=1,
                padding=1,
                groups=1,  # No groups in the first layer
                bias=conv_bias,
            )
        else:
            raise ValueError(f"Unknown style. Got {style}.")

        self.bn1 = normalization_layer(block_planes)

        if style == "imagenet":
            self.optional_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        else:
            self.optional_pool = nn.Identity()

        self.layer1 = self._make_layer(
            block,
            block_planes,
            num_blocks[0],
            stride=1,
            dropout_rate=dropout_rate,
            groups=groups,
            activation_fn=activation_fn,
            normalization_layer=normalization_layer,
            conv_bias=conv_bias,
        )
        self.layer2 = self._make_layer(
            block,
            block_planes * 2,
            num_blocks[1],
            stride=2,
            dropout_rate=dropout_rate,
            groups=groups,
            activation_fn=activation_fn,
            normalization_layer=normalization_layer,
            conv_bias=conv_bias,
        )
        self.layer3 = self._make_layer(
            block,
            block_planes * 4,
            num_blocks[2],
            stride=2,
            dropout_rate=dropout_rate,
            groups=groups,
            activation_fn=activation_fn,
            normalization_layer=normalization_layer,
            conv_bias=conv_bias,
        )
        if len(num_blocks) == 4:
            self.layer4 = self._make_layer(
                block,
                block_planes * 8,
                num_blocks[3],
                stride=2,
                dropout_rate=dropout_rate,
                groups=groups,
                activation_fn=activation_fn,
                normalization_layer=normalization_layer,
                conv_bias=conv_bias,
            )
            linear_multiplier = 8
        else:
            self.layer4 = nn.Identity()
            linear_multiplier = 4

        self.dropout = nn.Dropout(p=dropout_rate)
        self.pool = nn.AdaptiveAvgPool2d(output_size=1)
        self.flatten = nn.Flatten(1)

        self.linear = nn.Linear(
            block_planes * linear_multiplier * block.expansion,
            num_classes,
        )

    def _make_layer(
        self,
        block,
        planes: int,
        num_blocks: int,
        stride: int,
        dropout_rate: float,
        groups: int,
        activation_fn: Callable,
        normalization_layer: nn.Module,
        conv_bias: bool,
    ) -> nn.Module:
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(
                block(
                    in_planes=self.in_planes,
                    planes=planes,
                    stride=stride,
                    dropout_rate=dropout_rate,
                    groups=groups,
                    activation_fn=activation_fn,
                    normalization_layer=normalization_layer,
                    conv_bias=conv_bias,
                )
            )
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def feats_forward(self, x: Tensor) -> Tensor:
        out = self.activation_fn(self.bn1(self.conv1(x)))
        out = self.optional_pool(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.pool(out)
        return self.dropout(self.flatten(out))

    def forward(self, x: Tensor) -> Tensor:
        return self.linear(self.feats_forward(x))


def resnet20(
    in_channels: int,
    num_classes: int,
    conv_bias: bool = True,
    dropout_rate: float = 0.0,
    groups: int = 1,
    style: Literal["imagenet", "cifar"] = "imagenet",
    activation_fn: Callable = relu,
    normalization_layer: nn.Module = nn.BatchNorm2d,
) -> _ResNet:
    """ResNet-20 model."""
    return _ResNet(
        block=_BasicBlock,
        num_blocks=[3, 3, 3],
        in_channels=in_channels,
        num_classes=num_classes,
        conv_bias=conv_bias,
        dropout_rate=dropout_rate,
        groups=groups,
        style=style,
        in_planes=16,
        activation_fn=activation_fn,
        normalization_layer=normalization_layer,
    )

Now, we create the datasets and prepare the optimisation loop. We create two models and their corresponding optimizers. One will be trained with Cross-Entropy and the other one with label smoothing $\alpha=0.2$. We train with label smoothing using the dedicated parameter in PyTorch's CrossEntropyLoss.


In [None]:
import copy

import torch
from torch import optim
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import CIFAR10
from torchvision.transforms import (
    Compose,
    Normalize,
    RandomCrop,
    RandomHorizontalFlip,
    ToTensor,
)

model = resnet20(in_channels=3, num_classes=10, conv_bias=False, style="cifar")
model = model.to(device)

model_ls = resnet20(in_channels=3, num_classes=10, conv_bias=False, style="cifar")
model_ls = model_ls.to(device)

optimizer = optim.SGD(
    model.parameters(),
    lr=0.1,
    momentum=0.9,
    weight_decay=5e-4,
)
scheduler = optim.lr_scheduler.MultiStepLR(
    optimizer,
    milestones=[25, 50],
    gamma=0.1,
)
optimizer_ls = optim.SGD(
    model_ls.parameters(),
    lr=0.1,
    momentum=0.9,
    weight_decay=5e-4,
)
scheduler_ls = optim.lr_scheduler.MultiStepLR(
    optimizer_ls,
    milestones=[25, 50],
    gamma=0.1,
)

train_transform = Compose(
    [
        RandomCrop(32, padding=4),
        RandomHorizontalFlip(),
        ToTensor(),
        Normalize(
            (0.5071, 0.4867, 0.4408),
            (0.2675, 0.2565, 0.2761),
        ),
    ]
)
test_transform = Compose(
    [
        ToTensor(),
        Normalize(
            (0.5071, 0.4867, 0.4408),
            (0.2675, 0.2565, 0.2761),
        ),
    ]
)

val_size = 5000
indices = torch.load("cifar10_index.pth")
train_indices = indices[:-val_size]
val_indices = indices[-val_size:]
train_set = CIFAR10(
    root="data",
    train=True,
    download=True,
    transform=train_transform,
)
test_set = CIFAR10(
    root="data",
    train=False,
    download=True,
    transform=test_transform,
)

train = Subset(train_set, train_indices)
val = copy.deepcopy(Subset(train_set, val_indices))
val.dataset.transform = test_transform

train_dl = DataLoader(train, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)
val_dl = DataLoader(val, batch_size=128, num_workers=4, pin_memory=True)
test_dl = DataLoader(test_set, batch_size=128, num_workers=4, pin_memory=True)

In [None]:
criterion_no_ls = nn.CrossEntropyLoss()
criterion_ls = nn.CrossEntropyLoss(label_smoothing=0.2)

if TRAIN_MODELS:
    print("Training model without Label Smoothing")
    best_acc = 0
    for epoch in range(num_epochs):
        model.train()
        for x, y in train_dl:
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            y_hat = model(x)
            loss = criterion_no_ls(y_hat, y)
            loss.backward()
            optimizer.step()
        scheduler.step()
        if epoch % 5 == 0 or epoch == num_epochs - 1:
            model.eval()
            with torch.no_grad():
                total = 0
                correct = 0
                for x, y in val_dl:
                    x = x.to(device)
                    y = y.to(device)
                    y_hat = model(x)
                    _, predicted = torch.max(y_hat, 1)
                    total += y.size(0)
                    correct += (predicted == y).sum().item()
                print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item():.2f}, Accuracy: {correct / total:.2%}")
                if correct / total > best_acc:
                    best_acc = correct / total
                    print("Saving best model")
                    torch.save(model.state_dict(), "cifar10_resnet20.pth")

    print("Training model with Label Smoothing")
    best_acc = 0
    for epoch in range(num_epochs):
        model_ls.train()
        for x, y in train_dl:
            x = x.to(device)
            y = y.to(device)
            optimizer_ls.zero_grad()
            y_hat = model_ls(x)
            loss = criterion_ls(y_hat, y)
            loss.backward()
            optimizer_ls.step()
        scheduler_ls.step()
        if epoch % 5 == 0 or epoch == num_epochs - 1:
            model_ls.eval()
            with torch.no_grad():
                total = 0
                correct = 0
                for x, y in val_dl:
                    x = x.to(device)
                    y = y.to(device)
                    y_hat = model_ls(x)
                    _, predicted = torch.max(y_hat, 1)
                    total += y.size(0)
                    correct += (predicted == y).sum().item()
                print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item():.2f}, Accuracy: {correct / total:.2%}")
                if correct / total > best_acc:
                    best_acc = correct / total
                    print("Saving best model")
                    torch.save(model_ls.state_dict(), "cifar10_resnet20_ls.pth")

In [None]:
# Load the best models
model.load_state_dict(torch.load("cifar10_resnet20.pth", map_location=device))
model.to(device)
model.eval()
model_ls.load_state_dict(torch.load("cifar10_resnet20_ls.pth", map_location=device))
model_ls.to(device)
model_ls.eval();

## Generate the Risk-Coverage curves

In this section, we generate the Risk-Coverage curves for the two models to compare them. We first define a function to compute the curve and apply it to the confidence scores computed for both models, one trained with Cross-Entropy and one with Label Smoothing $\alpha=0.2$. 

In [None]:
def risk_coverage_curve(y_true, y_score, sample_weight=None, dev="cpu"):
    if sample_weight is None:
        sample_weight = 1
    sorted_idx = y_score.argsort(descending=True)
    # risk for each coverage value rather than recall
    # add one to cover situation with zero coverage, assume risk is zero
    # when nothing is selected
    coverage = torch.linspace(0, 1, len(y_score) + 1).to(dev)
    # invert labels to get invalid predictions
    sample_costs = ~(y_true.to(bool)) * sample_weight
    sorted_cost = sample_costs[sorted_idx]
    summed_cost = torch.cumsum(sorted_cost, 0)
    n_selected = torch.arange(1, len(y_score) + 1).to(dev)
    # zero risk when none selected
    risk = torch.cat([torch.zeros(1).to(dev), summed_cost / n_selected])
    thresholds = y_score[sorted_idx]  # select >= threshold
    return risk, coverage, thresholds

Now, we compute the logits for both models to get the confidence scores and the correctness of the predictions. We use these tensors to compute the Risk-Coverage curves.

In [None]:
scores = []
logits = []
scores_ls = []
logits_ls = []
correct_samples = []
correct_samples_ls = []

with torch.no_grad():
    for x, y in test_dl:
        x = x.to(device)
        logit = model(x)
        logit_ls = model_ls(x)
        y_hat = logit.softmax(1).cpu()
        y_hat_ls = logit_ls.softmax(1).cpu()
        score, predicted = torch.max(y_hat, 1)
        score_ls, predicted_ls = torch.max(y_hat_ls, 1)
        correct = predicted == y
        correct_ls = predicted_ls == y
        logits.append(logit)
        logits_ls.append(logit_ls)
        scores.append(score)
        scores_ls.append(score_ls)
        correct_samples.append(correct)
        correct_samples_ls.append(correct_ls)

logits = torch.cat(logits)
logits_ls = torch.cat(logits_ls)
scores = torch.cat(scores)
scores_ls = torch.cat(scores_ls)
correct_samples = torch.cat(correct_samples)
correct_samples_ls = torch.cat(correct_samples_ls)

ce_risk, ce_cov, thresholds = risk_coverage_curve(correct_samples, scores)
ls_risk, ls_cov, thresholds_ls = risk_coverage_curve(correct_samples_ls, scores_ls)

In [None]:
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import seaborn

# set the style
seaborn.set_theme()

# Compute and show the risk-coverage curves
colors = [
    "#1f77b4",
    "#ff7f0e",
    "#2ca02c",
    "#d62728",
    "#9467bd",
    "#8c564b",
    "#e377c2",
    "#7f7f7f",
    "#bcbd22",
    "#17becf",
]
fig, ax = plt.subplots(1, 2, figsize=(10, 4))
ax[0].plot(
    ce_cov * 100,
    ce_risk * 100,
    label=f"CE ({ce_risk[-1]*100:.1f}, {torch.trapz(ce_risk,ce_cov).item()*100:.2f})",
    alpha=0.6,
    color="black",
)
ax[0].plot(
    ls_cov * 100,
    ls_risk * 100,
    label=f"LS $\\alpha=0.2$ ({ls_risk[-1]*100:.1f}, {torch.trapz(ls_risk, ls_cov).item()*100:.2f})",
    alpha=0.6,
    linestyle="dotted",
)
ax[0].set_xlabel("%coverage")
ax[0].set_ylabel("%risk$\leftarrow$")
ax[0].legend(title="ResNet-20 (%error$\downarrow$, %AURC$\downarrow$)\nCIFAR-10")
ax[0].grid(visible=True, which="both")
ax[0].set_xlim([0, 100])
ax[0].set_ylim(ymin=0)
ax[0].minorticks_on()
ax[0].add_patch(patches.Rectangle((0, 0), 60, 6, fill=False, edgecolor="black", linestyle="--", alpha=0.5))
# Right plot
ax[1].plot(ce_cov * 100, ce_risk * 100, label="CE", alpha=0.5, color="black")
ax[1].plot(ls_cov * 100, ls_risk * 100, label="LS $\\alpha=0.2$", alpha=0.7, linestyle="dotted")
ax[1].grid(visible=True, which="both")
ax[1].set_xlim([0, 60])
ax[1].set_ylim([0, 6])
ax[1].minorticks_on()
ax[1].set_xlabel("%coverage")

fig.tight_layout()
plt.show()

## Risk-Coverage curves after logit normalization

We also show the effect of logit normalization on the Risk-Coverage curves. We first define the logit normalization function and apply it to the confidence scores computed for both models. We optimize the order of the p-norm to reduce the area under the curve using validation logits. We then plot the Risk-Coverage curves for the two models and compare them.

In [None]:
# We discard the temperature of the normalization since it had only a small effect
def norm_logits(logits, p=2):
    return logits / logits.norm(p=p, dim=1)[:, None]

#### Optimisation of $p$ for logit normalization

We now optimize the order of the p-norm to reduce the area under the curve using validation logits. We select the best value for the plots below. However, please note that the variability among the best values is very low and the results are not very sensitive to the choice of $p$.

In [None]:
from tqdm.auto import tqdm

# Precompute the logits on the validation set
logits = []
labels = []
with torch.no_grad():
    for x, y in val_dl:
        x = x.to(device)
        y_hat = model(x).cpu()
        logits.append(y_hat)
        labels.append(y)

logits = torch.cat(logits)
labels = torch.cat(labels)

ps = [1, 2, 3, 4, 5, 6, 7, 8]
best_aurc = 1.0
best_p = None
best_aurcs = []
# for plotting
best_risk = None
best_coverage = None

correct_idx = logits.argmax(dim=-1) == labels
for p in tqdm(ps):
    msp_scaled = norm_logits(logits.double().cpu(), p=p).max(dim=-1).values
    risk_scaled, cov_scaled, _ = risk_coverage_curve(
        correct_idx,
        msp_scaled,
    )
    aurc = torch.trapz(risk_scaled, cov_scaled)
    print(f"p={p}, AURC={aurc:.4f}")
    if aurc < best_aurc:
        best_aurc = aurc
        best_risk = risk_scaled
        best_coverage = cov_scaled
        best_p = p

In [None]:
scores = []
scores_ls = []
correct_samples = []
correct_samples_ls = []

with torch.no_grad():
    for x, y in test_dl:
        x = x.to(device)
        # Here we remove the softmax
        y_hat = norm_logits(model(x).cpu(), p=best_p)
        y_hat_ls = norm_logits(model_ls(x).cpu(), p=best_p)
        score, predicted = torch.max(y_hat, 1)
        score_ls, predicted_ls = torch.max(y_hat_ls, 1)
        correct = predicted == y
        correct_ls = predicted_ls == y
        scores.append(score)
        scores_ls.append(score_ls)
        correct_samples.append(correct)
        correct_samples_ls.append(correct_ls)

scores = torch.cat(scores)
scores_ls = torch.cat(scores_ls)
correct_samples = torch.cat(correct_samples)
correct_samples_ls = torch.cat(correct_samples_ls)

ce_risk, ce_cov, thresholds = risk_coverage_curve(correct_samples, scores)
ls_risk, ls_cov, thresholds_ls = risk_coverage_curve(correct_samples_ls, scores_ls)

To finish, we plot the Risk-Coverage curves for the two models with using normalised logits and compare them. We see that the RC curve of the model trained with LS $\alpha=0.2$ has been much improved.

In [None]:
import matplotlib.patches as patches
import matplotlib.pyplot as plt

colors = [
    "#1f77b4",
    "#ff7f0e",
    "#2ca02c",
    "#d62728",
    "#9467bd",
    "#8c564b",
    "#e377c2",
    "#7f7f7f",
    "#bcbd22",
    "#17becf",
]
fig, ax = plt.subplots(1, 2, figsize=(10, 4))
ax[0].plot(
    ce_cov * 100,
    ce_risk * 100,
    label=f"CE ({ce_risk[-1]*100:.1f}, {torch.trapz(ce_risk,ce_cov).item()*100:.2f})",
    alpha=0.6,
    color="black",
)
ax[0].plot(
    ls_cov * 100,
    ls_risk * 100,
    label=f"LS $\\alpha=0.2$ ({ls_risk[-1]*100:.1f}, {torch.trapz(ls_risk, ls_cov).item()*100:.2f})",
    alpha=0.6,
    linestyle="dotted",
)
ax[0].set_xlabel("%coverage")
ax[0].set_ylabel("%risk$\leftarrow$")
ax[0].legend(title="ResNet-20 (%error$\downarrow$, %AURC$\downarrow$)\nCIFAR-10")
ax[0].grid(visible=True, which="both")
ax[0].set_xlim([0, 100])
ax[0].set_ylim(ymin=0)
ax[0].minorticks_on()
ax[0].add_patch(patches.Rectangle((0, 0), 60, 6, fill=False, edgecolor="black", linestyle="--", alpha=0.5))
# Right plot
ax[1].plot(ce_cov * 100, ce_risk * 100, label="CE", alpha=0.5, color="black")
ax[1].plot(ls_cov * 100, ls_risk * 100, label="LS $\\alpha=0.2$", alpha=0.7, linestyle="dotted")
ax[1].grid(visible=True, which="both")
ax[1].set_xlim([0, 60])
ax[1].set_ylim([0, 6])
ax[1].minorticks_on()
ax[1].set_xlabel("%coverage")

fig.tight_layout()
plt.show()

## Max logits given MSP plots

In this part, we plot the distribution of the max logit $v_\text{max}$ *given* the MSP $\pi_\text{max}$ for correct and incorrect predictions of the ResNet-20 on the CIFAR-10 evaluation subset. $v_\text{max}$ is *lower* for for the LS model, whilst the distributions are roughly similar for CE. This empirically matches the imbalanced max logit regularisation described in the paper. We calculate the mean $\pm$ std. in a 0.05-wide sliding window.

In [None]:
scores = []
logits = []
scores_ls = []
logits_ls = []
correct_samples = []
correct_samples_ls = []

with torch.no_grad():
    for x, y in test_dl:
        x = x.to(device)
        logit = model(x)
        logit_ls = model_ls(x)
        y_hat = logit.softmax(1).cpu()
        y_hat_ls = logit_ls.softmax(1).cpu()
        score, predicted = torch.max(y_hat, 1)
        score_ls, predicted_ls = torch.max(y_hat_ls, 1)
        correct = predicted == y
        correct_ls = predicted_ls == y
        logits.append(logit)
        logits_ls.append(logit_ls)
        scores.append(score)
        scores_ls.append(score_ls)
        correct_samples.append(correct)
        correct_samples_ls.append(correct_ls)

logits = torch.cat(logits)
logits_ls = torch.cat(logits_ls)
scores = torch.cat(scores)
scores_ls = torch.cat(scores_ls)
correct_samples = torch.cat(correct_samples)
correct_samples_ls = torch.cat(correct_samples_ls)

ce_risk, ce_cov, thresholds = risk_coverage_curve(correct_samples, scores)
ls_risk, ls_cov, thresholds_ls = risk_coverage_curve(correct_samples_ls, scores_ls)

In [None]:
import numpy as np


def window_average(softmax, logits, window_size=0.05, step=0.005, name="") -> None:
    """Hue is a boolean tensor"""
    lhat_mean, lhat_std = [], []
    iterator = []
    t = window_size / 2
    end = 1 - window_size / 2

    while t <= end:
        indices = torch.where(t - window_size / 2 < softmax)[0]
        other_indices = torch.where(softmax < t + window_size / 2)[0]
        indices = torch.from_numpy(np.intersect1d(indices.numpy(), other_indices.numpy()))
        if indices.size() != 0:
            lhat_mean.append(torch.mean(logits[indices]).item())

            lhat_std.append(torch.std(logits[indices]).item())

        else:
            print("No data at ", t)
            lhat_mean.append(0)

            lhat_std.append(0)

        iterator.append(t)
        t += step
    return iterator, lhat_std, lhat_mean

In [None]:
msp_ce_correct = scores[correct_samples].cpu()
msp_ls_correct = scores_ls[correct_samples_ls].cpu()
ce_correct_max_logits = logits.max(-1).values[correct_samples].cpu()
ls_correct_max_logits = logits_ls.max(-1).values[correct_samples_ls].cpu()
msp_ce_err = scores[~correct_samples].cpu()
msp_ls_err = scores_ls[~correct_samples_ls].cpu()
ce_err_max_logits = logits.max(-1).values[~correct_samples].cpu()
ls_err_max_logits = logits_ls.max(-1).values[~correct_samples_ls].cpu()

fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharex=True, sharey=False)
id_correct_color = [i / 255 for i in [79, 156, 227]]
id_incorrect_color = [i / 255 for i in [227, 102, 79]]
ood_color = [i / 255 for i in [164, 212, 68]]
# CE
unc_range = np.array([[0, 1], [3, 25]])
x, std, y = window_average(msp_ce_correct, ce_correct_max_logits)
axes[0].plot(x, y, label="Test ✓", color=id_correct_color, alpha=0.8)
axes[0].fill_between(
    x,
    np.array(y) - np.array(std),
    np.array(y) + np.array(std),
    alpha=0.15,
    color="blue",
)
x, std, y = window_average(msp_ce_err, ce_err_max_logits)
axes[0].plot(x, y, label="Test ✗", color=id_incorrect_color, alpha=0.8)
axes[0].fill_between(x, np.array(y) - np.array(std), np.array(y) + np.array(std), alpha=0.15, color="red")
axes[0].set_title("CE")
axes[0].set_xlabel("$\pi_{max}$")
axes[0].set_ylabel("$v_{max}$ | $\pi_{max}$    (mean$\pm$std)")
axes[0].set_ylim(0, 15)
axes[0].set_xlim(0, 1)
axes[0].legend(title="ResNet-20\nCIFAR-10")

unc_range = np.array([[0, 1], [0, 15]])
x, std, y = window_average(msp_ls_correct, ls_correct_max_logits)
axes[1].plot(x, y, label="Test ✓", color=id_correct_color, alpha=0.8)
axes[1].fill_between(
    x,
    np.array(y) - np.array(std),
    np.array(y) + np.array(std),
    alpha=0.15,
    color="blue",
)
x, std, y = window_average(msp_ls_err, ls_err_max_logits)
axes[1].plot(x, y, label="Test ✗", color=id_incorrect_color, alpha=0.8)
axes[1].fill_between(x, np.array(y) - np.array(std), np.array(y) + np.array(std), alpha=0.15, color="red")


axes[1].set_ylim(0, 6)

axes[1].set_title("LS $\\alpha=0.2$")


axes[1].set_xlabel("$\pi_{max}$")
axes[1].set_xlim(0, 1)


axes[2].set_ylabel("relative to mean of ✓")
x, std, y = window_average(msp_ls_correct, ls_correct_max_logits)
axes[2].plot(x, np.zeros(np.array(x).size), label="Test ✓", color=id_correct_color, alpha=0.8)
axes[2].fill_between(x, -np.array(std), np.array(std), alpha=0.15, color="blue")
x, std, y_ls = window_average(msp_ls_err, ls_err_max_logits)
diff_y = np.array(y_ls) - np.array(y)
axes[2].plot(x, diff_y, label="Test ✗", color=id_incorrect_color, alpha=0.8)
axes[2].fill_between(
    x,
    np.array(diff_y) - np.array(std),
    np.array(diff_y) + np.array(std),
    alpha=0.15,
    color="red",
)


axes[2].set_title("LS $\\alpha=0.2$")


axes[2].set_xlabel("$\pi_{max}$")
axes[2].set_xlim(0, 1)
axes[2].set_ylim(None, 0.3)
fig.tight_layout()