# Test mosaicking methods

In [None]:
import os
from glob import glob
import rioxarray as rxr
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib

# Define paths to orthoimages, cameras, and reference DEM
data_folder = '/Users/rdcrlrka/Research/Soo_locks'
out_folder = os.path.join(data_folder, '20251001_imagery', 'frames_IR_proc_out')
ortho_folder = os.path.join(out_folder, 'final_ortho')
cam_positions_file = os.path.join(data_folder, 'inputs', 'cams.txt')
refdem_file = os.path.join(os.getcwd(), '..', 'inputs', '20251001_Soo_Model_1cm_mean_UTM19N-fake_filled_cropped.tif')

# Get images and cameras
image_list = sorted(glob(os.path.join(ortho_folder, '*.tiff')))
print(f'Located {len(image_list)} images')

# Define output folder
mosaic_folder = os.path.join(out_folder, 'mosaic_testing')
os.makedirs(mosaic_folder, exist_ok=True)

In [None]:
# Load camera positions
cam_positions = pd.read_csv(cam_positions_file, sep=' ')
camera_centers = cam_positions[['X', 'Y', 'Z']].values

# Load reference DEM
refdem = rxr.open_rasterio(refdem_file).squeeze()
refdem = xr.where(refdem < -1e3, np.nan, refdem)
refdem = refdem.rio.write_crs("EPSG:32619")

# Create a grid of the closest camera to each pixel
print('Creating 3D reference grid from DEM')
xv, yv = np.meshgrid(refdem.x.values, refdem.y.values)
Z = refdem.data
xyz_points = np.stack([xv.ravel(), yv.ravel(), Z.ravel()], axis=1)

# Read and stack images
print('Reading and stacking images')
datasets = [rxr.open_rasterio(f).squeeze() for f in image_list]
# match refdem grid
datasets = [f.rio.reproject_match(refdem) for f in datasets]
# create stack
stack = xr.concat(datasets, dim="camera")

# Calculate distances to each camera
print('Identifying closest camera to each pixel')
distances = np.linalg.norm(
    xyz_points[:, None, :] - camera_centers[None, :, :],
    axis=2
)
closest_idx = np.argmin(distances, axis=1)
closest_idx_img = closest_idx.reshape(refdem.shape)
# convert to data array
closest_idx_img_da = xr.DataArray(
    data=closest_idx_img,
    dims=['y', 'x'],
    coords={
        'y': stack.y,
        'x': stack.x
    }
)
# set no data values to NaN
closest_idx_img_da = xr.where(np.isnan(refdem), np.nan, closest_idx_img_da)

# Create mosaic
print('Creating mosaic')
mosaic = closest_idx_img_da.copy()
mosaic.data = np.nan * np.ones(closest_idx_img_da.data.shape)
for i in range(len(stack.camera.data)):
    mosaic = xr.where(closest_idx_img_da==i, stack.isel(camera=i), mosaic)

# Plot results
plt.rcParams.update({'font.size': 12, 'font.sans-serif': 'Verdana'})
fig, ax = plt.subplots(2, 2, figsize=(10,12), gridspec_kw=dict(height_ratios=[20,1]))
ax = ax.flatten()
# closest camera map
cmap = matplotlib.colors.ListedColormap([plt.cm.tab20(i) for i in range(len(camera_centers))])
im = ax[0].imshow(
    closest_idx_img_da.data, cmap=cmap, clim=(0,16),
    extent=(min(refdem.x), max(refdem.x), min(refdem.y), max(refdem.y))
    )
for i, c in enumerate(camera_centers):
    ax[0].plot(
        c[0], c[1], '*', markersize=10,
        markerfacecolor=cmap(i),
        markeredgecolor='w', markeredgewidth=0.5
        )
cbar = fig.colorbar(im, cax=ax[2], orientation='horizontal')
ticks = np.arange(0,16) + 0.5
cbar.set_ticks(ticks)
cbar.set_ticklabels([int(tick+0.5) for tick in ticks])
cbar.ax.set_xlabel('camera #')
ax[0].set_title('Closest camera')
# orthomosaic
ax[1].imshow(
    mosaic.data, cmap='Grays_r',
    extent=(min(mosaic.x), max(mosaic.x), min(mosaic.y), max(mosaic.y))
)
ax[1].set_title('Orthomosaic')

ax[0].set_ylabel('meters')
for axis in ax[0:2]:
    axis.set_xlabel('meters')

ax[-1].remove()

fig.tight_layout()
plt.show()

# Save results
closest_idx_img_da_file = os.path.join(mosaic_folder, 'closest_camera_map.tiff')
closest_idx_img_da = closest_idx_img_da.rio.write_crs("EPSG:32619")
closest_idx_img_da.rio.to_raster(closest_idx_img_da_file)
print('Saved closest camera map:', closest_idx_img_da_file)
mosaic_file = os.path.join(mosaic_folder, 'orthomosaic_sample_closest_camera.tiff')
mosaic = mosaic.rio.write_crs("EPSG:32619")
mosaic.rio.to_raster(mosaic_file)
print('Saved orthomosaic:', mosaic_file)
fig_file = os.path.join(mosaic_folder, 'closest_camera_orthomosaic.png')
fig.savefig(fig_file, dpi=300, bbox_inches='tight')
print('Saved figure:', fig_file)


## Try applying a correction for reflectance based on distance from the camera

In [None]:
# define the correction function
def reflectance_correction(
    distance_map,
    d_target,
    m_min=1.0,
    m_max=1.5,
    alpha=3.0,
    eps=1e-8,
    clip_above_target=True
):
    # calculate minimum distance
    d_min = np.nanmin(distance_map)
    # handle degenerate d_target values
    if d_target <= d_min + eps:
        # target at or below min -> simply return constant m_max everywhere valid
        out = np.full_like(distance_map, fill_value=m_max)
        out[np.isnan(distance_map)] = np.nan
        return out
    # compute normalized distance with d_target as the full multiplier point
    if clip_above_target:
        d_norm = (distance_map - d_min) / (d_target - d_min)
        d_norm = np.clip(d_norm, 0.0, 1.0)
    else:
        # allow values >1; they will map beyond g(1) unless we clip multiplier afterwards
        d_norm = (distance_map - d_min) / (d_target - d_min)
        d_norm = np.maximum(d_norm, 0.0)
    # handle alpha ~ 0 -> linear mapping
    if abs(alpha) < 1e-12:
        g = d_norm
    else:
        exp_alpha = np.exp(alpha)
        g = (np.exp(alpha * d_norm) - 1.0) / (exp_alpha - 1.0)
    multiplier = m_min + (m_max - m_min) * g
    # clip to min and max values
    multiplier = np.clip(multiplier, m_min, m_max)
    # preserve NaNs from input distance_map
    multiplier = np.where(np.isnan(distance_map), np.nan, multiplier)
    return multiplier

# reshape distances to match the DEM grid for each camera
distance_maps = [
    distances[:, i].reshape(refdem.shape) for i in range(len(camera_centers))
]

# apply correction to each image before mosaicking
print('Applying reflection correction')
corrected_stack = []
for i, img in enumerate(stack):
    dist_map = distance_maps[i]
    corr_factor = reflectance_correction(dist_map, d_target=5, m_max=1.3)
    corrected_img = img * corr_factor
    corrected_stack.append(corrected_img)
# stack the corrected images back together
corrected_stack = xr.concat(corrected_stack, dim="camera")

# Create corrected mosaic
print('Creating mosaic')
mosaic_corr = closest_idx_img_da.copy()
mosaic_corr.data = np.nan * np.ones(closest_idx_img_da.data.shape)
for i in range(len(stack.camera.data)):
    mosaic_corr = xr.where(closest_idx_img_da==i, corrected_stack.isel(camera=i), mosaic_corr)

# Plot
print('Plotting results')
fig, ax = plt.subplots(1, 2, figsize=(10,12))
ax = ax.flatten()
# original image
im = ax[0].imshow(
    mosaic.data, cmap='Grays_r', vmin=0, vmax=255,
    extent=(min(mosaic.x), max(mosaic.x), min(mosaic.y), max(mosaic.y))
    )
ax[0].set_title('Original')
# corrected image
ax[1].imshow(
    mosaic_corr.data, cmap='Grays_r', vmin=0, vmax=255,
    extent=(min(mosaic_corr.x), max(mosaic_corr.x), min(mosaic_corr.y), max(mosaic_corr.y))
)
ax[1].set_title('Reflectance-corrected')

ax[0].set_ylabel('meters')
for axis in ax[0:2]:
    axis.set_xlabel('meters')

fig.tight_layout()
plt.show()