# Testing ortho workflow

In [None]:
import os
from glob import glob
import cv2
import datetime
from tqdm import tqdm
import pandas as pd
import numpy as np
import ast
import subprocess
import multiprocessing
import shutil
import rioxarray as rxr
import xarray as xr
import rasterio as rio

data_folder = '/Users/rdcrlrka/Research/Soo_locks/20251001_imagery/'
video_folder = os.path.join(data_folder, 'video')
inputs_folder = '/Users/rdcrlrka/Research/Soo_locks/inputs/'
refdem_file = os.path.join(inputs_folder, 'lidar_DSM_filled_cropped.tif')
distort_param_file = os.path.join(inputs_folder, 'initial_undistortion_params.csv')
camera_files = sorted(glob(os.path.join(inputs_folder, '*_calibrated_camera.tsai')))
closest_cam_map_file = os.path.join(inputs_folder, 'closest_camera_map.tiff')

out_folder = os.path.join(data_folder, 'testing')
os.makedirs(out_folder, exist_ok=True)
image_folder = os.path.join(out_folder, 'images')
undistorted_folder = os.path.join(out_folder, 'images_undistorted')
ortho_folder = os.path.join(out_folder, 'orthoimages')

## Extract image frames from video

In [None]:
def string_to_datetime(datetime_string):
    return datetime.datetime(
        int(datetime_string[0:4]), 
        int(datetime_string[4:6]),
        int(datetime_string[6:8]),
        int(datetime_string[8:10]),
        int(datetime_string[10:12]),
        int(datetime_string[12:14])
        )


def extract_frame_at_clock_time(
        video_file: str = None, 
        target_time_string: str = None, 
        output_folder: str = None, 
        output_format: str = 'tiff'
        ):
    """Extract a frame from a video at a specific clock time and save as {video_file_name}.ext"""
    # parse start and end times from video file name
    start_time_string = os.path.basename(video_file).split('_')[3]
    end_time_string = os.path.splitext(os.path.basename(video_file))[0].split('_')[4].split('(')[0]

    # convert datetime strings to datetime objects
    target_time = string_to_datetime(target_time_string)
    start_time = string_to_datetime(start_time_string)
    end_time = string_to_datetime(end_time_string)

    print(f"\nProcessing {video_file}")
    print(f'Detected video time range: {start_time} to {end_time}')

    # Open video file
    cap = cv2.VideoCapture(video_file)
    if not cap.isOpened():
        print(f"Error: Could not open video file '{video_file}'.")
        return False

    # Determine the video time duration
    fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    duration = total_frames / fps if fps > 0 else 0

    # Check if the target time is beyond the video coverage
    time_offset = (target_time - start_time).total_seconds()
    if time_offset < 0 or (time_offset > duration):
        print(f"Error: Target time {time_offset:.2f}s is outside video duration ({duration:.2f}s)")
        cap.release()
        return False

    # Otherwise, get the appropriate frame
    frame_number = int(time_offset * fps)
    cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
    ret, frame = cap.read()
    if not ret:
        print(f"Error: Could not extract frame at {time_offset:.2f}s")
        cap.release()
        return False
    
    # Determine the camera number
    ch = os.path.basename(video_file).split('ch')[1][0:2]
    if (frame.shape[1] > 4000) & (ch=='1_'):
        ch = '09'
    elif (frame.shape[1] > 4000):
        ch = str(int(ch[0]) + 8)
    else:
        ch = '0' + ch[0]

    # Save to file
    output_image_file = os.path.join(
        output_folder, 
        f"ch{ch}_{target_time_string}.{output_format}"
        )
    # determine save settings based on output format
    save_params = []
    if output_format == 'png':
        save_params = [cv2.IMWRITE_PNG_COMPRESSION, 3]
    elif output_format in ['jpg', 'jpeg']:
        save_params = [cv2.IMWRITE_JPEG_QUALITY, 100]

    if cv2.imwrite(output_image_file, frame, save_params):
        print(f"Extracted frame -> {output_image_file}")
        cap.release()
        return True
    else:
        print(f"Failed to save frame")
        cap.release()
        return False
    

def process_video_files(
        video_files: list[str] = None, 
        target_time_string: str = None, 
        output_folder: str = None, 
        output_format: str = 'tiff'
        ):
    os.makedirs(output_folder, exist_ok=True)

    print('Target time:', string_to_datetime(target_time_string))

    # Iterate over video files
    for video_file in tqdm(video_files):
        extract_frame_at_clock_time(video_file, target_time_string, output_folder, output_format)


video_files = sorted(glob(os.path.join(video_folder, '*.avi')))
process_video_files(
    video_files, 
    target_time_string="20251001171500",
    output_folder=image_folder
    )

## Apply initial distortion correction

In [None]:
def correct_initial_image_distortion(
        params_file: str = None, 
        image_files: list[str] = None, 
        output_folder: str = None,
        full_fov: bool = True
        ):
    os.makedirs(output_folder, exist_ok=True)

    # Load the camera and distortion parameters file
    params = pd.read_csv(params_file)
    params['K'] = params['K'].apply(ast.literal_eval)
    params['K_full'] = params['K_full'].apply(ast.literal_eval)
    params['dist'] = params['dist'].apply(ast.literal_eval)

    # Iterate over image files
    for image_file in image_files:
        # Read image
        image = cv2.imread(image_file, cv2.IMREAD_UNCHANGED)
        h,w = image.shape[:2]

        # Determine the camera number
        ch = os.path.basename(image_file).split('_')[0][2:]

        # Get the respective distortion parameters
        params_im = params.loc[params['camera']==int(ch)].reset_index().iloc[0]
        K = np.array(params_im['K']).reshape(3,3)
        K_full = np.array(params_im['K_full']).reshape(3,3)
        dist = np.array(params_im['dist']).reshape(-1,1)

        # Undistort
        if full_fov:
            # must do some remapping to maintain no data values
            map1, map2 = cv2.initUndistortRectifyMap(K, dist, None, K_full, (w, h), cv2.CV_32FC1)
            # apply undistortion to the image
            image_undistorted = cv2.remap(image, map1, map2, interpolation=cv2.INTER_LINEAR)
            # create a white mask and remap it the same way to find valid areas
            mask = np.ones((h, w), dtype=np.uint8) * 255
            mask_undistorted = cv2.remap(mask, map1, map2, interpolation=cv2.INTER_NEAREST)
            # convert mask to boolean
            valid_mask = mask_undistorted > 0
            # Now set invalid pixels to NaN
            image_undistorted_nodata = image_undistorted.astype(np.float32)
            image_undistorted_nodata[~valid_mask] = np.nan
        else:
            image_undistorted = cv2.undistort(image, K, dist, None, K)

        # Save to file
        image_undistorted_file = os.path.join(
            output_folder, 
            os.path.splitext(os.path.basename(image_file))[0] + '_undistorted.tiff'
            )
        if full_fov:
            image_undistorted_file = os.path.splitext(image_undistorted_file)[0] + '_full.tiff'
        cv2.imwrite(image_undistorted_file, image_undistorted)
        print('Saved undistorted image:', image_undistorted_file)
    print('Done with intial undistortion.')

    return

image_files = sorted(glob(os.path.join(image_folder, '*.tiff')))
correct_initial_image_distortion(
    distort_param_file, 
    image_files, 
    undistorted_folder, 
    full_fov=True
    )


## Orthorectification

In [None]:
def run_cmd(bin: str = None, 
            args: list = None, **kw) -> str:
    """
    Wrapper for subprocess function to execute bash commands.

    Parameters
    ----------
    bin: str
        command to be excuted (e.g., stereo or gdalwarp)
    args: list
        arguments to the command as a list
    
    Returns
    ----------
    out: str
        log (stdout) as str if the command executed, error message if the command failed
    """
    binpath = shutil.which(bin)
    call = [binpath,]
    if args is not None: 
        call.extend(args)
    try:
        out = subprocess.run(call,check=True,capture_output=True,encoding='UTF-8').stdout
    except:
        out = f"the command {call} failed to run, see corresponding log"
    return out


def write_log_file(log, output_prefix):
    # create a string of the current datetime
    now_string = (
        str(datetime.datetime.now())
        .replace('-','')
        .replace(' ','')
        .replace(':','')
        .replace('.','')
    )

    # create output file name
    log_file = output_prefix + '_log_' + now_string + '.txt'

    # write to file
    with open(log_file, 'w') as f:
        f.write(log)

    return log_file


def orthorectify(
        image_list: list[str] = None, 
        camera_list: list[str] = None, 
        refdem_file: str = None, 
        output_folder: str = None,
        nodata_value: str = 'NaN',
        out_res: float = 0.003
        ):
    os.makedirs(output_folder, exist_ok=True)

    # Determine number of threads to use
    threads = multiprocessing.cpu_count()
    print(f'Will use up to {threads} threads for each process.')

    # Iterate over files
    for image_file, cam_file in zip(image_list, camera_list):
        # Define output file name
        image_out_file = os.path.join(output_folder, os.path.basename(image_file))

        # Set up and run command
        print('Orthorectifying:', image_file)
        args = [
            '--threads', str(threads),
            '--nodata-value', nodata_value,
            '--tr', str(out_res),
            refdem_file, image_file, cam_file, image_out_file
        ]
        log = run_cmd('mapproject', args)

        # Save log to file
        log_prefix = os.path.join(
            output_folder, 
            os.path.splitext(os.path.basename(image_file))[0] + '_mapproject'
            )
        log_file = write_log_file(log, log_prefix)
        print('Saved log:', log_file)

    print('Done orthorectifying.')


image_files = sorted(glob(os.path.join(undistorted_folder, '*.tiff')))

orthorectify(image_files, camera_files, refdem_file, ortho_folder)

## Mosaic orthorectified images

In [None]:
def mosaic_orthoimages(
        image_files: list[str] = None, 
        closest_cam_map_file: str = None, 
        output_folder: str = None
        ):
    # Load the map of closest camera
    print("Reading closest camera map")
    closest_cam_map = rxr.open_rasterio(closest_cam_map_file)
    crs = closest_cam_map.rio.crs

    # Load orthoimages
    print("Reading orthoimages")
    datasets = [rxr.open_rasterio(f, masked=True) for f in image_files]

    # Verify consistent CRS
    for ds in datasets:
        if ds.rio.crs != crs:
            raise ValueError(f"CRS mismatch in {ds.rio.nodata}")

    # Determine number of bands (use from first image)
    num_bands = datasets[0].rio.count
    print(f"Detected {num_bands} band(s) per image")

    # Determine target resolution (average or min pixel size)
    res_x = np.mean([abs(ds.rio.resolution()[0]) for ds in datasets])
    res_y = np.mean([abs(ds.rio.resolution()[1]) for ds in datasets])
    print(f"Using target resolution: {res_x:.3f}, {res_y:.3f}")

    # Determine output bounds and grid
    bounds = closest_cam_map.rio.bounds()
    width = int((bounds[2] - bounds[0]) / res_x)
    height = int((bounds[3] - bounds[1]) / res_y)
    transform = rio.transform.from_bounds(*bounds, width=width, height=height)

    # Create a dummy grid (reference for reprojection)
    dummy_grid = xr.DataArray(
        np.nan*np.zeros((height, width), dtype=np.uint8),
        dims=("y", "x"),
        coords={
            "y": np.linspace(bounds[3], bounds[1], height),
            "x": np.linspace(bounds[0], bounds[2], width),
        },
    ).rio.write_crs(crs).rio.write_transform(transform)

    # Reproject images
    print("Reprojecting images to target grid...")
    reprojected = [
        ds.rio.reproject_match(dummy_grid, resampling=rio.enums.Resampling.nearest)
        for ds in datasets
    ]

    # Stack all reprojected images along a "camera" dimension
    stack = xr.concat(reprojected, dim="camera")

    # Reproject closest_cam_map
    print("Reprojecting closest_cam_map to target grid...")
    closest_cam_map = closest_cam_map.rio.reproject_match(dummy_grid, resampling=rio.enums.Resampling.nearest)

    # Initialize mosaic with NaNs for all bands
    print("Creating mosaic...")
    mosaic_shape = (num_bands, height, width)
    mosaic = xr.DataArray(
        np.full(mosaic_shape, np.nan, dtype=np.float32),
        dims=("band", "y", "x"),
        coords={"band": np.arange(1, num_bands + 1), "y": dummy_grid.y, "x": dummy_grid.x},
    ).rio.write_crs(crs).rio.write_transform(transform)

    # Fill mosaic by selecting pixels based on closest_cam_map
    for i in range(len(stack.camera)):
        mask = closest_cam_map.squeeze() == i
        if num_bands == 1:
            mosaic = xr.where(mask, stack.isel(camera=i)[0], mosaic)
        else:
            for b in range(num_bands):
                mosaic[b] = xr.where(mask, stack.isel(camera=i, band=b), mosaic[b])

    # Save mosaic
    os.makedirs(output_folder, exist_ok=True)
    mosaic_file = os.path.join(output_folder, "orthomosaic.tif")
    mosaic.rio.to_raster(mosaic_file)
    print("Saved orthomosaic:", mosaic_file)


image_files = sorted(glob(os.path.join(ortho_folder, '*.tiff')))
mosaic_orthoimages(image_files, closest_cam_map_file, out_folder)