# <b>0. IMPORTING LIBRARIES

In [1]:
import numpy as np
import pandas as pd
import os
from matplotlib import pyplot as plt
import tensorflow as tf
import cv2 as cv
import random
from tqdm import tqdm
import xml.etree.ElementTree as ET
import xml.dom.minidom
import uuid

# <b> 1. Paths to directories and files

### 1.1 Donwloading the MNIST dataset from TF and preparing the directories to save the new dataset

In [2]:
#donwload dataset mnist
(x_train_ori, y_train_ori), (x_test_ori, y_test_ori) = tf.keras.datasets.mnist.load_data()
assert x_train_ori.shape == (60000, 28, 28)
assert x_test_ori.shape == (10000, 28, 28)
assert y_train_ori.shape == (60000,)
assert y_test_ori.shape == (10000,)

In [3]:
#files path in a dictionary
paths = dict()
# Linux pc
paths['main'] = '/home/javi/Desktop/Python/MNIST_OD'
paths['dataset'] = os.path.join(paths['main'], 'MNIST_dataset')
paths['train_data'] = os.path.join(paths['dataset'], 'train')
paths['test_data'] = os.path.join(paths['dataset'], 'test')
# Windows pc
# paths['main'] = os.path.normcase('D:\Javi\Python\MNIST_OD')#.replace('\\','/'))
# paths['dataset'] = os.path.join(paths['main'], 'MNIST_dataset')
# paths['train_data'] = os.path.join(paths['dataset'], 'train')
# paths['test_data'] = os.path.join(paths['dataset'], 'test')
# paths['test_img'] = os.path.join(paths['test_data'], 'test')

In [4]:
print(paths['dataset'])
os.path.exists(paths['train_data'])
# print(paths['train_data'])/home/javi/Desktop/Python/MNIST_OD/Mnist_dataset/train
# if 'Linux' in str(os.system('uname')):
#     print('Es linux')
# os.path.exists(paths['train_data'])

d:\javi\python\mnist_od\MNIST_dataset


True

### 1.2 Preparing the dataset to generate new data

In [4]:
#convert the raw data to a dataset-object

train = tf.data.Dataset.from_tensor_slices((x_train_ori,y_train_ori))
test = tf.data.Dataset.from_tensor_slices((x_test_ori,y_test_ori))

#genetating the input data format
BATCH = 5
FETCH = 6
SHUFFLE = 600


train = train.batch(BATCH).shuffle(SHUFFLE).prefetch(FETCH)
test = test.batch(BATCH).prefetch(FETCH)

2023-03-16 01:36:42.785091: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


# 2. Functions to preproccess images in to the new dataset

## 2.1 Functions to treat the image

In [5]:
# function to rotate images

def rotate_image(image, angle, not_print = True):
    '''
    This function rotate the image and plot it
    Imput params:
        image = array like,
        angle = counterclockwise rotate angle (degrees),
        not_print = if 'True' do not plot an image
    Return:
        image = image rotated [array like]
    '''
    image_center = tuple(np.array(image.shape[1::-1]) / 2)
    rot_mat = cv.getRotationMatrix2D(image_center, angle, 1.0)
    result = cv.warpAffine(image, rot_mat, image.shape[1::-1], flags=cv.INTER_LINEAR)
    if not not_print:
        plt.imshow(result, cmap='gray')
        plt.show()
    return result

## 2.2 functions to create a pascal voc format sheet

In [6]:
# Create a pascal voc format file from the dictionary data

def dict_to_xml(tag, d, attribute=None):
    '''
    This function create a xml object from scratch
    attributes:
    - tag : str name of the top tag
    - d : dict like to create child elemments in xml
    - attribute : attribute for top tag (default=None)
    return: xml object
    '''
    elem = ET.Element(tag)
    if attribute is not None:
        elem.attrib = attribute
    for key, val in d.items():
        child = ET.Element(key)
        child.text = str(val)
        elem.append(child)
    return elem


def add_dict_to_xml(xml, d):
    '''
    This function modify chilid elemments
    for an xml object
    attributes:
    - xml : xml parent obj
    - d : dict like to create child elemments in xml
    return: None
    '''
    elem = xml
    for key, val in d.items():
        child = ET.Element(key)
        child.text = str(val)
        elem.append(child)


def toPascalVocFormat(name_of_the_file, directory_name, dictionary_data):
    '''
    This function create a xml file.
    Attributes:
    - name_of_the_file: name of the file (str)
    - directory_name: directory path to save the file (str)
    - d: list of dict len(d)=number of objects in the image
      d example: [{'name': 'Class_name', 'x_min': 'int',
                   'ymin': 'int', 'xmax': 'int', 'ymax': 'int'}]
    return: None
    '''
    # Baseline data in a dictionary format
    basicXmlFileDict = {'folder': os.path.dirname(directory_name),
                        'filename': name_of_the_file+'.jpg',
                        'path': os.path.join(directory_name,
                                             name_of_the_file+'.jpg')
                       }

    sizeDict = {'width': '144', 'height': '144', 'depth': '1'}

    dummyObject = {'name': 'Class_name', 'pose': 'Unspecified',
                   'truncated': '0', 'difficult': '0'}

    dummyObjectBNB = {'xmin': 'int', 'ymin': 'int',
                      'xmax': 'int', 'ymax': 'int'}

    dummy_xml = dict_to_xml('annotation', basicXmlFileDict,
                            attribute={'verified': 'no'})
    dummy_xml.append(ET.Element('source'))
    dummy_xml.find('source').append(ET.Element('database'))
    dummy_xml.find('source')[0].text = 'MNIST dataset for OD by Javi'
    dummy_xml.append(ET.Element('size'))
    dummy_xml.append(ET.Element('segmented'))
    dummy_xml.find('segmented').text = '0'
    add_dict_to_xml(dummy_xml.find('size'), sizeDict)

    for item in dictionary_data:
        dummyObject['name'] = item['name']
        dummyObjectBNB['xmin'] = item['xmin']
        dummyObjectBNB['ymin'] = item['ymin']
        dummyObjectBNB['xmax'] = item['xmax']
        dummyObjectBNB['ymax'] = item['ymax']
        child = ET.Element('object')
        add_dict_to_xml(child, dummyObject)
        dummy_xml.append(child)
        ET.SubElement(dummy_xml[-1], 'bndbox')
        add_dict_to_xml(dummy_xml[-1].find('bndbox'), dummyObjectBNB)
    # Creating a xml in the directory
    f = open(os.path.join(directory_name, name_of_the_file+'.xml'), "w")
    file = ET.tostring(dummy_xml)
    dom = xml.dom.minidom.parseString(file)
    pretty_xml_as_string = dom.toprettyxml()
    f.write(pretty_xml_as_string)
    f.close()

## 2.3 Function to fine tune the dataset

In [7]:
# Creating a function to increase the accuracy of the bounding box on the image

def clossing_bnb(image, possition_list):
    '''
    This function take an image and the possition of an object in the image
    and try to close the box arround the object til the first pixel not dark
    Attributes:
    - image: ndarry of the image
    - possition_list: list of the initial possition of the object (list of lists)
      example point inside possition list: [inital_row, initial_col, high(high=width)]
    return: new possition of the objects [x_min, y_min, x_max, y_max]
    '''
    opt_bnb = list()
    for p, point in enumerate(possition_list):
        ROI = image[point[0]:point[0]+point[2], point[1]:point[1]+point[2]]
        # Reducing the top
        top = 0
        stopper = True
        threshold = 255
        while stopper:
            if ROI[-top-1, :].sum() < threshold:
                top += 1
            elif top >= point[2]:
                print('The bnb become 0, something is wrong0')
                break
            else:
                stopper = False
        # Reducing the bottom
        bot = 0
        stopper = True
        while stopper:
            if ROI[bot+1, :].sum() < threshold:
                bot += 1
            elif bot >= point[2]:
                print('The bnb become 0, something is wrong1')
                break
            else:
                stopper = False
        # Reducing the left
        left = 0
        stopper = True
        while stopper:
            if ROI[:, left+1].sum() < threshold:
                left += 1
            elif left >= point[2]:
                print('The bnb become 0, something is wrong2')
                break
            else:
                stopper = False
        # Reducing the rigth
        rigth = 0
        stopper = True
        while stopper:
            if ROI[:, -rigth-1].sum() < threshold:
                rigth += 1
            elif rigth >= point[2]:
                print('The bnb become 0, something is wrong3')
                break
            else:
                stopper = False
        opt_bnb.append([point[1]+left, point[0]+bot,
                       point[1]+point[2]-rigth, point[0]+point[2]-top])

    return opt_bnb

## 2.4 Main function to create the dataset

In [8]:
# function to create a dataset image

def image_generator(tf_dataset, number_of_images, directory_name):
    # For every number_of_images that it is going to be created:
    images, labels = tf_dataset.as_numpy_iterator().next()
    batch_size = images.shape[0]
    listOfPositions = list(range(batch_size))
    for numb in tqdm(range(number_of_images)):
        # take a banch of the dataset
        images, labels = tf_dataset.as_numpy_iterator().next()
        numberOfImages = random.randint(1, batch_size)
        # take a random numberOfImages from the banch
        indexOfImages = random.sample(listOfPositions, numberOfImages)
        # generate a back image
        img = np.zeros((144, 144), dtype=np.uint8)
        positionList = list()
        labelsList = list()
        # For every image chose:
        for noi in indexOfImages:
            # Rotate the image and scaled 0.5 to 1.5 times
            rotationAngle = random.randint(-45, 45)
            scaleFactor = round((1.5*random.random()+0.5)*28)
            # print(scaleFactor)
            size = (scaleFactor, scaleFactor)
            stopper = 0
            # For every number noi, try to find a possition to do not overlap each other (5 tries):
            while stopper <= 5:
                # Generate random position
                x = random.randint(0, 144-scaleFactor)
                y = random.randint(0, 144-scaleFactor)
                # print(x,y)
                # Initilize the possition has available (not overlap)
                positionCompromised = False
                # Compare with the numbers already added to the black image
                for pos in positionList:
                    # print('Scales',pos[2],scaleFactor)
                    # print(x, y, x+scaleFactor, y+scaleFactor)
                    # print(pos[0],pos[1],pos[0]+pos[2],pos[1]+pos[2])
                    # Is the new number size bigger than the old one
                    if scaleFactor > pos[2]:
                        # The points are True if old point inside square ->(x, y), (x+scale, y+scale)
                        point1 = (pos[0] > x and pos[0] < x+scaleFactor) and (pos[1] > y and pos[1] < y+scaleFactor)
                        point2 = (pos[0]+pos[2] > x and pos[0]+pos[2] < x+scaleFactor) and (pos[1] > y and pos[1] < y+scaleFactor)
                        point3 = (pos[0]+pos[2] > x and pos[0]+pos[2] < x+scaleFactor) and (pos[1]+pos[2] > y and pos[1]+pos[2] < y+scaleFactor)
                        point4 = (pos[0] > x and pos[0] < x+scaleFactor) and (pos[1]+pos[2] > y and pos[1]+pos[2] < y+scaleFactor)
                    else:
                        # The points are True if new point inside square -> pos
                        point1 = (x > pos[0] and x < pos[0]+pos[2]) and (y > pos[1] and y < pos[1]+pos[2])
                        point2 = (x+scaleFactor > pos[0] and x+scaleFactor < pos[0]+pos[2]) and (y > pos[1] and y < pos[1]+pos[2])
                        point3 = (x+scaleFactor > pos[0] and x+scaleFactor < pos[0]+pos[2]) and (y+scaleFactor > pos[1] and y+scaleFactor < pos[1]+pos[2])
                        point4 = (x > pos[0] and x < pos[0]+pos[2]) and (y+scaleFactor > pos[1] and y+scaleFactor < pos[1]+pos[2])
                    # If any point is True the image is ocluding other one and is not included
                    if point1 or point2 or point3 or point4:
                        # print('denegado')
                        positionCompromised = True
                    # print(scaleFactor > pos[2],point1,point2,point3,point4)
                if len(positionList) == 0:
                    positionList.append([x, y, scaleFactor])
                    labelsList.append(labels[noi])
                    img[x:x+scaleFactor, y:y+scaleFactor] = rotate_image(cv.resize(images[noi], size), 
                                                                         rotationAngle, not_print = True)
                    stopper = 10
                elif not positionCompromised:
                    # print('Premio')
                    positionList.append([x, y, scaleFactor])
                    labelsList.append(labels[noi])
                    img[x:x+scaleFactor, y:y+scaleFactor] = rotate_image(cv.resize(images[noi], size), 
                                                                         rotationAngle, not_print = True)
                    stopper = 10
                stopper += 1
        newPositionList = clossing_bnb(img, positionList)
        # newPositionList = list()
        # for box in positionList:
        #     newPositionList.append([box[1], box[0], box[1]+box[2], box[0]+box[2]])
        # print(newPositionList)
        data_dic =  list()
        for label, newPosition in zip(labelsList, newPositionList):
            data_dic.append({'name': label, 'xmin': newPosition[0],
                             'ymin': newPosition[1], 'xmax': newPosition[2],
                             'ymax': newPosition[3]})
        fileName = str(uuid.uuid4())
        toPascalVocFormat(fileName, directory_name, data_dic)
        cv.imwrite(os.path.join(directory_name, fileName+'.jpg'), img)
        # plt.imshow(img)
        # plt.show()

In [9]:
# Creating the dataset

NUMBER_OF_TRAINING_IMAGES = 2000
NUMBER_OF_TEST_IMAGES = 500
image_generator(train, NUMBER_OF_TRAINING_IMAGES, paths['train_data'])
image_generator(test, NUMBER_OF_TEST_IMAGES, paths['test_data'])

100%|███████████████████████████████████████| 2000/2000 [01:30<00:00, 21.99it/s]
100%|████████████████████████████████████████| 500/500 [00:03<00:00, 140.48it/s]
