In [1]:
# BSC THESIS - MACHINE LEARNING FOR UNEXPLODED ORDNANCE (UXO)
# THIS NOTEBOOK IS DEVELOPED ENTIRELY BY JONAS KNUDSEN

In [None]:
# %%%%%%%%%%%%%%%%      MODULES     %%%%%%%%%%%%%%%% #

# import osgeo.gdal as gdal # doesn't work on the GPU
import numpy as np
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import geopandas as gpd
# import piexif
import os
import time

In [1]:
# %%%%%%%%%%%%%%%% GLOBAL VARIABLES %%%%%%%%%%%%%%%% #

project_root = '/scratch/s204219/' #'C:/Users/nowax/OneDrive - Danmarks Tekniske Universitet/Skrivebord/DTU/Bachelor projekt/Code/Data'

survey_location = 'randomField' # 'randomField' or 'footballPitch' 
survey_number = '4'

dataset = f'{survey_location}_survey_{survey_number}'
path_to_dataset = f'{project_root}/{survey_location}/survey_{survey_number}' 

orthophoto_path = f'{path_to_dataset}/{dataset}_orthophoto.tif'

shapefile_path = f'{path_to_dataset}/qgis/{dataset}_polygons.shp'

path_to_targets = f'{path_to_dataset}/targets'
BB_path = f'{path_to_targets}/{dataset}_BB.npy'
target_annotations = f'{path_to_targets}/{dataset}_target'

save_destination_augmented    = f'{project_root}/augmentedData_hard'
save_destination_validation   = f'{project_root}/validationData'
save_destination_test         = f'{project_root}/testData'

#dataset = gdal.Open(orthophoto_path)
#geotransform_global = dataset.GetGeoTransform()
geotransform_global = tuple(np.load(f'{path_to_dataset}/geotransform_global.npy'))


orthophoto = Image.open(orthophoto_path)

dataset_start_number = 1000001
dsn = dataset_start_number

In [3]:
# %%%%%%%%%%%%%%%%     FUNCTIONS    %%%%%%%%%%%%%%%% #

def gps_to_pixel(lon, lat, geotransform, rowcol_format=True):
    
    
    # geotransform[0] pulls the longitude in UTM zone 33N 
    # geotransform[3] pulls the latitude in UTM zone 33N
    
    x = np.round((lon - geotransform[0]) / geotransform[1])
    y = np.round((lat - geotransform[3]) / geotransform[5])
    
    if x.shape != ():
        x = np.asarray(list(map(int,x)))
        y = np.asarray(list(map(int,y)))
    else:
        x = int(x)
        y = int(y)
    
    if rowcol_format == True:
        (x,y) = (y,x)
    
    return x, y


def gps_to_pixel_with_rotation(lon, lat, geotransform,rowcol_format=True):
    
    # Find by solving system of equations from gps_to_pixel() function
    x = (-lat+geotransform[3])*geotransform[2]+geotransform[5]*(-geotransform[0]+lon)
    x = x/(geotransform[1]*geotransform[5]-geotransform[2]*geotransform[4])
    x = round(x)
    
    y = (lat-geotransform[3])*geotransform[1]-geotransform[4]*(-geotransform[0]+lon)
    y = y/(geotransform[1]*geotransform[5]-geotransform[2]*geotransform[4])
    y = round(y)
    
    if rowcol_format == True:
        (x,y) = (y,x)
        
    return x,y


def pixel_to_gps(row, col, geotransform, rowcol_format=True):
    
    if rowcol_format != True: # change 
        (col,row) = (row,col)
        
    lon = geotransform[0] + col * geotransform[1] + row * geotransform[2] 
    lat = geotransform[3] + col * geotransform[4] + row * geotransform[5]
        
    return lon, lat


def export_annotations_from_qgis():
    
    
    gdf = gpd.read_file(f'{shapefile_path}')
    polygons = []

    # Print georeference system
    # print(gdf.crs) 

    # Append the coordinates of each polygon
    for i, row in gdf.iterrows():
        coords = row.geometry.exterior.coords
        polygons.append(np.array(coords))
    
    targets = [[] for _ in range(len(polygons))]
    for i in range(len(polygons)):
        for corner in range(polygons[i].shape[0]):
            x,y = polygons[i][corner,0],polygons[i][corner,1]
            #x,y = gps_to_pixel(polygons[i][corner,0], polygons[i][corner,1], rowcol_format=False)
            targets[i].append((x,y))
    
    for i, arr in enumerate(targets):
        np.save(f'{target_annotations}_{i}.npy', arr)
    
    # Create bounding boxes 
    n = len(polygons)
    BB = np.zeros((n,4))

    for i in range(n):
        poly = polygons[i]
        min_poly = np.min(poly,axis=0)
        max_poly = np.max(poly,axis=0)
        # Center point
        BB[i,0], BB[i,1] = (1/2)*(min_poly+max_poly)
        # Height and width
        BB[i,3], BB[i,2] = max_poly-min_poly

    #print(BB)
    np.save(f'{BB_path}',BB)
    
    return 


def load_BB_and_targets(geotransform):
    BB = np.load(f'{BB_path}')
    n = BB.shape[0]
    
    targets = []
    for i in range(n):
        target = np.load(f'{target_annotations}_{i}.npy')
        targets.append(target)
    
    return BB, targets


def save_target_BB_and_masks():
    
    _, targets = load_BB_and_targets(geotransform_global)
    
    n = len(targets)
    
    # Converting target polygons from GPS coordinates to proper format in pixel coordinates in the orthophoto
    targets_pixel = [[] for _ in range(n)]
    for i in range(n):
        for corner in range(targets[i].shape[0]):
            x,y = gps_to_pixel(targets[i][corner,0], targets[i][corner,1],geotransform_global, rowcol_format=False)
            targets_pixel[i].append((x,y))

    for i in range(n):
        target = targets_pixel[i]

        # Create a binary mask of the polygon in the orthophoto
        mask = Image.new('1', orthophoto.size, 0)
        draw = ImageDraw.Draw(mask)
        draw.polygon(target, fill=1)

        # Create a paste mask based on the binary mask
        paste_mask = mask.copy()

        # Crop the image to the polygon bounds
        
        # (col0,row0, col1,row1)
        bbox = mask.getbbox() # gets coordinates of top left and down right BB pixel
        
        # bbox = list(bbox)
        # bbox[2] -= 1
        # bbox[3] -= 1
        
        # # if (bbox[2]-bbox[0]) % 2 == 1:
        # #     bbox[2] -= 1
        # # if (bbox[3]-bbox[1]) % 2 == 1:
        # #     bbox[3] -= 1
        
        # bbox = tuple(bbox)
        
        target_BB = orthophoto.crop(bbox) # crops the orthophoto to only be within the BB
        paste_mask = paste_mask.crop(bbox) # crops the masked image to only be within the BB
        
        np.save(f'{target_annotations}_BB_'+str(i)+'.npy',target_BB)
        np.save(f'{target_annotations}_mask_'+str(i)+'.npy',paste_mask)

    return


def convert_BB_to_pixel_coords(im, BB, geotransform):
    
    n_row, n_col,_ = np.shape(im)
    
    n,m = BB.shape
    
    BB_pixel = np.zeros((n,m),dtype=int)
    BB_pixel[:,0],BB_pixel[:,1] = gps_to_pixel(BB[:,0],BB[:,1],geotransform,rowcol_format=True)
    BB_pixel[:,2] = np.asarray(list(map(int,np.ceil((-BB[:,2]/geotransform[5])/2)*2)))
    BB_pixel[:,3] = np.asarray(list(map(int,np.ceil(( BB[:,3]/geotransform[1])/2)*2))) 
    
    BB_pixel_in_image = []
    
    # only takes the BB of the targets within the frame
    # if BB is partly outside the frame, then neglect the target
    for i in range(n):
        if 0 <= BB_pixel[i,0]-BB_pixel[i,2]/2 and BB_pixel[i,0]+BB_pixel[i,2]/2 < n_row and \
            0 <= BB_pixel[i,1]-BB_pixel[i,3]/2 and BB_pixel[i,1]+BB_pixel[i,3]/2 < n_col:
            BB_pixel_in_image.append(BB_pixel[i,:])
            
    BB_pixel = np.asarray(BB_pixel_in_image)
    
    return BB_pixel


def draw_bounding_boxes(im,BB_pixel):
    
 
    BB_image = im.copy()
    n = BB_pixel.shape[0]
    
    draw = ImageDraw.Draw(BB_image)
    
    for i in range(n):
        row_up    = int(BB_pixel[i,0]+BB_pixel[i,2]/2)
        row_down  = int(BB_pixel[i,0]-BB_pixel[i,2]/2)
        col_left  = int(BB_pixel[i,1]-BB_pixel[i,3]/2)
        col_right = int(BB_pixel[i,1]+BB_pixel[i,3]/2)
        
        points = [(col_left,row_up),(col_right,row_down)]
        
        draw.rectangle(points, outline="red",width=1)
        
        
    plt.figure()
    plt.imshow(BB_image)
    plt.show()
    
    return 


def crop_image(im,geotransform,top_left,down_right,show_cropped_image=True):
    
    lon1,lat1 = top_left
    lon2,lat2 = down_right
    
    row1,col1 = gps_to_pixel(lon1,lat1,geotransform)
    row2,col2 = gps_to_pixel(lon2,lat2,geotransform)    
    
    cropped_image = np.asarray(im)[row1:row2+1,col1:col2+1,:]
    cropped_image = Image.fromarray(cropped_image)
    if show_cropped_image == True:
        plt.figure()
        plt.imshow(cropped_image)
        plt.show()
    
    new_geotransform = list(geotransform)
    
    new_geotransform[0] = top_left[0]
    new_geotransform[3] = top_left[1]
        
    return tuple(new_geotransform), cropped_image


def divide_image(im, square_size, show_image=False):

    # Get the size of the image
    
    im = im.copy()
    im = np.asarray(im)
    
    height, width, channels = np.shape(im)
    
    # Calculate the number of squares that fit in the image
    num_rows = height // square_size
    num_cols = width // square_size
    
    # Create an empty array to store the sub-images
    sub_images = np.empty((square_size, square_size, num_rows, num_cols, channels), dtype=im.dtype)
    # (rows, columns, square size, square size, channels)
    
    # Loop through each sub-image and copy it to the sub_images array
    for i in range(square_size):
        for j in range(square_size):
            # Copy the sub-image
            sub_images[j, i,:,:,:] = im[j*num_rows:(j+1)*num_rows, i*num_cols:(i+1)*num_cols,:]
            
            
            
    if show_image == True:
        show_square_grid_images(sub_images)
        
    return sub_images


def show_square_grid_images(sub_images):
    # sub_images is in numpy format
    
    square_size, _, num_rows, num_cols, _ = sub_images.shape

    # Create a figure with a grid of subplots
    fig_size = ( 1.3*10, 10)
    fig, axes = plt.subplots(square_size, square_size, figsize=fig_size)

    # Loop through each subplot and display the corresponding sub-image
    for i in range(square_size):
        for j in range(square_size):
            axes[j, i].imshow(sub_images[j,i,:,:,:])
            axes[j, i].axis('off')


    plt.subplots_adjust(wspace=0.05, hspace=0.05)
    # Show the figure
    plt.show()
    
    return


def copy_paste(im,BB_pixel_global, BB_pixel,spacing,target_idx,num_points,rng,show_image=True):
    
    def find_random_points(BB_pixel, spacing, num_points,im_shape,rng):
        
        n_row, n_col, _ = im_shape
        n = BB_pixel.shape[0]
        
        if n >= num_points: # 
            return []
        
        
        coords = []
        for i in range(n):
            coords.append((BB_pixel[i,1],BB_pixel[i,0]))
        
        # center (x,y) coordinates of the actual targets in the image
        # coords = [(col0,row0), (col1,row1), ... ]
        
                
        
        while len(coords) < num_points:
            x = rng.integers(low=spacing/2, high=n_col-spacing/2)
            y = rng.integers(low=spacing/2, high=n_row-spacing/2)
            
            # Check if distance between current coordinate and previous coordinate
            # is greater than or equal to the desired spacing
            if all(abs(x - prev_x) > spacing and abs(y - prev_y) > spacing for prev_x, prev_y in coords):
                coords.append((x, y))
                #print(len(coords))
                
        random_points = coords[BB_pixel.shape[0]:len(coords)+1]

        return random_points
    
    def update_BB_pixel(BB_pixel_global, BB_pixel,random_points, target_idx):
        
        
        
        if BB_pixel.shape[0] == 0:
            temp = np.zeros((len(random_points),4),dtype=BB_pixel.dtype)
            
            temp[BB_pixel.shape[0]:temp.shape[0],0] = np.array(random_points,dtype=BB_pixel.dtype)[:,1]
            temp[BB_pixel.shape[0]:temp.shape[0],1] = np.array(random_points,dtype=BB_pixel.dtype)[:,0]
            
            if np.shape(target_idx) == ():
                n_row,n_col = BB_pixel_global[target_idx,2:4]
                temp[BB_pixel.shape[0]:temp.shape[0],2:4] = np.tile(np.array(n_row,n_col),(len(random_points),1))
            else: 
                for i in range(len(random_points)):
                    temp[BB_pixel.shape[0]+i,2:4] = BB_pixel_global[target_idx[i],2:4]
        else:
            temp = np.zeros((BB_pixel.shape[0]+len(random_points),4),dtype=BB_pixel.dtype)
            
            temp[0:BB_pixel.shape[0],:]= BB_pixel
            temp[BB_pixel.shape[0]:temp.shape[0],0] = np.array(random_points,dtype=BB_pixel.dtype)[:,1]
            temp[BB_pixel.shape[0]:temp.shape[0],1] = np.array(random_points,dtype=BB_pixel.dtype)[:,0]
            
            if np.shape(target_idx) == ():
                n_row,n_col = BB_pixel_global[target_idx,2:4]
                temp[BB_pixel.shape[0]:temp.shape[0],2:4] = np.tile(np.array(n_row,n_col),(len(random_points),1))
            else: 
                for i in range(len(random_points)):
                    temp[BB_pixel.shape[0]+i,2:4] = BB_pixel_global[target_idx[i],2:4]
                    
        return temp
    
    
    if np.shape(target_idx) == ():
        target_BB  = Image.fromarray(np.load(f'{target_annotations}_BB_'+str(target_idx)+'.npy'))
        paste_mask = Image.fromarray(np.load(f'{target_annotations}_mask_'+str(target_idx)+'.npy'))
    
    rp = find_random_points(BB_pixel, spacing, num_points,np.shape(im),rng)
    n = len(rp)
    
    if n != 0:
        
        BB_pixel_updated = update_BB_pixel(BB_pixel_global,BB_pixel,rp, target_idx)

        
        im_augmented = im.copy()

        for i in range(n):
            if np.shape(target_idx) == ():
                im_augmented.paste(target_BB, (int(rp[i][0]-BB_pixel_global[target_idx,3]/2),int(rp[i][1]-BB_pixel_global[target_idx,2]/2)), mask=paste_mask)
            else:
                target_BB  = Image.fromarray(np.load(f'{target_annotations}_BB_'+str(target_idx[i])+'.npy'))
                paste_mask = Image.fromarray(np.load(f'{target_annotations}_mask_'+str(target_idx[i])+'.npy'))

                im_augmented.paste(target_BB, (int(rp[i][0]-BB_pixel_global[target_idx[i],3]/2),int(rp[i][1]-BB_pixel_global[target_idx[i],2]/2)), mask=paste_mask)
    if n == 0:
        im_augmented = im
        BB_pixel_updated = BB_pixel
    
    if show_image == True:
        plt.figure()
        plt.imshow(im_augmented)
    
    return im_augmented, BB_pixel_updated


def find_spacing(BB_pixel_global,buffer = 10):
    
    spacing = np.max(BB_pixel_global[:,2:4]) + buffer # +10 for buffer
    
    return spacing


def create_augmented_data(im,BB_pixel_global,BB_pixel,geotransform,spacing,target_idx,num_points,sub_shape,count,rng,save=True,show_images=True):
    
    #print('##### CREATING AUGMENTED DATA #####\n')
    
    
    
    n_row, n_col, _ = np.shape(im)
    n_row_sub, n_col_sub = sub_shape
    
    if target_idx == 'random':
        
        target_idx = rng.integers(low=0,high=BB_pixel_global.shape[0],size=num_points)
    
    
    im_np = np.asarray(im)
    geotransform_sub = geotransform
    # Creates loop that divides the image. In total the sub images may not constitute the entire image
    
    idx_y = np.arange(0, n_row-n_row_sub, n_row_sub)
    idx_x = np.arange(0, n_col-n_col_sub, n_col_sub)
    
    if show_images == True:
        fig_size = ( 10, 10)
        fig, axes = plt.subplots(len(idx_y), len(idx_x), figsize=fig_size)
        
    number_of_iterations = len(idx_y)*len(idx_x)
    count_print = 0
    i = 0
    j = 0
    for y in idx_y:
        for x in idx_x:
            
            geotransform_sub = list(geotransform_sub)
            
            sub_image_data = im_np[y:y+n_row_sub, x:x+n_col_sub]
            sub_image = Image.fromarray(sub_image_data)
            
            lon, lat = pixel_to_gps(y, x, geotransform,rowcol_format = True)
            
            geotransform_sub[0] = lon
            geotransform_sub[3] = lat
            
            geotransform_sub = tuple(geotransform_sub)
            
            BB_pixel_sub = convert_BB_to_pixel_coords(sub_image, BB, geotransform_sub)
            
            #print(BB_pixel_sub)
            im_augmented, BB_pixel_updated = copy_paste(sub_image,BB_pixel_global, BB_pixel_sub,spacing,target_idx,num_points,rng,show_image=False)
            count += 1
            count_print += 1
            
            # print_statement = str(np.round(100*count_print/number_of_iterations,2)) + '%'
            # if np.shape(target_idx) != ():
            #     #print(f'{target_idx} | {print_statement}')
            #     print(str(target_idx)+ ' | ' + str(count)+'/'+str(number_of_iterations))
            # else:
            #     print(print_statement)
            
            if save == True:
                # Save both image and BB
                rgb_image = np.array(im_augmented)[:,:,:-1]
                
                np.save(f'{save_destination_augmented}/images/image_'+str(dsn-1+count)+'.npy',rgb_image)
                np.save(f'{save_destination_augmented}/labels/BB_'+str(dsn-1+count)+'.npy',BB_pixel_updated.astype(int))
                
            if show_images == True:
                axes[j, i].imshow(im_augmented)
                axes[j, i].axis('off')
            
            if np.shape(target_idx) != ():
                target_idx = rng.integers(low=0,high=BB_pixel_global.shape[0],size=num_points)
                
            i+=1
        i = 0
        j+= 1
    
    if show_images == True:
        plt.subplots_adjust(wspace=0.05, hspace=0.05)
        # Show the figure
        plt.show()

    #print('\n##### FINISHED #####')
    return count


def create_validation_data(im,BB_pixel_global,BB_pixel,geotransform,sub_shape,count,save=True,show_images=True):
    
    print('##### CREATING VALIDATION DATA #####\n')

    
    n_row, n_col, _ = np.shape(im)
    n_row_sub, n_col_sub = sub_shape
    
    rng = np.random.default_rng(seed=1)
    
    # if target_idx == 'random':
        
    #     target_idx = rng.integers(low=0,high=BB_pixel_global.shape[0],size=num_points)
    
    
    im_np = np.asarray(im)
    geotransform_sub = geotransform
    # Creates loop that divides the image. In total the sub images may not constitute the entire image
    
    idx_y = np.arange(0, n_row-n_row_sub, n_row_sub)
    idx_x = np.arange(0, n_col-n_col_sub, n_col_sub)
    
    if show_images == True:
        fig_size = ( 10, 10)
        fig, axes = plt.subplots(len(idx_y), len(idx_x), figsize=fig_size)
        
    number_of_iterations = len(idx_y)*len(idx_x)
    
    i = 0
    j = 0
    for y in idx_y:
        for x in idx_x:
            
            geotransform_sub = list(geotransform_sub)
            
            sub_image_data = im_np[y:y+n_row_sub, x:x+n_col_sub]
            sub_image = Image.fromarray(sub_image_data)
            
            lon, lat = pixel_to_gps(y, x, geotransform,rowcol_format = True)
            
            geotransform_sub[0] = lon
            geotransform_sub[3] = lat
            
            geotransform_sub = tuple(geotransform_sub)
            
            BB_pixel_sub = convert_BB_to_pixel_coords(sub_image, BB, geotransform_sub)
            
            #print(BB_pixel_sub) 
            #im_augmented, BB_pixel_updated = copy_paste(sub_image,BB_pixel_global, BB_pixel_sub,spacing,target_idx,num_points,rng,show_image=False)
            count += 1
            print_statement = str(np.round(100*count/number_of_iterations,2)) + '%'
            if np.shape(target_idx) != ():
                #print(f'{target_idx} | {print_statement}')
                print(str(count)+'/'+str(number_of_iterations))
            else:
                print(print_statement)
            
            if save == True:
                # Save both image and BB
                rgb_image = np.array(sub_image)[:,:,:-1]
                
                np.save(f'{save_destination_validation}/image_'+str(2000+count)+'.npy',rgb_image)
                np.save(f'{save_destination_validation}/BB_'+str(2000+count)+'.npy',BB_pixel_sub.astype(int))
                
            if show_images == True:
                axes[j, i].imshow(sub_image)
                axes[j, i].axis('off')
            
            # if np.shape(target_idx) != ():
            #     target_idx = rng.integers(low=0,high=BB_pixel_global.shape[0],size=num_points)
                
            i+=1
        i = 0
        j+= 1
    
    plt.subplots_adjust(wspace=0.05, hspace=0.05)
    # Show the figure
    plt.show()
    
    print('\n##### FINISHED #####')
    return


def delete_files(root,numbers_list):
    for number in numbers_list:
        BB_filename    = f"{root}/labels/BB_{number}.npy"
        image_filename = f"{root}/images/image_{number}.npy"
        if os.path.exists(BB_filename):
            os.remove(BB_filename)
            os.remove(image_filename)
        else:
            #print('Done')
            return
    return


def create_folders(dataset_name):

    root = '/scratch/s204219'
    path = f'{root}/{dataset_name}'

    if os.path.exists(path):
        choice = input(f"{dataset_name} already exists. Do you want to overwrite it? (y/n)")
        if choice.lower() != 'y':
            return
        print('overwriting')
    if not os.path.exists(path):
        os.makedirs(path)
    
    if not os.path.exists(f'{path}/data_split'):
        os.makedirs(f'{path}/data_split')
    
    for i in range(1,5):
        survey = 'survey_' + str(i)
        
        if not os.path.exists(f'{path}/{survey}'):
            os.makedirs(f'{path}/{survey}')
        
        if not os.path.exists(f'{path}/{survey}/images'):
            os.makedirs(f'{path}/{survey}/images')
        
        if not os.path.exists(f'{path}/{survey}/labels'):
            os.makedirs(f'{path}/{survey}/labels')
        
        
    return


In [2]:
# PREPARING DATA
number_of_images = 25000 # per survey
save = True #Change to True if you want to save the images
show_images = False
save_split = True
if save_split == True:
    val_ratio  = 1/10.
    test_ratio = 1/10.


old_count = 0
count = old_count

dataset_name = 'augmentedData_hard_big'
create_folders(dataset_name=dataset_name)


target_idx = 'random' #0 # 'random' # 0
num_points = 2 # 1
sub_shape = (256,256)

start_time_global = time.perf_counter()
print('###### CREATING AUGMENTED DATA ######')

for k in range(4):
    
    if True: # GLOBAL VARIABLES #
        project_root = '/scratch/s204219/' #'C:/Users/nowax/OneDrive - Danmarks Tekniske Universitet/Skrivebord/DTU/Bachelor projekt/Code/Data'

        survey_location = 'randomField' # 'randomField' or 'footballPitch' 
        survey_number = str(k+1)

        dataset = f'{survey_location}_survey_{survey_number}'
        path_to_dataset = f'{project_root}/{survey_location}/survey_{survey_number}' 

        orthophoto_path = f'{path_to_dataset}/{dataset}_orthophoto.tif'

        shapefile_path = f'{path_to_dataset}/qgis/{dataset}_polygons.shp'

        path_to_targets = f'{path_to_dataset}/targets'
        BB_path = f'{path_to_targets}/{dataset}_BB.npy'
        target_annotations = f'{path_to_targets}/{dataset}_target'

        save_destination_augmented    = f'{project_root}/{dataset_name}/survey_{survey_number}'
        #save_destination_validation   = f'{project_root}/validationData'
        #save_destination_test         = f'{project_root}/testData'

        #dataset = gdal.Open(orthophoto_path)
        #geotransform_global = dataset.GetGeoTransform()
        geotransform_global = tuple(np.load(f'{path_to_dataset}/geotransform_global.npy'))

        orthophoto = Image.open(orthophoto_path)

        dataset_start_number = 1000001
        dsn = dataset_start_number
    
    if k == 0:
        
        
        
        # %% For randomField location
        top_left_list   = [(344185.387,6184645.055),(344209.161,6184640.912),(344195.639,6184640.526),
                           (344210.651,6184625.403),(344217.361,6184633.518),(344185.984,6184638.312),
                           (344210.503,6184615.221),(344189.153,6184642.381),(344209.504,6184640.334)]

        down_right_list = [(344217.391,6184625.633),(344237.806,6184613.757),(344218.544,6184622.147),
                           (344232.121,6184600.014),(344241.263,6184610.532),(344210.028,6184621.602),
                           (344237.767,6184598.515),(344205.644,6184623.083),(344228.450,6184618.873)]

        n = len(top_left_list)
        
        rng = np.random.default_rng(seed=1)

        for i in range(300):
            for j in range(n):
                top_left = top_left_list[j] 
                top_left = tuple(np.round(np.array(top_left) + rng.standard_normal(2)*1.5,3))
                top_left_list.append(top_left)

                down_right = down_right_list[j] 
                down_right = tuple(np.round(np.array(down_right) + rng.standard_normal(2)*1.5,3))
                down_right_list.append(down_right)

        m = len(top_left_list)
        
    if k == 1:
        # %% For randomField location
        top_left_list   = [(344143.609,6184664.736),(344150.996,6184666.034),(344141.997,6184663.169),
                          (344151.936,6184654.438),(344140.161,6184663.393),(344167.204,6184648.125),
                           (344164.070,6184647.812),(344145.803,6184661.199),(344138.281,6184657.483)]

        down_right_list = [(344166.891,6184645.887),(344165.324,6184654.483),(344171.771,6184643.155),
                          (344180.949,6184640.290),(344168.816,6184645.484),(344191.157,6184630.932),
                          (344191.829,6184629.947),(344165.726,6184648.797),(344161.876,6184644.678)]

        n = len(top_left_list)
        
        rng = np.random.default_rng(seed=1)

        for i in range(300):
            for j in range(n):
                top_left = top_left_list[j] 
                top_left = tuple(np.round(np.array(top_left) + rng.standard_normal(2)*1.5,3))
                top_left_list.append(top_left)

                down_right = down_right_list[j] 
                down_right = tuple(np.round(np.array(down_right) + rng.standard_normal(2)*1.5,3))
                down_right_list.append(down_right)

        m = len(top_left_list)
    
    if k == 2:
        # %% For randomField location
        top_left_list   = [(344142.193,6184665.129),(344141.508,6184661.134),(344155.741,6184661.857),
                           (344164.531,6184654.017),(344142.840,6184664.406),(344155.208,6184664.216),
                           (344165.064,6184645.950),(344145.314,6184663.569),(344138.921,6184664.863)]

        down_right_list = [(344162.895,6184645.075),(344163.009,6184642.982),(344174.387,6184645.113),
                           (344185.423,6184638.872),(344167.347,6184645.189),(344175.111,6184646.254),
                           (344184.320,6184629.929),(344158.747,6184644.504),(344169.098,6184644.580)]

        n = len(top_left_list)
        
        rng = np.random.default_rng(seed=1)

        for i in range(300):
            for j in range(n):
                top_left = top_left_list[j] 
                top_left = tuple(np.round(np.array(top_left) + rng.standard_normal(2)*1.5,3))
                top_left_list.append(top_left)

                down_right = down_right_list[j] 
                down_right = tuple(np.round(np.array(down_right) + rng.standard_normal(2)*1.5,3))
                down_right_list.append(down_right)

        m = len(top_left_list)
    
    if k == 3:
        # %% For randomField location
        top_left_list   = [(344144.567,6184661.301),(344140.790,6184661.016),(344151.279,6184657.120),
                           (344166.501,6184643.642),(344151.298,6184658.645),(344165.164,6184650.307),
                           (344162.489,6184645.270),(344139.751,6184663.103),(344153.840,6184656.549)]

        down_right_list = [(344162.098,6184648.366),(344162.561,6184646.264),(344172.252,6184639.205),
                           (344189.105,6184625.319),(344182.952,6184638.047),(344189.774,6184634.168),
                           (344187.232,6184626.455),(344162.756,6184641.926),(344171.762,6184641.837)]

        n = len(top_left_list)
        
        rng = np.random.default_rng(seed=1)

        for i in range(300):
            for j in range(n):
                top_left = top_left_list[j] 
                top_left = tuple(np.round(np.array(top_left) + rng.standard_normal(2)*1.5,3))
                top_left_list.append(top_left)

                down_right = down_right_list[j] 
                down_right = tuple(np.round(np.array(down_right) + rng.standard_normal(2)*1.5,3))
                down_right_list.append(down_right)

        m = len(top_left_list)
    
    BB, targets = load_BB_and_targets(geotransform_global)

    BB_pixel_global = convert_BB_to_pixel_coords(orthophoto,BB,geotransform_global)
    spacing = find_spacing(BB_pixel_global, buffer = 5)
    
    #############################################
    
    plt.close('all')
    #print('###### CREATING AUGMENTED DATA ######')
    for i in range(0,1000): # n
        geotransform_new, cropped_image = crop_image(orthophoto,geotransform_global,\
                                                     top_left_list[i],down_right_list[i],\
                                                     show_cropped_image=False)

        BB_pixel = convert_BB_to_pixel_coords(cropped_image, BB, geotransform_new)

        new_count = create_augmented_data(cropped_image,BB_pixel_global,BB_pixel,\
                                          geotransform_new,spacing,target_idx,\
                                          num_points,sub_shape,count,rng,\
                                          save=save,show_images=show_images)

        count = new_count

        if count >= number_of_images*(k+1):
            #print('Max number of images reached')
            break

        if show_images == True:
            print(str(i+1)+' - images = '+ str(count) )

    #print('Deleting excess images')

    if save == True:
        numbers_list = np.arange(dsn+number_of_images*(k+1),dsn+2*number_of_images*(k+1))
        delete_files(save_destination_augmented,numbers_list)
    #if save == False:
        #print('Done')
    count = number_of_images*(k+1)
    old_count = count
    
    if k == 0:
        start_time = start_time_global
    
    end_time = time.perf_counter()
    # Calculate the elapsed time
    elapsed_time = end_time - start_time

    # Print the elapsed time
    print('Created '+ str(count) + ' images. ' + f'Elapsed time {elapsed_time} seconds' )
    start_time = time.perf_counter()

    

# SAVING THE SPLIT 
if save_split == True:
    
    dataset_size = 25000*(k+1)
    
    file_root = f'{project_root}/{dataset_name}'
    file_list = np.arange(dsn,dsn+dataset_size) 
    
    np.random.seed(1)

    # Randomly select numbers from the array
    file_list = np.random.choice(file_list, size=dataset_size, replace=False)    

    val_split  = int(np.round(val_ratio  * dataset_size))
    test_split = int(np.round(test_ratio * dataset_size))
    train_split= int(dataset_size-(val_split+test_split))

    train_indices = np.array(file_list[0:train_split])
    val_indices   = np.array(file_list[train_split:train_split+val_split])
    test_indices  = np.array(file_list[(train_split+val_split):])
    
    np.save(f'{file_root}/data_split/Index_train',train_indices)
    np.save(f'{file_root}/data_split/Index_val',val_indices)
    np.save(f'{file_root}/data_split/Index_test',test_indices)
    
    
end_time_global = time.perf_counter()

# Calculate the elapsed time
elapsed_time_global = end_time_global - start_time_global


print('##### FINISHED #####')
# Print the elapsed time
print("Elapsed time: ", elapsed_time_global, "seconds")