In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
def crop_patch(image, y, x, size=112):
    """Given an image, and a Y, X location, this function will extract the patch. 
        Default patch size is 112 x 112"""

    patch = image[abs(size - y) : abs(size + y), abs(size - x) : abs(size + x), :]
    
    return patch


def check_dimensions(image, y, x, size=112):
    """Before a patch is extracted from an image, check to make sure that it will not
        extend pass the dimensions of the image."""
    
    # Get the dimensions of the image first, then check to make sure the patch that would be
    # extacted doesn't over-extend the boundaries of the image.
    height, width = image.shape[0:2]
    
    if(x + (size//2) > width or x - (size//2) < 0 or y + (size//2) > height or y - (size//2) < 0):
        return False
    else:
        return True
    
    
def create_patches(img, percent, offset=112):
    """Given an image, this function will extract patches from it in a uniform grid. To increase the
        density of the grid, increase the percentage of pixel-indices that should be sampled."""

    # Create a grid of points to extract patches from, use the percent variable to 
    # vary the density.
    num_points = int((img.shape[0] * img.shape[1]) * percent)
    density = int(np.sqrt(num_points)) 

    x_, y_ = np.meshgrid(np.linspace(offset, img.shape[1] - offset, int(density)), 
                         np.linspace(offset, img.shape[0] - offset, int(density)))

    xy = np.dstack([x_, y_]).reshape(-1, 2).astype(int)
    
    # Filter out points that exceed the extent of the image
    sparse_points = [p.tolist() for p in xy if check_dimensions(image, p[1], p[0])]
        
    # Extract the patches 
    patches = [crop_patch(img, point[1], point[0]) for point in sparse_points]
    
    print("Number of patches: ", len(patches))
    
    return np.array(sparse_points), patches

In [None]:
image_file = "..\\Figures\\example_coral_patch.png"
image = plt.imread(image_file)

plt.figure(figsize=(10, 10))
plt.imshow(image)
plt.show()

In [None]:
# Extracting points and the corresponding patches from the image
# Note that we're sampling 0.001% pixels.
sparse_points, patches = create_patches(image, 0.00035, offset = 56)

num_patches = len(patches)

In [None]:
# Display the image with the location of the patches
plt.figure(figsize=(20, 10))
plt.imshow(image)
plt.scatter(sparse_points.T[0], sparse_points.T[1], c='red')
plt.show()

In [None]:
# Save all of the patches and sparse points (if needed)

patch_names = []

for index, patch in enumerate(patches):
   
    x, y = sparse_points[index]
    
    image_name = os.path.basename(image_file).split(".")[0]
    patch_name = "Patches\\" + image_name + "_patch_" + str(x) + "_" + str(y) + ".png"
    patch_names.append(patch_name)
    
    plt.imsave(patch_name, patch)
    
    
pd.DataFrame(list(zip(x.tolist(), y.tolist(), patch_names)), 
             columns  = ['X', 'Y', 'file']).to_csv(image_name + "_Sparse_Points.csv")