# Segmentation by filtering

In [None]:
# %matplotlib notebook
%matplotlib inline

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d
from math import exp
from skimage import io
from skimage.measure import block_reduce

## Sample image

In [None]:
image = io.imread('../../SampleData/SampleExpImage.png')
# image = image[:,:,0]
# io.imsave('../../SampleData/SampleExpImage.png', image)
io.imshow(image)
print(image.shape)

In [None]:
# Fourier transform image
image_fft = np.fft.fft2(image)
image_fft = np.fft.fftshift(image_fft)

image_fft_mag = abs(image_fft)

print(image_fft.shape)
print(image_fft.dtype)
print(image_fft_mag.dtype)

## Basic binary filters

### Create filter

In [None]:
# Draw BPF ring fcn def
def ring_mask(height, width, radius_outer, radius_inner):
    center = (int(width/2), int(height/2))

    Y, X = np.ogrid[:height, :width]
    dist_from_center = np.sqrt((X - center[0])**2 + (Y-center[1])**2)

    mask = np.logical_and(dist_from_center <= radius_outer, dist_from_center >= radius_inner)
    return mask
# ---

shape = image.shape

# Make bases for filters
rows, cols = shape
xc, yc = (rows/2,cols/2)

# Set filter params
rad0_hpf = 5000
rad1_hpf = 100

rad0_bpf = 250
rad1_bpf = 75

# Make X,Y grid same size as base
x = np.array(np.arange(0, cols, 1))
y = np.array(np.arange(0, rows, 1))
X, Y = np.meshgrid(x, y)

# Ring masks
mask_hpf = ring_mask(shape[0], shape[1], rad0_hpf, rad1_hpf)
mask_bpf = ring_mask(shape[0], shape[1], rad0_bpf, rad1_bpf)

# Show masks
fig = plt.figure(figsize=(9,5))
ax1 = fig.add_subplot(121) # HPF
ax1.imshow(mask_hpf, cmap='gray')
ax1 = fig.add_subplot(122) # BPF
ax1.imshow(mask_bpf, cmap='gray')

### Visualize filter

In [None]:
# Z is masks
Z1 = mask_hpf
Z2 = mask_bpf

# Crop (fft is very big because image is big)
Z1 = Z1[200:-200, 200:-200]
# Z2 = Z2[500:-500, 500:-500]

# Downsample masks
# (too big to do surface plot)
Z1 = block_reduce(Z1, block_size=10)
Z2 = block_reduce(Z2, block_size=10)

# Threshold
Z1 = 1*(Z1>0)
Z2 = 1*(Z2>0)

size1 = Z1.shape[0]
size2 = Z2.shape[0]

# Set X, Y grids
x1 = np.array(np.arange(0, size1, 1))
y1 = np.array(np.arange(0, size1, 1))
xc1 = yc1 = size1//2
X1, Y1 = np.meshgrid(x1, y1)

x2 = np.array(np.arange(0, size2, 1))
y2 = np.array(np.arange(0, size2, 1))
xc2 = yc2 = size2//2
X2, Y2 = np.meshgrid(x2, y2)

# Show masks
fig = plt.figure(figsize=(12,8))

# HPF
ax1 = fig.add_subplot(121, projection='3d')

ax1.set_xticks([x1[1], xc1, x1[-3]], ['High\nfreq', 'Low\nfreq', 'High\nfreq'], rotation=0)
ax1.set_yticks([y1[1], yc1, y1[-3]], ['High\nfreq', 'Low\nfreq', 'High\nfreq'], rotation=0)

ax1.view_init(50, -80)
ax1.plot_surface(X1, Y1, Z1, rstride=1, cstride=1, cmap='viridis', linewidth=0, antialiased=False, alpha=0.8)

# BPF
ax2 = fig.add_subplot(122, projection='3d')

ax2.set_xticks([x2[1], xc2, x2[-3]], ['High\nfreq', 'Low\nfreq', 'High\nfreq'], rotation=0)
ax2.set_yticks([y2[1], yc2, y2[-3]], ['High\nfreq', 'Low\nfreq', 'High\nfreq'], rotation=0)

ax2.view_init(30, -80)

ax2.plot_surface(X2, Y2, Z2, rstride=1, cstride=1, cmap='viridis', linewidth=0, antialiased=False, alpha=0.8)

# fig.savefig('basic_filters.png', bbox_inches='tight')

### Applying the filter

In [None]:
# Filter
image_fft_hpf = np.multiply(mask_hpf, image_fft)
image_fft_bpf = np.multiply(mask_bpf, image_fft)

# Convert back IFFT and get magnitude
image_hpf = np.fft.ifftshift(image_fft_hpf)
image_hpf = np.fft.ifft2(image_hpf)
# image_hpf = abs(image_hpf)

image_bpf = np.fft.ifftshift(image_fft_bpf)
image_bpf = np.fft.ifft2(image_bpf)
# image_bpf = abs(image_bpf)

# Show filtered images
fig = plt.figure(figsize=(12,8))
ax1 = fig.add_subplot(121) # HPF
ax1.set_xticks([]);ax1.set_yticks([])
ax1.imshow(image_hpf.astype(np.uint8), cmap='gray')
ax2 = fig.add_subplot(122) # BPF
ax2.imshow(image_bpf.astype(np.uint8), cmap='gray')
ax2.set_xticks([]);ax2.set_yticks([])

# fig.savefig('basic_filters_seg.png', bbox_inches='tight')

## Gaussian

### Create filter

In [None]:
shape = image_fft.shape

# Make base for filters
rows, cols = shape
yc, xc = (rows/2,cols/2)
baseHP = np.zeros(shape)
baseLP = np.zeros(shape)

# Set HP filter params
A_HP = 1
sigmaHP = 150

# Set LP filter params
A_LP = 0.8
sigmaLP = 250

# Calculate HP filter values with HP Gaussian function
for x in range(cols):
    for y in range(rows):
        baseHP[y,x] = A_HP * ( 1 - exp(-abs((x-xc)**2 + (y-yc)**2) / (2*sigmaHP**2)) )

# Calculate LP filter values with LP Gaussian function
for x in range(cols):
    for y in range(rows):
        baseLP[y,x] = A_LP * exp(-abs((x-xc)**2 + (y-yc)**2) / (2*sigmaLP**2))

### Visualize filters

In [None]:
# Z is masks
Z1 = baseHP
Z2 = baseLP

# Downsample masks
# (too big to do surface plot)
Z1 = block_reduce(Z1, block_size=20)
Z2 = block_reduce(Z2, block_size=20)

r, c = Z1.shape

# Make X,Y grid same size as base and set Z vals to Gaussian
x = np.array(np.arange(0, c, 1))
y = np.array(np.arange(0, r, 1))

xc = yc = r//2

X1, Y1 = np.meshgrid(x, y)
X2, Y2 = np.meshgrid(x, y)

# Plot
fig = plt.figure(figsize=(8,8))
ax = fig.add_subplot(111, projection='3d')

ax.set_xticks([x[1], xc, x[-3]], ['High\nfreq', 'Low\nfreq', 'High\nfreq'], rotation=0)
ax.set_yticks([y[1], yc, y[-3]], ['High\nfreq', 'Low\nfreq', 'High\nfreq'], rotation=0)

ax.view_init(50, 70)

ax.plot_surface(X1, Y1, Z1, rstride=1, cstride=1, cmap='plasma', linewidth=0, antialiased=False, alpha=0.5)
ax.plot_surface(X2, Y2, Z2, rstride=1, cstride=1, cmap='viridis', linewidth=0, antialiased=False, alpha=0.5)

# fig.savefig('gaussian_filters.png', bbox_inches='tight')

### Apply filters

In [None]:
# FILTER 
image_fft_filt = np.multiply(baseHP, image_fft)
image_fft_filt = np.multiply(baseLP, image_fft_filt)

# Convert back IFFT
image_filt = np.fft.ifftshift(image_fft_filt)
image_filt = np.fft.ifft2(image_filt)
# image_filt = abs(image_filt) # magnitude

# Show filtered image
fig = plt.figure(figsize=(8,8))
ax = fig.add_subplot(111)
ax.set_xticks([]); ax.set_yticks([])
ax.imshow(image_filt.astype(np.uint8), cmap='gray')

# fig.savefig('gaussian_filter_seg.png', bbox_inches='tight')