In [None]:
import math
import numpy as np
import skimage.io as io
import scipy.stats as stats

In [None]:
image = io.imread('image.jpg').mean(axis=2)
io.imshow(image, cmap='gray')

In [None]:
shape = image.shape
print(shape)
print(image.flatten().shape)
print(stats.entropy(image.flatten(), base = 2))

# Polynom check

In [None]:
def check_polynom(coeffs):
    a, b, c = coeffs
    
    if (a > 0 and b == 0 and c == 0): # C1
        return True
    elif (a >= 0 and b > 0 and c == 0): # C2
        return True
    elif (c == 0 and b < 0 and -a <= 2 * b): # C3
        return True
    elif (c > 0 and b ** 2 < 3 * a * c): # C4
        return True
    elif (c > 0 and b ** 2 == 3 * a * c and b >= 0): # C5
        return True
    elif (c > 0 and b ** 2 == 3 * a * c and -b >= 3 * c): # C6
        return True
    elif (c == 0):
        return False
    
    right_root = (-2 * b + math.sqrt(4 * b ** 2 - 12 * a * c)) / (6 * c)
    
    if (c > 0 and b ** 2 > 3 * a * c and right_root <= 0): # C7
        return True
    
    left_root = (-2 * b - math.sqrt(4 * b ** 2 - 12 * a * c)) / (6 * c)
    
    if (c < 0 and b ** 2 > 3 * a * c and left_root >= 1): # C8
        return True
    
    return c < 0 and b ** 2 > 3 * a * c and right_root >= 1 and left_root <= 0 # C9

# Log-entropy

In [None]:
def compute_center_of_mass(image):
    rows, cols = image.shape
    image_sum = np.sum(image)
    row_indexes = np.arange(1, rows + 1)
    row_sum = np.sum(np.multiply(image.T, row_indexes))
    col_indexes = np.arange(1, cols + 1)
    col_sum = np.sum(np.multiply(image, col_indexes))
    return (row_sum // image_sum, col_sum // image_sum)

In [None]:
from scipy.ndimage import gaussian_filter

def compute_entropy(image):
    histogram_length = 256
    histogram = np.zeros(histogram_length, dtype = np.float32)
    log_intensity = ((histogram_length - 1) / np.log2(histogram_length)) * np.log2(1 + image)
    
    floor, ceil = np.floor(log_intensity).astype(np.uint8), np.ceil(log_intensity).astype(np.uint8)
    np.add.at(histogram, floor, 1 + floor - log_intensity)
    np.add.at(histogram, ceil, ceil - log_intensity)
    
    histogram = gaussian_filter(histogram, sigma = 2.25)
    
    result = histogram / np.sum(histogram)
    
    result[result > 0] = result[result > 0] * np.log2(result[result > 0])
    return np.sum(result)

# Compute polynom

In [None]:
def compute_polynom(radius, coeffs):
    a, b, c = coeffs
    return 1 + a * np.power(radius, 2) + b * np.power(radius, 4) + c * np.power(radius, 6)

# Compute radius

In [None]:
def compute_radius(image, image_center):
    rows, cols = image.shape
    row_center, col_center = image_center
    distance = math.sqrt(row_center ** 2 + col_center ** 2)
    
    radius = np.zeros(rows * cols, dtype = np.float32)
    for row in range(rows):
        for col in range(cols):
            radius[row * cols + col] = math.sqrt((row - row_center) ** 2 + (col - col_center) ** 2) / distance
            
    return radius

# Compute image

In [None]:
def correct_vignetting(image):
    base_shape = image.shape
    a, b, c, delta = 0.0, 0.0, 0.0, 8.0
    
    center_of_mass = compute_center_of_mass(image)
    radius = compute_radius(image, center_of_mass)
    Hmin = compute_entropy(image)
    image = image.flatten()
    
    while (delta > (1 / 256)):
        coeffs = np.array([(a + delta, b, c), (a - delta, b, c),
                            (a, b + delta, c), (a, b - delta, c),
                            (a, b, c + delta), (a, b, c - delta)])
        
        for vector in coeffs:
            if (check_polynom(vector)):
                polynom = compute_polynom(radius, vector)
                result_image = image * polynom
                H = compute_entropy(result_image)
                if (Hmin > H):
                    Hmin, delta, (a, b, c) = H, 16.0, vector
        
        delta /= 2
        
    io.imshow(polynom.reshape(base_shape))
    return result_image.reshape(base_shape)

In [None]:
test = correct_vignetting(image)

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(40,40))
plt.subplot(131)
plt.imshow(test, cmap = 'gray')
plt.title('Resulting image');
plt.subplot(132)
plt.imshow(cm_result, cmap='gray')
plt.title('Resulting image with -H');
plt.subplot(133)
plt.imshow(image, cmap='gray')
plt.title('Original image');

In [None]:
cm_result = test.copy()