# **Imports**

Bellow are all the imports used in the **Notebook**.

In [None]:
# Common
import keras
from glob import glob
from tqdm import tqdm
import tensorflow as tf
from numpy import zeros, random

# Data
from tensorflow.image import resize
from keras.preprocessing.image import load_img, img_to_array

# Data viz
import matplotlib.pyplot as plt

# Model
from keras.models import Model, Sequential, load_model
from keras.layers import Conv2D, Conv2DTranspose, concatenate, MaxPool2D, Dropout, BatchNormalization, Layer, Input, add, multiply, UpSampling2D

# Model Viz
from tensorflow.keras.utils import plot_model

# Callback
from keras.callbacks import Callback


# **Data**

The **foremost** thing that we need to accomplish is to **load the data**.

In [None]:
def load_image(path):
    img = resize( img_to_array( load_img(path) )/255. , (256,256))
    return img

This function takes in the **path of the image** and load it using **keras functions**.

In [None]:
root_path = '../input/butterfly-dataset/leedsbutterfly/images/'
image_paths = sorted(glob(root_path + f"*.png"))
mask_paths = []
for path in image_paths:
    mask_path = path.replace('images','segmentations')
    mask_path = mask_path.replace('.png','_seg0.png')
    mask_paths.append(mask_path)
print(f"Total Number of Images  : {len(image_paths)}")

This way, by **replacing the text** from the **path** we can easily get the **exact segmentation mask** for a **particular image**.

In [None]:
images = zeros(shape=(len(image_paths), 256, 256, 3))
masks = zeros(shape=(len(image_paths), 256, 256, 3))
for n, (img_path, mask_path) in tqdm(enumerate(zip(image_paths, mask_paths)), desc="Loading"):
    images[n] = load_image(img_path)
    masks[n] = load_image(mask_path)

Now, our **images and masks** are loaded. It's time to **visualize them. So that we can gain some insights about the data.

# **Data Visualization**

In [None]:
def show_image(image, title=None, alpha=1):
    plt.imshow(image, alpha=alpha)
    plt.title(title)
    plt.axis('off')

The **below function** will **plot the mask** for us, **with various variations** and we can also use it as a **callback**.

In [None]:
def show_mask(GRID, fig_size=(8,20), model=None, join=False, alpha=0.5):
    
    # Config the GRID
    n_rows, n_cols = GRID
    n_images = n_rows * n_cols
    n = 1
    plt.figure(figsize=fig_size)
    for i in range(1,n_images+1):
        
        if model is None:

            if join:
                
                # Seect a Random Image and mask
                id = random.randint(len(images))
                image, mask = images[id], masks[id]
                
                # plot the Mask over the Image
                plt.subplot(n_rows, n_cols, i)
                show_image(image)
                show_image(mask, alpha=alpha)
                
            else:
                
                if i%2==0:
                    plt.subplot(n_rows,n_cols,i)
                    show_image(mask)
                
                else:
                    # Seect a Random Image and mask
                    id = random.randint(len(images))
                    image, mask = images[id], masks[id]
                    
                    # Plot Image
                    plt.subplot(n_rows,n_cols,i)
                    show_image(image)
        else:
            if join:
                
                if i%2==0:
                    # plot the Mask over the Image
                    plt.subplot(n_rows, n_cols, i)
                    show_image(image)
                    show_image(pred_mask, alpha=alpha, title="Predicted Mask")
                else:
                    # Seect a Random Image and mask
                    id = random.randint(len(images))
                    image, mask = images[id], masks[id]
                    pred_mask = model.predict(tf.expand_dims(image, axis=0))[0]
                    
                    # plot the Mask over the Image
                    plt.subplot(n_rows, n_cols, i)
                    show_image(image)
                    show_image(mask, alpha=alpha, title="Original Mask")
            else:
                if n==1:
                    # Seect a Random Image and mask
                    id = random.randint(len(images))
                    image, mask = images[id], masks[id]
                    pred_mask = model.predict(tf.expand_dims(image, axis=0))[0]
                    
                    # plot the Mask over the Image
                    plt.subplot(n_rows, n_cols, i)
                    show_image(image, title="Original Image")
                    n+=1
                    
                elif n==2:
                    # plot the Mask over the Image
                    plt.subplot(n_rows, n_cols, i)
                    show_image(mask, title="Original Mask")
                    n+=1
                elif n==3:
                    # plot the Mask over the Image
                    plt.subplot(n_rows, n_cols, i)
                    show_image(pred_mask, title="Predicted Mask")
                    n=1
    plt.show()

In [None]:
GRID = [5,4]
show_mask(GRID, fig_size=(15,20))

This can be a **tough task** for the model because the **image background** contains **a lot of objects**. Thus, using an **Attention UNet** would be a good idea.

In [None]:
GRID = [5,4]
show_mask(GRID, fig_size=(15,20), join=True)

Plotting the mask over the image gives us a **better visualization**.

# **Attention UNet - Encoder**

* The **main task** of the **Encoder** is to **downsample the images** by a **factor of 2**, but at the same time **learn the features** present in the image. 

* The idea behind encoder is that it will gradually learn all the **useful features** and preserve them in a **latent representation**, which is present in the **last encoding layer**.

* A **small amount of dropout** is also added between the **convolutional layers** in the encoder so that each **layer is forced to learn the most useful features**.

In [None]:
class Encoder(Layer):
    
    def __init__(self, filters, rate, pooling=True, **kwargs):
        super(Encoder, self).__init__(**kwargs)
        
        self.filters = filters
        self.rate = rate
        self.pooling = pooling
        
        self.c1 = Conv2D(filters, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', activation='relu')
        self.drop = Dropout(rate)
        self.c2 = Conv2D(filters, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', activation='relu')
        self.pool = MaxPool2D()
        
    def call(self, X):
        x = self.c2(self.drop(self.c1(X)))
        if self.pooling:
            y = self.pool(x)
            return y, x
        else:
            return x
    
    def get_config(self):
        base_config = super().get_config()
        return {
            **base_config,
            "filters":self.filters,
            "rate":self.rate,
            "pooling":self.pooling,
        }

# **Attention UNet - Decoder**

* The **decoder** is just the **opposite** of the **encoder** in terms of **functioning** because it **Upsamples** the **input images** or the **input feature maps** by a **factor of 2**.

* The input to **the decoder** are the **latent representations** learned by the encoder. This means the **decoder** has access only to the **most useful features** and it **tries to replicate the segmentation mask** from these features.

* One **major reason** behind the **success of UNet architecture** are the **skip connections** from the **encoder to the decoder layer**. This allowed the **decoder to learn** the **spatial information** present in the original image.|

In [None]:
class Decoder(Layer):
    
    def __init__(self, filters, rate, **kwargs):
        super(Decoder, self).__init__(**kwargs)
        
        self.filters = filters
        self.rate = rate
        
        self.cT = Conv2DTranspose(filters, kernel_size=3, strides=2, padding='same', kernel_initializer='he_normal', activation='relu')
        self.net = Encoder(filters, rate, pooling=False)
        
    def call(self, X):
        x, skip_x = X
        x = self.cT(x)
        
        c = concatenate([x, skip_x])
        f = self.net(c)
        return f
    
    def get_config(self):
        base_config = super().get_config()
        return {
            **base_config,
            "filters":self.filters,
            "rate":self.rate
        }

# **Attention UNet - Attention Gate**

The **idea behind the attention gate** is to add a **particular gate or a layer** between the **skip connections** so that the **skip connections can be refined** and only the **most important spatial information is fed to the decoder**.

In [None]:
class AttentionGate(Layer):
    
    def __init__(self, filters, **kwargs):
        super(AttentionGate, self).__init__(**kwargs)
        
        self.filters = filters
        
        self.normal = Conv2D(filters, kernel_size=3, strides=1, padding='same', kernel_initializer='he_normal', activation='relu')
        self.down = Conv2D(filters, kernel_size=3, strides=2, padding='same', kernel_initializer='he_normal', activation='relu')
        
        self.learn = Conv2D(1, kernel_size=1, strides=1, activation='sigmoid')
        self.resample = UpSampling2D()
    
    def call(self, X):
        x, skip_x = X
        
        x = self.normal(x)
        skip = self.down(skip_x)
        a = add([x, skip])
        
        l = self.learn(a)
        l = self.resample(l)
        
        f = multiply([l, skip_x])
        return f
    
    def get_config(self):
        base_config = super().get_config()
        return {
            **base_config,
            "filters":self.filters
        }

# **Attention UNet**

So the **Encoder, Decoder and the Attention Gate** is ready. It's time to combine all of them in complete our **Attention Unet** architecture.

In [None]:
# Inputs
image_input = Input(shape=(256,256,3), name="InputImage")

# Encoder Phase
p1, c1 = Encoder(32, 0.1, name="EncoderBlock1")(image_input)
p2, c2 = Encoder(64, 0.1, name="EncoderBlock2")(p1)
p3, c3 = Encoder(128, 0.2, name="EncoderBlock3")(p2)
p4, c4 = Encoder(256, 0.2, name="EncoderBlock4")(p3)

# Latent Representation
encoding = Encoder(512, 0.3, pooling=False, name="Encoding")(p4)

# Deocder + Attention Phase
a1 = AttentionGate(256, name="Attention1")([encoding, c4])
d1 = Decoder(256, 0.2, name="DecoderBlock1")([encoding, a1])

a2 = AttentionGate(128, name="Attention2")([d1, c3])
d2 = Decoder(128, 0.2, name="DecoderBlock2")([d1, a2])

a3 = AttentionGate(64, name="Attention3")([d2, c2])
d3 = Decoder(64, 0.2, name="DecoderBlock3")([d2, a3])

a4 = AttentionGate(32, name="Attention4")([d3, c1])
d4 = Decoder(32, 0.1, name="DecoderBlock4")([d3, a4])

# Output Layer
mask_out = Conv2D(3, kernel_size=1, strides=1, activation='sigmoid', padding='same', name="MaskOut")(d4)

# Model
att_unet = Model(
    inputs=[image_input], outputs=[mask_out], name="AttentionUNet"
)

# Compile
att_unet.compile(
    loss='binary_crossentropy',
    optimizer='adam'
)

# **Attention UNet - Visualization**

In [None]:
plot_model(att_unet, "AttentionUNet.png", show_shapes=True)

# **Custom Callback**

It will be a **good idea to visualize models performance after each epoch.**

In [None]:
class ShowProgress(Callback):
    def on_epoch_end(self, epoch, logs=None):
        show_mask(GRID=[1,1], model=self.model, join=False, fig_size=(20,8))
        self.model.save("AttentionUnet.h5")

# **Training**

**Training Attention UNet** is simple. Just train it like we train other models.

In [None]:
# att_unet.fit(
#     images, masks,
#     validation_split=0.1,
#     epochs=20, 
#     callbacks=[ShowProgress()]
# )

# **Evaluation**

In [None]:
att_unet = load_model('../input/attention-unet-butterfly-segmentation/AttentionUnet.h5', custom_objects={
    "Encoder":Encoder,
    "Decoder":Decoder,
    "AttentionGate":AttentionGate
})

In [None]:
show_mask(GRID=[10,6], model=att_unet, join=False, fig_size=(20,30))

In [None]:
show_mask(GRID=[10,6], model=att_unet, join=True, fig_size=(20,30), alpha=0.8)

$Observations :$

* During the **initial stage** of **training**, the **model's prediction** will be **very bad**. This is because the **attention model** is still trying to capture or learn the **important spatial information**.

* The **generated mask** for **white butterflies** are **not good**. One possible reason for that is **the attention part of the model** has learned to ignore the **white flowers** present in the image and consider them as flowers. Now it's a theory, it could be wrong.

* I also tried to **add more decoder and encoder** layers, but then it was going out of **RAM**.

* The generated Mask will always be a little blurry. It is **always a good idea to add a post processing function.**

* Overall the **models prediction are very satisfying**.

**Thanks!!**

---
**DeepNets**