# Caustics Removal Notebook

In this notebook we will apply the processing pipeline described in the paper: Computer Vision Corrections Enhance UAV-Based Retrievals in Shallow Waters, to remove caustics to obtain a clear image of the bottom of the sea. To do this, we have a video as input and use its frames to perform some processing.

Please cite our work if this notebook is of any benefit to your work!

## Initial Declarations and Constants

In [1]:
from color_transfer import read_file, get_mean_and_std, color_transfer
from skimage.metrics import structural_similarity as ssim
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import imageio.v2 as imageio
from PIL import Image
import seaborn as sns
import pandas as pd
import numpy as np
import os, PIL
import shutil
import glob
import cv2
import re

In [2]:
## FOLDER STRUCTURE
#Be sure to follow this order or feel free to change the patter at your convenience

PARENT_FOLDER = 'Caustics_Removal' # The main folder

FRAMES_DIRECTORY = fr'{PARENT_FOLDER}\\1_Frame_Extraction'
CT = fr'{PARENT_FOLDER}\\2_Color_Transfer'
MEDIAN = fr'{PARENT_FOLDER}\\3_Median'
RESULTS = fr'{PARENT_FOLDER}\\4_Results'

TARGET_IAMGE = fr'{PARENT_FOLDER}\\5_Target'

## Functions

In [3]:
def sorted_alphanumeric(data):
    '''Useful function to order alphanumerical files in a folder into a list'''
    convert = lambda text: int(text) if text.isdigit() else text.lower()
    alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] 
    return sorted(data, key=alphanum_key)

def find_mp4_files(root_dir):
    mp4_files = []
    for root, dirs, files in os.walk(root_dir):
        for file in files:
            if file.endswith('.MP4'):
                mp4_files.append(os.path.join(root, file))
    return mp4_files

def frame_extr(path_video,saving_path):
    
    # Creating Directory
    if not os.path.exists(saving_path):
        os.makedirs(saving_path)

    # Frames Extraction    
    video = cv2.VideoCapture(path_video)
    success,image = video.read()
    count = 0

    if os.path.exists(f'{saving_path}\\Frame0.png'): # Sanity Check
        print(f'Frames were already extracted from the video, check folder: {saving_path}')
    else:
        while success:
            cv2.imwrite(fr'{saving_path}\\Frame{count}.png', image) # Extracting Frames
            success,image = video.read()       
            count += 1
        print(f'Finished Writing Frames...')
        
def color_transf_frames(load_path,save_path,frames=None):
    
    if not os.path.exists(save_path):
        os.makedirs(save_path)
            
    target = [fr'{load_path}\\Frame0.png']
    crop_list = sorted_alphanumeric(os.listdir(load_path))
    
    for i,img in enumerate(crop_list):
        if i == 0:
            shutil.copyfile(target[0], f'{save_path}\\Frame0_color_corrected.png')
        elif frames and i == frames:
            break
        else:
            source = [fr'{load_path}\\{img}']
            s = color_transfer(source,target)
            cv2.imwrite(f'{save_path}\\Frame{i}_color_corrected.png',s)       
    print('Fisnihed Color Transfer...')

def get_mean_image(save_path,load_path,frames=60,median=True):
    '''Creating average image giving in input the number of frames and the path to 
       cropped frames from the original video'''
    
    if frames > len(os.listdir(load_path)):
        print(f'ERROR: Number of frames out of boundaries! Maximum number allowed: {len(os.listdir(load_path))}. Quitting...')
        return
        
    if not os.path.exists(save_path):
        os.makedirs(save_path)
            
    images = [Image.open(f'{load_path}\\Frame%d_color_corrected.png' % i) for i in range(frames)]
    arrs = [np.array(im) for im in images]
    
    if median: 
        avg_arr = np.median(arrs, axis=0)
        avg_im = Image.fromarray(avg_arr.astype('uint8'))
        avg_im.save(f"{save_path}\\Median{frames}_CT.png")
    else:
        avg_arr = np.mean(arrs, axis=0)
        avg_im = Image.fromarray(avg_arr.astype('uint8'))
        avg_im.save(f"{save_path}\\Average{frames}.png")    
    
def histogram_similarity(video_name,chi=True,frames=60,corrected=False):
    images = []
    for i in range(frames):
        if corrected:
            image = cv2.imread(fr'{CT}\\{video_name}\\Frame{i}_color_corrected.png')
        else:
            image = cv2.imread(fr'{FRAMES_DIRECTORY}\\{video_name}\\Frame{i}.png')
        image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
        images.append(image)
    
    histograms1 = [cv2.calcHist([image], [0], None, [100], [0, 100]) for image in images]
    histograms2 = [cv2.calcHist([image], [1], None, [256], [-128, 128]) for image in images]
    histograms3 = [cv2.calcHist([image], [2], None, [256], [-128, 128]) for image in images]
    
#     for hist in histograms1:
#         cv2.normalize(hist, hist, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX)

#     for hist in histograms2:
#         cv2.normalize(hist, hist, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX)

#     for hist in histograms3:
#         cv2.normalize(hist, hist, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX)
    similarities = np.zeros((frames, frames))
    
    if chi:
        for i in range(frames):
            for j in range(i,frames):
                similarity1 = cv2.compareHist(histograms1[i], histograms1[j], cv2.HISTCMP_CHISQR)
                similarity2 = cv2.compareHist(histograms2[i], histograms2[j], cv2.HISTCMP_CHISQR)
                similarity3 = cv2.compareHist(histograms3[i], histograms3[j], cv2.HISTCMP_CHISQR)
                average_similarity = (similarity1 + similarity2 + similarity3) / 3
                similarities[i, j] = average_similarity
                similarities[j, i] = average_similarity
    else:
        for i in range(frames):
            for j in range(i,frames):
                similarity1 = cv2.compareHist(histograms1[i], histograms1[j], cv2.HISTCMP_CORREL)
                similarity2 = cv2.compareHist(histograms2[i], histograms2[j], cv2.HISTCMP_CORREL)
                similarity3 = cv2.compareHist(histograms3[i], histograms3[j], cv2.HISTCMP_CORREL)
                average_similarity = (similarity1 + similarity2 + similarity3) / 3
                similarities[i, j] = average_similarity
                similarities[j, i] = average_similarity
        
    similarities = (similarities - similarities.min()) / (similarities.max() - similarities.min())
    mean_similarity = np.mean(similarities)
    
    return similarities, mean_similarity

def ssim_calc(video_name, frames=60,corrected=False):
    ssim_values = np.zeros((frames, frames))
    images = []
    for i in range(frames):
        if corrected:
            image = cv2.imread(fr'{CT}\\{video_name}\\Frame{i}_color_corrected.png')
        else:
            image = cv2.imread(fr'{FRAMES_DIRECTORY}\\{video_name}\\Frame{i}.png')
        image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
        images.append(image)
        
    for i in range(frames):
        for j in range(i,frames):    
            ssim_score = ssim(images[i], images[j], multichannel=True)
            ssim_values[i, j] = ssim_score
            ssim_values[j, i] = ssim_score
            
    mean_ssim = np.mean(ssim_values) 
    
    return mean_ssim

def plot_matrices(similarity_original_chi, similarity_original_corr, 
                   similarity_corrected_chi, similarity_corrected_corr, 
                   mean_sim_original_chi, mean_sim_original_corr, 
                   mean_sim_corrected_chi, mean_sim_corrected_corr, 
                   mean_ssim, mean_ssim_corrected,video_name):
    fig, axs = plt.subplots(2, 2, figsize=(10, 8))

    im1 = axs[0, 0].imshow(similarity_original_chi, cmap='viridis', interpolation='nearest')
    axs[0, 0].set_title('χ² Matrix - Original')
    axs[0, 0].set_ylabel('Frame Index')
    axs[0, 0].text(36, 6.5, f'Mean χ²: {mean_sim_original_chi:.2f}\nMean SSIM: {mean_ssim:.2f}',
                   color='white', fontsize=10,
                   bbox=dict(facecolor='black', alpha=0.5, edgecolor='black', boxstyle='round,pad=0.5'))
    fig.colorbar(im1, ax=axs[0, 0], label='Chi-Squared Distance')

    im2 = axs[0, 1].imshow(similarity_original_corr, cmap='viridis', interpolation='nearest')
    axs[0, 1].set_title('Correlation Matrix - Original')
    axs[0, 1].text(36, 6.5, f'Mean Corr: {mean_sim_original_corr:.2f}\nMean SSIM: {mean_ssim:.2f}',
                   color='white', fontsize=10,
                   bbox=dict(facecolor='black', alpha=0.5, edgecolor='black', boxstyle='round,pad=0.5'))
    fig.colorbar(im2, ax=axs[0, 1], label='Correlation')

    im3 = axs[1, 0].imshow(similarity_corrected_chi, cmap='viridis', interpolation='nearest')
    axs[1, 0].set_title('χ² Matrix - Corrected')
    axs[1, 0].text(36, 6.5, f'Mean χ²: {mean_sim_corrected_chi:.2f}\nMean SSIM: {mean_ssim_corrected:.2f}',
                   color='white', fontsize=10,
                   bbox=dict(facecolor='black', alpha=0.5, edgecolor='black', boxstyle='round,pad=0.5'))
    axs[1, 0].set_xlabel('Frame Index')
    axs[1, 0].set_ylabel('Frame Index')
    fig.colorbar(im3, ax=axs[1, 0], label='Chi-Squared Distance')

    im4 = axs[1, 1].imshow(similarity_corrected_corr, cmap='viridis', interpolation='nearest')
    axs[1, 1].set_title('Correlation Matrix - Corrected')
    axs[1, 1].text(36, 6.5, f'Mean Corr: {mean_sim_corrected_corr:.2f}\nMean SSIM: {mean_ssim_corrected:.2f}',
                   color='white', fontsize=10,
                   bbox=dict(facecolor='black', alpha=0.5, edgecolor='black', boxstyle='round,pad=0.5'))
    axs[1, 1].set_xlabel('Frame Index')
    fig.colorbar(im4, ax=axs[1, 1], label='Correlation')

    plt.tight_layout()
    
    if not os.path.exists(f'{RESULTS}\\{video_name}'):
        os.makedirs(f'{RESULTS}\\{video_name}')
        
    plt.savefig(f'{RESULTS}\\{video_name}\\Matrices_{video_name}.png', dpi=600, bbox_inches='tight')
    plt.savefig(f'{RESULTS}\\{video_name}\\Transparent_Matrices_{video_name}.png', dpi=600, bbox_inches='tight',
                transparent = True)
    

## Processing

In [None]:
root_directory = PARENT_FOLDER #The videos are expected to be in the parent folder. Change this line if you wanna
                               # change the location
mp4_files_list = find_mp4_files(root_directory)

for mp4_file in mp4_files_list:
    
    video_name = mp4_file.split('\\')[-1].split('.')[0] # Get the name for the folder names
    print(f'Processing Video: {video_name}')
    
    #PROCESS
    frame_extr(mp4_file,f'{FRAMES_DIRECTORY}\\{video_name}') # Frame Extraction
    
    color_transf_frames(f'{FRAMES_DIRECTORY}\\{video_name}',f'{CT}\\{video_name}',70) # Color Transfer between frames
    
    get_mean_image(f'{MEDIAN}\\{video_name}',f'{CT}\\{video_name}',frames=60,median=True)
    
    #HISTOGRAMS
    similarity_original_chi,mean_sim_original_chi = histogram_similarity(video_name,True,60,False)
    similarity_original_corr,mean_sim_original_corr = histogram_similarity(video_name,False,60,False)

    similarity_corrected_chi,mean_sim_corrected_chi = histogram_similarity(video_name,True,60,True)
    similarity_corrected_corr,mean_sim_corrected_corr = histogram_similarity(video_name,False,60,True)
    print('Fisnihed Histograms Comparison...')
    
    # SSIM
    mean_ssim = ssim_calc(video_name,60,False)
    mean_ssim_corrected = ssim_calc(video_name,60,True)
    print('Fisnihed SSIM Calculation...')
    
    #Results
    plot_matrices(similarity_original_chi, similarity_original_corr, 
              similarity_corrected_chi, similarity_corrected_corr, 
              mean_sim_original_chi, mean_sim_original_corr, 
              mean_sim_corrected_chi, mean_sim_corrected_corr, 
              mean_ssim, mean_ssim_corrected,video_name)
    
    # Final Enhancement
    target = [fr'{TARGET_IAMGE}\GT.JPG']
    source = [f'{MEDIAN}\\{video_name}\\Median60_CT.png']
    s = color_transfer(source,target)
    cv2.imwrite(fr'{RESULTS}\\{video_name}\\Final_Result_{video_name}.png',s)
    print('Fisnihed Histograms Results...')