In [None]:
import numpy as np
import matplotlib.pyplot as plt
import modsim2.utils.config as config
from PIL import Image as pil_img

In [None]:
# Load sample of CIFAR-10 saved as numpy array
np_cifar_10_sample = np.load('../data/cifar_10_sample.npz')
np_cifar_10_sample = np_cifar_10_sample['cifar_10']

In [None]:
def sample_image(images, index):
    image = images[index, :]
    image = np.moveaxis(image, 0, 2)
    return image

image_0 = sample_image(np_cifar_10_sample, 0)
image_1 = sample_image(np_cifar_10_sample, 2)
image_2 = sample_image(np_cifar_10_sample, 5)

image_pil_0 = pil_img.fromarray(np.uint8(image_0*255))
image_pil_1 = pil_img.fromarray(np.uint8(image_1*255))
image_pil_2 = pil_img.fromarray(np.uint8(image_2*255))

images = [image_0, image_1, image_2]
images_pil = [image_pil_0, image_pil_1, image_pil_2]

In [None]:
grayscale_transforms = [
    {
        "name": "Grayscale",
        "kwargs":{
            "num_output_channels": 3
        }
    },
    {
        "name": "ToTensor",
    },
]

littleblur_transforms = [
    {
        "name": "GaussianBlur",
        "kwargs": {
            "kernel_size": 3,
            "sigma": 1
        }
    },
    {
        "name": "ToTensor",
    },

]

bigblur_transforms = [
    {
        "name": "GaussianBlur",
        "kwargs": {
            "kernel_size": 3,
            "sigma": 3
        }
    },
    {
        "name": "ToTensor",
    },
]

rotate_transforms = [
    {
        "name": "RandomVerticalFlip",
        "kwargs": {
            "p": 1
        }
    },
    {
        "name": "ToTensor",
    },
]

In [None]:
def transform_images(images, transform):
    new_images = []
    for image in images:
        new_image = config.create_transforms(transform)(image)
        new_image = new_image.numpy()
        new_image = np.moveaxis(new_image, 0, 2)
        new_images.append(new_image)
    return new_images

images_grayscale = transform_images(images_pil, grayscale_transforms)
images_littleblur = transform_images(images_pil, littleblur_transforms)
images_bigblur = transform_images(images_pil, bigblur_transforms)
images_rotate = transform_images(images_pil, rotate_transforms)


In [None]:
def show_image(ax, image, title, cmap=None):
    ax.imshow(image, cmap=cmap)
    ax.set_title(title)

In [None]:
fig, ax = plt.subplots(3, 5, figsize = (20, 10))

show_image(ax[0,0], images[0], 'No transform')
show_image(ax[0,1], images_grayscale[0], 'Grayscale')
show_image(ax[0,2], images_littleblur[0], 'Little Blur')
show_image(ax[0,3], images_bigblur[0], 'Big Blur')
show_image(ax[0,4], images_rotate[0], 'Rotate 180')
show_image(ax[1,0], images[1], '')
show_image(ax[1,1], images_grayscale[1], '')
show_image(ax[1,2], images_littleblur[1], '')
show_image(ax[1,3], images_bigblur[1], '')
show_image(ax[1,4], images_rotate[1], '')
show_image(ax[2,0], images[2], '')
show_image(ax[2,1], images_grayscale[2], '')
show_image(ax[2,2], images_littleblur[2], '')
show_image(ax[2,3], images_bigblur[2], '')
show_image(ax[2,4], images_rotate[2], '')

fig.savefig('../output/example_transforms.png', bbox_inches='tight')