# Cross Correlation Check

### Import Python libraries

In [None]:
import pathlib
import glob 

import rasterio
from rasterio.windows import from_bounds
from osgeo import gdal
import shapely
from shapely import wkt
from ipyfilechooser import FileChooser
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from skimage.registration import phase_cross_correlation
from PIL import Image as PILImage

import opensarlab_lib as asfn

### Define some later-used methods and variables

In [None]:
def save_tiff_as_png(save_path_and_name: str):
    """
    Convert dataframe of tiff into png.
    """
    img = Image.open(save_path_and_name).convert('RGB')
    img.save(save_path_and_name + ".png")
    
def convert_rast_to_df(rasterio_obj, window=None):
    if window:
        raster0 = rasterio_obj.read(1, window=window)
    else:
        raster0 = rasterio_obj.read(1)
    df = pd.DataFrame(raster0)
    df = df[(df > np.percentile(df, 1)) & (df < np.percentile(df, 99))]
    return df
    
# Get working directory of notebook
CWD = pathlib.Path().absolute()
CWD

### Choose tiff files to compare.

In [None]:
# Choose the tiffs
fc1 = FileChooser(f'{CWD}/data/')
display(fc1)

In [None]:
# Choose the tiffs
fc2 = FileChooser(f'{CWD}/data/')
display(fc2)

In [None]:
reference_path = fc1.selected_path
reference_file = fc1.selected_filename
print(reference_path, reference_file)

secondary_path = fc2.selected_path
secondary_file = fc2.selected_filename
print(secondary_path, secondary_file)

In [None]:
!mkdir -p {CWD}/work/
!cp {reference_path}/{reference_file} {CWD}/work/reference.tif
!cp {secondary_path}/{secondary_file} {CWD}/work/secondary.tif

### Convert tiffs to rasterio objects

In [None]:
with asfn.work_dir(f"{CWD}/work/"):
    reference = rasterio.open('reference.tif')    
    secondary = rasterio.open('secondary.tif')
    
    print(reference.meta)
    print(secondary.meta)

### Plot original reference and secondary scenes

In [None]:
df_reference = convert_rast_to_df(reference)
df_secondary = convert_rast_to_df(secondary)

fig = plt.figure(figsize=(16, 8))
ax1 = fig.add_subplot(121, title="reference")
ax2 = fig.add_subplot(122, title="secondary")

ax1.imshow(df_reference)
ax2.imshow(df_secondary)

### Find smallest common superset area and transform the scenes

In [None]:
ref_bound = reference.bounds
sec_bound = secondary.bounds

superset = {
    'left': min(ref_bound.left, sec_bound.left),
    'bottom': min(ref_bound.bottom, sec_bound.bottom), 
    'right': max(ref_bound.right, sec_bound.right), 
    'top': max(ref_bound.top, sec_bound.top)
}

print(ref_bound)
print(sec_bound)
print(superset)

with asfn.work_dir(f"{CWD}/work/"):
    gdal.Warp(
            str('reference_superset.tif'), 
            str('reference.tif'),
            outputBounds=(
                superset['left'], 
                superset['bottom'],
                superset['right'],
                superset['top'],
            ),
            outputBoundsSRS=reference.crs #"EPSG:4326"
        )
    
    gdal.Warp(
            str('secondary_superset.tif'), 
            str('secondary.tif'),
            outputBounds=(
                superset['left'], 
                superset['bottom'],
                superset['right'],
                superset['top'],
            ),
            outputBoundsSRS=secondary.crs #"EPSG:4326"
        )

In [None]:
with asfn.work_dir(f"{CWD}/work/"):
    reference = rasterio.open('reference_superset.tif')    
    secondary = rasterio.open('secondary_superset.tif')

df_reference = convert_rast_to_df(reference)
df_secondary = convert_rast_to_df(secondary)

fig = plt.figure(figsize=(16, 8))
ax1 = fig.add_subplot(121, title="reference")
ax2 = fig.add_subplot(122, title="secondary")

ax1.imshow(df_reference)
ax2.imshow(df_secondary)

### Show where the NANs are in the tiffs

In [None]:
nan_reference_mask = np.isnan(df_reference)
nan_moving_mask = np.isnan(df_secondary)

fig = plt.figure(figsize=(16, 8))
ax1 = fig.add_subplot(121, title='NaNs - reference')
ax2 = fig.add_subplot(122, title='NaNs - secondary')
ax1.imshow(nan_reference_mask)
ax2.imshow(nan_moving_mask)

### Perform cross-correlation on the tiffs

In [None]:
with asfn.work_dir(f"{CWD}/work/"):

    # Find cross correlation
    # https://scikit-image.org/docs/stable/api/skimage.registration.html#skimage.registration.phase_cross_correlation
    
    print("Replacing NaNs with zero....")
    shift, error, phase = phase_cross_correlation(df_reference.replace(np.nan,0), df_secondary.replace(np.nan,0))
    print(f"Shift vector (in pixels) required to register moving_image with reference_image: {shift}")
    print(f"Translation invariant normalized RMS error between reference_image and moving_image: {error}")
    print(f"Global phase difference between the two images (should be zero if images are non-negative).: {phase}\n")
    
    print("Using NaN mask...")
    shift = phase_cross_correlation(df_reference, df_secondary, reference_mask=~nan_reference_mask, moving_mask=~nan_moving_mask)
    print(f"Shift vector (in pixels) required to register moving_image with reference_image: {shift}")
    print(f"No error or phase given with masks.")

### Tile scenes

In [None]:
# https://gis.stackexchange.com/a/306862
from shapely import geometry
from rasterio.mask import mask

# Takes a Rasterio dataset and splits it into squares of dimensions squareDim * squareDim
def splitImageIntoCells(img, filename, x_num=1, y_num=1):    
    x_dim = img.shape[1] // x_num
    y_dim = img.shape[0] // y_num

    x, y = 0, 0
    for y_iter in range(y_num):
        y = y_iter * y_dim
        for x_iter in range(x_num):
            x = x_iter * x_dim
            
            # Get tile geometry
            corner1 = img.transform * (x, y)
            corner2 = img.transform * (x + x_dim, y + y_dim)
            geom = geometry.box(corner1[0], corner1[1], corner2[0], corner2[1])
            
            # Get cell 
            crop, cropTransform = mask(img, [geom], crop=True)
            img.meta.update(
                {
                    "driver": "GTiff",
                    "height": crop.shape[1],
                    "width": crop.shape[2],
                    "transform": cropTransform,
                    "crs": img.crs
                }
            )
            
            filepath = f'{filename}_{y_iter}_{x_iter}.tif'
            with rasterio.open(filepath, "w", **img.meta) as out:
                out.write(crop)
                
            rbg = PILImage.open(filepath).convert('RGB')
            rbg.save(filepath + ".png")

In [None]:
!mkdir -p {CWD}/work/reference_tiles/
with asfn.work_dir(f"{CWD}/work/reference_tiles"):
    splitImageIntoCells(reference, 'reference', x_num=10, y_num=10)

In [None]:
!mkdir -p {CWD}/work/secondary_tiles/
with asfn.work_dir(f"{CWD}/work/secondary_tiles"):
    splitImageIntoCells(secondary, 'secondary', x_num=10, y_num=10)

### Plot tiles

In [None]:
with asfn.work_dir(f"{CWD}/work/"):    
    reference_tile_paths = glob.glob('./reference_tiles/*.png')
    secondary_tile_paths = !ls ./secondary_tiles/*.png

    cnt = len(reference_tile_paths)
    print(cnt)
    
    plt.figure(figsize=(10,5))
    columns = 10
    for i, reference_tile_path in enumerate(reference_tile_paths):
        rast = rasterio.open(reference_tile_path)
        rast_df = convert_rast_to_df(rast)
        plt.subplot(int(cnt / columns) + 1, int(columns), i + 1, xticks=[], yticks=[])
        plt.imshow(rast_df)