In [None]:
import cv2
import numpy as np
from matplotlib import pyplot as plt
from collections import namedtuple
%matplotlib inline

In [None]:
det_params_names = ['min_edge_thres', 'max_edge_thres', 'retr_type', 'retr_approx', 'min_poly', 'max_area']
DetectorParams = namedtuple('DetectorParams', det_params_names)

class ContoursDetector:
    def __init__(self, det_params):
        self.det_params = det_params        
        
    def create_edged_image(self, image):
        min_thres = self.det_params.min_edge_thres
        max_thres = self.det_params.max_edge_thres
        return cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)#TODEL
        return cv2.Canny(image, min_thres, max_thres)
        
    def is_contour_valid(self, contour):
        if cv2.contourArea(contour) > self.det_params.max_area:
            return False
        if len(contour) > self.det_params.max_area:
            return False
        return True                
        poly_tolerance = 0.01 * cv2.arcLength(contour, True)
        poly_approx = cv2.approxPolyDP(contour, poly_tolerance, True)
        return (len(poly_approx) > self.det_params.min_poly)
        
    def get_contours(self, image):
        edged_image = self.create_edged_image(image)
        ret, thresh = cv2.threshold(edged_image, 127, 255, 0)        
        retr_type, retr_approx = self.det_params.retr_type, self.det_params.retr_approx
        out_image, contours, hierarchy = cv2.findContours(thresh, retr_type, retr_approx)
        return filter(self.is_contour_valid, contours)
    
    def create_contours_image(self, image):
        contours_list = self.get_contours(image)
        contours_image = image.copy()        
        for cnt in contours_list:
            contour_color = np.random.randint(255, size=3)
            cv2.drawContours(contours_image, [cnt], 0, contour_color, -1)   
        return contours_image
    
    def plot_contours_image(self, image):
        edged_image = self.create_edged_image(image)
        plt.figure(figsize=(14, 9))
        plt.imshow(edged_image, cmap='gray')
        plt.show()                
        contours_image = self.create_contours_image(image)
        plt.figure(figsize=(14, 9))
        plt.imshow(contours_image)
        plt.show()

In [None]:
class ImageFragmentExtractor:
    def __init__(self, image, fragment_scale):
        self.image = image
        self.fragment_scale = fragment_scale
        
    def get_scaled_size(self, start_point, end_point):
        center_point = (start_point + end_point) / 2.0
        start_scaled = center_point - self.fragment_scale * (center_point - start_point)
        end_scaled = center_point + self.fragment_scale * (end_point - center_point)
        return int(start_scaled), int(end_scaled)
    
    def extract_image_fragment(self, contour):
        start_x, start_y, width, height = cv2.boundingRect(contour)
        width, height = max(width, height), max(width, height)
        start_x, end_x = self.get_scaled_size(start_x, start_x + width)
        start_y, end_y = self.get_scaled_size(start_y, start_y + height)
        return ((start_x, start_y), self.image[start_y:end_y, start_x:end_x, :])

In [None]:
class ContourSizeEstimator:
    MEAN_EST, MEDIAN_EST = 'MEAN_EST', 'MEDIAN_EST'
    
    def __init__(self, est_type=MEDIAN_EST):
        self.est_type = est_type
    
    def get_estimated_size(self, contour_list):
        if len(contour_list) == 0:
            raise Exception("List of contours must be non-empty")            
        estimated_size_arr = np.array(map(lambda x: x.shape[0], contour_list))
        if self.est_type == self.MEAN_EST:
            return np.mean(estimated_size_arr)            
        elif self.est_type == self.MEDIAN_EST:
            return np.median(estimated_size_arr)
        else:
            raise Exception("Invalid estimation type")

In [None]:
class ImagePointFilter:
    def __init__(self, filter_radius):
        self.filter_radius = filter_radius
        self.occupied_point_list = []
        
    def get_point_distance_square(self, point1, point2):
        px, py = point1[0] - point2[0], point1[1] - point2[1]
        return (px ** 2 + py ** 2)
        
    def validate_point(self, point):
        for occupied_point in self.occupied_point_list:
            dist = self.get_point_distance_square(point, occupied_point)
            if dist <= self.filter_radius ** 2:
                return False
        self.occupied_point_list.append(point)
        return True
        
    def filter_points(self, point_list):
        return filter(self.validate_point, point_list)

In [None]:
class ImageFragmentFilter:
    def __init__(self, filter_radius):
        self.image_point_filter = ImagePointFilter(filter_radius)
        
    def filter_fragments(self, fragment_list):
        point_list = map(lambda x: x[0], fragment_list)
        point_set = set(self.image_point_filter.filter_points(point_list))
        filtered_fragment_list = []
        for point, image in fragment_list:        
            if image.shape[0] == 0 or image.shape[1] == 0:
                continue
            if point in point_set:
                point_set.remove(point)
                filtered_fragment_list.append((point, image))
        return filtered_fragment_list

In [None]:
def plot_fragments(fragment_list, plot_size=(16, 16), one_dim_subplot_count=8):
    plot_counter = 0
    subplot_count = one_dim_subplot_count ** 2
    for fragment in fragment_list:
        if plot_counter % subplot_count == 0:
            plt.figure(figsize=plot_size)
        plot_counter += 1
        plot_num = plot_counter % subplot_count
        if plot_num == 0:
            plot_num = subplot_count
        plt.subplot(one_dim_subplot_count, one_dim_subplot_count, plot_num) 
        plt.imshow(cv2.cvtColor(fragment[1], cv2.COLOR_RGB2GRAY), cmap='gray')
        if plot_counter % subplot_count == 0:
            plt.show()

def get_det_params(image):
    max_area = max(image.shape[0], image.shape[1])
    print(max_area)
    return DetectorParams(75, 200, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE, 10, max_area)

def test_image(image):
    det_params = get_det_params(image)
    contours_detector = ContoursDetector(det_params)
    contours_detector.plot_contours_image(image)
    contour_list = contours_detector.get_contours(image)
    contour_size_estimator = ContourSizeEstimator()
    est_contour_size = int(contour_size_estimator.get_estimated_size(contour_list))
    print("Estimated contour size: %.2f" % (est_contour_size))
    fragment_extractor = ImageFragmentExtractor(image, 2.5)  
    filter_size = min(image.shape[0], image.shape[1])
    print("Filter size: %d" % (filter_size))
    image_fragment_filter = ImageFragmentFilter(filter_size)    
    fragment_list = map(fragment_extractor.extract_image_fragment, contour_list)
    fragment_list = image_fragment_filter.filter_fragments(fragment_list)
    print("Fragments count: %d" % (len(fragment_list)))
    plot_fragments(fragment_list)        
        
def test_params(filepath_list):
    for filepath in filepath_list:
        image = cv2.imread(filepath, 1)
        test_image(image)

In [None]:
filepath_list = map(lambda x: "keyboard{0}.jpg".format(x), range(0, 19))
test_params(filepath_list)