# Cross Correlation Check On Whole Stack with VV

### Import Python libraries

In [None]:
import pathlib
import math
import traceback

from osgeo import gdal
import shapely
from shapely import wkt
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from skimage.registration import phase_cross_correlation
import dask.distributed
from PIL import Image

import opensarlab_lib as asfn

### Define the search area WKT

In [None]:
SEARCH_WKT = "POLYGON((-151.4766 62.3946,-143.781 62.3946,-143.781 65.6086,-151.4766 65.6086,-151.4766 62.3946))"

### Define some later-used methods and variables

In [None]:
def tiff_to_df(tiff_filename: str):
    """
    Convert a tiff file to a pandas DataFrame.
    """
    img = gdal.Open(tiff_filename)
    band = img.GetRasterBand(1)
    raster0 = band.ReadAsArray()
    df = pd.DataFrame(raster0)
    
    # Cutoff pixel values to 1 and 99 percentile to reduce saturation (so we can see features)
    df = df[(df > np.percentile(df, 1)) & (df < np.percentile(df, 99))]
    
    return df

def plot_tiffs(df_tiffs: list, width=3):
    """
    arg: df_tiffs
    type: list of dataframes
    """
    len_tiffs = len(df_tiffs)
    number_of_x_figs = 3
    number_of_y_figs = math.ceil(len_tiffs / number_of_x_figs)
    
    print(f"A grid of ({number_of_y_figs}, {number_of_x_figs}).")
    
    fig, axs = plt.subplots(number_of_y_figs, number_of_x_figs)
    
    for num, df_tiff in enumerate(df_tiffs):
        xx = num % number_of_x_figs
        yy = int(num / number_of_x_figs)
        
        axs[yy, xx].imshow(df_tiff['data'])
        #vars(df)
        axs[yy, xx].set_title(df_tiff['name'])

def save_tiff_as_png(save_path_and_name: str):
    """
    Convert dataframe of tiff into png.
    """
    img = Image.open(save_path_and_name).convert('RGB')
    img.save(save_path_and_name + ".png")
        
# Get working directory of notebook
CWD = pathlib.Path().absolute()
CWD

### Get all VV tiffs files to compare.

In [None]:
tiff_paths = []

# Get all VV's
vvs = !ls {CWD}/data/S1A_*/S1A_*_VV.tif

!mkdir -p {CWD}/work/
for filepath in vvs:
    tiff_path = pathlib.Path(filepath)
    !cp {filepath} {CWD}/work/{tiff_path.name}
    tiff_paths.append(f"{CWD}/work/{tiff_path.name}")
    
tiff_paths

### Display the original tiffs

In [None]:
# Display all tiffs
with asfn.work_dir(f"{CWD}/work/"):
    df_tiffs = [{"name": tiff_path.split('/')[-1], "data": tiff_to_df(str(tiff_path))} for tiff_path in tiff_paths]
    plot_tiffs(df_tiffs)
del df_tiffs

### Subset to the AOI and show

In [None]:
# Subset reference and secondary
with asfn.work_dir(f"{CWD}/work/"):

    shape = wkt.loads(SEARCH_WKT)
    bounds = shapely.bounds(shape)
    print(f"Shapely bounds: {bounds}")
    
    tiff_subset_paths = []
    
    for tiff_path in tiff_paths:
        tiff_subset_path = tiff_path.replace("_VV.tif", "_VV_subset.tif")
        tiff_subset_paths.append(tiff_subset_path)
            
        gdal.Warp(
            str(tiff_subset_path), 
            str(tiff_path),
            outputBounds=(bounds[0],bounds[1],bounds[2],bounds[3]),
            outputBoundsSRS="EPSG:4326"
        )
    
    print(f"Subset paths: {tiff_subset_paths}")
    
    df_subset_tiffs = []
    for tiff_subset_path in tiff_subset_paths:
        name = tiff_subset_path.split('/')[-1]
        data = tiff_to_df(str(tiff_subset_path))
        
        save_tiff_as_png(tiff_subset_path)
        
        df_subset_tiffs.append({"name": name, "data": data})
    
    plot_tiffs(df_subset_tiffs)

### Show where the NANs are in the subset tiffs

In [None]:
#df_tiff_nans = [{"name": tiff_subset_path.split('/')[-1], "data": np.isnan(tiff_to_df(str(tiff_subset_path)))} for tiff_subset_path in tiff_subset_paths]
#plot_tiffs(df_tiff_nans)
#del df_tiff_nans

### Perform cross-correlation on the tiffs, cross correlate one-by-one

In [None]:
def get_corr(df_subset):
    import dask.distributed
    
    def logg(level: str, msg: str):
        dask.distributed.get_worker().log_event(level, msg)

    index_1 = df_subset[0]
    df_subset_1 = df_subset[1].get('data', None)
    df_name_1 = df_subset[1].get('name', None)

    index_2 = df_subset[2]
    df_subset_2 = df_subset[3].get('data', None)
    df_name_2 = df_subset[3].get('name', None)

    try:
        logg("my_info", f"Replace Nans with zeroes for ({index_1}, {index_2})")
        df_subset_1 = df_subset_1.replace(np.nan,0)
        df_subset_2 = df_subset_2.replace(np.nan,0)

        logg("my_info", f"Calculate the phase cross corr for ({index_1}, {index_2})") 
        try:
            shift, error, phase = phase_cross_correlation(np.array(df_subset_1), np.array(df_subset_2))
            return (index_1, index_2, df_name_1, df_name_2, shift, error, phase)
        except Exception as e:
            #dask.distributed.get_worker().log_event("corr_log_error", f"No error and phase available for ({x}, {y})")
            #shift = phase_cross_correlation(df_subset_1, df_subset_2)
            return (index_1, index_2, df_name_1, df_name_2, '', str(traceback.format_exc()), '')

    except Exception as e:
        logg("my_error", f"{e}: {traceback.format_exc()} for ({x}, {y})")
        return (index_1, index_2, df_name_1, df_name_2, '', str(traceback.format_exc()), '')

In [None]:
# Get set of indices, df's, names
df_subset = []
for first, df_subset_tiff_1 in enumerate(df_subset_tiffs):
    for second, df_subset_tiff_2 in enumerate(df_subset_tiffs):
        if first in [2,3,4] and second in [5,6,7]:
            df_subset.append( (first, df_subset_tiff_1, second, df_subset_tiff_2) )

upper_ram_limit = 115
RAM_PER_WORKER_GB = 25

import logging

if 'futures' in locals():
    [f.cancel() for f in futures]

client = dask.distributed.Client(
    n_workers=int(upper_ram_limit / RAM_PER_WORKER_GB),
    #threads_per_worker=1,
    memory_limit=f"{RAM_PER_WORKER_GB}GB",
    memory_target_fraction=0.95,
    memory_pause_fraction=0.95,
    #nthreads=1,
    processes=False,
    #silence_logs=logging.ERROR
)

#client.restart()

print(client)
futures = client.map(get_corr, df_subset)
dask.distributed.progress(futures)

### Results

In [None]:
results = client.gather(futures)
df = pd.DataFrame(results, columns =['Ind 1', 'Ind 2', 'Name 1', 'Name 2', 'Shift', 'Error', 'Phase'])
df
#path = f"/home/jovyan/calval-RTC/cross_correlation_relative_geolocation_evaluation/work/{df['Name 1'][1]}"
#print(path)
#print(df_subset_tiffs)
#images = plt.imshow(path)
#images

In [None]:
df['Error'][2]

In [None]:
print("Info logs: ")
print(client.get_events("my_info"))

print("\nError logs: ")
client.get_events("my_error")

In [None]:
#print(futures)

# TO CANCEL ALL WORKERS
#[f.cancel() for f in futures]