# Calibrate cameras, create initial orthoimage and partial DSM

In [None]:
import os
from glob import glob
import subprocess
import numpy as np
from tqdm import tqdm
import pandas as pd
import cv2
import matplotlib.pyplot as plt
import rioxarray as rxr
import xarray as xr
from shapely.geometry import Point
import geopandas as gpd
import pyproj
import shutil
# Ignore warnings (rasterio throws a warning whenever an image is not georeferenced. Annoying in this case.)
import warnings
warnings.filterwarnings('ignore')

# Locate image files
data_folder = '/Users/rdcrlrka/Research/Soo_locks'
image_folder = os.path.join(data_folder, '20251001_imagery', 'frames_IR')
image_list = sorted(glob(os.path.join(image_folder, '*.tiff')))
print(f"{len(image_list)} images located")

# Grab other input files
refdem_file = os.path.join(os.getcwd(), '..', 'inputs', '20251001_Soo_Model_1cm_mean_UTM19N-fake_filled.tif')
gcp_folder = os.path.join(os.getcwd(), '..', 'inputs', 'gcp')
cams_file = os.path.join(os.getcwd(), '..', 'inputs', 'cams_lonlat-fake.txt')

# Define output folders
out_folder = image_folder + '_proc_out'
os.makedirs(out_folder, exist_ok=True)
new_image_folder = os.path.join(out_folder, 'single_band_images')
undistorted_folder = os.path.join(out_folder, 'undistorted_images_cams')
init_ortho_folder = os.path.join(out_folder, 'init_ortho')
init_stereo_folder = os.path.join(out_folder, 'init_stereo')
ba_folder = os.path.join(out_folder, 'bundle_adjust')
final_ortho_folder = os.path.join(out_folder, 'final_ortho')
final_stereo_folder = os.path.join(out_folder, 'final_stereo')

## Merge GCP

In [None]:
gcp_merged_file = os.path.join(gcp_folder, 'GCP_merged.csv')
if not os.path.exists(gcp_merged_file):

    gcp_list = sorted(glob(os.path.join(gcp_folder, '*.gcp')))
    df_list = []
    for gcp_file in gcp_list:
        df = pd.read_csv(
            gcp_file,
            sep=',',
            header=None,
            skiprows=[0],
            names=[
                'point_index', 'lat', 'lon', 'Z', 'lat_sigma', 'lon_sigma', 'Z_sigma', 
                'image_path', 'col_sample', 'row_sample', 'use_lat', 'use_lon']
        )
        df_list += [df]

    dfs = pd.concat(df_list).reset_index(drop=True)

    # reproject to UTM zone 19N
    gdf = gpd.GeoDataFrame(
        dfs,
        geometry=[Point(x,y) for x,y in dfs[['lon', 'lat']].values],
        crs="EPSG:4326"
    )
    gdf = gdf.to_crs("EPSG:32619")
    gdf['X'] = [x.coords.xy[0][0] for x in gdf['geometry']]
    gdf['Y'] = [x.coords.xy[1][0] for x in gdf['geometry']]

    # use just the image file name
    gdf['image_name'] = [os.path.basename(x) for x in gdf['image_path']]

    # select relevant columns
    gdf = gdf[['image_name', 'X', 'Y', 'Z', 'col_sample', 'row_sample']]

    # save to file
    gdf.to_csv(gcp_merged_file, sep=',', index=False)
    print('Saved merged GCP:', gcp_merged_file)
else:
    print('Merged GCP already exists in file, skipping merge.')


## Convert images to single band in case they're RGB

A couple IR images (near the windows) were captured in RGB 

In [None]:
os.makedirs(new_image_folder, exist_ok=True)

# iterate over images
print('Saving single-band images to:', new_image_folder)
for image_file in tqdm(image_list):
    # convert images to single band
    out_fn = os.path.join(new_image_folder, os.path.basename(image_file))
    if os.path.exists(out_fn):
        continue
    cmd = [
        "gdal_translate",
        "-b", "1",
        image_file, out_fn
    ]
    subprocess.run(cmd)

## Calibrate cameras using GCP

In [None]:
def save_tsai(tsai_dict, filename):
    """
    Save a TSAI (.tsai) pinhole camera file from a dictionary.
    """
    with open(filename, 'w') as f:
        f.write("VERSION_4\n")
        f.write("PINHOLE\n")
        f.write(f"fu = {tsai_dict['fu']}\n")
        f.write(f"fv = {tsai_dict['fv']}\n")
        f.write(f"cu = {tsai_dict['cu']}\n")
        f.write(f"cv = {tsai_dict['cv']}\n")
        f.write(f"u_direction = {' '.join(map(str, tsai_dict['u_direction']))}\n")
        f.write(f"v_direction = {' '.join(map(str, tsai_dict['v_direction']))}\n")
        f.write(f"w_direction = {' '.join(map(str, tsai_dict['w_direction']))}\n")
        f.write(f"C = {' '.join(map(str, tsai_dict['C']))}\n")
        f.write("R = " + " ".join(map(str, tsai_dict['R'].flatten())) + "\n")
        f.write(f"pitch = {tsai_dict['pitch']}\n")
        f.write("NULL\n")


def opencv_to_tsai_cam(K, rvec, tvec, object_points_mean, utm_crs="EPSG:32619"):
    """
    Convert an OpenCV camera to APS-compatible pinhole camera format.
    See the ASP docs for more info: https://stereopipeline.readthedocs.io/en/latest/pinholemodels.html

    Parameters
    ----------
    K : np.ndarray (3x3)
        intrinsic matrix
    rvec : np.ndarray (3x1)
        rotation vector (world -> camera)
    tvec : np.ndarray (3x1)
        translation vector (world -> camera) in UTM meters.
    object_points_mean: np.ndarray (3,)
        mean of the object points used in solvePnP (to restore absolute position)
    utm_crs : str
        CRS of the camera / object points (default EPSG:32619)

    Returns
    -------
    tsai_dict : dict
        TSAI camera parameters ready for ASP
    """

    # Compute rotation matrix camera -> world in UTM
    R_wc = cv2.Rodrigues(rvec)[0]   # world -> camera
    R_cw = R_wc.T                   # camera -> world (UTM)

    # Calculate camera center in UTM coordinates
    C_utm = -R_cw @ tvec.reshape(3) + object_points_mean.reshape(3)

    # Transform C to ECEF
    transformer = pyproj.Transformer.from_crs(utm_crs, "EPSG:4978", always_xy=True)
    X_ecef, Y_ecef, Z_ecef = transformer.transform(C_utm[0], C_utm[1], C_utm[2])
    C_ecef = np.array([X_ecef, Y_ecef, Z_ecef])

    # Calculate rotation from camera -> ECEF
    # Transform the UTM axes directions to ECEF at the camera location
    # We'll use a small delta along UTM axes to compute approximate ECEF rotation
    delta = 1.0  # 1 meter
    pts_utm = np.array([
        C_utm,
        C_utm + R_cw[:,0]*delta,  # X axis
        C_utm + R_cw[:,1]*delta,  # Y axis
        C_utm + R_cw[:,2]*delta   # Z axis
    ])

    # Transform these points to ECEF
    ecef_pts = np.array([transformer.transform(x,y,z) for x,y,z in pts_utm])
    # Build new axes in ECEF
    X_axis = ecef_pts[1] - ecef_pts[0]
    Y_axis = ecef_pts[2] - ecef_pts[0]
    Z_axis = ecef_pts[3] - ecef_pts[0]

    # Normalize axes
    X_axis /= np.linalg.norm(X_axis)
    Y_axis /= np.linalg.norm(Y_axis)
    Z_axis /= np.linalg.norm(Z_axis)

    R_ecef = np.stack([X_axis, Y_axis, Z_axis], axis=1)  # camera -> ECEF

    # Build TSAI dictionary
    tsai_dict = {
        'VERSION': 4,
        'TYPE': 'PINHOLE',
        'fu': float(K[0,0]),
        'fv': float(K[1,1]),
        'cu': float(K[0,2]),
        'cv': float(K[1,2]),
        'u_direction': np.array([1,0,0]),
        'v_direction': np.array([0,1,0]),
        'w_direction': np.array([0,0,1]),
        'C': C_ecef.astype(float),
        'R': R_ecef.astype(float),
        'pitch': 1
    }

    return tsai_dict


def calibrate_shared_intrinsics(image_files, gcp_file, output_folder=None, file_prefix=None):
    object_points_list = []
    image_points_list = []
    image_size = None

    # --- Load merged GCP ---
    gcp = pd.read_csv(gcp_file, sep=',')

    # --- Compile GCP (object) and image (pixel) points --- 
    for image_file in image_files:
        # Subset GCP to image
        gcp_image = gcp.loc[gcp['image_name']==os.path.basename(image_file)]

        # Object and image points
        obj_pts = gcp_image[['X','Y','Z']].values.astype(np.float32)
        image_pts = gcp_image[['col_sample','row_sample']].values.astype(np.float32)
        obj_pts = obj_pts.reshape(-1,1,3)
        image_pts = image_pts.reshape(-1,1,2)

        object_points_list.append(obj_pts)
        image_points_list.append(image_pts)

        # get image size
        if image_size is None:
            image = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE)
            image_size = (image.shape[1], image.shape[0])

    if len(object_points_list) == 0:
        raise ValueError("No valid images for calibration")

    # --- Initialize intrinsics --- 
    fx = fy = 2000
    cx = image_size[0] / 2
    cy = image_size[1] / 2
    K_init = np.array([
        [fx,0,cx],
        [0,fy,cy],
        [0,0,1]
        ], dtype=np.float64)
    dist_init = np.zeros(8)
    flags = (
        cv2.CALIB_USE_INTRINSIC_GUESS
        # | cv2.CALIB_RATIONAL_MODEL
        | cv2.CALIB_FIX_PRINCIPAL_POINT 
        | cv2.CALIB_ZERO_TANGENT_DIST
        )

    # --- Calibrate cameras ---
    rms, K, dist, _, _ = cv2.calibrateCamera(
        object_points_list,
        image_points_list,
        image_size,
        K_init,
        dist_init,
        flags=flags
    )

    # --- Calibrate mean of all object points ---
    all_obj_pts = np.vstack([op.reshape(-1,3) for op in object_points_list])
    object_points_mean = all_obj_pts.mean(axis=0)

    print("Shared calibration done")
    print("RMS reprojection error:", rms)
    print("Camera matrix K:\n", K)
    print("Distortion coefficients:", dist.ravel())

    # --- Save results to file ---
    calib_file = os.path.join(output_folder, file_prefix + 'camera_calibration_params.csv')
    calib_df = pd.DataFrame({
        'image_name': image_files,
        'K': [K]*len(image_files),
        'distortions_coefficients': [dist]*len(image_files),
        'RMS': [rms]*len(image_files)
    })
    calib_df.to_csv(calib_file, index=False, header=True)
    print('Saved camera calibration parameters:', calib_file)
        
    return rms, K, dist, object_points_mean


def undistort_calibrate_extrinsics(image_file, gcp_file, K, dist, object_points_mean, output_folder, plot_results=True):
    """
    Undistort image, calculate per-image rvec/tvec, save TSAI and undistorted image
    """
    image_name = os.path.basename(image_file)

    # --- Construct the image (pixel) and object (GCP) points ---
    # Load GCP
    gcp = pd.read_csv(gcp_file, sep=',')
    # subset to image
    gcp = gcp.loc[gcp['image_name']==image_name]

    # Object points mean-centered
    obj_pts = gcp[['X','Y','Z']].values.astype(np.float32) - object_points_mean
    image_pts = gcp[['col_sample','row_sample']].values.astype(np.float32)

    # --- Solve for camera extrinsics ---
    success, rvec, tvec = cv2.solvePnP(obj_pts, image_pts, K, dist, flags=cv2.SOLVEPNP_ITERATIVE)
    if not success:
        raise RuntimeError(f"solvePnP failed for {image_file}")

    # --- Undistort the image and the GCP pixel coordinates ---
    # Undistort the image
    image = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE)
    h,w = image.shape
    # K_new, _ = cv2.getOptimalNewCameraMatrix(K, dist, (w,h), 1, (w,h))
    # image_undistorted = cv2.undistort(image, K, dist, None, K_new)
    image_undistorted = cv2.undistort(image, K, dist, None, K)
    
    # Calculate new GCP pixel indices for the undistorted image
    undistorted_pts = cv2.undistortPoints(image_pts, K, dist, P=K).reshape(-1, 2)
    gcp['col_sample_undistorted'] = undistorted_pts[:, 0]
    gcp['row_sample_undistorted'] = undistorted_pts[:, 1]

    # --- Save results to file ---
    # Undistorted image
    image_undist_file = os.path.join(output_folder, os.path.splitext(image_name)[0] + '_undistorted.tiff')
    cv2.imwrite(image_undist_file, image_undistorted)
    # print(f"Saved undistorted image: {image_undist_file}")

    # Prepare undistorted GCP for ASP-friendly saving
    # reproject to lat-lon
    gcp_reformat = gcp.copy()
    gcp_reformat['geometry'] = [Point(x,y) for x,y in gcp[['X','Y']].values]
    gcp_gdf = gpd.GeoDataFrame(geometry=gcp_reformat['geometry'], crs='EPSG:32619')
    gcp_gdf = gcp_gdf.to_crs("EPSG:4326")
    gcp_reformat['lon'] = [x.coords.xy[0][0] for x in gcp_gdf['geometry']]
    gcp_reformat['lat'] = [x.coords.xy[1][0] for x in gcp_gdf['geometry']]
    # update image names
    gcp_reformat['image_name'] = [x.replace('.tiff','_undistorted.tiff') for x in gcp_reformat['image_name']]
    # add other relevant columns
    gcp_reformat[['lat_std', 'lon_std', 'Z_std']] = 0.2, 0.2, 0.2
    gcp_reformat[['use_lat', 'use_lon']] = 1, 1
    # reorder and select appropriate columns
    gcp_reformat = gcp_reformat[[
        'lat', 'lon', 'Z', 'lat_std', 'lon_std', 'Z_std', 
        'image_name', 'col_sample_undistorted', 'row_sample_undistorted',
        'use_lat', 'use_lon'
        ]]
    gcp_reformat.reset_index(drop=True, inplace=True)

    # Undistorted GCPs
    gcp_undistorted_file = os.path.join(output_folder, os.path.splitext(image_name)[0] + '_undistorted.gcp')
    gcp_reformat.to_csv(
        gcp_undistorted_file, 
        sep=' ',
        index=True,
        header=False
        )
    # print(f"Saved undistorted GCPs: {gcp_undistorted_file}")

    # TSAI camera model
    tsai_file = os.path.join(output_folder, os.path.splitext(image_name)[0] + '_undistorted.tsai')
    tsai_dict = opencv_to_tsai_cam(K, rvec, tvec, object_points_mean)
    save_tsai(tsai_dict, tsai_file)
    # print(f"Saved TSAI: {tsai_file}")

    # --- Plot results ---
    if plot_results:
        fig, ax = plt.subplots(1, 2, figsize=(14, 5))
        ax[0].imshow(image, cmap='Grays_r')
        ax[0].plot(
            gcp['col_sample'], gcp['row_sample'], 'xr',
            markersize=5, linewidth=1.5
            )
        ax[0].set_title('Original')
        ax[1].imshow(image_undistorted, cmap='gray')
        ax[1].plot(
            gcp['col_sample_undistorted'], gcp['row_sample_undistorted'], 'xr',
            markersize=5, linewidth=1.5
            )
        ax[1].set_title('Undistorted')
        for axis in ax:
            axis.set_xticks([]), axis.set_yticks([])
        plt.tight_layout()

        # save to file
        fig_file = os.path.join(output_folder, os.path.splitext(image_name)[0] + '_undistorted.png')
        fig.savefig(fig_file, dpi=300, bbox_inches='tight')
        # print('Saved results figure:', fig_file)

        plt.close()


os.makedirs(undistorted_folder, exist_ok=True)

# GROUP 1
print('\nGROUP 1: ch01-08\n----------')
image_list = sorted(glob(os.path.join(new_image_folder, '*.tiff')))[0:8]
# calculate shared calibration
print('Optimizing shared camera intrinsics using GCP...')
rms, K, dist, object_points_mean = calibrate_shared_intrinsics(image_list, gcp_merged_file, undistorted_folder, file_prefix='group1-')
# per-image processing
print('Estimating camera extrics for individual images...')
for image_file in tqdm(image_list):
    undistort_calibrate_extrinsics(image_file, gcp_merged_file, K, dist, object_points_mean, undistorted_folder)

# GROUP 2
print('\nGROUP 2: ch09-16\n----------')
image_list = sorted(glob(os.path.join(new_image_folder, '*.tiff')))[8:]
# calculate shared calibration
print('Optimizing shared camera intrinsics using GCP...')
rms, K, dist, object_points_mean = calibrate_shared_intrinsics(image_list, gcp_merged_file, undistorted_folder, file_prefix='group2-')
# per-image processing
print('Estimating camera extrics for individual images...')
for image_file in tqdm(image_list):
    undistort_calibrate_extrinsics(image_file, gcp_merged_file, K, dist, object_points_mean, undistorted_folder)

## Initial orthorectification

In [None]:
os.makedirs(init_ortho_folder, exist_ok=True)

image_list = sorted(glob(os.path.join(undistorted_folder, '*_undistorted.tiff')))
cam_list = sorted(glob(os.path.join(undistorted_folder, '*.tsai')))

# Mapproject
pbar = tqdm(total=len(image_list))
for image_file, cam_file in zip(image_list, cam_list):
    image_out_file = os.path.join(init_ortho_folder, os.path.basename(image_file).replace('.tiff', '_map.tiff'))
    cmd = [
        'mapproject',
        '--threads', '12',
        '--nodata-value', '0',
        '--tr', '0.003',
        refdem_file, image_file, cam_file, image_out_file
    ]
    subprocess.run(cmd)
    pbar.update(1)

# Mosaic orthoimages
print('\nMosaicking orthoimages')
image_list = sorted(glob(os.path.join(init_ortho_folder, '*.tiff')))
mosaic_file = os.path.join(init_ortho_folder, f'orthomosaic.tif')
fnc = shutil.which('gdal_merge.py')
cmd = [
    'python', fnc,
    '-o', mosaic_file,
    '-n', '0',
    '-a_nodata', '-9999'
] + image_list
subprocess.run(cmd)

## Run stereo preprocessing to create dense match files

In [None]:
init_stereo_folder = os.path.join(out_folder, 'init_stereo')
os.makedirs(init_stereo_folder, exist_ok=True)

image_list = sorted(glob(os.path.join(undistorted_folder, '*.tiff')))
cam_list = sorted(glob(os.path.join(undistorted_folder, '*.tsai')))

# Set up image pairs
image1_list, image2_list = image_list[0:-1], image_list[1:]
cam1_list, cam2_list = cam_list[0:-1], cam_list[1:]

# skip the 8/9 cams pair (different intrinsics solving during bundle adjust)
iskip = [i for i in range(0,len(image1_list)) if 'ch08' in image1_list[i]][0]
image1_list = image1_list[0:iskip] + image1_list[iskip+1:]
image2_list = image2_list[0:iskip] + image2_list[iskip+1:]
cam1_list = cam1_list[0:iskip] + cam1_list[iskip+1:]
cam2_list = cam2_list[0:iskip] + cam2_list[iskip+1:]

# Iterate over pairs
for i in tqdm(range(len(image1_list))):
    image1, image2 = image1_list[i], image2_list[i]
    cam1, cam2 = cam1_list[i], cam2_list[i]

    pair_prefix = os.path.join(
        init_stereo_folder,
        os.path.splitext(os.path.basename(image1))[0] + '__' + os.path.splitext(os.path.basename(image2))[0],
        'run'
        )
    
    cmd = [
        'parallel_stereo',
        '--threads-singleprocess', '12',
        '--threads-multiprocess', '12',
        '--stop-point', '1',
        image1, image2,
        cam1, cam2,
        pair_prefix,
    ]
    subprocess.run(cmd)

## Bundle adjust

In [None]:
ba_folder = os.path.join(out_folder, 'bundle_adjust')
os.makedirs(ba_folder, exist_ok=True)

image_list = sorted(glob(os.path.join(undistorted_folder, '*.tiff')))
cam_list = sorted(glob(os.path.join(undistorted_folder, '*.tsai')))

# Add small initial intrinsics values to cameras
for cam_file in cam_list:
    with open(cam_file, 'r') as f:
        cam_lines = f.read().split('\n')
    # check if it's already been added
    if sum([x=='TSAI' for x in cam_lines]) > 0:
        continue
    cam_lines = cam_lines[0:-1]
    cam_lines[-1] = 'TSAI'
    cam_lines += ['k1 = -1e-6']
    cam_lines += ['k2 = 1e-6']
    cam_lines += ['p1 = 0']
    cam_lines += ['p2 = 0']
    cam_lines += ['k3 = 1e-6']
    cam_lines_merged = '\n'.join(cam_lines) + '\n'
    with open(cam_file, 'w') as f:
        f.write(cam_lines_merged)
    
# Copy dense matches to bundle adjust folder
match_list = sorted(glob(os.path.join(init_stereo_folder, '*', '*.match')))
for match_file in match_list:
    match_out_file = os.path.join(ba_folder, os.path.basename(match_file))
    _ = shutil.copy2(match_file, match_out_file)

# GROUP 1
print('\nGROUP 1: ch01-08\n----------')
image_list_group1 = image_list[0:8]
cam_list_group1 = cam_list[0:8]
# Run bundle adjust
cmd = [
    'parallel_bundle_adjust',
    '--threads', '12',
    '--num-iterations', '2000',
    '--num-passes', '2',
    '--inline-adjustments',
    '--force-reuse-match-files',
    '--heights-from-dem', refdem_file,
    '--heights-from-dem-uncertainty', '0.01',
    '--solve-intrinsics',
    '--intrinsics-to-share', 'optical_center,other_intrinsics',
    '--intrinsics-to-float', 'all',
    '-o', os.path.join(ba_folder, 'run_group1')
] + image_list_group1 + cam_list_group1
subprocess.run(cmd)

# GROUP 2
print('\nGROUP 2: ch09-16\n----------')
image_list_group2 = image_list[8:]
cam_list_group2 = cam_list[8:]
# Run bundle adjust
cmd = [
    'parallel_bundle_adjust',
    '--threads', '12',
    '--num-iterations', '2000',
    '--num-passes', '2',
    '--inline-adjustments',
    '--force-reuse-match-files',
    '--heights-from-dem', refdem_file,
    '--heights-from-dem-uncertainty', '0.01',
    '--solve-intrinsics',
    '--intrinsics-to-share', 'optical_center,other_intrinsics',
    '--intrinsics-to-float', 'all',
    '-o', os.path.join(ba_folder, 'run_group2')
] + image_list_group2 + cam_list_group2
subprocess.run(cmd)

## Final orthorectification

In [None]:
os.makedirs(final_ortho_folder, exist_ok=True)

image_list = sorted(glob(os.path.join(undistorted_folder, '*_undistorted.tiff')))
cam_list = sorted(glob(os.path.join(ba_folder, '*.tsai')))
print(len(image_list), len(cam_list))

# Mapproject
pbar = tqdm(total=len(image_list))
for image_file, cam_file in zip(image_list, cam_list):
    image_out_file = os.path.join(final_ortho_folder, os.path.basename(image_file).replace('.tiff', '_map.tiff'))
    cmd = [
        'mapproject',
        '--threads', '12',
        '--nodata-value', '0',
        '--tr', '0.003',
        refdem_file, image_file, cam_file, image_out_file
    ]
    subprocess.run(cmd)
    pbar.update(1)

# Mosaic orthoimages
print('\nMosaicking orthoimages')
image_list = sorted(glob(os.path.join(final_ortho_folder, '*.tiff')))
mosaic_file = os.path.join(final_ortho_folder, f'orthomosaic.tif')
fnc = shutil.which('gdal_merge.py')
cmd = [
    'python', fnc,
    '-o', mosaic_file,
    '-n', '0',
    '-a_nodata', '-9999'
] + image_list
subprocess.run(cmd)


## Try mosaicking by identifying closest camera at each pixel

In [None]:
import rasterio as rio

image_list = sorted(glob(os.path.join(final_ortho_folder, '*.tiff')))
cam_list = sorted(glob(os.path.join(ba_folder, '*.tsai')))

# skip image 2 for now
image_list = [x for x in image_list if 'ch02' not in x]
cam_list = [x for x in cam_list if 'ch02' not in x]

# Open DEM to use as reference grid
refdem = rxr.open_rasterio(refdem_file).squeeze()
# upsample to 3 mm - CRASHED
# refdem = refdem.rio.reproject(
#     refdem.rio.crs,
#     resolution=(0.003, 0.003),
#     resampling=rio.enums.Resampling.bilinear
# )

def read_camera_center(tsai_path):
    """Parse camera center (C = x y z) from a .tsai pinhole model file."""
    with open(tsai_path) as f:
        cam_lines = f.read().split('\n')
    for line in cam_lines:
        if 'C = ' in line:
            C = line.split(' ')[2:]
            cx_ecef = float(C[0])
            cy_ecef = float(C[1])
            cz_ecef = float(C[2])
    # reproject to UTM
    gdf = gpd.GeoDataFrame(geometry=[Point(cx_ecef, cy_ecef, cz_ecef)], crs="EPSG:4978")
    gdf = gdf.to_crs("EPSG:32619")
    cx, cy, cz = gdf['geometry'].x[0], gdf['geometry'].y[0], gdf['geometry'].z[0]

    return np.array([cx, cy, cz])

def mosaic_from_stack(stack, closest_idx_img):
    """Select pixel values from the stack using the nearest-camera index."""
    data = stack.data  # (camera, band, y, x)
    out = np.zeros((data.shape[1], data.shape[2], data.shape[3]), dtype=data.dtype)
    for b in range(data.shape[1]):
        out[b, :, :] = np.take_along_axis(
            data[:, b, :, :],
            closest_idx_img[None, :, :],
            axis=0
        )[0]
    return out

datasets = [rxr.open_rasterio(f).squeeze() for f in image_list]
# match refdem grid
datasets = [f.rio.reproject_match(refdem) for f in datasets]

camera_centers = np.array([read_camera_center(f) for f in cam_list])
print("Loaded camera centers")
print(camera_centers)

# Create per-pixel 3D coordinates from reference DEM
print('Creating 3D reference grid from DEM')
xv, yv = np.meshgrid(refdem.x.values, refdem.y.values)
Z = refdem.data
xyz_points = np.stack([xv.ravel(), yv.ravel(), Z.ravel()], axis=1)

# Calculate distances to each camera
print('Identifying closest camera to each pixel')
distances = np.linalg.norm(
    xyz_points[:, None, :] - camera_centers[None, :, :],
    axis=2
)
closest_idx = np.argmin(distances, axis=1)
closest_idx_img = closest_idx.reshape(refdem.shape)
closest_idx_img

# # Create mosaic
# stack = xr.concat(datasets, dim="camera")
# mosaic_data = mosaic_from_stack(stack, closest_idx_img)

# Convert to DataArray for easy export
# mosaic = xr.DataArray(
#     mosaic_data,
#     dims=("band", "y", "x"),
#     coords={"band": stack.band, "y": stack.y, "x": stack.x},
#     attrs=stack[0].attrs
# )
# mosaic.rio.write_crs(stack.rio.crs, inplace=True)

# mosaic
# mosaic.rio.to_raster(OUTPUT_TIF)


## Test running stereo for DEM construction

### Run stereo

In [None]:
os.makedirs(final_stereo_folder, exist_ok=True)

image_list = sorted(glob(os.path.join(final_ortho_folder, '*_map.tiff')))
cam_list = sorted(glob(os.path.join(ba_folder, '*.tsai')))

# Set up image pairs
image1_list, image2_list = image_list[0:-1], image_list[1:]
cam1_list, cam2_list = cam_list[0:-1], cam_list[1:]

# Iterate over pairs
for i in tqdm(range(len(image1_list))):
    image1, image2 = image1_list[i], image2_list[i]
    cam1, cam2 = cam1_list[i], cam2_list[i]

    pair_prefix = os.path.join(
        final_stereo_folder,
        os.path.splitext(os.path.basename(image1))[0] + '__' + os.path.splitext(os.path.basename(image2))[0],
        'run'
        )
    
    cmd = [
        'parallel_stereo',
        '--threads-singleprocess', '12',
        '--threads-multiprocess', '12',
        image1, image2,
        cam1, cam2,
        pair_prefix,
        refdem_file
    ]
    subprocess.run(cmd)

### Rasterize point clouds

In [None]:
pc_files = sorted(glob(os.path.join(final_stereo_folder, '*', '*-PC.tif')))
for pc in tqdm(pc_files):
    cmd = [
        'point2dem',
        '--threads', '12',
        '--tr', '0.01',
        pc
    ]
    subprocess.run(cmd)

### Mosaic DEMs

In [None]:
dem_fns = sorted(glob(os.path.join(final_stereo_folder, '*', '*DEM.tif')))
mosaic_fn = os.path.join(final_stereo_folder, f'DEM_mosaic.tif')
cmd = [
    'dem_mosaic',
    '--threads', '12',
    '-o', mosaic_fn
] + dem_fns
subprocess.run(cmd)

### Plot results

In [None]:
# Load the input files
def load_raster(raster_file):
    raster = rxr.open_rasterio(raster_file).squeeze()
    crs = raster.rio.crs
    raster = xr.where(raster < -100, np.nan, raster)
    raster.rio.write_crs(crs, inplace=True)
    return raster
ortho_file = os.path.join(final_ortho_folder, 'orthomosaic.tif')
ortho = load_raster(ortho_file)
dem_file = os.path.join(final_stereo_folder, 'DEM_mosaic.tif')
dem = load_raster(dem_file)
refdem = load_raster(refdem_file)
refdem = refdem.rio.reproject_match(dem)

plt.rcParams.update({'font.sans-serif': 'Verdana', 'font.size': 12})
fig, ax = plt.subplots(1, 3, figsize=(18,8))
# Ortho
ax[0].imshow(
    ortho,
    cmap='Grays_r',
    extent=(min(ortho.x), max(ortho.x), min(ortho.y), max(ortho.y))
)
ax[0].set_title('IR image mosaic')
# DEM
im = ax[1].imshow(
    dem, 
    cmap='terrain', 
    extent=(min(dem.x), max(dem.x), min(dem.y), max(dem.y), 'meters')
    )
cb = fig.colorbar(im, shrink=0.5)
ax[1].set_title('DSM mosaic')
# DEM - refdem
im = ax[2].imshow(
    dem - refdem, 
    cmap='coolwarm_r',
    clim=(-1,1),
    extent=(min(dem.x), max(dem.x), min(dem.y), max(dem.y))
    )
cb = fig.colorbar(im, shrink=0.5, label='meters')
ax[2].set_title('DSM mosaic - Lidar mean')

for axis in ax:
    axis.set_xticks([])
    axis.set_yticks([])

fig.tight_layout()
plt.show()

# Save to file
fig_file = os.path.join(out_folder, 'result.jpg')
fig.savefig(fig_file, dpi=300, bbox_inches='tight')
print('Figure saved to file:', fig_file)