# CIFAR-100 DenseNet plots

Update the following variables with values corresponding to your own experiments:

In [None]:
# Paths
MODEL_CE_PATH = "../experiments/classification/cifar100/logs/densenet/version_0/checkpoints/last.ckpt"
MODEL_LS_PATH = "../experiments/classification/cifar100/logs/densenet_ls/version_0/checkpoints/last.ckpt"

# Choose your device

DEVICE = "cuda:0"  # or "cpu"

# Whether to save the image or not

SAVE_IMG = False

Prepare the datamodule and disable gradients:

In [None]:
import torch
from torch_uncertainty.datamodules import CIFAR100DataModule

# Disable gradients globally
torch.set_grad_enabled(False)

dm = CIFAR100DataModule("./data", batch_size=128)
dm.prepare_data()
dm.setup("test")

Instantiate the models in memory

In [None]:
from torch_uncertainty_ls.densenet import DenseNetBC

model = DenseNetBC(num_classes=100)
sd = torch.load(MODEL_CE_PATH, weights_only=True)["state_dict"]
sd = {k.replace("model.", ""): v for k, v in sd.items()}
model.load_state_dict(sd)
model = model.to(DEVICE)
model.eval()

model_ls = DenseNetBC(num_classes=100)
sd = torch.load(MODEL_LS_PATH, weights_only=True)["state_dict"]
sd = {k.replace("model.", ""): v for k, v in sd.items()}
model_ls.load_state_dict(sd)
model_ls = model_ls.to(DEVICE)
model_ls = model_ls.eval();

Compute the logits with the CE-based and LS-based models.

In [None]:
scores = []
scores_ls = []
correct_samples = []
correct_samples_ls = []

for batch in dm.test_dataloader()[0]:
    x, y = batch
    x = x.to(DEVICE)
    y_pred = model(x).softmax(dim=-1).cpu()
    y_pred_ls = model_ls(x).softmax(dim=-1).cpu()
    scores.append(y_pred)
    scores_ls.append(y_pred_ls)
    correct_samples.append(y_pred.argmax(-1) == y)
    correct_samples_ls.append(y_pred_ls.argmax(-1) == y)

scores = torch.cat(scores)
scores_ls = torch.cat(scores_ls)
correct_samples = torch.cat(correct_samples)
correct_samples_ls = torch.cat(correct_samples_ls)

In [None]:
from torch_uncertainty_ls.utils import risk_coverage_curve

ce_risk, ce_cov, thresholds = risk_coverage_curve(correct_samples, scores.max(1).values)
ls_risk, ls_cov, thresholds_ls = risk_coverage_curve(correct_samples_ls, scores_ls.max(1).values)

Create the plots

In [None]:
import matplotlib.pyplot as plt
import seaborn

# set the style
seaborn.set_theme()

# Compute and show the risk-coverage curves
fig, ax = plt.subplots(1, 1, figsize=(6, 4))
ax.plot(
    ce_cov * 100,
    ce_risk * 100,
    label=f"CE ({ce_risk[-1]*100:.1f}, {torch.trapz(ce_risk,ce_cov).item()*100:.2f})",
    alpha=0.6,
    color="black",
)
ax.plot(
    ls_cov * 100,
    ls_risk * 100,
    label=f"LS ({ls_risk[-1]*100:.1f}, {torch.trapz(ls_risk, ls_cov).item()*100:.2f})",
    alpha=0.6,
    linestyle="dotted",
)
ax.set_xlabel("%coverage")
ax.set_ylabel("%risk$\leftarrow$")
ax.legend(title="DenseNet (%error$\downarrow$, %AURC$\downarrow$)\nCIFAR-100", loc="upper left")
ax.grid(visible=True, which="both")
ax.set_xlim(0, 100)
ax.set_ylim(0, 25)
ax.minorticks_on()
fig.tight_layout()

if SAVE_IMG:
    plt.savefig("cifar100.pdf", dpi=300)

plt.show();