In [1]:
import tensorflow as tf

import numpy as np
import matplotlib.pyplot as plt
import glob, os

import ipywidgets as widgets
import IPython.display as disp
from tqdm.notebook import tqdm

# list available backends
# 'widget' backend allows interactive plots within the notebook
# Matplotlib is convenient for zooming into the image and viewing pixel coordinates/values
%matplotlib -l      
%matplotlib widget


Available matplotlib backends: ['tk', 'gtk', 'gtk3', 'gtk4', 'wx', 'qt4', 'qt5', 'qt6', 'qt', 'osx', 'nbagg', 'notebook', 'agg', 'svg', 'pdf', 'ps', 'inline', 'ipympl', 'widget']


In [2]:
SOURCE_PATH = glob.glob("./data/JPEG/*.jpg")
BASE_PATH   = "./data/JPEG/4J7A6511.jpg"

print("Current Working Directory :", os.getcwd())
widgets.FileUpload(accept='image/*', multiple=True)

SOURCE_PATH.sort()

if not SOURCE_PATH:
    raise FileNotFoundError("Please set a valid image sequence path for SOURCE_PATH")
if not os.path.isfile(BASE_PATH):
    raise FileNotFoundError("Please set a valid image path for BASE_PATH")


_abs_source = list(map(os.path.abspath, SOURCE_PATH))
if os.path.abspath(BASE_PATH) in _abs_source:
    SOURCE_PATH.pop(_abs_source.index(os.path.abspath(BASE_PATH)))


Current Working Directory : /Users/gautamd/Home/github/startrail-merger


In [3]:
nn = tf.keras.models.load_model('data/CNNmodel')
# print(dir(nn))
print("Expected dimensions of image batch (batchsize, w, h, bands) - ", nn.input_shape)
SLICE_SIZE = nn.input_shape[1:3]


def preprocess(img, nw_offset=0, slice_size=SLICE_SIZE,):
    """Normalise the image colours; split it into multiple pieces
    of required width/height & stack them into a 4D tensor.
    Pad borders with the average colour in the image if dimensions
    aren't an exact multiple."""
    h, w, bands = img.shape
    dx, dy = slice_size
    h += nw_offset; w += nw_offset
    if bands > 3:
        img = img[:,:,:3]
    img = tf.cast(img, tf.float32) / 255.0
    padding = tf.constant([
      [nw_offset, dy-h%dy if h%dy else 0], 
      [nw_offset, dx-w%dx if w%dx else 0], 
      [0, 0]
    ])
    full = tf.pad(img, padding,
      constant_values=tf.math.reduce_mean(img))
    # plt.figure(); plt.imshow(tf.keras.utils.array_to_img(full))
    slices = []
    for y in range(0,h,dy):
        for x in range(0,w,dx):
            slices.append(full[y:y+dy, x:x+dx, :])

    return tf.stack(slices)


def reconstruct(img_stack, w, h, nw_offset=0, norm_factor=1):
    """Unstack the 4D output of the neural network into a 3D tensor
    (output image), the reverse of `reconstruct`.
    Also normalise intensity of colours in the output."""
    ny = np.ceil((h+nw_offset)/SLICE_SIZE[0]).astype(int)
    nx = np.ceil((w+nw_offset)/SLICE_SIZE[1]).astype(int)
    reconstructed = tf.concat([
        tf.concat([
            img_stack[x] for x in range(nx*y, nx*(y+1))
        ], axis=1)
        for y in range(ny)
    ], axis=0)
    og_shape = reconstructed[nw_offset : h+nw_offset, 
                             nw_offset : w+nw_offset, :]
    mi, ma = og_shape.numpy().min(), og_shape.numpy().max()
    if ma==mi :
        processed = og_shape * 0
    else :
        processed = tf.clip_by_value((og_shape-mi) * norm_factor/(ma-mi),
                          0.0, 1.0)
    return processed


Expected dimensions of image batch (batchsize, w, h, bands) -  (None, 128, 128, 3)


In [8]:
imgStack = tf.io.decode_image(tf.io.read_file(BASE_PATH))

for p in tqdm(SOURCE_PATH) :
    img = tf.io.decode_image(tf.io.read_file(p))
    h, w, bands = img.shape
    mask_a = reconstruct(nn(preprocess(img, 0)), w, h, 0, 2)
    mask_b = reconstruct(nn(preprocess(img, 64)), w, h, 64, 2)
    mask = tf.math.maximum(mask_a, mask_b)

    # plt.figure(); plt.imshow(tf.keras.utils.array_to_img(mask))

    select = tf.cast(tf.cast(img, mask.dtype) * mask, img.dtype)
    imgStack = tf.math.maximum(imgStack, select)

plt.figure(); plt.imshow(tf.keras.utils.array_to_img(imgStack))


  0%|          | 0/31 [00:00<?, ?it/s]

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.image.AxesImage at 0x7fde406ea890>