In [None]:
import numpy as np
import cv2
import os
from PIL import Image


def read_image_and_name(path):
    imgdir = os.listdir(path)
    imgdir.sort(key=lambda x: int(x[0:2]))  # do the sorting according to file names
    imglst = []
    imgs = []
    for v in imgdir:
        imglst.append(path + v)
        imgs.append(cv2.imread(path + v))
    print(imglst)
    print('original images shape: ' + str(np.array(imgs).shape))
    return imglst,imgs


def read_label_and_name(path):
    labeldir = os.listdir(path)
    labeldir.sort(key=lambda x: int(x[0:2]))  # do the sorting according to file names
    labellst = []
    labels = []
    for v in labeldir:
        labellst.append(path + v)
        labels.append(np.asarray(Image.open(path + v)))
    print(labellst)
    print('original labels shape: ' + str(np.array(labels).shape))
    return labellst,labels


def resize(imgs,resize_height, resize_width):
    img_resize = []
    for file in imgs:
        img_resize.append(cv2.resize(file,(resize_height,resize_width)))
    return img_resize


# crop N images with the resolution of 576 by 576 into 48 by 48
def crop(image,dx):
    list = []
    for i in range(image.shape[0]):
        for x in range(image.shape[1] // dx):
            # the list here has appended 20*12*12, so the returned shape is (2880,48,48)
            for y in range(image.shape[2] // dx):
                list.append(image[i,  y*dx: (y+1)*dx,  x*dx: (x+1)*dx])
    return np.array(list)


# Estimated output of network is converted to image subpatches
# Estimated output of network size=[Npatches, patch_height*patch_width, 2]
def pred_to_imgs(pred, patch_height, patch_width, mode="original"):
    assert (len(pred.shape)==3)  #3D array: (Npatches,height*width,2)
    assert (pred.shape[2]==2 )  #check if the classes are 2
    pred_images = np.empty((pred.shape[0],pred.shape[1]))  #(Npatches,height*width)
    if mode=="original":        # the probability output of the network
        for i in range(pred.shape[0]):
            for pix in range(pred.shape[1]):
                pred_images[i,pix]=pred[i,pix,1] # pred[:,:,0] is non-segmentation output, and pred[:,:,1] is segmentation output

    elif mode=="threshold":                      # network probability-thresholds output
        for i in range(pred.shape[0]):
            for pix in range(pred.shape[1]):
                if pred[i,pix,1]>=0.5:
                    pred_images[i,pix]=1
                else:
                    pred_images[i,pix]=0
    else:
        print("mode " +str(mode) +" not recognized, it can be 'original' or 'threshold'")
        exit()
    # the output form is modified as (Npatches,1, patch_height, patch_width)
    pred_images = np.reshape(pred_images,(pred_images.shape[0],1, patch_height, patch_width))
    return pred_images