In [None]:
import numpy as np
import cv2
import os
import glob
import zipfile
from os.path import basename
import shutil


height = 4 # number of squares in the vertical direction of the image
width = 4 # " horizontal  "
N = height*width # number of squares
size = 8 # size of a square
rules = [1,2,3,4]

colorDict = {
    "0": (200, 0, 0), # blue
    "1": (0,200,0), # green
    "2": (0,0,200), # red
    "3": (100,0,100), # purple
    "4": (0,250,250), # yellow
    "5": (125, 0, 250), # pink
    "6": (0,0,0), # black
    "7": (0,125,250), # orange
    "8": (50, 50, 125), # brown
    "9": (125,125,125) # gray
}

all_shapes = [1,2,3,4]
all_colors = [0,1,2,3,4,5,6,7,8,9]
line_thickness = 1

draw = False # For testing, draw every image created
project_path = "C:/Users/alexf/Interpretability"

def main():
    n0, n1, rule, validation_fraction = getUserInput()
    main_functions(rule, n0, n1, height, width, validation_fraction)
    print("FINISHED GENERATING THE DATASET")

def main_functions(rule, n0, n1, height, width, f): 
    if (rule == 0): 
        for r in rules: 
            main_functions(r, n0, n1, height, width, f)
    else: 
        images_0, images_1 = createData(rule, n0, n1, height, width)       
        data_path = "data" + str(rule)
        removePreviousData(data_path)
        train_0, train_1, test_0, test_1 = split_train_test(images_0, images_1, f)
        saveDataset(data_path, train_0, train_1, test_0, test_1, rule)
        

def split_train_test(img0, img1, p):
    N0 = len(img0)
    nTrain0 = int(p * N0)
    train0 = np.full(N0, True)
    for i in range(nTrain0):
        train0[i] = False
    np.random.shuffle(train0)

    N1 = len(img1)
    nTrain1 = int(p * N1)
    train1 = np.full(N1, True)
    for i in range(nTrain1):
        train1[i] = False
    np.random.shuffle(train1)

    Train0 = []
    Train1 = []
    Test0 = []
    Test1 = []
    for i in range(N0):
        if train0[i]:
            Train0.append(img0[i])  # .flatten()
        else:
            Test0.append(img0[i])
    for i in range(N1):
        if train1[i]:
            Train1.append(img1[i])
        else:
            Test1.append(img1[i])

    return Train0, Train1, Test0, Test1

def status(p): 
    if (p%10 == 0): 
            print(str(p) + "% done", end="\r")

def getUserInput():
    n0 = int(input("Number of images with label 0 : "))
    n1 = int(input("Number of images with label 1 : "))
    label_question_string = "Labelling rule: \n 0: all \n 1: there is a shape \n 2: there is a cross"
    label_question_string += "\n 3: there is a red triangle \n 4: there is no cross and there is a black circle \n"
    rule = int(input(label_question_string))
    train_fraction = float(input("Input validation fraction: "))
    return n0, n1, rule, train_fraction

def createData(rule,n0,n1,h,w):
    IM0 = []
    print("Creating the data")
    k = 0
    n = n0 + n1
    for i in range(n0):
        IM0.append(createImage(rule, 0, h, w))
        k += 1 
        status(100*k/n)
    IM1 = []
    for i in range(n1):
        IM1.append(createImage(rule, 1, h, w))
        k += 1
        status(100*k/n)
    return IM0, IM1

def createImage(rule, label, h, w):
    if (rule == 1):
        return createImage1(label,h,w) # contains information about the shapes and colors in the image
    elif (rule == 2):
        return createImage2(label, h, w)
    elif (rule ==3):
        return createImage3(label, h, w)
    elif (rule == 4):
        return createImage4(label, h, w)
    else:
        print("Key error")
        quit(2)

def removePreviousData(data_path):
    if os.path.exists(data_path):
        shutil.rmtree(data_path)

def saveDataset(data_path, train_0, train_1, test_0, test_1, rule): # Careful: overwrites but doesn't delete previous data!
    os.makedirs(data_path)
    os.makedirs(data_path + "/train")
    os.makedirs(data_path + "/test")
    os.makedirs(data_path + "/train/0")
    os.makedirs(data_path + "/train/1")
    os.makedirs(data_path + "/test/0")
    os.makedirs(data_path + "/test/1")
    
    print("Saving dataset")
    k = 0
    n = len(train_0) + len(train_1) + len(test_0) + len(test_1)

    for i in range(len(train_0)):
        saveImage(data_path + "/train", i, train_0[i], False, draw)
        k += 1
        status(100*k/n)

    for i in range(len(train_1)):
        saveImage(data_path + "/train", i, train_1[i], True, draw)
        k += 1
        status(100*k/n)

    for i in range(len(test_0)):
        saveImage(data_path + "/test", i, test_0[i], False, draw)
        k += 1
        status(100*k/n)

    for i in range(len(test_1)):
        saveImage(data_path + "/test", i, test_1[i], True, draw)
        k += 1
        status(100*k/n)

    zip = zipfile.ZipFile(data_path + "/data" + str(rule) + ".zip", "w")
    
    k = 0
    print("Moving to zip")

    for file in glob.glob(project_path + "/" + data_path + "/train/0/*.png"):
        zip.write(file, 'train/0/' + basename(file))
        k += 1
        status(100*k/n)

    for file in glob.glob(project_path + "/" + data_path + "/train/1/*.png"):
        zip.write(file, 'train/1/' + basename(file))
        k += 1
        status(100*k/n)

    for file in glob.glob(project_path + "/" + data_path + "/test/0/*.png"):
        zip.write(file, 'test/0/' + basename(file))
        k += 1
        status(100*k/n)

    for file in glob.glob(project_path + "/" + data_path + "/test/1/*.png"):
        zip.write(file, 'test/1/' + basename(file))
        k += 1
        status(100*k/n)


    zip.close()


def saveImage(set_path, number, im, label, draw):
    image = 255*np.ones((height*size, width*size, 3), dtype=np.uint8)
    thickness = -1
    for i in range(height):
        for j in range(width):
            if im[i,j,0] != 0:
                color = colorDict[str(im[i, j, 1])]
                if im[i,j,0] == 1: # circle
                    center = (int(size * (j + 0.5)), int(size * (i + 0.5)))  # (x,y)
                    radius = int(0.9*(size/2))
                    image = cv2.circle(image, center, radius, color, thickness)
                elif im[i,j,0] == 2: # square
                    top_left = (int(size * (j+0.1)), int(size * (i+0.1)))
                    bottom_right = (int(top_left[0]+size*0.8), int(top_left[1]+size*0.8))
                    image = cv2.rectangle(image, top_left, bottom_right, color, thickness)
                elif im[i,j,0] == 3: # cross
                    top_left = (int(size * (j+0.1)), int(size * (i+0.1)))
                    top_right = (int(top_left[0]+size*0.8), int(top_left[1]))
                    bottom_left = (int(top_left[0]), int(top_left[1]+size*0.8))
                    bottom_right = (int(top_left[0]+size*0.8), int(top_left[1]+size*0.8))
                    image = cv2.line(image, top_left, bottom_right, color, line_thickness)
                    image = cv2.line(image, top_right, bottom_left, color, line_thickness)
                elif im[i,j,0] == 4: # triangle
                    top_left = (int(size * (j + 0.1)), int(size * (i + 0.1)))
                    bottom_left = (int(top_left[0]), int(top_left[1] + size * 0.8))
                    bottom_right = (int(top_left[0] + size * 0.8), int(top_left[1] + size * 0.8))
                    top = (int(0.5*(bottom_left[0] + bottom_right[0])), top_left[1])
                    triangle_cnt = np.array([bottom_left, bottom_right, top])
                    cv2.drawContours(image, [triangle_cnt], 0, color, -1)


    filename = set_path + "/"

    if label:
        filename += "1/im1_"
    else:
        filename += "0/im0_"

    filename += str(number) + ".png"

    cv2.imwrite(filename, image)
    if draw:
        cv2.imshow('Window', image)
        cv2.waitKey(0)
        cv2.destroyAllWindows()

def createImage1(label, h, w): # RULE: 1 if there is at least one shape

    if label:
        density = 5
        nb_shapes = numberOfShapes(density, N)
        return createRandomImage(h, w, nb_shapes, all_shapes, all_colors)
    else:
        return np.zeros((h, w, 2), dtype=int)

def createImage2(label, h, w): # RULE: 1 if there is at least one cross
    density = 10
    nb_shapes = numberOfShapes(density, N)
    cross = 3
    if label:
        im = createRandomImage(h, w, nb_shapes, all_shapes, all_colors)
        if not is_shape(im, cross):
            color = np.random.choice(all_colors)
            x = np.random.randint(0, len(im))
            y = np.random.randint(0, len(im[0]))
            im = add_element(im, x, y, cross, color)
    else:
        shapes = all_shapes.copy()
        shapes.remove(cross) # There can't be a cross
        im = createRandomImage(h, w, nb_shapes, shapes, all_colors)
    return im

def is_shape(image, shape):
    for i in range(len(image)):
        for j in range(len(image[i])):
            if (image[i,j, 0] == shape):
                return True
    return False

def createImage3(label, h, w): # RULE: 1 if there is at least one red triangle
    density = 10
    nb_shapes = numberOfShapes(density, N)
    triangle = 4
    red = 2
    im = createRandomImage(h, w, nb_shapes, all_shapes, all_colors)
    if not label: 
        im = remove_all_elements(im, triangle, red)
    if label:
        if not is_element(im, triangle, red):
            x = np.random.randint(0, len(im))
            y = np.random.randint(0, len(im[0]))
            im = add_element(im, x, y, triangle, red)
    return im

def createImage4(label, h, w): # RULE: 1 if there is no cross and a black circle
    density = 10
    nb_shapes = numberOfShapes(density, N)
    cross = 3
    circle = 1
    black = 6
    if label:
        shapes = all_shapes.copy()
        shapes.remove(cross)
        im = createRandomImage(h, w, nb_shapes, shapes, all_colors)
        if not is_element(im, circle, black):
            x = np.random.randint(0, len(im))
            y = np.random.randint(0, len(im[0]))
            im = add_element(im, x, y, circle, black)
    else:
        im = createRandomImage(h, w, nb_shapes, all_shapes, all_colors)
        if not is_shape(im, cross):
            im = remove_all_elements(im, circle, black)
    return im



def is_element(im, shape, color):
    for i in range(len(im)):
        for j in range(len(im[i])):
            if (im[i,j, 0] == shape & im[i,j,1] == color):
                return True
    return False

def remove_all_elements(im, shape, color):
    for i in range(len(im)):
        for j in range(len(im[i])):
            if ((im[i,j, 0] == shape) and (im[i,j,1] == color)):
                im[i,j,0] = 1 ## SHOULD BE 0
    return im


def add_element(image, x, y, shape, color):
    image[x,y,0] = shape
    image[x,y,1] = color
    return image



def numberOfShapes(density, nmax):
    n = np.ceil(np.random.exponential(density))
    n = min(n, nmax)
    return n

def createRandomImage(h, w, n, shapes, colors): # number of shapes follows an exponential distribution
    im = np.zeros((h, w, 2), dtype=int)
    k = 0
    cont = True
    for i in range(h):
        for j in range(w):
            if cont:
                im = add_element(im, i, j, np.random.choice(shapes), np.random.choice(colors))
                k += 1
                if k >= n:
                    cont = False
    for i in range(h):
        np.random.shuffle(im[i, :, :])  # shuffle each row
    for j in range(w):
        np.random.shuffle(im[:, j, :])  # shuffle each column
    return im

main()

Number of images with label 0 :  100000
Number of images with label 1 :  100000
Labelling rule: 
 0: all 
 1: there is a shape 
 2: there is a cross
 3: there is a red triangle 
 4: there is no cross and there is a black circle 
 0
Input validation fraction:  0.2


Creating the data
Saving dataset
Moving to zip
20.0% done