In [None]:
import numpy as np
import scipy.fftpack as ftp
import matplotlib.pyplot as plt
from scipy.signal import medfilt2d
from skimage.util import pad
import sys
import h5py
import realDataProcess as rdp
import freqCutting as fcut
from scipy import signal
from skimage.morphology import opening, closing
from skimage.measure import label, regionprops

# Defining functions for bragg filtering`

In [None]:
def gkern(kernlen=21, std=3):
    """Returns a 2D Gaussian kernel array."""
    gkern1d = signal.gaussian(kernlen, std=std).reshape(kernlen, 1)
    gkern2d = np.outer(gkern1d, gkern1d)
    return gkern2d

def bragg_filter(img,ksize = 9, sig = 1,shape = (512,512),cut_edge = (50,50), plot = False, morph = False, morph2 = False):
    """Creates a filtered fft of the input image. The gaussian kernel used to find peaks is controlled by ksize and sig."""
    fft = ftp.fftshift(ftp.fft2(img.reshape(img.shape[0],img.shape[1])))
    intf = np.log(abs(fft)**2)
    intfilt= medfilt2d(intf,kernel_size=25)
    if plot == True:
        plt.figure()
        rdp.imm(img)
        plt.figure()
        rdp.imm(intf)
        plt.figure()
        rdp.imm(intfilt)
    dif = intf-intfilt
    dif[dif < 2*intfilt.std()] = 0
    if morph == True:
        dif = closing(dif)
        dif = opening(dif)
    if plot == True:
        plt.figure()
        rdp.imm(dif)
    kernel = gkern(ksize,sig)
    if plot == True:
        plt.figure()
        plt.imshow(kernel)
    peaks = signal.convolve2d(dif,kernel)
    cut_pix_x = (peaks.shape[0]- shape[0])//2
    cut_pix_y = (peaks.shape[1]- shape[1])//2
    peaks = peaks[cut_pix_x:-cut_pix_x,cut_pix_y:-cut_pix_y]
    if plot == True:
        plt.figure()
        rdp.imm(dif)
    peaks[:cut_edge[0],:]= 0
    peaks[-cut_edge[0]:,:] =0
    peaks[:,:cut_edge[1]]= 0
    peaks[:,-cut_edge[1]:] =0
    peaks[peaks>0] = 1
    smoother_peaks = signal.convolve2d(peaks,kernel)
    smoother_peaks = smoother_peaks[cut_pix_x:-cut_pix_x,cut_pix_y:-cut_pix_y]
    if morph2 == True:
        smoother_peaks = closing(smoother_peaks)
        smoother_peaks = opening(smoother_peaks)
    inv_peaks = smoother_peaks.copy()
    inv_peaks += 1
    inv_peaks[inv_peaks> 1] = 0
    inv_peaks = signal.convolve2d(inv_peaks,kernel)
    inv_peaks = inv_peaks[cut_pix_x:-cut_pix_x,cut_pix_y:-cut_pix_y]
    filtered_fft = fft*smoother_peaks
    inv_filtered_fft = fft*inv_peaks
    return filtered_fft, inv_filtered_fft

def bragg_seg(filt_fft):
    """Segmentation of fourier transform"""
    ffft = ftp.fftshift(filt_fft)
    half_point = filt_fft.shape[0]//2
    ffft[half_point:,:] = 0
    seg_map = ftp.ifft2(ffft)
    return ffft,seg_map

def isolate_bragg_peaks(filt_fft, peak_thresh = 100, plot = False):
    """Inputting a filtered fft from the bragg filter function returns array of segmentation maps based on each identified bragg peak"""
    testf, testr = bragg_seg(filt_fft)
    
    image = np.real(np.sqrt(ftp.fftshift(testf.copy())**2)).astype('uint8')

    # apply threshold
    thresh = threshold_otsu(image)
    bw = closing(image > thresh, square(3))

    # remove artifacts connected to image border
    cleared = clear_border(bw)

    # label image regions
    label_image = label(cleared)
    image_label_overlay = label2rgb(label_image, image=image)
    
    if plot == True:
        fig, ax = plt.subplots(figsize=(10, 6))
        ax.imshow(image_label_overlay)

    bragg_spots = []

    for region in regionprops(label_image):
        # take regions with large enough areas
            # draw rectangle around segmented coins
        if region.area > peak_thresh:
            minr, minc, maxr, maxc = region.bbox
            filt = ftp.fftshift(testf.copy())
            filt[:minr,:] = 0
            filt[maxr:,:] = 0
            filt[:,:minc] = 0
            filt[:,maxc:] = 0
            bragg_spots.append(filt)
            if plot == True:
                rect = mpatches.Rectangle((minc, minr), maxc - minc, maxr - minr,
                                          fill=False, edgecolor='red', linewidth=2)
            ax.add_patch(rect)
    if plot == True:
        ax.set_axis_off()
        plt.tight_layout()
        plt.show()

    bragg_spots = np.asarray(bragg_spots)
    return bragg_spots

def plot_ifft3(fft):
    img = abs(ftp.ifft2(fft))**2
    img -= img.min()
    img = img/img.max()
    plt.figure(figsize=(10,10))
    plt.imshow(img,cmap='gray')
    plt.axis('off')
    return img
    

# Implementing bragg filtering on Pd test image

In [None]:
#load data
Pd = np.load('Pdimg.npy')
Pd = Pd.reshape((512,512))

In [None]:
#chack filtering 
fcut.immFFT(Pd_ffft[0])

In [None]:
#create filtered version of image fft
Pd_ffft = bragg_filter(Pd)

In [None]:
#get segmentation maps for each bragg peak
Pd_peaks = isolate_bragg_peaks(Pd_ffft[0].copy(),peak_thresh=30,plot = True)

In [None]:
#plot each of the segmentation maps generated to determine bad maps
a3 = plot_ifft3(Pd_peaks[61])
plt.colorbar()

In [None]:
#generate list of good segmentation maps
good_Pd = [5,9,15,19,24,31,32,37,46,54]

In [None]:
#compile segmentation maps
Pd_final_seg = np.zeros((512,512))
for idx, img in enumerate(Pd_peaks):
    if idx not in good_Pd:
        pass
    else:
        img = abs(ftp.ifft2(img))**2
        img -= img.min()
        img = img/img.max()
        img[img<0.1] = 0
        Pd_final_seg += img

In [None]:
#determine final segmentation map
Pd_final_seg[Pd_final_seg>0.6] = 1
Pd_final_seg[Pd_final_seg<1] = 0
rdp.imm(Pd_final_seg)

In [None]:
#save result
np.save('Pd_bragg_map.npy',Pd_final_seg)

# Bragg filtering on Au

In [None]:
#load data
Au = np.load('Auimg.npy')
Au = Au.reshape((512,512))

In [None]:
Au_ffft, _ = bragg_filter(Au,ksize=3,sig =1,morph = True)

In [None]:
#check filtering
fcut.immFFT(Au_ffft_2)

In [None]:
Au_peaks = isolate_bragg_peaks(Au_ffft_2,plot=True)

In [None]:
Au0 = plot_ifft3(Au_peaks[6])
plt.colorbar()

In [None]:
Au_final_seg = np.zeros((512,512))
for idx, img in enumerate(Au_peaks):
    if idx == 32 or idx == 35 or idx == 41 or idx == 53 or idx == 82 or idx ==86:
        pass
    else:
        img = abs(ftp.ifft2(img))**2
        img -= img.min()
        img = img/img.max()
        img[img<0.07] = 0
        Au_final_seg += img
Au_final_seg=Au_final_seg/Au_final_seg.max()
Au_final_seg[Au_final_seg2>0.4] = 1
Au_final_seg[Au_final_seg2<1] = 0

In [None]:
#display segmentation over original image
rdp.immOverlay(Auimg,Au_final_seg,0)

In [None]:
np.save('Au_bragg_v2.npy',Au_final_seg)

# Bragg filtering on CdSe

In [None]:
#load data
CdSe = np.load('CdSeimg.npy')
CdSe = CdSe.reshape((512,512))

In [None]:
#bragg filter
CdSe_ffft,_ = bragg_filter(CdSe,ksize=5,sig=1,morph2=True)

In [None]:
#cheack bragg filtering
fcut.immFFT(CdSe_ffft)

In [None]:
#isolate peaks and create segmentation map for each
CdSe_peaks = isolate_bragg_peaks(CdSe_ffft.copy(),peak_thresh=30,plot = True)

In [None]:
#determine good maps
_ = plot_ifft3(CdSe_peaks[0])
plt.colorbar()

In [None]:
#create final segmentation map
CdSe_final_seg = np.zeros((512,512))
bad_index = [1,2,3,7,16,20,23,31,25,28,29,30,31]
for idx, img in enumerate(CdSe_peaks):
    if idx in bad_index:
        pass
    else:
        img = abs(ftp.ifft2(img))**2
        img -= img.min()
        img = img/img.max()
        img[img<0.7] = 0
        CdSe_final_seg += img
CdSe_final_seg3 = CdSe_final_seg/CdSe_final_seg.max()
CdSe_final_seg[CdSe_final_seg>0] = 1

In [None]:
#display segmentation over original iamge
rdp.immOverlay(CdSeimg,CdSe_final_seg,0)

In [None]:
np.save('CdSe_bragg_v2.npy',CdSe_final_seg3)