In [1]:
import os
import json

In [2]:
import tensorflow as tf
tf.__version__

'2.6.0'

In [3]:
#images split into 256x256x3 pieces
NUMR = 256

In [4]:
trainset = tf.data.Dataset.list_files(r"train/short/*.bin", shuffle=True)
testset = tf.data.Dataset.list_files(r"val/short/*.bin")

In [6]:
def processfile(f):
    '''
    Data pair
    '''    
    g = tf.strings.split(f, sep='/')
    z = tf.strings.join([g[0], 'medium', g[2]], separator='/')
    shortdat = tf.io.read_file(f)
    shortdat = tf.io.decode_raw(shortdat, tf.uint8)
    shortdat = tf.cast(shortdat, tf.float32) / 255
    shortdat = tf.reshape(shortdat, (NUMR,NUMR,3))
    meddat = tf.io.read_file(z)
    meddat = tf.io.decode_raw(meddat, tf.uint8)
    meddat = tf.cast(meddat, tf.float32) / 255
    meddat = tf.reshape(meddat, (NUMR,NUMR,3))
    return shortdat, meddat

In [7]:
datset = trainset.map(processfile)
valset = testset.map(processfile)

In [10]:
def safepowneg(x):
    return tf.pow(tf.abs(x),1/3)
def safepowpos(x):
    return tf.pow(tf.abs(x),2.4)

def labconvert(rgb):
    '''
    RGB to LAB (Standard Illuminant D65)
    '''
    matrix = tf.constant([[0.43388193, 0.37622739, 0.18990225],
       [0.2126    , 0.7152    , 0.0722    ],
       [0.01772529, 0.1094743 , 0.87294736]], dtype=tf.float32)
    shmatrix = tf.constant([[0, 116, 0], [500, -500, 0], [0, 200, -200]], dtype=tf.float32)
    val1 = tf.constant(0.04045, dtype=tf.float32)
    val2 = tf.constant(0.008856451679035631, dtype=tf.float32)

    p = tf.clip_by_value(rgb, clip_value_min=0, clip_value_max=1)
    f = tf.where(p <= val1, 0.07739938080495357 * p, safepowpos(0.9478672985781991 * (p + 0.055)))
    x = tf.einsum('ij,...j->...i', matrix, f)
    die = tf.where(x > val2, safepowneg(x), 0.13793103448275862 + 7.787037037037036 * x)
    fie = tf.einsum('ij,...j->...i', shmatrix, die)
    return fie

In [12]:
def labconvertL(rgb):
    matrix = tf.constant([[0.2126    , 0.7152    , 0.0722    ]], dtype=tf.float32)
    val1 = tf.constant(0.04045, dtype=tf.float32)
    val2 = tf.constant(0.008856451679035631, dtype=tf.float32)
    
    p = tf.clip_by_value(rgb, clip_value_min=0, clip_value_max=1)
    f = tf.where(p <= val1, 0.07739938080495357 * p, safepowpos(0.9478672985781991 * (p + 0.055)))
    x = tf.einsum('ij,...j->...i', matrix, f)
    die = tf.where(x > val2, safepowneg(x), 0.13793103448275862 + 7.787037037037036 * x)
    fie = 116 * die - 16
    return fie

In [18]:
import keras

In [19]:
from keras.models import Model
from keras.layers import Input, Conv2D, MaxPooling2D, GlobalAveragePooling2D, Dense, Lambda

In [20]:
from keras.layers import Add, Concatenate

In [21]:
import tensorflow.keras.backend as K

In [22]:
def mean_ssim(y_true, y_pred):
    lab_true = labconvert(y_true)
    lab_pred = labconvert(y_pred)
    return tf.image.ssim_multiscale(lab_true,lab_pred,255)

In [63]:
def mean_ssimL(y_true, y_pred):
    lab_true = labconvertL(y_true)
    lab_pred = labconvertL(y_pred)
    return tf.image.ssim_multiscale(lab_true,lab_pred,100)

In [24]:
def mean_psnr(y_true, y_pred):
    return tf.image.psnr(y_true, y_pred, 2)

In [25]:
def custom_loss_1(y_true, y_pred):
    '''
    l1 loss
    '''
    lab_true = labconvert(y_true)
    lab_pred = labconvert(y_pred)
    loss = K.mean(K.abs(lab_true-lab_pred))
    return loss

In [64]:
def custom_loss_2(y_true, y_pred):
    '''
    Loss with ssim_multiscale -- fails to converge :(
    '''
    lab_true = labconvert(y_true)
    lab_pred = labconvert(y_pred)
    maeloss = K.mean(K.abs(lab_true-lab_pred))
    
    lab_trueL = labconvertL(y_true)
    lab_predL = labconvertL(y_pred)
    
    ssimloss = 1-tf.image.ssim_multiscale(lab_trueL,lab_predL,100)
    return maeloss+ssimloss

In [27]:
def lowlevelblock(inputdata, prefix, previous=None):
    '''
    Low-level
    '''
    if (previous == None):
        x = inputdata
    else:
        x = Concatenate(name=f'concat_{prefix}')([inputdata, previous])
    x = Conv2D(filters=61,
               kernel_size=(3, 3),
               activation='relu',
               padding='same',
               strides=(1,1),
               name=f'll_relu_{prefix}')(x)
    y = Conv2D(filters=3,
               kernel_size=(3, 3),
               activation='tanh',
               padding='same',
               strides=(1,1),
               name=f'll_th_{prefix}')(x)
    img = Add(name=f'll_add_{prefix}')([inputdata, y])
    return img, x   

In [28]:
def highlevelblock(inputdata, prefix):
    '''
    High-level
    '''
    
    x = Conv2D(filters=64,
               kernel_size=(3, 3),
               activation='relu',
               padding='same',
               strides=(2,2),
               name=f'hl_conv_{prefix}')(inputdata)
    x = MaxPooling2D(pool_size=(2,2),name=f'hl_pool_{prefix}')(x)

    return x

In [29]:
def comp(x):
    '''
    bilinear transformation
    '''
    a = x[0]
    b = x[1]
    c = tf.einsum('aik,apkj,aij->aip', b, a, b)
    return c

def mxmake(m):
    '''
    Three 4x4 matrices from one 30-component vector 
    '''
    cc = tf.reshape(m,(-1,3,10))
    ee = tf.constant([
        [[1,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]],
        [[0,1,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]],
        [[0,0,1,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]],
        [[0,0,0,1],[0,0,0,0],[0,0,0,0],[0,0,0,0]],
        [[0,0,0,0],[0,1,0,0],[0,0,0,0],[0,0,0,0]],
        [[0,0,0,0],[0,0,1,0],[0,0,0,0],[0,0,0,0]],
        [[0,0,0,0],[0,0,0,1],[0,0,0,0],[0,0,0,0]],
        [[0,0,0,0],[0,0,0,0],[0,0,1,0],[0,0,0,0]],
        [[0,0,0,0],[0,0,0,0],[0,0,0,1],[0,0,0,0]],
        [[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,1]]
    ], dtype=tf.float32)
    mx = tf.einsum('...j,jkl->...kl',cc,ee)
    return mx

def vecmake(x):
    '''
    Extending vector with a constant 
    '''
    return tf.pad(x, tf.constant([[0,0],[0,0],[0,1]]), "CONSTANT", constant_values=1)

### Defining model

In [30]:
input_img = Input(shape=(NUMR, NUMR, 3))

img = input_img
out = None

for lowstage in range(10):
    img, out = lowlevelblock(img, lowstage, out)

for highstage in range(4):
    out = highlevelblock(out, highstage)

out = GlobalAveragePooling2D(name="hl_glob_pool")(out)
out = Dense(30, name="dense")(out)

out = Lambda(lambda x:mxmake(x), name="lambda4x4")(out)

img = tf.keras.layers.Reshape((NUMR*NUMR,3), name="reshape_1")(img)
img = Lambda(lambda x:vecmake(x), name="lambda_addconst")(img)
res = Lambda(lambda x:comp(x), name="lambda_bilinear")((out,img))
res = tf.keras.layers.Reshape((NUMR,NUMR,3), name="reshape_2")(res)

model = Model(input_img, res, name="coder")

In [31]:
model.summary()

Model: "coder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
ll_relu_0 (Conv2D)              (None, 256, 256, 61) 1708        input_1[0][0]                    
__________________________________________________________________________________________________
ll_th_0 (Conv2D)                (None, 256, 256, 3)  1650        ll_relu_0[0][0]                  
__________________________________________________________________________________________________
ll_add_0 (Add)                  (None, 256, 256, 3)  0           input_1[0][0]                    
                                                                 ll_th_0[0][0]                

### Training

In [37]:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001), loss=custom_loss_1,
              metrics=[mean_psnr, mean_ssimL, mean_ssim])

In [None]:
model_history = model.fit(datset.batch(8),
                          epochs=4,
                          validation_data=valset.batch(4))

In [105]:
heavycfg = model.to_json()

In [189]:
with open('model/heavy_256.json', 'w') as outfile:
    json.dump(heavycfg, outfile)

In [190]:
model.save_weights('model/heavy')