# SpaceNet Optical-SAR Fusion
#### Created by Jason Brown and Jake Shermeyer

### Imports

In [None]:
# Make the notebook interactive
%matplotlib notebook

# Required libraries
import numpy as np
from scipy import fftpack
from skimage import io 
from skimage import transform
import skimage.color as color
from matplotlib import pyplot as plt 
from scipy.ndimage.filters import uniform_filter
from scipy.ndimage.measurements import variance
import spectral as spy
import gdal
import os
import glob
from tqdm.notebook import tqdm

The imagery data used here should be the corresponding SpaceNet 6 PS-RGB and SAR Intensity tiles. https://spacenet.ai/sn6-challenge/

### Contrast Stretch function

In [None]:
def stretch(bands, lower_percent=1, higher_percent=98):
    np.ma.array(bands, mask=np.isnan(bands))
    out = np.zeros_like(bands)
    a = 0 
    b = 255 
    c = np.percentile(bands, lower_percent)
    d = np.percentile(bands, higher_percent)        
    t = a + (bands - c) * (b - a) / (d - c)    
    t[t<a] = a
    t[t>b] = b
    out = t
    return out.astype(np.uint8)

### Lee Filter Function for Speckle Filtering

In [None]:
def lee_filter(img, size):
    img_mean = uniform_filter(img, (size, size))
    img_sqr_mean = uniform_filter(img**2, (size, size))
    img_variance = img_sqr_mean - img_mean**2
    overall_variance = variance(img)
    img_weights = img_variance / (img_variance + overall_variance)
    img_output = img_mean + img_weights * (img - img_mean)
    return img_output           

### PCT function for Polarimetric bands

In [None]:
### Principal components calculated from the Spectral package: https://github.com/spectralpython/spectral
def pct_image(img, eigs):
    pc = spy.principal_components(img)
    ch = pc.reduce(eigs=eigs) # select the desired principal component
    img_pc = ch.transform(img) # transform principal component into an image
    return img_pc

### Build the Span image

In [None]:
def span_image(img):
    img_sp = img[:,:,0]**2 + 2*abs(img[:,:,1]) + img[:,:,3]**2
    return img_sp

### Fuse data sources

In [None]:
def fusion(rgb, sar, method='hsv'):
    # Get rgb_bands
    R = rgb[:,:,0] # red
    G = rgb[:,:,1] # green
    B = rgb[:,:,2] # blue  
    image = None
    if method == 'simple_mean':
        r = 0.5 * (0.5*R + sar)[:, :, np.newaxis]
        g = 0.5 * (0.5*G + sar)[:, :, np.newaxis]
        b = 0.5 * (0.5*B + sar)[:, :, np.newaxis]
        image = np.concatenate([r,g,b], axis=2)   
    if method == 'hsv':
        hsv = color.rgb2hsv(rgb)
        hsv[:,:,2] = sar 
        image = color.hsv2rgb(hsv)
    return image    

### Main processing code block

In [None]:
def create_multiband_geotiff(array, out_name, proj, geo, nodata=0, out_format=gdal.GDT_Byte, verbose=False):
    """Convert an array to an output georegistered geotiff.
    Arguments
    ---------
    array : numpy array
        A numpy array with a the shape: [Channels, X, Y] or [X, Y]
    out_name : str
        The output name and path for your image
    proj : gdal projection
        A projection, can be extracted from an image opened with gdal with image.GetProjection().  Can be set to None if no georeferencing is required.
    geo : gdal geotransform
        A gdal geotransform which indicates the position of the image on the earth in projection units. Can be set to None if no georeferencing is required.
        Can be extracted from an image opened with gdal with image.GetGeoTransform()
    nodata : int, default - 0
        A value to set transparent for GIS systems. Can be set to None if the nodata value is not required.
    out_format : gdalconst
        https://gdal.org/python/osgeo.gdalconst-module.html
        Must be one of the variables listed in the docs above
    verbose : bool
        A verbose output, printing all inputs and outputs to the function.  Useful for debugging.
    """
    driver = gdal.GetDriverByName('GTiff')
    if len(array.shape) == 2:
        array = array[np.newaxis, ...]
    os.makedirs(os.path.dirname(os.path.abspath(out_name)), exist_ok=True)
    dataset = driver.Create(out_name, array.shape[2], array.shape[1], array.shape[0], out_format)
    if verbose is True:
        print("Array Shape, should be [Channels, X, Y] or [X,Y]:", array.shape)
        print("Output Name:", out_name)
        print("Projection:", proj)
        print("GeoTransform:", geo)
        print("NoData Value:", nodata)
        print("Bit Depth:", out_format)
    if proj is not None:
        dataset.SetProjection(proj)
    if geo is not None:
        dataset.SetGeoTransform(geo)
    if nodata is None:
        for i, image in enumerate(array, 1):
            dataset.GetRasterBand(i).WriteArray(image)
        del dataset
    else:
        for i, image in enumerate(array, 1):
            dataset.GetRasterBand(i).WriteArray(image)
            dataset.GetRasterBand(i).SetNoDataValue(nodata)
        del dataset

def color_sar(sar_tiles_dir, rgb_tiles_dir, tiles_out_dir, search='.tif', method='hsv', span_or_pca='span', filtersize=3, nodata=0, add_4th=False):
    """Colorizes SAR data using co-located RGB imagery and outputs it to the tiles_out_dir.
    Arguments
    ---------
    sar_tiles_dir : str
        Path to the directory that contains SAR data
    rgb_tiles_dir : str
        Path to the directory that contains RGB data
    tiles_out_dir : str
        Path to the directory where colorized SAR will be output
    search : str
        File extension of imagery, defaults to '.tif'
    method : str
        Colorization method, should be 'hsv' or 'simple_mean'.  Defaults to 'hsv.'
    span_or_pca : str
        Calculate a span image or use pca to convert SAR to single channel.  Should be 'span' or 'pca.'
        Defaults to 'span.'
    nodata : int
        A value to set transparent for GIS systems. Can be set to None if the nodata value is not required.
        Defaults to 0.
    filtersize : int
        The filter size for a lee filter to reduce image noise.
    add_4th : bool
        Add an artificial 4th channel for working with networks that expect 4 channel inputs. Defaults to 'False'.
    """
    if not os.path.exists(tiles_out_dir):
        os.makedirs(tiles_out_dir)
    os.chdir(sar_tiles_dir)
    search2 = "*" + search
    images = glob.glob(search2)
    for image in tqdm(images):
        sarimg = gdal.Open(image)
        proj = sarimg.GetProjection()
        geo = sarimg.GetGeoTransform()
        sarimg = sarimg.ReadAsArray()
        sarimg = np.swapaxes(sarimg,0,2)
        sarimg = np.swapaxes(sarimg,0,1)
        if sarimg.shape[2] > 1:
            if span_or_pca == 'span':
                span_img = span_image(sarimg)
            elif span_or_pca == 'pca':
                span_img = pct_image(sarimg, 0)
            else:
                print("Choose 'span' or 'pca' for span_or_pca")
        else:
            span_img = sarimg[0,:,:]
        new_image = image.split("SAR-Intensity")[0] + "PS-RGB" + image.split("SAR-Intensity")[1]
        rgb_file = os.path.join(rgb_tiles_dir, new_image)
        rgbimg = io.imread(rgb_file).astype(np.float32)
        sarimg_hgm = transform.match_histograms(span_img, rgbimg[:,:,1])
        lee_filt_img = lee_filter(sarimg_hgm, filtersize)
        eo_sar_fusion = fusion(rgbimg, lee_filt_img, method=method) # see above for various methods available
        eo_sar_fusion = np.swapaxes(eo_sar_fusion,1,0)
        eo_sar_fusion = np.swapaxes(eo_sar_fusion,2,0)
        output_image = image.split("SAR-Intensity")[0] + "SAR-Intensity-Colorized-HSV" + image.split("SAR-Intensity")[1]
        output_image = os.path.join(tiles_out_dir, output_image)
        if add_4th is True:
            eo_sar_fusion = np.stack([eo_sar_fusion[0,:,:], eo_sar_fusion[1,:,:], eo_sar_fusion[2,:,:], eo_sar_fusion[0,:,:]])
        create_multiband_geotiff(eo_sar_fusion, output_image, proj, geo, nodata=nodata)

Inputs and run function:

In [None]:
sar_tiles_dir = "../train/AOI_11_Rotterdam/SAR-Intensity"
rgb_tiles_dir = "../train/AOI_11_Rotterdam/PS-RGB"
tiles_out_dir = "../train/AOI_11_Rotterdam/Fused-RGB-SAR"

In [None]:
color_sar(sar_tiles_dir, rgb_tiles_dir, tiles_out_dir, search='.tif', method='hsv', span_or_pca='span', filtersize=3, nodata=0, add_4th=False)