In [None]:
import os
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from PIL import Image
from tqdm import tqdm
from efficientdet.constants import PROJECT_PATH

# Load data

In [None]:
(Xtrain, ytrain), (Xtest, ytest) = tf.keras.datasets.mnist.load_data()
Xtrain.shape, ytrain.shape, Xtest.shape, ytest.shape

# Convert MNIST data to bounding box data

In [None]:
# example image
plt.imshow(Xtrain[0], cmap='gray');

In [None]:
def check_if_intersect(boxA, boxB):
    x1_max = max(boxA[0], boxB[0])
    y1_max = max(boxA[1], boxB[1])
    x2_min = min(boxA[2], boxB[2])
    y2_min = min(boxA[3], boxB[3])
    intersection_area = max(0, x2_min - x1_max + 1) * max(0, y2_min - y1_max + 1)
    return intersection_area


def draw_mnist_to_image(imgs_mnist, labels_mnist, h=512, w=512):
    # extract only relevant pixels
    imgs_mnist_resized = []
    for img_mnist in imgs_mnist:
        xs, ys = np.where(img_mnist > 0)
        img_mnist_resized = img_mnist[xs.min()-1:xs.max()+2, ys.min()-1:ys.max()+2]
        imgs_mnist_resized.append(img_mnist_resized)
    
    # resize mnist image
    imgs_mnist_resized2 = []
    labels2 = []
    for img_mnist_resized, label in zip(imgs_mnist_resized, labels_mnist):
        size = np.random.randint(75, 150)
        try:
            img_mnist_resized = tf.image.resize(np.expand_dims(img_mnist_resized, axis=-1), (size, size), preserve_aspect_ratio=True)
            imgs_mnist_resized2.append(img_mnist_resized)
            labels2.append(label)
        except:
            pass
    
    # choose locations to draw in full image
    locations = []
    full_image = np.zeros((h, w), dtype=np.uint16)
    for img_mnist_resized in imgs_mnist_resized2:
        intersect = True
        while intersect:
            intersect = False
            height, width = img_mnist_resized.shape[:2]
            # add mnist image to bigger image
            x1 = np.random.randint(0, w - width)
            y1 = np.random.randint(0, h - height)
            x2 = x1 + width
            y2 = y1 + height
            for loc in locations:
                if check_if_intersect([x1, y1, x2, y2], loc):
                    intersect = True
                    break
        locations.append([x1, y1, x2, y2])
    
    # draw the images into the full image and add some noise
    for img_mnist_resized, [x1, y1, x2, y2] in zip(imgs_mnist_resized2, locations):
        full_image[y1:y2, x1:x2] = np.squeeze(img_mnist_resized)
    distortion = np.random.randint(0, 90, size=(h, w), dtype=np.uint16)
    full_image = np.clip(full_image + distortion, a_min=0, a_max=255)
    full_image = full_image.astype(np.uint8)
    return full_image, locations, labels2


def mnist_to_bbox_data(xdata, ydata, store_dir, min_digits=1, max_digits=3):
    image_paths = []
    coordinates = []
    targets = []
    os.makedirs(store_dir, exist_ok=True)
    
    n_images = np.random.randint(low=min_digits, high=max_digits+1)  # number of images to draw
    imgs, labels = [], []
    i = 0
    for img, label in tqdm(zip(xdata, ydata), total=len(xdata)):
        imgs.append(img)
        labels.append(label)
        if len(imgs) < n_images and len(image_paths) + len(imgs) < len(xdata):
            continue
        new_img, locs, labels = draw_mnist_to_image(imgs, labels)
        pil_img = Image.fromarray(new_img, mode='L')
        output_path = os.path.join(store_dir, "img{:05d}.jpg".format(i))
        pil_img.save(output_path)
        image_paths += len(locs) * [output_path]
        coordinates += locs
        targets += labels
        # setup for next image
        n_images = np.random.randint(low=min_digits, high=max_digits+1)  # number of images to draw
        imgs, labels = [], []
        i += 1
    coordinates = np.array(coordinates, dtype=np.int16)
    return image_paths, coordinates, targets


def create_data_df(image_paths, coordinates, labels):
    return pd.concat([
        pd.DataFrame(data={'img_path': image_paths}),
        pd.DataFrame(data=coordinates, columns=['x1', 'y1', 'x2', 'y2']),
        pd.DataFrame(data={'label': labels})
    ], axis=1)

In [None]:
train_image_paths, train_coordinates, train_targets = mnist_to_bbox_data(Xtrain, ytrain, os.path.join(PROJECT_PATH, 'data/images_train'))
len(train_image_paths), train_coordinates.shape, len(train_targets)

In [None]:
test_image_paths, test_coordinates, test_targets = mnist_to_bbox_data(Xtest, ytest, os.path.join(PROJECT_PATH, 'data/images_test'))
len(test_image_paths), test_coordinates.shape, len(test_targets)

In [None]:
df_train = create_data_df(train_image_paths, train_coordinates, train_targets)
df_train.head(2)

In [None]:
df_train.to_csv(os.path.join(PROJECT_PATH, 'data/train.csv'), index=False)

In [None]:
df_test = create_data_df(test_image_paths, test_coordinates, test_targets)
df_test.head(2)

In [None]:
df_test.to_csv(os.path.join(PROJECT_PATH, 'data/test.csv'), index=False)

# Visualize processed data

In [None]:
def plot_processed_data(df, n_rows=4, n_cols=4):
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 4 * n_rows))

    for ax, (img_path, g) in zip(np.ravel(axes), df.groupby('img_path')):
        ax.imshow(plt.imread(img_path))
        for _, row in g.iterrows():
            x1, y1, x2, y2 = row['x1'], row['y1'], row['x2'], row['y2']
            rect = Rectangle((x1, y1), x2 - x1, y2 - y1, ec='red', fc='None')
            ax.add_patch(rect)
        ax.set_title("Label(s): {}".format(", ".join(g['label'].astype(str).tolist())), fontsize=18)

In [None]:
# training data
plot_processed_data(df_train)

In [None]:
# test data
plot_processed_data(df_test, n_rows=2)