# Cross Correlation Check On Whole Stack with VV

### Import Python libraries

In [None]:
import pathlib
import math

from osgeo import gdal
import shapely
from shapely import wkt
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from skimage.registration import phase_cross_correlation

import opensarlab_lib as asfn

### Define the search area WKT

In [None]:
SEARCH_WKT = "POLYGON((-151.4766 62.3946,-143.781 62.3946,-143.781 65.6086,-151.4766 65.6086,-151.4766 62.3946))"

### Define some later-used methods and variables

In [None]:
def tiff_to_df(tiff_filename: str):
    """
    Convert a tiff file to a pandas DataFrame.
    """
    img = gdal.Open(tiff_filename)
    band = img.GetRasterBand(1)
    raster0 = band.ReadAsArray()
    df = pd.DataFrame(raster0)
    
    # Cutoff pixel values to 1 and 99 percentile to reduce saturation (so we can see features)
    df = df[(df > np.percentile(df, 1)) & (df < np.percentile(df, 99))]
    
    return df

def plot_tiffs(df_tiffs: list):
    """
    arg: df_tiffs
    type: list of dataframes
    """
    len_tiffs = len(df_tiffs)
    number_of_x_figs = 4
    number_of_y_figs = math.ceil(len_tiffs / number_of_x_figs)
    
    print(f"A grid of ({number_of_y_figs}, {number_of_x_figs}).")
    
    fig, axs = plt.subplots(number_of_y_figs, number_of_x_figs)
    
    for num, df in enumerate(df_tiffs):
        xx = num % number_of_x_figs
        yy = int(num / number_of_x_figs)
        
        axs[yy, xx].imshow(df)
        vars(df)
        axs[yy, xx].set_title('')

# Get working directory of notebook
CWD = pathlib.Path().absolute()
CWD

### Get all VV tiffs files to compare.

In [None]:
tiff_paths = []

# Get all VV's
vvs = !ls {CWD}/data/S1A_*/S1A_*_VV.tif

!mkdir -p {CWD}/work/
for filepath in vvs:
    tiff_path = pathlib.Path(filepath)
    tiff_paths.append(tiff_path)
    !cp {filepath} {CWD}/work/{tiff_path.name}
    
tiff_paths

### Display the original tiffs

In [None]:
# Display all tiffs
with asfn.work_dir(f"{CWD}/work/"):
    df_tiffs = [tiff_to_df(str(tiff_path)) for tiff_path in tiff_paths]
    plot_tiffs(df_tiffs)
del df_tiffs

### Subset to the AOI and show

In [None]:
# Subset reference and secondary
with asfn.work_dir(f"{CWD}/work/"):

    shape = wkt.loads(SEARCH_WKT)
    bounds = shapely.bounds(shape)
    print(f"Shapely bounds: {bounds}")
    
    tiff_subset_paths = []
    
    for tiff_path in tiff_paths:
        tiff_subset_path = f"{tiff_path.parent / tiff_path.stem}_subset{tiff_path.suffix}"
        tiff_subset_paths.append(tiff_subset_path)
            
        gdal.Warp(
            str(tiff_subset_path), 
            str(tiff_path),
            outputBounds=(bounds[0],bounds[1],bounds[2],bounds[3]),
            outputBoundsSRS="EPSG:4326"
        )
    
    df_subset_tiffs = [tiff_to_df(str(tiff_subset_path)) for tiff_subset_path in tiff_subset_paths]
    plot_tiffs(df_subset_tiffs)


### Show where the NANs are in the subset tiffs

In [None]:
df_tiff_nans = [np.isnan(tiff_to_df(str(tiff_subset_path))) for tiff_subset_path in tiff_subset_paths]
plot_tiffs(df_tiff_nans)
del df_tiff_nans

### Perform cross-correlation on the tiffs, cross correlate one-by-one

In [None]:
with asfn.work_dir(f"{CWD}/work/"):

    # Find cross correlation
    # https://scikit-image.org/docs/stable/api/skimage.registration.html#skimage.registration.phase_cross_correlation
    
    results = []
    
    for x, df_subset_1 in enumerate(df_subset_tiffs):
        for y, df_subset_2 in enumerate(df_subset_tiffs):
            shift, error, phase = phase_cross_correlation(df_subset_1.replace(np.nan,0), df_subset_2.replace(np.nan,0))
            results.append( (x, y, shift, error, phase) )
    
    print(f"Shift: Shift vector (in pixels) required to register moving_image with reference_image.")
    print(f"Error: Translation invariant normalized RMS error between reference_image and moving_image.")
    print(f"Phase: Global phase difference between the two images (should be zero if images are non-negative).\n")
    
    print(results)
    
    #print("Using NaN mask...")
    #shift = phase_cross_correlation(df_reference, df_secondary, reference_mask=~nan_reference_mask, moving_mask=~nan_moving_mask)
    #print(f"Shift vector (in pixels) required to register moving_image with reference_image: {shift}")
    #print(f"No error or phase given with masks.")