# Explore bundle adjust outputs

In [None]:
import os
import glob
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import shutil
import subprocess
import xdem
from tqdm.auto import tqdm

In [None]:
## Coregister individual DEMs to each other

import geoutils as gu
from shapely.geometry import Polygon

folder = '/Users/raineyaberle/Downloads/final_pinhole_stereo'
asp_dir = '/Users/raineyaberle/Research/PhD/SnowDEMs/StereoPipeline-3.5.0-alpha-2024-10-05-x86_64-OSX/bin'
coregdem_fn = '/Volumes/LaCie/raineyaberle/Research/PhD/SkySat-Stereo/study-sites/MCS/refdem/MCS_REFDEM_WGS84.tif'
in_dir = folder
out_dir = folder
if not os.path.exists(out_dir):
    os.mkdir(out_dir)
    
def get_dem_bbox(dem_fn):
    xmin, ymin, xmax, ymax = list(gu.Raster(dem_fn).bounds)
    bbox = Polygon([[xmin, ymin], [xmax, ymin],
                    [xmax, ymax], [xmin, ymax], [xmin, ymin]])
    return bbox
    
def get_dem_area(dem_fn):
    bbox = get_dem_bbox(dem_fn)
    bbox_area = bbox.area
    return bbox_area

def get_overlapping_dems(reference_fn, dem_list):
    reference_bbox = get_dem_bbox(reference_fn)
    overlapping_dem_list = []
    for dem_fn in dem_list:
        dem_bbox = get_dem_bbox(dem_fn)
        if dem_bbox.intersects(reference_bbox):
            overlapping_dem_list.append(dem_fn)
    # Filter to those with overlapping data values
    overlapping_dem_dict = {}
    reference = xdem.DEM(reference_fn)
    for dem_fn in overlapping_dem_list:
        dem = xdem.DEM(dem_fn).reproject(reference)
        ddem = dem - reference
        n_nodata = len(dem.data.mask[ddem.data.mask==False])
        if n_nodata > 0:
            overlapping_dem_dict[dem_fn] = n_nodata * reference.res[0] * reference.res[1]
            
    return overlapping_dem_dict
    
def align_dems(reference_fn, source_fn, source_out_fn, grid_dem):
    reference = xdem.DEM(reference_fn).reproject(grid_dem)
    source = xdem.DEM(source_fn).reproject(grid_dem)
    nk = xdem.coreg.NuthKaab(subsample=1).fit(reference, source)
    source_corr = nk.apply(source)
    source_corr.save(source_out_fn)
    
def mosaic_dems(dem_list, out_fn, stat, print_output=False):
    cmd = [os.path.join(asp_dir, 'dem_mosaic'),
            '--tr', '2', 
            f'--{stat}',
            '--threads', '12', 
            '-o', out_fn] + dem_list 
    out = subprocess.run(cmd, shell=False, capture_output=True)
    if print_output:
        print(out)

# Load coreg DEM for grid
coregdem = xdem.DEM(coregdem_fn).reproject(res=2)

# Start with the largest DEM
dem_list = glob.glob(os.path.join(out_dir, '20*_map_run-DEM.tif'))
dem_areas = np.array([get_dem_area(dem_fn) for dem_fn in dem_list])
imax = np.argwhere(dem_areas==np.max(dem_areas)).ravel()[0]
reference_dem = dem_list[imax]
coregistered_list = [reference_dem]
print(f'Starting with {reference_dem} as the reference')
tba_dem_list = [dem_fn for dem_fn in dem_list if dem_fn!=reference_dem]
print(f'{len(tba_dem_list)} DEMs to be coregistered')

# Start a progress bar
pbar = tqdm(total=len(dem_list))

# Create a copy of the reference DEM ending in "_coreg-DEM.tif"
reference_dem_coreg = reference_dem.replace('.tif', '_coreg.tif')
shutil.copy2(reference_dem, reference_dem_coreg)
reference_dem = reference_dem_coreg
pbar.update(1)

while tba_dem_list:
    # Get overlapping DEMs
    overlapping_dems = get_overlapping_dems(reference_dem, tba_dem_list)
    # sort by area of overlap
    overlapping_dems = dict(sorted(overlapping_dems.items(), key=lambda item: item[1], reverse=True))
    if len(overlapping_dems) > 0:
        source_dem = list(overlapping_dems.keys())[0]
        dem_out_fn = os.path.splitext(source_dem)[0] + '_coreg.tif'
        align_dems(reference_dem, source_dem, dem_out_fn, grid_dem=coregdem)
        coregistered_list.append(dem_out_fn)

    else:
        # Restart with the largest remaining DEM
        print('No overlapping DEMs, restarting with the largest')
        tba_dem_areas = np.array([get_dem_area(dem_fn) for dem_fn in tba_dem_list])
        imax = np.argwhere(tba_dem_areas==np.max(tba_dem_areas)).ravel()[0]
        source_dem = tba_dem_list[imax]
        reference_dem = source_dem
        coregistered_list.append(source_dem)

    # Create an intermediate DEM from all coregistered DEMs using median statistic
    dem_intermediate = os.path.join(out_dir, 'mosaic_median.tif')
    mosaic_dems(coregistered_list, dem_intermediate, stat='median')
    
    # Update reference DEM to intermediate DEM
    tba_dem_list.remove(source_dem)
    reference_dem = dem_intermediate
    
    pbar.update(1)

pbar.close()

# Create NMAD and count mosaics
print('Creating NMAD and count mosaics')
nmad_mos_fn = os.path.join(out_dir, 'mosaic_nmad.tif')
cmd = [os.path.join(asp_dir, 'dem_mosaic'),
        '--tr', '2', 
        '--nmad',
        '--threads', '12', 
        '-o', nmad_mos_fn] + coregistered_list 
out = subprocess.run(cmd, shell=False, capture_output=True)
print(out)
count_mos_fn = os.path.join(out_dir, 'mosaic_count.tif')
cmd = [os.path.join(asp_dir, 'dem_mosaic'),
        '--tr', '2', 
        '--count',
        '--threads', '12', 
        '-o', count_mos_fn] + coregistered_list 
out = subprocess.run(cmd, shell=False, capture_output=True)
print(out)


In [None]:
# Create median, NMAD, and count mosaics
print('Creating NMAD and count mosaics')
median_mos_fn = os.path.join(out_dir, 'mosaic_median.tif')
mosaic_dems(coregistered_list, median_mos_fn, 'median', print_output=False)
nmad_mos_fn = os.path.join(out_dir, 'mosaic_nmad.tif')
mosaic_dems(coregistered_list, nmad_mos_fn, 'nmad', print_output=False)
count_mos_fn = os.path.join(out_dir, 'mosaic_count.tif')
mosaic_dems(coregistered_list, count_mos_fn, 'count', print_output=False)

# Coregister median mosaic to coreg DEM
georegistered_median_mos_fn = os.path.join(out_dir, 'mosaic_median_georegistered.tif')



In [None]:
# Coregister median DEM to coregdem

out_prefix = os.path.join(out_dir, 'mosaic_median_georegistered')
align_out_fn = out_prefix + '-trans_source.tif'
dem_out_fn = out_prefix + '-trans_source-DEM.tif'
if not os.path.exists(dem_out_fn):
    # run pc_align
    cmd = [os.path.join(asp_dir, 'pc_align'),
           '--max-displacement', '100',
            '--save-transformed-source-points',
            '--highest-accuracy',
            '--threads', '12',
            '-o', out_prefix,
            coregdem_fn, median_mos_fn]
    out = subprocess.run(cmd, shell=False, capture_output=True)
    # run point2dem
    cmd = [os.path.join(asp_dir, 'point2dem'),
           '--threads', '12',
            '-o', os.path.splitext(align_out_fn)[0],
            align_out_fn]
    out = subprocess.run(cmd, shell=False, capture_output=False)


In [None]:
# Coregister final DEM mosaic to coregdem
coregdem_fn = '/Volumes/LaCie/raineyaberle/Research/PhD/SkySat-Stereo/study-sites/MCS/refdem/MCS_REFDEM_WGS84.tif'

# Run alignment
out_prefix = os.path.join(out_dir, 'georegistered_dem_mos')
align_out_fn = out_prefix + '-trans_source.tif'
dem_out_fn = out_prefix + '-trans_source-DEM.tif'
if not os.path.exists(dem_out_fn):
    # run pc_align
    # cmd = [os.path.join(asp_dir, 'pc_align'), 
    #         '--max-displacement', '100',
    #         '--save-transformed-source-points',
    #         '--highest-accuracy',
    #         '--threads', '16',
    #         '-o', out_prefix,
    #         coregdem_fn, dem_intermediate]
    # out = subprocess.run(cmd, shell=False, capture_output=True)
    # print(out)
    # run point2dem
    cmd = [os.path.join(asp_dir, 'point2dem'),
            '--threads', '16',
            '-o', os.path.splitext(align_out_fn)[0],
            align_out_fn]
    out = subprocess.run(cmd, shell=False, capture_output=True)
    print(out)
        

In [2]:
# folder = '/Users/raineyaberle/Downloads/final_pinhole_stereo'
# dems = glob.glob(os.path.join(folder, '20*', '20*', 'run-DEM.tif'))
# for dem in dems:
#     new_dem = os.path.join(folder, dem.split('/')[-2] + '_run-DEM.tif')
#     shutil.move(dem, new_dem)

# asp_dir = '/Users/raineyaberle/Research/PhD/SnowDEMs/StereoPipeline-3.5.0-alpha-2024-10-05-x86_64-OSX/bin'
# dem_list = sorted(glob.glob(os.path.join(folder, '*-DEM.tif')))

# mapproj_stats_fn = '/Users/raineyaberle/Downloads/run-mapproj_match_offset_stats.txt'
# gsd_thresh = 4.9
# ddem_thresh = 5

# # Read the mapproj stats file
# with open(mapproj_stats_fn, 'r') as file:
#     lines = file.readlines()
# column_names = lines[1].lstrip('#').strip().split()
# mapproj_df = pd.read_csv(mapproj_stats_fn, skiprows=2, names=column_names, sep=' ')
# mapproj_df['image_name_base'] = [os.path.basename(x) for x in mapproj_df['image_name']]

# # Filter images using the median GSD difference w.r.t. other images
# # mapproj_df_filt = mapproj_df.loc[mapproj_df['50%'] < gsd_thresh]
# # img_list_filt = sorted(mapproj_df_filt['image_name'].values)
# # print(f'{len(mapproj_df)-len(mapproj_df_filt)} images removed due to less reliable bundle adjustment results, {len(mapproj_df_filt)} images remain for stereo workflow.')
# # # Get DEMs that use the filtered images
# # img_list_filt_prefix = [os.path.splitext(os.path.basename(x))[0] for x in img_list_filt]
# # dem_mosaic_list = []
# # for dem in dem_list:
# #     im1 = os.path.basename(dem).split('__')[0]
# #     im2 = os.path.basename(dem).split('__')[1].split('_map')[0]
# #     if (im1 in img_list_filt_prefix) & (im2 in img_list_filt_prefix):
# #         dem_mosaic_list.append(dem)

# # Filter images using the median difference from the reference DEM
# refdem_fn = '/Volumes/LaCie/raineyaberle/Research/PhD/SkySat-Stereo/study-sites/MCS/refdem/MCS_REFDEM_WGS84.tif'
# refdem = xdem.DEM(refdem_fn).reproject(res=5)

# print('DEM name\tdDEM median [m]\tim1 diff [m]\tim2 diff[m]')
# dem_mosaic_list = []
# for dem_fn in dem_list:
#     dem = xdem.DEM(dem_fn).reproject(refdem)
#     ddem = dem - refdem
#     ddem_median = np.ma.median(ddem.data)
#     im1 = os.path.basename(dem_fn).split('__')[0]
#     im2 = os.path.basename(dem_fn).split('__')[1].split('_map')[0]
#     im1_diff = mapproj_df.loc[mapproj_df['image_name_base']==im1+'.tif', '50%'].values[0]
#     im2_diff = mapproj_df.loc[mapproj_df['image_name_base']==im2+'.tif', '50%'].values[0]
#     # print(f"{os.path.basename(dem_fn)}"
#     #       f"\t{np.round(float(ddem_median),3)}"
#     #       f"\t{np.round(float(im1_diff),3)}" 
#     #       f"\t{np.round(float(im2_diff),3)}")
#     if np.abs(np.ma.median(ddem.data)) < 5:
#         dem_list.append(dem_fn)
        
# # Mosaic DEMs
# print(f"Mosaicking {len(dem_mosaic_list)} DEMs using median and NMAD stats")
# mos_median_fn = os.path.join(folder, 'mos_median.tif')
# cmd = [os.path.join(asp_dir, 'dem_mosaic'),
#        '--tr', '2',
#        '--median',
#        '--threads', '16',
#        '-o', mos_median_fn] + dem_mosaic_list
# out = subprocess.run(cmd, shell=False, capture_output=True)
# print(out)
# mos_nmad_fn = os.path.join(folder, 'mos_nmad.tif')
# cmd = [os.path.join(asp_dir, 'dem_mosaic'),
#        '--tr', '2',
#        '--nmad',
#        '--threads', '16'
#        '-o', mos_nmad_fn] + dem_mosaic_list
# out = subprocess.run(cmd, shell=False, capture_output=True)
# print(out)

# # Coregister to reference DEM
# print('Coregistering to reference DEM')
# nk = xdem.coreg.NuthKaab().fit(refdem, dem)
# dem_corr = nk.apply(dem)
# ddem = dem_corr - refdem

# # Plot results
# mos_median_fn = os.path.join(folder, 'mos_median.tif')
# mos_median = xdem.DEM(mos_median_fn)
# mos_nmad_fn = os.path.join(folder, 'mos_nmad.tif')
# mos_nmad = xdem.DEM(mos_nmad_fn)

# fig, ax = plt.subplots(1, 3, figsize=(14,5))
# mos_median.plot(cmap='terrain', ax=ax[0])
# mos_nmad.plot(cmap='Reds', ax=ax[1])
# ddem.plot(cmap='coolwarm_r', vmin=-5, vmax=5, ax=ax[2])
# plt.show()

In [13]:
data_dir = "/Volumes/LaCie/raineyaberle/Research/PhD/SkySat-Stereo/study-sites/"

def load_plot_pointmap_res(res_pointmap_fn):
    res_pointmap = pd.read_csv(res_pointmap_fn, header=0)
    res_pointmap.rename(columns={'# lon': 'lon', ' lat': 'lat', ' mean_residual': 'mean_residual', ' num_observations': 'num_observations'}, inplace=True)
    res_pointmap = res_pointmap[['lon', 'lat', 'mean_residual']]
    res_pointmap = res_pointmap.iloc[1:]
    res_pointmap = res_pointmap.astype(float)
    res_pointmap.reset_index(drop=True, inplace=True)
    res_pointmap_sub = res_pointmap.iloc[0::10]
    fig = plt.figure(figsize=(10,10))
    sns.scatterplot(res_pointmap_sub, x='lon', y='lat', hue='mean_residual', s=1, palette='coolwarm', hue_norm=(0,1))
    print('Mean residual:', np.nanmean(res_pointmap['mean_residual']))
    print('Median residual:', np.nanmedian(res_pointmap['mean_residual']))
    print('Min residual:', np.nanmin(res_pointmap['mean_residual']))
    print('Max residual:', np.nanmax(res_pointmap['mean_residual']))
    return fig

In [None]:
# MCS_20241003
fns = sorted(glob.glob(os.path.join(data_dir, 'MCS', '20241003', '*pointmap.csv')))
for fn in fns:
    fig = load_plot_pointmap_res(fn)
    title = os.path.basename(fn).split('_run')[0]
    fig.suptitle(title)
    plt.show()

In [None]:
import xdem

dem_fn = os.path.join(data_dir, 'MCS', '20241003', 'coregStable_ba-u0.1m_camweight-0_DEM.tif')
refdem_fn = os.path.join(data_dir, 'MCS', 'refdem', 'MCS_REFDEM_WGS84.tif')
dem = xdem.DEM(dem_fn)
refdem = xdem.DEM(refdem_fn).reproject(dem)
ddem = dem - refdem
ddem.plot(cmap='coolwarm_r', vmin=-5, vmax=5)

In [None]:
import geoutils as gu

fn = os.path.join(data_dir, 'MCS', '20241003', 'run-20241003_221111_ssc9d3_0009_basic_panchromatic-stats.tif')
f = gu.Raster(fn, load_data=True)

f