In [None]:
from datetime import datetime
import os
import random

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
from torcheval.metrics import Mean, Metric, MulticlassAccuracy
import torchvision
from tqdm import tqdm

In [None]:
# Download datasets

train_dataset = torchvision.datasets.CIFAR10(root="data", train=True, download=True, transform=torchvision.transforms.ToTensor())
test_dataset = torchvision.datasets.CIFAR10(root="data", train=False, download=True, transform=torchvision.transforms.ToTensor())

In [None]:
def make_threshold_sample_flags(num_examples, num_classes):
    threshold_sample_flags = np.zeros((num_examples,), dtype=np.uint8)
    num_threshold_samples = num_examples // (num_classes + 1)
    threshold_sample_flags[:num_threshold_samples] = 1
    np.random.shuffle(threshold_sample_flags)
    return threshold_sample_flags


class ThresholdSamplesDataset(Dataset):
    """A Dataset wrapper that adds threshold samples."""
    def __init__(self, dataset, threshold_sample_flags):
        if not hasattr(dataset, "classes"):
            raise ValueError("dataset must have 'classes' attribute.")
        
        self.dataset = dataset
        self.threshold_sample_flags = threshold_sample_flags
        self.classes = self.dataset.classes + ["fake_label"]

    def __getitem__(self, index):
        x, y = self.dataset[index]
        if self.threshold_sample_flags[index]:
            return x, len(self.dataset.classes)
        return x, y

    def __len__(self):
        return len(self.dataset)

In [None]:
# Make threshold samples dataset

threshold_sample_flags = make_threshold_sample_flags(num_examples=len(train_dataset), num_classes=len(train_dataset.classes))
threshold_dataset = ThresholdSamplesDataset(train_dataset, threshold_sample_flags)

In [None]:
print(f"Number of classes: {len(train_dataset.classes)}")
print(f"Number of samples: {len(threshold_dataset)}")
print(f"Number of threshold samples: {sum(1 for x, y in threshold_dataset if y == len(train_dataset.classes))}")

In [None]:
# Make data loaders

batch_size = 64

threshold_loader = torch.utils.data.DataLoader(threshold_dataset, batch_size=batch_size)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)

x, y = next(iter(test_loader))
print(f"x: {x.shape}, {x.dtype}")
print(f"y: {y.shape}, {y.dtype}")

In [None]:
# Set device

device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

In [None]:
# Build model

model = torchvision.models.resnet34(num_classes=len(threshold_dataset.classes)).to(device)
print(model)

In [None]:
# Set loss function and optimizer

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

In [None]:
# Make TensorBoard writer

now = datetime.now().strftime("%b%d_%H-%M-%S")
log_dir = os.path.join("runs", "find_mislabeled_data", now)
writer = SummaryWriter(log_dir=log_dir)

In [None]:
# Define metrics

loss_metric = Mean(device=device)
accuracy_metric = MulticlassAccuracy(device=device)

In [None]:
# Helper functions for visualizing predictions

to_pil_image = torchvision.transforms.ToPILImage()


def predict_with_probs(model, x):
    """
    Generates predictions and corresponding probabilities from a trained
    network and a list of images
    """
    logits = model(x)
    output = F.softmax(logits, dim=1)
    probs, preds = torch.max(output, 1)
    return preds, probs


def plot_classes_preds(model, x, y, classes):
    """
    Generates matplotlib Figure using a trained network, along with images
    and labels from a batch, that shows the network's top prediction along
    with its probability, alongside the actual label, coloring this
    information based on whether the prediction was correct or not.
    Uses the "images_to_probs" function.
    """
    with torch.no_grad():
        preds, probs = predict_with_probs(model, x.to(device))
    preds = preds.cpu().numpy()
    probs = probs.cpu().numpy()
    # Plot the images in the batch, along with predicted and true labels
    fig, axs = plt.subplots(1, 4, figsize=(10, 3))
    for i, ax in enumerate(axs):
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        ax.imshow(to_pil_image(x[i]))
        ax.set_title(
            "{0}, {1:.1%}\n(actual: {2})".format(
                classes[preds[i]],
                probs[i],
                classes[y[i]]),
            color=("green" if preds[i]==y[i].item() else "red")
        )
    return fig

In [None]:
# Define AUM (area under margin)

class AUM:
    def __init__(self, num_examples, device=None):
        self.num_examples = num_examples
        self.device = device
        self.reset()

    @torch.inference_mode()
    def update(self, logits, y, start):
        """
        Updates states with the ground truth labels and predictions.

        Args:
            pred (Tensor): Tensor of label predictions logits of shape (batch_size,
                num_classes).
            y (Tensor): Tensor of ground truth labels with shape (batch_size,).
            start (int): Index of the first example within the dataset.
        """
        logits = logits.to(self.device)
        y = y.to(self.device)

        # Get the logits for the ground truth labels
        batch_size = y.shape[0]
        assigned_logits = logits[torch.arange(batch_size), y]

        # Get the next highest logits
        masked_logits = torch.scatter(logits, dim=1, index=y[..., None], value=-torch.inf)
        largest_other_logits, _ = torch.max(masked_logits, dim=1)

        # Calculate the margins
        margins = assigned_logits - largest_other_logits

        # Accumulate the margin totals
        stop = start + batch_size
        self.margin_totals[start:stop] += margins

        return self
    

    @torch.inference_mode()
    def compute(self, epochs):
        """
        Returns the AUM values.

        Args:
            epochs (int): The number of training epochs that have occurred.
        """
        return self.margin_totals / epochs
    
    @torch.inference_mode()
    def reset(self):
        """
        Resets the state.
        """
        self.margin_totals = torch.zeros((self.num_examples,), device=self.device)


aum = AUM(num_examples=len(threshold_dataset), device=device)

In [None]:
def train(loader, model, loss_fn, optimizer, epoch, writer, loss_metric, accuracy_metric, aum):
    model.train()
    with tqdm(loader) as progress:
        for batch, (x, y) in enumerate(progress):
            x = x.to(device)
            y = y.to(device)

            pred = model(x)
            loss = loss_fn(pred, y)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            loss_metric.update(loss.detach())
            accuracy_metric.update(pred, y)

            progress.set_postfix_str(
                f"loss={loss_metric.compute().item():.4f}, accuracy={accuracy_metric.compute().item():.2%}",
                refresh=False
            )

            aum.update(pred, y, start=loader.batch_size * batch)

    writer.add_scalar("loss/train", scalar_value=loss_metric.compute(), global_step=epoch)
    loss_metric.reset()
    writer.add_scalar("accuracy/train", scalar_value=accuracy_metric.compute(), global_step=epoch)
    accuracy_metric.reset()

In [None]:
def test(loader, model, loss_fn, epoch, writer, loss_metric, accuracy_metric, classes):
    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)

            pred = model(x)
            loss = loss_fn(pred, y)

            loss_metric.update(loss)
            accuracy_metric.update(pred, y)
    
    print(f"test_loss={loss_metric.compute().item():.4f}, test_accuracy={accuracy_metric.compute().item():.2%}")

    writer.add_scalar("loss/test", scalar_value=loss_metric.compute(), global_step=epoch)
    loss_metric.reset()
    writer.add_scalar("accuracy/test", scalar_value=accuracy_metric.compute(), global_step=epoch)
    accuracy_metric.reset()

    x, y = zip(*random.choices(loader.dataset, k=4))
    x = torch.stack(x)
    y = torch.tensor(y)
    writer.add_figure("predictions", plot_classes_preds(model, x, y, classes), global_step=epoch)

In [None]:
# Train the model

epochs = 15
for epoch in range(epochs):
    print("-------------------------------")
    print(f"Epoch {epoch}")
    train(threshold_loader, model, loss_fn, optimizer, epoch, writer, loss_metric, accuracy_metric, aum)
    test(test_loader, model, loss_fn, epoch, writer, loss_metric, accuracy_metric, classes=threshold_loader.dataset.classes)
print("-------------------------------")
print("Done!")

In [None]:
aum_values = aum.compute(epochs).cpu().numpy()

print(f"AUM values: {aum_values.shape}, {aum_values.dtype}")
print(f"mean: {np.mean(aum_values):.4f}, min: {np.min(aum_values):.4f}, max: {np.max(aum_values):.4f}, std: {np.std(aum_values):.4f}")

In [None]:
# Compute AUM threshold

threshold_sample_aum_values = aum_values[threshold_sample_flags == 1]
threshold_sample_aum_percentile = 0.99
aum_threshold = np.percentile(threshold_sample_aum_values, threshold_sample_aum_percentile)

print(f"AUM threshold: {aum_threshold}")

In [None]:
mislabeled_example_flags = (threshold_sample_flags == 0) & (aum_values <= aum_threshold)
mislabeled_example_flags = mislabeled_example_flags.astype(threshold_sample_flags.dtype)
print(f"Mislabeled example flags: {mislabeled_example_flags.shape}, {mislabeled_example_flags.dtype}")
print(f"Potentially mislabeled examples: {np.sum(mislabeled_example_flags)}")

In [None]:
def plot_mislabeled_examples(model, dataset, mislabeled_example_flags, classes):
    (mislabeled_example_indexes,) = np.nonzero(mislabeled_example_flags)
    indexes = np.random.choice(mislabeled_example_indexes, 6)
    fig, axs = plt.subplots(1, 6, figsize=(15, 3.8))
    fig.suptitle("Potentially mislabeled examples")
    for i, ax in enumerate(axs):
        x, y = dataset[indexes[i]]
        with torch.no_grad():
            pred = model(x.to(device)[None])[0].argmax(0)
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        ax.imshow(to_pil_image(x))
        ax.set_title(f"{classes[y]}\nAUM={aum_values[indexes[i]]:.4f}\n(predicted: {classes[pred]})")


plot_mislabeled_examples(model, threshold_dataset, mislabeled_example_flags, threshold_dataset.classes)