In [2]:
import numpy as np
import cv2
import keras
from keras import optimizers
from keras.models import *
from keras.layers import Input, merge, Conv2D, concatenate, Add
from keras.optimizers import *
from math import pi as pi
from scipy import signal
from keras import backend as K
from math import pi as pi
%matplotlib qt
import matplotlib.pyplot as plt


def activation_square(x):
    return np.square(x)


def spiral_kxky(filename, ledNum):
    kxky = [[], []]
    with open(filename, 'r') as file:
        for line in file:
            for j, value in enumerate(line.split(",")):
                kxky[j].append(np.float(value))
    kxky = np.asarray(kxky)
    kxky = kxky.T
    return kxky[:ledNum, :]


def show_result(model, show=0, noShow=10):
    w_conv1 = model.get_layer('conv_O').get_weights()
    w_conv1_array = np.asarray(w_conv1)
    c_real = w_conv1_array[:, :, :, 0, :].reshape((imSize, imSize))
    c_imag = w_conv1_array[:, :, :, 1, :].reshape((imSize, imSize))
    
    c_complex = c_real + 1j * c_imag  
    c_abs = np.flip(np.flip(np.abs(c_complex), 0), 1)
    c_phase = np.flip(np.flip(np.angle(c_complex), 0), 1)
    c_complexFTLog = np.log(np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(c_complex)))))
    objFTLog = np.log(np.abs(np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(obj)))))
    
    if show:
        plt.figure()
        plt.subplot(231),plt.imshow(c_abs[noShow:imSize-noShow, noShow:imSize-noShow], cmap='gray'),plt.title('recover (abs)')
        plt.subplot(232),plt.imshow(c_phase[noShow:imSize-noShow, noShow:imSize-noShow], cmap='gray'),plt.title('recover (phase)')
        plt.subplot(233),plt.imshow(c_complexFTLog[noShow:imSize-noShow, noShow:imSize-noShow], cmap='gray'),plt.title('recover FT')
        plt.subplot(234),plt.imshow(np.abs(obj[noShow:imSize-noShow, noShow:imSize-noShow]), cmap='gray'),plt.title('high res (abs)')
        plt.subplot(235),plt.imshow(np.angle(obj[noShow:imSize-noShow, noShow:imSize-noShow]), cmap='gray'),plt.title('high res (phase)')
        plt.subplot(236),plt.imshow(objFTLog[noShow:imSize-noShow, noShow:imSize-noShow], cmap='gray'),plt.title('high res FT')
        plt.show()
        
    return c_complex

In [3]:
# Set parameters
wlength = 0.532*1e-6
NA = 0.1
k0 = 2 * pi / wlength
spsize = (3.45*1e-6)/2
psize = spsize/4
imSize = 128
imCenter = int(imSize / 2)
arraysize = 15
NAstep = 0.05
index_downSample = 1 # downsample: index_downSample=4

In [4]:
# Load image
imgAmp = cv2.imread('cameraman.bmp', 0)+10
imgAmp = cv2.resize(imgAmp, (imSize, imSize), interpolation=cv2.INTER_CUBIC).astype(float) # input amplitude
imgPhase = cv2.imread('westconcordorthophoto.bmp', 0)
imgPhase = cv2.resize(imgPhase, (imSize, imSize), interpolation=cv2.INTER_CUBIC).astype(float) # input phase
imgPhase = cv2.normalize(imgPhase, None, -1, 1.0, cv2.NORM_MINMAX)
obj = imgAmp * np.exp(1j * 0.5 * pi * imgPhase)

# Generate CTF
dkxy = 2*pi/psize/(imSize-1)
cutoffFrequency = (NA * k0 / dkxy)
center = [imCenter, imCenter]
kYY, kXX = np.ogrid[:imSize, :imSize]
CTF = np.sqrt((kXX - center[0]) ** 2 + (kYY - center[1]) ** 2) <= cutoffFrequency
CTF = CTF.astype(float)

# Show input image and CTF
plt.figure()
plt.subplot(1, 3, 1),plt.imshow(imgAmp, cmap='gray'),plt.title('Amplitude')
plt.subplot(1, 3, 2),plt.imshow(imgPhase, cmap='gray'),plt.title('Phase')
plt.subplot(1, 3, 3),plt.imshow(CTF, cmap='gray'),plt.title('CTF')
plt.show()

In [5]:
# Generate low res images
imgs_train_input1 = np.ndarray((arraysize ** 2, imSize, imSize, 2)) # input real(PSF), -imag(PSF)
imgs_train_input2 = np.ndarray((arraysize ** 2, imSize, imSize, 2)) # input imag(PSF), real(PSF)
kxky = spiral_kxky('spiral_kxky.txt', arraysize ** 2)   # load kx, ky here
print('kxky shape:',kxky.shape)
for i in range(arraysize ** 2):
    kx = kxky[i,0] * NAstep
    ky = kxky[i,1] * NAstep
    kxIllu = int(kx * k0 / dkxy)
    kyIllu = int(ky * k0 / dkxy)
    ctfIllu = np.roll(CTF, [kxIllu, kyIllu], axis=(0, 1))    
    psfIllu = np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(ctfIllu)))
    psfIlluReal = np.real(psfIllu)
    psfIlluImag = np.imag(psfIllu)
    
    imgs_train_input1[i, :, :, 0] = 1 * psfIlluReal
    imgs_train_input1[i, :, :, 1] = -1 * psfIlluImag    
    imgs_train_input2[i, :, :, 0] = 1 * psfIlluImag
    imgs_train_input2[i, :, :, 1] = 1 * psfIlluReal

# show result    
plt.figure()
plt.subplot(1, 2, 1),plt.imshow(imgs_train_input1[0, :, :, 0], cmap='gray'),plt.title('Input real(PSF)')
plt.subplot(1, 2, 2),plt.imshow(imgs_train_input2[0, :, :, 0], cmap='gray'),plt.title('Input imag(PSF)')
plt.show()

kxky shape: (225, 2)


In [6]:
# input layer
input_1 = Input((imSize, imSize, 2), name='input_1')  # channel 1: Pr, channel 2: -Pi
input_2 = Input((imSize, imSize, 2), name='input_2')  # channel 1: Pi, channel 2: Pr
# define O
conv_O = Conv2D(1, imSize, activation=activation_square, padding='same', strides=index_downSample, 
                 kernel_initializer='one', bias_initializer='zero', use_bias=False, name='conv_O')
# generate low res images
conv1_1 = conv_O(input_1)
conv1_2 = conv_O(input_2)
addLayer = Add()([conv1_1, conv1_2])

model = Model(inputs=[input_1, input_2], outputs=addLayer)
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 128, 128, 2)  0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            (None, 128, 128, 2)  0                                            
__________________________________________________________________________________________________
conv_O (Conv2D)                 (None, 128, 128, 1)  32768       input_1[0][0]                    
                                                                 input_2[0][0]                    
__________________________________________________________________________________________________
add_1 (Add)                     (None, 128, 128, 1)  0           conv_O[0][0]                     
          

In [7]:
# set high resolution image as the conv2D layer's weight
weight_o = np.ndarray((1, imSize, imSize, 2, 1))
weight_o[0, :, :, 0, 0] = np.flip(np.flip(np.real(obj), 1), 0)
weight_o[0, :, :, 1, 0] = np.flip(np.flip(np.imag(obj), 1), 0)
model.get_layer('conv_O').set_weights(weight_o)

# predict to get low resolution image sequences
model.compile(loss='mean_absolute_error', optimizer=Adam(lr = 0.0, decay = 0.0))
imgs_test_predict = model.predict([imgs_train_input1, imgs_train_input2], batch_size=1, verbose=1)
plt.figure() 
plt.imshow(imgs_test_predict[0, :, :, 0],cmap='gray'),plt.title('measurement')
plt.show()



In [8]:
# set low res image as the initial weight
weight_o[0, :, :, 0, 0] = np.flip(np.flip(np.sqrt(np.resize(imgs_test_predict[0, :, :, 0],(imSize,imSize))/(index_downSample**2)),1),0) 
weight_o[0, :, :, 1, 0] = np.flip(np.flip(np.sqrt(np.resize(imgs_test_predict[0, :, :, 0],(imSize,imSize))/(index_downSample**2)),1),0)
model.get_layer('conv_O').set_weights(weight_o)

# train net
adam = Adam(lr = 1, amsgrad=True)
model.compile(loss='mean_absolute_error', optimizer=adam)
history = model.fit([imgs_train_input1, imgs_train_input2], imgs_test_predict, batch_size=1, epochs=10, verbose=1, shuffle=False)
imRecover = show_result(model, 1)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
