## Fracture detection with Complex Shearlet Transform (adpted from https://github.com/rahulprabhakaran/Automatic-Fracture-Detection-Code)
## Using the Python port of the Matlab Toolbox Complex Shearlet-Based Ridge and Edge Measurement by Rafael Reisenhofer: https://github.com/rgcda/PyCoShREM

In [None]:
import sys
import cv2
import easygui
sys.path.append('py_modules')
import ipywidgets as widgets
from IPython.display import display
from matplotlib import pyplot as plt
from processing import ReadImage, SplitInput, GenerateSystems, ImgSizes, DetectFeatures, CheckDetectionParams, WriteImage
%matplotlib inline

## Select Image and load it

In [None]:
img_list = []
filenames  = easygui.fileopenbox("select image file(s)", "CoSh_ensemble", filetypes= "*.jpg", multiple=True)
for f in filenames:
    img_list.append(ReadImage(f))
print('selected', len(img_list), ' images')

In [None]:
style = {'description_width': 'initial'}
waveletEffSupp = widgets.Text(value='25,50,150',placeholder='25,50,150',description='waveletEffSupp:',style=style, disabled=False)
gaussianEffSupp = widgets.Text(value='12,25,75',placeholder='12,25,75',description='gaussianEffSupp:',style=style, disabled=False)
scalesPerOctave = widgets.Text(value='2',placeholder='2',description='scalesPerOctave:',style=style, disabled=False)
shearLevel = widgets.Text(value='3',placeholder='3',description='shearLevel:',style=style, disabled=False)
alpha = widgets.Text(value='1',placeholder='1',description='alpha:',style=style, disabled=False)
octaves = widgets.Text(value='3.5',placeholder='3.5',description='octaves:',style=style, disabled=False)
ridges = widgets.Checkbox(True, description='Ridges')

## Define shearlet parameter combination (comma delimintered)

In [None]:
display(waveletEffSupp)
display(gaussianEffSupp)
display(scalesPerOctave)
display(shearLevel)
display(alpha)
display(octaves)
display(ridges)

In [None]:
wavelet_eff_supp  = SplitInput(waveletEffSupp.value, False)
gaussian_eff_supp = SplitInput(gaussianEffSupp.value, False)
scales_per_octave = SplitInput(scalesPerOctave.value, True)
shear_level = SplitInput(shearLevel.value, True)
ALPHA = SplitInput(alpha.value, False)
OCTAVES = SplitInput(octaves.value, False)
EDGE = False
i_size = ImgSizes(img_list)
systems = GenerateSystems(i_size, wavelet_eff_supp, gaussian_eff_supp, scales_per_octave, shear_level, ALPHA, OCTAVES, ridges.value)

## Detection Parameters

In [None]:
style = {'description_width': 'initial'}
min_contrast = widgets.Text(value='5,10,25,50',placeholder='5,10,25,50',description='minContrast:',style=style, disabled=False)
offset = widgets.Text(value='1',placeholder='1,1.5',description='offset:',style=style, disabled=False)
pivoting_scales = widgets.Dropdown(description='scalesUsedForPivotSearch',style=style, options=['all', 'highest', 'lowest'], value='all', layout=widgets.Layout(width='50%'))
negative = widgets.Checkbox(True, description='negative')
positive = widgets.Checkbox(False, description='positive')

In [None]:
display(min_contrast)
display(offset)
display(pivoting_scales)
if ridges.value:
    display(negative)
    display(positive)

In [None]:
min_contrast = SplitInput(min_contrast.value, False)
offset = SplitInput(offset.value, False)
pivoting_scales = pivoting_scales.value
offset = CheckDetectionParams(offset)

## Detect features in images

In [None]:
features = DetectFeatures(img_list, systems, min_contrast, offset, pivoting_scales, negative.value, positive.value, ridges.value, i_size )

## Display and write image

In [None]:
for i, img in enumerate(img_list):
    f, ax1 = plt.subplots(nrows=1,figsize=(25,25))
    overlay = cv2.addWeighted(img[0],0.001, features[i] ,0.99,0, dtype=cv2.CV_64F)       
    ax1.imshow(overlay, cmap="gray"); 
    ax1.get_xaxis().set_visible(False)
    ax1.axes.get_yaxis().set_visible(False)
    WriteImage(img_list, features, "test")
    input("Next image. (press key)")