In [None]:
import os
import sys
import matplotlib.pyplot as plt
import numpy as np
from torchvision import transforms
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))
import himyb.datasets.biased_mnist as biased_mnist

%load_ext autoreload
%autoreload 2

In [None]:
ROOT_DIR = ... # root dir of the Biased MNIST dataset

In [None]:
transform = transforms.Compose(
        [
            transforms.ToTensor(),
            # transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
        ]
    )

## Original resolution 28x28

In [None]:
dataset = biased_mnist.ColourBiasedMNIST(
            root=ROOT_DIR, 
            train=True,
            download=True,
            rho=0.9,
            transform=transform,
            no_digit=False,
            classes_to_use=[0,1],
            n_confusing_labels=1,
            class_size=5000
)
dataset.effective_rho, len(dataset)/2, dataset.conflict_count

In [None]:
len(dataset)

In [None]:
fig, axes = plt.subplots(9, 9, figsize=(12, 12))
axes = axes.flatten()

for i in range(81):
    image = dataset[i][0].permute(1, 2, 0)  # permute to change the order of dimensions for plotting
    target_label = dataset[i][1]
    bias_label = dataset[i][2]
    axes[i].set_title(f"T:{target_label}, B:{bias_label}")
    axes[i].imshow(image)
    axes[i].axis('off')

plt.tight_layout()
plt.show()

## Resolution 32x32

In [None]:
custom_loader = biased_mnist.get_dataloader(root=ROOT_DIR,batch_size=1000, train=True, rho=0.9,classes_to_use=[0,1],n_confusing_labels=1,resolution=(32,32), resize_before_colouring=False)
images, target_labels, bias_labels  = next(iter(custom_loader))

In [None]:
fig, axes = plt.subplots(9,9, figsize=(12,12))
axes = axes.flatten()

for i in range(81):
    image = images[i].permute(1, 2, 0)/2 + 0.5  # permute to change the order of dimensions for plotting
    axes[i].set_title(f"T:{target_labels[i]}, B:{bias_labels[i]}")
    axes[i].imshow(image)
    axes[i].axis('off')

plt.tight_layout()
plt.show()

## Print colors of the color map

In [None]:
COLOUR_MAP = [
        [255, 0, 0],
        [0, 255, 0],
        [0, 0, 255],
        [225, 225, 0],
        [225, 0, 225],
        [0, 255, 255],
        [255, 130, 0],
        [255, 0, 128],
        [128, 0, 255],
        [128, 128, 128],
    ]

In [None]:
fig, axes = plt.subplots(1, len(COLOUR_MAP), figsize=(15, 2))

for i, (ax, color) in enumerate(zip(axes, COLOUR_MAP)):
    # Create an image with the color
    img = np.zeros((10, 10, 3), dtype=np.uint8)
    img[:, :] = color
    
    # Plot the image
    ax.imshow(img)
    ax.axis('off')
    ax.set_title(f"{i}:{color}")

plt.tight_layout()
plt.show()