## Create Datasets

In [None]:
def prep_data(band,target,cropx,cropy):
    import numpy as np
    bc = crop_center(band,cropx,cropy)
    bc_train = crop_trainset(bc,target)
    bc_test = crop_testset(bc,target)
    bc_val = crop_valset(bc,target)
    
    bc_train = (bc_train*255).astype('uint16')
    bc_test = (bc_test*255).astype('uint16')
    bc_val = (bc_val*255).astype('uint16')

    return bc_train,bc_test,bc_val

    
def split_dataset(img,target,cropx,cropy):
    import numpy as np
    import time

    t1 = time.time()
    img = np.einsum('ijk->kij',img)
    train_arr = []
    test_arr = []
    val_arr = []
    
    for i in img:
        train,test,val = prep_data(i,target,cropx,cropy)
        train_arr.append([train])
        test_arr.append([test])
        val_arr.append([val])
            
    train_arr = train_arr[0:3]
    train_stk = np.vstack(train_arr)
    train_stk = np.einsum('kij->ijk',train_stk)
    
    test_arr = test_arr[0:3]
    test_stk = np.vstack(test_arr)
    test_stk = np.einsum('kij->ijk',test_stk)
    
    val_arr = val_arr[0:3]
    val_stk = np.vstack(val_arr)
    val_stk = np.einsum('kij->ijk',val_stk)
    
    t2 = time.time()
    total_time = t2-t1
        
    return train_stk, test_stk,val_stk, total_time  

## Cropping Functions

In [2]:
def get_crop_dims(x,patch_size):
    crop = (round(int(x/patch_size),0)*patch_size)
    return crop

def crop_center(img,cropx,cropy): #this function defines the central extent and crops input to match
    x,y = img.shape
    startx = x//2-(cropx//2)
    starty=y//2-(cropy//2)
    print(startx, startx+cropx, starty, starty+cropy)
    return img[starty:starty+cropy,startx:startx+cropx]

def crop_trainset(img,target): #this function splits the image to a defined target for training and validating (eg. 80%), currently set for x clipping because input image is landscape
    x,y = img.shape
    print(img.shape)
    startx = 0
    endx = x
    starty=int(y*(2*target))
    endy=y
    print(startx, endx, starty, endy)
    return img[starty:endy,startx:endx]

def crop_testset(img,target):#this function splits the image to a defined target for testing (eg. 20%)
    x,y = img.shape
    startx = 0
    endx = x
    starty=0
    endy=int(y*target)
    print(startx, endx, starty, endy)
    return img[starty:endy,startx:endx]

def crop_valset(img,target):#this function splits the image to a defined target for validating (eg. 20%)
    x,y = img.shape
    startx = 0
    endx = x
    starty=int(y*target)
    endy=int(y*(2*target))
    print(startx, endx, starty, endy)
    return img[starty:endy,startx:endx]

## Tiling Functions

In [None]:
def make_folders(out_path):
    import os
    isExist = os.path.exists(out_path+"train/")
    if not isExist:
        os.makedirs(out_path+"train/")  
    isExist = os.path.exists(out_path+"train/organized/")
    if not isExist:
        os.makedirs(out_path+"train/organized/")   
    isExist = os.path.exists(out_path+"test/")
    if not isExist:
        os.makedirs(out_path+"test/")  
    isExist = os.path.exists(out_path+"test/organized/")
    if not isExist:
        os.makedirs(out_path+"test/organized/")   
    isExist = os.path.exists(out_path+"val/")
    if not isExist:
        os.makedirs(out_path+"val/") 
    isExist = os.path.exists(out_path+"val/organized/")
    if not isExist:
        os.makedirs(out_path+"val/organized/")

In [1]:
def get_pts(mask):
    import numpy as np
    mask = np.array(mask)
    result = np.where(np.array(mask)==1)
    data1 = result[1]
    data2 = result[0]
    tree_tops = zip(data1,data2)
    tree_tops=list(tree_tops)
    #good_points = refine_pts(tree_tops)
    return tree_tops

In [1]:
def make_tiles(in_path,out_path,good_pts,image,annotation,xmax,ymax,patch_size,in_pad):
    import numpy as np
    import os, shutil, re, time, cv2, random
    patch_list = []
    
    in_path = os.path.abspath(in_path)
    out_path = os.path.abspath(out_path)
    
    t1 = time.time()
    for i in good_pts:
        x = int(i[1])
        y = int(i[0])

        if in_pad == -1:
                pad = random.randrange(33)
                patch_size = int(128-(pad*2))
        else:
            pad = in_pad
        img_name = str(x)+"_"+str(y)
        suffix = str("_pad_"+str(pad))
        filename = img_name+str(suffix)
        
        step = int(patch_size/2)
        
        x1 = x-step
        x2 = x+step
        y1 = y-step
        y2 = y+step
        
        img = image[x1:x2,y1:y2]
        tc_img = annotation[x1:x2,y1:y2]

        #multi-channel input = image for tiling
        img = np.einsum('kij->jik',img)
        
        #b1 = img[0]
        #b2 = img[1]
        #b3 = img[2]
        
        #b1 = check_NoData(b1)
        #b2 = check_NoData(b2)
        #b3 = check_NoData(b3)
        for i in range(3):
            check_NoData(img[:,:,i])
            
        if np.all(img[:,:,0] != -9999) and np.all(img[:,:,1] != -9999) and np.all(img[:,:,2] != -9999):#replaced b1,b2,b3
            array = [b3,b2,b1]
            patch = np.dstack(array)
            padded_patch = cv2.copyMakeBorder(patch,pad,pad,pad,pad,cv2.BORDER_CONSTANT,value=0)
            width, height, depth = padded_patch.shape
            full_tile = patch_size+(2*pad)
                
            if width != full_tile:
                print("tile size invalid, width ",width)
            elif height != full_tile:
                print("tile size invalid, height ",height)
            else:
                #single channel input = annotation file
                ones = np.count_nonzero(tc_img)
                #if ones != 0.0:
                patch_list.append(str(ones)+filename)
                folder = str(ones)
                out_path_join = os.path.join(str(out_path),str(folder))
                isExist = os.path.exists(out_path_join)
                if not isExist:
                    os.makedirs(out_path_join)
                img_name = str(ones)+"_"+filename+'.png'
                cv2.imwrite(os.path.join(out_path_join , img_name),padded_patch)
    
    t2 = time.time()
    total_time = t2-t1
    return patch_list,total_time

## Data Checks

In [3]:
def check_NoData(array):
    import numpy as np
    #if np.isnan(array).any() and (arr == 0).any():
    #    print("invalid: array contains NaN and zeros")
    #    return -9999
    if np.isnan(array).any():
        #print("invalid: array contains NaN")
        return -9999
    elif np.all(array == 0):
        #print("invalid: array all zeros")
        return -9999

    else:
        #print("Valid! Array has Data")
        return array

## Other Functions (May be useless)

In [None]:
def refine_pts(in_pts,size,xmax,ymax):
    good_pts = []
    for i in in_pts:
        x = i[0]
        y = i[1]

        if size <= x <= (xmax-size):
            if size <= y <= (ymax-size):
                array = [x,y]
                good_pts.append(array)
    return good_pts

In [None]:
def get_sum(in_path,out_path,datalist): #get a list of tree counts (sum) for each tile and write to .txt
    import numpy as np
    import os
    import ntpath
    import PIL
    from PIL import Image
    in_path = os.path.abspath(in_path)
    for i in os.listdir(in_path):
        if i.endswith(".png"):
            filename = ntpath.basename(i)
            filename = filename[:-4]
            if filename in datalist:
                img = Image.open(in_path+"/"+i)
                ones = np.count_nonzero(img)
                with open(out_path+str(filename)+".txt",'w') as f:                
                    f.write(str(ones))

In [None]:
def get_list(in_path): #get a list of tiles that are suitable for CNN - tile does not have any NoDATA values
    datalist = []
    import os
    import ntpath
    data_path = os.path.abspath(in_path)
    for f in os.listdir(data_path):
        if f.endswith(".png"):
            filename = data_path+'/' + f
            tile_id = ntpath.basename(f)
            tile_id = tile_id[:-4]
            datalist.append(str(tile_id))
    return datalist 

In [None]:
def remove_files(path):
    import os
    path = os.path.abspath(path)
    for file in os.listdir(path):
        if file.endswith(".png"):
            os.remove(path+"/"+file)
        if file.endswith(".txt"):
            os.remove(path+"/"+file)

In [1]:
def reorganize_patches(in_path,out_path,zipped_list): 
    import os, shutil, re, time
    in_path = os.path.abspath(in_path)
    out_path = os.path.abspath(out_path)
    time1 = time.time()
    
    for i in zipped_list:
        #determine tree count
        treecount = i[1]
        treename = str(treecount)
        treename = re.sub('\W+','',treename)
        if treecount != 0.0:
            filename = str(i[0] + ".png")
            folder = str(i[1])
            #move tiles to labelled folders
            for file in os.listdir(in_path):
                if file.endswith(".png"):
                    if file in filename:
                        out_path_join = os.path.join(str(out_path),str(folder))
                        oldname = out_path_join+"\\"+filename
                        newname = out_path_join+"\\"+str(treename)+"_"+filename
                        isExist = os.path.exists(out_path_join)
                        if not isExist:
                            os.makedirs(out_path_join)
                            #print("The new directory is created!")
                        in_path_join = os.path.join(in_path,filename)
                        shutil.move(in_path_join,out_path_join)
                        os.rename(oldname,newname)
    time2 = time.time()
    total_time = time2-time1
    return total_time

In [2]:
def clip_pts(out_path,good_pts,size,image,xmax,ymax,chan,patch_size,padding,suffix):
    import numpy as np
    import cv2, time
    patch_list = []
    TC_list = []
    time1 = time.time()
    for i in good_pts:
        x = int(i[1])
        y = int(i[0])

        img_name = str(x)+"_"+str(y)

        x1 = x-size
        x2 = x+size
        y1 = y-size
        y2 = y+size
        img = image[x1:x2,y1:y2]
        
        if chan == 3:
            #multi-channel input = image for tiling
            img = np.einsum('kij->jik',img)
        
            b1 = img[0]
            b2 = img[1]
            b3 = img[2]
        
            b1 = check_NoData(b1)
            b2 = check_NoData(b2)
            b3 = check_NoData(b3)
            
            if np.all(b1 != -9999) and np.all(b2 != -9999) and np.all(b3 != -9999):
                array = [b1,b2,b3]
                patch = np.dstack(array)
                padded_patch = cv2.copyMakeBorder(patch, padding,padding,padding,padding,cv2.BORDER_CONSTANT,value=0)
                width, height, depth = padded_patch.shape
                full_tile = patch_size+(2*padding)
                
                if width != full_tile:
                    print("tile size invalid, width ",width)
                elif height != full_tile:
                    print("tile size invalid, height ",height)
                else:
                    patch_list.append(img_name+str(suffix))
                    cv2.imwrite((out_path+str(img_name)+str(suffix)+'.png'),padded_patch)
            
        else:
            #single channel input = annotation file
            filename = img_name+str(suffix)
            ones = np.count_nonzero(img)
            with open(out_path+str(filename)+'.txt','w') as f:
                f.write(str(ones))
            patch_list.append(img_name+str(suffix))
            TC_list.append(ones)
             #cv2.imwrite((out_path+str(img_name)+str(suffix)+'.png'),img) #use this if you want to write the annotation patch to file   
    
    time2 = time.time()
    total_time = time2-time1
    return patch_list,TC_list,total_time

In [None]:
def randomize_patches(in_path,in_list):
    rand_filenames = random.Random(4).shuffle(in_list)
    for i in rand_filenames:
        new_name = i
        for file in os.listdir(in_path):
            if file.endswith(".png"):
                old_name = str(file)
                old_file = os.path.join(in_path, old_name)
                new_file = os.path.join(in_path, new_name)
                os.rename(old_file, new_file)

In [None]:
#identify and move suitable patches based on tree counts

def get_TC(in_path,datalist): #use this if you want to return a list of full images and their associated tree counts
    tc = []
    import os
    import ntpath
    data_path = os.path.abspath(in_path)
    for f in os.listdir(data_path):
        if f.endswith(".txt"):
            filename = str(f[:-4])
            #print(filename)
            if filename in datalist:
                #print("listed")
                filename = data_path+'/' + f
                print(filename)
                with open(filename,'r') as infile:
                    lines = infile.readlines()
                    #print(lines)
                    TC_id = ntpath.basename(f)
                    TC_id = TC_id[:-4]
                    #print(TC_id)
                    tc.append((str(TC_id),lines))
    return tc

In [None]:
def get_tiles(in_path,patch_size,padding): #review tiles and delete any that contain NoData values
    import numpy as np
    import os
    import cv2
    tiles = []
    data_path = os.path.abspath(in_path)
    for f in os.listdir(data_path):
        if f.endswith(".png"):
            img = cv2.imread(data_path+"/"+f,0)
            padded = patch_size+(2*padding)
            tile_id = f[:-4]
            print(tile_id)
            img2 = img.reshape((padded)*(padded))
            min_val = np.amin(img)#/65535
            #print(min_val)
            if 0<= min_val < 9999:
                tiles.append((str(tile_id),min_val))
            else:
                os.remove(data_path+"/"+f)
    return tiles