# Check Cross Correlation of a Stack

** This notebook is best ran on a machine that can handle 120 GB RAM and 32 CPUs **

The following notebook assumes that a stack of RTC scenes have been created and put into the `./work/original/` directory (possibly through the accompanying `1b_get_prepared_data_from_s3` notebook).

Once imports are imported, Dask is initialized (section 0), and VVs are selected (section 1), the rest of the notebook can be run automatically. Once a section is run, a sequential section can be reran independently. This can reduce the non-linear aspects of notebooks and also allow for more playing with code.   

The following procedures will be applied in this notebook:

1. Select VV tiffs from prepared stack. Any `*_vv.tif` within the selected directory or child directories will be copied. Move tiffs to `./.work/original/`.
1. Superset tiffs to a common AOI and save re-formatted tiffs in `./work/superset/`.
1. Due to spikes in data, flatten the tiffs by chopping off the bottom 1% and top 1%. Save tiffs in `./work/flatten/`.
1. Because of NaNs and other areas of no-data, evenly tile the tiffs (default to 8x8). To speed things up, we use a Dask LocalCluster to multiprocess. Tiles are saved in `./work/tiles/`.
1. Apply the cross-correlation function to the individual tile nearest-chronological pairs. If more than 10% of a tile is NaNs, treat the whole tile as a NaN. Any remaining NaNs are converted to zero. Data is upscaled by a factor of ten. The cross correlation results include the shift in x and y and the RMSE. Results are converted from degrees to meters. The results are saved as json files for each tile pair in `./work/correlation/`.
1. Perform analysis on json results. Results are read into a Pandas DataFrame. A statistical description and graph of the results are shown in two ways: all tiles in a scene are averaged and all tiles are averaged in time. 

### Some Prerequisites

In [None]:
import pathlib
import math
from datetime import datetime
import re
import json

from ipyfilechooser import FileChooser
import rasterio
from rasterio.mask import mask
from shapely import geometry
from osgeo import gdal
import pandas as pd
import numpy as np
import dask.distributed
from skimage.registration import phase_cross_correlation
import matplotlib.pyplot as plt

import opensarlab_lib as asfn

%matplotlib inline

# Get working directory of notebook
CWD = pathlib.Path().cwd()
CWD

METERS_PER_PIXEL = 30

### 0. Setup Dask Methods

Dask on a LocalCluster is used for multiprocessing to make some operations go faster. It is assumed that only one Dask client is used at one time.

In [None]:
def setup_dask(ram_per_worker_gb:int=20, num_workers:int=20, num_threads_per_worker:int=1) -> dask.distributed.Client:
    cluster = dask.distributed.LocalCluster(
        threads_per_worker=num_threads_per_worker,
        n_workers=num_workers,
        memory_limit=f"{ram_per_worker_gb}GB",
        processes=True
    )

    return dask.distributed.Client(cluster)

def teardown_dask(client: dask.distributed.Client) -> None:
    client.shutdown()

def do_dask(client: dask.distributed.Client, callback, args: list):
    try:
        futures = client.map(callback, args)
        dask.distributed.progress(futures)
    except Exception as e:
        print(f"Error in dask: {e}")
        teardown_dask(client)
        return
    
    _  = client.gather(futures)

### 1. Select VVs

Choose the parent directory of all child directories that contain the desired stack of VVs

In [None]:
fc = FileChooser(f'{CWD}/data/')
fc.show_only_dirs = True
display(fc)

In [None]:
# Remove any staged intermediate files to work in a clean area
!mkdir -p "{CWD}/work/original"
for filepath in pathlib.Path(f"{CWD}/work/original").glob("*.tif"):
    filepath.unlink()

In [None]:
all_vv_paths = pathlib.Path(fc.selected_path).glob("**/*_VV.tif")

# Move desired products to work directory.
for source_path in all_vv_paths:
    print(f"Copying {source_path} to {CWD}/work/original/{source_path.name}")
    !cp "{source_path}" "{CWD}/work/original/{source_path.name}"

### 2. Superset VVs

Scene frames have a tendency to move over time. This means that the extant coverage for the whole scene is always different per frame. For the cross-correlation to properly work and for more accurate comparison, all the scenes need to be "normalized" by increasing/decreasing the size of the square extant. 

From extant metadata, get the full superset coordinates for all stack scenes.

In [None]:
# Open all the tiffs and get overall coords.
superset = {
    'left': math.inf,
    'bottom': math.inf,
    'right': -math.inf,
    'top': -math.inf
}

# The SRS is set to the first raster. It is assumed that the SRSs are the same (or close enough) for all.
output_srs = None

vv_original_paths = pathlib.Path(f"{CWD}/work/original").glob(f"*_VV.tif")

for i, original_path in enumerate(vv_original_paths):

    raster = rasterio.open(original_path)    
    raster_bounds = raster.bounds
    print(raster_bounds)
    
    if i == 0:
        output_srs = raster.crs
    
    superset = {
        'left': min(superset['left'], raster_bounds.left),
        'bottom': min(superset['bottom'], raster_bounds.bottom), 
        'right': max(superset['right'], raster_bounds.right), 
        'top': max(superset['top'], raster_bounds.top)
    }

print(f"Superset box coords: {superset}")
print(f"Output SRS: {output_srs}")

In [None]:
# Remove any staged intermediate files to work in a clean area
!mkdir -p "{CWD}/work/superset"
for filepath in pathlib.Path(f"{CWD}/work/superset").glob("*.tif"):
    filepath.unlink()

In [None]:
output_bounds = (
            superset['left'], 
            superset['bottom'],
            superset['right'],
            superset['top'],
        )

print(f"Output bounds (superset) set to '{output_bounds}'")
print(f"Output SRS set to '{output_srs}'")

# Superset and save VVs
vv_original_paths = pathlib.Path(f"{CWD}/work/original").glob(f"*_VV.tif")
for original_path in vv_original_paths:
    
    superset_path = pathlib.Path(str(original_path).replace('original', 'superset'))
    print(f"Taking {original_path} and supersetting to {superset_path}")
    
    gdal.Warp(
        str(superset_path),
        str(original_path), 
        outputBounds=output_bounds,
        outputBoundsSRS=output_srs
    )
    

### 3. Flatten and Save VVs

Often the VVs have extraneous high and low values that make matching difficult. So we need to get rid of these and save the intermediate results.

In [None]:
# Remove any staged intermediate files to work in a clean area
!mkdir -p "{CWD}/work/flatten"
for filepath in pathlib.Path(f"{CWD}/work/flatten").glob("*.tif"):
    filepath.unlink()

In [None]:
def flatten(df: pd.DataFrame) -> pd.DataFrame:
    """
    Truncated values become NaNs
    """
    df[df < np.nanpercentile(df, 1)] = np.nan
    df[df > np.nanpercentile(df, 99)] = np.nan
    return df

# Flatten and save VVs
superset_vv_paths = pathlib.Path(f"{CWD}/work/superset").glob(f"*_VV.tif")

for superset_path in superset_vv_paths:
    print(f"Flattening {superset_path}")
    
    # Convert raster to dataframe
    raster = rasterio.open(superset_path)
    raster_metadata = raster.meta

    raster0 = raster.read(1)
    df_superset = pd.DataFrame(raster0)
    
    # Flatten raster data
    df_flatten = flatten(df_superset)
    
    flatten_path = pathlib.Path(str(superset_path).replace('superset', 'flatten'))
    
    with rasterio.open(flatten_path, 'w', **raster_metadata) as out:
        out.write(df_flatten, 1)

### 4. Tile and Save VVs

In [None]:
def split_into_cells_args(x_num: int, y_num: int) -> list:
    """
    return list of dict of args for `split_into_cells` dask function callback.
    """
    
    args = []
    for i, flatten_path in enumerate(flatten_vv_paths):
        args.append({
            'input_number': i, 
            'input_file': flatten_path, 
            'output_dir': f"{CWD}/work/tiles", 
            'x_num': x_num, 
            'y_num': y_num
        })
    
    return args 

# https://gis.stackexchange.com/a/306862
# Takes a Rasterio dataset and splits it into squares of dimensions squareDim * squareDim
def split_into_cells(args):
    """
    input_number: A sequential number representing the ordering of the scenes. This is to make later scene pairing easier.
    input_file: Full file path of scene to be tiled.
    output_dir: Full path of directory to place tiles.
    x_num: Number of tiles formed in the x direction per scene.
    y_num: Number of tiles formed in the y direction per scene.
    """
    
    input_number: int = args['input_number']
    input_file: str = args['input_file']
    output_dir: str = args['output_dir']
    x_num: int = args.get('x_num', 1)
    y_num: int = args.get('y_num', 1)
    
    print(f"Tileing {input_file}")

    
    raster = rasterio.open(input_file)
    
    x_dim = raster.shape[1] // x_num
    y_dim = raster.shape[0] // y_num

    x, y = 0, 0
    for y_iter in range(y_num):
        y = y_iter * y_dim
        for x_iter in range(x_num):
            x = x_iter * x_dim
            
            input_filestem = pathlib.Path(input_file).stem
            
            output_file = f'{input_filestem}_{input_number}_{y_iter}_{x_iter}.tif'
            print(f"Creating tile {output_file}...")
            
            # Get tile geometry
            corner1 = raster.transform * (x, y)
            corner2 = raster.transform * (x + x_dim, y + y_dim)
            geom = geometry.box(corner1[0], corner1[1], corner2[0], corner2[1])
            
            # Get cell 
            crop, cropTransform = mask(raster, [geom], crop=True)
            raster.meta.update(
                {
                    "driver": "GTiff",
                    "height": crop.shape[1],
                    "width": crop.shape[2],
                    "transform": cropTransform,
                    "crs": raster.crs
                }
            )
                        
            output_filepath = f"{output_dir}/{output_file}"
            with rasterio.open(output_filepath, "w", **raster.meta) as out:
                out.write(crop)

In [None]:
# Remove any staged intermediate files to work in a clean area
!mkdir -p "{CWD}/work/tiles"
for filepath in pathlib.Path(f"{CWD}/work/tiles").glob("*.tif"):
    filepath.unlink()

In [None]:
X_NUM = 8
Y_NUM = 8

flatten_vv_paths = pathlib.Path(f"{CWD}/work/flatten").glob(f"*_VV.tif")
start_time = datetime.now()
print(f"\nStart time is {start_time}")

client = setup_dask(ram_per_worker_gb=20, num_workers=100, num_threads_per_worker=1)
do_dask(client, split_into_cells, split_into_cells_args(x_num=X_NUM, y_num=Y_NUM))

teardown_dask(client)

end_time = datetime.now()
print(f"\nEnd time is {end_time}")
print(f"Time elapsed is {end_time - start_time}\n")  

### 5. Correlate Tiles and Save Results

In [None]:
def get_correlation_args() -> list:
    """
    return [
        {
            'reference_index': '',
            'secondary_index': '',
            'tile_number_x': '',
            'tile_number_y': '',
            'ref_file_path': '',
            'sec_file_path': ''
        },
    ]
    """
    
    tiles_paths = pathlib.Path(f"{CWD}/work/tiles").glob(f"*.tif")
    tiles = []
    
    # Get index and tile numbers from path
    for tiles_path in tiles_paths:

        m = re.match(r".*_([0-9]+)_([0-9]+)_([0-9]+).tif", tiles_path.name)

        tiles.append({
            'index': m.group(1),
            'tile_number_x': m.group(2),
            'tile_number_y': m.group(3),
            'file_path': tiles_path
        })

    tiles_df = pd.DataFrame(tiles).sort_values(by=['tile_number_x', 'tile_number_y', 'index'])

    paris = []
    for i in range(len(tiles_df.index) - 1):

        #if i > 10:
        #    continue
        
        ref_row = tiles_df.iloc[i]
        sec_row = tiles_df.iloc[i+1]

        # If the next row in the sorted dataframe has different tile numbers, then we are at a new set
        if ref_row['tile_number_x'] != sec_row['tile_number_x'] or ref_row['tile_number_y'] != sec_row['tile_number_y']:
            continue

        paris.append({
            'reference_index': ref_row['index'],
            'secondary_index': sec_row['index'],
            'tile_number_x': ref_row['tile_number_x'],
            'tile_number_y': ref_row['tile_number_y'],
            'ref_file_path': ref_row['file_path'],
            'sec_file_path': sec_row['file_path']
        })

    return paris

def correlation_callback(args: dict) -> dict:
    """
    args = {
        'reference_index': '',
        'secondary_index': '',
        'tile_number_x': '',
        'tile_number_y': '',
        'ref_file_path': '',
        'sec_file_path': ''
    }
    """
    
    try:
        reference_index = args['reference_index']
        secondary_index = args['secondary_index']
        tile_number_x = args['tile_number_x']
        tile_number_y = args['tile_number_y']
        ref_file_path = args['ref_file_path']
        sec_file_path = args['sec_file_path']
        
        ###### Reference 
        stime = datetime.now()
        print(f"\nIndex {reference_index} {secondary_index}, Tile {tile_number_x} {tile_number_y}: Rendering {ref_file_path}...")
        rast = rasterio.open(ref_file_path)
        raster0 = rast.read(1)
        df_ref = pd.DataFrame(raster0)
        print(f"Index {reference_index} {secondary_index}, Tile {tile_number_x} {tile_number_y}: Time to complete ref: {datetime.now() - stime}")


        ###### Secondary
        stime = datetime.now()
        print(f"\nIndex {reference_index} {secondary_index}, Tile {tile_number_x} {tile_number_y}: Rendering {sec_file_path}...")
        rast = rasterio.open(sec_file_path)
        raster0 = rast.read(1)
        df_sec = pd.DataFrame(raster0)
        print(f"Index {reference_index} {secondary_index}, Tile {tile_number_x} {tile_number_y}: Time to complete sec: {datetime.now() - stime}")


        ###### If crop tile is more than 10% NANs, skip correlation and set return values to NaN 
        def get_percent_nans(df):
            number_of_elements = df.size
            number_of_nans = df.isnull().sum().sum()

            return number_of_nans / number_of_elements

        percent_nans_ref = get_percent_nans(df_ref)
        percent_nans_sec = get_percent_nans(df_sec)

        if percent_nans_ref > 0.10 or percent_nans_sec > 0.10:
            print(f"\nIndex {reference_index} {secondary_index}, Tile {tile_number_x} {tile_number_y}: Too many NaNs. Skipping correlation....")

            result = {
                "reference_index": int(reference_index),
                "secondary_index": int(secondary_index),
                "tile_number_x": int(tile_number_x),
                "tile_number_y": int(tile_number_y),
                "ref_file": str(ref_file_path),
                "sec_file": str(sec_file_path),
                "shift_x": np.nan,
                "shift_y": np.nan, 
                "error": np.nan, 
                "phase": np.nan,
                "message": "Too many NaNs"
            }

        ####### Cross corr without masking
        stime = datetime.now()
        print(f"\nIndex {reference_index} {secondary_index}, Tile {tile_number_x} {tile_number_y}: Finding phase correlation with nans set to zero....")
        shift, error, phase = phase_cross_correlation(
            df_ref.replace(np.nan, 0), 
            df_sec.replace(np.nan, 0),
            normalization=None,
            upsample_factor=10
        )
        
        shift = shift * METERS_PER_PIXEL
        error = error * METERS_PER_PIXEL
        
        print(f"Index {reference_index} {secondary_index}, Tile {tile_number_x} {tile_number_y}: Shift vector (in meters) required to register moving_image with reference_image: {shift}")
        print(f"Index {reference_index} {secondary_index}, Tile {tile_number_x} {tile_number_y}: Translation invariant normalized RMS error between reference_image and moving_image: {error}")
        print(f"Index {reference_index} {secondary_index}, Tile {tile_number_x} {tile_number_y}: Global phase difference between the two images (should be zero if images are non-negative).: {phase}\n")

        if len(list(shift)) != 2:
            result = {
                "reference_index": int(reference_index),
                "secondary_index": int(secondary_index),
                "tile_number_x": int(tile_number_x),
                "tile_number_y": int(tile_number_y),
                "ref_file": str(ref_file_path),
                "sec_file": str(sec_file_path),
                "shift_x": np.float64(shift[0]),
                "shift_y": np.float64(shift[1]),
                "error": np.nan, 
                "phase": np.nan,
                "message": "Shift is not a two element array"
            }
        
        print(f"Index {reference_index} {secondary_index}, Tile {tile_number_x} {tile_number_y}:  Time to complete correlation: {datetime.now() - stime}")


        ####### Write metadata to correlation result files

        result = {
            "reference_index": int(reference_index),
            "secondary_index": int(secondary_index),
            "tile_number_x": int(tile_number_x),
            "tile_number_y": int(tile_number_y),
            "ref_file": str(ref_file_path),
            "sec_file": str(sec_file_path),
            "shift_x": np.float64(shift[0]),
            "shift_y": np.float64(shift[1]),
            "error": np.float64(error), 
            "phase": np.float64(phase),
            "message": "Correlation successful"
        }

    except Exception as e:
        print(f"An error occurred: {e}")
        result = {
            "reference_index": int(reference_index),
            "secondary_index": int(secondary_index),
            "tile_number_x": int(tile_number_x),
            "tile_number_y": int(tile_number_y),
            "ref_file": str(ref_file_path),
            "sec_file": str(sec_file_path),
            "shift_x": np.nan, 
            "shift_y": np.nan,
            "error": np.nan, 
            "phase": np.nan,
            "message": f"Error: {e}"
        }
        
    try:
        result_file = pathlib.Path(f"{CWD}/work/correlation/index_{reference_index}_{secondary_index}-tile_{tile_number_x}_{tile_number_y}.json")
        with open(result_file, 'w') as f:
            json.dump(result, f)
    except Exception as e:
        print(f"An error occurred: {e}")

In [None]:
# Remove any staged intermediate files to work in a clean area
!mkdir -p "{CWD}/work/correlation"
for filepath in pathlib.Path(f"{CWD}/work/correlation").glob("*.json"):
    filepath.unlink()

In [None]:
start_time = datetime.now()
print(f"\nStart time is {start_time}")

# ram_per_worker_gb:int=20, num_workers:int=20, num_threads_per_worker:int=1
client = setup_dask(ram_per_worker_gb=11, num_workers=10)
do_dask(client, correlation_callback, get_correlation_args())

teardown_dask(client)

end_time = datetime.now()
print(f"\nEnd time is {end_time}")
print(f"Time elapsed is {end_time - start_time}\n")  

### 6. Do Analysis on Tiles

Read the correlation result files from the previous section into a Pandas DataFrame.

Then perform various statistical analyses.

In [None]:
# Put results into 3D Pandas dataset
correlation_paths = pathlib.Path(f"{CWD}/work/correlation").glob(f"*.json")

results = []

for corr_path in correlation_paths:
    with open(corr_path, 'r') as f:
        results.append(json.load(f))

results_df = pd.DataFrame(results)

In [None]:
display(results_df)

Display the stats for the shift, error, and phase of the cross-correlation between two scenes in the stack.

The `reference_index` is the order number of the reference scene within the stack. The stack scenes are ordered from newest to oldest.
The `secondary_index` is the order number of the secondary scene within the stack.

The `mean` value is a simple mean of all tile values. Similarity for `std`, etc.

#### A. Combine all tiles per scene pair correlation results

The cross-correlation results of all tiles in Scene 1 and Scene 2 are re-assembled together into one result and statistically analyzed. Repeat for all pairs. 

In [None]:
paired_df = results_df.groupby(by=['reference_index'])

## Uncomment to display statistical descriptions of DataFrames
#display(paired_df['shift_x'].describe())
#display(paired_df['shift_y'].describe())
#display(paired_df['error'].describe())

display(paired_df)

# Take the Root Mean Square Error of the individual tile RMSE to get the overall RMSE.
def rms(series):
    if np.isnan(series).all():
        return np.nan  
    return np.sqrt(np.nanmean(np.square(series)))
    #return np.nanmean(np.abs(series))

pdf = pd.DataFrame()
pdf['tile_mean_x'] = paired_df['shift_x'].agg(['mean'])
pdf['tile_mean_y'] = paired_df['shift_y'].agg(['mean'])
pdf['error'] = paired_df['error'].agg(rms)
display(pdf)

plt.grid(color='grey', alpha=0.4)
plt.errorbar(pdf['tile_mean_x'], pdf['tile_mean_y'], yerr=pdf['error'], xerr=pdf['error'], ecolor='grey', alpha=0.6, ls='none')
plt.scatter(pdf['tile_mean_x'], pdf['tile_mean_y'], color='black')
plt.xlabel("X Offset (Meters)")
plt.ylabel("Y Offset (Meters)")
plt.title("Cross-correlation Offset Per Stack Pair w/ RMSE")
plt.show()

#### B. Combine all tile correlation results temporally

The cross-correlation result of each individual tile for Scene 1 and Scene 2 are combined together temporally with the corresponding tile in later pairs. This creates a time series by tile. Repeat for all tiles. This is statistically analyzed. 

In [None]:
temporal_df = results_df.groupby(by=['tile_number_x', 'tile_number_y'])

## Uncomment to display statistical descriptions of DataFrames
#display(temporal_df['shift_x'].describe())
#display(temporal_df['shift_y'].describe())
#display(temporal_df['error'].describe())

def rms(series):
    if np.isnan(series).all():
        return np.nan  
    return np.sqrt(np.nanmean(np.square(series)))
    #return np.nanmean(np.abs(series))

tdf = pd.DataFrame()
tdf['tile_mean_x'] = temporal_df['shift_x'].agg(['mean'])
tdf['tile_mean_y'] = temporal_df['shift_y'].agg(['mean'])
tdf['error'] = temporal_df['error'].agg(rms)
display(tdf)

plt.grid(color='grey', alpha=0.4)
plt.errorbar(tdf['tile_mean_x'], tdf['tile_mean_y'], yerr=tdf['error'], xerr=tdf['error'], ecolor='grey', alpha=0.6, ls='none')
plt.scatter(tdf['tile_mean_x'], tdf['tile_mean_y'], color='black')
plt.xlabel("X Offset (Meters)")
plt.ylabel("Y Offset (Meters)")
plt.title("Cross-correlation Offset Per Tile For Whole Stack w/ RMSE")
plt.show()