# Test using autoRIFT for horizontal coregistration

In [None]:
import os
from autoRIFT import autoRIFT
from geogrid import GeogridOptical
import xdem
import geoutils as gu
import matplotlib.pyplot as plt

In [None]:
# Define inputs
data_dir = '/Volumes/LaCie/raineyaberle/Research/PhD/SkySat-Stereo/study-sites/MCS'
refdem_fn = os.path.join(data_dir, 'refdem', 'MCS_REFDEM_WGS84_CHM.tif')
sourcedem_fn = os.path.join(data_dir, '20240420', 'MCS_20240420-1_DEM.tif')
ortho_fn = os.path.join(data_dir, '20240420', 'MCS_20240420-1_4band_orthomosaic.tif')

job_dir = '/Volumes/LaCie/raineyaberle/Research/PhD/SkySat-Stereo/snow_depth_maps/MCS_20240420-1_ref+chm_sourcetrees'
masks_dir = os.path.join(job_dir, 'land_cover_masks')
trees_mask_fn = os.path.join(masks_dir, 'trees_mask.tif')
roads_mask_fn = os.path.join(masks_dir, 'roads_mask.tif')
snow_mask_fn = os.path.join(masks_dir, 'snow_mask.tif')
ss_mask_fn = os.path.join(masks_dir, 'stable_surfaces_mask.tif')

out_dir = os.path.join(job_dir, 'testing_tree_coreg')
if not os.path.exists(out_dir):
    os.mkdir(out_dir)
    print('Made directory for outputs:', out_dir)

In [None]:
# Deramp input DEM on stable surfaces
vmin, vmax = -10, 10

# Define output file names
sourcedem_deramped_fn = os.path.join(out_dir, os.path.basename(sourcedem_fn).replace('.tif', '_deramped.tif'))
fig_fn = os.path.join(out_dir, os.path.basename(sourcedem_fn).replace('.tif', '_deramp_correction.png'))

if not os.path.exists(sourcedem_deramped_fn):
    # Load input files
    refdem = xdem.DEM(refdem_fn)
    sourcedem = xdem.DEM(sourcedem_fn)
    sourcedem = sourcedem.reproject(refdem)
    ss_mask = gu.Raster(ss_mask_fn, load_data=True)
    ss_mask = ss_mask.reproject(ref_dem)
    ss_mask = (ss_mask == 1) # convert to boolean mask

    # Calculate difference refdem
    diff_before = sourcedem - ref_dem

    # Fit and apply Deramp object
    print('Fitting deramper...')
    deramp = xdem.coreg.Deramp(poly_order=2)
    deramp.fit(refdem, sourcedem, inlier_mask=ss_mask)
    meta = deramp.meta
    print(meta)
    dem_deramped = deramp.apply(sourcedem)

    # Save corrected DEM
    sourcedem_deramped.save(sourcedem_deramped_fn)
    print('Deramped DEM saved to file:', sourcedem_deramped_fn)

    # Calculate difference after
    diff_after = sourcedem_deramped - refdem

    # Plot results
    print('Plotting deramp correction results...')
    bins = np.linspace(vmin, vmax, num=100)
    fig, ax = plt.subplots(2, 2, figsize=(10,10))
    ax = ax.flatten()
    diff_before.plot(cmap='coolwarm_r', vmin=vmin, vmax=vmax, ax=ax[0])
    ax[0].set_title('dDEM')
    diff_after.plot(cmap='coolwarm_r', vmin=vmin, vmax=vmax, ax=ax[1])
    ax[1].set_title('Deramped dDEM')
    ax[2].hist(np.ravel(diff_before.data), color='grey', bins=bins)
    ax[2].set_xlim(vmin,vmax)
    ax[2].set_xlabel('Elevation differences (all surfaces) [m]')
    ax[3].hist(np.ravel(diff_after.data), color='grey', bins=bins)
    ax[3].set_xlim(vmin, vmax)
    ax[3].set_xlabel('Elevation differences (all surfaces) [m]')
    fig.tight_layout()
    plt.show()

    # Save figure
    fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
    print('Figure saved to file:', fig_fn)
    
else:
    print('Deramped DEM already exists in file, skipping.')

In [None]:
# Calculate hillshade, mask snow-covered pixels, save hillshades to file

# Define output files
source_hillshade_fn = os.path.join(out_dir, 'source_hillshade_snowfree.tif')
ref_hillshade_fn = os.path.join(out_dir, 'ref_hillshade.tif')
fig_fn = os.path.join(out_dir, 'hillshade_inputs.png')

if not os.path.exists(source_hillshade_fn):

    # Load input files
    print('Loading input files...')

    source_dem = xdem.DEM(sourcedem_deramped_fn)
    ref_dem = xdem.DEM(refdem_fn)
    snow_mask = gu.Raster(snow_mask_fn, load_data=True)
    snow_mask = snow_mask.reproject(ref_dem)
    snow_mask = (snow_mask == 1)

    # Calculate hillshades
    print('Calculating hillshades...')
    source_hillshade = source_dem.hillshade()
    ref_hillshade = ref_dem.hillshade()

    # Mask snow-covered pixels
    print('Masking snow-covered pixels in source hillshade...')
    new_mask = (source_hillshade.data.mask | snow_mask.data.data==1)
    source_hillshade.set_mask(new_mask)
    
    # Save hillshades to file
    source_hillshade.save(source_hillshade_fn)
    print('Source hillshade saved to file:', source_hillshade_fn)
    ref_hillshade.save(ref_hillshade_fn)
    print('Reference hillshade saved to file:', ref_hillshade_fn)
    
    # Plot hillshades
    print('Plotting hillshades...')
    fig, ax = plt.subplots(1, 2, figsize=(12,5))
    ref_hillshade.plot(ax=ax[0], cmap='Greys')
    ax[0].set_title('Reference hillshade')
    source_hillshade.plot(ax=ax[1], cmap='Greys')
    ax[1].set_title('Source hillshade')
    fig.tight_layout()
    plt.show()
    
    # Save figure
    fig.savefig(fig_fn, dpi=300, bbox_inches='tight')
    print('Figure saved to file:', fig_fn)

else:
    print('Input hillshades already exist in file, skipping.')


In [None]:
# Run autoRIFT on input hillshades

# Create autoRIFT object
obj = autoRIFT()

# Define data inputs and settings
I1 = 

In [None]:
def runAutorift(I1, I2, xGrid, yGrid, Dx0, Dy0, SRx0, SRy0, CSMINx0, CSMINy0, CSMAXx0, CSMAXy0, noDataMask, optflag,
                nodata, mpflag, geogrid_run_info=None, preprocessing_methods=('hps', 'hps'),
                preprocessing_filter_width=5):
    '''
    Wire and run geogrid.
    '''

    from autoRIFT import autoRIFT
    import numpy as np
    import time
    
    obj = autoRIFT()

    obj.WallisFilterWidth = preprocessing_filter_width
    print(f'Setting Wallis Filter Width to {preprocessing_filter_width}')

    obj.MultiThread = mpflag

    # take the amplitude only for the radar images
    if optflag == 0:
        I1 = np.abs(I1)
        I2 = np.abs(I2)

    obj.I1 = I1
    obj.I2 = I2

    # create the grid if it does not exist
    if xGrid is None:
        m,n = obj.I1.shape
        xGrid = np.arange(obj.SkipSampleX+10,n-obj.SkipSampleX,obj.SkipSampleX)
        yGrid = np.arange(obj.SkipSampleY+10,m-obj.SkipSampleY,obj.SkipSampleY)
        nd = xGrid.__len__()
        md = yGrid.__len__()
        obj.xGrid = np.int32(np.dot(np.ones((md,1)),np.reshape(xGrid,(1,xGrid.__len__()))))
        obj.yGrid = np.int32(np.dot(np.reshape(yGrid,(yGrid.__len__(),1)),np.ones((1,nd))))
        noDataMask = np.logical_not(obj.xGrid)
    else:
        obj.xGrid = xGrid
        obj.yGrid = yGrid

    # NOTE: This assumes the zero values in the image are only outside the valid image "frame",
    #        but is not true for Landsat-7 after the failure of the Scan Line Corrector, May 31, 2003.
    #        We should not mask based on zero values in the L7 images as this percolates into SearchLimit{X,Y}
    #        and prevents autoRIFT from looking at large parts of the images, but untangling the logic here
    #        has proved too difficult, so lets just turn it off if `wallis_fill` preprocessing is going to be used.
    # generate the nodata mask where offset searching will be skipped based on 1) imported nodata mask and/or 2) zero values in the image
    if 'wallis_fill' not in preprocessing_methods:
        for ii in range(obj.xGrid.shape[0]):
            for jj in range(obj.xGrid.shape[1]):
                if (obj.yGrid[ii,jj] != nodata)&(obj.xGrid[ii,jj] != nodata):
                    if (I1[obj.yGrid[ii,jj]-1,obj.xGrid[ii,jj]-1]==0)|(I2[obj.yGrid[ii,jj]-1,obj.xGrid[ii,jj]-1]==0):
                        noDataMask[ii,jj] = True

    ######### mask out nodata to skip the offset searching using the nodata mask (by setting SearchLimit to be 0)

    if SRx0 is None:
#        ###########     uncomment to customize SearchLimit based on velocity distribution (i.e. Dx0 must not be None)
#        obj.SearchLimitX = np.int32(4+(25-4)/(np.max(np.abs(Dx0[np.logical_not(noDataMask)]))-np.min(np.abs(Dx0[np.logical_not(noDataMask)])))*(np.abs(Dx0)-np.min(np.abs(Dx0[np.logical_not(noDataMask)]))))
#        obj.SearchLimitY = 5
#        ###########
        obj.SearchLimitX = obj.SearchLimitX * np.logical_not(noDataMask)
        obj.SearchLimitY = obj.SearchLimitY * np.logical_not(noDataMask)
    else:
        obj.SearchLimitX = SRx0
        obj.SearchLimitY = SRy0
#        ############ add buffer to search range
#        obj.SearchLimitX[obj.SearchLimitX!=0] = obj.SearchLimitX[obj.SearchLimitX!=0] + 2
#        obj.SearchLimitY[obj.SearchLimitY!=0] = obj.SearchLimitY[obj.SearchLimitY!=0] + 2

    if CSMINx0 is not None:
        obj.ChipSizeMaxX = CSMAXx0
        obj.ChipSizeMinX = CSMINx0

        if geogrid_run_info is None:
            gridspacingx = float(str.split(runCmd('fgrep "Grid spacing in m:" testGeogrid.txt'))[-1])
            chipsizex0 = float(str.split(runCmd('fgrep "Smallest Allowable Chip Size in m:" testGeogrid.txt'))[-1])
            try:
                pixsizex = float(str.split(runCmd('fgrep "Ground range pixel size:" testGeogrid.txt'))[-1])
            except:
                pixsizex = float(str.split(runCmd('fgrep "X-direction pixel size:" testGeogrid.txt'))[-1])
        else:
            gridspacingx = geogrid_run_info['gridspacingx']
            chipsizex0 = geogrid_run_info['chipsizex0']
            pixsizex = geogrid_run_info['XPixelSize']

        obj.ChipSize0X = int(np.ceil(chipsizex0/pixsizex/4)*4)
        obj.GridSpacingX = int(obj.ChipSize0X*gridspacingx/chipsizex0)

        # obj.ChipSize0X = np.min(CSMINx0[CSMINx0!=nodata])
        RATIO_Y2X = CSMINy0/CSMINx0
        obj.ScaleChipSizeY = np.median(RATIO_Y2X[(CSMINx0!=nodata)&(CSMINy0!=nodata)])
        # obj.ChipSizeMaxX = obj.ChipSizeMaxX / obj.ChipSizeMaxX * 544
        # obj.ChipSizeMinX = obj.ChipSizeMinX / obj.ChipSizeMinX * 68
    else:
        if ((optflag == 1)&(xGrid is not None)):
            obj.ChipSizeMaxX = 32
            obj.ChipSizeMinX = 16
            obj.ChipSize0X = 16

    # create the downstream search offset if not provided as input
    if Dx0 is not None:
        obj.Dx0 = Dx0
        obj.Dy0 = Dy0
    else:
        obj.Dx0 = obj.Dx0 * np.logical_not(noDataMask)
        obj.Dy0 = obj.Dy0 * np.logical_not(noDataMask)

    # replace the nodata value with zero
    obj.xGrid[noDataMask] = 0
    obj.yGrid[noDataMask] = 0
    obj.Dx0[noDataMask] = 0
    obj.Dy0[noDataMask] = 0
    if SRx0 is not None:
        obj.SearchLimitX[noDataMask] = 0
        obj.SearchLimitY[noDataMask] = 0
    if CSMINx0 is not None:
        obj.ChipSizeMaxX[noDataMask] = 0
        obj.ChipSizeMinX[noDataMask] = 0

    # convert azimuth offset to vertical offset as used in autoRIFT convention
    if optflag == 0:
        obj.Dy0 = -1 * obj.Dy0



    ######## preprocessing
    t1 = time.time()
    print("Pre-process Start!!!")
    print(f"Using Wallis Filter Width: {obj.WallisFilterWidth}")
#    obj.zeroMask = 1

    # TODO: Allow different filters to be applied images independently
    # default to most stringent filtering
    if 'wallis_fill' in preprocessing_methods:
        obj.preprocess_filt_wal_nodata_fill()
    elif 'wallis' in preprocessing_methods:
        obj.preprocess_filt_wal()
    elif 'fft' in preprocessing_methods:
        # FIXME: The Landsat 4/5 FFT preprocessor looks for the image corners to
        #        determine the scene rotation, but Geogrid + autoRIFT rond the
        #        corners when co-registering and chop the non-overlapping corners
        #        when subsetting to the common image overlap. FFT filer needs to
        #        be applied to the native images before they are processed by
        #        Geogrid or autoRIFT.
        # obj.preprocess_filt_wal()
        # obj.preprocess_filt_fft()
        warnings.warn('FFT filtering must be done before processing with geogrid! Be careful when using this method', UserWarning)
    else:
        obj.preprocess_filt_hps()
#    obj.I1 = np.abs(I1)
#    obj.I2 = np.abs(I2)
    print("Pre-process Done!!!")
    print(time.time()-t1)

    t1 = time.time()
#    obj.DataType = 0
    obj.uniform_data_type()
    print("Uniform Data Type Done!!!")
    print(time.time()-t1)

#    pdb.set_trace()

#    obj.sparseSearchSampleRate = 16

    obj.OverSampleRatio = 64
#    obj.colfiltChunkSize = 4

    #   OverSampleRatio can be assigned as a scalar (such as the above line) or as a Python dictionary below for intellgient use (ChipSize-dependent).
    #   Here, four chip sizes are used: ChipSize0X*[1,2,4,8] and four OverSampleRatio are considered [16,32,64,128]. The intelligent selection of OverSampleRatio (as a function of chip size) was determined by analyzing various combinations of (OverSampleRatio and chip size) and comparing the resulting image quality and statistics with the reference scenario (where the largest OverSampleRatio of 128 and chip size of ChipSize0X*8 are considered).
    #   The selection for the optical data flag is based on Landsat-8 data over an inland region (thus stable and not moving much) of Greenland, while that for the radar flag (optflag = 0) is based on Sentinel-1 data over the same region of Greenland.
    if CSMINx0 is not None:
        if (optflag == 1):
            obj.OverSampleRatio = {obj.ChipSize0X:16,obj.ChipSize0X*2:32,obj.ChipSize0X*4:64,obj.ChipSize0X*8:64}
        else:
            obj.OverSampleRatio = {obj.ChipSize0X:32,obj.ChipSize0X*2:64,obj.ChipSize0X*4:128,obj.ChipSize0X*8:128}



#    ########## export preprocessed images to files; can be commented out if not debugging
#
#    t1 = time.time()
#
#    I1 = obj.I1
#    I2 = obj.I2
#
#    length,width = I1.shape
#
#    filename1 = 'I1_uint8_hpsnew.off'
#
#    slcFid = open(filename1, 'wb')
#
#    for yy in range(length):
#        data = I1[yy,:]
#        data.astype(np.float32).tofile(slcFid)
#
#    slcFid.close()
#
#    img = isceobj.createOffsetImage()
#    img.setFilename(filename1)
#    img.setBands(1)
#    img.setWidth(width)
#    img.setLength(length)
#    img.setAccessMode('READ')
#    img.renderHdr()
#
#
#    filename2 = 'I2_uint8_hpsnew.off'
#
#    slcFid = open(filename2, 'wb')
#
#    for yy in range(length):
#        data = I2[yy,:]
#        data.astype(np.float32).tofile(slcFid)
#
#    slcFid.close()
#
#    img = isceobj.createOffsetImage()
#    img.setFilename(filename2)
#    img.setBands(1)
#    img.setWidth(width)
#    img.setLength(length)
#    img.setAccessMode('READ')
#    img.renderHdr()
#
#    print("output Done!!!")
#    print(time.time()-t1)


    ########## run Autorift
    t1 = time.time()
    print("AutoRIFT Start!!!")
    obj.runAutorift()
    print("AutoRIFT Done!!!")
    print(time.time()-t1)

    import cv2
    kernel = np.ones((3,3),np.uint8)
    noDataMask = cv2.dilate(noDataMask.astype(np.uint8),kernel,iterations = 1)
    noDataMask = noDataMask.astype(np.bool)


    return obj.Dx, obj.Dy, obj.InterpMask, obj.ChipSizeX, obj.GridSpacingX, obj.ScaleChipSizeY, obj.SearchLimitX, obj.SearchLimitY, obj.origSize, noDataMask
