This is the first experiment where the model is given all the 98 focal slices to be trained for. The model is also saved for quick use.

In [1]:
autofocus_path='/mnt/Data/Autofocus/'

autofocus_train_path=autofocus_path+"Train/"
autofocus_test_path=autofocus_path+"Test/"
autofocus_cache_path=autofocus_path+"Cache/"
autofocus_exp_path=autofocus_path+"Exp1/"

In [2]:
import os
import numpy as np
import time
from tqdm import tqdm
import tensorflow as tf
import shutil
from tensorflow import keras
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam

print(tf.version.VERSION)

2.16.2


Load the dataset from the Cache folder with only read permision

In [3]:
def load_variable_from_file(filename):
    try:
        with open(filename, 'r') as file:
            variable = file.read()
        print(f"Variable loaded successfully from {filename}")
        return variable
    except Exception as e:
        print(f"Error occurred while loading variable from {filename}: {e}")
        return None
    
_dataset_len=int(load_variable_from_file(autofocus_cache_path+"length.dat"))

shutil.copy(autofocus_cache_path+"dataset.dat",autofocus_exp_path+"dataset.dat")
shutil.copy(autofocus_cache_path+"labels.dat",autofocus_exp_path+"labels.dat")
shutil.copy(autofocus_cache_path+"patch.dat",autofocus_exp_path+"patch.dat")

dataset=np.memmap(autofocus_exp_path+"dataset.dat", dtype='int8', mode='r', shape=(_dataset_len,128,128,98))
labels=np.memmap(autofocus_exp_path+"labels.dat", dtype='int8', mode='r', shape=(_dataset_len,))


Variable loaded successfully from /mnt/Data/Autofocus/Cache/length.dat


Incase the dataset is too large, this function effectively uses every kind of image scene with a reduced patch size making the dataset smaller. The highest threshold value is 15, anything below that will decrease the dataset by that percentage. 

In [5]:
def patch_threshold (patch_count,threshold):
    patch_index=[]
    index=0

    for count in patch_count:
        indices=[index+i for i in range(min(threshold,count))]
        index+=count
        patch_index=patch_index+indices
    
    return patch_index

def finalize_files(file_path):
    try:
        os.remove(file_path+"dataset.dat")
    except OSError as e:
        print(f"Error deleting file: {e}")

    try:
        os.rename(file_path+"dataset_temp.dat", file_path+"dataset.dat")
    except OSError as e:
        print(f"Error renaming file: {e}")

    try:
        os.remove(file_path+"labels.dat")
    except OSError as e:
        print(f"Error deleting file: {e}")

    try:
        os.rename(file_path+"labels_temp.dat", file_path+"labels.dat")
    except OSError as e:
        print(f"Error renaming file: {e}")

def dataset_threshold(file_path,length,threshold):
    dataset=np.memmap(file_path+"dataset.dat", dtype='int8', mode='r', shape=(length,128,128,98))
    labels=np.memmap(file_path+"labels.dat", dtype='int8', mode='r', shape=(length,))
    patch_count=np.memmap(file_path+"patch.dat", dtype='int8', mode='r', shape=(1775,))

    patch_index=patch_threshold(patch_count,threshold)
    new_len=len(patch_index)

    updated_dataset=np.memmap(file_path+"dataset_temp.dat", dtype='int8', mode='w+', shape=(new_len,128,128,98))
    updated_labels=np.memmap(file_path+"labels_temp.dat", dtype='int8', mode='w+', shape=(new_len,))

    for pos,index in tqdm(enumerate(patch_index),total=len(patch_index)):
        updated_dataset[pos]=dataset[index]
        updated_labels[pos]=labels[index]
        updated_dataset.flush()
        updated_labels.flush()

    dataset.flush()
    labels.flush()
    patch_count.flush()
    updated_dataset.flush()
    updated_labels.flush()

    finalize_files(file_path)

    return new_len


dataset.flush()
labels.flush()

_dataset_len=dataset_threshold(autofocus_exp_path,_dataset_len,7)

def save_variable_to_file(variable, filename):
    try:
        with open(filename, 'w+') as file:
            file.write(str(variable))
        print(f"Variable saved successfully to {filename}")
    except Exception as e:
        print(f"Error occurred while saving variable to {filename}: {e}")

save_variable_to_file(_dataset_len,autofocus_exp_path+"length.dat")


dataset=np.memmap(autofocus_exp_path+"dataset.dat", dtype='int8', mode='r', shape=(_dataset_len,128,128,98))
labels=np.memmap(autofocus_exp_path+"labels.dat", dtype='int8', mode='r', shape=(_dataset_len,))


100%|██████████| 10772/10772 [25:52<00:00,  6.94it/s]


Variable saved successfully to /mnt/Data/Autofocus/Exp1/length.dat


Initializing the Test and Train Dataset

In [3]:
def load_variable_from_file(filename):
    try:
        with open(filename, 'r') as file:
            variable = file.read()
        print(f"Variable loaded successfully from {filename}")
        return variable
    except Exception as e:
        print(f"Error occurred while loading variable from {filename}: {e}")
        return None
    
_dataset_len=int(load_variable_from_file(autofocus_exp_path+"length.dat"))

shutil.copy(autofocus_cache_path+"dataset.dat",autofocus_exp_path+"dataset.dat")
shutil.copy(autofocus_cache_path+"labels.dat",autofocus_exp_path+"labels.dat")
shutil.copy(autofocus_cache_path+"patch.dat",autofocus_exp_path+"patch.dat")

dataset=np.memmap(autofocus_exp_path+"dataset.dat", dtype='int8', mode='r', shape=(_dataset_len,128,128,98))
labels=np.memmap(autofocus_exp_path+"labels.dat", dtype='int8', mode='r', shape=(_dataset_len,))


Variable loaded successfully from /mnt/Data/Autofocus/Exp1/length.dat


In [4]:
autofocus_test_path_cache=autofocus_test_path+"Cache/"

shutil.copy(autofocus_test_path_cache+"dataset.dat",autofocus_exp_path+"Test/dataset.dat")
shutil.copy(autofocus_test_path_cache+"labels.dat",autofocus_exp_path+"Test/labels.dat")

_test_len=int(load_variable_from_file(autofocus_test_path_cache+"length.dat"))

test_data=np.memmap(autofocus_exp_path+"Test/dataset.dat", dtype='int8', mode='r', shape=(_test_len,128,128,98))
test_labels=np.memmap(autofocus_exp_path+"Test/labels.dat", dtype='int8', mode='r', shape=(_test_len,))

Variable loaded successfully from /mnt/Data/Autofocus/Test/Cache/length.dat


The MobileNetV2 model is used. The input shape is converted to (128,128,98) to be able to take the dual pixel focal stack slices as individual channels of an image. Then we use ordinal regression loss (L2). The Adam optimizer is used to build the model. 

In [None]:


# Step 1: Modify MobileNetV2 to accept 128x128x98 input
def create_modified_mobilenetv2(input_shape=(128, 128, 98)):
    base_model = MobileNetV2(input_shape=input_shape, include_top=False, weights=None)
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    # Assuming the output for the ordinal regression problem is a single value
    outputs = Dense(1)(x)
    model = Model(inputs=base_model.input, outputs=outputs)
    return model

# Step 2: Implement the ordinal regression loss
def ordinal_regression_loss(y_true, y_pred):
    # L2 loss (mean squared error)
    return tf.reduce_mean(tf.square(y_true - y_pred), axis=-1)

# Step 3: Set up the training loop
def train_model(model, train_dataset, steps_per_epoch=20000, epochs=128, initial_lr=1e-5, beta1=0.5, beta2=0.999):
    optimizer = Adam(learning_rate=initial_lr, beta_1=beta1, beta_2=beta2)
    model.compile(optimizer=optimizer, loss=ordinal_regression_loss,metrics=['accuracy'])
    
    model.fit(train_dataset, epochs=epochs, steps_per_epoch=steps_per_epoch)

# Create the model
input_shape = (128, 128, 98)
model = create_modified_mobilenetv2(input_shape=input_shape)

#model.summary()

The steps_per_epoch is the only thing the user running the model needs to think about. The formula is len(dataset)=batch_size * steps_per_epoch * epochs. So choose a good steps_per_epoch size. Choosing a higher number means consuming more RAM but less time and choosing a lesser number number means less RAM but more time. 

In [6]:
batch_size=128

train_dataset = tf.data.Dataset.from_tensor_slices((dataset, labels)).batch(batch_size)
print("Train Dataset Length :",len(train_dataset))

Train Dataset Length : 84


In [7]:
steps_per_epoch=7
epochs=len(train_dataset)//steps_per_epoch

train_model(model, train_dataset, steps_per_epoch=steps_per_epoch, epochs=epochs)

Epoch 1/12
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 1001ms/step - accuracy: 0.915 - loss: 2.263
Epoch 2/12
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 960ms/step - accuracy: 0.939 - loss: 1.351
Epoch 3/12
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 956ms/step - accuracy: 0.973 - loss: 2.419
Epoch 4/12
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 1034ms/step - accuracy: 0.964 - loss: 2.47
Epoch 5/12
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 1024ms/step - accuracy: 0.923 - loss: 2.372
Epoch 6/12
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 911ms/step - accuracy: 0.999 - loss: 1.501
Epoch 7/12
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 911ms/step - accuracy: 0.906 - loss: 2.982
Epoch 8/12
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 917ms/step - accuracy: 0.964 - loss: 1.273
Epoch 9/12
[1m7/7[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m 

In [None]:
model.save(autofocus_exp_path+"Experiment_1_model.keras")

Test the model

In [9]:
batch_size=32

test_dataset = tf.data.Dataset.from_tensor_slices((test_data, test_labels)).batch(batch_size)
print("Test Dataset Length :",len(test_dataset))

Test Dataset Length : 40


In [10]:
results = model.evaluate(test_dataset)
loss, accuracy = results
print(f"Loss: {loss}, Accuracy: {accuracy*100}%")


[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 1s/step - accuracy: 0.965 - loss: 0.234
Loss: 0.54836130142212, Accuracy: 96.457846284716289%


Clear the RAM for immediate clean up

In [3]:

def clear_ram():
    global_vars = list(globals().keys())  # Get a list of global variable names
    vars_to_delete = [var for var in global_vars]
    
    # Delete selected variables
    for var in vars_to_delete:
        del globals()[var]
    import gc
    # Invoke garbage collector
    gc.collect()
    
    # Print confirmation
    print('RAM cleared')

time.sleep(5)

clear_ram()

RAM cleared
