In [2]:
import time
import os
import glob
import json
import cv2
import math
import numpy as np
from PIL import Image
from IPython.display import clear_output
from statistics import mode
from statistics import mean
from keras.models import load_model

In [3]:
# set parameters
checkpoint_path = r'E:/USGS/checkpoints/final'

# process path
process_path = r'E:/USGS/test/'

# set the confidence threshold for the network
threshold = 0.85

In [4]:
# load best model
model = load_model(checkpoint_path)

In [5]:
# define a function to convert 8bit arrays to 1bit arrays that the network can understand
def eight_to_one(crop):
    
    # scale each array to be between 0 and 1
    array = crop/255
    
    # reshape it, network expects <samples, x_dimension, y_dimension, bands>
    array = np.expand_dims(array, axis=2)
    array = np.expand_dims(array, axis=0)

    return array

In [6]:
# use a slightly different function to do the tiling, note the log_tiles flag
def tile_scan(source_path, desired_samples_per_image, desired_tile_dimension, log_tiles):

    # get start time
    start_time = time.time()
    
    # make a dict to store overall results
    results = {}

    # get a list of scans to process
    images_to_process = glob.glob(process_path + '*.TIFF')

    # loop through all the images
    for image in images_to_process:

        # get the image name (adjust if filetype changes, PLEASE READ THIS TIM)
        image_name = image[len(source_path):-5]
        
        # set the tile save path
        tile_path = os.path.join(process_path, image_name)
        
        # print name
        print('Processing:', image_name)
        
        # check if tiles are to be saved
        if log_tiles == True:
            
            # don't create a new directory if it exists
            if not os.path.exists(os.path.join(process_path, image_name)):
                
                # create a subfolder to save tiles
                os.mkdir(tile_path)

        # open the file
        img = Image.open(image)

        # get the dimensions of the image
        imageWidth, imageHeight = img.size
        
        # get the number of pixels per tile
        pixel_count = desired_tile_dimension*desired_tile_dimension

        # sets the x and y ranges 
        rangex = math.ceil(imageWidth / desired_tile_dimension)
        rangey = math.ceil(imageHeight / desired_tile_dimension)
        total_frames = rangex * rangey
        
        # make a counter to count samples attempted and successful
        number_of_samples = 0
        attempted_samples = 0
        
        # make a list to store scan results
        sample_list = []
        
        # tile sample tuples
        tile_tuples = []
        
        # only generate the minimum samples
        while number_of_samples < desired_samples_per_image and attempted_samples < 250: 

            # get random x and y coordinates to sample, stay away from the edges of the scan
            rand_x = np.random.randint(2, rangex-2)
            rand_y = np.random.randint(6, rangey-3)
            
            # generate a coordinate tuple
            tile_tuple = (rand_x, rand_y)
            
            # if the tile is unqiue, process it
            if tile_tuple not in tile_tuples:
                
                # note a new attempt
                attempted_samples += 1
                
                # mark that tile as processed
                tile_tuples.append(tile_tuple)

                # set the crop coordinates. box = (<start x>, <start y>, <end x>, <end y>)
                box = (rand_x*desired_tile_dimension, 
                       rand_y*desired_tile_dimension, 
                       rand_x*desired_tile_dimension+desired_tile_dimension, 
                       rand_y*desired_tile_dimension+desired_tile_dimension)

                # crop each tile
                tile = img.crop(box)

                # get the image as an array
                tile_array = np.asarray(tile)

                # convert to binary
                ret, binary_tile8 = cv2.threshold(tile_array, 127, 255, cv2.THRESH_BINARY)

                # format the crop for network processing
                binary_tile = eight_to_one(binary_tile8)

                # ignore mostly white images. We want between 50% and 95% white
                if np.sum(binary_tile) < 0.95*pixel_count and np.sum(binary_tile) > 0.5*pixel_count:
                    
                    # get the classification, the network expects a list
                    prediction = model.predict(binary_tile)

                    # check that the tile exceeds our confidence threshold
                    if np.max(prediction) > threshold:

                        # note successful sample
                        number_of_samples += 1

                        # convert the one-hot encoding to the actual class labels
                        label = np.where(prediction == np.max(prediction))[1][0]

                        # save the prediction and confidence to the results list
                        results_dict = {'x': int(rand_x),
                                        'y': int(rand_y), 
                                        'class_label': int(label),
                                        'confidence': float(np.max(prediction))}

                        # add the sample data to the sample list
                        sample_list.append(results_dict)

                        # check if tiles are to be saved
                        if log_tiles == True:

                            # set the tile save name
                            padding_x_zeros = len(str(rangex))
                            padding_y_zeros = len(str(rangey))
                            filename = '{}_{}_{}-{}.png'.format(image_name, 
                                                                str(rand_y).zfill(padding_y_zeros), 
                                                                str(rand_x).zfill(padding_x_zeros),
                                                                str(label))
                            # save the tile
                            Image.fromarray(binary_tile8).save(os.path.join(tile_path, filename))

                # note another sample
                else:
                    attempted_samples += 1
                
                # add to the overall data dict
                results[image_name] = sample_list

        # clear the output 
        clear_output(wait=True)
        
    # save  to a log
    with open(process_path + 'tile_classification_results0.json', 'a') as outfile:
        outfile.write(json.dumps(results)) 
        
    # get end time
    end_time = time.time()
    
    print('Processing time: {:.3}s'.format(end_time - start_time))

In [7]:
tile_scan(process_path, 200, 200, log_tiles=False)

Processing time: 3.36e+02s


Took 455s for 74 scans each with 100 samples. ~6s per scan

# Classify each scan

In [8]:
# data path
process_path = r'E:/USGS/test/'

# json log path
json_path = process_path + 'tile_classification_results0.json'

In [9]:
# load json
with open(json_path, 'r') as data:
     json_data = json.loads(data.read())

In [10]:
# iterate through the image dictionary
for key in json_data.keys():
    
    # set up a list to store class labels for each scan
    class_label_list = []
    
    # get the image name
    name = key
    
    # get the predictions
    prediction_list = json_data[key]
    
    # loop through the predictions
    for predict_index in range(0, len(prediction_list)):
        
        # get the class
        classification = prediction_list[predict_index]['class_label']
        
        # add it to the list
        class_label_list.append(classification)
    
    # check for guesses
    if len(class_label_list) != 0:
    
        # get some basic statistics 
        scan_mean = mean(class_label_list)
        scan_mode = mode(class_label_list)
        scan_max = max(class_label_list)

        # apply logic to get an overall label for each scan, this can and should be modified
        if scan_max == 0:
            label = '0' # no interest

        elif scan_max == 1:
            label = '1' # little interest

        elif scan_max == 2 and scan_mode ==2:
            label = '3' # high interest
            
        elif scan_max == 2 and scan_mean < 1.5:
            label = '2' # interest
            
        elif scan_max == 2 and scan_mean > 1.5:
            label = '3' # high interest
            
    # add an error label
    else:
        label = 'ERROR: could not determine'
        
    # save the results to a log
    with open(process_path + 'scan_classification_results0.csv', 'a') as outfile:
        outfile.write('{}, {}, {}, {}, {}\n'.format(name, scan_mean, scan_mode, scan_max, label)) 