In [None]:
from datetime import datetime
import os
import random

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torcheval.metrics import Mean, MulticlassAccuracy
import torchvision
from tqdm import tqdm

from aum_ranking import assign_threshold_samples, ThresholdSamplesDataset, AUM, compute_aum_threshold, flag_mislabeled_examples, combine_mislabeled_examples
from models import ResNet32

In [None]:
# Define how to corrupt the training dataset labels

def assign_corrupted_samples(dataset, fraction):
    corrupted_flags = np.zeros((len(dataset),), dtype=np.uint8)
    num_corrupted_samples = round(fraction * len(dataset))
    corrupted_flags[:num_corrupted_samples] = 1
    np.random.shuffle(corrupted_flags)

    real_labels = np.array([y for x, y in dataset])
    noise = np.random.randint(1, len(dataset.classes), size=len(dataset))
    fake_labels = (real_labels + noise) % len(dataset.classes)
    corrupted_labels = np.where(corrupted_flags, fake_labels, real_labels)
    return corrupted_labels, corrupted_flags


class CorruptedDataset(Dataset):
    """
    A Dataset wrapper which synthetically mislabels some of the data.
    """

    def __init__(self, dataset, labels):
        self.dataset = dataset
        self.labels = labels
        self.classes = dataset.classes

    def __getitem__(self, index):
        x, y = self.dataset[index]
        return x, self.labels[index]

    def __len__(self):
        return len(self.dataset)


def corrupt_dataset(dataset, fraction):
    corrupted_labels, corrupted_flags = assign_corrupted_samples(dataset, fraction)
    return CorruptedDataset(dataset, corrupted_labels), corrupted_flags

In [None]:
# Define datasets

pixel_means = torch.tensor([0.4914, 0.4822, 0.4465])
pixel_stds = torch.tensor([0.2470, 0.2435, 0.2616])

preprocess = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=pixel_means, std=pixel_stds)
])

# This will be used later for visualizing images
unprocess = torchvision.transforms.Compose([
    torchvision.transforms.Normalize(mean=-pixel_means / pixel_stds, std=1.0 / pixel_stds),
    torchvision.transforms.ToPILImage()
])

train_dataset = torchvision.datasets.CIFAR10(root="data", train=True, download=True, transform=torchvision.transforms.Compose([
    preprocess,
    torchvision.transforms.Pad(4),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomCrop(32)
]))

train_dataset, corrupted_flags = corrupt_dataset(train_dataset, 0.2)

test_dataset = torchvision.datasets.CIFAR10(root="data", train=False, download=True, transform=preprocess)

In [None]:
# Assign two sets of threshold samples

threshold_sample_flags_1, threshold_sample_flags_2 = assign_threshold_samples(num_examples=len(train_dataset), num_classes=len(train_dataset.classes))

In [None]:
print(f"Number of classes: {len(train_dataset.classes)}")
print(f"Number of samples: {len(train_dataset)}")
print(f"Number of threshold samples (first pass): {threshold_sample_flags_1.sum()}")
print(f"Number of threshold samples (second pass): {threshold_sample_flags_2.sum()}")

In [None]:
# Set device

device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

In [None]:
# Helper functions for visualizing predictions

def predict_with_probs(model, x):
    """
    Generates predictions and corresponding probabilities from a trained
    network and a list of images
    """
    x = x.to(device)
    logits = model(x)
    output = F.softmax(logits, dim=1)
    probs, preds = torch.max(output, 1)
    probs = probs.cpu().numpy()
    preds = preds.cpu().numpy()
    return preds, probs


def plot_classes_preds(model, x, y, classes):
    """
    Generates matplotlib Figure using a trained network, along with images
    and labels from a batch, that shows the network's top prediction along
    with its probability, alongside the actual label, coloring this
    information based on whether the prediction was correct or not.
    Uses the "images_to_probs" function.
    """
    with torch.no_grad():
        preds, probs = predict_with_probs(model, x)
    # Plot the images in the batch, along with predicted and true labels
    fig, axs = plt.subplots(1, 4, figsize=(10, 3))
    for i, ax in enumerate(axs):
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        ax.imshow(unprocess(x[i]))
        ax.set_title(
            "{0}, {1:.1%}\n(actual: {2})".format(
                classes[preds[i]],
                probs[i],
                classes[y[i]]),
            color=("green" if preds[i]==y[i].item() else "red")
        )
    return fig

In [None]:
def train(loader, progress, model, loss_fn, optimizer, epoch, epochs, writer, loss_metric, accuracy_metric, aum):
    progress.reset()
    progress.desc = f"Epoch {epoch+1}/{epochs}"

    model.train()
    for x, y, indexes in loader:
        x = x.to(device)
        y = y.to(device)

        pred = model(x)
        loss = loss_fn(pred, y)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        loss_metric.update(loss.detach())
        accuracy_metric.update(pred, y)

        progress.set_postfix_str(
            f"loss={loss_metric.compute().item():.4f}, accuracy={accuracy_metric.compute().item():.2%}",
            refresh=False
        )

        aum.update(pred, y, indexes)

        progress.update()

    writer.add_scalar("loss/train", scalar_value=loss_metric.compute(), global_step=epoch)
    loss_metric.reset()
    writer.add_scalar("accuracy/train", scalar_value=accuracy_metric.compute(), global_step=epoch)
    accuracy_metric.reset()

    progress.refresh()

In [None]:
def test(loader, model, loss_fn, epoch, writer, loss_metric, accuracy_metric, classes):
    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)

            pred = model(x)
            loss = loss_fn(pred, y)

            loss_metric.update(loss)
            accuracy_metric.update(pred, y)
    
    writer.add_scalar("loss/test", scalar_value=loss_metric.compute(), global_step=epoch)
    loss_metric.reset()
    writer.add_scalar("accuracy/test", scalar_value=accuracy_metric.compute(), global_step=epoch)
    accuracy_metric.reset()

    x, y = zip(*random.choices(loader.dataset, k=4))
    x = torch.stack(x)
    y = torch.tensor(y)
    writer.add_figure("predictions", plot_classes_preds(model, x, y, classes), global_step=epoch)

In [None]:
def predict_on_dataset(model, loader):
    preds = np.empty(len(loader.dataset), dtype=np.int32)
    probs = np.empty(len(loader.dataset), dtype=np.int32)
    
    model.eval()
    with torch.no_grad():
        for x, y, indexes in loader:
            pred, prob = predict_with_probs(model, x)
            preds[indexes] = pred
            probs[indexes] = prob

    return preds, probs

In [None]:
def training_run(pass_index, threshold_dataset, test_dataset, epochs, aum):
    # Make data loaders
    batch_size = 64
    threshold_loader = DataLoader(threshold_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)

    # Build model
    model = ResNet32(num_classes=len(threshold_dataset.classes)).to(device)

    # Set loss function and optimizer
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, weight_decay=1e-4, momentum=0.9, nesterov=True)

    # Make TensorBoard writer
    now = datetime.now().strftime("%b%d_%H-%M-%S")
    log_dir = os.path.join("runs", "find_mislabeled_data", f"{now}_pass_{pass_index}")
    writer = SummaryWriter(log_dir=log_dir)

    # Define metrics
    loss_metric = Mean(device=device)
    accuracy_metric = MulticlassAccuracy(device=device)

    # Train the model
    print("Training model")
    print("-------------------------------")
    with tqdm(total=len(threshold_loader)) as train_progress:
        for epoch in range(epochs):
            train(threshold_loader, train_progress, model, loss_fn, optimizer, epoch, epochs, writer, loss_metric, accuracy_metric, aum)
            test(test_loader, model, loss_fn, epoch, writer, loss_metric, accuracy_metric, classes=threshold_loader.dataset.classes)
    suggested_labels, suggested_label_probs = predict_on_dataset(model, threshold_loader)
    print("-------------------------------")
    print("Done!")

    return model, suggested_labels, suggested_label_probs

In [None]:
def find_mislabeled_examples_pass(pass_index, threshold_sample_flags, train_dataset, test_dataset):
    print(f"Performing pass {pass_index}")

    # Make threshold samples dataset
    threshold_dataset = ThresholdSamplesDataset(train_dataset, threshold_sample_flags)

    # Create AUM calculator
    aum = AUM(num_examples=len(threshold_dataset))

    # Train a model to populate the margin values
    epochs = 150
    model, suggested_labels, suggested_label_probs = training_run(pass_index, threshold_dataset, test_dataset, epochs, aum)

    # Compute AUM values
    aum_values = aum.compute(epochs).numpy()
    print(f"AUM values: {aum_values.shape}, {aum_values.dtype}")
    print(f"mean: {np.mean(aum_values):.4f}, min: {np.min(aum_values):.4f}, max: {np.max(aum_values):.4f}, std: {np.std(aum_values):.4f}")

    # Compute AUM threshold
    aum_threshold = compute_aum_threshold(aum_values, threshold_sample_flags)
    print(f"AUM threshold: {aum_threshold}")

    # Flag (potentially) mislabeled examples
    mislabeled_example_flags = flag_mislabeled_examples(aum_values, threshold_sample_flags, aum_threshold)
    print(f"Potentially mislabeled examples: {np.sum(mislabeled_example_flags)}")
    print(f"Finished pass {pass_index}")
    print("===============================")
    return mislabeled_example_flags, suggested_labels, suggested_label_probs


def combine_suggested_labels(suggested_labels_1, suggested_label_probs_1, suggested_labels_2, suggested_label_probs_2):
    choose_1 = suggested_label_probs_1 >= suggested_label_probs_2
    suggested_labels = np.where(choose_1, suggested_labels_1, suggested_labels_2)
    return suggested_labels


def find_mislabeled_examples(train_dataset, test_dataset, threshold_sample_flags_1, threshold_sample_flags_2):
    mislabeled_example_flags_1, suggested_labels_1, suggested_label_probs_1 = find_mislabeled_examples_pass(1, threshold_sample_flags_1, train_dataset, test_dataset)
    mislabeled_example_flags_2, suggested_labels_2, suggested_label_probs_2 = find_mislabeled_examples_pass(2, threshold_sample_flags_2, train_dataset, test_dataset)
    mislabeled_example_flags = combine_mislabeled_examples(mislabeled_example_flags_1, mislabeled_example_flags_2)
    suggested_labels = combine_suggested_labels(suggested_labels_1, suggested_label_probs_1, suggested_labels_2, suggested_label_probs_2)
    return mislabeled_example_flags, suggested_labels

In [None]:
mislabeled_example_flags, suggested_labels = find_mislabeled_examples(train_dataset, test_dataset, threshold_sample_flags_1, threshold_sample_flags_2)
print(f"Mislabeled example flags: {mislabeled_example_flags.shape}, {mislabeled_example_flags.dtype}")
print(f"Potentially mislabeled examples: {np.sum(mislabeled_example_flags)}")

In [None]:
def plot_mislabeled_examples(dataset, mislabeled_example_flags, suggested_labels):
    (mislabeled_example_indexes,) = np.nonzero(mislabeled_example_flags)
    indexes = np.random.choice(mislabeled_example_indexes, 6)
    fig, axs = plt.subplots(1, 6, figsize=(15, 3.6))
    fig.suptitle("Potentially mislabeled examples")
    for i, ax in zip(indexes, axs):
        x, y = dataset[i]
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        ax.imshow(unprocess(x))
        ax.set_title(f"{i}\n{dataset.classes[y]}\n(suggested: {dataset.classes[suggested_labels[i]]})")


plot_mislabeled_examples(train_dataset, mislabeled_example_flags, suggested_labels)

In [None]:
mislabeled_true_positives = np.count_nonzero(mislabeled_example_flags & corrupted_flags)
mislabeled_false_positives = np.count_nonzero(mislabeled_example_flags & ~corrupted_flags)
mislabeled_false_negatives = np.count_nonzero(~mislabeled_example_flags & corrupted_flags)

precision = mislabeled_true_positives / (mislabeled_true_positives + mislabeled_false_positives)
recall = mislabeled_true_positives / (mislabeled_true_positives + mislabeled_false_negatives)

print(f"Mislabeled example identification")
print(f"precision: {precision:.2%}")
print(f"recall:    {recall:.2%}")