In [None]:
import numpy as np
import h5py as h5
import matplotlib.pyplot as plt

# Enter the file paths and exposure times here
file1_path = '/Users/allisondennis/Library/CloudStorage/OneDrive-NortheasternUniversity/AMD/IR VIVO data/231216_mirrors/231216_capillaries_BP_air_0_1.h5'
file1_exposure_time = 0.1

file2_path = '/Users/allisondennis/Library/CloudStorage/OneDrive-NortheasternUniversity/AMD/IR VIVO data/231216_mirrors/231216_capillaries_BP_air_0_05.h5'
file2_exposure_time = 0.05

file3_path = '/Users/allisondennis/Library/CloudStorage/OneDrive-NortheasternUniversity/AMD/IR VIVO data/231216_mirrors/231216_capillaries_BP_air_0_01.h5'
file3_exposure_time = 0.01

output_file = '231216_capillaries_BP_air_hdr.npy'

def aggregate_exposures(images, exposure_times):
    print("Aggregating exposures...")
    print(f"Number of images: {len(images)}")
    print(f"Image shape: {images[0].shape}")
    
    # Convert exposure times to relative radiance values
    radiance_values = np.log(exposure_times)
    print(f"Radiance values: {radiance_values}")
    
    # Initialize the HDR image
    hdr_image = np.zeros(images[0].shape, dtype=np.float32)
    print(f"Initialized HDR image with shape: {hdr_image.shape}")
    
    # Iterate over each pixel in the images
    for i in range(images[0].shape[0]):
        for j in range(images[0].shape[1]):
            # Collect pixel values from all images
            pixel_values = np.array([img[i, j] for img in images], dtype=np.float32)
            print(f"Pixel values at index ({i}, {j}): {pixel_values}")
            
            # Perform weighted average based on radiance values
            weights = np.exp(radiance_values - np.max(radiance_values))
            weighted_sum = np.sum(weights * pixel_values)
            weight_sum = np.sum(weights)
            print(f"Weighted sum at index ({i}, {j}): {weighted_sum}")
            print(f"Weight sum at index ({i}, {j}): {weight_sum}")
            
            # Assign the aggregated value to the HDR image
            hdr_image[i, j] = weighted_sum / weight_sum
    
    print(f"Aggregated HDR image: {hdr_image}")
    return hdr_image

def process_files(image_files, exposure_times):
    print("Processing files...")
    
    wavelengths = None
    
    # Extract image data and wavelengths from each file
    images_list = []
    for file_path in image_files:
        with h5.File(file_path, 'r') as h5_file:
            img_data = h5_file['Cube']['Images'][:]
            print(f"Extracted image data with shape {img_data.shape} from file {file_path}")
            images_list.append(img_data)
            
            if wavelengths is None:
                wavelengths = h5_file['Cube']['Wavelength'][:]
                print(f"Extracted wavelengths with shape {wavelengths.shape}")
    
    # Initialize the HDR image cube
    hdr_image_cube = np.zeros((len(wavelengths), images_list[0].shape[1], images_list[0].shape[2]), dtype=np.float32)
    print(f"Initialized HDR image cube with shape: {hdr_image_cube.shape}")

    # Process each wavelength
    for i, wavelength in enumerate(wavelengths):
        print(f"Processing wavelength: {wavelength}")
        
        # Extract images for the current wavelength from each exposure time
        images = [img_data[i] for img_data in images_list]
        
        # Display the original images for the current wavelength
        fig, axs = plt.subplots(1, len(image_files), figsize=(15, 5))
        for j, img in enumerate(images):
            axs[j].imshow(img, cmap='gray')
            axs[j].set_title(f'Wavelength {wavelength}, Exposure Time {exposure_times[j]}')
            axs[j].axis('off')
        plt.tight_layout()
        plt.show()
        
        # Aggregate pixel values from multiple exposures for the current wavelength
        hdr_image = aggregate_exposures(images, exposure_times)
        
        # Normalize the HDR image for display
        hdr_image_normalized = (hdr_image - np.min(hdr_image)) / (np.max(hdr_image) - np.min(hdr_image))
        hdr_image_normalized = (hdr_image_normalized * 65535).astype(np.uint16)
        print(f"Normalized HDR image for wavelength {wavelength}: {hdr_image_normalized}")
        
        # Assign the HDR image to the corresponding slice in the HDR image cube
        hdr_image_cube[i] = hdr_image_normalized
        
        # Display the normalized HDR image for the current wavelength
        plt.figure(figsize=(8, 8))
        plt.imshow(hdr_image_normalized, cmap='gray')
        plt.title(f'Normalized HDR Image - Wavelength {wavelength}')
        plt.axis('off')
        plt.show()

    # Save the HDR image cube
    np.save(output_file, hdr_image_cube)
    print(f"Saved HDR image cube to {output_file}")

    # Save the HDR image cube as an h5 file
    output_file_h5 = '231216_capillaries_BP_air_hdr.h5'
    with h5.File(output_file_h5, 'w') as h5_file:
        cube_group = h5_file.create_group('Cube')
        cube_group.create_dataset('Images', data=hdr_image_cube)
        cube_group.create_dataset('Wavelength', data=wavelengths)
    print(f"Saved HDR image cube as h5 file: {output_file_h5}")

# Process the files
image_files = [file1_path, file2_path, file3_path]
exposure_times = [file1_exposure_time, file2_exposure_time, file3_exposure_time]
process_files(image_files, exposure_times)