In [None]:
from datetime import datetime
import os
import random

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torcheval.metrics import Mean, MulticlassAccuracy
import torchvision
from tqdm import tqdm

from aum_ranking import assign_threshold_samples, ThresholdSamplesDataset, AUM, compute_aum_threshold, flag_mislabeled_examples, combine_mislabeled_examples
from models import ResNet32

In [None]:
# Download datasets

train_dataset = torchvision.datasets.CIFAR10(root="data", train=True, download=True, transform=torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Pad(4),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.RandomCrop(32)
]))
test_dataset = torchvision.datasets.CIFAR10(root="data", train=False, download=True, transform=torchvision.transforms.ToTensor())

In [None]:
# Assign two sets of threshold samples

threshold_sample_flags_1, threshold_sample_flags_2 = assign_threshold_samples(num_examples=len(train_dataset), num_classes=len(train_dataset.classes))

In [None]:
print(f"Number of classes: {len(train_dataset.classes)}")
print(f"Number of samples: {len(train_dataset)}")
print(f"Number of threshold samples (first pass): {threshold_sample_flags_1.sum()}")
print(f"Number of threshold samples (second pass): {threshold_sample_flags_2.sum()}")

In [None]:
# Set device

device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

In [None]:
# Helper functions for visualizing predictions

to_pil_image = torchvision.transforms.ToPILImage()


def predict_with_probs(model, x):
    """
    Generates predictions and corresponding probabilities from a trained
    network and a list of images
    """
    logits = model(x)
    output = F.softmax(logits, dim=1)
    probs, preds = torch.max(output, 1)
    return preds, probs


def plot_classes_preds(model, x, y, classes):
    """
    Generates matplotlib Figure using a trained network, along with images
    and labels from a batch, that shows the network's top prediction along
    with its probability, alongside the actual label, coloring this
    information based on whether the prediction was correct or not.
    Uses the "images_to_probs" function.
    """
    with torch.no_grad():
        preds, probs = predict_with_probs(model, x.to(device))
    preds = preds.cpu().numpy()
    probs = probs.cpu().numpy()
    # Plot the images in the batch, along with predicted and true labels
    fig, axs = plt.subplots(1, 4, figsize=(10, 3))
    for i, ax in enumerate(axs):
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        ax.imshow(to_pil_image(x[i]))
        ax.set_title(
            "{0}, {1:.1%}\n(actual: {2})".format(
                classes[preds[i]],
                probs[i],
                classes[y[i]]),
            color=("green" if preds[i]==y[i].item() else "red")
        )
    return fig

In [None]:
def train(loader, model, loss_fn, optimizer, epoch, writer, loss_metric, accuracy_metric, aum):
    model.train()
    with tqdm(loader) as progress:
        for batch, (x, y, indexes) in enumerate(progress):
            x = x.to(device)
            y = y.to(device)

            pred = model(x)
            loss = loss_fn(pred, y)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            loss_metric.update(loss.detach())
            accuracy_metric.update(pred, y)

            progress.set_postfix_str(
                f"loss={loss_metric.compute().item():.4f}, accuracy={accuracy_metric.compute().item():.2%}",
                refresh=False
            )

            aum.update(pred, y, indexes)

    writer.add_scalar("loss/train", scalar_value=loss_metric.compute(), global_step=epoch)
    loss_metric.reset()
    writer.add_scalar("accuracy/train", scalar_value=accuracy_metric.compute(), global_step=epoch)
    accuracy_metric.reset()

In [None]:
def test(loader, model, loss_fn, epoch, writer, loss_metric, accuracy_metric, classes):
    model.eval()
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)

            pred = model(x)
            loss = loss_fn(pred, y)

            loss_metric.update(loss)
            accuracy_metric.update(pred, y)
    
    print(f"test_loss={loss_metric.compute().item():.4f}, test_accuracy={accuracy_metric.compute().item():.2%}")

    writer.add_scalar("loss/test", scalar_value=loss_metric.compute(), global_step=epoch)
    loss_metric.reset()
    writer.add_scalar("accuracy/test", scalar_value=accuracy_metric.compute(), global_step=epoch)
    accuracy_metric.reset()

    x, y = zip(*random.choices(loader.dataset, k=4))
    x = torch.stack(x)
    y = torch.tensor(y)
    writer.add_figure("predictions", plot_classes_preds(model, x, y, classes), global_step=epoch)

In [None]:
def training_run(pass_index, threshold_dataset, test_dataset, epochs, aum):
    # Make data loaders
    batch_size = 64
    threshold_loader = DataLoader(threshold_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)

    # Build model
    model = ResNet32(num_classes=len(threshold_dataset.classes)).to(device)

    # Set loss function and optimizer
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, weight_decay=1e-4, momentum=0.9, nesterov=True)

    # Make TensorBoard writer
    now = datetime.now().strftime("%b%d_%H-%M-%S")
    log_dir = os.path.join("runs", "find_mislabeled_data", f"{now}_pass_{pass_index}")
    writer = SummaryWriter(log_dir=log_dir)

    # Define metrics
    loss_metric = Mean(device=device)
    accuracy_metric = MulticlassAccuracy(device=device)

    # Train the model
    print("Training model")
    for epoch in range(epochs):
        print("-------------------------------")
        print(f"Epoch {epoch}")
        train(threshold_loader, model, loss_fn, optimizer, epoch, writer, loss_metric, accuracy_metric, aum)
        test(test_loader, model, loss_fn, epoch, writer, loss_metric, accuracy_metric, classes=threshold_loader.dataset.classes)
    print("-------------------------------")
    print("Done!")

    return model

In [None]:
def find_mislabeled_examples_pass(pass_index, threshold_sample_flags, train_dataset, test_dataset):
    print(f"Performing pass {pass_index}")

    # Make threshold samples dataset
    threshold_dataset = ThresholdSamplesDataset(train_dataset, threshold_sample_flags)

    # Create AUM calculator
    aum = AUM(num_examples=len(threshold_dataset), device=device)

    # Train a model to populate the margin values
    epochs = 150
    model = training_run(pass_index, threshold_dataset, test_dataset, epochs, aum)

    # Compute AUM values
    aum_values = aum.compute(epochs).cpu().numpy()
    print(f"AUM values: {aum_values.shape}, {aum_values.dtype}")
    print(f"mean: {np.mean(aum_values):.4f}, min: {np.min(aum_values):.4f}, max: {np.max(aum_values):.4f}, std: {np.std(aum_values):.4f}")

    # Compute AUM threshold
    aum_threshold = compute_aum_threshold(aum_values, threshold_sample_flags)
    print(f"AUM threshold: {aum_threshold}")

    # Flag (potentially) mislabeled examples
    mislabeled_example_flags = flag_mislabeled_examples(aum_values, threshold_sample_flags, aum_threshold)
    print(f"Potentially mislabeled examples: {np.sum(mislabeled_example_flags)}")
    print(f"Finished pass {pass_index}")
    print("===============================")
    return mislabeled_example_flags


def find_mislabeled_examples(train_dataset, test_dataset, threshold_sample_flags_1, threshold_sample_flags_2):
    mislabeled_example_flags_1 = find_mislabeled_examples_pass(1, threshold_sample_flags_1, train_dataset, test_dataset)
    mislabeled_example_flags_2 = find_mislabeled_examples_pass(2, threshold_sample_flags_2, train_dataset, test_dataset)
    mislabeled_example_flags = combine_mislabeled_examples(mislabeled_example_flags_1, mislabeled_example_flags_2)
    return mislabeled_example_flags

In [None]:
mislabeled_example_flags = find_mislabeled_examples(train_dataset, test_dataset, threshold_sample_flags_1, threshold_sample_flags_2)
print(f"Mislabeled example flags: {mislabeled_example_flags.shape}, {mislabeled_example_flags.dtype}")
print(f"Potentially mislabeled examples: {np.sum(mislabeled_example_flags)}")

In [None]:
def plot_mislabeled_examples(dataset, mislabeled_example_flags):
    (mislabeled_example_indexes,) = np.nonzero(mislabeled_example_flags)
    indexes = np.random.choice(mislabeled_example_indexes, 6)
    fig, axs = plt.subplots(1, 6, figsize=(15, 3.2))
    fig.suptitle("Potentially mislabeled examples")
    for i, ax in zip(indexes, axs):
        x, y = dataset[i]
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        ax.imshow(to_pil_image(x))
        ax.set_title(f"{i}\n{dataset.classes[y]}")


plot_mislabeled_examples(train_dataset, mislabeled_example_flags)