In [1]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/My\ Drive/Colab\ Notebooks/Rootee

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/My Drive/Colab Notebooks/Rootee


In [2]:
import numpy as np
import scipy.ndimage as ndi
from scipy import signal
import cv2
import logging
log = logging.getLogger(__name__)
import PIL
from matplotlib import pyplot as plt
import scipy.stats as stats
import glob
from PIL import Image


def get_background(img, is_01_normalized=True):
    return ~get_foreground(img, is_01_normalized)

def get_foreground(img, is_01_normalized=True):
    return center_crop_and_get_foreground_mask(
        img, crop=False, is_01_normalized=is_01_normalized)[1]

def get_center_circle_coords(im, is_01_normalized: bool):
    A = np.dstack([
        signal.cspline2d(im[:,:,ch] * (255 if is_01_normalized else 1), 200.0)
        for ch in range(im.shape[-1])])
    min_r = int(min(im.shape[0], im.shape[1]) / 4)
    max_r = int(max(im.shape[0], im.shape[1]) / 4*3)
    try:
        circles = cv2.HoughCircles(
            (norm01(A).max(-1)*255).astype('uint8'), cv2.HOUGH_GRADIENT, .8,
            min(A.shape[:2]), param1=20, param2=50, minRadius=min_r, maxRadius=max_r)[0]
    except:
        log.warn('center_crop_and_get_foreground_mask failed to get background - trying again with looser parameters')
        A2 = get_foreground_slow(im)
        circles = cv2.HoughCircles(
            (A2*255).astype('uint8'), cv2.HOUGH_GRADIENT, .8,
            min(A.shape[:2]), param1=20, param2=10, minRadius=min_r, maxRadius=max_r)[0]
    x, y, r = circles[circles[:, -1].argmax()].round().astype('int')
    return x,y,r

def get_foreground_mask_from_center_circle_coords(shape, x,y,r):
    mask = np.zeros(shape, dtype='uint8')
    cv2.circle(mask, (x, y), r, 255, cv2.FILLED)
    mask = mask.astype(bool)
    return mask

def center_crop_and_get_foreground_mask(im, crop=True, is_01_normalized=True, center_circle_coords=None, label_img=None):
    if center_circle_coords is not None:
        x,y,r = center_circle_coords
    else:
        h, w, _ = im.shape
        x, y, r = get_center_circle_coords(im, is_01_normalized)
    mask = get_foreground_mask_from_center_circle_coords(im.shape[:2], x,y,r)
    if crop:
        crop_slice = np.s_[max(0, y-r):min(h,y+r),max(0,x-r):min(w,x+r)]
        rv = [im[crop_slice], mask[crop_slice]]
        if label_img is not None:
            rv.append(label_img[crop_slice])
    else:  # don't crop.  just get the mask.
        rv = [im, mask]
        if label_img is not None:
            rv.append(label_img[crop_slice])
    return rv

def get_background_slow(img):

    img = img/img.max()
    background = (img < 20/255)
    background = ndi.morphology.binary_closing(
        background, np.ones((5, 5, 1)))
    background |= np.pad(np.zeros(
        (background.shape[0]-6, background.shape[1]-6, 3), dtype='bool'),
        [(3, 3), (3, 3), (0, 0)], 'constant', constant_values=1)
    return background.sum(2) == 3

def get_foreground_slow(img):
    return ~get_background_slow(img)

def zero_mean(img, fg):
    z = img[fg]
    return (img - z.mean()) + 0.5

def norm01(img, background=None):
    """normalize in [0,1] using global min and max.
    If background mask given, exclude it from normalization."""
    if background is not None:
        tmp = img[~background]
        min_, max_ = tmp.min(), tmp.max()
    else:
        min_, max_ = img.min(), img.max()
    rv = (img - min_) / (max_ - min_)
    if background is not None:
        rv[background] = img[background]
    return rv

In [3]:
def get_dark_channel(
        img: np.ndarray, filter_size: int):
  
    _tmp = stats.norm.pdf(np.linspace(0, 1, filter_size), .5, .25/2)
    dark_channel = ndi.minimum_filter(
        img.min(-1), footprint=np.log(np.outer(_tmp, _tmp)) > -6)
    return dark_channel

def get_atmosphere(img: np.ndarray, dark: np.ndarray):
    """Given an image of shape (h, w, 3) and a dark channel of shape (h, w),
    compute and return the atmosphere, a vector of shape (3, )
    Consider the 10% brightest pixels in dark channel, look up their
    intensities in original image and use the brightest intensity found from
    that set.
    """
    # top 10\% of brightest pixels in dark channel
    q = np.quantile(dark.ravel(), 0.999) - 1e-6
    mask = dark >= q
    rv = np.array([img[:, :, ch][mask].max() for ch in range(3)])
    assert img.shape[2] == 3  # sanity check
    rv += 1 - rv.max()  # seems to make img brighter
    return rv

def dehaze(img, dark_channel_filter_size=15, guided_filter_radius=50,
           guided_eps=1e-2):
    img = img / img.max()
    darkch_unnorm = get_dark_channel(img, dark_channel_filter_size)
    A = get_atmosphere(img, darkch_unnorm).reshape(1, 1, 3)

    t_unrefined = 1 - get_dark_channel(img / A, dark_channel_filter_size)
  
    t_refined = cv2.ximgproc.guidedFilter(
        img.astype('float32'),
        t_unrefined.astype('float32'), guided_filter_radius, guided_eps)
    t_refined = t_refined.clip(0.0001, 1)  # guided filter can make slightly >1

    radiance = (  # Eq. 22 of paper
        img.astype('float')-A) \
        / np.expand_dims(t_refined, -1).astype('float') \
        + A
    
    radiance = radiance.clip(0, 1)
    return locals()

def illumination_correction(img, dark_channel_filter_size=25,
                            guided_filter_radius=80, guided_eps=1e-2, A=1):
    """Illumination correction is basically just inverted dehazing"""
    img = img / img.max()

    t_unrefined = get_dark_channel((1-img) / A, dark_channel_filter_size)
    # invert image after guided filtering
    t_refined = 1-cv2.ximgproc.guidedFilter(
        1-img.astype('float32'),
        t_unrefined.astype('float32'), guided_filter_radius, guided_eps)
    t_refined = t_refined.clip(0.00001, 1)  # guided filter can make slightly >1
    # invert the inverted image when recovering radiance
    radiance = 1 - (((1-img.astype('float')) - A)/np.expand_dims(t_refined, -1) + A)
  
    return locals()

def dehaze_from_fp(fp):
    with PIL.Image.open(fp) as img:
        img.load()
    img = np.array(img)/255
    # remove background, assuming retinal fundus image
    background = get_background(img)
    img[background] = 1
    return dehaze(img)

def illuminate_from_fp(fp):
    with PIL.Image.open(fp) as img:
        img.load()
    img = np.array(img)/255
    return illuminate_dehaze(img)

def illuminate_dehaze(img):
    """
    Perform illumination correction to remove shadows followed by dehazing.
    Correctly remove background
    Return a tuple of dicts.  The first dict is output of illumination
    correction.  Second dict is output from dehazing.
    """
    # compute a background mask to clean up noise from the guided filter
    background = get_background(img)
    img[background] = 1

    d = illumination_correction(img)
    # reset the background
    d['radiance'][background] = 1/255

    d2 = dehaze(d['radiance'])

    d['background'] = background
    return d, d2


if __name__ == "__main__":
    
    fps_grade3 = glob.glob('/content/drive/My Drive/Colab Notebooks/Rootee/poorimage.jpg')
    fp = fps_grade3[0]
    
    with PIL.Image.open(fp) as img:
        img.load()
    img = np.array(img)/255
    d, d2 = illuminate_from_fp(fp)

    illuminated = d['radiance']
    print(illuminated.shape)
    im = Image.fromarray((illuminated* 255).astype(np.uint8))
    #im = Image.fromarray(illuminated)
    im.save('illuminated.png')
    dehazed = d2['radiance']
    im2 = Image.fromarray((dehazed* 255).astype(np.uint8))
    #im = Image.fromarray(illuminated)
    im2.save('dehazed.png')
    f, axs = plt.subplots(1, 3)
    axs[0].imshow(img)
    axs[0].set_title('Original Image')
    axs[1].imshow(illuminated)
    axs[1].set_title("Illuminated")
    axs[2].imshow(dehazed)
    axs[2].set_title("Dehazed")
    #f.suptitle('Illumination Correction Pipeline')

OSError: ignored