**Table of contents**<a id='toc0_'></a>    
- [ML notebook](#toc1_)    
  - [Assemble dataset](#toc1_1_)    
  - [Prepare the dataset](#toc1_2_)    
- [Unet](#toc2_)    
  - [Save the weights](#toc2_1_)    
- [Training results](#toc3_)    
- [Inference](#toc4_)    

<!-- vscode-jupyter-toc-config
	numbering=false
	anchor=true
	flat=false
	minLevel=1
	maxLevel=6
	/vscode-jupyter-toc-config -->
<!-- THIS CELL WILL BE REPLACED ON TOC UPDATE. DO NOT WRITE YOUR TEXT IN THIS CELL -->

# <a id='toc1_'></a>[ML notebook](#toc0_)
17.02.2025 -> 04.06.2025 - Dominique Humbert
Initial version.

Usefull links with which the Implementation was built:
- https://medium.com/coinmonks/learn-how-to-train-u-net-on-your-dataset-8e3f89fbd623

Model U-net from:
Vanberg, P.-O. (2019). Machine learning for image-based wavefront sensing. Université de Liège, Liège, Belgique.
https://explore.lib.uliege.be/permalink/32ULG_INST/oao96e/alma9919993582302321

In [None]:
%matplotlib widget
import matplotlib.pyplot as plt
from ows import ows
from ows import fouriertransform as mathft
from astropy.io import fits
from astropy.visualization import SqrtStretch
import numpy as np


# Inputs

In [None]:
SMALL_SIZE = 20
MEDIUM_SIZE = 20
BIGGER_SIZE = 25
GPU = True
foldername = "unet_outputs/Dw_N16_3"

## <a id='toc1_1_'></a>[Assemble dataset](#toc0_)

In [None]:
n = 99999
dim_data = 64
# train_DwDx = np.zeros((n,128,128))
# train_DwDy = np.zeros((n,128,128))
train_Dw = np.zeros((n,dim_data,dim_data,2))
#train_Dw = np.zeros((n,128,128))
train_lightfield = np.zeros((n,dim_data,dim_data))
train_psf = np.zeros((n,dim_data,dim_data))
train_ps = np.zeros((n,dim_data,dim_data))
print(train_Dw.shape)
filename0 = "data/DwDx_"
filename1 = "data/DwDy_"
filename2 = "data/lightfield_"
filename3 = "data/psf_"
filename4 = "data/ps_"

status = 0
for i in range(0,n):
    train_Dw[i,:,:,0] = ows.pixel_adder(np.array(fits.getdata(filename0+str(i)+".fits")), scale_factor = [4,4], final_shape = None)
    # train_Dw[i,:,:,0] = np.array(fits.getdata(filename0+str(i)+".fits"))
    train_Dw[i,:,:,1] = ows.pixel_adder(np.array(fits.getdata(filename1+str(i)+".fits")), scale_factor = [4,4], final_shape = None)
    # train_Dw[i,:,:,1] = np.array(fits.getdata(filename1+str(i)+".fits"))

    # train_lightfield[i,:,:] = ows.pixel_adder(np.array(fits.getdata(filename2+str(i)+".fits")),[1/4,1/4])
    train_lightfield[i,:,:] = np.array(fits.getdata(filename2+str(i)+".fits"))
    # train_psf[i,:,:] = ows.pixel_adder(np.array(fits.getdata(filename3+str(i)+".fits")),[1/4,1/4])
    train_psf[i,:,:] = (np.array(fits.getdata(filename3+str(i)+".fits")))
    # train_ps[i,:,:] = ows.normalize(np.array(fits.getdata(filename4+str(i)+".fits")))
    if 100*i//n == status*10:
        print('Loading training set: ',status*10,'% done')
        status += 1
print('Loading training set: ',100,'% done')
#np.savez("data/train_DwDx.npz", train_DwDy)
#np.savez("data/train_DwDy.npz", train_DwDy)
np.savez("data/train_Dw.npz", train_Dw)
np.savez("data/train_lightfield.npz", train_lightfield)
np.savez("data/train_psf.npz", train_psf)
# np.savez("data/train_ps.npz", train_ps)

plt.close()
plt.figure(1)
plt.subplot(2,2,1)
plt.imshow(train_psf[0,:,:])
plt.subplot(2,2,2)
plt.imshow(train_lightfield[0,:,:])
plt.subplot(2,2,3)
plt.imshow(train_Dw[0,:,:,0])
plt.subplot(2,2,4)
plt.imshow(train_Dw[0,:,:,1])

## <a id='toc1_2_'></a>[Prepare the dataset](#toc0_)

1. Random draw to separate the training and validation set
2. Split the datasets

In [None]:
rng = np.random.default_rng(seed=None)
n = 1999

In [None]:
# Training set
rng = np.random.default_rng(n)
temp = np.load("data/train_Dw.npz")["arr_0"]
n = temp.shape[0]
k = 0.8 
train_id = np.zeros(n,np.bool)

train_id[rng.choice(n,int(k*n),replace=False)] = True

norm_max = np.max(temp)
norm_min = np.min(temp)

temp = (temp - norm_min) / (norm_max - norm_min)

temp = np.array(temp)
print(temp.shape)
x_train = temp[train_id,:,:]
x_val = temp[~train_id,:,:]
# y_train = temp[train_id,:,:]
# y_val = temp[~train_id,:,:]
# print(x_train.shape)

temp = np.load("data/train_psf.npz")["arr_0"]

y_train = temp[train_id,:,:]
y_val = temp[~train_id,:,:]
# x_train = temp[train_id,:,:]
# x_val = temp[~train_id,:,:]

plt.close(0)
plt.figure(0)
plt.imshow(x_train[4,:,:,0])
plt.colorbar()
plt.show()

print(x_val.shape,x_train.shape)
print(y_val.shape,y_train.shape)

# <a id='toc2_'></a>[Unet](#toc0_)

In [None]:
import ows.unet as unet
model = unet.build_unet(input_shape=(64,64,1), n_channels_out=1)
# model.summary()

In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint, LearningRateScheduler, EarlyStopping, ReduceLROnPlateau 
weight_path="{}_weights.best.weights.h5".format('cxr_reg') 
 
checkpoint = ModelCheckpoint(weight_path, monitor='val_loss', verbose=1,  
                             save_best_only=True, mode='min', save_weights_only = True) 
 
reduceLROnPlat = ReduceLROnPlateau(monitor='val_loss', factor=0.05, patience=3, 
                                   verbose=1, mode='min', epsilon=0.05, cooldown=2, min_lr=1e-6) 
 
early = EarlyStopping(monitor="val_loss",  mode="min", patience=15)  
callbacks_list = [checkpoint, early, reduceLROnPlat]

from IPython.display import clear_output 
from tensorflow.keras.optimizers import Adam 
from tensorflow.keras.optimizers import SGD 
from sklearn.model_selection import train_test_split 
from sklearn.metrics import roc_curve, auc 
#images, mask = images/255, (mask>127).astype(np.float32) 
                                                            
train_vol = x_train

validation_vol = x_val
train_seg = y_train
validation_seg = y_val

print(train_seg.shape)

In [None]:
if GPU:
    import tensorflow as tf
    tf.config.list_physical_devices('GPU')
    with tf.device('/GPU:0'):
    model.compile(optimizer=SGD(learning_rate=0.001, momentum=0.9), loss=[unet.mse_loss], metrics = [unet.dice_coef, 'binary_accuracy',"AUC" ]) 

    loss_history = model.fit(x = train_vol,y = train_seg,batch_size = 32,epochs = 100,validation_data =(validation_vol,validation_seg) , callbacks=callbacks_list)
else:
    loss_history = model.fit(x = train_vol,y = train_seg,batch_size = 12,epochs = 100,validation_data =(validation_vol,validation_seg) , callbacks=callbacks_list)


## <a id='toc2_1_'></a>[Save the weights](#toc0_)

In [None]:
model.save_weights(foldername + "/model.weights.h5",overwrite=True)

# <a id='toc3_'></a>[Training results](#toc0_)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize = (20, 10))
ax1.plot(loss_history.history['loss'], '-', label = 'Loss')
ax1.plot(loss_history.history['val_loss'], '-', label = 'Validation Loss')
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Loss')
ax1.set_title('Loss History')
ax1.legend()

ax2.plot(100*np.array(loss_history.history['binary_accuracy']), '-', label = 'Accuracy')
ax2.plot(100*np.array(loss_history.history['val_binary_accuracy']), '-',label = 'Validation Accuracy')
ax2.set_xlabel('Epochs')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Accuracy History')
ax2.legend()
plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title
plt.savefig(foldername + "/loss.png")

# <a id='toc4_'></a>[Inference](#toc0_)

In [None]:
model.load_weights(foldername + "/model.weights.h5")

In [None]:
DwDx = ows.pixel_adder(np.array(fits.getdata("data/DwDx_99999.fits")), scale_factor = [4,4], final_shape = None)
DwDy = ows.pixel_adder(np.array(fits.getdata("data/DwDy_99999.fits")), scale_factor = [4,4], final_shape = None)
# DwDx = np.array(fits.getdata("data/DwDx_99999.fits"))
# DwDy = np.array(fits.getdata("data/DwDy_99999.fits"))
# psf = ows.pixel_adder(np.array(fits.getdata("data/psf_99999.fits")), scale_factor = [1/4,1/4], final_shape = None)
psf = np.array(fits.getdata("data/psf_99999.fits"))
# ps = np.array(fits.getdata("data/ps_99999.fits"))
lightfield = np.array(fits.getdata("data/lightfield_99999.fits"))



Dw = np.zeros((64,64,2))
print(DwDx.shape)
Dw[:,:,0] = DwDx
Dw[:,:,1] = DwDy

Dw = (Dw - norm_min) / (norm_max - norm_min)

im = np.expand_dims(Dw, axis=0)

predictions = model(im)[0,:,:,0]
# print(predictions[:,:,0].shape)


plt.close(2)
plt.figure(2,figsize = (10, 10))
plt.imshow(predictions)
plt.title("Inference")
plt.colorbar()
plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title
plt.savefig(foldername + "/inference.png")
fits.writeto(foldername+"/inference.fits", np.asarray(predictions), overwrite=True)

plt.close(3)
plt.figure(3,figsize = (10, 10))
plt.imshow(psf[:,:])
plt.title("True image")
plt.colorbar()
plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title
plt.savefig(foldername + "/true_image.png")
fits.writeto(foldername+"/true_image.fits", psf[:,:], overwrite=True)
plt.show()
