In [1]:
import tensorflow as tf
import numpy as np
from scipy import misc
import matplotlib.pyplot as plt

  from ._conv import register_converters as _register_converters


In [2]:
class convolution:
    
    def __init__(self,in_channel,ot_channel,kernel_size,strides):
        
        self.in_channel = in_channel//4
        self.ot_channel = ot_channel//4
        self.weight_size = [self.ot_channel,self.in_channel,kernel_size,kernel_size]
        self.kernel  = (kernel_size,kernel_size)
        self.kernel_size  =  kernel_size
        self.stride  =  strides
             
        #initialize the weights
        self.glorat_normal =  1./np.sqrt(2*(self.in_channel+self.ot_channel))
        self.r_weight  =  tf.random_normal(self.weight_size,stddev=self.glorat_normal)
        self.i_weight  = tf.random_normal(self.weight_size,stddev=self.glorat_normal)
        self.j_weight  =  tf.random_normal(self.weight_size,stddev=self.glorat_normal)
        self.k_weight  =  tf.random_normal(self.weight_size,stddev=self.glorat_normal)
        
        
        
     
        
    def conv2D(self,x):
        
        qr_kernelR = tf.concat([self.r_weight,-self.i_weight,-self.j_weight,-self.k_weight],axis=1)
        qr_kerneli = tf.concat([self.i_weight,self.r_weight,-self.j_weight,self.k_weight],axis=1)
        qr_kernelj = tf.concat([self.j_weight,self.i_weight,self.r_weight,-self.k_weight],axis=1)
        qr_kernelk = tf.concat([self.k_weight,-self.i_weight,self.j_weight,self.r_weight],axis=1)
        
        weight =tf.concat([qr_kernelR,qr_kerneli,qr_kernelj,qr_kernelk],axis=0)
        
        
        
        
        weight  = tf.reshape(weight,(weight.shape[3],weight.shape[2],weight.shape[1],weight.shape[0]))
        
        weight  = tf.Variable(weight)
      
        bias= tf.random_uniform([weight.shape[3].value])
              
        
        Z = tf.nn.conv2d(x,weight,strides=[1,2,2,1],padding="SAME")
       
        Z = tf.layers.batch_normalization(Z,trainable=True,momentum=0.1)
        
        Z = tf.add(Z,bias)
        
        output  = tf.nn.sigmoid(Z)
        
        return output
    
    
    def calculateOutputShape(self,input_shape,stride,kernel_size):
        
        x =  (input_shape -1)*stride +2
        return x.value

    
    def conv2DTranspose(self,x):
        
        qr_kernelR = tf.concat([self.r_weight,-self.i_weight,-self.j_weight,-self.k_weight],axis=1)
        qr_kerneli = tf.concat([self.i_weight,self.r_weight,-self.j_weight,self.k_weight],axis=1)
        qr_kernelj = tf.concat([self.j_weight,self.i_weight,self.r_weight,-self.k_weight],axis=1)
        qr_kernelk = tf.concat([self.k_weight,-self.i_weight,self.j_weight,self.r_weight],axis=1)
        
        weight = tf.concat([qr_kernelR,qr_kerneli,qr_kernelj,qr_kernelk],axis=0)
        weight  = tf.reshape(weight,(weight.shape[3],weight.shape[2],weight.shape[0],weight.shape[1]))
        
        weight  =  tf.Variable(weight)
        
        #H1(input_size)*stride  if padding == SAME
    
        H1  = [1,self.calculateOutputShape(x.shape[1],self.stride,self.kernel_size),self.calculateOutputShape(x.shape[2],self.stride,self.kernel_size),weight.shape[2].value]
        
        Z = tf.nn.conv2d_transpose(x,weight,strides=[1,2,2,1],output_shape=H1,padding="SAME")
       
        Z = tf.layers.batch_normalization(Z,trainable=True,momentum=0.1)
        bias  =  tf.random_normal([weight.shape[2].value])        
        Z  = tf.add(Z,bias)

        output = tf.nn.sigmoid(Z)
     
        return output
    
    
    @staticmethod
    def forwardModel(x):
        #encoder
        cnv = convolution(4,32,3,2)
        output1  = cnv.conv2D(x)
        
      
        cnv =  convolution(32,64,3,2)
        output2 = cnv.conv2D(output1)
       
        #decoder
        cnv =  convolution(64,32,3,2)
        output3 = cnv.conv2DTranspose(output2)
      
        cnv =  convolution(32,4,3,2)
        y = cnv.conv2DTranspose(output3)
       
        return y

    


        
        

In [3]:
def rgb2gray(rgb):
    return np.dot(rgb[...,:3], [0.2989, 0.5870, 0.1140])

In [5]:
from PIL import Image
size = 400

train_original = Image.open("kodim05.png")
train_original = train_original.resize((size,size))
train_original = np.array(train_original)
train_original = rgb2gray(train_original)


train = np.zeros((size,size,4))
train[:,:,0]= np.zeros(size)
train[:,:,1]= train_original
train[:,:,2]= train_original
train[:,:,3]= train_original
train = np.expand_dims(train,0)

train = train.astype('float32')
train /= 255


test_original = Image.open("kodim14.png")
test_original = test_original.resize((size,size))
test_original = np.asarray(test_original)

test = np.zeros((size,size,4))
test[:,:,0]= np.zeros(size)
test[:,:,1]= test_original[:,:,0]
test[:,:,2]= test_original[:,:,1]
test[:,:,3]= test_original[:,:,2]
test = np.expand_dims(test,0)

test = test.astype('float32')

test /=  255


X  = tf.placeholder('float32',[None,train.shape[1],train.shape[2],train.shape[3]])
Y  = tf.placeholder('float32',[None,train.shape[1],train.shape[2],train.shape[3]])
                    

In [6]:
def evaluation():
    sess = tf.Session()
    print(sess.run(tf.image.psnr(np.squeeze(test_output,0)[:,:,1:],test_original/255,max_val=1)))
    print(sess.run(tf.image.ssim(tf.convert_to_tensor(np.squeeze(test_output,0)[:,:,1:]),tf.convert_to_tensor(test_original.astype('float32')/255),max_val=1)))
    sess.close()
    

In [7]:
# forward output 
pred = convolution.forwardModel(X)

#MSE
loss  =  tf.losses.mean_squared_error(Y,pred)
optimizer =  tf.train.AdamOptimizer(learning_rate=0.01).minimize(loss)

In [None]:
init =  tf.global_variables_initializer()

y_predict = []
with tf.Session() as sess:
    sess.run(init)
    for i in range(3000):
        loss_value,opt = sess.run([loss,optimizer],feed_dict={X:train,Y:train})
        if i % 500 ==0:
            print(loss_value)
    train_output = sess.run(pred,feed_dict={X:train})
    test_output =  sess.run(pred,feed_dict={X:test})

plt.imshow(np.squeeze(train_output,0)[:,:,1:])
plt.show()
plt.imshow(train_original)
plt.show()
plt.imshow(np.squeeze(test_output,0)[:,:,1:])
plt.show()
plt.imshow(test_original)
plt.show()
evaluation()

0.1487344
0.027891656
0.013148128
0.020367214
0.02821776
0.017417014
