In [12]:
import tensorflow as tf
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import os, time, scipy.io
import pathlib
import rawpy
from IPython.display import Image
import glob
import PIL
import PIL.Image as plim
import keras

In [2]:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
  try:
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
    logical_gpus = tf.config.list_logical_devices('GPU')
    print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
  except RuntimeError as e:
    print(e)
tf.data.Dataset.cache
tf.data.Dataset.prefetch

1 Physical GPUs, 1 Logical GPUs


<function tensorflow.python.data.ops.dataset_ops.DatasetV2.prefetch(self, buffer_size, name=None)>

In [3]:
from PIL import Image


_errstr = "Mode is unknown or incompatible with input array shape."


def bytescale(data, cmin=None, cmax=None, high=255, low=0):
    """
    Byte scales an array (image).
    Byte scaling means converting the input image to uint8 dtype and scaling
    the range to ``(low, high)`` (default 0-255).
    If the input image already has dtype uint8, no scaling is done.
    This function is only available if Python Imaging Library (PIL) is installed.
    Parameters
    ----------
    data : ndarray
        PIL image data array.
    cmin : scalar, optional
        Bias scaling of small values. Default is ``data.min()``.
    cmax : scalar, optional
        Bias scaling of large values. Default is ``data.max()``.
    high : scalar, optional
        Scale max value to `high`.  Default is 255.
    low : scalar, optional
        Scale min value to `low`.  Default is 0.
    Returns
    -------
    img_array : uint8 ndarray
        The byte-scaled array.
    Examples
    --------
    >>> from scipy.misc import bytescale
    >>> img = np.array([[ 91.06794177,   3.39058326,  84.4221549 ],
    ...                 [ 73.88003259,  80.91433048,   4.88878881],
    ...                 [ 51.53875334,  34.45808177,  27.5873488 ]])
    >>> bytescale(img)
    array([[255,   0, 236],
           [205, 225,   4],
           [140,  90,  70]], dtype=uint8)
    >>> bytescale(img, high=200, low=100)
    array([[200, 100, 192],
           [180, 188, 102],
           [155, 135, 128]], dtype=uint8)
    >>> bytescale(img, cmin=0, cmax=255)
    array([[91,  3, 84],
           [74, 81,  5],
           [52, 34, 28]], dtype=uint8)
    """
    if data.dtype == np.uint8:
        return data

    if high > 255:
        raise ValueError("`high` should be less than or equal to 255.")
    if low < 0:
        raise ValueError("`low` should be greater than or equal to 0.")
    if high < low:
        raise ValueError("`high` should be greater than or equal to `low`.")

    if cmin is None:
        cmin = data.min()
    if cmax is None:
        cmax = data.max()

    cscale = cmax - cmin
    if cscale < 0:
        raise ValueError("`cmax` should be larger than `cmin`.")
    elif cscale == 0:
        cscale = 1

    scale = float(high - low) / cscale
    bytedata = (data - cmin) * scale + low
    return (bytedata.clip(low, high) + 0.5).astype(np.uint8)


def toimage(arr, high=255, low=0, cmin=None, cmax=None, pal=None,
            mode=None, channel_axis=None):
    """Takes a numpy array and returns a PIL image.
    This function is only available if Python Imaging Library (PIL) is installed.
    The mode of the PIL image depends on the array shape and the `pal` and
    `mode` keywords.
    For 2-D arrays, if `pal` is a valid (N,3) byte-array giving the RGB values
    (from 0 to 255) then ``mode='P'``, otherwise ``mode='L'``, unless mode
    is given as 'F' or 'I' in which case a float and/or integer array is made.
    .. warning::
        This function uses `bytescale` under the hood to rescale images to use
        the full (0, 255) range if ``mode`` is one of ``None, 'L', 'P', 'l'``.
        It will also cast data for 2-D images to ``uint32`` for ``mode=None``
        (which is the default).
    Notes
    -----
    For 3-D arrays, the `channel_axis` argument tells which dimension of the
    array holds the channel data.
    For 3-D arrays if one of the dimensions is 3, the mode is 'RGB'
    by default or 'YCbCr' if selected.
    The numpy array must be either 2 dimensional or 3 dimensional.
    """
    data = np.asarray(arr)
    if np.iscomplexobj(data):
        raise ValueError("Cannot convert a complex-valued array.")
    shape = list(data.shape)
    valid = len(shape) == 2 or ((len(shape) == 3) and
                                ((3 in shape) or (4 in shape)))
    if not valid:
        raise ValueError("'arr' does not have a suitable array shape for "
                         "any mode.")
    if len(shape) == 2:
        shape = (shape[1], shape[0])  # columns show up first
        if mode == 'F':
            data32 = data.astype(np.float32)
            image = Image.frombytes(mode, shape, data32.tostring())
            return image
        if mode in [None, 'L', 'P']:
            bytedata = bytescale(data, high=high, low=low,
                                 cmin=cmin, cmax=cmax)
            image = Image.frombytes('L', shape, bytedata.tostring())
            if pal is not None:
                image.putpalette(np.asarray(pal, dtype=np.uint8).tostring())
                # Becomes a mode='P' automagically.
            elif mode == 'P':  # default gray-scale
                pal = (np.arange(0, 256, 1, dtype=np.uint8)[:, np.newaxis] *
                       np.ones((3,), dtype=np.uint8)[np.newaxis, :])
                image.putpalette(np.asarray(pal, dtype=np.uint8).tostring())
            return image
        if mode == '1':  # high input gives threshold for 1
            bytedata = (data > high)
            image = Image.frombytes('1', shape, bytedata.tostring())
            return image
        if cmin is None:
            cmin = np.amin(np.ravel(data))
        if cmax is None:
            cmax = np.amax(np.ravel(data))
        data = (data*1.0 - cmin)*(high - low)/(cmax - cmin) + low
        if mode == 'I':
            data32 = data.astype(np.uint32)
            image = Image.frombytes(mode, shape, data32.tostring())
        else:
            raise ValueError(_errstr)
        return image

    # if here then 3-d array with a 3 or a 4 in the shape length.
    # Check for 3 in datacube shape --- 'RGB' or 'YCbCr'
    if channel_axis is None:
        if (3 in shape):
            ca = np.flatnonzero(np.asarray(shape) == 3)[0]
        else:
            ca = np.flatnonzero(np.asarray(shape) == 4)
            if len(ca):
                ca = ca[0]
            else:
                raise ValueError("Could not find channel dimension.")
    else:
        ca = channel_axis

    numch = shape[ca]
    if numch not in [3, 4]:
        raise ValueError("Channel axis dimension is not valid.")

    bytedata = bytescale(data, high=high, low=low, cmin=cmin, cmax=cmax)
    if ca == 2:
        strdata = bytedata.tostring()
        shape = (shape[1], shape[0])
    elif ca == 1:
        strdata = np.transpose(bytedata, (0, 2, 1)).tostring()
        shape = (shape[2], shape[0])
    elif ca == 0:
        strdata = np.transpose(bytedata, (1, 2, 0)).tostring()
        shape = (shape[2], shape[1])
    if mode is None:
        if numch == 3:
            mode = 'RGB'
        else:
            mode = 'RGBA'

    if mode not in ['RGB', 'RGBA', 'YCbCr', 'CMYK']:
        raise ValueError(_errstr)

    if mode in ['RGB', 'YCbCr']:
        if numch != 3:
            raise ValueError("Invalid array shape for mode.")
    if mode in ['RGBA', 'CMYK']:
        if numch != 4:
            raise ValueError("Invalid array shape for mode.")

    # Here we know data and mode is correct
    image = Image.frombytes(mode, shape, strdata)
    return image

In [6]:
input_dir = './Sony/short/'
gt_dir = './Sony/long/'
checkpoint_dir = './result_Sony/'
result_dir = './result_Sony/keras_model_test/'

target_fns = glob.glob(gt_dir + '0*.ARW')
target_ids = [int(os.path.basename(target_fn)[0:5]) for target_fn in target_fns]

NameError: name 'glob' is not defined

In [3]:
def double_conv(x, n_filters):
    x1 = tf.keras.layers.Conv2D(n_filters, [3,3], padding = 'SAME', activation = leaky_relu_layer, kernel_initializer = 'random_normal')(x)
    x2 = tf.keras.layers.Conv2D(n_filters, [3,3], padding = 'SAME', activation = leaky_relu_layer, kernel_initializer = 'random_normal')(x1)
    return x2

def downsample_block(x, n_filters):
    f = double_conv(x, n_filters)
    p = tf.keras.layers.MaxPool2D(2, padding = 'SAME')(f)
    #  p = tf.keras.layers.Dropout(0.3)(p)
    return f, p

def upsample_block(x1, n_filters, x2 = None):
    pool_size = 2 
    # output_chanel = n_filter
    input_channels = n_filters*2
    
    deconv_filter = tf.Variable(tf.random.truncated_normal([pool_size, pool_size, n_filters, input_channels], stddev = 0.02))
    # print(deconv_filter)
    # print('x2 down_conw',tf.shape(x2))
    # print('x1 up_conw',tf.shape(x1))
    x = tf.keras.layers.Conv2DTranspose(n_filters, pool_size, strides = (pool_size, pool_size),padding='same')(x1)
    print(np.shape(x))
    x = tf.concat([x, x2], 3)
    x.set_shape([None, None, None, n_filters * 2])
    
    #  x = tf.keras.layers.Dropout(0.3)(x)
    x = double_conv(x, n_filters)
    return x

def build_unet_model():
    inputs = tf.keras.layers.Input(shape=(None,None,4))
    
    conv_1, p1 = downsample_block(inputs, 32)
    conv_2, p2 = downsample_block(p1, 64)
    conv_3, p3 = downsample_block(p2, 128)
    conv_4, p4 = downsample_block(p3, 256)
    
    conv_5 = double_conv(p4, 512)
    
    conv_6 = upsample_block(conv_5, 256, conv_4)
    conv_7 = upsample_block(conv_6, 128,  conv_3)
    conv_8 = upsample_block(conv_7, 64, conv_2)
    conv_9 = upsample_block(conv_8, 32, conv_1)
    conv_10 = tf.keras.layers.Conv2D(12, 1, padding="same", activation = None)(conv_9)
    out = tf.nn.depth_to_space(conv_10, 2)
    unet_model = tf.keras.Model(inputs, out, name="U-Net")
    return unet_model

In [5]:
leaky_relu_layer = tf.keras.layers.LeakyReLU(0.2)
learning_rate = 1e-4
ps = 512
save_freq = 50

NameError: name 'tf' is not defined

In [7]:
unet_model = build_unet_model()
learning_rate = 1e-4
optimizer = tf.keras.optimizers.Adam(learning_rate)
unet_model.compile(optimizer, keras.losses.MeanAbsoluteError())



<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x24622a7ce20>

In [8]:
def pack_raw(raw):
  im = raw.raw_image_visible.astype(np.float32)
  im = np.maximum(im - 512, 0)/(16383 - 512)
  im = np.expand_dims(im, axis = 2)
  # print(np.shape(im))
  img_shape = im.shape
  H = img_shape[0]
  W = img_shape[1]

  out = np.concatenate((im[0:H:2, 0:W:2, :],
                          im[0:H:2, 1:W:2, :],
                          im[1:H:2, 1:W:2, :],
                          im[1:H:2, 0:W:2, :]), axis=2)
  # print(np.shape(out))
  return out
#разделяет raw на RGBG каналы
def divide_layers(layer, channel):
  only_one_layer = []
  for i in range(len(layer)):
    first_layer = []
    for j in range(len(layer[0])):
      first_layer.append(layer[i][j][channel])
    only_one_layer.append(first_layer)
  return only_one_layer

def print_subplot(ax, channel, color_space, name):
  ax.imshow(channel, cmap = color_space)
  ax.set_title(name)
def print_image(image):
  image = np.array(image)
  R, G_1, B, G= cv2.split(image)
  fig, (ax_1, ax_2, ax_3, ax_4) = plt.subplots(1, 4, figsize = [20, 10])
  print_subplot(ax_1, R, 'Reds', 'Red_channel')
  print_subplot(ax_2, G_1, 'Greens', 'Green_channel')
  print_subplot(ax_3, B, 'Blues', 'Blue_channel')
  print_subplot(ax_4, G, 'Greens', 'Green_2_channel')

In [9]:
optimizer = tf.keras.optimizers.Adam(learning_rate)
def loss_1(y_pred, y_true):
    return tf.reduce_mean(tf.abs(y_pred - y_true))

In [10]:
@tf.function
def run_optimization(x,y):
    with tf.GradientTape() as g:
        prediction = unet_model(x)
        loss = tf.reduce_mean(tf.abs(prediction - y))
    trainable_variables = unet_model.trainable_variables
    gradients = g.gradient(loss, trainable_variables)
    optimizer.apply_gradients(zip(gradients, trainable_variables))

    return loss

In [11]:
gt_images = [None] * 6000
input_images = {}
input_images['300'] = [None] * len(target_ids)
input_images['250'] = [None] * len(target_ids)
input_images['100'] = [None] * len(target_ids)
cnt = 0

for epoch in range(1+903, 2001):
    if epoch > 1500:
        learning_rate = 1e-5
    st = time.time()
    loss_epoch = []
    st = time.time()
    for ind in np.random.permutation(len(target_ids)):
        target_id = target_ids[ind]
        in_files = glob.glob(input_dir + '%05d_0?*.ARW' % target_id)
        in_path = in_files[np.random.randint(0, len(in_files))]
        in_fn = os.path.basename(in_path)

        gt_files = glob.glob(gt_dir + '%05d_00*.ARW' % target_id)
        gt_path = gt_files[0]
        gt_fn = os.path.basename(gt_path)
        in_exposure = float(in_fn[9:-5])
        gt_exposure = float(gt_fn[9:-5])
        ratio = min(gt_exposure / in_exposure, 300)

        if input_images[str(ratio)[0:3]][ind] is None:
            raw = rawpy.imread(in_path)
            input_images[str(ratio)[0:3]][ind] = np.expand_dims(pack_raw(raw), axis=0) * ratio

            gt_raw = rawpy.imread(gt_path)
            im = gt_raw.postprocess(use_camera_wb=True, half_size=False, no_auto_bright=True, output_bps=16)
            gt_images[ind] = np.expand_dims(np.float32(im / 65535.0), axis=0)
        
        H = input_images[str(ratio)[0:3]][ind].shape[1]
        W = input_images[str(ratio)[0:3]][ind].shape[2]
        xx = np.random.randint(0, W - ps)
        yy = np.random.randint(0, H - ps)
        input_patch = input_images[str(ratio)[0:3]][ind][:, yy:yy + ps, xx:xx + ps, :]
        gt_patch = gt_images[ind][:, yy * 2:yy * 2 + ps * 2, xx * 2:xx * 2 + ps * 2, :]
        
        if np.random.randint(2, size=1)[0] == 1:  # random flip
            input_patch = np.flip(input_patch, axis=1)
            gt_patch = np.flip(gt_patch, axis=1)
        if np.random.randint(2, size=1)[0] == 1:
            input_patch = np.flip(input_patch, axis=2)
            gt_patch = np.flip(gt_patch, axis=2)
        if np.random.randint(2, size=1)[0] == 1:  # random transpose
            input_patch = np.transpose(input_patch, (0, 2, 1, 3))
            gt_patch = np.transpose(gt_patch, (0, 2, 1, 3))
        # print("Min and max pixel values:",np.min(input_patch), np.max(input_patch))
        input_patch = np.minimum(input_patch, 1.0)
        loss = run_optimization(input_patch, gt_patch)
        loss_epoch.append(loss)
        if epoch % 100 == 0:
            output = unet_model(input_patch)  
            output = np.minimum(np.maximum(output, 0), 1)
            if not os.path.isdir(result_dir + '%04d' % epoch):
                os.makedirs(result_dir + '%04d' % epoch)
            temp = np.concatenate((gt_patch[0, :, :, :], output[0, :, :, :]), axis=1)
            toimage(temp * 255, high=255, low=0, cmin=0, cmax=255, mode = 'RGB').save(
                result_dir + '%04d/%05d_00_train_%d.png' % (epoch, target_id, ratio))

    if epoch % save_freq == 0 or epoch == 5:
        cnt +=1
        output = unet_model(input_patch)  
        output = np.minimum(np.maximum(output, 0), 1)
        if not os.path.isdir(result_dir + '%04d' % epoch):
            os.makedirs(result_dir + '%04d' % epoch)
        temp = np.concatenate((gt_patch[0, :, :, :], output[0, :, :, :]), axis=1)
        toimage(temp * 255, high=255, low=0, cmin=0, cmax=255, mode = 'RGB').save(
            result_dir + '%04d/%05d_00_train_%d.png' % (epoch, target_id, ratio))
        unet_model.save(filepath=result_dir + "check_points/keras_model/kerasmodel%03d.keras" % cnt)
    print(f'epoch: {epoch}, time: { "{:.4f}".format(time.time() - st)}s, mean loss {"{:.3f}".format(np.mean(loss_epoch))}')

epoch: 1, time: 212.1296s, mean loss 0.046
epoch: 2, time: 99.4767s, mean loss 0.037
epoch: 3, time: 83.1773s, mean loss 0.036
epoch: 4, time: 87.1225s, mean loss 0.034


  strdata = bytedata.tostring()


TypeError: Layer tf.nn.conv2d_transpose was passed non-JSON-serializable arguments. Arguments had types: {'filters': <class 'tensorflow.python.ops.resource_variable_ops.ResourceVariable'>, 'output_shape': [<class 'str'>, <class 'int'>, <class 'int'>], 'strides': [<class 'int'>, <class 'int'>, <class 'int'>, <class 'int'>]}. They cannot be serialized out when saving the model.

In [None]:
im = rawpy.imread(f'{target_fns[14]}')
raw = pack_raw(im)
crop = tf.image.stateless_random_crop(
      raw, size=[ps, ps, 4], seed=(3,8))
test = np.expand_dims(crop, axis=0)
out_im = unet_model(test)

In [None]:
out_im = np.minimum(np.maximum(out_im, 0), 1)
plt.imshow(out_im[0])

In [None]:
test_fn = glob.glob('./Sony/' + '*.ARW')
print(test_fn)
im = rawpy.imread(f'{test_fn[0]}')
raw = pack_raw(im)
for i in range(4):
    for j in range(6): 
        crop = raw[512*i:512*(i+1), 512*j:512*(j+1), :]
        test = np.expand_dims(crop, axis=0) * 50
        out_im = unet_model(test)
        out_im = np.minimum(np.maximum(out_im, 0), 1)
        toimage(out_im[0] * 255, high=255, low=0, cmin=0, cmax=255, mode = 'RGB').save(
                    result_dir + "test_my_data/" + f'{int(os.path.basename(test_fn[0])[4:8])}_{i}{j}.png')