In [11]:
import cv2
import rawpy
import os
import os.path as path
import glob
import numpy as np
import scipy.spatial
import ipdb
import os
import platform
import multiprocessing
import json
from functools import lru_cache
from collections import namedtuple
from matplotlib import pyplot as plot

In [12]:
%matplotlib qt

In [13]:
def visualise_clusters(clusters):
    image_size = clusters[0][0].icon_medium.shape[0]
    number_of_images = np.sum([len(cluster) for cluster in clusters])
    hr_height = 25
    vr_length = 25
    
    max_images_per_column = 4
    columns = int(number_of_images / max_images_per_column) + 1
    
    width = (image_size + vr_length) * columns
    height = image_size * number_of_images + hr_height * len(clusters)
    buffer = np.zeros((height, width, 3))
    
    i = 0
    hr_offset = 0
    for cluster in clusters:
        for iImage, image in enumerate(cluster):
            column = int(i / max_images_per_column)
            if column != int((i - 1) / max_images_per_column):
                if iImage == 0:
                    hr_offset = hr_height
                else:
                    hr_offset = 0
            row_offset = i * image_size + hr_offset - column * max_images_per_column * image_size
            column_offset = int(i / max_images_per_column) * image_size
            buffer[row_offset:row_offset + image_size,column_offset:column_offset + image_size,:] = image.icon_medium
            i = i + 1
        hr_offset = hr_offset + hr_height
    cv2.imshow("Clusters", buffer / np.max(buffer))
    cv2.waitKey(1)

# Load images defined in groups generated in Image Grouping notebook
This time, a high resolution version of the image is also loaded. This is because sharpness can't be measured on small images.

In [14]:
print("Loading images")

with open("grouped_images.json") as file:
    groups = json.load(file)

#!!!!!
groups = groups
    
Image = namedtuple('Image', 'cropped_high_resolution icon icon_medium icon_large features filename timestamp')

def creation_date(path_to_file):
    """
    Try to get the date that a file was created, falling back to when it was
    last modified if that isn't possible.
    See http://stackoverflow.com/a/39501288/1709587 for explanation.
    """
    if platform.system() == 'Windows':
        return os.path.getctime(path_to_file)
    else:
        stat = os.stat(path_to_file)
        try:
            return stat.st_birthtime
        except AttributeError:
            # We're probably on Linux. No easy way to get creation dates here,
            # so we'll settle for when its content was last modified.
            return stat.st_mtime

cpu_pool = multiprocessing.Pool(multiprocessing.cpu_count() * 2)

def load_image(image_filename):
    raw_image = rawpy.imread(image_filename)
    rgb_image = raw_image.postprocess(use_camera_wb=True, output_bps=8)
    
    if rgb_image.shape[0] < rgb_image.shape[1]:
        centre = int(rgb_image.shape[1] / 2)
        height_full = rgb_image.shape[0]
        height_half = int(height_full / 2)
        crop_start = centre - height_half
        cropped_rgb_image = rgb_image[:, crop_start:crop_start + height_full]
    else:
        raise Exception("Oops")
        
    icon = cv2.resize(rgb_image, (100, 100))
    icon_medium = cv2.resize(rgb_image, (200, 200))
    icon_large = cv2.resize(rgb_image, (300, 300))
    features = icon
    modified_time = creation_date(image_filename)

    return Image(cropped_rgb_image, icon, icon_medium, icon_large, features, image_filename, modified_time)

groups = list(map(lambda group: list(map(load_image, group)), groups))

print("Done!")

Loading images
Done!


# Measure image sharpness
This function was derived from the paper "Image Sharpness Measure for Blurred Images in Frequency
Domain" by Kanjar De and V. Masilamani. See https://ac.els-cdn.com/S1877705813016007/1-s2.0-S1877705813016007-main.pdf?_tid=f95122a8-c9be-45ec-90a7-e15fb5ecaed3&acdnat=1551129901_433e1bc6f9e722f8250c67330dc43d4a

Input: Image I of size M×N.

Output: Image Quality measure (FM) where FM stands for Frequency Domain Image Blur Measure

1. Compute F which is the Fourier Transform representation of image I
2. Find Fc which is obtained by shifting the origin of F to centre.
3. Calculate AF = abs (Fc) where AF is the absolute value of the centered Fourier transform of image I.
4. Calculate M = max (AF) where M is the maximum value of the frequency component in F.
5. Calculate TH = the total number of pixels in F whose pixel value > thres, where thres = M/1000.
6. Calculate Image Quality measure (FM) from equation (1). 

Image Quality Measure (FM) = TH / (M * N)

In [15]:
def image_sharpness(image_mat):
    fourier_transform = np.abs(np.fft.fftshift(np.fft.fft2(image_mat)))
    maximum_value = np.max(fourier_transform)
    e = maximum_value / 1000
    th = np.sum(fourier_transform > e)
    total_pixels = fourier_transform.shape[0] * fourier_transform.shape[1]
    
    return th / (fourier_transform.shape[0] * fourier_transform.shape[1])


In [16]:
# sharpness = list(map(lambda group: list(map(lambda image: image_sharpness(image.cropped_high_resolution), group)), groups))

In [17]:
sorted_groups = list(map(lambda group: list(reversed(sorted(group, key=lambda image: image_sharpness(image.cropped_high_resolution)))), groups))

In [18]:
visualise_clusters(sorted_groups)

In [19]:
textified_groups = list(map(lambda group: list(map(lambda image: image.filename, group)), sorted_groups))

with open('grouped_and_sorted_images.json', 'w') as outfile:
    json.dump(textified_groups, outfile)

In [None]:
target_folder = "target_folder"

for i, group in enumerate(sorted_groups):
    print("mkdir {0}".format(path.join(target_folder, str(i))))
    for image in group:
        print("mv {0} {1}".format(image.filepath, path.join(target_folder, str(i)) + "/"))