In [None]:
import numpy as np
import tensorflow as tf
import math
import cmath
import scipy
from tensorflow.keras import layers
from tensorflow.keras.layers import Input, Add, Dense, Activation, ZeroPadding2D, BatchNormalization, Dropout
from tensorflow.keras.layers import Flatten, Conv2D, AveragePooling2D, MaxPooling2D, GlobalMaxPooling2D, concatenate
from tensorflow.keras.models import Model, load_model
from tensorflow.keras import metrics
from tensorflow.keras import optimizers
from tqdm import tqdm
import os
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import imageio.v3 as iio
import pandas as pd

In [None]:
resolution = 200
image_shape = (resolution, resolution)

In [None]:
data_loc = "enter dataset location"
images = os.listdir(data_loc + "/Train/Images")
labels = os.listdir(data_loc + "/Train/Labels")

X_train_orig = list()
Y_train_orig = list()
n_train_images = np.shape(images)[0]

for i in tqdm(range(np.shape(images)[0])):
    X_train_orig.append(iio.imread(data_loc + "/Train/Images/" + images[i]))
    Y_train_orig.append(pd.read_csv(data_loc + "/Train/Labels/" + labels[i], header = None))
    
X_train = np.reshape(X_train_orig,[n_train_images, resolution, resolution, 1]).astype(np.float32)
Y_train = np.reshape(Y_train_orig,[n_train_images, 12, 1]).astype(np.float32)

image_augmenter = ImageDataGenerator(rotation_range=4,width_shift_range=5.,height_shift_range=5.,zoom_range=[0.93,1.07])
X_train = image_augmenter.flow(X_train,shuffle=False,batch_size=n_train_images)[0]

In [None]:
images = os.listdir(data_loc + "/Test/Images")
labels = os.listdir(data_loc + "/Test/Labels")

X_test_orig = list()
Y_test_orig = list()
n_test_images = np.shape(images)[0]

for i in tqdm(range(np.shape(images)[0])):
    X_test_orig.append(iio.imread(data_loc + "/Test/Images/" + images[i]))
    Y_test_orig.append(pd.read_csv(data_loc + "/Test/Labels/" + labels[i], header = None))
    
X_test = np.reshape(X_test_orig,[n_test_images, resolution, resolution, 1]).astype(np.float32)
Y_test = np.reshape(Y_test_orig,[n_test_images, 12, 1]).astype(np.float32)

X_test = image_augmenter.flow(X_test,shuffle=False,batch_size=n_test_images)[0]

In [None]:
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, Y_train))
train_dataset = train_dataset.batch(64, drop_remainder = False)
train_dataset = train_dataset.prefetch(8)

In [None]:
train_dataset

In [None]:
test_dataset = tf.data.Dataset.from_tensor_slices((X_test, Y_test))
test_dataset = test_dataset.batch(64, drop_remainder = False)
test_dataset = test_dataset.prefetch(8)

In [None]:
def conv_block(inputs, n_filters, k_size, stride, pad = "same",max_pool = False):
    X = Conv2D(n_filters, k_size, stride, activation = "relu",padding = pad)(inputs)
    X = BatchNormalization(axis=-1)(X)
    if max_pool:
        X = MaxPooling2D(pool_size=(4, 4),padding = "same")(X)
    
    return X
    

input_img = tf.keras.Input(shape = image_shape + (1,))
conv1 = conv_block(input_img,16,5,1)
conv2 = conv_block(conv1,16,5,1)
conv3 = conv_block(conv2,16,5,2)
conv4 = conv_block(conv3,32,5,1)
conv5 = conv_block(conv4,32,5,1)
conv6 = conv_block(conv5,32,5,2)
conv7 = conv_block(conv6,64,5,1)
conv8 = conv_block(conv7,64,5,1)
conv9 = conv_block(conv8,64,5,2)
conv10 = conv_block(conv9,128,5,1)
conv11 = conv_block(conv10,128,5,1)
conv12 = conv_block(conv11,128,5,2)
conv13 = conv_block(conv12,256,5,1)
conv14 = conv_block(conv13,256,5,1)
conv15 = conv_block(conv14,256,5,2)
conv16 = conv_block(conv15,512,5,1)
conv17 = conv_block(conv16,512,5,1)
conv18 = conv_block(conv17,512,5,2)


flat5 = Flatten()(conv18)

fc6 = Dense(1000,activation = "relu")(flat5)
drop6 = Dropout(0.5)(fc6)
merge6 = concatenate([drop6,flat5])
fc7 = Dense(256,activation = "relu")(merge6)
drop7 = Dropout(0.5)(fc7)
merge7 = concatenate([drop7,merge6])
fc8 = Dense(128,activation = None)(merge7)

embedding = tf.keras.Model(inputs = input_img, outputs = fc8, name="Embedding")
embedding.load_weights("encoder.h5")
embedding.trainable = False
embedding.summary()

In [None]:
def cart2pol(x, y):
    rho = np.sqrt(x**2 + y**2)
    phi = np.arctan2(y, x)
    return(rho, phi)

def generate_orders(start,final):
    N = []
    M = []
    for i  in range(start,final+1):
        N.append(i * np.ones((i+1,)))
        M.append(np.arange(-i,i+1,2))
        
    return (np.concatenate(N),np.concatenate(M))
        

def myzernike(n,m,r,theta):
    Z = np.zeros((len(r),len(n)))
    for i in range(0,len(n)):
        for s in range(0,int((n[i] - abs(m[i]))/2)+1):
            Z[:,i] = Z[:,i] + (-1)**s * math.factorial(int(n[i] - s)) / (math.factorial(int(s)) * math.factorial(int((n[i] + m[i])/2 - s)) * math.factorial(int((n[i]-m[i])/2 - s))) * r**(n[i] - 2*s)
        
        if m[i] < 0:
            Z[:,i] = -Z[:,i] * np.sin(theta * m[i])
        else:
            Z[:,i] = Z[:,i] * np.cos(theta * m[i])
    
    return Z

resolution = 200
X = np.linspace(-1,1,resolution)
[x,y] = np.meshgrid(X,X)
[r,theta] = cart2pol(x,y)
pupil = (r<=1)
r = r * pupil
theta = theta * pupil 
#theta = np.transpose(theta) # if you want Z matrix to be same as on MATLAB this should be transposed again later
# python just has it in a different order to MATLAB
r = r[np.nonzero(r)]
theta = theta[np.nonzero(theta)]
print(np.shape(r))
print(np.shape(theta))

N,M = generate_orders(2,9)
Z = myzernike(N,M,r,theta)

In [None]:
class FarField:
    def __init__(self,wavelength):
        self.radian_inator = 1e-6 * 2 * np.pi/(636e-9)
        self.aperture_size = 0.2
        self.noise_dev = 0.05
        self.noise_mean = 0.05
        
    def generate_farfield(self,pupil_func):
        row_col = np.array(np.shape(pupil_func)[1:]) # array of number of rows and columns
        sz2 = row_col / self.aperture_size
        padwidth = (np.round(sz2 -row_col)/2).astype(int)
        paddings = tf.constant([[0,0],[padwidth[0],padwidth[0]],[padwidth[1],padwidth[1]]])
        padded_pupil_func = tf.pad(pupil_func, paddings)
        image = tf.signal.fftshift(tf.signal.fft2d(padded_pupil_func))
        image = image[:,padwidth[0]:-padwidth[0],padwidth[1]:-padwidth[1]]
        image = image * tf.math.conj(image)
        image = tf.math.sqrt(image)
        image = tf.cast(image,tf.float32)
        return image
    
FF = FarField(636e-9)

In [None]:
class generate_image(layers.Layer):
    
    def __init__(self,Z,**kwargs):
        super().__init__(**kwargs)
        self.Z = Z
        
    def call(self, zernike_coeffs,indices):
        batch_size = tf.shape(zernike_coeffs)[0]
        num_polynomials = np.minimum(np.shape(zernike_coeffs)[1],np.shape(self.Z)[1])
        
        self.Z = (self.Z)[:,0:num_polynomials]
        zernike_coeffs = zernike_coeffs[:,0:num_polynomials]
       # zernike_coeffs = tf.expand_dims(zernike_coeffs,-1)
      #  self.Z = tf.convert_to_tensor(self.Z)
      #  self.Z = tf.expand_dims(self.Z,axis = 0)
      #  self.Z = tf.repeat(self.Z,batch_size,axis = 0)
        
        shape = tf.matmul(zernike_coeffs, tf.transpose(tf.cast(self.Z,tf.float32))) * FF.radian_inator

       # indices = tf.where(pupil_tensor)
        #indices = tf.expand_dims(indices,0)

       # indices = tf.repeat(indices,batch_size,axis = 0)
        
        blank_wavefront = tf.zeros((1,200,200))
        

        tensor = tf.zeros((batch_size,200,200))
        updates = tf.squeeze(shape)
        
        i = tf.constant(0)

        cond = lambda tensor, indices, updates, batch_size, i, blank_wavefront: tf.less(i,batch_size)    
        
        def body(tensor,indices,updates,batch_size,i,blank_wavefront):
            current_wavefront = tf.tensor_scatter_nd_update(tensor=tensor[i,:,:], indices=indices, updates = updates[i,:])
            current_wavefront = tf.expand_dims(current_wavefront,axis=0)
            i = tf.add(i, 1)
            blank_wavefront = tf.concat([blank_wavefront,current_wavefront],axis=0)
            return tensor, indices, updates, batch_size, i, blank_wavefront
        
        tensor, indices, updates, batch_size, i, blank_wavefront = tf.while_loop(cond,body,[tensor, indices, updates, batch_size, i, blank_wavefront],[tensor.get_shape(),indices.get_shape(),updates.get_shape(),batch_size.get_shape(),i.get_shape(),tf.TensorShape([None,resolution,resolution])])

        blank_wavefront = blank_wavefront[1:,:,:] # I realise this is very messy
        wavefront = blank_wavefront
        pupil_func = tf.math.exp(tf.dtypes.complex(0.,wavefront))
        gen_farfield = FF.generate_farfield(pupil_func)
        #gen_farfield = tf.expand_dims(gen_farfield,-1)
        return gen_farfield

In [None]:
class cosine_similarity(layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, real_embedding, generated_embedding):
        dot = tf.math.reduce_sum(tf.multiply(real_embedding,generated_embedding),axis = -1)
        norms = tf.math.multiply(tf.norm(real_embedding), tf.norm(generated_embedding))
        similarity = (tf.divide(dot,norms))
        return (similarity)
    

In [None]:
Z = myzernike(N,M,r,theta)
real_image = layers.Input(name="real_image_input",shape = image_shape+(1,))
zernike_coeffs = coefficient_network(real_image)
#attempt = tf.expand_dims(pupil,0)
#attempt = tf.repeat(attempt,12,axis = 0)
indexes = tf.where(pupil)
print(indexes)
generated_image = generate_image(Z)(zernike_coeffs,indexes)
similarity = cosine_similarity()(embedding(real_image),embedding(generated_image))

full_network = Model(inputs = real_image,outputs = similarity,name = "full_network")
full_network.summary()

In [None]:
tf.keras.utils.plot_model(full_network,show_shapes=True)

In [None]:
class FullModel(Model):
    def __init__(self, full_network):
        super().__init__()
        self.full_network = full_network
        self.loss_tracker = metrics.Mean(name="loss")

    def call(self, inputs):
        return self.full_network(inputs)

    def train_step(self, data):
        # GradientTape is a context manager that records every operation that
        # you do inside. We are using it here to compute the loss so we can get
        # the gradients and apply them using the optimizer specified in
        # `compile()`.
        with tf.GradientTape() as tape:
            loss = self._compute_loss(data)

        # Storing the gradients of the loss function with respect to the
        # weights/parameters.
        gradients = tape.gradient(loss, self.full_network.trainable_weights)

        # Applying the gradients on the model using the specified optimizer
        self.optimizer.apply_gradients(
            zip(gradients, self.full_network.trainable_weights)
        )

        # Let's update and return the training loss metric.
        self.loss_tracker.update_state(loss)
        return {"loss": self.loss_tracker.result()}

    def test_step(self, data):
        loss = self._compute_loss(data)

        # Let's update and return the loss metric.
        self.loss_tracker.update_state(loss)
        return {"loss": self.loss_tracker.result()}

    def _compute_loss(self, data):
        similarity = self.full_network(data[0])

        loss = 1 - similarity
        return loss
    

    @property
    def metrics(self):
        # We need to list our metrics here so the `reset_states()` can be
        # called automatically.
        return [self.loss_tracker]

In [None]:
full_model = FullModel(full_network)
full_model.compile(optimizer=optimizers.Adam(0.0001),weighted_metrics = [full_model.metrics])
full_model.fit(train_dataset,epochs = 5,validation_data=test_dataset)