In [1]:
from glob import glob
import random
import os
import xml.etree.ElementTree as ET
import shutil
import numpy as np
import pandas as pd

In [2]:
PATH = 'C:\\Users\\Diego\\Desktop\\TFI-Cazcarra'

In [3]:
ELEMENT_TO_TRAIN = "diagramas"

TRAIN_TEST_SPLIT = 0.7
VAL_SPLIT = 0.5

In [4]:
BASE_DIR_IMG = f"{PATH}/data/imagenes_{ELEMENT_TO_TRAIN}/"
BASE_DIR_XML = f"{PATH}/data/xml_{ELEMENT_TO_TRAIN}/"

CLASSES_CSV = f"{PATH}/data/csv/classes_{ELEMENT_TO_TRAIN}.csv"
TRAIN_CSV = f"{PATH}/data/csv/train_{ELEMENT_TO_TRAIN}.csv"
TEST_CSV = f"{PATH}/data/csv/test_{ELEMENT_TO_TRAIN}.csv"
VAL_CSV = f"{PATH}/data/csv/val_{ELEMENT_TO_TRAIN}.csv"

In [5]:
FILES = glob(BASE_DIR_XML+"*.xml") + glob(BASE_DIR_IMG+"*.png")
IMAGES = glob(os.path.join(BASE_DIR_IMG,"*.png"))

In [6]:
CLASSES = ['tabla','muchos_opcional','muchos_obligatorio','uno_opcional','uno_obligatorio']

In [19]:
def sep_images(images):
    random.shuffle(images)
    test_split_n = int(len(IMAGES) * TRAIN_TEST_SPLIT)
    val_split_n = int((len(IMAGES)-test_split_n) * VAL_SPLIT)

    train_images = IMAGES[:test_split_n]
    test_images = IMAGES[test_split_n:(test_split_n+val_split_n)]
    val_images = IMAGES[(test_split_n+val_split_n):]
    
    return train_images, test_images, val_images

In [43]:
def generate_train_test(dataset):    
    for (dtype, image_paths, output_csv) in dataset:
        print ("[INFO] creating '{}' set...".format(dtype))
        print ("[INFO] {} total images in '{}' set".format(len(image_paths), dtype))

        csv = open(output_csv, "w")
        
        header_row = ["image_path","xmin", "ymin", "xmax", "ymax", "label"]
        csv.write("{}\n".format(",".join(header_row)))
        
        for image_path in image_paths:
            fname = image_path.split("\\")[-1]
            fname = fname[:fname.rfind(".")].replace("\\","")+".xml"
            annot_path = BASE_DIR_XML + fname
            tree = ET.parse(annot_path)
            root = tree.getroot()
            size = root.find("size")
            h = int(size.find("height").text)
            w = int(size.find("width").text)

            for label in CLASSES:
                for o in tree.iter("object"):
                    if o.find("name").text==label:
                        box = o.find("bndbox")
                        xmin = int(box.find("xmin").text)
                        ymin = int(box.find("ymin").text)
                        xmax = int(box.find("xmax").text)
                        ymax = int(box.find("ymax").text)

                        # truncate any bounding box coordinates that fall outside
                        # the boundaries of the image
                        xmin = max(0, xmin)
                        ymin = max(0, ymin)
                        xmax = min(w, xmax)
                        ymax = min(h, ymax)

                        # ignore the bounding boxes where the minimum values are larger
                        # than the maximum values and vice-versa due to annotation errors
                        if xmin >= xmax or ymin >= ymax:
                            continue
                        elif xmax <= xmin or ymax <= ymin:
                            continue

                        row = [os.path.abspath(image_path),str(xmin), str(ymin), str(xmax), str(ymax), str(label)]
                        csv.write("{}\n".format(",".join(row)))
        csv.close()

In [14]:
def write_classes():
    print("[INFO] writing classes...")
    csv = open(CLASSES_CSV, "w")
   
    header_row = ["nombre", "encoding"]
    csv.write("{}\n".format(",".join(header_row)))
    
    rows = [",".join([c, str(i+1)]) for (i,c) in enumerate(CLASSES)]
    csv.write("\n".join(rows))
    csv.close()

In [45]:
def split_wrapper():    
    train_images, test_images, val_images = sep_images(IMAGES)
    dataset = [("train", train_images, TRAIN_CSV), ("test", test_images, TEST_CSV), ("val", val_images, VAL_CSV)]
    #Create and write train and test csv
    generate_train_test(dataset)
    #Create classes csv
    write_classes()

In [120]:
def get_val_count(df):
    return df[df.columns[-1]].value_counts()

split_wrapper()

train = pd.read_csv(TRAIN_CSV)
test = pd.read_csv(TEST_CSV)
val = pd.read_csv(VAL_CSV)

while len(get_val_count(train))<5 or len(get_val_count(test))<5 or len(get_val_count(val))<5:
    split_wrapper()
    train = pd.read_csv(TRAIN_CSV)
    test = pd.read_csv(TEST_CSV)
    val = pd.read_csv(VAL_CSV)

[INFO] creating 'train' set...
[INFO] 28 total images in 'train' set
[INFO] creating 'test' set...
[INFO] 6 total images in 'test' set
[INFO] creating 'val' set...
[INFO] 6 total images in 'val' set
[INFO] writing classes...


In [121]:
get_val_count(train)

tabla                 208
uno_obligatorio       197
muchos_obligatorio    175
uno_opcional           34
muchos_opcional        21
Name: label, dtype: int64

In [122]:
get_val_count(test)

tabla                 33
muchos_obligatorio    25
uno_obligatorio       25
uno_opcional          12
muchos_opcional        9
Name: label, dtype: int64

In [123]:
get_val_count(val)

tabla                 31
uno_obligatorio       25
muchos_obligatorio    15
uno_opcional           9
muchos_opcional        8
Name: label, dtype: int64