# Ridge detection with Complex Shearlets Transform

In [33]:
import os
import cv2
import math
import numpy as np
from osgeo import gdal
import ipywidgets as widgets
from IPython.display import display
from matplotlib import pyplot as plt
%matplotlib inline
from coshrem.shearletsystem import RidgeSystem
%matplotlib inline

## read reoreferencing information if availabe

In [34]:
georef = False
filename = 'img/TEST.tif'
#filename = 'img/gawler_dem.tif'
#filename = 'img/Ortho_3_061.png'
outfile = 'Test_cs_ref.tif'

dataset = gdal.Open(filename, gdal.GA_ReadOnly)
if dataset:
    print("Driver: {}/{}".format(dataset.GetDriver().ShortName,
                                dataset.GetDriver().LongName))
    print("Size is {} x {} x {}".format(dataset.RasterXSize,
                                        dataset.RasterYSize,
                                        dataset.RasterCount))
    if dataset.GetProjection():
        print("Projection is: {}".format(dataset.GetProjection()))
        geotransform = dataset.GetGeoTransform()
        georef = True
        if geotransform:
            print("Geotransform:" ,geotransform)
            print("Origin = ({}, {})".format(geotransform[0], geotransform[3]))
            print("Pixel Size = ({}, {})".format(geotransform[1], geotransform[5]))
            
            gray = np.array(dataset.GetRasterBand(1).ReadAsArray())
    else:
        image = cv2.imread(filename)    
        gray  = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    
MASTER = np.zeros(gray.shape, np.double)
p = gray.shape

Driver: GTiff/GeoTIFF
Size is 368 x 339 x 3


In [35]:
with open('img/CS6l6tB.gif', 'rb') as f:
    img = f.read()
loading_bar = widgets.Image(value=img)
out = widgets.Output()
display(out)
def CSridges(p, wavelet_eff_supp = 60, scales_per_octave = 4, shear_level = 3, alpha = 0.2, octaves = 3.5,
                          min_contrast = 10, offset = 1, positive = True, negative = False):
    
    sys = RidgeSystem(*p,
                wavelet_eff_supp = wavelet_eff_supp,
                gaussian_eff_supp = math.ceil( (math.ceil(p[0]/wavelet_eff_supp) /2)),
                scales_per_octave = scales_per_octave,
                #shear_level = shear_level,
                #alpha = alpha,
                #octaves = octaves)
                     )
    features, orientations = sys.detect(gray, 
                                    min_contrast = min_contrast,  
                                   # offset = offset, 
                                    #pivoting_scales= pivot_scale,
                                    positive_only = positive, 
                                    negative_only = negative)
    return(features)

def InteracRidges(wavelet_eff_supp = 60, scales_per_octave = 4, shear_level = 3, alpha = 0.2, octaves = 3.5,
                  min_contrast = 10, offset = 1, positive = True, negative = False): 
    
    f, ax1 = plt.subplots(nrows=1,figsize=(25,25))
    with out:
        display(loading_bar)
        ridges = CSridges(p, wavelet_eff_supp = wavelet_eff_supp, scales_per_octave = scales_per_octave, shear_level = shear_level, alpha = alpha,
                     min_contrast = min_contrast, offset = offset, positive = positive, negative = negative)
        out.clear_output()
    ax1.imshow(ridges); 
    ax1.get_xaxis().set_visible(False)
    ax1.axes.get_yaxis().set_visible(False)
    x1, x2 = ridges.shape[:2] 
    MASTER[:x1, :x2] = ridges[:x1, :x2]
    Msum = np.sum(MASTER)
    has_nan = np.isnan(Msum)
    if (has_nan):
        print("NaN in array, Check parameter combination.")

widgets.interactive(InteracRidges, wavelet_eff_supp = widgets.IntSlider(value=60 ,min=1, max=1000, step=1,continuous_update=False),
                                       scales_per_octave = widgets.IntSlider(value=4, min=1, max=100, step=1,continuous_update=False),
                                       shear_level = widgets.IntSlider(value=3, min=1, max=10, step=1,continuous_update=False), 
                                       alpha = widgets.FloatSlider(value=0.2, min=1, max=10, step=0.1,continuous_update=False), 
                                       octaves = widgets.FloatSlider(value=3.5, min=1, max=10, step=0.1,continuous_update=False),
                                       min_contrast = widgets.IntSlider(value=10, min=1, max=100, step=1,continuous_update=False),  
                                       offset = widgets.FloatSlider(value=1, min=1, max=100, step=0.1,continuous_update=False)
                                       )   

Output()

interactive(children=(IntSlider(value=60, continuous_update=False, description='wavelet_eff_supp', max=1000, m…

## write final image to file

In [23]:
if georef:
    driver = gdal.GetDriverByName("GTiff")
    outdata = driver.Create(outfile, dataset.RasterXSize, dataset.RasterYSize, 1, gdal.GDT_Float32)
    outdata.SetGeoTransform(dataset.GetGeoTransform())
    outdata.SetProjection(dataset.GetProjection())
    outdata.GetRasterBand(1).WriteArray(MASTER)
    outdata.FlushCache() 
else:
    cv2.imwrite(outfile, MASTER)

In [None]:
outdata = None
band = None
dataset = None