In [None]:
!git clone https://github.com/Jacobiano/morpholayers.git

In [None]:
from imageio import imread
from matplotlib import pyplot as plt
import numpy as np
import numpy.matlib
import tensorflow.keras
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras import backend as K
from morpholayers.layers import Dilation2D, Erosion2D, Opening2D
from morpholayers.constraints import NonPositiveExtensive

## Learning a dilation structuring element by adjunction

Load an image.

In [None]:
imCol = imread("imageio:chelsea.png")
imCat = imCol[:, :, 0]

In [None]:
plt.figure(figsize=(15, 7))
plt.imshow(imCat, cmap="gray", vmin=0, vmax=255)
plt.axis("off")
plt.show()

Function to create structuring elements.

In [None]:
def mkSE(p, se_type):
    K = -255 * np.ones((2 * p + 1, 2 * p + 1))
    if se_type == "cross":
        K[p, :] = 0
        K[:, p] = 0
    if se_type == "half_cross":
        K[p, p:] = 0
        K[p:, p] = 0
    if se_type == "x-shaped":
        for i in range(2 * p + 1):
            K[i, i] = 0
            K[i, 2 * p - i] = 0
    if se_type == "diag1":
        for i in range(2 * p + 1):
            K[i, 2 * p - i] = 0
    if se_type == "diag2":
        for i in range(2 * p + 1):
            K[i, i] = 0
    if se_type == "quad":
        lambd = 2 * p
        for i in range(2 * p + 1):
            for j in range(2 * p + 1):
                K[i, j] = int(
                    np.round(-0.5 * 255 * ((i - p) ** 2 + (j - p) ** 2) / lambd**2)
                )
    return K

Let's first try a $3\times 3$ cross structuring element (that is, $p=1$ and hence $2p+1 = 3$).

In [None]:
p = 1
SE_id = "cross"  #'diag2' #'diag1' #'x-shaped' # cross #'half_cross' #'quad'
SE = mkSE(p, SE_id)
plt.figure()
plt.imshow(SE, cmap="gray", vmax=0, vmin=-255)
plt.colorbar()
plt.axis("off")
plt.show()

The following functions implement the learning setting presented in the course:
- reshape the $M\times N$ image as a matrix $\tilde{X}$ containing $M\cdot N$ rows and $(2p+1)^2$ columns (each row $i$ is the reshaped $(2p+1)^2$ neighbourhood of pixel $i$.)
- the $(2p+1)\times (2p+1)$ structuring element is reshaped as a $(2p+1)^2$ column $W$
- the dilation is the max-plus matrix product $\tilde{X}.W$
- the fuction DilMaxPlus is an additional layer that does all the reshaping to take as input an image and a square structuring element, and return an image.

In [None]:
def block_reshape(im_in, p):
    m = im_in.shape[0]
    n = im_in.shape[1]
    im_shape = m * n
    im_padded = np.pad(im_in, ((p, p), (p, p)), mode="symmetric")
    X = np.zeros((im_shape, (2 * p + 1) ** 2))
    for i in range(m):
        for j in range(n):
            w = im_padded[i : i + 2 * p + 1, j : j + 2 * p + 1]
            w = np.reshape(np.transpose(w), ((2 * p + 1) ** 2))
            X[m * j + i, :] = w
    return X

In [None]:
def dil_max_plus(X, W):
    X_plus_W = X + np.matlib.repmat(np.transpose(W), X.shape[0], 1)
    return X_plus_W.max(axis=1)

In [None]:
def DilMaxPlus(im_input, SE):
    M = im_input.shape[0]
    N = im_input.shape[1]
    p = int((SE.shape[0] - 1) / 2)
    SECol = np.reshape(np.transpose(SE), ((2 * p + 1) ** 2, 1))
    imBlock = block_reshape(im_input, p)
    resColumn = dil_max_plus(imBlock, SECol)
    im_res = np.reshape(np.transpose(resColumn), (M, N), "F")
    return im_res

In [None]:
catDil = DilMaxPlus(imCat, SE)

In [None]:
plt.figure(figsize=(30, 12))
plt.subplot(1, 2, 1)
plt.axis("off")
plt.imshow(imCat, cmap="gray", vmin=0, vmax=255)
plt.title("Original", fontsize=20)
plt.subplot(1, 2, 2)
plt.imshow(catDil, cmap="gray", vmin=0, vmax=255)
plt.title("Dilation by " + SE_id + " SE", fontsize=20)
plt.axis("off")
plt.show()

Reshape output image into a column vector of size $m\cdot n$:

In [None]:
m, n = catDil.shape
Y = np.reshape(catDil, (m * n, 1), "F")
print(Y.shape)

Define the erosion adjoint to the max-plus dilation:

In [None]:
def erod_min_plus(X, W):
    X_plus_W = X + np.matlib.repmat(np.transpose(W), X.shape[0], 1)
    return X_plus_W.min(axis=1)

Reshape input image into a matrix $X$ of size $(m\cdot n) \times (2p+1)^2$:

In [None]:
X = block_reshape(imCat, p)
print(X.shape)

Define $X^* = -\tilde{X}^T$ and apply the adjoint erosion to recover the structuring element: $\hat{W} = \varepsilon_{X^*}(Y)$.

In [None]:
X_star = -np.transpose(X)
What = erod_min_plus(X_star, Y)
What = np.reshape(What, (2 * p + 1, 2 * p + 1), "F")
print(What)

Compare original and recovered structuring elements.

In [None]:
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(What, cmap="gray", vmax=0, vmin=-255)
plt.title("Estimated SE", fontsize=20)
plt.colorbar()
plt.subplot(1, 2, 2)
plt.imshow(SE, cmap="gray", vmax=0, vmin=-255)
plt.title("Original SE", fontsize=20)
plt.colorbar()
plt.show()

Compute the dilation of input image with the recovered structuring element and compare with original dilation.

In [None]:
catDilEstSE = DilMaxPlus(imCat, What)

In [None]:
plt.figure(figsize=(30, 12))
plt.subplot(1, 2, 1)
plt.axis("off")
plt.imshow(catDil, cmap="gray", vmin=0, vmax=255)
plt.title("Dilation by " + SE_id + " SE", fontsize=20)
plt.subplot(1, 2, 2)
plt.imshow(catDilEstSE, cmap="gray", vmin=0, vmax=255)
plt.title("Dilation by estimated SE", fontsize=20)
plt.axis("off")
plt.show()

Compute the mean squared error between original and estimated dilation:

In [None]:
errEst = np.sum((catDil - catDilEstSE) ** 2)
print(errEst)

## Learning a dilation structuring element by error minization with gradient descent

Load Fashion MNIST images.

In [None]:
# input image dimensions
img_rows, img_cols = 28, 28
num_classes = 10

# the data, split between train and test sets

(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
# (x_train, y_train), (x_test, y_test) = mnist.load_data()

if K.image_data_format() == "channels_first":
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

x_train = x_train.astype("float32")
x_test = x_test.astype("float32")
x_train /= 255
x_test /= 255
print("x_train shape:", x_train.shape)

Define a $7\times 7$ quadratic structuring element.

In [None]:
p = 3
SE_id = "quad"  #'diag2' #'diag1' #'x-shaped' # cross #'half_cross' #'quad'
SE = mkSE(p, SE_id) / 255
plt.figure()
plt.imshow(SE, cmap="gray", vmax=0, vmin=-1)
plt.colorbar()
plt.axis("off")
plt.show()

Show examples of dilation results.

In [None]:
print("Examples from the training set:")
nsamp = 10
plt.figure(figsize=(30, 5))
ridxs = []
for i in range(nsamp):
    plt.subplot(1, nsamp, i + 1)
    randidx = np.random.randint(x_train.shape[0])
    ridxs.append(randidx)
    plt.imshow(x_train[randidx, :, :, 0], vmin=0, vmax=1, cmap="gray")
    plt.axis("off")
plt.show()
plt.figure(figsize=(30, 5))
for i in range(nsamp):
    plt.subplot(1, nsamp, i + 1)
    imdil = DilMaxPlus(x_train[ridxs[i], :, :, 0], SE)
    plt.imshow(imdil, vmin=0, vmax=1, cmap="gray")
    plt.axis("off")
plt.show()

Split images into train and test sets.

In [None]:
ntrain = 150
randidxs = np.random.randint(x_train.shape[0], size=(ntrain,))
Xtrain = x_train[randidxs, :, :, :]
ntest = 50
randidxs = np.random.randint(x_test.shape[0], size=(ntest,))
Xtest = x_test[randidxs, :, :, :]

In [None]:
print(Xtrain.shape)
print(Xtest.shape)

Create ground truth labelling (by dilating train and test images).

In [None]:
Ytrain = np.zeros(Xtrain.shape)
for i in range(ntrain):
    Ytrain[i, :, :, 0] = DilMaxPlus(Xtrain[i, :, :, 0], SE)
Ytest = np.zeros(Xtest.shape)
for i in range(ntest):
    Ytest[i, :, :, 0] = DilMaxPlus(Xtest[i, :, :, 0], SE)

Define the dilation neural network containing one dilation layer.

In [None]:
inputIm = Input(shape=input_shape)
xout = Dilation2D(
    1, kernel_size=(2 * p + 1, 2 * p + 1), padding="same", strides=(1, 1)
)(inputIm)
modelDilation = Model(inputIm, xout, name="dilationModel")
modelDilation.summary()

Set optimization parameters.

In [None]:
modelDilation.compile(
    loss=tensorflow.keras.losses.mse,
    optimizer=tensorflow.keras.optimizers.Adam(),
    metrics=["mse"],
)

In [None]:
# filecsv='LearningCurvesMorphoModel_fashion.txt'
# cb=CSVLogger(filecsv,append=False)
earlyStop = EarlyStopping(
    monitor="val_loss",
    min_delta=0.00001,
    patience=20,
    verbose=0,
    mode="auto",
    baseline=None,
    restore_best_weights=True,
)

Start training!

In [None]:
modelDilation.fit(
    Xtrain,
    Ytrain,
    batch_size=10,
    epochs=500,
    verbose=1,
    validation_data=(Xtest, Ytest),
    callbacks=[earlyStop],
)

Monitor the weights of the dilation layer (that is, the learned structuring element).

In [None]:
W = modelDilation.get_weights()[0]  # Get the weights of the only layer in modelMorpho
print(W.shape)
print(W[:, :, 0, 0])

In [None]:
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(W[:, :, 0, 0], cmap="gray", vmax=0, vmin=-1)
plt.title("Estimated SE", fontsize=20)
plt.colorbar()
plt.subplot(1, 2, 2)
plt.imshow(SE, cmap="gray", vmax=0, vmin=-1)
plt.title("Original SE", fontsize=20)
plt.colorbar()
plt.show()

## Learning an opening structuring element by error minization by gradient descent

The functions ErodMinPlus and OpenMaxPlus are analogous to DilMaxPlus: the do the reshaping and apply the proper max/min-plus operators.

In [None]:
def ErodMinPlus(im_input, SE):
    M = im_input.shape[0]
    N = im_input.shape[1]
    p = int((SE.shape[0] - 1) / 2)
    SECol = np.reshape(SE, ((2 * p + 1) ** 2, 1))
    imBlock = block_reshape(im_input, p)
    resColumn = erod_min_plus(imBlock, -SECol)
    im_res = np.reshape(np.transpose(resColumn), (M, N), "F")
    return im_res

In [None]:
def OpenMaxPlus(im_input, SE):
    imEro = ErodMinPlus(im_input, SE)
    im_res = DilMaxPlus(imEro, SE)
    return im_res

Show chosen structring element (try symmetric and non symmetric SEs).

In [None]:
p = 1
SE_id = "half_cross"  #'diag2' #'diag1' #'x-shaped' # cross #'half_cross' #'quad'
SE = mkSE(p, SE_id) / 255
plt.figure()
plt.imshow(SE, cmap="gray", vmax=0, vmin=-1)
plt.colorbar()
plt.axis("off")
plt.show()

Show examples of opening.

In [None]:
print("Examples from the training set:")
nsamp = 10
plt.figure(figsize=(30, 5))
ridxs = []
for i in range(nsamp):
    plt.subplot(1, nsamp, i + 1)
    randidx = np.random.randint(x_train.shape[0])
    ridxs.append(randidx)
    plt.imshow(x_train[randidx, :, :, 0], vmin=0, vmax=1, cmap="gray")
    plt.axis("off")
plt.show()
plt.figure(figsize=(30, 5))
for i in range(nsamp):
    plt.subplot(1, nsamp, i + 1)
    imopen = OpenMaxPlus(x_train[ridxs[i], :, :, 0], SE)
    plt.imshow(imopen, vmin=0, vmax=1, cmap="gray")
    plt.axis("off")
plt.show()

Split into train and test sets.

In [None]:
ntrain = 500  # 150
randidxs = np.random.randint(x_train.shape[0], size=(ntrain,))
Xtrain = x_train[randidxs, :, :, :]
ntest = 100  # 50
randidxs = np.random.randint(x_test.shape[0], size=(ntest,))
Xtest = x_test[randidxs, :, :, :]

Ground truth labelling.

In [None]:
Ytrain = np.zeros(Xtrain.shape)
for i in range(ntrain):
    Ytrain[i, :, :, 0] = OpenMaxPlus(Xtrain[i, :, :, 0], SE)
Ytest = np.zeros(Xtest.shape)
for i in range(ntest):
    Ytest[i, :, :, 0] = OpenMaxPlus(Xtest[i, :, :, 0], SE)

Define a "one layer opening" architecture.

In [None]:
inputIm = Input(shape=input_shape)
xout = Opening2D(
    1,
    kernel_size=(2 * p + 1, 2 * p + 1),
    padding="same",
    strides=(1, 1),
    kernel_constraint=NonPositiveExtensive(),
)(inputIm)
modelOpen = Model(inputIm, xout, name="openingWrapped")
modelOpen.summary()

In [None]:
modelOpen.compile(
    loss=tensorflow.keras.losses.mse,
    optimizer=tensorflow.keras.optimizers.Adam(learning_rate=0.0001),
    metrics=["mse"],
)

In [None]:
earlyStop = EarlyStopping(
    monitor="val_loss",
    min_delta=0.00001,
    patience=200,
    verbose=0,
    mode="auto",
    baseline=None,
    restore_best_weights=True,
)

Start training!

In [None]:
modelOpen.fit(
    Xtrain,
    Ytrain,
    batch_size=10,
    epochs=1500,
    verbose=1,
    validation_data=(Xtest, Ytest),
    callbacks=[earlyStop],
)

Monitor the learned weights (the structuring element).

In [None]:
W = modelOpen.get_weights()[0]  # Get the weights of the only layer in modelMorpho
print(W.shape)
print(W[:, :, 0, 0])

In [None]:
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(W[:, :, 0, 0], cmap="gray", vmax=0, vmin=-1)
plt.title("Estimated SE", fontsize=20)
plt.colorbar()
plt.subplot(1, 2, 2)
plt.imshow(SE, cmap="gray", vmax=0, vmin=-1)
plt.title("Original SE", fontsize=20)
plt.colorbar()
plt.show()

Now define a "two layers" opening architecture.

In [None]:
inputIm = Input(shape=input_shape)
xero = Erosion2D(
    1,
    kernel_size=(2 * p + 1, 2 * p + 1),
    padding="same",
    strides=(1, 1),
    kernel_constraint=NonPositiveExtensive(),
    name="myErosion",
)(inputIm)
xout = Dilation2D(
    1,
    kernel_size=(2 * p + 1, 2 * p + 1),
    padding="same",
    strides=(1, 1),
    kernel_constraint=NonPositiveExtensive(),
    name="myDilation",
)(xero)
modelOpen2 = Model(inputIm, xout, name="OpeningSequential")
modelOpen2.summary()

In [None]:
modelOpen2.compile(
    loss=tensorflow.keras.losses.mse,
    optimizer=tensorflow.keras.optimizers.Adam(learning_rate=0.0001),
    metrics=["mse"],
)

In [None]:
earlyStop = EarlyStopping(
    monitor="val_loss",
    min_delta=0.00001,
    patience=100,
    verbose=0,
    mode="auto",
    baseline=None,
    restore_best_weights=True,
)

In [None]:
modelOpen2.fit(
    Xtrain,
    Ytrain,
    batch_size=10,
    epochs=1500,
    verbose=1,
    validation_data=(Xtest, Ytest),
    callbacks=[earlyStop],
)

Monitor the weights of each layer.

In [None]:
L1 = modelOpen2.get_layer("myErosion")
L2 = modelOpen2.get_layer("myDilation")
W1 = L1.get_weights()[0]
W2 = L2.get_weights()[0]
print(W1.shape)
print(W2.shape)

In [None]:
plt.figure(figsize=(12, 6))
plt.subplot(1, 3, 1)
plt.imshow(W1[:, :, 0, 0], cmap="gray", vmax=0, vmin=-1)
plt.title("Estimated SE (erosion layer)", fontsize=15)
plt.colorbar()
plt.subplot(1, 3, 2)
plt.imshow(W2[:, :, 0, 0], cmap="gray", vmax=0, vmin=-1)
plt.title("Estimated SE (dilation layer)", fontsize=15)
plt.colorbar()
plt.subplot(1, 3, 3)
plt.imshow(SE, cmap="gray", vmax=0, vmin=-1)
plt.title("Original SE", fontsize=15)
plt.colorbar()
plt.show()