In [None]:
import numpy as np
import tensorflow as tf
import math
import cmath
import scipy

from keras.utils.vis_utils import plot_model
from tensorflow.keras import layers,regularizers
from tensorflow.keras.layers import Input, Add, Dense, Activation, ZeroPadding2D, BatchNormalization, Dropout
from tensorflow.keras.layers import Flatten, Conv2D, AveragePooling2D, MaxPooling2D, GlobalMaxPooling2D, concatenate
from tensorflow.keras.models import Model, load_model
from tensorflow.keras import metrics
from tensorflow.keras import optimizers
from tqdm import tqdm
import os
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import imageio.v3 as iio
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error

In [None]:
data_loc = "C:/Users/yrq64132/OneDrive - Science and Technology Facilities Council/Documents/MATLAB/Petros/DeepLearning/ImQual_Dataset_nonroot2"
images = os.listdir(data_loc + "/Train/Images")
labels = os.listdir(data_loc + "/Train/Labels")


X_train_orig = list()
Y_train_orig = list()
n_train_images = np.shape(images)[0]
for i in tqdm(range(np.shape(images)[0])):
    X_train_orig.append(iio.imread(data_loc + "/Train/Images/" + images[i]))
    Y_train_orig.append(pd.read_csv(data_loc + "/Train/Labels/" + labels[i], header = None))

In [None]:
resolution = np.shape(X_train_orig)[1]
n_coeffs = np.shape(Y_train_orig)[1]
image_shape = (resolution,resolution)

In [None]:
X_train = np.reshape(X_train_orig,[n_train_images, resolution, resolution, 1]).astype(int)
Y_train = np.reshape(Y_train_orig,[n_train_images, n_coeffs]).astype(np.single)

images_train_dataset = tf.data.Dataset.from_tensor_slices(X_train)
labels_train_dataset = tf.data.Dataset.from_tensor_slices(Y_train)
train_dataset = tf.data.Dataset.zip((images_train_dataset, labels_train_dataset))
train_dataset = train_dataset.batch(64, drop_remainder = False)
train_dataset = train_dataset.prefetch(8)

In [None]:
images = os.listdir(data_loc + "/Test/Images")
labels = os.listdir(data_loc + "/Test/Labels")


X_test_orig = list()
Y_test_orig = list()
n_test_images = np.shape(images)[0]
for i in tqdm(range(np.shape(images)[0])):
    X_test_orig.append(iio.imread(data_loc + "/Test/Images/" + images[i]))
    Y_test_orig.append(pd.read_csv(data_loc + "/Test/Labels/" + labels[i], header = None))

X_test = np.reshape(X_test_orig,[n_test_images, resolution, resolution, 1]).astype(int)
Y_test = np.reshape(Y_test_orig,[n_test_images, n_coeffs]).astype(np.single)

images_test_dataset = tf.data.Dataset.from_tensor_slices(X_test)
labels_test_dataset = tf.data.Dataset.from_tensor_slices(Y_test)
test_dataset = tf.data.Dataset.zip((images_test_dataset, labels_test_dataset))
test_dataset = test_dataset.batch(64, drop_remainder = False)
test_dataset = test_dataset.prefetch(8)

In [None]:
def conv_block(inputs, n_filters, k_size, stride,layer_num, pad = "same",max_pool = False):
    X = Conv2D(n_filters, k_size, stride, activation = "relu",padding = pad,name = 'conv'+str(layer_num))(inputs)
    X = BatchNormalization(axis=-1,name = 'batch_norm'+str(layer_num))(X)
    if max_pool:
        X = MaxPooling2D(pool_size=(4, 4),padding = "same",name='max_pool'+str(layer_num))(X)
    
    return X
    

input_img = tf.keras.Input(shape = image_shape + (1,),name='input_image')
conv1 = conv_block(input_img,16,5,1,1)
conv2 = conv_block(conv1,16,5,1,2)
conv3 = conv_block(conv2,16,5,2,3)
conv4 = conv_block(conv3,32,5,1,4)
conv5 = conv_block(conv4,32,5,1,5)
conv6 = conv_block(conv5,32,5,2,6)
conv7 = conv_block(conv6,64,5,1,7)
conv8 = conv_block(conv7,64,5,1,8)
conv9 = conv_block(conv8,64,5,2,9)
conv10 = conv_block(conv9,128,5,1,10)
conv11 = conv_block(conv10,128,5,1,11)
conv12 = conv_block(conv11,128,5,2,12)
conv13 = conv_block(conv12,256,5,1,13)
conv14 = conv_block(conv13,256,5,1,14)
conv15 = conv_block(conv14,256,5,2,15)
conv16 = conv_block(conv15,512,5,1,16)
conv17 = conv_block(conv16,512,5,1,17)
conv18 = conv_block(conv17,512,5,2,18)


flat19 = Flatten(name='flat19')(conv18)

fc20 = Dense(1000,activation = "relu",name='fc20')(flat19)
drop20 = Dropout(0.5,name='Dropout20')(fc20)
merge20 = concatenate([drop20,flat19],name='merge20')
fc21 = Dense(256,
            activation = "relu",
            kernel_regularizer=regularizers.L2(1e-4),
            bias_regularizer=regularizers.L2(1e-4),
            activity_regularizer=regularizers.L2(1e-5),name='fc21')(merge20)

drop21 = Dropout(0.5,name='Dropout21')(fc21)
merge21 = concatenate([drop21,merge20],name='merge21')
fc22 = Dense(12,activation = None,name='fc22')(merge21)

coefficient_network = tf.keras.Model(inputs = input_img, outputs = fc22, name="coefficient_network")
plot_model(coefficient_network, show_shapes=True)

In [None]:
def cart2pol(x, y):
    rho = np.sqrt(x**2 + y**2)
    phi = np.arctan2(y, x)
    return(rho, phi)

def generate_orders(start,final):
    N = []
    M = []
    for i  in range(start,final+1):
        N.append(i * np.ones((i+1,)))
        M.append(np.arange(-i,i+1,2))
        
    return (np.concatenate(N),np.concatenate(M))
        

def myzernike(n,m,r,theta):
    Z = np.zeros((len(r),len(n)))
    for i in range(0,len(n)):
        for s in range(0,int((n[i] - abs(m[i]))/2)+1):
            Z[:,i] = Z[:,i] + (-1)**s * math.factorial(int(n[i] - s)) / (math.factorial(int(s)) * math.factorial(int((n[i] + m[i])/2 - s)) * math.factorial(int((n[i]-m[i])/2 - s))) * r**(n[i] - 2*s)
        
        if m[i] < 0:
            Z[:,i] = -Z[:,i] * np.sin(theta * m[i])
        else:
            Z[:,i] = Z[:,i] * np.cos(theta * m[i])
    
    return Z

resolution = 200
X = np.linspace(-1,1,resolution)
[x,y] = np.meshgrid(X,X)
[r,theta] = cart2pol(x,y)
pupil = (r<=1)
r = r * pupil
theta = theta * pupil 
#theta = np.transpose(theta) # if you want Z matrix to be same as on MATLAB this should be transposed again later
# python just has it in a different order to MATLAB
r = r[np.nonzero(r)]
theta = theta[np.nonzero(theta)]

N,M = generate_orders(2,9)
Z = myzernike(N,M,r,theta) # matrix of zernike polynomials

In [None]:
class FarField: # function for making far-field images
    def __init__(self,wavelength):
        self.radian_inator = 1e-6 * 2 * np.pi/(636e-9)
        self.aperture_size = 0.2
        self.noise_dev = 0.05
        self.noise_mean = 0.05
        
    def generate_farfield(self,pupil_func):
        row_col = np.array(np.shape(pupil_func)[1:]) # array of number of rows and columns
        sz2 = row_col / self.aperture_size
        padwidth = (np.round(sz2 -row_col)/2).astype(int)
        paddings = tf.constant([[0,0],[padwidth[0],padwidth[0]],[padwidth[1],padwidth[1]]])
        padded_pupil_func = tf.pad(pupil_func, paddings)
        image = tf.signal.fftshift(tf.signal.fft2d(padded_pupil_func))
        image = image[:,padwidth[0]:-padwidth[0],padwidth[1]:-padwidth[1]]
        image = image * tf.math.conj(image)
        #image = tf.math.sqrt(image)
        image = tf.cast(image,tf.float32)
        return image
    
FF = FarField(636e-9)

In [None]:
class generate_shape(layers.Layer): # custom layer to generate wavefront shapes from zernike coefficients
    
    def __init__(self,Z,**kwargs):
        super().__init__(**kwargs)
        self.Z = Z
        
    def call(self, zernike_coeffs,indices):
        batch_size = tf.shape(zernike_coeffs)[0]
        num_polynomials = np.minimum(np.shape(zernike_coeffs)[1],np.shape(self.Z)[1])
        
        self.Z = (self.Z)[:,0:num_polynomials]
        zernike_coeffs = zernike_coeffs[:,0:num_polynomials]

        
        shape = tf.matmul(zernike_coeffs, tf.transpose(tf.cast(self.Z,tf.float32))) * FF.radian_inator

        blank_wavefront = tf.zeros((1,200,200))
        

        tensor = tf.zeros((batch_size,200,200))
        if np.shape(shape)[-1] == 1:
            shape = tf.squeeze(shape,-1)
        updates = shape
        
        i = tf.constant(0)

        cond = lambda tensor, indices, updates, batch_size, i, blank_wavefront: tf.less(i,batch_size)    
        
        def body(tensor,indices,updates,batch_size,i,blank_wavefront):
            #updates = lambda: updates
            current_wavefront = tf.tensor_scatter_nd_update(tensor=tensor[i,:,:], indices=indices, updates = updates[i,:])
            current_wavefront = tf.expand_dims(current_wavefront,axis=0)
            i = tf.add(i, 1)
            blank_wavefront = tf.concat([blank_wavefront,current_wavefront],axis=0)
            return tensor, indices, updates, batch_size, i, blank_wavefront
        
        tensor, indices, updates, batch_size, i, blank_wavefront = tf.while_loop(cond,body,[tensor, indices, updates, batch_size, i, blank_wavefront],[tensor.get_shape(),indices.get_shape(),updates.get_shape(),batch_size.get_shape(),i.get_shape(),tf.TensorShape([None,resolution,resolution])])

        wavefronts = blank_wavefront[1:,:,:] # I realise this is kinda messy but it works

        return wavefronts

In [None]:
class generate_image(layers.Layer): # layer for combining shapes and outputting far-field
    
    def __init__(self,**kwargs):
        super().__init__(**kwargs)
        
        
    def call(self,generated_shape,real_shape):
        reverse_shape = -tf.reverse(generated_shape,[1,2]) # network may have found the reverse shape
        shape1 = tf.subtract(real_shape,generated_shape) # find resultant shapes
        shape2 = tf.subtract(real_shape,reverse_shape)
        m1 = tf.reduce_mean(tf.square(shape1),[-2,-1])
        m2 = tf.reduce_mean(tf.square(shape2),[-2,-1])
        lower = tf.cast((m1<=m2),tf.float32)
        lower = tf.expand_dims(tf.expand_dims(lower,-1),-1) # add dimensions so it can be broadcasted
        shape1_lower = tf.multiply(shape1,lower) # sets examples of shape 1 with higher mean to 0
        shape2_lower = tf.multiply(shape2,1-lower) # sets examples of shape 2 with higher mean to 0
        better_shape = tf.add(shape1_lower,shape2_lower) # add the shapes to combine the best ones
        pupil_func = pupil*tf.math.exp(tf.dtypes.complex(0.,better_shape))
        gen_farfield = FF.generate_farfield(pupil_func) # find new farfield
        
        return [gen_farfield,better_shape]

In [None]:
class measure_quality(layers.Layer): # final layer for image quality
    def __init__(self,**kwargs):
        super().__init__(**kwargs)
        
    def call(self,generated_image):
        maxI = tf.reduce_max(generated_image,[1,2])
        normalised = tf.divide(generated_image,maxI[:,None,None])
        quality = tf.reduce_sum(normalised,[1,2])
        
        return quality
        

In [None]:
Z = myzernike(N,M,r,theta)
real_image = layers.Input(name="real_image_input",shape = image_shape+(1,)) # inputs
real_coeffs = layers.Input(name="real_coeffs_input",shape = (36,))
zernike_coeffs = coefficient_network(real_image) # predict coefficients
indices = tf.where(pupil)
generated_shape = generate_shape(Z,name='generated_shape')(zernike_coeffs,indices) # predicted wavefront shape
real_shape = generate_shape(Z,name='real_shape')(real_coeffs,indices)
[generated_image,corrected] = generate_image(name='generated_image')(generated_shape,real_shape) # resultant farfield
image_quality = measure_quality(name='image_quality')(generated_image) # evaluate image quality
qualityNet = Model(inputs = [real_image,real_coeffs],outputs=[image_quality,zernike_coeffs,generated_image,generated_shape,real_shape,corrected],name="qualityNet_untrained")
qualityNet.summary()


In [None]:
class QualityModel(Model):

    def __init__(self, qualityNet):
        super().__init__()
        self.qualityNet = qualityNet
        self.loss_tracker = metrics.Mean(name="loss")

    def call(self, inputs):
        return self.qualityNet(inputs)

    def train_step(self, data):
        with tf.GradientTape() as tape:
            loss = self._compute_loss(data)
        # Storing the gradients of the loss function with respect to the
        # weights/parameters.
        gradients = tape.gradient(loss, self.qualityNet.trainable_weights)

        # Applying the gradients on the model using the specified optimizer
        self.optimizer.apply_gradients(
            zip(gradients, self.qualityNet.trainable_weights)
        )

        # Let's update and return the training loss metric.
        self.loss_tracker.update_state(loss)
        return {"loss": self.loss_tracker.result()}

    def test_step(self, data):
        loss = self._compute_loss(data)
        # update and return the loss metric.
        self.loss_tracker.update_state(loss)
        return {"loss": self.loss_tracker.result()}

    def _compute_loss(self, data):
        # the loss is the generated image quality
        loss = self.qualityNet(data)[0]
        return loss 

    @property
    def metrics(self):
        # We need to list our metrics here so the `reset_states()` can be
        # called automatically.
        return [self.loss_tracker]

In [None]:
plot_model(qualityNet,show_shapes=True)

In [None]:
quality_model = QualityModel(qualityNet)
quality_model.compile(optimizer=optimizers.Adam(0.0001),weighted_metrics = [quality_model.metrics])

In [None]:
qualityNet.load_weights('q_model_nonrooted.h5')

In [None]:
history = quality_model.fit(train_dataset,epochs = 4,validation_data=test_dataset)

In [None]:
loss = history.history['loss']
val_loss = history.history['val_loss']

In [None]:
def strehl_ratio(farfield):
    maxI = np.max(farfield/np.sum(farfield))
    I_0 = 0.03134885
    sr = maxI/I_0
    return(sr)

In [None]:
in1 = X_test[321]
in2 = Y_test[321]
in1 = tf.expand_dims(in1,0)
in2 = tf.expand_dims(in2,0)

In [None]:
plt.imshow(in1[0])

In [None]:
outputs = qualityNet.predict([in1,in2]) # if predicting multiple examples at once the output order is very strange

In [None]:
correctedff = outputs[2]
gen_shape = outputs[3]
real_shape = outputs[4]
corrected_shape = outputs[5]
mse_real = mean_squared_error(real_shape[0],np.zeros((200,200)))
mse_corr = mean_squared_error(corrected_shape[0],np.zeros((200,200)))
sr_real = strehl_ratio(in1[0])
sr_corr = strehl_ratio(correctedff[0])

fig = plt.figure(figsize=(10,7))
rows=2
columns=2
fig.add_subplot(rows,columns,1)
inff = plt.imshow(in1[0]/tf.reduce_max(in1[0]))
fig.colorbar(inff)
plt.axis('off')
plt.title('Input far-field, Strehl ratio: %.4f' % sr_real)

fig.add_subplot(rows,columns,2)
cff = plt.imshow(correctedff[0]/np.max(correctedff[0]))
fig.colorbar(cff)
plt.axis('off')
plt.title('Corrected far-field, Strehl ratio: %.4f' % sr_corr)

fig.add_subplot(rows,columns,3)
insh = plt.imshow(real_shape[0])
fig.colorbar(insh)
plt.axis('off')
plt.title('Input shape, MSE: %0.4f' % mse_real)

fig.add_subplot(rows,columns,4)
gensh = plt.imshow(corrected_shape[0])
fig.colorbar(gensh)
plt.axis('off')
plt.title('Corrected shape, MSE: %0.4f' % mse_corr)
plt.show()