In [None]:
import tensorflow as tf

import os
import numpy as np
import matplotlib.pyplot as plt
import statistics
import pickle

import math
import cv2
import numpy as np
from numpy import cos as cos
from numpy import sin as sin
from numpy import sqrt as sqrt
from numpy import arctan2 as arctan2
from matplotlib import pyplot as plt
import os
import datetime
import time

import sys
sys.path.append('../common/')

from functions import *
from tf_functions import *
from zernike_functions import *

In [None]:
def init_param():
    ### N: pix num, p: pix size[m]
    N = 1024
    p = 6.4e-6

    ### l_ambda: wavelength
    l_ambda = 520e-9

    ### z: distance from aperture to screen
    z = 0.2

    ### pupil_r: pupil radius
    pupil_r_m = N*p/2
    pupil_r_mm = pupil_r_m*1000

    ### focus length of lens
    lens_f = 0.1

    ##### Zernike Efficients Setting #####
    nm_arr = [[2,0], [3,1]]
    nm_coeff_arr = [0.0, 0.0]

    return N, z, p, l_ambda, pupil_r_m, pupil_r_mm, lens_f, nm_arr, nm_coeff_arr

In [None]:
gpu_id = 1
print(tf.__version__)
if tf.__version__ >= "2.1.0":
    physical_devices = tf.config.list_physical_devices('GPU')
    tf.config.list_physical_devices('GPU')
    tf.config.set_visible_devices(physical_devices[gpu_id], 'GPU')
    tf.config.experimental.set_memory_growth(physical_devices[gpu_id], True)
    print('set gpu_id')

In [None]:
##### Initialization #####
### Initial Params
N, z, p, l_ambda, pupil_r_m, pupil_r_mm, lens_f, nm_arr, nm_coeff_arr = init_param()
back_z = z * -1.0
k = 2*np.pi/l_ambda

### Pupil Function
p_xy = CGH.pupil_func(N, p, pupil_r_m)

### Wavefront Aberration
W_xy = np.zeros((N, N))
for i in range(len(nm_arr)):
    nm = nm_arr[i]
    nm_coeff = nm_coeff_arr[i]
    temp_W_xy = CGH.norm_wave_aberration(N, nm)
    temp_W_xy = CGH.resize_and_add_pad(temp_W_xy, N, p, pupil_r_m)
    W_xy += temp_W_xy * nm_coeff
# W_xy = W_xy * 0.5
# W_xy = W_xy * 2.0

### Reverse Wavefront Aberration
rev_W_xy = W_xy * -1.0

### Generalized Pupil Function
P_xy = CGH.general_pupil(p_xy, W_xy, l_ambda)
rev_P_xy = CGH.general_pupil(p_xy, rev_W_xy, l_ambda)

### Directory and Path for Save
dir_fn = ''
fn = ''
out_path = PreProcess.mkdir_out_dir(dir_fn, fn, N, p, z)


##### Check Aliasing #####
CGH.check_aliasing(lens_f, N, l_ambda, p, z)


##### Input Image #####
in_dir = '../input/'
amp_img_np = cv2.imread(in_dir + 'baboon_gray.png', cv2.IMREAD_GRAYSCALE)
# amp_img_np = cv2.imread(in_dir + 'peppers_gray.png', cv2.IMREAD_GRAYSCALE)
# amp_img_np = cv2.imread(in_dir + 'USAF_1951.png', cv2.IMREAD_GRAYSCALE)
size = (N, N)
amp_img_np = cv2.resize(amp_img_np, size)
amp_img_np = amp_img_np / 255.0
amp_img = tf.constant(amp_img_np)
amp_img = tf.dtypes.cast(amp_img, tf.float64)


##### Initial Random Phase #####
# phase = np.random.uniform(-np.pi, np.pi, size)
random_phase = np.random.uniform(0.0, 2.0*np.pi, size)
init_phase = tf.constant(random_phase)
phase = tf.Variable(random_phase)


##### Incident Light toward SLM #####
light_img = tf.fill(size, 1.0)
light_img = tf.dtypes.cast(light_img, tf.complex128)


##### Target Image #####
source_amp = np.sqrt(amp_img)
source_img = source_amp

# When using Intensity Distribution
goal_img = amp_img
goal_img_tf = tf.constant(goal_img)
goal_img_tf = tf.dtypes.cast(goal_img_tf, tf.complex128)

target = amp_img

In [None]:
# opt = tf.keras.optimizers.Adam(learning_rate=1.0)
opt = tf.keras.optimizers.Adam(learning_rate=0.1)

def loss_func():
    phase_exp = tf.dtypes.complex(tf.math.cos(phase), tf.math.sin(phase))

    # Input Complex Amplitude
    init_plane = tf.math.multiply(light_img, phase_exp)

    # Propagation from SLM to Image Plane with normalization
    init_plane_add = tf_CGH.add_zero_padding(init_plane)
    N = init_plane_add.shape[0]
    prop = tf_CGH.band_limited_angular_spectrum(init_plane_add, k, N, l_ambda, z, p)
#     prop = tf_CGH.angular_spectrum(init_plane_add, k, N, l_ambda, z, p)
#     prop = tf_CGH.normalize_amp_one(prop)
    prop_rm = tf_CGH.remove_zero_padding(prop)
    N = prop_rm.shape[0]

    # Show Images
    imgs = [
        CGH.amp_abs(CGH.intensity(prop_rm.numpy())),
        CGH.amp_abs(CGH.amplitude(prop_rm.numpy())),
        CGH.phase_norm(CGH.phase(prop_rm.numpy())),

        CGH.amp_abs(CGH.intensity(goal_img_tf.numpy())),
        CGH.amp_abs(CGH.amplitude(goal_img_tf.numpy())),
        CGH.phase_norm(CGH.phase(goal_img_tf.numpy())),
    ]
    ImageProcess.show_imgs(imgs)

    
    # Calculate Loss
    prop_inten = prop_rm * tf.math.conj(prop_rm)
    prop_inten = tf.dtypes.cast(prop_inten, tf.float64)
    prop_inten_norm = tf_CGH.normalize_amp_one(prop_inten)

    # PSNR value
    psnr_val = cv2.PSNR(prop_inten_norm.numpy(), target.numpy(), R=1)
    print('PSNR: ', psnr_val)

    err = prop_inten_norm - target
    loss = 1/2 * ( err ** 2 )
    print(tf.reduce_mean(loss))
    
    return loss


for i in range(50):
    step_count = opt.minimize(loss_func, [phase]).numpy()
    print(step_count)

In [None]:
##### Validate Optimization Result #####
# Input Image
phase_exp = tf.dtypes.complex(tf.math.cos(phase), tf.math.sin(phase))
input_plane = tf.math.multiply(light_img, phase_exp)
input_plane = input_plane.numpy()


# Propagation from SLM to Image Plane
input_plane_add = ImageProcess.add_zero_padding(input_plane)
N = input_plane_add.shape[0]
# recon_img = CGH.propagation(input_plane_add, N, l_ambda, z, p)
recon_img = CGH.band_limited_angular_spectrum(input_plane_add, N, l_ambda, z, p)
# recon_img = CGH.angular_spectrum(input_plane_add, N, l_ambda, z, p)
recon_img = ImageProcess.remove_zero_padding(recon_img)
N = recon_img.shape[0]


# PSNR value
recon_img_inten = ImageProcess.normalize_amp_one(CGH.intensity(recon_img))
psnr_val = cv2.PSNR(recon_img_inten, target.numpy(), R=1)
print('PSNR: ', psnr_val)


##### Check Images #####
imgs = [
        CGH.amp_abs(CGH.intensity(input_plane)),
        CGH.amp_abs(CGH.amplitude(input_plane)),
        CGH.phase_norm(CGH.phase(input_plane)),

        CGH.amp_abs(CGH.intensity(recon_img)),
        CGH.amp_abs(CGH.amplitude(recon_img)),
        CGH.phase_norm(CGH.phase(recon_img))

    ]
ImageProcess.show_imgs(imgs)
opt_out_path = out_path + 'opt_'
ImageProcess.save_imgs(opt_out_path, imgs)