In [None]:
from IPython.display import clear_output
!pip install tf_explain
clear_output()

In [None]:
# common
import os
import time
import keras
import numpy as np
import pandas as pd
from glob import glob
import tensorflow as tf
import tensorflow.image as tfi

# Data
from keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.utils import to_categorical

# Data Viz
import matplotlib.pyplot as plt

# Model 
from keras.models import Model
from keras.layers import Layer
from keras.layers import Conv2D
from keras.layers import Dropout
from keras.layers import UpSampling2D
from keras.layers import concatenate
from keras.layers import Add
from keras.layers import Multiply
from keras.layers import Input
from keras.layers import MaxPooling2D
from keras.layers import Activation
from keras.layers import BatchNormalization

# Callbacks 
from keras.callbacks import Callback
from keras.callbacks import EarlyStopping
from keras.callbacks import ModelCheckpoint
from tf_explain.core.grad_cam import GradCAM

# Metrics
from keras.metrics import MeanIoU

In [None]:
def load_image(image, SIZE):
    return np.round(tfi.resize(img_to_array(load_img(image))/255.,(SIZE, SIZE)),4)

def load_images(image_paths, SIZE, mask=False, trim=None):
    if trim is not None:
        image_paths = image_paths[:trim]
    
    if mask:
        images = np.zeros(shape=(len(image_paths), SIZE, SIZE, 1))
    else:
        images = np.zeros(shape=(len(image_paths), SIZE, SIZE, 3))
    
    for i,image in enumerate(image_paths):
        img = load_image(image,SIZE)
        if mask:
            images[i] = img[:,:,:1]
        else:
            images[i] = img
    
    return images

In [None]:
def show_image(image, title=None, cmap=None, alpha=1):
    plt.imshow(image, cmap=cmap, alpha=alpha)
    if title is not None:
        plt.title(title)
    plt.axis('off')

def show_mask(image, mask, cmap=None, alpha=0.4):
    plt.imshow(image)
    plt.imshow(tf.squeeze(mask), cmap=cmap, alpha=alpha)
    plt.axis('off')

In [None]:
SIZE = 256

In [None]:
root_path = '../input/breast-ultrasound-images-dataset/Dataset_BUSI_with_GT/'
classes = sorted(os.listdir(root_path))
classes

In [None]:
single_mask_paths = sorted([sorted(glob(root_path + name + "/*mask.png")) for name in classes])
double_mask_paths = sorted([sorted(glob(root_path + name + "/*mask_1.png")) for name in classes])

In [None]:
image_paths = []
mask_paths = []
for class_path in single_mask_paths:
    for path in class_path:
        img_path = path.replace('_mask','')
        image_paths.append(img_path)
        mask_paths.append(path)

In [None]:
show_image(load_image(image_paths[0], SIZE))

In [None]:
show_image(load_image('../input/breast-ultrasound-images-dataset/Dataset_BUSI_with_GT/benign/benign (100).png', SIZE))

In [None]:
show_image(load_image('../input/breast-ultrasound-images-dataset/Dataset_BUSI_with_GT/benign/benign (100)_mask_1.png', SIZE))

In [None]:
show_image(load_image('../input/breast-ultrasound-images-dataset/Dataset_BUSI_with_GT/benign/benign (100)_mask.png', SIZE))

In [None]:
img = np.zeros((1,SIZE,SIZE,3))
mask1 = load_image('../input/breast-ultrasound-images-dataset/Dataset_BUSI_with_GT/benign/benign (100)_mask_1.png', SIZE)
mask2 = load_image('../input/breast-ultrasound-images-dataset/Dataset_BUSI_with_GT/benign/benign (100)_mask.png', SIZE)

img = img + mask1 + mask2
img = img[0,:,:,0]
show_image(img, cmap='gray')

In [None]:
show_image(load_image('../input/breast-ultrasound-images-dataset/Dataset_BUSI_with_GT/benign/benign (100).png', SIZE))
plt.imshow(img, cmap='binary', alpha=0.4)
plt.axis('off')
plt.show()

In [None]:
show_image(load_image('../input/breast-ultrasound-images-dataset/Dataset_BUSI_with_GT/benign/benign (100).png', SIZE))
plt.imshow(img, cmap='gray', alpha=0.4)
plt.axis('off')
plt.show()

In [None]:
images = load_images(image_paths, SIZE)
masks = load_images(mask_paths, SIZE, mask=True)

In [None]:
plt.figure(figsize=(13,8))
for i in range(15):
    plt.subplot(3,5,i+1)
    id = np.random.randint(len(images))
    show_mask(images[id], masks[id], cmap='jet')
plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(13,8))
for i in range(15):
    plt.subplot(3,5,i+1)
    id = np.random.randint(len(images))
    show_mask(images[id], masks[id], cmap='binary')
plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(13,8))
for i in range(15):
    plt.subplot(3,5,i+1)
    id = np.random.randint(len(images))
    show_mask(images[id], masks[id], cmap='afmhot')
plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(13,8))
for i in range(15):
    plt.subplot(3,5,i+1)
    id = np.random.randint(len(images))
    show_mask(images[id], masks[id], cmap='copper')
plt.tight_layout()
plt.show()

# Model

In [None]:

def attention_block_2d(input, input_channels=None, output_channels=None, encoder_depth=1, name='at'):
    """
    attention block
    https://arxiv.org/abs/1704.06904
    """
    p = 1
    t = 2
    r = 1

    if input_channels is None:
        input_channels = input.get_shape()[-1].value
    if output_channels is None:
        output_channels = input_channels

    # First Residual Block
    for i in range(p):
        input = residual_block_2d(input)

    # Trunc Branch
    output_trunk = input
    for i in range(t):
        output_trunk = residual_block_2d(output_trunk)

    # Soft Mask Branch

    ## encoder
    ### first down sampling
    output_soft_mask = MaxPooling2D(padding='same')(input)  # 32x32
    for i in range(r):
        output_soft_mask = residual_block_2d(output_soft_mask)

    skip_connections = []
    for i in range(encoder_depth - 1):

        ## skip connections
        output_skip_connection = residual_block_2d(output_soft_mask)
        skip_connections.append(output_skip_connection)

        ## down sampling
        output_soft_mask = MaxPooling2D(padding='same')(output_soft_mask)
        for _ in range(r):
            output_soft_mask = residual_block_2d(output_soft_mask)

            ## decoder
    skip_connections = list(reversed(skip_connections))
    for i in range(encoder_depth - 1):
        ## upsampling
        for _ in range(r):
            output_soft_mask = residual_block_2d(output_soft_mask)
        output_soft_mask = UpSampling2D()(output_soft_mask)
        ## skip connections
        output_soft_mask = add([output_soft_mask, skip_connections[i]])

    ### last upsampling
    for i in range(r):
        output_soft_mask = residual_block_2d(output_soft_mask)
    output_soft_mask = UpSampling2D()(output_soft_mask)

    ## Output
    output_soft_mask = Conv2D(input_channels, (1, 1))(output_soft_mask)
    output_soft_mask = Conv2D(input_channels, (1, 1))(output_soft_mask)
    output_soft_mask = Activation('sigmoid')(output_soft_mask)

    # Attention: (1 + output_soft_mask) * output_trunk
    output = Lambda(lambda x: x + 1)(output_soft_mask)
    output = Multiply()([output, output_trunk])  #

    # Last Residual Block
    for i in range(p):
        output = residual_block_2d(output, name=name)

    return output


def residual_block_2d(input, input_channels=None, output_channels=None, kernel_size=(3, 3), stride=1, name='out'):
    """
    full pre-activation residual block
    https://arxiv.org/pdf/1603.05027.pdf
    """
    if output_channels is None:
        output_channels = input.get_shape()[-1].value
    if input_channels is None:
        input_channels = output_channels // 4

    strides = (stride, stride)

    x = BatchNormalization()(input)
    x = Activation('relu')(x)
    x = Conv2D(input_channels, (1, 1))(x)

    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(input_channels, kernel_size, padding='same', strides=stride)(x)

    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(output_channels, (1, 1), padding='same')(x)

    if input_channels != output_channels or stride != 1:
        input = Conv2D(output_channels, (1, 1), padding='same', strides=strides)(input)
    if name == 'out':
        x = add([x, input])
    else:
        x = add([x, input], name=name)
    return x


def build_res_atten_unet_2d(input_shape, filter_num=8):
    merge_axis = -1  # Feature maps are concatenated along last axis (for tf backend)
    data = Input(shape=input_shape)

    conv1 = Conv2D(filter_num * 4, 3, padding='same')(data)
    conv1 = BatchNormalization()(conv1)
    conv1 = Activation('relu')(conv1)

    # res0 = residual_block_2d(data, output_channels=filter_num * 2)

    pool = MaxPooling2D(pool_size=(2, 2))(conv1)

    res1 = residual_block_2d(pool, output_channels=filter_num * 4)

    # res1 = residual_block_2d(atb1, output_channels=filter_num * 4)

    pool1 = MaxPooling2D(pool_size=(2, 2))(res1)
    # pool1 = MaxPooling2D(pool_size=(2, 2))(atb1)

    res2 = residual_block_2d(pool1, output_channels=filter_num * 8)

    # res2 = residual_block_2d(atb2, output_channels=filter_num * 8)
    pool2 = MaxPooling2D(pool_size=(2, 2))(res2)
    # pool2 = MaxPooling2D(pool_size=(2, 2))(atb2)

    res3 = residual_block_2d(pool2, output_channels=filter_num * 16)
    # res3 = residual_block_2d(atb3, output_channels=filter_num * 16)
    pool3 = MaxPooling2D(pool_size=(2, 2))(res3)
    # pool3 = MaxPooling2D(pool_size=(2, 2))(atb3)

    res4 = residual_block_2d(pool3, output_channels=filter_num * 32)

    # res4 = residual_block_2d(atb4, output_channels=filter_num * 32)
    pool4 = MaxPooling2D(pool_size=(2, 2))(res4)
    # pool4 = MaxPooling2D(pool_size=(2, 2))(atb4)

    res5 = residual_block_2d(pool4, output_channels=filter_num * 64)
    # res5 = residual_block_2d(res5, output_channels=filter_num * 64)
    res5 = residual_block_2d(res5, output_channels=filter_num * 64)

    atb5 = attention_block_2d(res4, encoder_depth=1, name='atten1')
    up1 = UpSampling2D(size=(2, 2))(res5)
    merged1 = concatenate([up1, atb5], axis=merge_axis)
    # merged1 = concatenate([up1, atb4], axis=merge_axis)

    res5 = residual_block_2d(merged1, output_channels=filter_num * 32)
    # atb5 = attention_block_2d(res5, encoder_depth=1)

    atb6 = attention_block_2d(res3, encoder_depth=2, name='atten2')
    up2 = UpSampling2D(size=(2, 2))(res5)
    # up2 = UpSampling2D(size=(2, 2))(atb5)
    merged2 = concatenate([up2, atb6], axis=merge_axis)
    # merged2 = concatenate([up2, atb3], axis=merge_axis)

    res6 = residual_block_2d(merged2, output_channels=filter_num * 16)
    # atb6 = attention_block_2d(res6, encoder_depth=2)

    # atb6 = attention_block_2d(res6, encoder_depth=2)
    atb7 = attention_block_2d(res2, encoder_depth=3, name='atten3')
    up3 = UpSampling2D(size=(2, 2))(res6)
    # up3 = UpSampling2D(size=(2, 2))(atb6)
    merged3 = concatenate([up3, atb7], axis=merge_axis)
    # merged3 = concatenate([up3, atb2], axis=merge_axis)

    res7 = residual_block_2d(merged3, output_channels=filter_num * 8)
    # atb7 = attention_block_2d(res7, encoder_depth=3)

    # atb7 = attention_block_2d(res7, encoder_depth=3)
    atb8 = attention_block_2d(res1, encoder_depth=4, name='atten4')
    up4 = UpSampling2D(size=(2, 2))(res7)
    # up4 = UpSampling2D(size=(2, 2))(atb7)
    merged4 = concatenate([up4, atb8], axis=merge_axis)
    # merged4 = concatenate([up4, atb1], axis=merge_axis)

    res8 = residual_block_2d(merged4, output_channels=filter_num * 4)
    # atb8 = attention_block_2d(res8, encoder_depth=4)

    # atb8 = attention_block_2d(res8, encoder_depth=4)
    up = UpSampling2D(size=(2, 2))(res8)
    # up = UpSampling2D(size=(2, 2))(atb8)
    merged = concatenate([up, conv1], axis=merge_axis)
    # res9 = residual_block_2d(merged, output_channels=filter_num * 2)

    conv9 = Conv2D(filter_num * 4, 3, padding='same')(merged)
    conv9 = BatchNormalization()(conv9)
    conv9 = Activation('relu')(conv9)

    output = Conv2D(1, 3, padding='same', activation='sigmoid')(conv9)
    model = Model(data, output)
    return model