In [3]:
import numpy as np
import cv2
import os
import glob
import zipfile
from os.path import basename
import shutil
import numpy as np
import sys
import pandas as pd 

HD = False # Wheter you want 100x120 (False) or 1000x1200 (True) images


height = 10 # number of squares in the vertical direction of the image
width = 12 # " horizontal  "
if HD: 
    height *= 10
    width *= 10
    
N = height*width # number of squares
size = 10 # size of a square

rules = [2,3,4]

## Internal image representation: images are represented only by the shape and color of their entries
colorDict = {
    "0": (200, 0, 0), # blue
    "1": (0,200,0), # green
    "2": (0,0,200), # red
    "3": (100,0,100), # purple
    "4": (0,250,250), # yellow
    "5": (125, 0, 250), # pink
    "6": (0,0,0), # black
    "7": (0,125,250), # orange
    "8": (50, 50, 125), # brown
    "9": (125,125,125) # gray
}

all_shapes = [1,2,3,4]
all_colors = [0,1,2,3,4,5,6,7,8,9]
line_thickness = 1

draw = False # For testing, draw every image created

project_path = "C:/Users/alexf/Interpretability/Generating"

def main():
    """ Create and save a dataset that follows user-defined rules. 
    """
    
    rule, n_Train0, n_Train1, n_Val0, n_Val1, n_Test0, n_Test1, random_placement, noise = getUserInput()
    main_functions(rule, n_Train0, n_Train1, n_Val0, n_Val1, n_Test0, n_Test1, height, width, random_placement, noise)
    print("FINISHED GENERATING THE DATASET")

def main_functions(rule, nTr0, nTr1, nV0, nV1, nTs0, nTs1, height, width, rd_pl, n): 
    """ Create and save a dataset.
    
    Parameters
    ----------
    rule : int in {1,2,3,4}
        the chosen rule 
    nTr0 : int
        number of images with label 0 in training set  
    nTr1 : int
        number of images with label 0 in training set
    nV0 : int
        number of images with label 0 in validation set  
    nV1 : int
        number of images with label 0 in validation set
    nTs0 : int
        number of images with label 0 in test set  
    nTs1 : int
        number of images with label 0 in test set
    h : int
        number of square in an image (vertical direction) 
    w : int 
        number of square in an image (horizontal direction) 
    rd_pl : bool
        whether placement should be random 
    n : float
        scale of the Gaussian noise added to the images
    """
    if (rule == 1): 
        # Rule 1 means with all the rules, so call the function individually for each
        for r in rules: 
            main_functions(r, nTr0, nTr1, nV0, nV1, nTs0, nTs1, height, width, rd_p, n)
    else: 
        # Call with a single rule 
        
        images_0, images_1 = createData(rule, nTr0 + nV0 + nTs0, nTr1 + nV1 + nTs1, height, width) 
        data_path = "data" + str(rule)
        if HD: 
            data_path += "HD"
        removePreviousData(data_path)
        train_0, train_1, val_0, val_1, test_0, test_1 = split_data(nTr0, nTr1, nV0, nV1, nTs0, nTs1, images_0, images_1)
        saveDataset(data_path, train_0, train_1, val_0, val_1, test_0, test_1, rd_pl, n, rule)
        

def split_data(nTr0, nTr1, nV0, nV1, nTs0, nTs1, img0, img1):
    """ Split the images into training, validation and testing sets. 
    
    Parameters
    ----------
    nTr0 : int
        number of images with label 0 in training set  
    nTr1 : int
        number of images with label 0 in training set
    nV0 : int
        number of images with label 0 in validation set  
    nV1 : int
        number of images with label 0 in validation set
    nTs0 : int
        number of images with label 0 in test set  
    nTs1 : int
        number of images with label 0 in test set
    img0 : list of images
        the images with label 0
    img1 : list of images
        the images with label 1
    """
    
    N0 = len(img0)
    N1 = len(img1)
    
    set0 = np.zeros(N0)
    set1 = np.zeros(N1)
    
    for i in range(nTr0, N0): 
        if (i<(nTr0+nV0)): 
            set0[i] = 1
        else: 
            set0[i] = 2
    np.random.shuffle(set0)
            
    for i in range(nTr1, N1): 
        if (i<(nTr1+nV1)): 
            set1[i] = 1
        else: 
            set1[i] = 2
    np.random.shuffle(set1)
 

    Train0 = []
    Train1 = []
    Val0 = []
    Val1 = []
    Test0 = []
    Test1 = []
    
    for i in range(N0):
        if (set0[i] == 1):
            Val0.append(img0[i])  
        elif (set0[i] == 2):
            Test0.append(img0[i])
        else: 
            Train0.append(img0[i])
            
    for i in range(N1):
        if (set1[i] == 1):
            Val1.append(img1[i])  
        elif (set1[i] == 2):
            Test1.append(img1[i])
        else: 
            Train1.append(img1[i])

    return Train0, Train1, Val0, Val1, Test0, Test1

def status(p): 
    """ Update the progress of a task. 
    
    Parameters 
    ----------
    p : float 
        an estimation of the percentage of tasks finished
    """
    
    if (p%1 == 0): 
            print(str(p) + "% done", end="\r")

def getUserInput():
    """ Determine which rules and parameters the user wants to generate datasets with. 
    
    Note: if the user chooses random placement, a global variable, dist_fract, is created. 
    It represents the minimum distance between the centers of two shapes allowed, as a fraction of the shape of the sizes. 

    
    Returns
    -------
    rule : int in {1,2,3,4}
        the chosen rule 
    nTr0 : int
        number of images with label 0 in training set  
    nTr1 : int
        number of images with label 0 in training set
    nV0 : int
        number of images with label 0 in validation set  
    nV1 : int
        number of images with label 0 in validation set
    nTs0 : int
        number of images with label 0 in test set  
    nTs1 : int
        number of images with label 0 in test set
    rd : bool
        whether placement should be random 
    noise : float
        scale of the Gaussian noise added to the images
    """
    
    label_question_string = "Labelling rule: \n 1: all \n 2: there is a cross"
    label_question_string += "\n 3: there is a red triangle \n 4: there is no cross and there is a black circle \n"
    rule = int(input(label_question_string))
    
    nTr0 = int(input("Number of images in the training set with label 0 : "))
    nTr1 = int(input("Number of images in the training set with label 1 : "))
    nV0 = int(input("Number of images in the validation set with label 0 : "))
    nV1 = int(input("Number of images in the validation set with label 1 : "))
    nTs0 = int(input("Number of images in the testing set with label 0 : "))
    nTs1 = int(input("Number of images in the testing set with label 1 : "))
    rd = (input("Random placement ? [y/n]") == 'y')
    if (rd):
        msg = "What is the minimal distance between two centers ? (expressed as a fraction of the size of a shape)"
        global dist_fract 
        dist_fract = float(input(msg))
    noise = int(input("Standard deviation of Gaussian noise:"))
    
    return rule, nTr0, nTr1, nV0, nV1, nTs0, nTs1, rd, noise

def createData(rule,N0,N1,h,w):
    """ Creates the list of images with label 0 and 1. 
    
    Parameters
    ----------
    rule : int in {1,2,3,4}
        the chosen rule 
    N0 : int
        the number of images with label 0
    N1 : int
        the number of images with label 1
    h : int
        number of square in an image (vertical direction) 
    w : int 
        number of square in an image (horizontal direction) 
        
    Returns
    -------
    IM0 : list of images 
        the images with label 0
    IM1 : list of images 
        the images with label 1
    """

    IM0 = []
    print("Creating the data")
    k = 0
    N = N0 + N1
    for i in range(N0):
        IM0.append(createImage(rule, 0, h, w))
        k += 1 
        status(100*k/N)
    IM1 = []
    for i in range(N1):
        IM1.append(createImage(rule, 1, h, w))
        k += 1
        status(100*k/N)
    return IM0, IM1

def createImage(rule, label, h, w):
    """ Create an image following a certain rule. 
    
    Parameters
    ----------
    rule : int in {1,2,3,4}
        the chosen rule 
    label : int in {0,1}
        the label of the image to create
    h : int
        number of square in an image (vertical direction) 
    w : int 
        number of square in an image (horizontal direction) 
        
    Returns
    -------
    the created image
    """
    
    if (rule == 2):
        return createImage2(label, h, w)
    elif (rule ==3):
        return createImage3(label, h, w)
    elif (rule == 4):
        return createImage4(label, h, w)
    else:
        print("Key error")
        quit(2)

def removePreviousData(data_path):
    """ Remove the data from a previous run. 
    
    Parameters
    ----------
    data_path : string 
        the path to remove 
    """
    
    if os.path.exists(data_path):
        shutil.rmtree(data_path)

def saveDataset(data_path, train_0, train_1, val_0, val_1, test_0, test_1, rd, noise, rule): 
    """ Save the images in the dataset and then zip them. 
    
    Note: overwrites, doesn't delete all previous data. 
    If rule is 2, the positions of the salient centers are saved. 
    
    Parameters
    ----------
    data_path : string 
        base directory where to save the data
    train_0 : list of images 
        the training images with label 0
    train_1 : list of images 
        the training images with label 1
    val_0 : list of images 
        the validation images with label 0
    val_1 : list of images 
        the validation images with label 1
    test_0 : list of images 
        the testing images with label 0
    test_1 : list of images 
        the testing images with label 1
    rule : int in {1,2,3,4}
        the chosen rule
    rd : bool
        whether placement should be random 
    noise : float
        scale of the Gaussian noise added to the images
    """
    
    # Careful: overwrites but doesn't delete previous data!
    os.makedirs(data_path)
    os.makedirs(data_path + "/train")
    os.makedirs(data_path + "/val")
    os.makedirs(data_path + "/test")
    os.makedirs(data_path + "/train/0")
    os.makedirs(data_path + "/train/1")
    os.makedirs(data_path + "/val/0")
    os.makedirs(data_path + "/val/1")
    os.makedirs(data_path + "/test/0")
    os.makedirs(data_path + "/test/1")
    if rule==2: 
        os.makedirs(data_path + "/train/1/positions")
        os.makedirs(data_path + "/val/1/positions")
        os.makedirs(data_path + "/test/1/positions")
    
    print("Saving dataset")
    k = 0
    n = len(train_0) + len(train_1) + len(val_0) + len(val_1) + len(test_0) + len(test_1) 
        
    

    for i in range(len(train_0)):
        saveImage(data_path + "/train", i, train_0[i], False, draw, rd, noise, rule)
        k += 1
        status(100*k/n)

    for i in range(len(train_1)):
        saveImage(data_path + "/train", i, train_1[i], True, draw, rd, noise, rule)
        k += 1
        status(100*k/n)
        
        
    for i in range(len(val_0)):
        saveImage(data_path + "/val", i, val_0[i], False, draw, rd, noise, rule)
        k += 1
        status(100*k/n)

    for i in range(len(val_1)):
        saveImage(data_path + "/val", i, val_1[i], True, draw, rd, noise, rule)
        k += 1
        status(100*k/n)
        
    

    for i in range(len(test_0)):
        saveImage(data_path + "/test", i, test_0[i], False, draw, rd, noise, rule)
        k += 1
        status(100*k/n)

    for i in range(len(test_1)):
        saveImage(data_path + "/test", i, test_1[i], True, draw, rd, noise, rule)
        k += 1
        status(100*k/n)

    if rd: 
        zip_path = "/random.zip"
    elif (noise > 0):
        zip_path = "/noisy.zip"
    else: 
        zip_path = "/data.zip"
        
    zip = zipfile.ZipFile(data_path + zip_path, "w")
    
    k = 0
    print("Moving to zip")

    for file in glob.glob(project_path + "/" + data_path + "/train/0/*.png"):
        zip.write(file, 'train/0/' + basename(file))
        k += 1
        status(100*k/n)

    for file in glob.glob(project_path + "/" + data_path + "/train/1/*.png"):
        zip.write(file, 'train/1/' + basename(file))
        k += 1
        status(100*k/n)
    
    for file in glob.glob(project_path + "/" + data_path + "/val/0/*.png"):
        zip.write(file, 'val/0/' + basename(file))
        k += 1
        status(100*k/n)

    for file in glob.glob(project_path + "/" + data_path + "/val/1/*.png"):
        zip.write(file, 'val/1/' + basename(file))
        k += 1
        status(100*k/n)

    for file in glob.glob(project_path + "/" + data_path + "/test/0/*.png"):
        zip.write(file, 'test/0/' + basename(file))
        k += 1
        status(100*k/n)

    for file in glob.glob(project_path + "/" + data_path + "/test/1/*.png"):
        zip.write(file, 'test/1/' + basename(file))
        k += 1
        status(100*k/n)
        
    if rule==2: 
        print("Adding the positions to zip")
        for file in glob.glob(project_path + "/" + data_path + "/test/1/positions/*.csv"):
            zip.write(file, 'test/1/positions/' + basename(file))
        for file in glob.glob(project_path + "/" + data_path + "/train/1/positions/*.csv"):
            zip.write(file, 'train/1/positions/' + basename(file))
        for file in glob.glob(project_path + "/" + data_path + "/val/1/positions/*.csv"):
            zip.write(file, 'val/1/positions/' + basename(file))

    zip.close()


                

def saveImage(set_path, number, im, label, draw, random, noise, rule):  
    """ Convert one image from summary representation to pixel value description.
    
    Parameters
    ----------
    set_path : string 
        base directory of the set where to save the data
    number : int
        the index of the image within its set
    im : array 
        the image to save
    label : int in {0,1}
        the label of the image
    draw : bool  
        whether or not to show the image (for debugging purposes)
    random : bool 
        whether or not to place shapes randomly
    noise : float 
        the scale of the Gaussian noise to add to the image
    rule : int in {1,2,3,4}
        the chosen rule
    """
    
    image = 255*np.ones((height*size, width*size, 3), dtype=np.uint8)
    thickness = -1
    centers = []
    if rule == 2: 
        cross_centers = []
    
    for i in range(height):
        for j in range(width):
            if im[i,j,0] != 0:
                color = colorDict[str(im[i, j, 1])]
                if im[i,j,0] == 1: # circle
                    if random: 
                        accepted = False 
                        while not accepted: 
                            c1, c2 = getRandomCenter()
                            accepted = centerAccepted((c1, c2), centers)
                        centers.append((c1,c2))                        
                        center = (c1,c2)
                    else: 
                        center = (int(size * (j + 0.5)), int(size * (i + 0.5)))  # (x,y)
                    radius = int(0.9*(size/2))
                    image = cv2.circle(image, center, radius, color, thickness)
                elif im[i,j,0] == 2: # square
                    if random: 
                        accepted = False 
                        while not accepted: 
                            c1, c2 = getRandomCenter()
                            accepted = centerAccepted((c1, c2), centers)
                        centers.append((c1,c2))                        
                        top_left = (c1-int(size/2), c2-int(size/2))
                    else: 
                        top_left = (int(size * (j+0.1)), int(size * (i+0.1)))
                    bottom_right = (int(top_left[0]+size*0.8), int(top_left[1]+size*0.8))
                    image = cv2.rectangle(image, top_left, bottom_right, color, thickness)
                elif im[i,j,0] == 3: # cross
                    if random: 
                        accepted = False 
                        while not accepted: 
                            c1, c2 = getRandomCenter()
                            accepted = centerAccepted((c1, c2), centers)
                        centers.append((c1,c2))
                        if rule == 2: 
                            cross_centers.append(np.array([c1, c2]))
                        top_left = (c1-int(size/2), c2-int(size/2))
                    else: 
                        top_left = (int(size * (j+0.1)), int(size * (i+0.1)))
                        if rule == 2: 
                            cross_centers.append(np.array([int(size * (j + 0.5)), int(size * (i + 0.5))]))
                    top_right = (int(top_left[0]+size*0.8), int(top_left[1]))
                    bottom_left = (int(top_left[0]), int(top_left[1]+size*0.8))
                    bottom_right = (int(top_left[0]+size*0.8), int(top_left[1]+size*0.8))
                    image = cv2.line(image, top_left, bottom_right, color, line_thickness)
                    image = cv2.line(image, top_right, bottom_left, color, line_thickness)
                elif im[i,j,0] == 4: # triangle
                    if random: 
                        accepted = False 
                        while not accepted: 
                            c1, c2 = getRandomCenter()
                            accepted = centerAccepted((c1, c2), centers)
                        centers.append((c1,c2))                        
                        top_left = (c1-int(size/2), c2-int(size/2))
                    else: 
                        top_left = (int(size * (j + 0.1)), int(size * (i + 0.1)))
                    bottom_left = (int(top_left[0]), int(top_left[1] + size * 0.8))
                    bottom_right = (int(top_left[0] + size * 0.8), int(top_left[1] + size * 0.8))
                    top = (int(0.5*(bottom_left[0] + bottom_right[0])), top_left[1])
                    triangle_cnt = np.array([bottom_left, bottom_right, top])
                    cv2.drawContours(image, [triangle_cnt], 0, color, -1)

   # Adding gaussian noise
    if (noise > 0): 
        mu, sigma = 0, noise
        image = np.clip((image + np.random.normal(mu, sigma, size = image.shape)).astype(int), 0, 255)
    
    
    filename = set_path + "/"

    if label:
        filename += "1/im1_"
    else:
        filename += "0/im0_"

    filename += str(number).zfill(5) + ".png"

    cv2.imwrite(filename, image)
    
    if draw:
        cv2.imshow('Window', image)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
        
    if rule == 2 and label: 
        cross_centers = np.array(cross_centers)
        fn = set_path + "/1/positions/positions_" + str(number).zfill(5) + ".csv"
        pd.DataFrame(cross_centers).to_csv(fn, header=None, index=None)
        
    
        
def getRandomCenter(): 
    """ Get random coordinates for a new center of a shape.  
    
    Returns
    -------
    A tuple of ints.
    """
    
    return (np.random.randint(low = int(size/2), high = int((width-0.5)*size)), np.random.randint(low = int(size/2), high = int((height-0.5)*size)))
           
def centerAccepted(c, centers): 
    """ Determine whether a new center can be added.
    
    Parameters
    ----------
    c : tuple of ints 
        the coordinates of the new center 
    centers : list of tuples of int 
        the pre-existing centers
        
    Returns
    -------
    acc : bool 
        Whether or not the center has been accepted
    """
    
    acc = True
    for center in centers: 
        if (distance(center, c) < dist_fract*size): 
            acc = False
            break
    return acc 
        
def distance(c1, c2): 
    """ Compute the Euclidean distance between two centers. 
    
    Parameters
    ----------
    c1 : tuple of ints 
        the coordinates of the first center 
    c2 : tuple of ints 
        the coordinates of the second center 
        
    Returns
    -------
    the distance between c1 and c2 
    """
    
    delta = (c1[0] - c2[0], c1[1] - c2[1])
    return np.sqrt(delta[0]**2 + delta[1]**2)


def createImage2(label, h, w):
    """ Creates an image following rule 2: 1 if there is at least one cross. 
    
    Parameters
    ----------
    label : int in {0,1}
        the label of the image to create
    h : int
        number of square in an image (vertical direction) 
    w : int 
        number of square in an image (horizontal direction) 
        
    Returns
    -------
    im : array of dimensions (h,w,2)
        the created image
    """
    
    density = 10
    nb_shapes = numberOfShapes(density, N)
    cross = 3
    if label:  
        # Create an image randomly and add a cross if there is none
        im = createRandomImage(h, w, nb_shapes, all_shapes, all_colors)
        if not is_shape(im, cross):
            color = np.random.choice(all_colors)
            x = np.random.randint(0, len(im))
            y = np.random.randint(0, len(im[0]))
            im = add_element(im, x, y, cross, color)
    else:
        # Create a random image without the possibility of adding a cross
        shapes = all_shapes.copy()
        shapes.remove(cross) # There can't be a cross
        im = createRandomImage(h, w, nb_shapes, shapes, all_colors)
    return im

def is_shape(image, shape):
    """ Determines whether a shape is present in an image. 
    
    Parameters
    ----------
    image : array
        the image to check 
    shape : int
        the shape to check 
        
    Returns
    -------
    Whether or not the shape was found in the image at least once. 
    """
    
    for i in range(len(image)):
        for j in range(len(image[i])):
            if (image[i,j, 0] == shape):
                return True
    return False

def createImage3(label, h, w):
    """ Creates an image following rule 3: 1 if there is at least one red triangle. 
    
    Parameters
    ----------
    label : int in {0,1}
        the label of the image to create
    h : int
        number of square in an image (vertical direction) 
    w : int 
        number of square in an image (horizontal direction) 
        
    Returns
    -------
    im : array of dimensions (h,w,2)
        the created image
    """
    
    density = 10
    nb_shapes = numberOfShapes(density, N)
    triangle = 4
    red = 2
    im = createRandomImage(h, w, nb_shapes, all_shapes, all_colors) # Create a random image
    if not label: 
        im = remove_all_elements(im, triangle, red) # Remove red triangles from it
    if label:
        if not is_element(im, triangle, red): # Add a red triangle if there is none
            x = np.random.randint(0, len(im))
            y = np.random.randint(0, len(im[0]))
            im = add_element(im, x, y, triangle, red)
    return im

def createImage4(label, h, w): # RULE: 1 if there is no cross and a black circle
    """ Creates an image following rule 4: 1 if there is no cross and at least one black circle. 
    
    Parameters
    ----------
    label : int in {0,1}
        the label of the image to create
    h : int
        number of square in an image (vertical direction) 
    w : int 
        number of square in an image (horizontal direction) 
        
    Returns
    -------
    im : array of dimensions (h,w,2)
        the created image
    """
    
    density = 10
    nb_shapes = numberOfShapes(density, N)
    cross = 3
    circle = 1
    black = 6
    if label:
        # Create a random image without crosses. Add a blakc circle if there is none
        shapes = all_shapes.copy()
        shapes.remove(cross)
        im = createRandomImage(h, w, nb_shapes, shapes, all_colors)
        if not is_element(im, circle, black):
            x = np.random.randint(0, len(im))
            y = np.random.randint(0, len(im[0]))
            im = add_element(im, x, y, circle, black)
    else:
        # Create a random image. If it doesn't have a cross, remove black circles. 
        im = createRandomImage(h, w, nb_shapes, all_shapes, all_colors)
        if not is_shape(im, cross):
            im = remove_all_elements(im, circle, black)
    return im



def is_element(im, shape, color):
    """ Determines whether a shape of the given color is present in an image. 
    
    Parameters
    ----------
    im : array
        the image to check 
    shape : int
        the shape to check 
    color : int
        the color to check
        
    Returns
    -------
    Whether or not the shape of the given color was found in the image at least once. 
    """
    
    for i in range(len(im)):
        for j in range(len(im[i])):
            if (im[i,j, 0] == shape & im[i,j,1] == color):
                return True
    return False

def remove_all_elements(im, shape, color):
    """ Remove all entries of a given shape and color in an image.
    
    Parameters
    ----------
    im : array
        the image to filter 
    shape : int
        the shape to remove
    color : int
        the color to remove
        
    Returns
    -------
    im : array
        the modified image
    """

    for i in range(len(im)):
        for j in range(len(im[i])):
            if ((im[i,j, 0] == shape) and (im[i,j,1] == color)):
                im[i,j,0] = 0 
    return im


def add_element(image, x, y, shape, color):
    """ Add an entry of a given shape and color to an image.
    
    Parameters
    ----------
    image : array
        the image to filter 
    x : int
        the first coordinate of the entry to change 
    y : int 
        the second coordinate of the entry to change 
    shape : int
        the shape to add
    color : int
        the color to dd
        
    Returns
    -------
    image : array
        the modified image
    """
    
    image[x,y,0] = shape
    image[x,y,1] = color
    return image



def numberOfShapes(density, nmax):
    """ Determine how many shapes are in an image. 
    
    Follows an exponential distribution with parameter density
    
    Parameters
    ----------
    density : float 
        the density of the exponential density to sample from 
    nmax : int
        the maximum bound on the number of shapes
        
    Returns
    -------
    n : int
        the number of shapes to add in an image
    """
    
    n = np.ceil(np.random.exponential(density))
    n = min(n, nmax)
    return n

def createRandomImage(h, w, n, shapes, colors): 
    """ Create a random image. 
    
    Parameters
    ----------
    h : int
        number of square in an image (vertical direction) 
    w : int 
        number of square in an image (horizontal direction) 
    shapes : list 
        shapes allowed in the image
    colors : list
        colors allowed in the iamge
        
    Returns
    -------
    im : array of dimensions (h,w,2)
        the image created
    """
    
    im = np.zeros((h, w, 2), dtype=int)
    k = 0
    cont = True
    for i in range(h):
        for j in range(w):
            if cont:
                im = add_element(im, i, j, np.random.choice(shapes), np.random.choice(colors))
                k += 1
                if k >= n:
                    cont = False
    for i in range(h):
        np.random.shuffle(im[i, :, :])  # shuffle each row
    for j in range(w):
        np.random.shuffle(im[:, j, :])  # shuffle each column
    return im

def onlyRedTriangles(h, w): 
    """ Create and save an image containing only red triangles. 
    
    Parameters
    ----------
    h : int
        number of square in an image (vertical direction) 
    w : int 
        number of square in an image (horizontal direction) 
    """
    
    im = np.zeros((h, w, 2), dtype=int)
    global dist_fract 
    dist_fract = 0.5
    for i in range(h):
        for j in range(w):
            im = add_element(im, i, j, 4, "2")
    
    path = project_path + "/custom/1"
    removePreviousData(path)
    os.makedirs(path)
    saveImage("custom", 1, im, True, draw, True, 0, None)

def backgroundNoise(h, w, noise): 
    """ Create and save an image containing no shape with backrgound noise. 
    
    Parameters
    ----------
    h : int
        number of square in an image (vertical direction) 
    w : int 
        number of square in an image (horizontal direction) 
    noise : float 
        the scale of the Gaussian noise in the image
    """
    
    im = np.zeros((h, w, 2), dtype=int)
    path = project_path + "/custom/0"
    removePreviousData(path)
    os.makedirs(path)
    saveImage("custom", 1, im, False, draw, False, noise, None)
    
    
    
    
a = int(input("Call main (1) or create custom image (2) ?"))

if (a == 1):
    main()
elif (a == 2): 
    b = int(input("Red triangles (1) or background noise (2)"))
    if (b == 1): 
        onlyRedTriangles(height, width)
    else: 
        noise = int(input("Standard deviation of Gaussian noise:"))
        backgroundNoise(height, width, noise)
else: 
    print("Invalid key", file = sys.stderr)

Call main (1) or create custom image (2) ? 1
Labelling rule: 
 1: all 
 2: there is a cross
 3: there is a red triangle 
 4: there is no cross and there is a black circle 
 1
Number of images in the training set with label 0 :  1
Number of images in the training set with label 1 :  1
Number of images in the validation set with label 0 :  1
Number of images in the validation set with label 1 :  1
Number of images in the testing set with label 0 :  1
Number of images in the testing set with label 1 :  1
Random placement ? [y/n] n
Standard deviation of Gaussian noise: 0


Creating the data
Saving dataset
Moving to zip
Adding the positions to zip
Creating the data
Saving dataset
Moving to zip
Creating the data
Saving dataset
Moving to zip
FINISHED GENERATING THE DATASET
