In [None]:
import os
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from tqdm import tqdm

# Load data

In [None]:
(Xtrain, ytrain), (Xtest, ytest) = tf.keras.datasets.mnist.load_data()
Xtrain.shape, ytrain.shape, Xtest.shape, ytest.shape

# Convert MNIST data to bounding box data

In [None]:
# add channels dimension
Xtrain = np.expand_dims(Xtrain, axis=-1)
Xtest = np.expand_dims(Xtest, axis=-1)

In [None]:
# example image
plt.imshow(Xtrain[0], cmap='gray');

In [None]:
def draw_mnist_to_image(img_mnist):
    height = np.random.randint(20, 36)
    width = np.random.randint(20, 36)
    img_mnist_resized = tf.image.resize(img_mnist, (height, width), preserve_aspect_ratio=False)
    full_image = np.zeros((128, 128, 1))
    x1 = np.random.randint(0, 128 - width)
    y1 = np.random.randint(0, 128 - height)
    x2 = x1 + width
    y2 = y1 + height
    full_image[y1:y2, x1:x2] = img_mnist_resized
    return full_image, x1, y1, x2, y2


def mnist_to_bbox_data(xdata):
    images = []
    coordinates = []
    for img in tqdm(xdata):
        new_img, x1, y1, x2, y2 = draw_mnist_to_image(img)
        images.append(new_img)
        coordinates.append([x1,y1,x2,y2])
    images = np.array(images, dtype=np.uint8)
    coordinates = np.array(coordinates, dtype=np.uint8)
    return images, coordinates

In [None]:
train_images, train_coordinates = mnist_to_bbox_data(Xtrain)
train_images.shape, train_coordinates.shape

In [None]:
test_images, test_coordinates = mnist_to_bbox_data(Xtest)
test_images.shape, test_coordinates.shape

# Visualize processed data

In [None]:
def plot_processed_data(images, coordinates, labels, n_rows=4, n_cols=4):
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows))

    for i, ax in enumerate(np.ravel(axes)):
        ax.imshow(images[i])
        x1, y1, x2, y2 = coordinates[i,0], coordinates[i,1], coordinates[i,2], coordinates[i,3]
        rect = Rectangle((x1, y1), x2 - x1, y2 - y1, ec='red', fc='None')
        ax.add_patch(rect)
        ax.set_title("Label: {}".format(labels[i]), fontsize=18)

In [None]:
# training data
plot_processed_data(train_images, train_coordinates, ytrain)

In [None]:
# test data
plot_processed_data(test_images, test_coordinates, ytest, n_rows=2)

# Store processed data

In [None]:
def store_data(store_dir, filename_prefix, images, coordinates, labels):
    assert len(images) == len(coordinates) == len(labels), "Provided datas don't have the same lengths."
    os.makedirs(store_dir, exist_ok=True)
    np.save(os.path.join(store_dir, filename_prefix+'images.npy'), images, allow_pickle=True)
    np.save(os.path.join(store_dir, filename_prefix+'coordinates.npy'), coordinates, allow_pickle=True)
    np.save(os.path.join(store_dir, filename_prefix+'labels.npy'), labels, allow_pickle=True)

In [None]:
# store training data
store_data('../data', 'train_', train_images, train_coordinates, ytrain)

In [None]:
# store test data
store_data('../data', 'test_', test_images, test_coordinates, ytest)