# Cross Correlation Check

### Import Python libraries

In [None]:
import pathlib

from osgeo import gdal
import shapely
from shapely import wkt
from ipyfilechooser import FileChooser
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from skimage.registration import phase_cross_correlation

import opensarlab_lib as asfn

### Define some later-used methods and variables

In [None]:
def tiff_to_df(tiff_filename: str):
    """
    Convert a tiff file to a pandas DataFrame.
    """
    img = gdal.Open(tiff_filename)
    band = img.GetRasterBand(1)
    raster0 = band.ReadAsArray()
    df = pd.DataFrame(raster0)
    
    # Cutoff pixel values to 1 and 99 percentile to reduce saturation (so we can see features)
    df = df[(df > np.percentile(df, 1)) & (df < np.percentile(df, 99))]
    
    return df
    
# Get working directory of notebook
CWD = pathlib.Path().absolute()
CWD

### Choose tiff files to compare.

In [None]:
# Choose the tiffs
fc1 = FileChooser(f'{CWD}/data/')
display(fc1)

In [None]:
# Choose the tiffs
fc2 = FileChooser(f'{CWD}/data/')
display(fc2)

In [None]:
reference_path = fc1.selected_path
reference_file = fc1.selected_filename
print(reference_path, reference_file)

secondary_path = fc2.selected_path
secondary_file = fc2.selected_filename
print(secondary_path, secondary_file)

In [None]:
!mkdir -p {CWD}/work/
!cp {reference_path}/{reference_file} {CWD}/work/reference.tif
!cp {secondary_path}/{secondary_file} {CWD}/work/secondary.tif

### Display the original two tiffs

In [None]:
# Display reference and secondary
with asfn.work_dir(f"{CWD}/work/"):
    df_reference = tiff_to_df("reference.tif")
    df_secondary = tiff_to_df("secondary.tif")
    
    fig = plt.figure(figsize=(16, 8))
    ax1 = fig.add_subplot(121, title="reference")
    ax2 = fig.add_subplot(122, title="secondary")
    ax1.imshow(df_reference)
    ax2.imshow(df_secondary)

### Subset to the AOI and show

In [None]:
# Subset reference and secondary
with asfn.work_dir(f"{CWD}/work/"):

    SEARCH_WKT = "POLYGON((-118.3968 33.9967,-117.9123 33.9967,-117.9123 34.3065,-118.3968 34.3065,-118.3968 33.9967))"
    
    shape = wkt.loads(SEARCH_WKT)
    bounds = shapely.bounds(shape)
    print(f"Shapely bounds: {bounds}")
    
    !rm reference_subset.tif    
    !gdalwarp -te {bounds[0]} {bounds[1]} {bounds[2]} {bounds[3]} -te_srs EPSG:4326 reference.tif reference_subset.tif
    !rm secondary_subset.tif
    !gdalwarp -te {bounds[0]} {bounds[1]} {bounds[2]} {bounds[3]} -te_srs EPSG:4326 secondary.tif secondary_subset.tif
    
    df_reference = tiff_to_df("reference_subset.tif")
    df_secondary = tiff_to_df("secondary_subset.tif")
    
    fig = plt.figure(figsize=(16, 8))
    ax1 = fig.add_subplot(121, title='reference')
    ax2 = fig.add_subplot(122, title='secondary')
    ax1.imshow(df_reference)
    ax2.imshow(df_secondary)

### Show where the NANs are in the tiffs

In [None]:
nan_reference_mask = np.isnan(df_reference)
nan_moving_mask = np.isnan(df_secondary)

fig = plt.figure(figsize=(16, 8))
ax1 = fig.add_subplot(121, title='NaNs - reference')
ax2 = fig.add_subplot(122, title='NaNs - secondary')
ax1.imshow(nan_reference_mask)
ax2.imshow(nan_moving_mask)

### Perform cross-correlation on the tiffs

In [None]:
with asfn.work_dir(f"{CWD}/work/"):

    # Find cross correlation
    # https://scikit-image.org/docs/stable/api/skimage.registration.html#skimage.registration.phase_cross_correlation
    
    print("Replacing NaNs with zero....")
    shift, error, phase = phase_cross_correlation(df_reference.replace(np.nan,0), df_secondary.replace(np.nan,0))
    print(f"Shift vector (in pixels) required to register moving_image with reference_image: {shift}")
    print(f"Translation invariant normalized RMS error between reference_image and moving_image: {error}")
    print(f"Global phase difference between the two images (should be zero if images are non-negative).: {phase}\n")
    
    print("Using NaN mask...")
    shift = phase_cross_correlation(df_reference, df_secondary, reference_mask=~nan_reference_mask, moving_mask=~nan_moving_mask)
    print(f"Shift vector (in pixels) required to register moving_image with reference_image: {shift}")
    print(f"No error or phase given with masks.")