In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt

from utils import set_matplotlib_configuration

In [None]:
PLOTTING_KWARGS, SAVEFIG_KWARGS = set_matplotlib_configuration(8.1)

In [None]:
def sample_images(dataset, num_samples):
    images = dataset.tensors[0]
    if images.shape[0] < num_samples:
        idxs = np.random.choice(images.shape[0], num_samples, replace=True)
    else:
        idxs = np.random.choice(images.shape[0], num_samples, replace=False)
    return images[idxs]


def plot(folder, num_samples=8):

    target_label = int(folder.split("/")[-1])

    # load the datasets
    target_dataset = torch.load(folder + "/target_dataset.pth", map_location="cpu")
    similar_dataset = torch.load(folder + "/similar_dataset.pth", map_location="cpu")
    different_dataset = torch.load(folder + "/different_dataset.pth", map_location="cpu")

    assert len(target_dataset) > 5
    assert len(similar_dataset) > 5
    assert len(different_dataset) > 5
    target_images = sample_images(target_dataset, num_samples)

    similar_images = sample_images(similar_dataset, num_samples)
    different_images = sample_images(different_dataset, num_samples)

    fig, ax = plt.subplots(3, num_samples, figsize=(6.0, 2.3))
    image_sets = [target_images, similar_images, different_images]

    ax[0][0].set_ylabel("Reference\ndataset", fontsize=8.2)
    ax[1][0].set_ylabel("Similar\ndataset", fontsize=8.2)
    ax[2][0].set_ylabel("Different\ndataset", fontsize=8.2)

    for row in range(3):
        for col in range(num_samples):
            ax_ = ax[row, col]
            img = image_sets[row][col].cpu().numpy().squeeze()
            ax_.imshow(img, cmap="gray")
            ax_.set_xticks([])
            ax_.set_yticks([])
            for spine in ax_.spines.values():
                spine.set_visible(True)
    plt.savefig(f"images/concept_drift_{target_label}.pdf", **SAVEFIG_KWARGS)

In [None]:
# repeat this cell for every digit you run the experiment for (1, 3, 7 in the `all.sh` script)
folder = "../outputs/concept_drift/3"
plot(folder)