Test and Train archives webpage - https://learntoautofocus-google.github.io/

Path to your downloaded files. Store the files in a folder say 'Autofocus'. Then create a dirctory 'Train' as Autofocus/Train and store inside them downloaded folders which have name as 'train<number>' as Autofocus/Train/train<number>. Then create a ditrectory called 'Test' as Autofocus/Test and store the downloaded folder 'test' in it as Autofocus/Test/test

The code is made such that the user running the model needs to download the test archive and atleast 1 train dataset. Theres no need to download all train datasets all the increasing the number of train folders will increase accuracy. Copying a train folder and using that as a seperate folder (for example copying train1 and naming it train2 and using it) will not make any difference. Its the same as using only train1 but will of course consume more time to run.

In [None]:
#autofocus_path='S:/Personal/Projects/Python/Autofocus/'
#autofocus_path='/data/Autofocus/'
autofocus_path='/mnt/Velocity Vault/Autofocus/'
autofocus_train_path=autofocus_path+"Train/"
autofocus_test_path=autofocus_path+"Test/"
autofocus_temp_path=autofocus_path+"Temp/"

Necessary imports

In [None]:
import os
import pprint
import numpy as np
import random
import cv2
import OpenEXR
import Imath
from tqdm import tqdm
import copy
import time
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam

print(tf.version.VERSION)

Make dictionaries which contain the location of the depth maps and depth confidance maps. Instead of the storing the file itself, we store just the locations so that the the data doesn't use up RAM and instead we can access them whenever needed.

In [None]:

def build_depth_map(super_path):
    depth_map_path = {}

    # Walk through the directory tree
    for root, _, files in os.walk(super_path):
        for filename in files:
            if filename.endswith('.png'):
                full_path = os.path.join(root, filename)
                # Remove super_path from the full path to get the relative path
                relative_path = os.path.relpath(full_path, super_path)
                # Split the relative path into parts (folders)
                parts = relative_path.split(os.sep)
                parts[-1]=parts[-1][len('result_merged_depth_'):-len('.png')]
                
                # Initialize nested dictionaries as needed
                current_dict = depth_map_path
                for part in parts[:-1]:  # Iterate over all parts except the last one (filename)
                    current_dict = current_dict.setdefault(part, {})
                
                # Assign the full path to the deepest nested dictionary
                current_dict[parts[-1]] = full_path

    return depth_map_path

def build_depth_map_confidence(super_path):
    depth_map_path = {}

    # Walk through the directory tree
    for root, _, files in os.walk(super_path):
        for filename in files:
            if filename.endswith('.exr'):
                full_path = os.path.join(root, filename)
                # Remove super_path from the full path to get the relative path
                relative_path = os.path.relpath(full_path, super_path)
                # Split the relative path into parts (folders)
                parts = relative_path.split(os.sep)
                parts[-1]=parts[-1][len('result_merged_conf_'):-len('.exr')]
                
                # Initialize nested dictionaries as needed
                current_dict = depth_map_path
                for part in parts[:-1]:  # Iterate over all parts except the last one (filename)
                    current_dict = current_dict.setdefault(part, {})
                
                # Assign the full path to the deepest nested dictionary
                current_dict[parts[-1]] = full_path

    return depth_map_path



It is possible that someone hasn't downloaded all the Train forder for some reason, so we cfind how many folders have been downloaded and then load then in the dictionary as needed

In [None]:

def find_train_folders(directory):
    train_folders = []
    for root, dirs, files in os.walk(directory):
        for dir in dirs:
            if dir.startswith('train'):
                train_folders.append(os.path.join(root, dir))
    return train_folders

train_path=find_train_folders(autofocus_train_path)
print(len(train_path))
depth_map_path={}
depth_map_path_temp={}

for path in train_path:
    depth_map_path_temp=build_depth_map(path+"/merged_depth")
    depth_map_path.update(depth_map_path_temp)

depth_map_confidence_path={}
depth_map_confidence_path_temp={}

for path in train_path:
    depth_map_confidence_path_temp=build_depth_map_confidence(path+"/merged_conf")
    depth_map_confidence_path.update(depth_map_confidence_path_temp)

pprint.pprint(depth_map_path)
pprint.pprint(depth_map_confidence_path)

print(len(depth_map_path))


Generate random patches in the image of sizes 32x32 (since we need 128x128 patches from the images but the depth map is 4 times smaller than the actual image) with a distance of 40 around each patch

In [None]:

def generate_random_patches(size=(504,378)):
    image=np.zeros(size)
    patches=[]
    for _ in (range(100)):
        x=random.randint(16, 487)
        y=random.randint(16, 361)

        #print(x,y)
        
        breaker=False

        for i in range(x-26,x+26):
            for j in range(y-26,y+26):
                if i<0 or i>503 or j<0 or j>377 :
                    continue
                if image[i][j]==1:
                    breaker=True
                    break
            if breaker:
                break

        if breaker:
            continue
        
        x=x-16
        y=y-16

        for i in range(x,x+33):
            for j in range(y,y+33):
                image[i][j]=1

        patches.append((x,y))


    return patches

Given the depth maps, we will take every image and generate patches and calculate which of the patches have a higher median confidance and then calculate the median depth for those patches. Then median depth is converted to actual depth which is our predicted focal length. Then we select the slice with the nearest focal length with the predicted focal length.

In [None]:


def load_exr(file_path):
    exr_file = OpenEXR.InputFile(file_path)
    header = exr_file.header()
    dw = header['dataWindow']
    width = dw.max.x - dw.min.x + 1
    height = dw.max.y - dw.min.y + 1
    pt = Imath.PixelType(Imath.PixelType.FLOAT)
    r = np.frombuffer(exr_file.channel('R', pt), dtype=np.float32)
    r.shape = (height, width)
    return r

def predict_patches(depth_confidence_path):

    patch_threshold = 15

    depth_map_confidence = load_exr(depth_confidence_path)

    all_patches=generate_random_patches()

    confidence_patch_blocks=[depth_map_confidence[x:x+32,y:y+32] for (x,y) in all_patches]
    median_confidence_patch_block=[np.median(patch.flatten()) for patch in confidence_patch_blocks]

    patch_indices=[i for i, value in enumerate(median_confidence_patch_block) if value >= 0.98]

    if len(patch_indices)>4:
        patch_indices = sorted(range(len(median_confidence_patch_block)), key=lambda i: median_confidence_patch_block[i], reverse=True)[:patch_threshold]

    patches=[all_patches[patch_index] for patch_index in patch_indices]

    return patches

def predict_focal_length(depth_map_path,patch):
    depth_map = cv2.imread(depth_map_path, cv2.IMREAD_GRAYSCALE).astype(np.float32)

    depth_values=depth_map[patch[0]:patch[0]+32,patch[1]:patch[1]+32]
    depth_values=depth_values.flatten()

    # Define max and min values
    max_depth = 100.0
    min_depth = 0.2

    depth_map_in_meters = (max_depth * min_depth) / (max_depth - (max_depth - min_depth) * (depth_values / 255.0))

    # Compute the median value in the entire depth map
    median_depth = np.median(depth_map_in_meters)

    final_focus=median_depth*1000

    return final_focus

slice_focal_length=[3910.92,2289.27,1508.71,1185.83,935.91,801.09,700.37,605.39,546.23,486.87,447.99,407.40,379.91,350.41,329.95,307.54,291.72,274.13,261.53,247.35,237.08,225.41,216.88,207.10,198.18,191.60,183.96,178.29,171.69,165.57,160.99,155.61,150.59,146.81,142.35,138.98,134.99,131.23,127.69,124.99,121.77,118.73,116.40,113.63,110.99,108.47,106.54,104.23,102.01]

def find_closest(value, num_list):
    closest_value = min(num_list, key=lambda x: abs(x - value))
    return closest_value

def predict_slice(depth_map_path,depth_confidence_path):
    truth=[]

    patches=predict_patches(depth_confidence_path)
    
    for patch in patches:
        predicted_focus=predict_focal_length(depth_map_path,patch)
        closest_value = find_closest(predicted_focus, slice_focal_length)
        true_slice=slice_focal_length.index(closest_value)
        # manual annotation
        # true_slice=true_slice-1 if true_slice !=0 else true_slice
        truth.append((patch[0],patch[1],true_slice))

    return truth

Make the Ground Truth dictionary which contains the predicted focal slice and its respective patches

In [None]:

ground_truth=copy.deepcopy(depth_map_path)

for image_type in tqdm(ground_truth):
    for pos in ground_truth[image_type]:
        ground_truth[image_type][pos]=predict_slice(depth_map_path[image_type][pos],depth_map_confidence_path[image_type][pos])

pprint.pprint(ground_truth)

Counting the average number of patches, increasing the threshold of 4 in the original function will increase the size of the training dataset

In [None]:
patch_count=[]

for key1 in ground_truth:
    for key2 in ground_truth[key1]:
        patch_count.append(len(ground_truth[key1][key2]))

print(min(patch_count)," ",np.mean(patch_count)," ",max(patch_count))

Function that takes the image path and returns the image with its own patch

In [None]:
def generate_image_paths(super_path):
    # Initialize the dictionary
    image_path = {}

    # Traverse the directory structure
    for root, dirs, files in os.walk(super_path):
        for file in files:
            if file.endswith('.png'):
                # Get the relative path from super_path
                relative_path = os.path.relpath(root, super_path)
                # Split the relative path to get <string> and <integer>
                string_part, integer_part = os.path.split(relative_path)
                integer_part = int(integer_part)
                
                # Extract the position from the filename
                file_name_parts = file.split('_')
                position = file_name_parts[-1].split('.')[0]  # Extract 'bottom', 'top', etc.
                
                # Construct the full file path
                full_file_path = os.path.join(root, file)
                
                # Populate the dictionary
                if string_part not in image_path:
                    image_path[string_part] = {}
                if integer_part not in image_path[string_part]:
                    image_path[string_part][integer_part] = {}
                
                image_path[string_part][integer_part][position] = full_file_path

    return image_path



Storing only the locations of the left and right dual pixel images

In [None]:

left_image_path = {}
left_image_path_temp = {}
right_image_path = {}
right_image_path_temp = {}

train_path = find_train_folders(autofocus_train_path)
print(len(train_path))


for path in train_path:
    left_image_path_temp = generate_image_paths(path+'/raw_up_left_pd')
    left_image_path.update(left_image_path_temp)
    right_image_path_temp = generate_image_paths(path+'/raw_up_right_pd')
    right_image_path.update(right_image_path_temp)


Since the dictionary contains the path length as [...][<focal_stack_number>][<position>]=path, for ease of creation of the dataset we need to make it into [...][<position>][<focal_stack_number>]=path

In [None]:
def reorder_image_path(image_path):
    ip = {}

    for key1, level2_dict in image_path.items():
        for key2, level3_dict in level2_dict.items():
            for key3, value in level3_dict.items():
                if key1 not in ip:
                    ip[key1] = {}
                if key3 not in ip[key1]:
                    ip[key1][key3] = []
                ip[key1][key3].append(value)

    return ip

left_image_path=reorder_image_path(left_image_path)
right_image_path=reorder_image_path(right_image_path)

pprint.pprint(left_image_path)
pprint.pprint(right_image_path)
print(len(left_image_path))
print(len(right_image_path))



For the second experiment we use the dual pixel data so the images are loaded sperately creating 49 left dp slices and 49 right dp slices, thus making 98 slices. So we need to make the paths into a single dictionary for ease of creation of the dataset. Then the the paths a stored in a tuple with left dp and right dp images.

In [None]:
def experminet_combine(left_path,right_path):
    return list(zip(left_path,right_path))

def combine_image_paths(left_image_path,right_image_path):
    image_paths=copy.deepcopy(left_image_path)
    for key1 in image_paths:
        for key2 in image_paths[key1]:
            image_paths[key1][key2]=experminet_combine(left_image_path[key1][key2],right_image_path[key1][key2])
    return image_paths

image_paths=combine_image_paths(left_image_path,right_image_path)

pprint.pprint(image_paths)
print(len(image_paths))
print(type(image_paths['apt1_0']['center']))
print(len(image_paths['apt1_0']['center']))

Since the images are of format AxBx3 we need to make them into AxB and then to create the dataset we take multiple such images making it sxAxB and converting it into AxBxs

In [None]:
def check_dimension(arr):
    # Ensure the array has the correct shape
    s=len(arr.shape)
    if s!=2:
        print("Input array must have shape AxB")
    
    return arr

def combine_last_dimension(arrays):
    A = arrays[0].shape[0]  # Assuming all arrays have the same shape AxA
    s = len(arrays)
    
    # Initialize the resulting array with lists
    list_array = np.empty((A, A, s), dtype=object)
    
    # Fill the array with corresponding elements from each input array
    for k, array in enumerate(arrays):
        list_array[..., k] = array
    
    return list_array


The dataset and the ground truth labels are created correspoing to each patch of each focal stack slice and the dp values are added and the mean value is taken to simulate only the green channel data of an image.

In [None]:
def build_labels(ground_truth):
    labels=[]

    for image_type in tqdm(ground_truth):
        for pos in ground_truth[image_type]:
            temp=ground_truth[image_type][pos]
            truth=[z for (_,_,z) in temp]

            for i in range(len(truth)):
                labels.append(truth[i])

    return labels

labels_RAM=build_labels(ground_truth)

# pprint.pprint(labels)
print(len(labels_RAM))

In [None]:
def build_patch_count(ground_truth):
    patch_count=[]

    for image_type in tqdm(ground_truth):
        for pos in ground_truth[image_type]:
            count=len(ground_truth[image_type][pos])
            patch_count.append(count)

    return patch_count

patch_RAM=build_patch_count(ground_truth)

# pprint.pprint(labels)
print(len(patch_RAM))

In [None]:
_dataset_len=len(labels_RAM)

filename=autofocus_path+"exp2_labels_cache.dat"
shape=(_dataset_len)
labels = np.memmap(filename, dtype='int8', mode='w+', shape=shape)

for index,label in enumerate(labels_RAM):
    labels[index]=label

labels.flush()

filename=autofocus_path+"exp2_patch_cache.dat"
shape=(len(patch_RAM))
patch_count = np.memmap(filename, dtype='int8', mode='w+', shape=shape)

for index,patch in enumerate(patch_RAM):
    patch_count[index]=patch

patch_count.flush()

In [None]:

def sleep_with_progress(seconds,index):
    print("Start Sleeping",seconds,"seconds -",index)
    time.sleep(seconds)
    print("End Sleeping -",index)

def wait_for_file(filepath, check_interval=1):
    while not os.path.isfile(filepath):
        time.sleep(check_interval)
    return 0

def build_dataset(filename,shape,image_paths,ground_truth):

    wait_for_file(filename)
    dataset=np.memmap(filename, dtype='int8', mode='r+', shape=shape)
    data_count=0
    count=0
    sleep_index=60
    for image_type in tqdm(ground_truth):
        if count%sleep_index==0 and count!=0:
            sleep_with_progress(60,count//sleep_index)
        count+=1
        for pos in ground_truth[image_type]:
            temp=ground_truth[image_type][pos]
            patches=[(x,y) for (x,y,_) in temp]

            try:
                all_images=[(check_dimension(cv2.imread(path[0],0)),check_dimension(cv2.imread(path[1],0))) for path in image_paths[image_type][pos]]
            except Exception as e:
                print("Reconnect Drive")
                checker=[(wait_for_file(path[0]),wait_for_file(path[1])) for path in image_paths[image_type][pos]]
                all_images=[(check_dimension(cv2.imread(path[0],0)),check_dimension(cv2.imread(path[1],0))) for path in image_paths[image_type][pos]]

            all_images=[(image[0]+image[1])//2 for image in all_images]

            for i in range(len(patches)):
                x=patches[i][0]*4
                y=patches[i][1]*4

                images=[image[x:x+128,y:y+128]for image in all_images]

                image_set=combine_last_dimension(images)
                image_array=np.array(image_set,dtype='int8')

                wait_for_file(filename)
                dataset[data_count]=image_array            
                wait_for_file(filename)
                dataset.flush()
                data_count+=1


filename=autofocus_path+"exp2_dataset_cache.dat"
shape=(_dataset_len,128,128,49)
dataset=np.memmap(filename, dtype='int8', mode='w+', shape=shape)

build_dataset(filename,shape,image_paths,ground_truth)

# pprint.pprint(dataset)
print(len(dataset))


Just checking if the dataset is in the desired shape

In [None]:

def check_shapes(lst):
    target_shape = (128, 128, 49)
    for element in lst:
        if not isinstance(element, np.ndarray) or element.shape != target_shape:
            return False
    return True

print(dataset[0].shape)
print(check_shapes(dataset))

print(len(dataset))

In [None]:
# dataset_len=22678

def save_variable_to_file(variable, filename):
    try:
        with open(filename, 'w') as file:
            file.write(str(variable))
        print(f"Variable saved successfully to {filename}")
    except Exception as e:
        print(f"Error occurred while saving variable to {filename}: {e}")

save_variable_to_file(_dataset_len,autofocus_path+"exp2_length.dat")


In [None]:
def exp1_to_exp2(exp1_filename, exp1_shape, exp2_filename, exp2_shape):
    # Open the input memmap file in read-only mode with dtype int8
    exp1_dataset = np.memmap(exp1_filename, dtype='int8', mode='r', shape=exp1_shape)
    
    # Create the output memmap file with the given shape and dtype int8
    exp2_dataset = np.memmap(exp2_filename, dtype='int8', mode='w+', shape=exp2_shape)
    
    # Loop over the array and process in chunks to handle large data efficiently
    for i in tqdm(range(exp1_shape[0])):
        for j in range(exp1_shape[1]):
            for k in range(exp1_shape[2]):
                # Reshape and reduce the last dimension by averaging pairs
                reshaped = exp1_dataset[i, j, k].reshape(49, 2)
                reduced = reshaped.mean(axis=-1).astype('int8')
                
                # Write the reduced data to the new memmap file
                exp2_dataset[i, j, k] = reduced
    
        # Flush changes to the output file
        exp2_dataset.flush()

exp1_to_exp2(autofocus_path+"exp1_dataset_cache.dat", (_dataset_len,128,128,98), autofocus_path+"exp2_dataset_cache.dat", (_dataset_len,128,128,49))


In [None]:
def patch_threshold (patch_count,threshold):
    patch_index=[]
    index=0

    for count in patch_count:
        indices=[index+i for i in range(min(threshold,count))]
        index+=count
        patch_index=patch_index+indices
    
    return patch_index

def dataset_threshold(data_name,shape_data,patch_count_name,shape_patch,threshold):
    dataset=np.memmap(data_name, dtype='int8', mode='r', shape=shape_data)
    patch_count=np.memmap(patch_count_name, dtype='int8', mode='r', shape=shape_patch)

    patch_index=patch_threshold(patch_count,threshold)
    new_len=len(patch_index)

    updated_dataset=np.memmap(autofocus_path+"exp2_dataset_cache.dat", dtype='int8', mode='w+', shape=(new_len,128,128,49))

    for pos,index in tqdm(enumerate(patch_index),total=len(patch_index)):
        updated_dataset[pos]=dataset[index]

#basically 1775 is 355*5 which is number of scenes*5
dataset_threshold(autofocus_temp_path+"exp2_dataset_cache.dat",(_dataset_len,128,128,49),autofocus_temp_path+"exp2_patch_cache.dat",(1775,),7)


Incase someone is low on RAM (which I was) run this code, this deletes all the global variables thus freeing RAM and loads the needed variables from the previous instance and loads the dataset from the file which takes very little time compared to creating it.

In [None]:
def load_variable_from_file(filename):
    try:
        with open(filename, 'r') as file:
            variable = file.read()
        print(f"Variable loaded successfully from {filename}")
        return variable
    except Exception as e:
        print(f"Error occurred while loading variable from {filename}: {e}")
        return None
    
_dataset_len=int(load_variable_from_file(autofocus_path+"exp2_length.dat"))

In [None]:
import os

# Set the environment variable
os.environ['HSA_OVERRIDE_GFX_VERSION'] = '10.3.0'
os.environ['TF_GPU_ALLOCATOR']='cuda_malloc_async'

import os
import pprint
import numpy as np
import random
import cv2
import OpenEXR
import Imath
from tqdm import tqdm
import copy
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam

print(tf.version.VERSION)

dataset=np.memmap(autofocus_path+"exp2_dataset_cache.dat", dtype='int8', mode='r', shape=(_dataset_len,128,128,49))
labels=np.memmap(autofocus_path+"exp2_labels_cache.dat", dtype='int8', mode='r', shape=(_dataset_len,))

The MobileNetV2 model is used. The input shape is converted to (128,128,49) to be able to take the focal stack slices as individual cahnnels of an image. Then we use ordinal regression loss (L2). The Adam optimizer is used to build the model. 

In [None]:


# Step 1: Modify MobileNetV2 to accept 128x128x98 input
def create_modified_mobilenetv2(input_shape=(128, 128, 49)):
    base_model = MobileNetV2(input_shape=input_shape, include_top=False, weights=None)
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    # Assuming the output for the ordinal regression problem is a single value
    outputs = Dense(1)(x)
    model = Model(inputs=base_model.input, outputs=outputs)
    return model

# Step 2: Implement the ordinal regression loss
def ordinal_regression_loss(y_true, y_pred):
    # L2 loss (mean squared error)
    return tf.reduce_mean(tf.square(y_true - y_pred), axis=-1)

# Step 3: Set up the training loop
def train_model(model, train_dataset, steps_per_epoch=20000, epochs=128, initial_lr=1e-3, beta1=0.5, beta2=0.999):
    optimizer = Adam(learning_rate=initial_lr, beta_1=beta1, beta_2=beta2)
    model.compile(optimizer=optimizer, loss=ordinal_regression_loss)
    
    model.fit(train_dataset, epochs=epochs, steps_per_epoch=steps_per_epoch)

# Create the model
input_shape = (128, 128, 49)
model = create_modified_mobilenetv2(input_shape=input_shape)

#model.summary()

The steps_per_epoch is the only thing the user running the model needs to think about. The formula is len(dataset)=batch_size*steps_per_epoch*epochs. So choose a good steps_per_epoch size. Choosing a higher number means consuming more RAM but less time and choosing a lesser number number means less RAM but more time. 

In [None]:
batch_size=128

train_dataset = tf.data.Dataset.from_tensor_slices((dataset, labels)).batch(batch_size)
print("Train Dataset Length :",len(train_dataset))

In [None]:
steps_per_epoch=3
epochs=len(train_dataset)//steps_per_epoch

train_model(model, train_dataset, steps_per_epoch=steps_per_epoch, epochs=epochs)

In [None]:
model.save(autofocus_path+"Experiment_2_model.keras")

#model=tf.keras.models.load_model(autofocus_path+"autofocus.keras")


In [None]:
def clear_ram():
    global_vars = list(globals().keys())  # Get a list of global variable names
    vars_to_delete = [var for var in global_vars]
    
    # Delete selected variables
    for var in vars_to_delete:
        del globals()[var]
    import gc
    # Invoke garbage collector
    gc.collect()
    
    # Print confirmation
    print('RAM cleared')

time.sleep(10)

clear_ram()