In [None]:
import setGPU
import nn_utils.misc.keras_not_full_memory_importer

In [None]:
import glob
import inspect
import json
import math
import os
import os.path as osp
import imageio as io
import minimg
import copy
import cv2

import tensorflow as tf

from tensorflow.keras.callbacks import Callback, ModelCheckpoint, TensorBoard
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from tensorflow.keras.layers import *
from tensorflow.keras import losses
from tensorflow.keras.utils import Sequence
from tensorflow.keras import backend as K, regularizers

from Tensorflow.KiHoughTF import KiHoughLayerTF

import numpy as np 
import pandas as pd
import minimg
import matplotlib.pyplot as plt

In [None]:
FRAGMENT_SIZE = 512
BS = 4
REG_PARAM = 0.000025
MAX_NUM_EPOCHS = 100
INITIAL_LEARNING_RATE = 0.0001

In [None]:
data_path = "data/81k_preprocessed_glare"
name = "unet_glare_equal"
output_dir = "nn_weights"
output_nn_dir = osp.join(output_dir, name)

In [None]:
def my_loss(y_true, y_pred):
    y_t = K.flatten(y_true)
    y_p = K.flatten(y_pred)
    delta1 = y_t - y_p
    k1 = tf.math.multiply(delta1, delta1)
    return K.sum(k1) 

def ssim_mse(y_true, y_pred):
        return tf.reduce_mean(tf.keras.losses.MSE(y_true, y_pred)) + (1. - tf.image.ssim(y_true, y_pred, 1.0)) * 0.5

In [None]:
def simple_conv(kernels, kernel_size = (3,3), _padding= 'same', activation=None):
    reg_weight = REG_PARAM
    bias_reg_weight = REG_PARAM
    return Conv2D(kernels, kernel_size, padding = _padding, activation=activation, kernel_regularizer=regularizers.l2(reg_weight), 
                  bias_regularizer=regularizers.l2(bias_reg_weight)) 

def conv_with_BN_down(x, kernels, padding, kernel_size = (3,3), dropout = 0, bn=True, is_bn = True):
    ret = simple_conv(kernels, kernel_size, _padding = padding, activation = None)(x)
    if is_bn:
        ret = BatchNormalization()(ret, training=bn)
    ret = Activation('relu')(ret)
    ret = Dropout(dropout)(ret)
    return ret

def conv_with_BN_up(x, kernels, padding, kernel_size = (3,3), dropout = 0, bn=True, is_bn = True):
    ret = simple_conv(kernels, kernel_size, _padding = padding, activation = None)(x)
    if is_bn:
        ret = BatchNormalization()(ret, training=bn)
    ret = tf.keras.layers.LeakyReLU(alpha=0.2)(ret)
    ret = Dropout(dropout)(ret)
    return ret

def mse_loss(y_true, y_pred):
    return tf.reduce_mean(tf.keras.losses.MSE(y_true, y_pred))

losses = {
    "out": my_loss,
    "br1": my_loss,
    "br2": my_loss
}

lossWeights = {
    "out": 1.0,
    "br1": 1.0
}

In [None]:
def unet_model(bn = True, pic=False, img_shape=(None, None, 3)):
    droprate = 0.25
    inputs = Input(img_shape)
    #inp_bg = Input((FRAGMENT_SIZE, FRAGMENT_SIZE, 3), batch_size=BS)
    
    conv1 = conv_with_BN_down(inputs, 32, padding='same', dropout=droprate*0.125, is_bn=False)
    conv1 = conv_with_BN_down(conv1, 32, padding='same', dropout=droprate*0.125, is_bn=False)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    
    
    conv2 = conv_with_BN_down(pool1, 64, padding='same', dropout=droprate*0.25, bn=bn)
    conv2 = conv_with_BN_down(conv2, 64, padding='same', dropout=droprate*0.25, bn=bn)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
    
    
    conv3 = conv_with_BN_down(pool2, 96, padding='same', dropout=droprate*0.375, bn=bn)
    conv3 = conv_with_BN_down(conv3, 96, padding='same', dropout=droprate*0.375, bn=bn)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
    
    
    conv4 = conv_with_BN_down(pool3, 128, padding='same', dropout=droprate*0.5, bn=bn)
    conv4 = conv_with_BN_down(conv4, 128, padding='same', dropout=droprate*0.5, bn=bn)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)
   
    
    conv5 = conv_with_BN_down(pool4, 160, padding='same', dropout=droprate*0.625, bn=bn)
    conv5 = conv_with_BN_down(conv5, 160, padding='same', dropout=droprate*0.625, bn=bn)
    pool5 = MaxPooling2D(pool_size=(2, 2))(conv5)
    
    
    conv5_1 = conv_with_BN_down(pool5, 192, padding='same', dropout=droprate*0.75, bn=bn)
    conv5_1 = conv_with_BN_down(conv5_1, 192, padding='same', dropout=droprate*0.75, bn=bn)
    pool5_1 = MaxPooling2D(pool_size=(2, 2))(conv5_1)
    
    
    conv5_2 = conv_with_BN_down(pool5_1, 224, padding='same', dropout=droprate*0.875, bn=bn)
    conv5_2 = conv_with_BN_down(conv5_2, 224, padding='same', dropout=droprate*0.875, bn=bn)
    pool5_2 = MaxPooling2D(pool_size=(2, 2))(conv5_2)
    
    
    conv6 = conv_with_BN_down(pool5_2, 256, padding='same', dropout=droprate, bn=bn)
    conv6 = conv_with_BN_down(conv6, 256, padding='same', dropout=droprate, bn=bn)
    
    
    up7_mask_2 = Conv2DTranspose(224, (3, 3), strides=(2, 2), padding='same')(conv6)
    if(up7_mask_2.shape[1] != conv5_2.shape[1]):
        up7_mask_2 = ZeroPadding2D(padding=((0, 1),(0, 0)))(up7_mask_2)
    if(up7_mask_2.shape[2] != conv5_2.shape[2]):
        up7_mask_2 = ZeroPadding2D(padding=((0, 0),(0, 1)))(up7_mask_2)
    up7_mask_2 = concatenate([up7_mask_2, conv5_2], axis=-1)
    conv7_mask_2 = conv_with_BN_up(up7_mask_2, 224, padding='same', dropout=droprate*0.875, bn=bn)
    conv7_mask_2 = conv_with_BN_up(conv7_mask_2, 224, padding='same', dropout=droprate*0.875, bn=bn)
    
    if(pic):
        up7_pic_2 = Conv2DTranspose(224, (3, 3), strides=(2, 2), padding='same')(conv6)
        if(up7_pic_2.shape[1] != conv5_2.shape[1]):
            up7_pic_2 = ZeroPadding2D(padding=((0, 1),(0, 0)))(up7_pic_2)
        if(up7_pic_2.shape[2] != conv5_2.shape[2]):
            up7_pic_2 = ZeroPadding2D(padding=((0, 0),(0, 1)))(up7_pic_2)
        up7_pic_2 = concatenate([up7_pic_2, conv5_2], axis=-1)
        conv7_pic_2 = conv_with_BN_up(up7_pic_2, 224, padding='same', dropout=droprate*0.875, bn=bn)
        conv7_pic_2 = conv_with_BN_up(conv7_pic_2, 224, padding='same', dropout=droprate*0.875, bn=bn)
    
    up7_mask_3 = Conv2DTranspose(192, (3, 3), strides=(2, 2), padding='same')(conv7_mask_2)
    if(up7_mask_3.shape[1] != conv5_1.shape[1]):
        up7_mask_3 = ZeroPadding2D(padding=((0, 1),(0, 0)))(up7_mask_3)
    if(up7_mask_3.shape[2] != conv5_1.shape[2]):
        up7_mask_3 = ZeroPadding2D(padding=((0, 0),(0, 1)))(up7_mask_3)
    up7_mask_3 = concatenate([up7_mask_3, conv5_1], axis=-1)
    conv7_mask_3 = conv_with_BN_up(up7_mask_3, 192, padding='same', dropout=droprate*0.75, bn=bn)
    conv7_mask_3 = conv_with_BN_up(conv7_mask_3, 192, padding='same', dropout=droprate*0.75, bn=bn)
    
    if(pic):
        up7_pic_3 = Conv2DTranspose(192, (3, 3), strides=(2, 2), padding='same')(conv7_pic_2)
        if(up7_pic_3.shape[1] != conv5_1.shape[1]):
            up7_pic_3 = ZeroPadding2D(padding=((0, 1),(0, 0)))(up7_pic_3)
        if(up7_pic_3.shape[2] != conv5_1.shape[2]):
            up7_pic_3 = ZeroPadding2D(padding=((0, 0),(0, 1)))(up7_pic_3)
        up7_pic_3 = concatenate([up7_pic_3, conv5_1], axis=-1)
        conv7_pic_3 = conv_with_BN_up(up7_pic_3, 192, padding='same', dropout=droprate*0.75, bn=bn)
        conv7_pic_3 = conv_with_BN_up(conv7_pic_3, 192, padding='same', dropout=droprate*0.75, bn=bn)
    
    up7_mask = Conv2DTranspose(160, (3, 3), strides=(2, 2), padding='same')(conv7_mask_3)
    if(up7_mask.shape[1] != conv5.shape[1]):
        up7_mask = ZeroPadding2D(padding=((0, 1),(0, 0)))(up7_mask)
    if(up7_mask.shape[2] != conv5.shape[2]):
        up7_mask = ZeroPadding2D(padding=((0, 0),(0, 1)))(up7_mask)
    up7_mask = concatenate([up7_mask, conv5], axis=-1)
    conv7_mask = conv_with_BN_up(up7_mask, 160, padding='same', dropout=droprate*0.625, bn=bn)
    conv7_mask = conv_with_BN_up(conv7_mask, 160, padding='same', dropout=droprate*0.625, bn=bn)
    
    if(pic):
        up7_pic = Conv2DTranspose(160, (3, 3), strides=(2, 2), padding='same')(conv7_pic_3)
        if(up7_pic.shape[1] != conv5.shape[1]):
            up7_pic = ZeroPadding2D(padding=((0, 1),(0, 0)))(up7_pic)
        if(up7_pic.shape[2] != conv5.shape[2]):
            up7_pic = ZeroPadding2D(padding=((0, 0),(0, 1)))(up7_pic)
        up7_pic = concatenate([up7_pic, conv5], axis=-1)
        conv7_pic = conv_with_BN_up(up7_pic, 160, padding='same', dropout=droprate*0.625, bn=bn)
        conv7_pic = conv_with_BN_up(conv7_pic, 160, padding='same', dropout=droprate*0.625, bn=bn)
    
    up_mask_8 = Conv2DTranspose(128, (3, 3), strides=(2, 2), padding='same')(conv7_mask)
    if(up_mask_8.shape[1] != conv4.shape[1]):
        up_mask_8 = ZeroPadding2D(padding=((0, 1),(0, 0)))(up_mask_8)
    if(up_mask_8.shape[2] != conv4.shape[2]):
        up_mask_8 = ZeroPadding2D(padding=((0, 0),(0, 1)))(up_mask_8)
    up_mask_8 = concatenate([up_mask_8, conv4], axis=-1)
    conv_mask_8 = conv_with_BN_up(up_mask_8, 128, padding='same', dropout=droprate*0.5, bn=bn)
    conv_mask_8 = conv_with_BN_up(conv_mask_8, 128, padding='same', dropout=droprate*0.5, bn=bn)
    
    if(pic):
        up_pic_8 = Conv2DTranspose(128, (3, 3), strides=(2, 2), padding='same')(conv7_pic)
        if(up_pic_8.shape[1] != conv4.shape[1]):
            up_pic_8 = ZeroPadding2D(padding=((0, 1),(0, 0)))(up_pic_8)
        if(up_pic_8.shape[2] != conv4.shape[2]):
            up_pic_8 = ZeroPadding2D(padding=((0, 0),(0, 1)))(up_pic_8)
        up_pic_8 = concatenate([up_pic_8, conv4], axis=-1)
        conv_pic_8 = conv_with_BN_up(up_pic_8, 128, padding='same', dropout=droprate*0.5, bn=bn)
        conv_pic_8 = conv_with_BN_up(conv_pic_8, 128, padding='same', dropout=droprate*0.5, bn=bn)
    
    up_mask_9 = Conv2DTranspose(96, (3, 3), strides=(2, 2), padding='same')(conv_mask_8)
    if(up_mask_9.shape[1] != conv3.shape[1]):
        up_mask_9 = ZeroPadding2D(padding=((0, 1),(0, 0)))(up_mask_9)
    if(up_mask_9.shape[2] != conv3.shape[2]):
        up_mask_9 = ZeroPadding2D(padding=((0, 0),(0, 1)))(up_mask_9)
    up_mask_9 = concatenate([up_mask_9, conv3], axis=-1)
    conv_mask_9 = conv_with_BN_up(up_mask_9, 96, padding='same', dropout=droprate*0.375, bn=bn)
    conv_mask_9 = conv_with_BN_up(conv_mask_9, 96, padding='same', dropout=droprate*0.375, bn=bn)
    
    if(pic):
        up_pic_9 = Conv2DTranspose(96, (3, 3), strides=(2, 2), padding='same')(conv_pic_8)
        if(up_pic_9.shape[1] != conv3.shape[1]):
            up_pic_9 = ZeroPadding2D(padding=((0, 1),(0, 0)))(up_pic_9)
        if(up_pic_9.shape[2] != conv3.shape[2]):
            up_pic_9 = ZeroPadding2D(padding=((0, 0),(0, 1)))(up_pic_9)
        up_pic_9 = concatenate([up_pic_9, conv3], axis=-1)
        conv_pic_9 = conv_with_BN_up(up_pic_9, 96, padding='same', dropout=droprate*0.375, bn=bn)
        conv_pic_9= conv_with_BN_up(conv_pic_9, 96, padding='same', dropout=droprate*0.375, bn=bn)
    
    up_mask_10 = Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same')(conv_mask_9)
    if(up_mask_10.shape[1] != conv2.shape[1]):
        up_mask_10 = ZeroPadding2D(padding=((0, 1),(0, 0)))(up_mask_10)
    if(up_mask_10.shape[2] != conv2.shape[2]):
        up_mask_10 = ZeroPadding2D(padding=((0, 0),(0, 1)))(up_mask_10)
    up_mask_10 = concatenate([up_mask_10, conv2], axis=-1)
    conv_mask_10 = conv_with_BN_up(up_mask_10, 64, padding='same', dropout=droprate*0.25, bn=bn)
    conv_mask_10 = conv_with_BN_up(conv_mask_10, 64, padding='same', dropout=droprate*0.25, bn=bn)
    
    if(pic):
        up_pic_10 = Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same')(conv_pic_9)
        if(up_pic_10.shape[1] != conv2.shape[1]):
            up_pic_10 = ZeroPadding2D(padding=((0, 1),(0, 0)))(up_pic_10)
        if(up_pic_10.shape[2] != conv2.shape[2]):
            up_pic_10 = ZeroPadding2D(padding=((0, 0),(0, 1)))(up_pic_10)
        up_pic_10 = concatenate([up_pic_10, conv2], axis=-1)
        conv_pic_10 = conv_with_BN_up(up_pic_10, 64, padding='same', dropout=droprate*0.25, bn=bn)
        conv_pic_10= conv_with_BN_up(conv_pic_10, 64, padding='same', dropout=droprate*0.25, bn=bn)
    
    up_mask_11 = Conv2DTranspose(32, (3, 3), strides=(2, 2), padding='same')(conv_mask_10)
    if(up_mask_11.shape[1] != conv1.shape[1]):
        up_mask_11 = ZeroPadding2D(padding=((0, 1),(0, 0)))(up_mask_11)
    if(up_mask_11.shape[2] != conv1.shape[2]):
        up_mask_11 = ZeroPadding2D(padding=((0, 0),(0, 1)))(up_mask_11)
    up_mask_11 = concatenate([up_mask_11, conv1], axis=-1)
    conv_mask_11 = conv_with_BN_up(up_mask_11, 32, padding='same', dropout=droprate*0.125, is_bn=False)
    conv_mask_11 = conv_with_BN_up(conv_mask_11, 32, padding='same', dropout=droprate*0.125, is_bn=False)
    
    if(pic):
        up_pic_11 = Conv2DTranspose(32, (3, 3), strides=(2, 2), padding='same')(conv_pic_10)
        if(up_pic_11.shape[1] != conv1.shape[1]):
            up_pic_11 = ZeroPadding2D(padding=((0, 1),(0, 0)))(up_pic_11)
        if(up_pic_11.shape[2] != conv1.shape[2]):
            up_pic_11 = ZeroPadding2D(padding=((0, 0),(0, 1)))(up_pic_11)
        up_pic_11 = concatenate([up_pic_11, conv1], axis=-1)
        conv_pic_11 = conv_with_BN_up(up_pic_11, 32, padding='same', dropout=droprate*0.125, bn=False)
        conv_pic_11= conv_with_BN_up(conv_pic_11, 32, padding='same', dropout=droprate*0.125, bn=False)
    
    
    
    conv_mask = Conv2D(3, (1, 1), activation='relu', name="br1")(conv_mask_11)
    #glare_color = Conv2D(3, (1, 1), activation='sigmoid', name="glare_color")(conv_mask_11)
    
    if(pic):
        conv_pic = Conv2D(3, (1, 1), activation='relu', name="br2")(conv_pic_11)
    
    
    
    if(pic):
        out = inputs - conv_mask + conv_pic
    else:
        out = inputs - conv_mask
    out = tf.math.maximum(out, 0.)
    out = tf.math.minimum(out, 1.)
    
    out = Lambda(lambda x: x, name="out")(out)

    
    model = Model(inputs=inputs, outputs=[out, conv_mask, conv_pic])
    
    return model

In [None]:
if not osp.exists(output_nn_dir):
    os.mkdir(output_nn_dir)
       
fs = FRAGMENT_SIZE

train_len = pd.read_csv(osp.join(data_path, 'train_filenames.csv')).shape[0]
valid_len = pd.read_csv(osp.join(data_path, 'valid_filenames.csv')).shape[0]

x_train = np.memmap(
    osp.join(data_path, 'x_train.bin'), 
    dtype=np.float32, shape=(train_len, fs, fs, 3)
) 
x_valid = np.memmap(
    osp.join(data_path, 'x_valid.bin'), 
    dtype=np.float32, shape=(valid_len, fs, fs, 3)
) 
y_train = np.memmap(
    osp.join(data_path, 'y_train.bin'), 
    dtype=np.float32, shape=(train_len, fs, fs, 3)
) 
y_valid = np.memmap(
    osp.join(data_path, 'y_valid.bin'), 
    dtype=np.float32, shape=(valid_len, fs, fs, 3)
) 
z_train = np.memmap(
    osp.join(data_path, 'z_train.bin'), 
    dtype=np.float32, shape=(train_len, fs, fs, 3)
) 
z_valid = np.memmap(
    osp.join(data_path, 'z_valid.bin'), 
    dtype=np.float32, shape=(valid_len, fs, fs, 3)
) 
p_train = np.memmap(
    osp.join(data_path, 'p_train.bin'), 
    dtype=np.float32, shape=(train_len, fs, fs, 3)
) 
p_valid = np.memmap(
    osp.join(data_path, 'p_valid.bin'), 
    dtype=np.float32, shape=(valid_len, fs, fs, 3)
) 

In [None]:
weight_dir = osp.join(output_nn_dir, 'weights')
if not osp.exists(weight_dir):
    os.mkdir(weight_dir)
weight_file = osp.join(weight_dir, "weights-improvement-{loss:.0f}-{val_loss:.0f}.hdf5")        
checkpoint = ModelCheckpoint(weight_file, monitor='loss', verbose=1, save_best_only=False, save_weights_only=True, mode='min')
my_callbacks_list = [checkpoint]  

In [None]:
model = unet_model(bn = True, pic = True)

In [None]:
model.load_weights("nn_weights/unet_glare_equal/weights/weights-improvement-10708-10561.hdf5")

In [None]:
model.compile(optimizer=Adam(lr = INITIAL_LEARNING_RATE), loss=losses, loss_weights=lossWeights)

In [None]:
model.fit(x_train ,{"out":y_train, "br1":z_train, "br2":p_train}, validation_data=(x_valid, {"out":y_valid, "br1":z_valid, "br2":p_valid}), epochs=MAX_NUM_EPOCHS, shuffle=True, callbacks=my_callbacks_list, batch_size=BS)
#model.fit(x=[x_train, y_train] , y=x_train, validation_data=([x_valid, y_valid], x_valid), epochs=MAX_NUM_EPOCHS, shuffle=True, callbacks=my_callbacks_list, batch_size=BS)
#model.fit(x_train , y_train , validation_data=(x_valid, y_valid), epochs=MAX_NUM_EPOCHS, shuffle=True, callbacks=my_callbacks_list, batch_size=BS)

In [None]:
#[preds1, preds2] = model.predict([x_valid[:72, :, :, :], x_valid[:72, :, :, :]])
#preds2[preds2 < 0] = 0
preds, br1, br2  = model.predict(x_valid[:30, :, :, :], batch_size=1)

In [None]:
#preds2 = model.predict([preds, preds])
#[preds3, preds4] = model.predict([preds2, preds2])
#preds2 = model.predict(preds)

In [None]:
num = 2

In [None]:
fig = plt.figure(figsize = (8, 8))
plt.imshow(x_valid[num, :, :, :])

In [None]:
fig = plt.figure(figsize = (8, 8))
plt.imshow(y_valid[num, :, :, :])

In [None]:
fig = plt.figure(figsize = (8, 8))
plt.imshow(preds[num, :, :, :])

In [None]:
fig = plt.figure(figsize = (8, 8))
plt.imshow(br1[num, :, :, :])

In [None]:
fig = plt.figure(figsize = (8, 8))
plt.imshow(br2[num, :, :, :])

In [None]:
for i in range(106):
    a = io.imread("real_flared_100/pic" + str(i) + ".jpg") / 255.
    if( a.shape[0] <= 2000 and a.shape[1] <= 2000):
        model = unet_model(bn = True, pic = True, img_shape=a.shape)
        model.load_weights("nn_weights/unet_glare_equal/weights/weights-improvement-10708-10561.hdf5")
        #model.load_weights("nn_weights/unet_glare5/weights/weights-improvement-24306.hdf5")
        pred, br1, br2 = model.predict(a.reshape(1, a.shape[0], a.shape[1], a.shape[2]).astype(np.float32), batch_size=1)
            
        io.imwrite("ch_real_equal/ch" + str(i) + "_x.jpg", np.uint8(a * 255))
        io.imwrite("ch_real_equal/ch" + str(i) + "_pred.jpg", np.uint8(pred[0, :, :, :] * 255))
        #io.imwrite("ch_real_color_tanh/ch" + str(i) + "_color_pred.jpg", np.uint8(color[0, :, :, :] * 255))
        io.imwrite("ch_real_equal/ch" + str(i) + "_br1.jpg", np.uint8(br1[0, :, :, :] * 255))
        io.imwrite("ch_real_equal/ch" + str(i) + "_br2.jpg", np.uint8(br2[0, :, :, :] * 255))
        
        print(i)

In [None]:
import math
import numpy as np
import imageio as io

def calculate_psnr(img1, img2):
    # img1 and img2 have range [0, 255]
    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    mse = np.mean((img1 - img2)**2)
    if mse == 0:
        return float('inf')
    return 20 * math.log10(255.0 / math.sqrt(mse))

all_psnrs = []
l = [6, 11, 15, 21, 26, 30, 40, 42, 46, 68, 83]
for i in range(106):
    if(i not in l):
        cur_psnr = calculate_psnr(io.imread("ch_real_equal/ch" + str(i) + "_x.jpg"), io.imread("ch_real_equal/ch" + str(i) + "_pred.jpg"))
        all_psnrs.append(cur_psnr)
        #print(i)

In [None]:
aver_psnr = sum(all_psnrs) / len(all_psnrs)
aver_psnr

In [None]:
import cv2
def ssim(img1, img2):
    C1 = (0.01 * 255)**2
    C2 = (0.03 * 255)**2

    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    kernel = cv2.getGaussianKernel(11, 1.5)
    window = np.outer(kernel, kernel.transpose())

    mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]  # valid
    mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
    mu1_sq = mu1**2
    mu2_sq = mu2**2
    mu1_mu2 = mu1 * mu2
    sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
    sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
    sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
                                                            (sigma1_sq + sigma2_sq + C2))
    return ssim_map.mean()

all_ssims = []
l = [6, 11, 15, 21, 26, 30, 40, 42, 46, 68, 83]
for i in range(106):
    if(i not in l):
        cur_ssim = ssim(io.imread("ch_real_equal/ch" + str(i) + "_x.jpg"), io.imread("ch_real_equal/ch" + str(i) + "_pred.jpg"))
        all_ssims.append(cur_ssim)
        #print(i)

In [None]:
aver_ssim = sum(all_ssims) / len(all_ssims)
aver_ssim

In [None]:
#%%timeit
#model.predict(x_valid[:1, :, :, :])

In [None]:
num = 30
preds  = model.predict(x_valid[:num, :, :, :], batch_size=1)
for i in range(num):
    io.imwrite("circles/ch" + str(i) + "_initial.jpg", x_valid[i, :, :, :])
    io.imwrite("circles/ch" + str(i) + "_needed.jpg", y_valid[i, :, :, :])
    io.imwrite("circles/ch" + str(i) + "_predicted.jpg", preds[i, :, :, :])

In [None]:
num = 2

In [None]:
source = np.array(io.imread("ch_real_el_bigger_zero/ch" + str(num) + "_x.jpg")).astype(np.int32)
br1 = np.array(io.imread("ch_real_el_bigger_zero/ch" + str(num) + "_br1.jpg")).astype(np.int32)
br2 = np.array(io.imread("ch_real_el_bigger_zero/ch" + str(num) + "_br2.jpg")).astype(np.int32)
out = np.array(io.imread("ch_real_el_bigger_zero/ch" + str(num) + "_pred.jpg")).astype(np.int32)



In [None]:
res = source - br1 + br2
bz = np.minimum(res, 0)
bz[bz < 0] = 1
bz = np.sum(bz, axis=-1)
bz[bz > 0] = 1
print(np.sum(bz) / res.shape[0] / res.shape[1])

In [None]:
l = [6, 11, 15, 21, 26, 30, 40, 42, 46, 68, 83]
for num in range(0, 106):
    if(num in l):
        continue
    source = np.array(io.imread("ch_real_el_bigger_zero/ch" + str(num) + "_x.jpg")).astype(np.int32)
    br1 = np.array(io.imread("ch_real_el_bigger_zero/ch" + str(num) + "_br1.jpg")).astype(np.int32)
    br2 = np.array(io.imread("ch_real_el_bigger_zero/ch" + str(num) + "_br2.jpg")).astype(np.int32)
    out = np.array(io.imread("ch_real_el_bigger_zero/ch" + str(num) + "_pred.jpg")).astype(np.int32)
    
    res = source - br1 + br2
    bz = np.minimum(res, 0)
    bz[bz < 0] = 1
    summ = np.sum(bz) / res.shape[0] / res.shape[1]
    alp = 1.
    while(summ > 0.05):
        alp -= 0.05
        res = source - alp * (br1 - br2)
        bz = np.minimum(res, 0)
        bz[bz < 0] = 1
        summ = np.sum(bz) / res.shape[0] / res.shape[1]
    
    res = np.minimum(res, 255.)
    res = np.maximum(res, 0)
    res = res.astype(np.uint8)
    
    io.imwrite("ch_real_coef/ch" + str(num) + "_init.jpg", source.astype(np.uint8))
    io.imwrite("ch_real_coef/ch" + str(num) + "_pred.jpg", out.astype(np.uint8))
    io.imwrite("ch_real_coef/ch" + str(num) + "_using_coefs.jpg", res)
    
    print(num)
    
        

In [None]:
#%%timeit
#model.predict(x_valid[:1, :, :, :])