# Cross Correlation Check

### Import Python libraries

In [None]:
import time
import pathlib
import glob
from datetime import datetime
from pprint import pprint

import dask.distributed
import rasterio
from rasterio.windows import from_bounds
from shapely import geometry
from rasterio.mask import mask
from osgeo import gdal
import shapely
from shapely import wkt
from ipyfilechooser import FileChooser
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from skimage.registration import phase_cross_correlation
from PIL import Image as PILImage

PILImage.MAX_IMAGE_PIXELS = None

import opensarlab_lib as asfn

### Define some later-used methods and variables

In [None]:
def save_tiff_as_png(save_path_and_name: str):
    """
    Convert dataframe of tiff into png.
    """
    img = Image.open(save_path_and_name).convert('RGB')
    img.save(save_path_and_name + ".png")

def flatten_df(df):
    df[df < np.nanpercentile(df, 1)] = 0
    df[df > np.nanpercentile(df, 99)] = 0
    return df
    
def convert_rast_to_df(rasterio_obj, window=None):
    if window:
        raster0 = rasterio_obj.read(1, window=window)
    else:
        raster0 = rasterio_obj.read(1)
    df = pd.DataFrame(raster0)

    return df
    
# Get working directory of notebook
CWD = pathlib.Path().absolute()
CWD

# Note: 8x8 tiles will use about 120 GB of RAM  
# ~2GB+ per tile to process
X_NUM = 8
Y_NUM = 8

### Choose tiff files to compare.

In [None]:
# Choose the tiffs
fc1 = FileChooser(f'{CWD}/data/')
display(fc1)

In [None]:
# Choose the tiffs
fc2 = FileChooser(f'{CWD}/data/')
display(fc2)

In [None]:
reference_path = fc1.selected_path
reference_file = fc1.selected_filename
print(reference_path, reference_file)

secondary_path = fc2.selected_path
secondary_file = fc2.selected_filename
print(secondary_path, secondary_file)

### <span style="color:green">The following cells can be ran automatically via the "Run Selected Cell and All Below" menu option.</span> 

In [None]:
!mkdir -p {CWD}/work/
!cp {reference_path}/{reference_file} {CWD}/work/reference.tif
!cp {secondary_path}/{secondary_file} {CWD}/work/secondary.tif

### Convert tiffs to rasterio objects

In [None]:
with asfn.work_dir(f"{CWD}/work/"):
    reference = rasterio.open('reference.tif')    
    secondary = rasterio.open('secondary.tif')
    
    print(reference.meta)
    print(secondary.meta)

### Plot original reference and secondary scenes

In [None]:
df_reference = convert_rast_to_df(reference)
df_secondary = convert_rast_to_df(secondary)

ref_counts, ref_bins = np.histogram(df_reference)
print(ref_counts)

sec_counts, sec_bins = np.histogram(df_secondary)
print(sec_counts)

fig = plt.figure(figsize=(16, 8))
ax1 = fig.add_subplot(221, title="reference")
ax2 = fig.add_subplot(222, title="secondary")
ax3 = fig.add_subplot(223, title="reference hist")
ax4 = fig.add_subplot(224, title="secondary hist")

ax1.imshow(df_reference)
ax2.imshow(df_secondary)
ax3.hist(ref_counts, ref_bins, log=True)
ax4.hist(sec_counts, sec_bins, log=True)

### Flatten Reference and Secondary Scenes


In [None]:
with asfn.work_dir(f"{CWD}/work/"):
    
    df_reference_flatten = flatten_df(df_reference)
    filepath = "flat_reference.tif"   
    with rasterio.open(filepath, 'w', **reference.meta) as out:
        out.write(df_reference_flatten, 1)
        
    df_secondary_flatten = flatten_df(df_secondary)
    filepath = "flat_secondary.tif"
    with rasterio.open(filepath, "w", **secondary.meta) as out:
        out.write(df_secondary_flatten, 1)

### Plot Flatten Reference and Secondary Scenes

In [None]:
ref_counts, ref_bins = np.histogram(df_reference_flatten)
print(ref_counts)

sec_counts, sec_bins = np.histogram(df_secondary_flatten)
print(sec_counts)

fig = plt.figure(figsize=(16, 8))
ax1 = fig.add_subplot(221, title="flatten reference")
ax2 = fig.add_subplot(222, title="flatten secondary")
ax3 = fig.add_subplot(223, title="flatten reference hist")
ax4 = fig.add_subplot(224, title="flatten secondary hist")

ax1.imshow(df_reference_flatten)
ax2.imshow(df_secondary_flatten)
ax3.hist(ref_counts, ref_bins, log=True)
ax4.hist(sec_counts, sec_bins, log=True)

### Find smallest common superset area and transform the scenes

In [None]:
ref_bound = reference.bounds
sec_bound = secondary.bounds

superset = {
    'left': min(ref_bound.left, sec_bound.left),
    'bottom': min(ref_bound.bottom, sec_bound.bottom), 
    'right': max(ref_bound.right, sec_bound.right), 
    'top': max(ref_bound.top, sec_bound.top)
}

print(ref_bound)
print(sec_bound)
print(superset)

with asfn.work_dir(f"{CWD}/work/"):
    gdal.Warp(
            str('flat_reference_superset.tif'), 
            str('flat_reference.tif'),
            outputBounds=(
                superset['left'], 
                superset['bottom'],
                superset['right'],
                superset['top'],
            ),
            outputBoundsSRS=reference.crs #"EPSG:4326"
        )
    
    gdal.Warp(
            str('flat_secondary_superset.tif'), 
            str('flat_secondary.tif'),
            outputBounds=(
                superset['left'], 
                superset['bottom'],
                superset['right'],
                superset['top'],
            ),
            outputBoundsSRS=secondary.crs #"EPSG:4326"
        )

### Show superset-ed scenes

In [None]:
with asfn.work_dir(f"{CWD}/work/"):
    reference_superset = rasterio.open('flat_reference_superset.tif')    
    secondary_superset = rasterio.open('flat_secondary_superset.tif')

In [None]:
df_reference_superset = convert_rast_to_df(reference_superset)
df_reference_superset = flatten_df(df_reference_superset)

df_secondary_superset = convert_rast_to_df(secondary_superset)
df_secondary_superset = flatten_df(df_secondary_superset)


ref_counts, ref_bins = np.histogram(df_reference_superset)
print(ref_counts)

sec_counts, sec_bins = np.histogram(df_secondary_superset)
print(sec_counts)

fig = plt.figure(figsize=(16, 8))
ax1 = fig.add_subplot(221, title="reference")
ax2 = fig.add_subplot(222, title="secondary")
ax3 = fig.add_subplot(223, title="reference hist")
ax4 = fig.add_subplot(224, title="secondary hist")

ax1.imshow(df_reference_superset)
ax2.imshow(df_secondary_superset)
ax3.hist(ref_counts, ref_bins, log=True)
ax4.hist(sec_counts, sec_bins, log=True)

### Perform cross-correlation on the tiffs (replacing Nans with Zeroes)

In [None]:
# Find cross correlation
# https://scikit-image.org/docs/stable/api/skimage.registration.html#skimage.registration.phase_cross_correlation

print("Phase Cross Correlation....")

shift, error, phase = phase_cross_correlation(
    df_reference_superset,
    df_secondary_superset,
    normalization=None
)

print(f"Shift vector (in pixels) required to register moving_image with reference_image: {shift}")
print(f"Translation invariant normalized RMS error between reference_image and moving_image: {error}")
print(f"Global phase difference between the two images (should be zero if images are non-negative).: {phase}\n")

### Tile scenes into equal rectangles/squares

//Save files as PNGs as well

In [None]:
# https://gis.stackexchange.com/a/306862
# Takes a Rasterio dataset and splits it into squares of dimensions squareDim * squareDim
def splitImageIntoCells(img, filename, x_num=1, y_num=1):    
    x_dim = img.shape[1] // x_num
    y_dim = img.shape[0] // y_num

    x, y = 0, 0
    for y_iter in range(y_num):
        y = y_iter * y_dim
        for x_iter in range(x_num):
            x = x_iter * x_dim
            
            filepath = f'{filename}_{y_iter}_{x_iter}.tif'
            print(f"Creating tile {filepath}...")
            
            # Get tile geometry
            corner1 = img.transform * (x, y)
            corner2 = img.transform * (x + x_dim, y + y_dim)
            geom = geometry.box(corner1[0], corner1[1], corner2[0], corner2[1])
            
            # Get cell 
            crop, cropTransform = mask(img, [geom], crop=True)
            img.meta.update(
                {
                    "driver": "GTiff",
                    "height": crop.shape[1],
                    "width": crop.shape[2],
                    "transform": cropTransform,
                    "crs": img.crs
                }
            )
            
            with rasterio.open(filepath, "w", **img.meta) as out:
                out.write(crop)
                
            #rbg = PILImage.open(filepath).convert('RGB')
            #rbg.save(filepath + ".png")

In [None]:
start_time = datetime.now()
print(f"Start time is {start_time}\n")

!mkdir -p {CWD}/work/reference_tiles/
with asfn.work_dir(f"{CWD}/work/reference_tiles"):
    splitImageIntoCells(reference_superset, 'flat_reference', x_num=X_NUM, y_num=Y_NUM)
    
end_time = datetime.now()
print(f"\nEnd time is {end_time}")
print(f"Time elapsed is {end_time - start_time}\n")

In [None]:
start_time = datetime.now()
print(f"Start time is {start_time}\n")

!mkdir -p {CWD}/work/secondary_tiles/
with asfn.work_dir(f"{CWD}/work/secondary_tiles"):
    splitImageIntoCells(secondary_superset, 'flat_secondary', x_num=X_NUM, y_num=Y_NUM)
    
end_time = datetime.now()
print(f"\nEnd time is {end_time}")
print(f"Time elapsed is {end_time - start_time}\n")

### Using Dask parallelization, calculate the cross correlation of the tiles

In [None]:
RAM_PER_WORKER_GB = 20
NUM_WORKERS = 6
NUM_THREADS_PER_WORKER = 1

cluster = dask.distributed.LocalCluster(
    threads_per_worker=NUM_THREADS_PER_WORKER,
    n_workers=NUM_WORKERS,
    memory_limit=f"{RAM_PER_WORKER_GB}GB",
    processes=True
)

client = dask.distributed.Client(cluster)
display(client)

In [None]:
def calc_data(args):
    
        count = args['count']
        ref_file_path = args['ref_file_path']
        sec_file_path = args['sec_file_path']
        
        ###### Reference 
        stime = datetime.now()
        print(f"\nTile {count}: Rendering {ref_file_path}...")
        rast = rasterio.open(ref_file_path)
        
        df_ref = convert_rast_to_df(rast)
        #df_ref = flatten_df(df_ref)
        print(f"Tile {count}: Time to complete ref: {datetime.now() - stime}")

        
        ###### Secondary
        stime = datetime.now()
        print(f"\nTile {count}: Rendering {sec_file_path}...")
        rast = rasterio.open(sec_file_path)
        
        df_sec = convert_rast_to_df(rast)
        #df_sec = flatten_df(df_sec)
        print(f"Tile {count}: Time to complete sec: {datetime.now() - stime}")
    
    
        ####### Cross corr without masking
        stime = datetime.now()
        print(f"\nTile {count}: Finding phase correlation with nans set to zero....")
        shift, error, phase = phase_cross_correlation(
            
            df_ref.replace(np.nan, 0), 
            df_sec.replace(np.nan, 0),
            
            normalization=None
        )
        print(f"Tile {count}: Shift vector (in pixels) required to register moving_image with reference_image: {shift}")
        print(f"Tile {count}: Translation invariant normalized RMS error between reference_image and moving_image: {error}")
        print(f"Tile {count}: Global phase difference between the two images (should be zero if images are non-negative).: {phase}\n")

        print(f"Tile {count}:  Time to complete correlation: {datetime.now() - stime}")
        
        return {
            "count": count,
            "ref_file": ref_file_path,
            "sec_file": sec_file_path,
            "shift": shift, 
            "error": error, 
            "phase": phase
        }

def get_cross_corr_args() -> list:
    
    cross_corr_args = []

    count = 1
    for i in range(X_NUM):

        #if i not in [5,6]:
        #    continue

        for j in range(Y_NUM):

            #if j not in [5,6]:
            #    continue

            cross_corr_args.append({
                'count': count,
                'ref_file_path': f'{CWD}/work/reference_tiles/flat_reference_{i}_{j}.tif',
                'sec_file_path': f'{CWD}/work/secondary_tiles/flat_secondary_{i}_{j}.tif'
            })

            count = count + 1
            
    return cross_corr_args
    
def do_dask(cross_corr_args: list) -> list:
    futures = client.map(calc_data, cross_corr_args)
    dask.distributed.progress(futures)
    
    return client.gather(futures)

In [None]:
start_time = datetime.now()
print(f"Global start time is {start_time}\n")

cross_corr_args = get_cross_corr_args()
cross_corr_results = do_dask(cross_corr_args)

client.shutdown()

end_time = datetime.now()
print(f"\nGlobal end time is {end_time}")
print(f"\nGlobal time elapsed is {end_time - start_time}")

### Print results

In [None]:
pprint(cross_corr_results)

### Plot Reference tiles, Secondary tiles and Correlation tiles

In [None]:
ref_plt = plt.figure(figsize=(10,10))
sec_plt = plt.figure(figsize=(10,10))
corr_plt = plt.figure(figsize=(10,10))        
    
for cross_corr_result in cross_corr_results:
    
    count = cross_corr_result['count']
    ref_file_path = cross_corr_result['ref_file']
    sec_file_path = cross_corr_result['sec_file']
    shift = cross_corr_result['shift']
    error = cross_corr_result['error']
    phase = cross_corr_result['phase']
    
    # Plot references
    rast = rasterio.open(ref_file_path)
    df_ref = convert_rast_to_df(rast)
    
    ax_ref = ref_plt.add_subplot(X_NUM, Y_NUM, count, xticks=[], yticks=[])
    ax_ref.spines[:].set_color('blue')
    ax_ref.imshow(df_ref)
    
    # Plot secondary
    rast = rasterio.open(sec_file_path)
    df_sec = convert_rast_to_df(rast)
    
    ax_sec = sec_plt.add_subplot(X_NUM, Y_NUM, count, xticks=[], yticks=[])
    ax_sec.spines[:].set_color('orange')
    ax_sec.imshow(df_sec)
    
    # Plot result texts
    ax_corr = corr_plt.add_subplot(X_NUM, Y_NUM, count, xticks=[], yticks=[])
    ax_corr.spines[:].set_color('green')
    ax_corr.text(0.1, 0.8, str(shift), transform = ax_corr.transAxes, fontsize='small')
    ax_corr.text(0.1, 0.4, str(error), transform = ax_corr.transAxes, fontsize='x-small')
    ax_corr.text(0.1, 0.1, str(phase), transform = ax_corr.transAxes, fontsize='xx-small')

### Graph some of the results

In [None]:
df_cross_corr_results = pd.DataFrame(cross_corr_results)

# Split out shift into x and y
df_cross_corr_results['xshift'] = df_cross_corr_results.apply(lambda x: x['shift'][0], axis=1)
df_cross_corr_results['yshift'] = df_cross_corr_results.apply(lambda x: x['shift'][1], axis=1)

# Create shift perimeter
df_cross_corr_results['shift_perimeter'] = df_cross_corr_results.apply(lambda x: np.sqrt(x['shift'][0]**2 + x['shift'][1]**2), axis=1)

df_cross_corr_results

In [None]:
df_cross_corr_results.plot.scatter(x = 'shift_perimeter', y = 'error')

In [None]:
df_cross_corr_results.plot.scatter(x = 'xshift', y = 'yshift')