In [None]:
!git clone https://github.com/Jacobiano/morpholayers.git

In [None]:
from imageio import imread
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from matplotlib import pyplot as plt
import numpy as np
from morpholayers.layers import Dilation2D, Erosion2D, Opening2D, Closing2D
from morpholayers.initializers import Quadratic, SEinitializer

## Images

In [None]:
imCol = imread("imageio:chelsea.png")
imCat = imCol[:, :, 0]

In [None]:
plt.figure(figsize=(15, 7))
plt.imshow(imCat, cmap="gray", vmin=0, vmax=255)
plt.axis("off")
plt.show()

In [None]:
from tensorflow.keras.datasets import fashion_mnist

(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
print(x_train.shape)

In [None]:
print("Examples from the training set:")
nsamp = 5
plt.figure(figsize=(30, 5))
ridxs = []
for i in range(nsamp):
    plt.subplot(1, nsamp, i + 1)
    randidx = np.random.randint(x_train.shape[0])
    ridxs.append(randidx)
    plt.imshow(x_train[randidx, :, :], cmap="gray", vmin=0, vmax=255)
    plt.axis("off")
plt.show()

In [None]:
Batch = []
for i in ridxs:
    Batch.append(np.expand_dims(x_train[i, :, :], axis=0))
Batch = np.concatenate(Batch, axis=0)
print("Batch size", Batch.shape)

## Flat operators

### Square structuring element (7 x 7)

In [None]:
# Dilation model
xin = Input(shape=(None, None, 1))
x = Dilation2D(1, kernel_size=(7, 7))(xin)
modelDilFlatSquare = Model(xin, x)
modelDilFlatSquare.summary()

In [None]:
catDil = modelDilFlatSquare.predict(np.expand_dims(imCat, axis=0))

In [None]:
plt.figure(figsize=(30, 12))
plt.subplot(1, 2, 1)
plt.axis("off")
plt.imshow(imCat, cmap="gray", vmin=0, vmax=255)
plt.title("Original", fontsize=20)
plt.subplot(1, 2, 2)
plt.imshow(catDil[0, :, :, 0], cmap="gray", vmin=0, vmax=255)
plt.title("Dilation by 7 x 7 square", fontsize=20)
plt.axis("off")
plt.show()

In [None]:
# Erosion model
xin = Input(shape=(None, None, 1))
x = Erosion2D(1, kernel_size=(7, 7))(xin)
modelEroFlatSquare = Model(xin, x)
modelEroFlatSquare.summary()

In [None]:
catEro = modelEroFlatSquare.predict(np.expand_dims(imCat, axis=0))

In [None]:
plt.figure(figsize=(30, 12))
plt.subplot(1, 2, 1)
plt.axis("off")
plt.imshow(imCat, cmap="gray", vmin=0, vmax=255)
plt.title("Original", fontsize=20)
plt.subplot(1, 2, 2)
plt.imshow(catEro[0, :, :, 0], cmap="gray", vmin=0, vmax=255)
plt.title("Erosion by 7 x 7 square", fontsize=20)
plt.axis("off")
plt.show()

### Cross structuring element (3 x 3)

In [None]:
import skimage.morphology as skm

cross = skm.disk(1)
cross = np.round(-1.0 * (cross - 1))
print(cross)

In [None]:
# Dilation model
xin = Input(shape=(None, None, 1))
x = Dilation2D(
    1,
    kernel_size=(cross.shape[0], cross.shape[1]),
    kernel_initializer=SEinitializer(SE=cross, minval=0),
)(xin)
modelDilCross = Model(xin, x)

In [None]:
listW = modelDilCross.get_weights()
SE = listW[-1]
print(SE.shape)
plt.figure()
plt.imshow(SE[:, :, 0, 0], cmap="gray", vmax=0, vmin=-255)  # RdBu
plt.colorbar()
plt.axis("off")
print(SE.max(), SE.min())

In [None]:
# Apply dilation to all batch of Fashion Mnist images
fmnistDilCross = modelDilCross.predict(Batch)

In [None]:
plt.figure(figsize=(30, 5))
for i in range(nsamp):
    plt.subplot(1, nsamp, i + 1)
    plt.imshow(Batch[i], cmap="gray", vmin=0, vmax=255)
    plt.axis("off")
plt.show()
plt.figure(figsize=(30, 5))
for i in range(nsamp):
    plt.subplot(1, nsamp, i + 1)
    plt.imshow(fmnistDilCross[i, :, :, 0], cmap="gray", vmin=0, vmax=255)
    plt.axis("off")
plt.show()

In [None]:
# Erosion model
xin = Input(shape=(None, None, 1))
x = Erosion2D(
    1,
    kernel_size=(cross.shape[0], cross.shape[1]),
    kernel_initializer=SEinitializer(SE=cross, minval=0),
)(xin)
modelEroCross = Model(xin, x)

In [None]:
# Opening model
xin = Input(shape=(None, None, 1))
x = Opening2D(
    1,
    kernel_size=(cross.shape[0], cross.shape[1]),
    kernel_initializer=SEinitializer(SE=cross, minval=0),
)(xin)
modelOpenCross = Model(xin, x)

In [None]:
# Closing model
xin = Input(shape=(None, None, 1))
x = Closing2D(
    1,
    kernel_size=(cross.shape[0], cross.shape[1]),
    kernel_initializer=SEinitializer(SE=cross, minval=0),
)(xin)
modelClosCross = Model(xin, x)

In [None]:
# Apply operators to all batch of Fashion Mnist images
fmnistEroCross = modelEroCross.predict(Batch)
fmnistOpenCross = modelOpenCross.predict(Batch)
fmnistClosCross = modelClosCross.predict(Batch)

In [None]:
print("Original")
plt.figure(figsize=(30, 5))
for i in range(nsamp):
    plt.subplot(1, nsamp, i + 1)
    plt.imshow(Batch[i], cmap="gray", vmin=0, vmax=255)
    plt.axis("off")
plt.show()
print("Erosion")
plt.figure(figsize=(30, 5))
for i in range(nsamp):
    plt.subplot(1, nsamp, i + 1)
    plt.imshow(fmnistEroCross[i, :, :, 0], cmap="gray", vmin=0, vmax=255)
    plt.axis("off")
plt.show()

In [None]:
print("Original")
plt.figure(figsize=(30, 5))
for i in range(nsamp):
    plt.subplot(1, nsamp, i + 1)
    plt.imshow(Batch[i], cmap="gray", vmin=0, vmax=255)
    plt.axis("off")
plt.show()
print("Opening")
plt.figure(figsize=(30, 5))
for i in range(nsamp):
    plt.subplot(1, nsamp, i + 1)
    plt.imshow(fmnistOpenCross[i, :, :, 0], cmap="gray", vmin=0, vmax=255)
    plt.axis("off")
plt.show()

In [None]:
print("Original")
plt.figure(figsize=(30, 5))
for i in range(nsamp):
    plt.subplot(1, nsamp, i + 1)
    plt.imshow(Batch[i], cmap="gray", vmin=0, vmax=255)
    plt.axis("off")
plt.show()
print("Closing")
plt.figure(figsize=(30, 5))
for i in range(nsamp):
    plt.subplot(1, nsamp, i + 1)
    plt.imshow(fmnistClosCross[i, :, :, 0], cmap="gray", vmin=0, vmax=255)
    plt.axis("off")
plt.show()

### Quadratic structuring element

Structuring function: $$b(x) = -c\cdot \frac{||x||^2}{t^2} $$

In [None]:
p = 5  # SE of size (2*p+1)^2
tval = 2 * p
cval = 0.5 * 255

In [None]:
# Dilation model
xin = Input(shape=(None, None, 1))
x = Dilation2D(
    1,
    kernel_size=(2 * p + 1, 2 * p + 1),
    kernel_initializer=Quadratic(tvalue=tval, cvalue=cval),
)(xin)
modelDilQuad = Model(xin, x)
modelDilQuad.summary()

In [None]:
listW = modelDilQuad.get_weights()
SE = listW[-1]
print(SE.shape)
plt.figure()
plt.imshow(SE[:, :, 0, 0], cmap="gray", vmax=0, vmin=-255)  # RdBu
plt.colorbar()
plt.axis("off")
print(SE.max(), SE.min())

In [None]:
# Erosion model
xin = Input(shape=(None, None, 1))
x = Erosion2D(
    1,
    kernel_size=(2 * p + 1, 2 * p + 1),
    kernel_initializer=Quadratic(tvalue=tval, cvalue=cval),
)(xin)
modelEroQuad = Model(xin, x)

In [None]:
# Opening model
xin = Input(shape=(None, None, 1))
x = Opening2D(
    1,
    kernel_size=(2 * p + 1, 2 * p + 1),
    kernel_initializer=Quadratic(tvalue=tval, cvalue=cval),
)(xin)
modelOpenQuad = Model(xin, x)

In [None]:
# Closing model
xin = Input(shape=(None, None, 1))
x = Closing2D(
    1,
    kernel_size=(2 * p + 1, 2 * p + 1),
    kernel_initializer=Quadratic(tvalue=tval, cvalue=cval),
)(xin)
modelClosQuad = Model(xin, x)

In [None]:
# Apply operators
catDilQuad = modelDilQuad.predict(np.expand_dims(imCat, axis=0))
catEroQuad = modelEroQuad.predict(np.expand_dims(imCat, axis=0))
catOpenQuad = modelOpenQuad.predict(np.expand_dims(imCat, axis=0))
catClosQuad = modelClosQuad.predict(np.expand_dims(imCat, axis=0))

In [None]:
plt.figure(figsize=(15, 7))
plt.imshow(imCat, cmap="gray", vmin=0, vmax=255)
plt.title("Original", fontsize=20)
plt.axis("off")
plt.show()

In [None]:
plt.figure(figsize=(30, 20))
plt.subplot(2, 2, 1)
plt.axis("off")
plt.imshow(catEroQuad[0, :, :, 0], cmap="gray", vmin=0, vmax=255)
plt.title("Erosion", fontsize=20)
plt.subplot(2, 2, 2)
plt.imshow(catDilQuad[0, :, :, 0], cmap="gray", vmin=0, vmax=255)
plt.title("Dilation", fontsize=20)
plt.axis("off")
plt.subplot(2, 2, 3)
plt.axis("off")
plt.imshow(catOpenQuad[0, :, :, 0], cmap="gray", vmin=0, vmax=255)
plt.title("Opening", fontsize=20)
plt.subplot(2, 2, 4)
plt.imshow(catClosQuad[0, :, :, 0], cmap="gray", vmin=0, vmax=255)
plt.title("Closing", fontsize=20)
plt.axis("off")
plt.show()