In [None]:
from pathlib import Path

In [None]:
import numpy as np
from PIL import Image
from scipy import fft
from skimage.restoration import inpaint

In [None]:
base_dir = Path("photos")

In [None]:
size = (666, 1000)

In [None]:
max_count = 1000
statistics = []
for path in base_dir.iterdir():
    with Image.open(path) as im:
        if im.size == size:
            statistics.append(np.asarray(im))
            if len(statistics) % (max_count // 20) == 0:
                print(f"Loaded {int(100.0 * (len(statistics) / max_count))} %")
            if len(statistics) >= max_count:
                break
statistics = np.stack(statistics)
print(f"Loaded {len(statistics)} images")

In [None]:
dark = np.quantile(statistics, 0.01, axis=0)
Image.fromarray(dark.astype(np.uint8))

In [None]:
light = np.quantile(statistics, 0.99, axis=0)
Image.fromarray(light.astype(np.uint8))

In [None]:
del statistics

In [None]:
def grayscale(image):
    return np.inner(image, np.array([0.2126, 0.7152, 0.0722]))

In [None]:
alpha = (255.0 - (light - dark))

#alpha = np.mean(alpha, axis=2, keepdims=True)
# or
alpha = np.expand_dims(grayscale(alpha), axis=2)

color = (255.0 * dark / alpha).clip(0.0, 255.0)
overlay = np.concatenate([color, alpha], axis=2)
Image.fromarray(overlay.astype(np.uint8))#.save("mask.png")

In [None]:
mask = ((alpha / 255.0) > 0.66).squeeze(2)
Image.fromarray(255 * mask.astype(np.uint8))

In [None]:
in_dir = Path("input")
out_dir = Path("output")
out_dir.mkdir(exist_ok=True)

In [None]:
images = {}
for path in in_dir.iterdir():
    with Image.open(path) as im:
        if im.size == size:
            images[path.stem] = np.asarray(im)
print(f"Loaded {len(images)} images")

In [None]:
def subtract(image):
    return 255.0 * (image - dark) / (light - dark)

In [None]:
def circle(size, radius):
    m = np.sqrt(size[0] * size[1])
    xs = np.expand_dims((np.arange(size[0], dtype="float") + 0.5) / size[0] - 0.5, axis=0) * (size[0] / m)
    ys = np.expand_dims((np.arange(size[1], dtype="float") + 0.5) / size[1] - 0.5, axis=1) * (size[1] / m)
    rs = np.sqrt(xs ** 2 + ys ** 2)
    return np.expand_dims(rs < (radius / 2), axis=2)

#Image.fromarray(255 * circle(size, 0.5).squeeze(2).astype(np.uint8))

In [None]:
def low_pass_filter(image):
    axes = (0, 1)
    freqs = fft.fftshift(fft.fft2(image, axes=axes), axes=axes)
    freqs *= circle(size, 0.66)
    image = np.real(fft.ifft2(fft.ifftshift(freqs, axes=axes), axes=axes))
    return image

In [None]:
def fill_mask(image):
    return inpaint.inpaint_biharmonic(image, mask, channel_axis=2)

In [None]:
def interpolate(x, y, alpha):
    return (1.0 - alpha) * x + alpha * y

In [None]:
for name, image in images.items():
    image = subtract(image)
    #image = interpolate(image, low_pass_filter(image), alpha / 255.0)
    image = fill_mask(image)
    output = Image.fromarray(image.clip(0, 255).astype(np.uint8))
    output.save(out_dir / (name + ".png"))
output