In [None]:
import os
from pathlib import Path
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import numpy as np
from dotenv import find_dotenv, load_dotenv

load_dotenv(find_dotenv())

def display_image_grid(images_filenames, images_directory, images_labels):
    cols = 4
    rows = int(np.floor(len(images_filenames) / cols))
    figure = plt.figure(figsize=(24, 24))
    grid = ImageGrid(figure, 111,
                 nrows_ncols=(rows, cols),  # creates grid of axes
                 axes_pad=(0.1, 0.3),  # pad between axes
                 )

    for ax, img_name, img_label in zip(grid, images_filenames, images_labels):
        image = np.load(os.path.join(images_directory, img_name))
        ax.imshow(image)
        ax.set_title(f"Label: {img_label}")
        ax.set_axis_off()

    plt.tight_layout()
    plt.show()

In [None]:
path = Path(os.environ.get('PROJECT_DIR')) / 'data'
images_filenames = []
images_labels = []
with open(path / 'MNIST_training_data.txt', 'r') as file_stream:
    for line in file_stream.readlines():
        file_name, label = line.split(',')
        images_filenames.append(file_name)
        images_labels.append(int(label))
        if len(images_filenames) >= 8:
            break

display_image_grid(images_filenames, path / 'processed/mnist', images_labels)

In [None]:
path = Path(os.environ.get('PROJECT_DIR')) / 'data'
images_filenames = []
images_labels = []
with open(path / 'MNIST_validation_data.txt', 'r') as file_stream:
    for line in file_stream.readlines():
        file_name, label = line.split(',')
        images_filenames.append(file_name)
        images_labels.append(int(label))
        if len(images_filenames) >= 8:
            break

display_image_grid(images_filenames, path / 'processed/mnist', images_labels)

In [None]:
path = Path(os.environ.get('PROJECT_DIR')) / 'data'
images_filenames = []
images_labels = []
with open(path / 'MNIST_test_data.txt', 'r') as file_stream:
    for line in file_stream.readlines():
        file_name, label = line.split(',')
        images_filenames.append(file_name)
        images_labels.append(int(label))
        if len(images_filenames) >= 8:
            break

display_image_grid(images_filenames, path / 'processed/mnist', images_labels)
