In [2]:
import numpy as np
import cv2
import keras
from keras import optimizers
from keras.optimizers import *
from keras import regularizers
from keras.models import *
from keras.layers import Input, merge, Conv2D, MaxPooling2D, UpSampling2D, Dropout, concatenate, Add, Lambda, Subtract
from math import pi as pi
from scipy import signal
from keras import backend as K
%matplotlib qt
import matplotlib.pyplot as plt


def spiral_kxky(filename, ledNum):
    kxky = [[], []]
    with open(filename, 'r') as file:
        for line in file:
            for j, value in enumerate(line.split(",")):
                kxky[j].append(np.float(value))
    kxky = np.asarray(kxky)
    kxky = kxky.T
    return kxky[:ledNum, :]


def show_result(model, show=0, noShow=10):
    w_conv_Or = model.get_layer('O_FTr').get_weights()
    w_conv_Oi = model.get_layer('O_FTi').get_weights()
    w_conv_Or_array = np.asarray(w_conv_Or)
    w_conv_Oi_array = np.asarray(w_conv_Oi)
    c_real = w_conv_Or_array[0, :, :, 0].reshape((imSize, imSize))
    c_imag = w_conv_Oi_array[0, :, :, 0].reshape((imSize, imSize))
    
    c_complex = c_real + 1j * c_imag
    c_abs = np.abs(c_complex)
    c_phase = np.angle(c_complex)
    im_spatial = np.abs(np.fft.ifft2(np.fft.ifftshift(c_complex)))
    im_phase = np.angle(np.fft.ifft2(np.fft.ifftshift(c_complex)))
    objFTLog = np.log(np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))))
    
    if show:
        plt.figure()
        plt.subplot(233),plt.imshow(np.log(c_abs[noShow:imSize-noShow, noShow:imSize-noShow]+1), cmap='gray'),plt.title('recover (abs)')
        plt.subplot(232),plt.imshow(im_phase[noShow:imSize-noShow, noShow:imSize-noShow], cmap='gray'),plt.title('recover (phase)')
        plt.subplot(231),plt.imshow(im_spatial[noShow:imSize-noShow, noShow:imSize-noShow], cmap='gray'),plt.title('recover FT')
        plt.subplot(234),plt.imshow(np.abs(obj[noShow:imSize-noShow, noShow:imSize-noShow]), cmap='gray'),plt.title('high res (abs)')
        plt.subplot(235),plt.imshow(np.angle(obj[noShow:imSize-noShow, noShow:imSize-noShow]), cmap='gray'),plt.title('high res (phase)')
        plt.subplot(236),plt.imshow(objFTLog[noShow:imSize-noShow, noShow:imSize-noShow], cmap='gray'),plt.title('high res FT')
        plt.show()
        
    return c_complex


class MyLayer(Layer):

    def __init__(self, output_dims, **kwargs):
        self.output_dims = output_dims

        super(MyLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        # Create a trainable weight variable for this layer.
        self.kernel = self.add_weight(name='kernel',
                                      shape=self.output_dims,
                                      initializer='ones',
                                      trainable=True)

        super(MyLayer, self).build(input_shape)  # Be sure to call this somewhere!

    def call(self, x):
        return x*self.kernel

    def compute_output_shape(self, input_shape):
        return (self.output_dims)

In [3]:
# Set parameters
wlength = 0.532*1e-6
NA = 0.1
k0 = 2 * pi / wlength
spsize = (3.45*1e-6)/2
psize = spsize/4
imSize = 128
imCenter = int(imSize / 2)
arraysize = 15
NAstep = 0.05

In [4]:
# Load image
imgAmp = cv2.imread('cameraman.bmp', 0)+10
imgAmp = cv2.resize(imgAmp, (imSize, imSize), interpolation=cv2.INTER_CUBIC).astype(float) # input amplitude
imgPhase = cv2.imread('westconcordorthophoto.bmp', 0)
imgPhase = cv2.resize(imgPhase, (imSize, imSize), interpolation=cv2.INTER_CUBIC).astype(float) # input phase
imgPhase = cv2.normalize(imgPhase, None, -1, 1.0, cv2.NORM_MINMAX)
obj = imgAmp * np.exp(1j * 0.5 * pi * imgPhase)

# Generate CTF
dkxy = 2*pi/psize/(imSize-1)
cutoffFrequency = (NA * k0 / dkxy)
center = [imCenter, imCenter]
kYY, kXX = np.ogrid[:imSize, :imSize]
CTF = np.sqrt((kXX - center[0]) ** 2 + (kYY - center[1]) ** 2) <= cutoffFrequency
CTF = CTF.astype(float)

# Show input image and CTF
plt.figure()
plt.subplot(1, 3, 1),plt.imshow(imgAmp, cmap='gray'),plt.title('Amplitude')
plt.subplot(1, 3, 2),plt.imshow(imgPhase, cmap='gray'),plt.title('Phase')
plt.subplot(1, 3, 3),plt.imshow(CTF, cmap='gray'),plt.title('CTF')
plt.show()

In [5]:
# Generate low res images
imgs_train_input1 = np.ndarray((arraysize ** 2, imSize, imSize, 1),dtype=np.complex64) # input CTF
imgs_train_input2 = np.ndarray((arraysize ** 2, imSize, imSize, 1),dtype=np.float32) # input measurement
kxky = spiral_kxky('spiral_kxky.txt', arraysize ** 2) # load kx, ky here
print('kxky shape:',kxky.shape)
for i in range(arraysize ** 2):
    kx = kxky[i,0] * NAstep
    ky = kxky[i,1] * NAstep
    kxIllu = int(kx * k0 / dkxy)
    kyIllu = int(ky * k0 / dkxy)
    ctfIllu = np.roll(CTF, [kxIllu, kyIllu], axis=(0, 1))  
    imFT = np.fft.fftshift(np.fft.fft2(obj))
    imLowpassFT = imFT * ctfIllu   
    imLowpass = np.fft.ifft2(np.fft.ifftshift(imLowpassFT))
    imMeasurement = np.square(np.abs(imLowpass))
    
    imgs_train_input1[i, :, :, 0] = ctfIllu.astype(np.complex64)
    imgs_train_input2[i, :, :, 0] = imMeasurement 
    
# show result    
plt.figure()
plt.subplot(1, 2, 1),plt.imshow(np.real(imgs_train_input1[0, :, :, 0]), cmap='gray'),plt.title('Input CTF')
plt.subplot(1, 2, 2),plt.imshow(imgs_train_input2[0, :, :, 0], cmap='gray'),plt.title('Input measurement')
plt.show()

kxky shape: (225, 2)


In [7]:
# input layer
input_CTF = Input((imSize, imSize, 1),dtype='complex64', name='input_CTF')  # CTF
input_measurement = Input((imSize, imSize, 1), name='input_measurement')  # measurement
# define O (FT)
O_FTr = MyLayer((imSize, imSize, 1), input_shape= (imSize, imSize, 1), name='O_FTr')
O_FTi = MyLayer((imSize, imSize, 1), input_shape= (imSize, imSize, 1), name='O_FTi')
# CTF * O (FT)
CTFr = Lambda(lambda x: tf.real(x))(input_CTF)
CTFi = Lambda(lambda x: tf.imag(x))(input_CTF)
CrOr = O_FTr(CTFr)
CiOi = O_FTi(CTFi)
CrOi = O_FTi(CTFr)
CiOr = O_FTr(CTFi)
# generate low resolution image (FT)
lowFT_r = Subtract()([CrOr, CiOi])
lowFT_i = Add()([CrOi, CiOr])
lowFT = Lambda(lambda x: tf.cast(x[0], tf.complex64) + 1j * tf.cast(x[1], tf.complex64))([lowFT_r, lowFT_i])
# do ifft
im_iFT = Lambda(lambda x: tf.ifft3d(tf.manip.roll(tf.cast(x, tf.complex64),[64, 64], axis=[0, 1])))(lowFT)
# keep angle, and use sqrt(I) to change the anplitude
iFT_angle = Lambda(lambda x: tf.angle(tf.cast(x, tf.complex64)))(im_iFT)
sqrtI = Lambda(lambda x: tf.sqrt(x))(input_measurement)
sinAngle = Lambda(lambda x: tf.sin(x[0]) * x[1]) ([iFT_angle, sqrtI])
cosAngle = Lambda(lambda x: tf.cos(x[0]) * x[1]) ([iFT_angle, sqrtI])
im_iFT_2 = Lambda(lambda x: tf.cast(x[0], tf.complex64) + 1j * tf.cast(x[1], tf.complex64))([cosAngle, sinAngle])
# do fft
lowFT_2 = Lambda(lambda x: K.tf.manip.roll(tf.fft3d(tf.cast(x, tf.complex64)),[64, 64],axis=[0, 1]))(im_iFT_2)
# calculate the difference between lowFT and lowFT_2
output = Lambda(lambda x: tf.square(tf.abs(tf.cast(x[0], tf.complex64)-tf.cast(x[1], tf.complex64))))([lowFT, lowFT_2])

model = Model(inputs=[input_CTF,input_measurement], outputs=[output])
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_CTF (InputLayer)          (None, 128, 128, 1)  0                                            
__________________________________________________________________________________________________
lambda_12 (Lambda)              (None, 128, 128, 1)  0           input_CTF[0][0]                  
__________________________________________________________________________________________________
lambda_13 (Lambda)              (None, 128, 128, 1)  0           input_CTF[0][0]                  
__________________________________________________________________________________________________
O_FTr (MyLayer)                 (128, 128, 1)        16384       lambda_12[0][0]                  
                                                                 lambda_13[0][0]                  
__________

In [8]:
# set high resolution image as the conv2D layer's weight
weight_or = np.ndarray((1, imSize, imSize, 1))
weight_oi = np.ndarray((1, imSize, imSize, 1))
weight_or[0, :, :, 0] = np.real(imFT)
weight_oi[0, :, :, 0] = np.imag(imFT)
model.get_layer('O_FTr').set_weights(weight_or)
model.get_layer('O_FTi').set_weights(weight_oi)

# predict to get output sequences (ground truth)
model.compile(loss='mean_absolute_error', optimizer=Adam(lr = 0,decay = 0.0))
imgs_test_predict = model.predict([imgs_train_input1,imgs_train_input2], batch_size=1, verbose=1)
plt.figure()  
plt.imshow(np.abs(imgs_test_predict[0, :, :, 0]),cmap='gray'),plt.title('output'),plt.colorbar()
plt.show()



In [9]:
# set low res image FT as the initial weight
imlowFT1 = np.fft.fftshift(np.fft.fft2(np.sqrt(imgs_train_input2[0, :, :, 0])))
weight_or[0, :, :, 0] = np.real(imlowFT1)
weight_oi[0, :, :, 0] = np.imag(imlowFT1)
model.get_layer('O_FTr').set_weights(weight_or)
model.get_layer('O_FTi').set_weights(weight_oi)

# train net
adam = Adam(lr = 3e2,amsgrad = True)
model.compile(loss='mean_absolute_error', optimizer=adam)
history = model.fit([imgs_train_input1, imgs_train_input2], imgs_test_predict, batch_size=1, epochs=5, verbose=1, shuffle=False)
imRecover = show_result(model, 1)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
