In [1]:
"""Oxford Pet handwritten digits dataset.

"""
import os
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from dotenv import find_dotenv, load_dotenv

load_dotenv(find_dotenv())

def display_image_grid(images_filenames, images_directory, masks_directory, predicted_masks=None):
    cols = 3 if predicted_masks else 2
    rows = len(images_filenames)
    figure, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(10, 24))
    for i, image_filename in enumerate(images_filenames):
        image = np.load(os.path.join(images_directory, image_filename))
        mask = np.load(os.path.join(masks_directory, image_filename))
        print(mask.max())
        ax[i, 0].imshow(image)
        ax[i, 1].imshow(mask, interpolation="nearest")

        ax[i, 0].set_title("Image")
        ax[i, 1].set_title("Ground truth mask")

        ax[i, 0].set_axis_off()
        ax[i, 1].set_axis_off()

        if predicted_masks:
            predicted_mask = predicted_masks[i]
            ax[i, 2].imshow(predicted_mask, interpolation="nearest")
            ax[i, 2].set_title("Predicted mask")
            ax[i, 2].set_axis_off()
    plt.tight_layout()
    plt.show()

In [2]:
path = Path(os.environ.get('PROJECT_DIR')) / 'data'
images_filenames = []
with open(path / 'OxfordPet_training_data.txt', 'r') as file_stream:
    for line in file_stream.readlines():
        file_name, _ = line.split(',')
        images_filenames.append(file_name)
        if len(images_filenames) >= 8:
            break

In [None]:
display_image_grid(images_filenames, path / 'processed/oxford-pet/images', path / 'processed/oxford-pet/masks')