# Testing bundle adjust approaches

In [None]:
import os
from glob import glob
import subprocess
import rioxarray as rxr
import numpy as np
from tqdm import tqdm
from p_tqdm import p_map
import shutil
import pandas as pd
import rasterio as rio
import re
import multiprocessing
import itertools
import geopandas as gpd
from shapely.geometry import Point, Polygon
import json
import networkx as nx
from datetime import datetime

# Define inputs
data_path = '/Users/rdcrlrka/Research/SkySat-Stereo/study-sites/MCS/20240420/'
img_folder = os.path.join(data_path, 'SkySatScene')
refdem_file = os.path.join(data_path, '..', 'refdem', 'MCS_refdem_lidar_COPDEM_merged.tif')

# Define output folders
out_folder = os.path.join(data_path, 'proc_out')
os.makedirs(out_folder, exist_ok=True)
new_img_folder = os.path.join(out_folder, 'single_band_images')
camgen_folder = os.path.join(out_folder, 'camgen_cam_gcp')
init_ortho_folder = os.path.join(out_folder, 'init_ortho')
ba_triplet_folder = os.path.join(out_folder, 'bundle_adjust_triplets')
ba_global_folder = os.path.join(out_folder, 'bundle_adjust_global')
final_stereo_folder = os.path.join(out_folder, 'final_stereo')
final_ortho_folder = os.path.join(out_folder, 'final_ortho')

# Get images
img_list = sorted(glob(os.path.join(img_folder, '*analytic.tif')))
print(f"{len(img_list)} images located")

## Helper functions

In [None]:
def run_cmd(bin: str = None, 
            args: list = None, **kw) -> str:
    binpath = shutil.which(bin)
    if binpath.endswith('.py'):
        call = ['python', binpath,]
    else:
        call = [binpath,]
    if args is not None: 
        call.extend(args)
    try:
        out = subprocess.run(call,check=True,capture_output=True,encoding='UTF-8').stdout
    except:
        out = "the command {} failed to run, see corresponding asp log".format(call)
    return out


def run_gdal_merge(img_list: str = None,
                   mosaic_fn: str = None,
                   t_res: float | int = None,
                   t_nodata: float | int = 0,
                   t_dtype: str = 'UInt16'):
    # Make sure output directory exists
    if not os.path.exists(os.path.dirname(mosaic_fn)):
        os.mkdir(os.path.dirname(mosaic_fn))

    # Set up mosaic arguments
    mos_args = ['-ot', t_dtype,
                '-a_nodata', str(t_nodata)]
    if t_res:
        mos_args.extend(['-ps', str(t_res), str(t_res)])
    mos_args.extend(['-o', mosaic_fn])
    mos_args.extend(img_list)

    # Run image mosaic
    run_cmd('gdal_merge.py', mos_args)


def generate_frame_cameras(
        img_list = None,
        dem_file: str = None, 
        product_level: str = 'l1b', 
        gcp_std: float = 1,
        out_folder: str = None
        ) -> str:
    """
    Generate ASP camera models and GCPs for a list of images using cam_gen.

    Parameters
    ----------
    img_list: list
        list of image file names
    dem_file: str
        file name of the reference DEM
    product_level: str
        product level of the images, either 'l1b' or 'l1a'
    out_folder: str
        folder where output camera models and GCPs will be saved
    
    Returns
    ----------
    cam_gen_log: str
        file name of the cam_gen log file, which contains information about the number of GCP
    """
    # Make output directory if it doesn't exist
    if not os.path.exists(out_folder):
        os.mkdir(out_folder)
    
    cam_list = img_list

    # Define output camera and GCP files
    frames = [os.path.splitext(os.path.basename(x))[0] for x in img_list] # grab just the image identifier strings
    out_cam_list = [os.path.join(out_folder,'{}.tsai'.format(frame)) for frame in frames]
    out_gcp_list = [os.path.join(out_folder,'{}.gcp'.format(frame)) for frame in frames]

    # Define reference height value where DEM has no data
    ht_datum = np.nanmedian(rxr.open_rasterio(dem_file).squeeze().data) 

    # Determine number of jobs and threads per job
    threads = os.cpu_count()
    # ncpu, threads_per_job = 4, 3 #setup_parallel_jobs(total_jobs=len(img_list))
        
    # Iterate over images
    log_list = []
    for img, cam, out_cam, out_gcp in zip(img_list, cam_list, out_cam_list, out_gcp_list):
        # construct arguments
        args = [
            '--threads', str(threads),
            '--focal-length', str(553846.153846),
            '--optical-center', str(1280), str(540),
            '--height-above-datum', str(ht_datum),
            '--gcp-std', str(gcp_std),
            '--datum', 'WGS84',
            '--reference-dem', dem_file,
            '--refine-camera',
            '--input-camera', cam,
            '-o', out_cam,
            '--gcp-file', out_gcp,
            img
        ]
        if product_level=='l1b':
            args += ['--pixel-pitch', str(0.8)]
        else:
            args += ['--pixel-pitch', str(1.0)]

        # run command
        log = run_cmd('cam_gen', args)
        log_list += [log]

    # Save compiled cam_gen log
    cam_gen_log = os.path.join(out_folder, 'cam_gen.log')
    print("Saving cam_gen log at {}".format(cam_gen_log))
    with open(cam_gen_log,'w') as f:
        for log in log_list:
            f.write(log + '\n')
    
    # Remove basename from GCP file names
    # ASP's cam_gen writes full path for images in the GCP files. This does not play well during bundle adjustment.
    # The function returns a consolidated gcp file with all images paths only containing basenames so that bundle adjustment can roll along
    # See ASP's gcp logic here: https://stereopipeline.readthedocs.io/en/latest/tools/bundle_adjust.html#bagcp
    print("Writing GCPs with dirname removed")  
    def clean_img_in_gcp(row):
        return os.path.basename(row[7])
    for out_gcp in tqdm(out_gcp_list):
        df = pd.read_csv(out_gcp, header=None,delimiter=r"\s+")
        df[7] = df.apply(clean_img_in_gcp, axis=1)
        df[0] = np.arange(len(df))
        out_fn = os.path.join(out_folder, os.path.basename(out_gcp).replace('.gcp', '_clean.gcp'))
        df.to_csv(out_fn, sep=' ', index=False, header=False)

    return cam_gen_log


def calculate_baseline_to_height_ratio(
        img1: str = None, 
        img2: str = None, 
        utm_epsg: str = None
        ) -> float:
    """
    Calculate the baseline to height ratio for a pair of images.

    Parameters
    ----------
    img1: str
        file name of the first image
    img2: str
        file name of the second image
    utm_epsg: str
        EPSG code for the optimal UTM zone, e.g. "EPSG:32601"
    
    Returns
    ----------
    b_h_ratio: float
        baseline to height ratio, where baseline is the distance between camera centers and height is the average height of the two images
    """
    # iterate over images
    cams_list, h_list = [], []
    for img in [img1, img2]:
        # get camera center coordinates and heights
        with rio.open(img) as src:
            h = src.rpcs.height_off
            lat = src.rpcs.lat_off
            lon = src.rpcs.long_off
        # reproject to UTM for distance calculations
        gdf = gpd.GeoDataFrame(index=[0], geometry=[Point(lon, lat)], crs="EPSG:4326")
        gdf = gdf.to_crs(utm_epsg)
        x = gdf.geometry[0].coords.xy[0][0]
        y = gdf.geometry[0].coords.xy[0][0]
        # save in arrays
        cams_list += [[x,y]]
        h_list += [h]
    # calculate baseline
    diff = np.array(cams_list[0]) - np.array(cams_list[1])
    b = np.linalg.norm(diff)
    h_mean = np.nanmean(np.array(h_list))
    # calculate B/H ratio
    return float(b / h_mean)


def rpc_image_latlon_bounds(
        img_fn: str = None, 
        height: float = 0.0
        ) -> tuple[float, float, float, float]:
    """
    Get bounding box (min lon, min lat, max lon, max lat) for image with RPC metadata.

    Parameters
    ----------
    img_fn: str
        Path to image file with RPCs.
    height: float
        Ground height in meters used for projection.

    Returns
    ----------
    tuple: [min_lon, min_lat, max_lon, max_lat]
    """
    with rio.open(img_fn) as src:
        if not src.rpcs:
            raise ValueError("Image does not contain RPC metadata.")

        transformer = rio.transform.RPCTransformer(src.rpcs)

        width = src.width
        height_px = src.height

        # Image corners (col, row)
        pixel_coords = [
            (0, 0),                    # top-left
            (width - 1, 0),            # top-right
            (width - 1, height_px - 1),# bottom-right
            (0, height_px - 1)         # bottom-left
        ]

        cols, rows = zip(*pixel_coords)
        zs = [height] * 4

        lons, lats = transformer.xy(cols, rows, zs)

        min_lon = float(np.min(lons))
        max_lon = float(np.max(lons))
        min_lat = float(np.min(lats))
        max_lat = float(np.max(lats))

        return min_lon, min_lat, max_lon, max_lat
    

def find_matching_camera_file(
        image_fn: str = None, 
        cam_folder: str = None
        ) -> str:
    """
    Find camera file matching the image file's unique identifier.
    Parameters
    ----------
    image_fn: str
        file name of the image
    cam_folder: str
        folder containing camera files
    Returns
    ----------
    matched_fn: str
        file name of the matching camera file
    """
    # Get the identifying string from the image file
    # File naming convention SkySatScenes (https://developers.planet.com/docs/data/skysat/): 
    # <acquisition date>_<acquisition time>_<satellite_id><camera_id>_<frame_id>_<bandProduct>
    match = re.search(r"\d{8}_\d{6}_[a-zA-Z0-9]+_\d{4}", image_fn)
    if match:
        identifier = match.group(0)
    else:
        identifier = None
    if not identifier:
        raise ValueError(f"Could not extract identifier from image: {image_fn}")

    # Find matching camera file(s)
    cam_list = (glob(os.path.join(cam_folder, "*.tsai")) 
                + glob(os.path.join(cam_folder, '*.TXT')))
    matched_fns = [f for f in cam_list if identifier in f]
    # ideally, only one match, otherwise it's ambiguous
    if len(matched_fns) == 0:
        raise ValueError(f"No matching camera file found for image: {image_fn}")
    elif len(matched_fns) > 1:
        print(f"Multiple matching camera files found for image: {image_fn}. " 
              "Returning the first one.")
        
    return matched_fns[0]


def setup_parallel_jobs(
        total_jobs: int = None,
        verbose: bool = True
        ) -> tuple[int, int]:
    """
    Determine the number of parallel jobs to run and threads per job.

    Parameters
    ----------
    total_jobs: int
        The total number of jobs to run (e.g., number of stereo pairs).

    Returns
    -------
    njobs: int
        Number of parallel processes to run.
    threads_per_job: int
        Number of threads to allocate per process.
    """
    total_cpus = multiprocessing.cpu_count()

    if total_jobs <= 1:
        njobs = 1
    elif total_jobs <= 10:
        njobs = min(2, total_jobs)
    elif total_jobs <= 100:
        njobs = min(4, total_jobs)
    else:
        njobs = min(4, total_jobs)

    threads_per_job = max(1, total_cpus // njobs)

    if verbose:
        print(f"Will run {total_jobs} jobs across {njobs} CPU with {threads_per_job} threads per CPU")

    return njobs, threads_per_job


def run_mapproject(
        img_list: str = None, 
        cam_folder: str = None, 
        ba_prefix: str = None,
        out_folder: str = None, 
        dem: str = 'WGS84', 
        t_res: float = None, 
        t_crs: str = None, 
        session: str = None, 
        orthomosaic: bool = False
        ) -> None:
    """
    Mapproject images onto a reference DEM and optionally, create median mosaic of mapprojected images. 

    Parameters
    ----------
    img_list: list of str
        list of image file names
    cam_folder: str
        folder containing camera files
    ba_prefix: str
        bundle adjust prefix used to grab cameras or adjustments
    out_folder: str
        path to the folder where mapprojected images and cameras will be saved
    dem: str (default="WGS84")
        reference DEM used for mapprojection. If None, will use the WGS84 datum.
    t_res: float | str
        target spatial resolution of the mapprojected images (meters)
    t_crs: str
        target coordinate reference system of the mapprojected images (e.g., "EPSG:4326")
    session: str
        ASP session type (e.g., "pinhole"). Usually, ASP can determine this automatically based on the inputs. 
    orthomosaic: bool
        whether to create a median mosaic of the mapprojected images, along with count, NMAD, weighted average, 
        and mosaics from different stereo views

    Returns
    ----------
    None
    """
    os.makedirs(out_folder, exist_ok=True)

    # Set up image specific arguments: output prefixes and cameras
    frames_list = [os.path.splitext(os.path.basename(img))[0] for img in img_list]
    out_list = [os.path.join(out_folder, img + '.tif') for img in frames_list]
    cam_list = [find_matching_camera_file(img, cam_folder) for img in img_list]

    # Determine number of threads to use per job
    # ncpu, threads_per_job = setup_parallel_jobs(total_jobs=len(img_list))
    # Mapproject is automatically splitting images into a single tile, 
    # so only one threads is needed for each image job
    ncpu, threads_per_job = 12, 1

    # Set up mapproject arguments
    map_opts = [
        '--threads', str(threads_per_job),
        # limit to integer values, with 0 as no-data
        '--nodata-value', '0',
        '--ot', 'UInt16'
        ]
    if t_res:
        map_opts += ['--tr', str(t_res)]
    if t_crs:
        map_opts += ['--t_srs', str(t_crs)]
    if session:
        map_opts += ['--session', session]
    if ba_prefix:
        map_opts += ['--bundle-adjust-prefix', ba_prefix]

    # Set up jobs
    jobs_list = []
    for img, cam, out in tqdm(list(zip(img_list, cam_list, out_list))):
        job = map_opts + [dem, img, cam, out]
        jobs_list += [job]
    print('\nmapproject arguments for first job:')
    print(jobs_list[0])
    
    # Run in parallel
    log_list = p_map(run_cmd, ['mapproject']*len(jobs_list), jobs_list, num_cpus=ncpu)
    
    # Save compiled ortho log
    ortho_log = os.path.join(out_folder, 'ortho.log')
    print("Saving compiled orthorectification log at {}".format(ortho_log))
    with open(ortho_log,'w') as f:
        for log in log_list:
            f.write(log + '\n')
    
    # Create orthomosaic
    if orthomosaic:
        print("\nCreating orthomosaic")
        # get unique image datetimes
        dt_list = list(set(sorted(['_'.join(os.path.basename(im).split('_')[0:2]) for im in out_list])))

        # define mosaic prefix containing timestamps of inputs
        mos_prefix = '__'.join(dt_list)

        # define output filenames
        mosaic_fn = os.path.join(out_folder, f'{mos_prefix}_orthomosaic.tif')

        run_gdal_merge(
            img_list=out_list, 
            mosaic_fn = mosaic_fn,
            t_res = t_res
            )


def identify_overlapping_image_pairs(
        img_list: str = None, 
        overlap_perc: float = 10, 
        bh_ratio_range: tuple = None,
        true_stereo: bool = True,
        utm_epsg: str = None,
        out_folder: str = None,
        write_basename: bool = False
        )-> None:
    # Make sure out_folder exists
    os.makedirs(out_folder, exist_ok=True)

    # Get image bounds polygons
    def get_image_polygon(img_fn):
        # if no CRS, image is likely raw, ungeoregistered. Estimate using RPC.
        crs = rxr.open_rasterio(img_fn).rio.crs
        if not crs:
            min_x, min_y, max_x, max_y = rpc_image_latlon_bounds(img_fn)
            crs = "EPSG:4326"
        # otherwise, use the embedded image bounds.
        else:
            min_x, min_y, max_x, max_y = rxr.open_rasterio(img_fn).rio.bounds()
        # convert bounds to polygon
        bounds_poly = Polygon([[min_x, min_y], [max_x, min_y],
                                [max_x, max_y], [min_x, max_y],
                                [min_x, min_y]])
        # make sure bounds are in UTM projection
        bounds_gdf = gpd.GeoDataFrame(index=[0], geometry=[bounds_poly], crs=crs)
        bounds_gdf = bounds_gdf.to_crs(utm_epsg)

        return bounds_gdf.geometry[0]
    polygons = {img: get_image_polygon(img) for img in img_list}
    
    # Compare all unique pairs
    print('Identifying stereo image pairs...')
    print(f'Requirements:')
    print(f'\toverlap >= {overlap_perc} %')
    if bh_ratio_range:
        print(f'\tbaseline to height ratio = {bh_ratio_range[0]} to {bh_ratio_range[1]}')
    print(f'\ttrue stereo = {true_stereo}')
    overlapping_pairs = []
    overlap_ratios = []
    bh_ratios = []
    # number of combos for progress bar
    n = len(img_list)
    total = n * (n - 1) // 2
    for img1, img2 in tqdm(itertools.combinations(img_list, 2), total=total):
        poly1 = polygons[img1]
        poly2 = polygons[img2]

        intersection = poly1.intersection(poly2)
        if not intersection.is_empty:
            area1 = poly1.area
            area2 = poly2.area
            overlap_percent = intersection.area / min(area1, area2) * 100
            if overlap_percent >= overlap_perc:
                # check for B/H ratio thresholds if specified
                bh_ratio = calculate_baseline_to_height_ratio(img1, img2, utm_epsg)
                if bh_ratio_range:
                    if (bh_ratio < bh_ratio_range[0]) | (bh_ratio > bh_ratio_range[1]):
                        continue
                
                # check for true stereo if specified - datetimes must be different
                dt1 = '_'.join(os.path.basename(img1).split('_')[0:2])
                dt2 = '_'.join(os.path.basename(img2).split('_')[0:2])
                if true_stereo & (dt1==dt2):
                    continue

                bh_ratios += [bh_ratio]
                overlapping_pairs += [(img1, img2)]
                overlap_ratios += [overlap_percent]
    print('Number of overlapping stereo pairs identified =', len(overlap_ratios))
                    
    # Write to file
    out_fn = os.path.join(out_folder, 'overlapping_image_pairs.txt')
    # add the header
    with open(out_fn, 'w') as f:
        f.write(f"img1 img2 datetime_identifier overlap_percent bh_ratio\n")
    # iterate over pairs
    for i, (img1, img2) in enumerate(overlapping_pairs):
        date1, time1 = os.path.basename(img1).split('_')[0:2]
        date2, time2 = os.path.basename(img2).split('_')[0:2]
        dt_text = date1 + '_' + time1 + '__' + date2 + '_' + time2
        with open(out_fn, 'a') as f:
            if write_basename:
                if i==0:
                    print('\nWriting image pairs with basename only.')
                f.write(f"{os.path.basename(img1)} {os.path.basename(img2)} {dt_text} {overlap_ratios[i]} {bh_ratios[i]}\n")
            else:
                if i==0:
                    print('Writing image pairs with full path name.')
                f.write(f"{img1} {img2} {dt_text} {overlap_ratios[i]} {bh_ratios[i]}\n")

    print('Overlapping stereo pairs saved to file:', out_fn)

    return


def create_triplets_pairs(overlap_txt, output_folder, high_overlap_thresh=65, max_group_size=4):
    """
    Create overlapping triplet groups (for bundle adjust) and stereo pairs 
    from overlapping image data. Ensures each image belongs to exactly one triplet.
    """

    os.makedirs(output_folder, exist_ok=True)

    # Load overlap data
    df = pd.read_csv(overlap_txt, sep=' ', header=0)

    # Initialize tracking
    groups_dict = {}
    assigned = set()
    image_list = set(df[['img1', 'img2']].values.flatten())
    unassigned = set(image_list)

    print(f"Total images: {len(image_list)}")

    # --- Start with high-overlap pairs ---
    df_max = df[df['overlap_percent'] > high_overlap_thresh].sort_values(by='overlap_percent', ascending=False)

    i = 0
    for _, row in df_max.iterrows():
        img1, img2 = row['img1'], row['img2']
        if img1 in assigned or img2 in assigned:
            continue  # skip already used

        # Find a third image overlapping strongly with either img1 or img2
        df_third = df[
            ((df['img1'].isin([img1, img2])) | (df['img2'].isin([img1, img2])))
            & (~df['img1'].isin([img1, img2]))
            & (~df['img2'].isin([img1, img2]))
        ].sort_values(by='overlap_percent', ascending=False)

        if not df_third.empty:
            row3 = df_third.iloc[0]
            img3 = row3['img1'] if row3['img1'] not in [img1, img2] else row3['img2']
            group_imgs = [img1, img2, img3]
        else:
            group_imgs = [img1, img2]

        # Save group
        df_group = df[(df['img1'].isin(group_imgs)) & (df['img2'].isin(group_imgs))]
        groups_dict[f'group_{i}'] = {
            'file_names': group_imgs,
            'length': len(group_imgs),
            'overlap_percent_mean': float(df_group['overlap_percent'].mean()),
            'bh_ratio_mean': float(df_group['bh_ratio'].mean())
        }

        assigned.update(group_imgs)
        unassigned = image_list - assigned
        i += 1

    print(f"After high-overlap grouping: {len(assigned)} assigned, {len(unassigned)} unassigned.")

    # --- Assign remaining images to best overlapping group ---
    while unassigned:
        ref_img = next(iter(unassigned))
        df_img = df[(df['img1'] == ref_img) | (df['img2'] == ref_img)]
        if df_img.empty:
            # no overlaps — make its own group
            groups_dict[f'group_{i}'] = {
                'file_names': [ref_img],
                'length': 1,
                'overlap_percent_mean': 0.0,
                'bh_ratio_mean': 0.0
            }
            assigned.add(ref_img)
            unassigned = image_list - assigned
            i += 1
            continue

        # find the best overlapping assigned image
        df_img_sorted = df_img.sort_values(by='overlap_percent', ascending=False)
        for _, row in df_img_sorted.iterrows():
            img_best = row['img1'] if row['img1'] != ref_img else row['img2']
            found_group = None
            for g, ginfo in groups_dict.items():
                if img_best in ginfo['file_names']:
                    found_group = g
                    break
            if found_group:
                ginfo = groups_dict[found_group]
                if ref_img not in ginfo['file_names']:
                    ginfo['file_names'].append(ref_img)
                    ginfo['length'] += 1
                assigned.add(ref_img)
                break
        else:
            # no overlapping assigned image — make a new pair
            img_best = df_img_sorted.iloc[0]['img1'] if df_img_sorted.iloc[0]['img1'] != ref_img else df_img_sorted.iloc[0]['img2']
            groups_dict[f'group_{i}'] = {
                'file_names': [ref_img, img_best],
                'length': 2,
                'overlap_percent_mean': float(df_img_sorted.iloc[0]['overlap_percent']),
                'bh_ratio_mean': float(df_img_sorted.iloc[0]['bh_ratio'])
            }
            assigned.update([ref_img, img_best])
        unassigned = image_list - assigned
        i += 1

    print(f"After assignment: {len(assigned)} assigned (should equal total images).")

    # --- Split large groups with overlap-preserving subgroups ---
    new_groups_dict = {}
    new_group_idx = 0

    for _, ginfo in groups_dict.items():
        files = ginfo['file_names']

        if len(files) <= max_group_size:
            new_groups_dict[f'group_{new_group_idx}'] = ginfo
            new_group_idx += 1
            continue

        # Build overlap graph among images in this group
        G = nx.Graph()
        df_sub = df[(df['img1'].isin(files)) & (df['img2'].isin(files))]
        for _, row in df_sub.iterrows():
            G.add_edge(row['img1'], row['img2'], weight=row['overlap_percent'])

        remaining = set(files)
        while remaining:
            if len(remaining) <= max_group_size:
                sub_group = list(remaining)
                remaining.clear()
            else:
                # start from node with highest degree (most connections)
                node = max(remaining, key=lambda n: G.degree(n))
                neighbors = sorted(
                    [nbr for nbr in G.neighbors(node) if nbr in remaining],
                    key=lambda n: G[node][n]['weight'],
                    reverse=True
                )
                sub_group = [node] + neighbors[:max_group_size - 1]
                remaining -= set(sub_group)
                # overlap last node to maintain connection
                if sub_group[-1] not in remaining:
                    remaining.add(sub_group[-1])

            # Save subgroup
            df_group = df[(df['img1'].isin(sub_group)) | (df['img2'].isin(sub_group))]
            new_groups_dict[f'group_{new_group_idx}'] = {
                'file_names': sub_group,
                'length': len(sub_group),
                'overlap_percent_mean': float(df_group['overlap_percent'].mean()),
                'bh_ratio_mean': float(df_group['bh_ratio'].mean())
            }
            new_group_idx += 1

    print(f"Final groups after splitting: {len(new_groups_dict)}")

    # --- Generate stereo pairs from triplets ---
    stereo_pairs_list = []
    for _, ginfo in new_groups_dict.items():
        imgs = ginfo['file_names']
        for pair in itertools.combinations(imgs, 2):
            # check both orientations in df
            mask = (
                ((df['img1'] == pair[0]) & (df['img2'] == pair[1])) |
                ((df['img1'] == pair[1]) & (df['img2'] == pair[0]))
            )
            df_pair = df.loc[mask]
            # only add if not empty
            if not df_pair.empty:
                stereo_pairs_list.append(df_pair)
    stereo_pairs = pd.concat(stereo_pairs_list, ignore_index=True).drop_duplicates(subset=['img1', 'img2'])

    # --- Save outputs ---
    # Save triplets as JSON to account to potentially different-sized groups
    triplets_json = os.path.join(output_folder, 'bundle_adjust_triplets.json')
    with open(triplets_json, 'w') as f:
        json.dump(new_groups_dict, f, indent=2)
    print(f"Saved {len(new_groups_dict)} triplets to file:", triplets_json)

    # Save pairs as TXT for easier reading
    stereo_pairs_txt = os.path.join(output_folder, 'stereo_pairs.txt')
    stereo_pairs.to_csv(stereo_pairs_txt, sep=' ', header=True, index=False)
    print(f"Saved {len(stereo_pairs)} pairs to file:", stereo_pairs_txt)

    return triplets_json, stereo_pairs_txt


def get_stereo_opts(
        session: str = None, 
        threads: int = None, 
        texture: str = 'normal', 
        stop_point: int = -1, 
        unalign_disparity: bool = False
        ):
    """
    Get the stereo options for the ASP parallel_stereo command.

    Parameters
    ----------
    session: str (default=None)
        The session type to use for stereo matching. Options include 'rpc', 'pinhole', etc. 
        ASP can often figure this out automatically. 
    threads: int (default=None)
        Number of threads to use for parallel processing. If None, will automatically determine based on CPU count.
    texture: str (default='normal')
        This is used for determining the correlation and refinement kernel. Options = "low", "normal".
    stop_point: int (default=-1)
        Stopping point for stereo processing. Set to -1 to run all steps. Useful if only creating feature matches 
        or running image correlation, for example. See the ASP docs on stereo entry points for more information: 
        https://stereopipeline.readthedocs.io/en/latest/tools/parallel_stereo.html#entrypoints
    unalign_disparity: bool (default=False)
        Whether to generate disparity maps without alignment. This can be used for debugging or testing purposes.
    
    Returns
    ----------
    stereo_opt: list
        A list of command line options for the ASP parallel_stereo command.
    """
    stereo_opts = []
    # session_args
    if session:
        stereo_opts.extend(['-t', session])
    stereo_opts.extend(['--threads-multiprocess', str(threads)])
    stereo_opts.extend(['--threads-singleprocess', str(threads)])
    # stereo_pprc args : This is for preprocessing (adjusting image dynamic range, 
    # alignment using ip matches, etc.)
    stereo_opts.extend(['--individually-normalize'])
    stereo_opts.extend(['--ip-per-tile', '8000'])
    stereo_opts.extend(['--ip-num-ransac-iterations','2000'])
    stereo_opts.extend(['--force-reuse-match-files'])
    stereo_opts.extend(['--skip-rough-homography'])
    stereo_opts.extend(['--alignment-method', 'Affineepipolar'])
    # mask out completely feature less area using a std filter, to avoid gross MGM errors
    # this is experimental and needs more testing
    stereo_opts.extend(['--stddev-mask-thresh', '0.5'])
    stereo_opts.extend(['--stddev-mask-kernel', '-1'])
    # stereo_corr_args
    stereo_opts.extend(['--stereo-algorithm', 'asp_mgm'])
    # correlation kernel size depends on the texture
    if texture=='low':
        stereo_opts.extend(['--corr-kernel', '9', '9'])
    elif texture=='normal':
        stereo_opts.extend(['--corr-kernel', '7', '7'])
    stereo_opts.extend(['--corr-tile-size', '1024'])
    stereo_opts.extend(['--cost-mode', '4'])
    stereo_opts.extend(['--corr-max-levels', '5'])
    # stereo_rfne_args:
    stereo_opts.extend(['--subpixel-mode', '9'])
    if texture=='low':
        stereo_opts.extend(['--subpixel-kernel', '21', '21'])
    elif texture=='normal':
        stereo_opts.extend(['--subpixel-kernel', '15', '15'])
    stereo_opts.extend(['--xcorr-threshold', '2'])
    stereo_opts.extend(['--num-matches-from-disparity', '10000'])
    # add stopping point if specified
    if stop_point!=-1:
        stereo_opts.extend(['--stop-point', str(stop_point)])
    # get the disparity map without any alignment
    if unalign_disparity:
        stereo_opts.extend(['--unalign-disparity'])
    
    return stereo_opts


def run_stereo(
        stereo_pairs_fn: str = None, 
        cam_folder: str = None, 
        dem_file: str = None,
        out_folder: str = None, 
        session: str = None,
        texture: str = 'normal', 
        stop_point: int = -1,
        verbose: bool = True
        ) -> None:
    """
    Execute stereo matching for SkySat images using the ASP parallel_stereo command.

    Parameters
    ----------
    stereo_pairs_fn: str (default=None)
        Path to the text file containing overlapping image pairs.
    cam_folder: str (default=None)
        Path to the folder containing camera files. Required if using 'pinhole' session.
    dem_file: str (default=None)

    out_folder: str
        Path to the folder where the output stereo results will be saved.
    session: str (default=None)
        The session type to use for stereo matching. Options include 'rpc', 'pinhole', etc. ASP can often figure this out automatically.
    texture: str (default='normal')
        How much relative texture there is in your images. This is used for determining the correlation and refinement kernel. 
        Options = "low", "normal". For example, a flat, snowy landscape likely has "low" texture. 
    stop_point: int

    
    Returns
    ----------
    None
    """
    # Check if output folder exists
    if not os.path.exists(out_folder):
        os.makedirs(out_folder)
    
    # Load the stereo pairs
    stereo_pairs_df = pd.read_csv(stereo_pairs_fn, sep=' ', header=0)

    # Determine number of CPUs for parallelization and threads per job
    ncpu, threads_per_job = setup_parallel_jobs(total_jobs=len(stereo_pairs_df), verbose=verbose)
    
    # Define stereo arguments
    stereo_opts = get_stereo_opts(
        session=session, 
        threads=threads_per_job, 
        texture=texture, 
        stop_point=stop_point
        )
    
    # Create jobs list for each stereo pair
    job_list = []
    for _, row in stereo_pairs_df.iterrows():
        # Determine output folder for stereo job
        IMG1 = os.path.splitext(os.path.basename(row['img1']))[0]
        IMG2 = os.path.splitext(os.path.basename(row['img2']))[0]
        out_prefix = os.path.join(out_folder, row['datetime_identifier'], IMG1 + '__' + IMG2, 'run')  

        # Construct the stereo job
        if cam_folder:
            # Use the camera files if provided
            cam1 = find_matching_camera_file(row['img1'], cam_folder)
            cam2 = find_matching_camera_file(row['img2'], cam_folder)
            job = stereo_opts + [row['img1'], cam1, row['img2'], cam2, out_prefix]
        else:
            # Otherwise, use the images directly
            stereo_args = [row['img1'], row['img2'], out_prefix]
            job = stereo_opts + stereo_args
        # add DEM last
        if dem_file:
            job += [dem_file]

        # Add job to list of jobs
        job_list.append(job)

    if verbose:
        print('stereo arguments for first job:')
        print(job_list[0])
    
    # Run the jobs in parallel
    stereo_logs = p_map(run_cmd, ['parallel_stereo']*len(job_list), job_list, num_cpus=ncpu)

    # Save the consolidated log
    stereo_log_fn = os.path.join(out_folder, 'stereo_log.log')
    with open(stereo_log_fn, 'w') as f:
        for log in stereo_logs:
            f.write(log + '\n')
    if verbose:
        print("Consolidated stereo log saved at {}".format(stereo_log_fn))

    return


def copy_match_files(matches_folder, image_files, output_prefix, verbose=True):
    # get match files from output_folder and any subfolders
    match_list = sorted(glob(os.path.join(matches_folder, '*.match')))
    if not match_list:
        match_list = sorted(glob(os.path.join(matches_folder, '*', '*.match')))
    if not match_list:
        match_list = sorted(glob(os.path.join(matches_folder, '*', '*', '*.match')))

    # subset to pairs in the image list
    image_list_base = [
        os.path.splitext(os.path.basename(x))[0].replace('run-','') 
        for x in image_files
        ]
    match_list = [
        x for x in match_list 
        if (os.path.dirname(x).split('/')[-1].split('__')[0] in image_list_base)
        & (os.path.dirname(x).split('/')[-1].split('__')[1] in image_list_base)
        ]
    if verbose:
        print(f'Copying {len(match_list)} matches to output folder')

    for match_file in match_list:
        match_out_file = (
            output_prefix + '-' 
            + os.path.dirname(match_file).split('/')[-1] 
            + '.match'
            )
        _ = shutil.copy2(match_file, match_out_file)


def reduce_asp_match_file(
        match_file: str = None, 
        out_file: str = None,
        num_pts: int = 100
        ):

    # --- Convert match files from binary -> text ---
    match_txt = match_file.replace('.match', '_match.txt')
    if not os.path.exists(match_txt):
        cmd = [match_file, match_txt]
        run_cmd('parse_match_file.py', cmd)

    # --- Reduce the number of feature matches ---
    with open(match_txt, 'r') as f:
        lines = [l.strip() for l in f if l.strip()]

    if not lines:
        print(f"Empty match file: {match_txt}. Exiting.")
        return

    # Parse header (number of match points per image)
    try:
        n1, n2 = np.array(lines[0].split()).astype(int)
    except Exception as e:
        print(f"Invalid header in {match_txt}: {lines[0]}. Exiting.")
        return

    img1_matches = lines[1:1+n1]
    img2_matches = lines[1+n1:1+n1+n2]

    if not img1_matches or not img2_matches:
        print(f"Incomplete match file: {match_txt}. Exiting.")
        return

    # Subsample
    new_samps = max(1, int(n1 / num_pts))
    img1_matches = img1_matches[::new_samps]
    img2_matches = img2_matches[::new_samps]

    # Update header
    n1_new = len(img1_matches)
    n2_new = len(img2_matches)
    header_line = f"{n1_new} {n2_new}"

    # Write reduced match text
    match_txt_reduced = match_txt.replace('.txt', '_reduced.txt')
    with open(match_txt_reduced, 'w') as f:
        f.write(header_line + '\n')
        f.write('\n'.join(img1_matches) + '\n')
        f.write('\n'.join(img2_matches) + '\n')

    # --- Convert reduced text -> binary ---
    cmd = ['-rev', match_txt_reduced, out_file]
    run_cmd('parse_match_file.py', cmd)

    return


def run_ba(
        image_list: list[str] = None, 
        cam_folder: str = None, 
        output_prefix: str = None,
        refdem_file: str = None, 
        refdem_uncertainty: float = 5, 
        skip_matching: bool = False,
        threads_string: str = "all",
        verbose: bool = True
        ):
    output_folder = os.path.dirname(output_prefix)
    os.makedirs(output_folder, exist_ok=True)

    # determine how many threads to use
    if threads_string=="all":
        threads = os.cpu_count()
    else:
        threads = int(threads_string)
    if verbose:
        print(f'Using up to {threads} threads for each process.')

    # get the cameras for each image
    cam_list = [find_matching_camera_file(x, cam_folder) for x in image_list]

    # construct the arguments
    args = [
        "--threads", str(threads),
        "--num-iterations", "500",
        "--num-passes", "2",
        "--inline-adjustments",
        "--save-cnet-as-csv",
        "--min-matches", "4",
        "--disable-tri-ip-filter",
        "--ip-per-tile", "4000",
        "--ip-inlier-factor", "0.2",
        "--ip-num-ransac-iterations", "1000",
        "--skip-rough-homography", 
        "--min-triangulation-angle", "0.0001",
        "--remove-outliers-params", "75 3 5 6",
        "--individually-normalize",
        "-o", output_prefix
        ] + image_list + cam_list

    if skip_matching:
        args += ["--force-reuse-match-files"]
        args += ["--skip-matching"]
    if refdem_file:
        args += ["--heights-from-dem", refdem_file]
        args += ["--heights-from-dem-uncertainty", str(refdem_uncertainty)]

    # run bundle adjust
    log = run_cmd('parallel_bundle_adjust', args)

    # write log to file
    dt_now = datetime.now()
    dt_now_string = str(dt_now).replace('-','').replace(' ','').replace(':','').replace('.','')
    log_file = output_prefix + f'-parallel_bundle_adjust_{dt_now_string}.log'
    with open(log_file, 'w') as f:
        f.write(log)
    
    if verbose:
        print('Saved compiled log to file:', log_file)
        print('Bundle adjust complete.')
        
    return



## Convert images to single band

In [None]:
os.makedirs(new_img_folder, exist_ok=True)

print('Saving first band of each image. Only single-band images allowed by ASP.')
for img in tqdm(img_list):
    out_img = os.path.join(new_img_folder, os.path.basename(img))
    if os.path.exists(out_img):
        continue
    args = [
        "-b", "1",
        img, out_img
    ]
    run_cmd("gdal_translate", args)

## Generate frame cameras and GCP

In [None]:
os.makedirs(camgen_folder, exist_ok=True)

generate_frame_cameras(
    img_list = sorted(glob(os.path.join(new_img_folder, '*.tif'))),
    dem_file = refdem_file, 
    product_level = 'l1b',
    out_folder = camgen_folder,
    gcp_std = 5
    )

## Initial orthorectification

In [None]:
os.makedirs(init_ortho_folder, exist_ok=True)

img_list = [x for x in sorted(glob(os.path.join(new_img_folder, '*.tif')))
            if not os.path.exists(os.path.join(init_ortho_folder, os.path.basename(x)))]
if img_list:
    run_mapproject(
        img_list = img_list,
        cam_folder = camgen_folder,
        out_folder = init_ortho_folder,
        dem = refdem_file,
        t_res = 1,
        t_crs = "EPSG:32611"
        )

## Incremental bundle adjust

In [None]:
ba_increm_folder = os.path.join(out_folder, 'ba_incremental')
os.makedirs(ba_increm_folder, exist_ok=True)

# First, identify all overlapping image pairs
identify_overlapping_image_pairs(
    img_list = sorted(glob(os.path.join(init_ortho_folder, '*.tif'))), 
    overlap_perc = 40, 
    utm_epsg = "EPSG:32611",
    out_folder = ba_increm_folder,
    true_stereo = False,
    )

# Update file names to use original images
overlap_txt = os.path.join(ba_increm_folder, 'overlapping_image_pairs.txt')
overlap = pd.read_csv(overlap_txt, sep=' ', header=0)
overlap['img1'] = [os.path.join(new_img_folder, os.path.basename(x)) for x in overlap['img1']]
overlap['img2'] = [os.path.join(new_img_folder, os.path.basename(x)) for x in overlap['img2']]
overlap.to_csv(overlap_txt, sep=' ', header=True, index=False)

In [None]:
# Helper function to plan the incrememtal adjustments before running
def plan_incremental_bundle_adjust(
    overlap_txt,
    out_plan_json,
    min_overlap_percent=None,
    top_k=2,
    verbose = True
):
    print('\nPlanning incremental bundle adjustment rounds')
    print('--------------------------------------------------')

    # Load overlap table
    overlap = pd.read_csv(overlap_txt, sep=' ', header=0)
    overlap = overlap.sort_values(by='overlap_percent', ascending=False)

    if min_overlap_percent:
        overlap = overlap[overlap['overlap_percent'] >= min_overlap_percent].copy()

    # Start with the most-overlapping pair
    img1 = overlap.iloc[0]['img1']
    img2 = overlap.iloc[0]['img2']

    # Initialize image sets, keeping one image fixed to prevent drift
    adjusted_img_list = [img1]
    fixed_this_round = [img1]
    current_new_imgs = [img2]
    unadjusted_img_list = list(set(overlap[['img1', 'img2']].values.ravel()))
    unadjusted_img_list = [x for x in unadjusted_img_list if x not in fixed_this_round]
    rounds = []
    round_num = 1

    # Main planning loop
    while current_new_imgs:
        # --- Plan the current round --- 
        # Determine pairs to run stereo on this round
        # pairs must have either: one unadjusted + one adjusted, or two unadjusted images
        pairs_this_round = overlap[
            (overlap['img1'].isin(current_new_imgs) & overlap['img2'].isin(set(current_new_imgs + fixed_this_round))) |
            (overlap['img2'].isin(current_new_imgs) & overlap['img1'].isin(set(current_new_imgs + fixed_this_round)))
        ].copy()
        pairs_list = pairs_this_round[['img1', 'img2']].values.tolist()

        # Add the round plan
        rounds.append({
            'round': round_num,
            'stereo_pairs': pairs_list,
            'adjust_images': list(current_new_imgs),
            'fixed_images': fixed_this_round
        })

        # Update global image lists
        adjusted_img_list.extend(current_new_imgs)
        for img in current_new_imgs:
            if img in unadjusted_img_list:
                unadjusted_img_list.remove(img)

        if verbose:
            print(f"\n Planned Round {round_num}")
            print(f"  Fixed: {len(fixed_this_round)} images")
            print(f"  Adjusting: {len(current_new_imgs)} images -> {current_new_imgs}")
            print(f"  Stereo pairs: {len(pairs_this_round)}")

        # --- Set up the next round
        # Identify the next images to adjust. They must overlap an already adjusted image.
        candidates = overlap[
            (overlap['img1'].isin(adjusted_img_list) | overlap['img2'].isin(adjusted_img_list))
        ]

        scores = {}
        for _, row in candidates.iterrows():
            imgA, imgB, ov = row['img1'], row['img2'], row['overlap_percent']
            if (imgA in adjusted_img_list) and (imgB in unadjusted_img_list):
                scores[imgB] = max(scores.get(imgB, 0), ov)
            elif (imgB in adjusted_img_list) and (imgA in unadjusted_img_list):
                scores[imgA] = max(scores.get(imgA, 0), ov)
        if not scores:
            if verbose:
                print("\nNo more overlapping unadjusted images. Planning complete.")
            break

        sorted_imgs = sorted(scores.items(), key=lambda x: x[1], reverse=True)
        current_new_imgs = [img for img, _ in sorted_imgs[:top_k]]

        # Find overlapping image to fix: adjusted images that overlap to-be-adjusted images
        overlapping_fixed = overlap[
            (overlap['img1'].isin(current_new_imgs) & overlap['img2'].isin(adjusted_img_list)) |
            (overlap['img2'].isin(current_new_imgs) & overlap['img1'].isin(adjusted_img_list))
        ]
        fixed_this_round = sorted(set(
            overlapping_fixed['img1'].tolist() + overlapping_fixed['img2'].tolist()
        ))
        fixed_this_round = [img for img in fixed_this_round if img in adjusted_img_list]

        # Continue to next iteration
        round_num += 1

    # Save plan as JSON
    with open(out_plan_json, 'w') as f:
        json.dump(rounds, f, indent=2)

    print('Total number of rounds planned:', len(rounds))
    print(f"Saved incremental bundle adjust plan to:\n{out_plan_json}")

    return


# Helper function to run the incremental adjustment plan
def run_incremental_bundle_adjust(
        ba_plan_json: str = None, 
        overlap_txt: str = None,
        cam_folder: str = None,
        output_folder: str = None,
        reduce_matches: bool = True,
        num_matches: int = 100
        ):
    
    # Load the incremental ba plan
    with open(ba_plan_json, 'r') as f:
        ba_plan = json.load(f)

    # Load the overlapping image pairs (global)
    overlap = pd.read_csv(overlap_txt, sep=' ', header=0)
    # add a "pair" row for easier access of pairs later
    overlap['pair'] = overlap.apply(lambda r: tuple(sorted([r['img1'], r['img2']])), axis=1)

    # Define bundle adjust base options
    ba_opts = [
        "--threads", str(os.cpu_count()),
        "--num-iterations", "500",
        "--num-passes", "2",
        "--inline-adjustments",
        "--save-cnet-as-csv",
        "--min-matches", "4",
        "--disable-tri-ip-filter",
        "--ip-per-tile", "4000",
        "--ip-inlier-factor", "0.2",
        "--ip-num-ransac-iterations", "1000",
        "--skip-rough-homography",
        "--min-triangulation-angle", "0.0001",
        "--remove-outliers-params", "75 3 5 6",
        "--individually-normalize",
        "--force-reuse-match-files",
        "--skip-matching"
    ]

    # Keep track of adjusted cameras
    adjusted_cams = {}

    # Iterate over rounds
    print('\nRunning incremental bundle adjust (IBA)')
    print('--------------------------------------------------')
    for i in range(0,2):
        round_num = ba_plan[i]['round']
        print(f'\nIBA round {round_num}')

        # --- Set up output folder --- 
        round_string = f"0{round_num}" if round_num < 10 else str(round_num)
        round_folder = os.path.join(output_folder, f"round{round_string}")
        round_prefix = os.path.join(round_folder, 'run')
        os.makedirs(round_folder, exist_ok=True)

        # --- Stereo preprocessing ---
        # Subset and save the overlapping pairs dataframe
        round_stereo_pairs = ba_plan[i]['stereo_pairs']
        round_stereo_pairs_set = set(tuple(sorted(pair)) for pair in round_stereo_pairs)
        round_overlap = overlap[overlap['pair'].isin(round_stereo_pairs_set)].copy().reset_index(drop=True)
        round_overlap_txt = os.path.join(round_folder, f'round{round_string}_overlapping_image_pairs.txt')
        round_overlap.to_csv(round_overlap_txt, sep=' ', index=False)
            
        # Run stereo preprocessing
        print(f'Running stereo preprocessing for {len(round_overlap)} image pairs')
        run_stereo(
            stereo_pairs_fn = round_overlap_txt,
            cam_folder = cam_folder,
            out_folder = round_folder,
            stop_point = 1,
            verbose = False
        )

        # Copy / reduce match files
        match_files = sorted(glob(os.path.join(round_folder, '*', '*', '*.match')))
        for match_file in match_files:
            match_out_file = round_prefix + '-' + os.path.dirname(match_file).split('/')[-1] + '.match'
            # reduce the number of matches if specified
            if reduce_matches:           
                print(f'Reducing number of per-pair feature matches to {num_matches}')     
                reduce_asp_match_file(match_file, match_out_file, num_pts=num_matches)
            # otherwise, copy the match file over
            else:
                _ = shutil.copy2(match_file, match_out_file)

        # --- Bundle adjust ---
        # Get the image and camera lists
        round_images = ba_plan[i]['adjust_images'] + ba_plan[i]['fixed_images']
        # For each image, use adjusted camera if available, else original
        round_cams = []
        for img in round_images:
            if img in adjusted_cams:
                round_cams.append(adjusted_cams[img])
            else:
                round_cams.append(find_matching_camera_file(img, cam_folder))
        
        # Construct the fixed image list
        round_fixed_images = ba_plan[i]['fixed_images']
        if round_fixed_images:
            round_fixed_images_txt = os.path.join(round_folder, f'round{round_string}_fixed_images.txt')
            with open(round_fixed_images_txt, 'w') as f:
                for img in round_fixed_images:
                    f.write(img+'\n')

        # Run bundle adjust
        print(f'Running bundle adjust for {len(round_images)-len(round_fixed_images)} moveable and {len(round_fixed_images)} fixed images')
        round_args = (
            ba_opts 
            + round_images
            + round_cams
            + ['-o', round_prefix] 
            )
        if round_fixed_images:
            round_args += ['--fixed-image-list', round_fixed_images_txt]
        _ = run_cmd('parallel_bundle_adjust', round_args)

        # --- Update adjusted cameras ---
        out_cam_list = sorted(glob(round_prefix + '*.tsai'))
        for cam_file in out_cam_list:
            cam_base = os.path.basename(cam_file)
            # remove prefix and extension to get the image base name
            img_name = cam_base.replace(os.path.basename(round_prefix) + '-', '').replace('.tsai', '') + '.tif'
            # only assign cameras for images that were *adjusted* this round
            if img_name not in round_fixed_images:
                adjusted_cams[img_name] = cam_file
            else:
                # delete redundant fixed cameras from this round
                # APPEARS TO NOT BE WORKING
                os.remove(cam_file)

        print('Done')
    
    print('\nIBA complete!')

    return


# Make the bundle adjustment plan
overlap_txt = os.path.join(ba_increm_folder, 'overlapping_image_pairs.txt')
ba_plan_json = os.path.join(ba_increm_folder, 'bundle_adjust_incremental_plan.json')
plan_incremental_bundle_adjust(
    overlap_txt,
    ba_plan_json,
    verbose = False
)

# Run incremental bundle adjust
run_incremental_bundle_adjust(
    ba_plan_json, 
    overlap_txt,
    cam_folder = camgen_folder,
    output_folder = ba_increm_folder,
    reduce_matches = False,
    # num_matches = 100
    )

In [None]:
# Mapproject to check for success

cam_list = sorted(glob(os.path.join(ba_increm_folder, '*', '*.tsai')))
cam_folder_list = sorted(list(set([os.path.dirname(x) for x in cam_list])))
for f in cam_folder_list[0:1]:
    cam_f_list = [x for x in cam_list if os.path.dirname(x)==f]
    img_list = sorted([os.path.join(new_img_folder, os.path.splitext(os.path.basename(x))[0].replace('run-','') + '.tif') for x in cam_f_list])
    run_mapproject(
        img_list = img_list,
        cam_folder = f,
        out_folder = ba_increm_folder,
        dem = refdem_file,
        t_res = 1,
        t_crs = "EPSG:32611"
        )

----------

In [None]:
##### 1. Triplet bundle adjustment #####

ba_triplet_folder = os.path.join(out_folder, 'ba_triplets')
os.makedirs(ba_triplet_folder, exist_ok=True)

# First, identify all overlapping image pairs
identify_overlapping_image_pairs(
    img_list = sorted(glob(os.path.join(init_ortho_folder, '*.tif'))), 
    overlap_perc = 10, 
    utm_epsg = "EPSG:32611",
    out_folder = ba_triplet_folder,
    true_stereo = False,
    )


In [None]:
# Now, identify triplets and within-triplet stereo pairs
print('\nIdentifying triplets')
overlap_txt = os.path.join(ba_triplet_folder, 'overlapping_image_pairs.txt')
triplets_json, pairs_txt = create_triplets_pairs(overlap_txt, ba_triplet_folder, max_group_size=6)

In [None]:
# Bundle adjust

# Load the triplets file
with open(triplets_json, 'r') as f:
    triplets = json.load(f)

# PICKING UP AFTER CANCELLING - Just run groups without new cameras
sub_keys = [g for g in triplets.keys() if len(glob(os.path.join(ba_triplet_folder, g, '*.tsai'))) < 1]
triplets = {key: triplets[key] for key in triplets.keys() if key in sub_keys}

# Iterate over groups
for g in tqdm(list(triplets.keys())[11:12], desc='Triplets-BA'):
    # define output folder and prefix
    g_folder = os.path.join(ba_triplet_folder, g)
    g_prefix = os.path.join(g_folder, 'run')
    os.makedirs(g_folder, exist_ok=True)

    # subset stereo pairs to relevant images
    g_images = triplets[g]['file_names']
    pairs = pd.read_csv(pairs_txt, sep=' ', header=0)
    g_pairs = pd.concat([
        pairs.loc[(pairs['img1'].isin(g_images)) & (pairs['img2'].isin(g_images))],
        pairs.loc[(pairs['img2'].isin(g_images)) & (pairs['img1'].isin(g_images))],
    ]).drop_duplicates()
    g_pairs_txt = os.path.join(g_folder, f"{g}_overlapping_image_pairs.txt")
    g_pairs.to_csv(g_pairs_txt, header=True, index=False, sep=' ')

    # stereo preprocessing
    # print('Running stereo preprocessing for dense feature matching')
    # run_stereo(
    #     stereo_pairs_fn = g_pairs_txt,
    #     cam_folder = camgen_folder,
    #     dem_file = refdem_file,
    #     out_folder = g_folder,
    #     stop_point = 1
    # )

    # # reduce number of matches
    # reduce_asp_match_files(
    #     match_folder = g_folder,
    #     output_folder = g_folder
    # )

    # get camera and GCP lists
    g_cams = [find_matching_camera_file(x, camgen_folder) for x in g_images]

    # bundle adjust
    args = [
        "--threads", "9",
        "--num-iterations", "500",
        "--num-passes", "2",
        "--inline-adjustments",
        "--min-matches", "4",
        "--individually-normalize",
        "--heights-from-dem", refdem_file,
        "--heights-from-dem-uncertainty", "1",
        "--force-reuse-match-files",
        "--skip-matching",
        "-o", g_prefix
        ] + g_images + g_cams

    # run bundle adjust
    log = run_cmd('parallel_bundle_adjust', args)



In [None]:
## TESTING MAPPROJECT AFTER BA TRIPLETS

g_folder = os.path.join(ba_triplet_folder, 'group_11')
cam_list = glob(os.path.join(g_folder, '*.tsai'))
cam_list_base = [os.path.splitext(os.path.basename(x))[0].replace('run-','') for x in cam_list]
img_list = sorted(glob(os.path.join(new_img_folder, '*.tif')))
img_list = [x for x in img_list if os.path.splitext(os.path.basename(x))[0] in cam_list_base]

run_mapproject(
    img_list = img_list,
    cam_folder = g_folder,
    out_folder = g_folder,
    dem = refdem_file,
    t_res = 0.8,
    t_crs = "EPSG:32611"
    )

In [None]:
# copy adjusted cameras to root ba_triplet_folder
cam_list = sorted(glob(os.path.join(ba_triplet_folder, '*', '*.tsai')))
for cam in tqdm(cam_list):
    cam_out = os.path.join(ba_triplet_folder, os.path.basename(cam))
    _ = shutil.copy2(cam, cam_out)

In [None]:
##### 2. Global bundle adjustment #####

ba_global_folder = os.path.join(out_folder, 'ba_global')
os.makedirs(ba_global_folder, exist_ok=True)

# --- Stereo preprocessing ---
# Identify overlapping pairs that were not bundle adjusted together
overlap = pd.read_csv(overlap_txt, sep=' ', header=0)
stereo_pairs = pd.read_csv(pairs_txt, sep=' ', header=0)
merged_df = pd.merge(overlap, stereo_pairs, how='left', indicator=True)
overlap_filtered = merged_df[merged_df['_merge'] == 'left_only'].drop(columns=['_merge']).reset_index(drop=True)

# Subset to overlap > 40%
overlap_filtered = overlap_filtered.loc[overlap_filtered['overlap_percent'] > 40].reset_index(drop=True)

# Save to file
overlap_filtered_txt = os.path.join(ba_global_folder, 'stereo_pairs.txt')
overlap_filtered.to_csv(overlap_filtered_txt, sep=' ', header=True, index=False)
print('Saved stereo pairs to file:', overlap_filtered_txt)

# Run stereo preprocessing
ba_global_stereo_folder = os.path.join(ba_global_folder, 'feature_matches')
run_stereo(
    stereo_pairs_fn = overlap_filtered_txt,
    cam_folder = ba_triplet_folder,
    dem_file = refdem_file,
    out_folder = ba_global_stereo_folder,
    stop_point = 1
)

# --- Reduce number of matches ---
reduce_asp_match_files(
    match_folder = os.path.join(ba_global_folder, 'feature_matches'),
    output_folder = ba_global_folder
    )

In [None]:
# --- Bundle adjust with one triplet fixed ---
cam_list = sorted(glob(os.path.join(ba_triplet_folder, '*.tsai')))
cam_list_base = [os.path.splitext(os.path.basename(x))[0].replace('run-','') for x in cam_list]
image_list = [os.path.join(init_ortho_folder, x + '.tif') for x in cam_list_base]

# Fix one triplet group
with open(triplets_json, 'r') as f:
    triplets = json.load(f)
max_length = max(np.array([triplets[g]['length'] for g in triplets.keys()]))
fixed_g = [g for g in triplets.keys() if triplets[g]['length']==max_length][0]
print(f'Using triplet {fixed_g} as an anchor.')
fixed_images = triplets[fixed_g]['file_names']
fixed_image_txt = os.path.join(ba_global_folder, 'fixed_images.txt')
with open(fixed_image_txt, 'w') as f:
    f.write(' '.join(fixed_images))
print('Saved list of fixed images to file:', fixed_image_txt)

# Run bundle adjust
print('Running global bundle adjust')
args = [
    "--threads", "9",
    "--num-iterations", "500",
    "--num-passes", "2",
    "--inline-adjustments",
    "--min-matches", "4",
    "--individually-normalize",
    "--fixed-image-list", fixed_image_txt,
    "--force-reuse-match-files",
    "--skip-matching",
    "-o", os.path.join(ba_global_folder, 'run')
    ] + image_list + cam_list

# run bundle adjust
log = run_cmd('parallel_bundle_adjust', args)

----------

## Intermediate ortho

In [None]:
interm_ortho_folder = os.path.join(out_folder, 'interm_ortho')
os.makedirs(interm_ortho_folder, exist_ok=True)

# get only images with successful cameras
cam_list = glob(os.path.join(ba_global_folder, '*.tsai'))
cam_list_base = [os.path.splitext(os.path.basename(x))[0].replace('run-run-','') for x in cam_list]
img_list = sorted(glob(os.path.join(new_img_folder, '*.tif')))
img_list = [x for x in img_list if os.path.splitext(os.path.basename(x))[0] in cam_list_base]

run_mapproject(
    img_list = img_list,
    cam_folder = ba_global_folder,
    out_folder = interm_ortho_folder,
    dem = refdem_file,
    t_res = 0.8,
    t_crs = "EPSG:32611"
    )

## Final stereo

In [None]:
os.makedirs(final_stereo_folder, exist_ok=True)
img_list = sorted(glob(os.path.join(interm_ortho_folder, '*.tif')))

# Create list of stereo pairs
identify_overlapping_image_pairs(
    img_list, 
    overlap_perc = 40, 
    utm_epsg = "EPSG:32611",
    out_folder = final_stereo_folder,
    )
overlap_txt = os.path.join(final_stereo_folder, 'overlapping_image_pairs.txt')

# Run stereo
run_stereo(
    stereo_pairs_fn = overlap_txt,
    cam_folder = ba_global_folder,
    dem_fn = refdem_file,
    out_folder = final_stereo_folder,
    correlator_mode = False
)

## DSM mosaicing

## DSM coregistration

## Final ortho

In [None]:
# run_mapproject(img_list: List[str] = None, 
#                    cam_folder: str = None, 
#                    ba_prefix: str = None,
#                    out_folder: str = None, 
#                    dem: str = 'WGS84', 
#                    t_res: float = None, 
#                    t_crs: str = None, 
#                    session: str = None, 
#                    orthomosaic: bool = False)

## Plot results