## Pre-calculate the underwater params

In [None]:
import Pude_training_loop.loss_functions_torch as loss_functions
import Pude_training_loop.pude_utils as pude_utils
import Pude_training_loop.model_training as model_training
import Pude_training_loop.dataset_loader as data_loader
from Pude_training_loop.data_logger import Data_logger
from Pude_training_loop.physics_parameter_estmation import UnderwaterParameterFinder
import torch
import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"
data_log_results_dir = "Results/SeaThru_Combined/Datalogger_params/"

model_name = "depth_anything"

depth_anything_model, depth_anything_image_processor = model_training.get_model_image_processor_pair(model_name=model_name, model_path=model_training.models[model_name], device=device)

dataset_loader = data_loader.DatasetLoader() # Initialize dataset loader with default parameters
# for pude tau_thresholds=(0.08, 0.6), depthanyting its 1
underwater_parameter_estimator = UnderwaterParameterFinder() # Initialize underwater parameter finder with default parameters
data_logger = Data_logger(results_dir=data_log_results_dir)
# Define training parameters
batch_size = 1

# seed the torch
torch.manual_seed(model_training.seed)

# storing underwater params for efficient training
underwater_params = []


is_first_epoch = True
skipped_images = []
max_images_to_process = 182

for i in range(len(dataset_loader)):
    non_linear_images, linear_images = dataset_loader[i]
    # linear_images = linear_images.to(device)
    # Similarly, convert non_linear_images to a PyTorch tensor
    non_linear_images_tensor = torch.tensor(non_linear_images)
    # Forward pass
    depth_anything_output = model_training.get_model_output(model=depth_anything_model, 
                                                            image_processor=depth_anything_image_processor, 
                                                            raw_image=non_linear_images_tensor, device=device, requires_grad=False)
    
    # normalize the depth map
    model_output = depth_anything_output.cpu().detach().numpy()
    depth_anything_output =(((model_output-np.min(model_output)) / (np.max(model_output)-np.min(model_output)))*20.0)

    hat_nu, hat_mu, hat_B_infty, valid_estimate = underwater_parameter_estimator.algorithm_1(d_D=depth_anything_output, I=linear_images, data_logger=data_logger, non_linear_image=non_linear_images)
    print("Estimated parameters: ", hat_nu, hat_mu, hat_B_infty, valid_estimate)
    data_logger.log_data((hat_nu, hat_mu, hat_B_infty, valid_estimate))


data_logger.save_data()


Estimated parameters:  [9.56667519 4.6959245 ] 1.0 [0.20937552 0.35893391] True

## New PUDE Training Loop

In [None]:
import Pude_training_loop.loss_functions_torch as loss_functions
import Pude_training_loop.pude_utils as pude_utils
import Pude_training_loop.model_training as model_training
import Pude_training_loop.dataset_loader as data_loader
from Pude_training_loop.physics_parameter_estmation import UnderwaterParameterFinder
import torch
import numpy as np
from PIL import Image
import cv2

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import linregress
from PIL import Image


device = "cuda" if torch.cuda.is_available() else "cpu"

for alpha3 in [2,4,6]:
    print("----------------------------------")
    print(f"Training for alpha3: {alpha3}")
    # seed the torch
    torch.manual_seed(model_training.seed)
    np.random.seed(model_training.seed)

    # depth_anything_model, depth_anything_image_processor = model_training.get_model_image_processor_pair(model_name="depth_anything", model_path=model_training.models["depth_anything"], device=device)
    # new_pude_model, new_pude_mode_image_processor = model_training.get_model_image_processor_pair(model_name="depth_anything", model_path=model_training.models["depth_anything"], device=device)
    depth_anything_model, new_pude_model, depth_anything_image_processor, new_pude_mode_image_processor = model_training.get_two_separate_model_pairs(model_path=model_training.models["depth_anything"], device=device)
    dataset_loader = data_loader.DatasetLoader() # Initialize dataset loader with default parameters #1e-15 smoothness
    # betas = [1, 0.1, 0], alphas = [1, 0.1, 4]
    pude_loss_fn = loss_functions.PUDELoss(betas=[1, 0.1,0], alphas=[1, 0.1, alpha3]) # Initialize Pude loss function with default parameters
    # 1, 0.1 
    # pude_loss_fn = loss_functions.PUDELoss(betas=[20, 60], alphas=[1,0.1,1]) # Initialize Pude loss function with default parameters
    # pude_loss_fn = loss_functions.PUDELoss(betas=[5, 0.0005]) # Initialize Pude loss function with default parameters
    # Define training parameters
    # epochs = 3
    # learning_rate = 1e-6
    # batch_size = 1

    # # Define optimizer
    # optimizer = torch.optim.AdamW(new_pude_model.parameters(), lr=learning_rate)

    new_pude_model.train()


    # Define training parameterss
    epochs = 3
    learning_rate = 1e-6
    batch_size = 1


    # Define optimizer
    optimizer = torch.optim.AdamW(new_pude_model.parameters(), lr=learning_rate)
    # linear scheduler 
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)




    print("Training started")

    #  load the underwater params from pickle file
    # underwater_param_file = "Results/SeaThru_Combined/Datalogger_params/params/parameter_results.pickle"
    # underwater_params = pude_utils.load_pickle(underwater_param_file)
    # # skipped images for seathrunerf
    # # skipped_images = [5, 14, 25, 59, 60, 61]
    # # skipped images for seathrucombined
    # skipped_images = [5,  14, 23, 24, 25, 28, 37, 38, 39, 41, 
    #                   42, 43, 45, 46, 51, 52, 53, 56, 57, 59, 
    #                   61, 63, 66, 67, 69, 70, 74, 75, 81, 82, 
    #                   83, 86, 87, 88, 95, 96, 118, 171, 172, 173, 
    #                   175, 177, 179 ]


    #  load the underwater params from pickle file
    underwater_param_file = "Results/SeaThru_Combined/Datalogger_params/params/parameter_results.pickle"
    underwater_params = pude_utils.load_pickle(underwater_param_file)
    skipped_images = [5,  14, 23, 24, 25, 28, 37, 38, 39, 41, 
                    42, 43, 45, 46, 51, 52, 53, 55, 56, 57, 59, 
                    61, 63, 64, 66, 67, 68, 69, 70, 71, 72, 73,  74, 75, 76, 77, 78, 81, 82, 
                    83, 85,  86, 87, 88, 89, 90, 93,94,  95, 96, 97, 107, 108, 112, 113, 118, 119, 122, 123, 124, 127, 130, 131, 171, 172, 173, 174,
                    175, 177, 179 ]


    max_images_to_process = len(dataset_loader)-len(skipped_images)
    print("Total images to process: ", max_images_to_process)

    indices = np.random.permutation(len(dataset_loader))

    for epoch in range(epochs):
        n = 0
        for i in indices:
            # if (i+1 - len(skipped_images))>max_images_to_process:
            #     break
            if i in skipped_images:
                continue
            n+=1
            
            non_linear_images, linear_images = dataset_loader[i]
            # linear_images = linear_images.to(device)
            # Similarly, convert non_linear_images to a PyTorch tensor
            non_linear_images_tensor = torch.tensor(non_linear_images)
            # Forward pass
            depth_anything_output = model_training.get_model_output(model=depth_anything_model, 
                                                                    image_processor=depth_anything_image_processor, 
                                                                    raw_image=non_linear_images_tensor, device=device, requires_grad=False)
            pude_output = model_training.get_model_output(model=new_pude_model, 
                                                        image_processor=new_pude_mode_image_processor, 
                                                        raw_image=non_linear_images_tensor, device=device)
            # parameter estimation
            hat_nu, hat_mu, hat_B_infty, _ = underwater_params[i]

            # if (n)%1==0 or n==1:
            #     result_images = pude_utils.pude_display_image_with_depth(image=Image.fromarray(non_linear_images), 
            #                                                              depth1=depth_anything_output.cpu().detach().numpy(), 
            #                                                              depth2=pude_output.cpu().detach().numpy())
            #     result_images.save(f"Results/SeaThru_Combined/Pude_results/epoch{epoch}_{n}_image_{i}.png")
            # Loss calculation
            loss = pude_loss_fn( pude_output, depth_anything_output, 
                                torch.tensor(linear_images, device=device, requires_grad=True), 
                                torch.tensor(hat_nu, device=device, requires_grad=True), 
                                torch.tensor(hat_mu, device=device, requires_grad=True), 
                                torch.tensor(hat_B_infty, device=device, requires_grad=True).unsqueeze(1))
            if loss.item() is torch.nan:
                print(f"Loss is nan for image {i}")
                continue
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if ((n)%20==0  or n==1):
                print(f"Epoch: {epoch}, Image: {n}, Loss: {loss.item()}")
                eval_image = model_training.evaluate(depth_anything_model=depth_anything_model, new_pude_model=new_pude_model,
                                depth_anything_image_processor=depth_anything_image_processor, new_pude_image_processor=new_pude_mode_image_processor,
                                non_linear_images=dataset_loader[-1][0]) 
                display(eval_image)      
        print(f"Epoch: {epoch+1}")
        scheduler.step()
        # Save the model
        # torch.save(new_pude_model.state_dict(), f"new_pude_model_combined_edge_aware_betas_1_1_1_alphas_1_0-1_1_lr_{str(learning_rate)}_epoch_{str(epoch)}.pth")
        


    print("Training completed")
    print(f"Number of skipped images: {len(skipped_images)}")
    # print(f"Number of processed images: {min(60,len(dataset_loader)-len(skipped_images))}")

    

    # Save the model
    # torch.save(new_pude_model.state_dict(), "new_pude_model_2.pth")

    def plot_one_over_z_vs_d(actual_depth, model_output, save_folder=None, img_name="test image"):
        # Remove zero values
        z = actual_depth.flatten()
        d = model_output.flatten()

        # normalise d
        d = d/np.max(d)
        z = z/np.max(z)


        non_zero = np.where(z > 0.15)
        z = 1/z[non_zero]
        d = d[non_zero]


        plt.figure()

        def _fit_lingress(z,d):
            # Linear regression
            result = linregress(z, d)
            
            # Line of best fit
            fit_x = np.linspace(np.min(z), np.max(z), 100)
            fit_y = result.slope * fit_x + result.intercept
            return result, fit_x, fit_y
            
        def _remove_outlier(z,d):
            # Linear regression
            result, fit_x, fit_y = _fit_lingress(z,d)
            # Calculate residuals
            residuals = d - result.slope * z - result.intercept

            abs_residuals = np.abs(residuals)
            threshold = 4 * np.std(abs_residuals)

            # Identify the non-outliers
            non_outliers_mask = abs_residuals < threshold

            return z[non_outliers_mask], d[non_outliers_mask]


        z, d = _remove_outlier(z,d)
        result, fit_x, fit_y = _fit_lingress(z,d)

        # Scatter plot
        plt.scatter(z, d, s=1)
        plt.xlabel("1/z")
        plt.ylabel("d")
    
        plt.plot(fit_x, fit_y, '-r', label='Line of best fit, r = {:.3f}'.format(result.rvalue))
        plt.ylim(bottom=0)
        # Display R-squared value
        plt.legend()
        plt.title(f"1/z vs d for {img_name}")
        plt.show( )
        # plt.savefig(save_folder, dpi=100, bbox_inches='tight')
        plt.close()
        return result.rvalue
    # correct image
    def correct_image(img):
        """
        Correct the brightness of an RGB image.

        Parameters:
            img (numpy.ndarray): Input RGB image.

        Returns:
            numpy.ndarray: Image with corrected brightness.
        """
        # Get the shape of the image
        rows, cols, channels = img.shape

        # Compute the average brightness across all channels
        brightness = np.sum(img) / (255 * rows * cols * channels)

        # Define the target minimum brightness
        minimum_brightness = 0.3

        # Compute the brightness ratio
        ratio = brightness / minimum_brightness

        # If the ratio is greater than or equal to 1, the image is already bright enough
        if ratio >= 1:
            print("Image already bright enough")
            return img

        # Otherwise, adjust brightness to get the target brightness for each channel
        corrected_img = cv2.convertScaleAbs(img, alpha=1 / ratio, beta=0)

        return corrected_img


    def white_balance_linear(img):
        # Compute the average value of each color channel

        # Get the shape of the image
        rows, cols, channels = img.shape

        # Compute the average brightness across all channels
        brightness = np.sum(img) / (255 * rows * cols * channels)

        # Define the target minimum brightness
        minimum_brightness = 0.25

        # Compute the brightness ratio
        ratio = brightness / minimum_brightness

        # If the ratio is greater than or equal to 1, the image is already bright enough
        if ratio >= 1:
            print("Image already bright enough")
            return img

        # Clip the scaled image to ensure pixel values remain in the valid range [0, 255]
        corrected_img = np.clip(img * 1/ratio, 0, 255)

        # # Otherwise, adjust brightness to get the target brightness for each channel
        # corrected_img = cv2.convertScaleAbs(img, alpha=1 / ratio, beta=0)

        return corrected_img.astype(np.uint8)


    def _prepare_depth(model_output):
            # prepare images for visualization
            # format the image to be between 0 and 255
            formatted = (((model_output-np.min(model_output)) / (np.max(model_output)-np.min(model_output)))*255).astype("uint8")
            colored_depth = cv2.applyColorMap(formatted, cv2.COLORMAP_INFERNO)[:, :, ::-1]
            depth = Image.fromarray(colored_depth)
            return depth
    

    Dataset_name = "D5"
    image_name = "LFT_3395"

    # Dataset_name = "D3"
    # image_name = "T_S04857"
    test_image_path = f"Datasets\SeaThru\{Dataset_name}\linearPNG\{image_name}.png"
    # test_image_path = f"Datasets\SeaThru_old\{Dataset_name}\Raw\{image_name}.NEF"
    # test_image_path = f"Datasets\SeaThru_old\{Dataset_name}\Raw\{image_name}.ARW"
    # test_image_path = f"Datasets\SeaThruNeRF\Curasao\images_wb\MTN_1288.png"
    ground_truth_path = f"Datasets\SeaThru\{Dataset_name}\depth\depth{image_name}.tif"
    actual_depth = dataset_loader.open_depth_map(path=ground_truth_path, img_dim=(model_training.default_image_dim[1], model_training.default_image_dim[0]))
    non_linear_image = dataset_loader._open_image(test_image_path, img_dim=model_training.default_image_dim)[0]

    non_linear_image_np = correct_image(non_linear_image)
    # non_linear_image_np = non_linear_image
    non_linear_image = Image.fromarray(non_linear_image_np)

    display(pude_utils.make_image_grid([non_linear_image, _prepare_depth(actual_depth)], rows=1, cols=2))

    non_linear_image_tensor = torch.tensor(non_linear_image_np)

    pude_output = model_training.get_model_output(model=new_pude_model, 
                                                        image_processor=new_pude_mode_image_processor, 
                                                        raw_image=non_linear_image_tensor, device=device, requires_grad=False)
    depth_anything_output = model_training.get_model_output(model=depth_anything_model, 
                                                                    image_processor=depth_anything_image_processor, 
                                                                    raw_image=non_linear_image_tensor, device=device, requires_grad=False)
    pude_utils.display_image_with_depth(image=_prepare_depth(actual_depth), depth1=depth_anything_output.cpu().detach().numpy(), depth2=pude_output.cpu().detach().numpy())

    plot_one_over_z_vs_d(actual_depth, pude_output.cpu().detach().numpy(), img_name="pude_output")
    plot_one_over_z_vs_d(actual_depth, depth_anything_output.cpu().detach().numpy(), img_name="depth_anything_output")







In [None]:
torch.save(new_pude_model.state_dict(), "best_new_pude_model_Adam_betas_1_0.1_0_alphas_0.5_0.1_4_lr_1e-6_epoch_3.pth")

## New stereo + pude training loop

In [None]:
import Pude_training_loop.loss_functions_torch as loss_functions
import Pude_training_loop.pude_utils as pude_utils
import Pude_training_loop.model_training as model_training
import Pude_training_loop.dataset_loader as data_loader
from monocular_stereo_matching.stereo_pair_gen import Stereo_Pair_Generator
import torch
from PIL import Image
import pickle
import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"


pude_path = "best_new_pude_model_Adam_betas_1_0.1_0_alphas_0.5_0.1_4_lr_1e-6_epoch_3.pth"

unimatch_model, unimatch_image_processor = model_training.get_model_image_processor_pair(model_name="unimatch", model_path=model_training.models["unimatch"], device=device)
depth_anything_model, depth_anything_image_processor = model_training.get_model_image_processor_pair(model_name="new_pude", model_path=pude_path, device=device)
# depth_anything_model, new_pude_model, depth_anything_image_processor, new_pude_mode_image_processor = model_training.get_two_separate_model_pairs(model_path=model_training.models["depth_anything"], device=device)
dataset_loader = data_loader.DatasetLoader() # Initialize dataset loader with default parameters
loss_fn = loss_functions.PUDELoss(betas=[1, 20, 4], alphas=[1,0.1,4]) # Initialize Pude loss function with default parameters
#loss_fn = loss_functions.PUDELoss(betas=[20, 60], alphas=[1,0.1,1]) # Initialize Pude loss function with default parameters
#loss_fn = loss_functions.PUDELoss(betas=[6, 0.0005]) # Initialize Pude loss function with default parameters
stereo_pair_gen = Stereo_Pair_Generator(image_dim=model_training.default_image_dim)


# Define training parameters
epochs = 5
learning_rate = 5e-6
batch_size = 1

unimatch_model.train()
# Define optimizer
optimizer = torch.optim.AdamW(unimatch_model.parameters(), lr=learning_rate)
# scheduler 
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)



# seed the torch
torch.manual_seed(model_training.seed)
np.random.seed(model_training.seed)

print("Training started")

# unimatch_outputs = []

#  load the underwater params from pickle file
underwater_param_file = "Results/SeaThru_Combined/Datalogger_params/params/parameter_results.pickle"
underwater_params = pude_utils.load_pickle(underwater_param_file)
skipped_images = [5,  14, 23, 24, 25, 28, 37, 38, 39, 41, 
                  42, 43, 45, 46, 51, 52, 53, 55, 56, 57, 59, 
                  61, 63, 64, 66, 67, 68, 69, 70, 71, 72, 73,  74, 75, 76, 77, 78, 81, 82, 
                  83, 85,  86, 87, 88, 89, 90, 93,94,  95, 96, 97, 107, 108, 112, 113, 118, 119, 122, 123, 124, 127, 130, 131, 171, 172, 173, 174,
                  175, 177, 179 ]

print("Total images to process: ", len(dataset_loader)-len(skipped_images))

# create a random list of indices to shuffle the images in the dataset
indices = np.random.permutation(len(dataset_loader))

# max_images_to_process = 64
for epoch in range(epochs):
    n=0
    for i in indices:
        # if (i+1 - len(skipped_images))>max_images_to_process:
        #     break
        if i in skipped_images:
            continue
        n+=1
        optimizer.zero_grad()
        non_linear_images, linear_images = dataset_loader[i]
        # linear_images = linear_images.to(device)
        # Similarly, convert non_linear_images to a PyTorch tensor
        non_linear_images_tensor = torch.tensor(non_linear_images)
        # Forward pass
        depth_anything_output = model_training.get_model_output(model=depth_anything_model, 
                                                                image_processor=depth_anything_image_processor, 
                                                                raw_image=non_linear_images_tensor, device=device, requires_grad=False)
        # parameter estimation
        # depth_anything_output =(((depth_anything_output-torch.min(depth_anything_output)) / (torch.max(depth_anything_output)-torch.min(depth_anything_output)))*20.0)

        # for pude model rescale it between 0 and 20
       
        # generate stereo pair #depth anything 1.5 scaling #pude 0.005 scaling
        scaling_factor = np.random.uniform(0.5, 5)
        # scaling_factor = np.random.randint(0.8, 5) # [2, 5)
        image_2 = stereo_pair_gen.generate_stereo_pair(non_linear_images, depth_anything_output.cpu().detach().numpy(), scaling_factor=scaling_factor)
    
        unimatch_input = {"image1": non_linear_images, "image2": image_2}
        unimatch_output = model_training.get_model_output(model=unimatch_model, 
                                                     image_processor=unimatch_image_processor, 
                                                     raw_image=unimatch_input, device=device)
        
        # unimatch_outputs.append(unimatch_output.cpu().detach().numpy())
        if (n)%1==0 or n==1:
            result_images = pude_utils.unimatch_display_image_with_depth(image=Image.fromarray(non_linear_images),
                                                            depth1=depth_anything_output.cpu().detach().numpy(),
                                                            depth2=unimatch_output.cpu().detach().numpy(),
                                                            shifted_image=Image.fromarray(image_2), scaling_factor=scaling_factor)
            # save the image
            result_images.save(f"Results/SeaThru_Combined/Unimatch_results/epoch{epoch}_{n}_image_{i}.png")
    
        hat_nu, hat_mu, hat_B_infty, _ = underwater_params[i]
        # Loss calculation
        loss = loss_fn( unimatch_output, depth_anything_output, 
                            torch.tensor(linear_images, device=device, requires_grad=True), 
                            torch.tensor(hat_nu, device=device, requires_grad=True), 
                            torch.tensor(hat_mu, device=device, requires_grad=True), 
                            torch.tensor(hat_B_infty, device=device, requires_grad=True).unsqueeze(1))
        
        # # Backward pass
        loss.backward()
        optimizer.step()


        if ((n)%50==0):
            eval_image = model_training.evaluate_unimatch(depth_anything_model=depth_anything_model, unimatch_model=unimatch_model,
                            depth_anything_image_processor=depth_anything_image_processor, unimatch_image_processor=unimatch_image_processor,
                            stereo_pair_gen=stereo_pair_gen,non_linear_images=dataset_loader[-1][0], scaling_factor=2) 
            display(eval_image)      
    scheduler.step()
    torch.save(unimatch_model.state_dict(), f"new_unimatch_model_edge_aware_loss_cosine_scheduler_AdamW_betas_1_20_4_alphas_1_0.1_4_lr_{str(learning_rate)}_epoch_{str(epoch)}.pth")


    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")
    # evaluate the model and show image grid
    eval_image = model_training.evaluate_unimatch(depth_anything_model=depth_anything_model, unimatch_model=unimatch_model,
                depth_anything_image_processor=depth_anything_image_processor, unimatch_image_processor=unimatch_image_processor,
                stereo_pair_gen=stereo_pair_gen,non_linear_images=dataset_loader[-1][0], scaling_factor=2) 
    display(eval_image)
    
# save the unimatch outputs as pickle using pickle 
# pickle_file = "Results/SeaThru_Combined/unimatch_outputs.pickle"
# with open(pickle_file, 'wb') as f:
#     pickle.dump(unimatch_outputs, f)

print("Training completed")
print(f"Number of skipped images: {len(skipped_images)}")
# print(f"Number of processed images: {min(60,len(dataset_loader)-len(skipped_images))}")



# Save the model
# torch.save(new_pude_model.state_dict(), "new_pude_model_2.pth")

In [None]:
torch.save(unimatch_model.state_dict(), "new_unimatch_model_cosine_scheduler_AdamW_betas_60_60,_alphas_4_0-4_8_lr_5e-6_epochs_3.pth")

## eval

In [None]:
def eval_new_pude():
    import cv2

    import matplotlib.pyplot as plt
    import numpy as np
    from scipy.stats import linregress
    from PIL import Image


    def plot_one_over_z_vs_d(actual_depth, model_output, save_folder=None, img_name="test image"):
        # Remove zero values
        z = actual_depth.flatten()
        d = model_output.flatten()

        # normalise d
        d = d/np.max(d)
        z = z/np.max(z)


        non_zero = np.where(z > 0.15)
        z = 1/z[non_zero]
        d = d[non_zero]


        plt.figure()

        def _fit_lingress(z,d):
            # Linear regression
            result = linregress(z, d)
            
            # Line of best fit
            fit_x = np.linspace(np.min(z), np.max(z), 100)
            fit_y = result.slope * fit_x + result.intercept
            return result, fit_x, fit_y
            
        def _remove_outlier(z,d):
            # Linear regression
            result, fit_x, fit_y = _fit_lingress(z,d)
            # Calculate residuals
            residuals = d - result.slope * z - result.intercept

            abs_residuals = np.abs(residuals)
            threshold = 4 * np.std(abs_residuals)

            # Identify the non-outliers
            non_outliers_mask = abs_residuals < threshold

            return z[non_outliers_mask], d[non_outliers_mask]


        z, d = _remove_outlier(z,d)
        result, fit_x, fit_y = _fit_lingress(z,d)

        # Scatter plot
        plt.scatter(z, d, s=1)
        plt.xlabel("1/z")
        plt.ylabel("d")
    
        plt.plot(fit_x, fit_y, '-r', label='Line of best fit, r = {:.3f}'.format(result.rvalue))
        plt.ylim(bottom=0)
        # Display R-squared value
        plt.legend()
        plt.title(f"1/z vs d for {img_name}")
        plt.show( )
        # plt.savefig(save_folder, dpi=100, bbox_inches='tight')
        plt.close()
        return result.rvalue
    # correct image
    def correct_image(img):
        """
        Correct the brightness of an RGB image.

        Parameters:
            img (numpy.ndarray): Input RGB image.

        Returns:
            numpy.ndarray: Image with corrected brightness.
        """
        # Get the shape of the image
        rows, cols, channels = img.shape

        # Compute the average brightness across all channels
        brightness = np.sum(img) / (255 * rows * cols * channels)

        # Define the target minimum brightness
        minimum_brightness = 0.3

        # Compute the brightness ratio
        ratio = brightness / minimum_brightness

        # If the ratio is greater than or equal to 1, the image is already bright enough
        if ratio >= 1:
            print("Image already bright enough")
            return img

        # Otherwise, adjust brightness to get the target brightness for each channel
        corrected_img = cv2.convertScaleAbs(img, alpha=1 / ratio, beta=0)

        return corrected_img


    def white_balance_linear(img):
        # Compute the average value of each color channel

        # Get the shape of the image
        rows, cols, channels = img.shape

        # Compute the average brightness across all channels
        brightness = np.sum(img) / (255 * rows * cols * channels)

        # Define the target minimum brightness
        minimum_brightness = 0.25

        # Compute the brightness ratio
        ratio = brightness / minimum_brightness

        # If the ratio is greater than or equal to 1, the image is already bright enough
        if ratio >= 1:
            print("Image already bright enough")
            return img

        # Clip the scaled image to ensure pixel values remain in the valid range [0, 255]
        corrected_img = np.clip(img * 1/ratio, 0, 255)

        # # Otherwise, adjust brightness to get the target brightness for each channel
        # corrected_img = cv2.convertScaleAbs(img, alpha=1 / ratio, beta=0)

        return corrected_img.astype(np.uint8)


    def _prepare_depth(model_output):
            # prepare images for visualization
            # format the image to be between 0 and 255
            formatted = (((model_output-np.min(model_output)) / (np.max(model_output)-np.min(model_output)))*255).astype("uint8")
            colored_depth = cv2.applyColorMap(formatted, cv2.COLORMAP_INFERNO)[:, :, ::-1]
            depth = Image.fromarray(colored_depth)
            return depth
    

    Dataset_name = "D5"
    image_name = "LFT_3395"

    # Dataset_name = "D3"
    # image_name = "T_S04857"
    test_image_path = f"Datasets\SeaThru\{Dataset_name}\linearPNG\{image_name}.png"
    # test_image_path = f"Datasets\SeaThru_old\{Dataset_name}\Raw\{image_name}.NEF"
    # test_image_path = f"Datasets\SeaThru_old\{Dataset_name}\Raw\{image_name}.ARW"
    # test_image_path = f"Datasets\SeaThruNeRF\Curasao\images_wb\MTN_1288.png"
    ground_truth_path = f"Datasets\SeaThru\{Dataset_name}\depth\depth{image_name}.tif"
    actual_depth = dataset_loader.open_depth_map(path=ground_truth_path, img_dim=(model_training.default_image_dim[1], model_training.default_image_dim[0]))
    non_linear_image = dataset_loader._open_image(test_image_path, img_dim=model_training.default_image_dim)[0]

    non_linear_image_np = correct_image(non_linear_image)
    # non_linear_image_np = non_linear_image
    non_linear_image = Image.fromarray(non_linear_image_np)

    display(pude_utils.make_image_grid([non_linear_image, _prepare_depth(actual_depth)], rows=1, cols=2))

    non_linear_image_tensor = torch.tensor(non_linear_image_np)

    pude_output = model_training.get_model_output(model=new_pude_model, 
                                                        image_processor=new_pude_mode_image_processor, 
                                                        raw_image=non_linear_image_tensor, device=device, requires_grad=False)
    depth_anything_output = model_training.get_model_output(model=depth_anything_model, 
                                                                    image_processor=depth_anything_image_processor, 
                                                                    raw_image=non_linear_image_tensor, device=device, requires_grad=False)
    pude_utils.display_image_with_depth(image=_prepare_depth(actual_depth), depth1=depth_anything_output.cpu().detach().numpy(), depth2=pude_output.cpu().detach().numpy())

    plot_one_over_z_vs_d(actual_depth, pude_output.cpu().detach().numpy(), img_name="pude_output")
    plot_one_over_z_vs_d(actual_depth, depth_anything_output.cpu().detach().numpy(), img_name="depth_anything_output")



In [None]:
import cv2

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import linregress
from PIL import Image


def plot_one_over_z_vs_d(actual_depth, model_output, save_folder=None, img_name="test image"):
    # Remove zero values
    z = actual_depth.flatten()
    d = model_output.flatten()

    # normalise d
    d = d/np.max(d)
    z = z/np.max(z)


    non_zero = np.where(z > 0.15)
    z = 1/z[non_zero]
    d = d[non_zero]


    plt.figure()

    def _fit_lingress(z,d):
        # Linear regression
        result = linregress(z, d)
        
        # Line of best fit
        fit_x = np.linspace(np.min(z), np.max(z), 100)
        fit_y = result.slope * fit_x + result.intercept
        return result, fit_x, fit_y
        
    def _remove_outlier(z,d):
        # Linear regression
        result, fit_x, fit_y = _fit_lingress(z,d)
        # Calculate residuals
        residuals = d - result.slope * z - result.intercept

        abs_residuals = np.abs(residuals)
        threshold = 4 * np.std(abs_residuals)

        # Identify the non-outliers
        non_outliers_mask = abs_residuals < threshold

        return z[non_outliers_mask], d[non_outliers_mask]


    z, d = _remove_outlier(z,d)
    result, fit_x, fit_y = _fit_lingress(z,d)

    # Scatter plot
    plt.scatter(z, d, s=1)
    plt.xlabel("1/z")
    plt.ylabel("d")
  
    plt.plot(fit_x, fit_y, '-r', label='Line of best fit, r = {:.3f}'.format(result.rvalue))
    plt.ylim(bottom=0)
    # Display R-squared value
    plt.legend()
    plt.title(f"1/z vs d for {img_name}")
    plt.show( )
    # plt.savefig(save_folder, dpi=100, bbox_inches='tight')
    plt.close()
    return result.rvalue


Dataset_name = "D5"
image_name = "LFT_3395"

# Dataset_name = "D3"
# image_name = "T_S04857"
test_image_path = f"Datasets\SeaThru\{Dataset_name}\linearPNG\{image_name}.png"
# test_image_path = f"Datasets\SeaThru_old\{Dataset_name}\Raw\{image_name}.NEF"
# test_image_path = f"Datasets\SeaThru_old\{Dataset_name}\Raw\{image_name}.ARW"
# test_image_path = f"Datasets\SeaThruNeRF\Curasao\images_wb\MTN_1288.png"
ground_truth_path = f"Datasets\SeaThru\{Dataset_name}\depth\depth{image_name}.tif"
actual_depth = dataset_loader.open_depth_map(path=ground_truth_path, img_dim=(model_training.default_image_dim[1], model_training.default_image_dim[0]))
non_linear_image = dataset_loader._open_image(test_image_path, img_dim=model_training.default_image_dim)[0]
# correct image
def correct_image(img):
    """
    Correct the brightness of an RGB image.

    Parameters:
        img (numpy.ndarray): Input RGB image.

    Returns:
        numpy.ndarray: Image with corrected brightness.
    """
    # Get the shape of the image
    rows, cols, channels = img.shape

    # Compute the average brightness across all channels
    brightness = np.sum(img) / (255 * rows * cols * channels)

    # Define the target minimum brightness
    minimum_brightness = 0.3

    # Compute the brightness ratio
    ratio = brightness / minimum_brightness

    # If the ratio is greater than or equal to 1, the image is already bright enough
    if ratio >= 1:
        print("Image already bright enough")
        return img

    # Otherwise, adjust brightness to get the target brightness for each channel
    corrected_img = cv2.convertScaleAbs(img, alpha=1 / ratio, beta=0)

    return corrected_img


def white_balance_linear(img):
    # Compute the average value of each color channel

     # Get the shape of the image
    rows, cols, channels = img.shape

    # Compute the average brightness across all channels
    brightness = np.sum(img) / (255 * rows * cols * channels)

    # Define the target minimum brightness
    minimum_brightness = 0.25

    # Compute the brightness ratio
    ratio = brightness / minimum_brightness

    # If the ratio is greater than or equal to 1, the image is already bright enough
    if ratio >= 1:
        print("Image already bright enough")
        return img

    # Clip the scaled image to ensure pixel values remain in the valid range [0, 255]
    corrected_img = np.clip(img * 1/ratio, 0, 255)

    # # Otherwise, adjust brightness to get the target brightness for each channel
    # corrected_img = cv2.convertScaleAbs(img, alpha=1 / ratio, beta=0)

    return corrected_img.astype(np.uint8)


def _prepare_depth(model_output):
        # prepare images for visualization
        # format the image to be between 0 and 255
        formatted = (((model_output-np.min(model_output)) / (np.max(model_output)-np.min(model_output)))*255).astype("uint8")
        colored_depth = cv2.applyColorMap(formatted, cv2.COLORMAP_INFERNO)[:, :, ::-1]
        depth = Image.fromarray(colored_depth)
        return depth

non_linear_image_np = correct_image(non_linear_image)
# non_linear_image_np = non_linear_image
non_linear_image = Image.fromarray(non_linear_image_np)

display(pude_utils.make_image_grid([non_linear_image, _prepare_depth(actual_depth)], rows=1, cols=2))

non_linear_image_tensor = torch.tensor(non_linear_image_np)

# pude_output = model_training.get_model_output(model=new_pude_model, 
#                                                      image_processor=new_pude_mode_image_processor, 
#                                                      raw_image=non_linear_image_tensor, device=device, requires_grad=False)
# depth_anything_output = model_training.get_model_output(model=depth_anything_model, 
#                                                                 image_processor=depth_anything_image_processor, 
#                                                                 raw_image=non_linear_image_tensor, device=device, requires_grad=False)
# pude_utils.display_image_with_depth(image=_prepare_depth(actual_depth), depth1=depth_anything_output.cpu().detach().numpy(), depth2=pude_output.cpu().detach().numpy())

# plot_one_over_z_vs_d(actual_depth, pude_output.cpu().detach().numpy(), img_name="pude_output")
# plot_one_over_z_vs_d(actual_depth, depth_anything_output.cpu().detach().numpy(), img_name="depth_anything_output")


# get model outputs
depth_anything_output = model_training.get_model_output(model=depth_anything_model, 
                                        image_processor=depth_anything_image_processor, 
                                        raw_image=non_linear_image_tensor, requires_grad=False)
image_2 = stereo_pair_gen.generate_stereo_pair(non_linear_image_np, depth_anything_output.cpu().detach().numpy(), scaling_factor=1)

unimatch_input = {"image1": non_linear_image_np, "image2": image_2}
    
unimatch_output = model_training.get_model_output(model=unimatch_model,
                                image_processor=unimatch_image_processor,
                                raw_image=unimatch_input, requires_grad=False)

image_grid = pude_utils.unimatch_display_image_with_depth(image=non_linear_image,
                                                            depth1=depth_anything_output.cpu().detach().numpy(),
                                                            depth2=unimatch_output.cpu().detach().numpy(),
                                                            shifted_image=Image.fromarray(image_2))
display(image_grid)

plot_one_over_z_vs_d(actual_depth, unimatch_output.cpu().detach().numpy(), img_name="unimatch_output")
plot_one_over_z_vs_d(actual_depth, depth_anything_output.cpu().detach().numpy(), img_name="pude_output")



## save

In [None]:

# torch.save(unimatch_model.state_dict(), "new_unimatch_model_pude_AdamW_beta_60_600_alphas_4_0_4_epochs_3_lr_1e-6.pth")