## <font color = "seagreen">Abstraction</font>

[**About Competition**]: In this competition, we’ll detect the presence and position of **catheters** and **lines** on chest x-rays. We need to use machine learning to train and test our model on **40,000** images to categorize a tube that is poorly placed.

[**About Data**]: This is a code-only competition so there is a hidden test set (approximately 4x larger, with ~14k images) as well. And here is list of the given files information for this competition. 

    - train.csv - contains image IDs, binary labels, and patient IDs.
    - sample_submission.csv - a sample submission file in the correct format
    - test - test images
    - train - training images
    - train_annotations.csv - segmentation annotations (additional information for competitors).

Here are the essential information from the meta data (`train.csv`). 

- `StudyInstanceUID` - unique ID for each image
- `ETT - Abnormal` - endotracheal tube placement abnormal
- `ETT - Borderline` - endotracheal tube placement borderline abnormal
- `ETT - Normal` - endotracheal tube placement normal
- `NGT - Abnormal` - nasogastric tube placement abnormal
- `NGT - Borderline` - nasogastric tube placement borderline abnormal
- `NGT - Incompletely Imaged` - nasogastric tube placement inconclusive due to imaging
- `NGT - Normal` - nasogastric tube placement borderline normal
- `CVC - Abnormal` - central venous catheter placement abnormal
- `CVC - Borderline` - central venous catheter placement borderline abnormal
- `CVC - Normal` - central venous catheter placement normal
- `Swan Ganz Catheter Present`
- `PatientID` - unique ID for each patient in the dataset

It's a **multi-label** classification problem. Our target labels are from `ETT - Abnormal` to `Swan Ganz Catheter Present`. The submissions are evaluated on **Area Under the ROC** curve between the predicted probability and the observed target. To calculate the final score, **AUC** is calculated for each of the **11** labels, then averaged. The score is then the average of the individual **AUCs** of each predicted column.

## <font color = "seagreen">Content of Notebook</font>

Here is the content of this code examples. It is configured to run on **TPU** as well as **GPU** hardware with [**mixed precision**](https://www.tensorflow.org/guide/mixed_precision) training. It's also configured to work on **[JPEG](https://www.kaggle.com/c/ranzcr-clip-catheter-line-classification/data)** samples as well as **[TF-Record](https://www.kaggle.com/c/ranzcr-clip-catheter-line-classification/data)** format - both option are provided. To know about **TPU** and TF-Record, please refer to [this tpu-guidebook](https://www.kaggle.com/docs/tpu), hosted on Kaggle and [tfrecod-example](https://www.tensorflow.org/tutorials/load_data/tfrecord#:~:text=The%20TFRecord%20format%20is%20a,to%20understand%20a%20message%20type.), hosted on TensorFlow.org.

Along with the above **device set-up** and **data** specific configuration, we've also tried to develop a research idea, refer to the **multi-branch soft attention modules**. We've tried to implement it in `tf.keras` with model [subclassing API](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) and tried to inspect it with **GradCAM** at the end of epoch using **Callback** function. The approach **Multi-Attention** presented in this notebook is a simple research idea. It can provide ideas or encourage. Practically. it may need extensive experiments and careful integration to the imagenet models or custom models. In this notebook, for demonstration purposes, we have simply tried to add it as a custom neck to the pre-trained model. In a result, this gives promising outputs for this task. More theoretical and mathematical informations about it, can be found in the relevant section in this notebook. Here is the summary:

```
1 Clean Set-Up for Training (on TPU) and Inference (on GPU).
2 Data Modeling  
    - Data preprocessing : tf.data API with 
        - JPEG and, 
        - TF-RECORD
    - Data augmentation : tf.image modules
3 Hyperparameter Setting:
    - Warmup Learning Rate Schedule (exponential, cosine, constant, cosine-restart).
4 Deep Network Modeling
    - Backbone: EfficientNet  
    - Neck: Custom Top Layers (Multi-Attention Mechanisms)
    - Head: Sigmoid Classifier (Multi-label)
5 Training
    - GradCAM set up
    - Model training
6 Inference
```

## Imports

In [None]:
import numpy as np 
import pandas as pd 
import matplotlib.cm as cm
import matplotlib.pyplot as plt 
from kaggle_datasets import KaggleDatasets

import os, warnings
from sklearn import model_selection
warnings.simplefilter('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import tensorflow as tf; print(tf.__version__)
from tensorflow import keras; print(keras.__version__)
from tensorflow.keras import layers
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import schedules

## Device Config

As mentioned, we can use the host provided `jpeg` format files or we can use host provided `tfrecord` format files. Also, this notebook is configured to train the model on **TPU** and inference on **GPU**. But feel free to pick and set option which you prefer most. 

I've trained the model in **TPU + JPEG** and saved the best weight. But for demonstration purpose, I will train the model again with **TPU + TF-TFRCORD** configs. Please read also the comment that attach with the code in the following cells to understand the whole set up precisely. 

In [None]:
# set: 'TF_RECORD' for using .tfrec format 
# set: 'JPEG' for using jpeg format
FILE_TYPE = 'TF_RECORD' # OPTION: 'JPEG', 'TF_RECORD'    

MIXED_PRECISION = False  # Faster and use Less memory
XLA_ACCELERATE  = False # XLA: Optimizing Compiler

try:  # detect TPUs
    tpu      = tf.distribute.cluster_resolver.TPUClusterResolver.connect()  # TPU detection
    strategy = tf.distribute.TPUStrategy(tpu)
    DEVICE   = 'TPU'
except ValueError:  # detect GPUs
    strategy = tf.distribute.get_strategy() # default strategy operates on CPU and GPU
    DEVICE   = 'GPU'
    
if DEVICE == "GPU":
    physical_devices = tf.config.list_physical_devices('GPU')
    print("Num GPUs Available: ", len(physical_devices))
    try: 
        tf.config.experimental.set_memory_growth(physical_devices[0], True)
        assert tf.config.experimental.get_memory_growth(physical_devices[0])
    except: # Invalid device or cannot modify virtual devices once initialized.
        pass 
    
if MIXED_PRECISION:
    dtype = 'mixed_bfloat16' if DEVICE == "TPU" else 'mixed_float16'
    tf.keras.mixed_precision.set_global_policy(dtype)
    print('Mixed precision enabled')

if XLA_ACCELERATE:
    tf.config.optimizer.set_jit(True)
    print('Accelerated Linear Algebra enabled')
    
AUTO  = tf.data.AUTOTUNE
REPLICAS = strategy.num_replicas_in_sync

print('REPLICAS           : ', REPLICAS)
print('Eager Mode Status  : ', tf.executing_eagerly())
print('TF Cuda Built Test : ', tf.test.is_built_with_cuda)
print('TF Device Detected : ', 'Running on TPU' if DEVICE == "TPU" else tf.test.gpu_device_name())

try:
    print('TF System Cuda V.  : ', tf.sysconfig.get_build_info()["cuda_version"])
    print('TF System CudNN V. : ', tf.sysconfig.get_build_info()["cudnn_version"])
except:
    pass

**Base Config**

In [None]:
ROOT_DIR  = "ranzcr-clip-catheter-line-classification"
TRAIN_DF  = pd.read_csv(f"/kaggle/input/{ROOT_DIR}/" + 'train.csv')
SUBMIT    = pd.read_csv(f"/kaggle/input/{ROOT_DIR}/" + 'sample_submission.csv')
CLASS_MAP = pd.read_csv(f"/kaggle/input/{ROOT_DIR}/" + 'train_annotations.csv')
TARGET    = TRAIN_DF[SUBMIT.columns[1:]].values
display(TRAIN_DF.head())

if DEVICE == "TPU":
    GCS_DS_PATH    = KaggleDatasets().get_gcs_path(ROOT_DIR)
    if FILE_TYPE == 'TF_RECORD':
        TRAIN_IMG_PATH = sorted(tf.io.gfile.glob(GCS_DS_PATH + '/train_tfrecords/*.tfrec'))
        TEST_IMG_PATH  = sorted(tf.io.gfile.glob(GCS_DS_PATH + '/test_tfrecords/*.tfrec'))
    elif FILE_TYPE == 'JPEG':
        TRAIN_IMG_PATH = GCS_DS_PATH + "/train/" + TRAIN_DF['StudyInstanceUID'] + '.jpg'
        TEST_IMG_PATH  = GCS_DS_PATH + "/test/"  + SUBMIT['StudyInstanceUID']   + '.jpg'
else:
    if FILE_TYPE == 'TF_RECORD':
        TRAIN_IMG_PATH = sorted(tf.io.gfile.glob(f"/kaggle/input/{ROOT_DIR}" + '/train_tfrecords/*.tfrec'))
        TEST_IMG_PATH  = sorted(tf.io.gfile.glob(f"/kaggle/input/{ROOT_DIR}" + '/test_tfrecords/*.tfrec'))
    elif FILE_TYPE == 'JPEG':
        TRAIN_IMG_PATH = f"/kaggle/input/{ROOT_DIR}" + "/train/" + TRAIN_DF['StudyInstanceUID'] + '.jpg'
        TEST_IMG_PATH  = f"/kaggle/input/{ROOT_DIR}" + "/test/"  + SUBMIT['StudyInstanceUID']   + '.jpg'

## <font color = "seagreen">Param</font>

In [None]:
# training params 
EPOCHS       = 20
VERBOSITY    = 1
LABEL_SMOOTH = 0.01

# data params 
BATCH_SIZE   = REPLICAS * 8
IMG_SIZE     = 512
SEED         = 101

# model params 
BASE_NETS    = tf.keras.applications.EfficientNetB5

## <font color = "seagreen">RANZCR-CLiP Dataloader - tf.data API: JPEG and TFRecord</font> 

In [None]:
def get_target_size(target_size):
    if isinstance(target_size, int):
        return (target_size, target_size)
    if isinstance(target_size, list):
        return target_size
    raise ValueError('target_size must be an int, or (height, width) but got %r' % target_size)

@tf.function
def augment_image(img):
    img = tf.image.random_flip_left_right(img)
    img = tf.image.random_flip_up_down(img)
    return img

@tf.function
def read_image(file_bytes, target_size):
    img = tf.image.decode_png(file_bytes, channels=3)
    img = tf.image.resize(img, get_target_size(target_size))
    return img 

**JPEG Builder --- Option 1**

In [None]:
# ensure the image format 
def get_target_size(target_size):
    if isinstance(target_size, int):
        return (target_size, target_size)
    if isinstance(target_size, tuple):
        return target_size
    raise ValueError('target_size must be an int, or (height, width) but got %r' % image_size)

# read image and preprocessing 
def build_decoder(with_labels=True, target_size=(256, 256)):
    def decode(path):
        file_bytes = tf.io.read_file(path)
        return read_image(file_bytes, target_size=target_size)
    def decode_with_labels(path, label):
        return decode(path), label
    return decode_with_labels if with_labels else decode

# data augmentation 
def build_augmenter(with_labels=True):
    def augment_with_labels(img, label):
        return augment_image(img), label
    return augment_with_labels if with_labels else augment_image

In [None]:
# bind all the previous cell functions
def jpeg_loader(paths,
                labels     = None, 
                bsize      = 32, 
                cache      = True,
                decode_fn  = None, 
                augment_fn = None,
                augment    = True, 
                repeat     = True, 
                shuffle    = 1024, 
                cache_dir  = ""):
    if cache_dir != "" and cache is True:
        os.makedirs(cache_dir, exist_ok=True)
    
    if decode_fn is None:
        decode_fn = build_decoder(labels is not None)
    
    if augment_fn is None:
        augment_fn = build_augmenter(labels is not None)
    
    slices = paths if labels is None else (paths, labels)
    dataset = tf.data.Dataset.from_tensor_slices(slices)
    dataset = dataset.map(decode_fn, num_parallel_calls=AUTO)
    dataset = dataset.cache(cache_dir) if cache else dset
    dataset = dataset.map(augment_fn, num_parallel_calls=AUTO) if augment else dataset
    dataset = dataset.repeat() if repeat else dset
    dataset = dataset.shuffle(8 * bsize) if shuffle else dset
    dataset = dataset.batch(bsize).prefetch(AUTO)
    return dataset

**TFRecord Builder --- Option 2**

In [None]:
def fn_read_tfrecords(
    tfrecords, 
    tfrecords_schema, 
    apply_augment = False, 
    target_size   = 256,
    inference     = False
):
    read_tfrecord = tf.io.parse_single_example(tfrecords, tfrecords_schema)
    x_train = read_image(read_tfrecord['image'], target_size=target_size)

    if apply_augment:
        x_train = augment_image(x_train)
    
    if not inference:
        y_train = tf.stack([
            read_tfrecord['ETT - Abnormal'],
            read_tfrecord['ETT - Borderline'],
            read_tfrecord['ETT - Normal'],
            read_tfrecord["NGT - Abnormal"],
            read_tfrecord['NGT - Borderline'],
            read_tfrecord['NGT - Incompletely Imaged'],
            read_tfrecord['NGT - Normal'],
            read_tfrecord['CVC - Abnormal'],
            read_tfrecord['CVC - Borderline'],
            read_tfrecord['CVC - Normal'],
            read_tfrecord['Swan Ganz Catheter Present']
        ], axis=-1)
        
        return x_train, y_train
    else:
        return x_train

In [None]:
def tfrecords_loader(
    files_path,
    tfschemes    = None,
    shuffle      = True,
    cache        = True,
    repeat       = True,
    augment      = False,
    ignore_order = True,
    inference    = False
):
    dataset = tf.data.TFRecordDataset(
        files_path, num_parallel_reads=AUTO
    ) 
    
    ignore_order = tf.data.Options()
    ignore_order.experimental_deterministic = False # disable order, increase speed
    dataset = dataset.with_options(ignore_order) if ignore_order else dataset
    
    dataset = dataset.map(
        lambda x: fn_read_tfrecords(x, 
                                    tfrecords_schema,
                                    apply_augment = augment, 
                                    target_size   = IMG_SIZE,
                                    inference     = inference)
    )
    
    dataset = dataset.repeat() if repeat else dataset
    dataset = dataset.shuffle(
        8 * BATCH_SIZE
    )
    dataset = dataset.batch(
        BATCH_SIZE
    )
    dataset = dataset.prefetch(
        buffer_size=AUTO
    )
    return dataset


tfrecords_schema = {
    'StudyInstanceUID'           : tf.io.FixedLenFeature([], tf.string),
    'image'                      : tf.io.FixedLenFeature([], tf.string),
    'ETT - Abnormal'             : tf.io.FixedLenFeature([], tf.int64),
    'ETT - Borderline'           : tf.io.FixedLenFeature([], tf.int64),
    'ETT - Normal'               : tf.io.FixedLenFeature([], tf.int64),
    "NGT - Abnormal"             : tf.io.FixedLenFeature([], tf.int64),
    'NGT - Borderline'           : tf.io.FixedLenFeature([], tf.int64),
    'NGT - Incompletely Imaged'  : tf.io.FixedLenFeature([], tf.int64),
    'NGT - Normal'               : tf.io.FixedLenFeature([], tf.int64),
    'CVC - Abnormal'             : tf.io.FixedLenFeature([], tf.int64),
    'CVC - Borderline'           : tf.io.FixedLenFeature([], tf.int64),
    'CVC - Normal'               : tf.io.FixedLenFeature([], tf.int64),
    'Swan Ganz Catheter Present' : tf.io.FixedLenFeature([], tf.int64)
}

## <font color = "seagreen">Build Data Set</font>

In [None]:
if FILE_TYPE == 'TF_RECORD':
    TRAIN_SET, VALID_SET = model_selection.train_test_split(TRAIN_IMG_PATH, test_size=0.2, random_state=SEED)

    # The following two line may take a while to compute
    train_set_len = sum(1 for record in tf.data.TFRecordDataset(TRAIN_SET))
    valid_set_len = sum(1 for record in tf.data.TFRecordDataset(VALID_SET))

    TRAIN_STEPS_PER_EPOCH = int(np.ceil(train_set_len / float(BATCH_SIZE)))
    VALID_STEPS_PER_EPOCH = int(np.ceil(valid_set_len / float(BATCH_SIZE)))
    
    train_datasets = tfrecords_loader(
        TRAIN_SET, 
        tfschemes    = tfrecords_schema,
        shuffle      = True,
        repeat       = True,
        augment      = True,
        ignore_order = False,
        inference    = False
    )
    
    valid_datasets = tfrecords_loader(
        VALID_SET,
        tfschemes    = tfrecords_schema,
        shuffle      = False,
        repeat       = False,
        augment      = False,
        ignore_order = True,
        inference    = False
    )
elif FILE_TYPE == 'JPEG':
    (
        train_paths, valid_paths, 
        train_labels, valid_labels
    ) = model_selection.train_test_split(TRAIN_IMG_PATH, TARGET, test_size=0.2, random_state=SEED)
    
    TRAIN_STEPS_PER_EPOCH = int(np.ceil(train_set_len / float(BATCH_SIZE)))
    VALID_STEPS_PER_EPOCH = int(np.ceil(train_set_len / float(BATCH_SIZE)))
    
    decoder = build_decoder(with_labels=True, target_size=IMG_SIZE)
    
    train_datasets = jpeg_loader(
        train_paths, 
        train_labels, 
        bsize     = BATCH_SIZE, 
        decode_fn = decoder,
        repeat    = True, 
        shuffle   = True, 
        augment   = True
    )

    valid_datasets = jpeg_loader(
        valid_paths, 
        valid_labels, 
        bsize     = BATCH_SIZE, 
        decode_fn = decoder,
        repeat    = False, 
        shuffle   = False, 
        augment   = False
    )

## Train Data Visualization

In [None]:
image_batch, label_batch = next(iter(train_datasets))
print(image_batch.shape, label_batch.shape)

def show_batch(image_batch, label_batch=None, title=''):
    fig = plt.figure(figsize=(15, 15))
    plt.title(title)
    plt.yticks([])
    plt.xticks([])
    
    if DEVICE == "TPU":
        xy = int(np.sqrt(image_batch.shape[0]))
    else:
        xy = image_batch.shape[0] // 2
    for n in range(image_batch.shape[0]):
        ax = fig.add_subplot(xy, xy, n + 1)
        plt.imshow(image_batch[n] / 255.0)
        plt.tight_layout()
        plt.axis("off")

show_batch(image_batch.numpy(), label_batch=label_batch.numpy(), title='Augmented Training Set')

## Validation Data Visualization

In [None]:
image_batch, label_batch = next(iter(valid_datasets))
print(image_batch.shape, label_batch.shape)
show_batch(image_batch.numpy(), label_batch=label_batch.numpy(), title='Validation Set (No Augmentation)')

# <font color = "seagreen">Modeling</font>

Here, we've tried to integrate a **Multi-Attention** mechanism on the top of base model. In essence, we've added [**Convolutional Block Attention Module (CBAM)**](https://arxiv.org/abs/1807.06521) and [**DeepMoji**](https://github.com/bfelbo/DeepMoji/blob/master/deepmoji/attlayer.py)'s attention mechanism in parallel and merge them at the end. Additionally, we've modified a bit of the **CBAM** mechanism. Also the end part of **spatial** module of **CBAN** by integrating the **Global Weighted Average Pooling (GWAP)** method, mathematically as follows. The idea of **GWAP** is inspired from [Dr. Kevin](https://www.kaggle.com/kmader)'s great work, [check](https://www.kaggle.com/kmader/attention-on-pretrained-vgg16-for-bone-age).


$$ \text{GWAP}(x, y, d) = \frac{ \sum\limits_{x}\sum\limits_{y} \text{Attention}(x,y,d) \text{Feature}(x,y,d)} {\sum\limits_{x}\sum\limits_{y} \text{Attention}(x,y,d)} $$

[Attention Learning in CV (!ViT) - Details](https://www.kaggle.com/c/cassava-leaf-disease-classification/discussion/203083). 

In [None]:
class SpatialAttentionModule(keras.layers.Layer):
    def __init__(self, kernel_size=3):
        '''
        paper: https://arxiv.org/abs/1807.06521
        code: https://gist.github.com/innat/99888fa8065ecbf3ae2b297e5c10db70
        '''
        super().__init__()
        self.conv1 = keras.layers.Conv2D(64, 
                                            kernel_size=kernel_size, 
                                            use_bias=False, 
                                            kernel_initializer='he_normal',
                                            strides=1, padding='same', 
                                            activation=tf.nn.relu)
        self.conv2 = keras.layers.Conv2D(32, kernel_size=kernel_size, 
                                            use_bias=False, 
                                            kernel_initializer='he_normal',
                                            strides=1, padding='same', 
                                            activation=tf.nn.relu)
        self.conv3 = keras.layers.Conv2D(16, kernel_size=kernel_size, 
                                            use_bias=False, 
                                            kernel_initializer='he_normal',
                                            strides=1, padding='same', 
                                            activation=tf.nn.relu)
        self.conv4 = keras.layers.Conv2D(1, 
                                            kernel_size=(1, 1),  
                                            use_bias=False,
                                            kernel_initializer='he_normal',
                                            strides=1, padding='same', 
                                            activation=tf.math.sigmoid)

    def call(self, inputs):
        avg_out = tf.reduce_mean(inputs, axis=3)
        max_out = tf.reduce_max(inputs,  axis=3)
        x = tf.stack([avg_out, max_out], axis=3) 
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        return self.conv4(x)

In [None]:
# A custom layer
class ChannelAttentionModule(keras.layers.Layer):
    def __init__(self, ratio=8):
        '''paper: https://arxiv.org/abs/1807.06521
        code: https://gist.github.com/innat/99888fa8065ecbf3ae2b297e5c10db70
        '''
        super(ChannelAttentionModule, self).__init__()
        self.ratio = ratio
        self.gapavg = keras.layers.GlobalAveragePooling2D()
        self.gmpmax = keras.layers.GlobalMaxPooling2D()
        
    def build(self, input_shape):
        self.conv2 = keras.layers.Conv2D(input_shape[-1], 
                                         kernel_size=1,
                                         strides=1, 
                                         padding='same',
                                         use_bias=False, 
                                         activation=tf.nn.elu)
        super(ChannelAttentionModule, self).build(input_shape)

    def call(self, inputs):
        # compute gap and gmp pooling 
        gapavg = self.gapavg(inputs)
        gmpmax = self.gmpmax(inputs)
        gapavg = keras.layers.Reshape((1, 1, gapavg.shape[1]))(gapavg)   
        gmpmax = keras.layers.Reshape((1, 1, gmpmax.shape[1]))(gmpmax)   
        # forward passing to the respected layers
        gapavg_out = self.conv2(gapavg)
        gmpmax_out = self.conv2(gmpmax)
        return tf.math.sigmoid(gapavg_out + gmpmax_out)

In [None]:
# Original Src: https://github.com/bfelbo/DeepMoji/blob/master/deepmoji/attlayer.py
class AttentionWeightedAverage2D(keras.layers.Layer):
    def __init__(self, **kwargs):
        self.init = keras.initializers.get('uniform')
        super(AttentionWeightedAverage2D, self).__init__(**kwargs)

    def build(self, input_shape):
        self.input_spec = [layers.InputSpec(ndim=4)]
        assert len(input_shape) == 4
        self.W = self.add_weight(shape=(input_shape[3], 1),
                                 name='{}_W'.format(self.name),
                                 initializer=self.init)
        self._trainable_weights = [self.W]
        super(AttentionWeightedAverage2D, self).build(input_shape)

    def call(self, x):
        # computes a probability distribution over the timesteps
        # uses 'max trick' for numerical stability
        # reshape is done to avoid issue with Tensorflow
        # and 2-dimensional weights
        logits  = K.dot(x, self.W)
        x_shape = K.shape(x)
        logits  = K.reshape(logits, (x_shape[0], x_shape[1], x_shape[2]))
        ai      = K.exp(logits - K.max(logits, axis=[1,2], keepdims=True))
        
        att_weights    = ai / (K.sum(ai, axis=[1,2], keepdims=True) + K.epsilon())
        weighted_input = x * K.expand_dims(att_weights)
        result         = K.sum(weighted_input, axis=[1,2])
        return result

    def get_output_shape_for(self, input_shape):
        return self.compute_output_shape(input_shape)

    def compute_output_shape(self, input_shape):
        output_len = input_shape[3]
        return (input_shape[0], output_len)

# <font color = "seagreen">Keras Model Sub-Classing</font>

Until now, we've seen some of the building blocks of the complete model. Now, we will be building the entire model with these blocks. In the following, we've used `keras` model new [subclassing API](https://keras.io/guides/making_new_layers_and_models_via_subclassing/) to build the model. Though it can be easily done with functional API, but we will now choose the subclassing API approach. 

In [None]:
class RANZCRClassifier(keras.Model):
    def __init__(self, dim):
        super(RANZCRClassifier, self).__init__()
        # Defining all trainable layers in __init__ / build
        self.base  = BASE_NETS(
            input_shape=(IMG_SIZE, IMG_SIZE, 3),
            weights='imagenet' if DEVICE == "TPU" else None, 
            include_top=False
        )
        
        # Keras Built-in
        self.batch_norm  = layers.BatchNormalization()
        self.dropout     = layers.Dropout(rate=0.5)
        
        # Neck
        self.can_module   = ChannelAttentionModule()
        self.san_module_x = SpatialAttentionModule()
        self.san_module_y = SpatialAttentionModule()
        self.awn_module   = AttentionWeightedAverage2D()
        
        # Head
        self.dense_layer = layers.Dense(512, activation=tf.nn.relu)
        self.classifier  = layers.Dense(len(SUBMIT.columns[1:]), activation='sigmoid', dtype=tf.float32)
    
    def call(self, input_tensor, training=False):
        if training is None:
            training = K.learning_phase()
            
        # Base Inputs
        base_out = self.base(input_tensor)

        # Attention Modules 1
        # Channel Attention + Spatial Attention 
        canx   = self.can_module(base_out)*base_out
        spnx   = self.san_module_x(canx)*canx
        spny   = self.san_module_y(canx)

        # Global Weighted Average Pooling
        gapx   = layers.GlobalAveragePooling2D()(spnx)
        wvgx   = layers.GlobalAveragePooling2D()(spny)
        gapavg = layers.Average()([gapx, wvgx])
        
        # Attention Modules 2
        # Attention Weighted Average (AWG)
        awgavg = self.awn_module(base_out)
        # Summation of Attentions
        attns_adds = layers.Add()([gapavg, awgavg])
        
        # Tails
        x = self.batch_norm(attns_adds)
        x = self.dense_layer(x)
        x = self.dropout(x, training=training)
        x = self.classifier(x)
        
        if not training:
            return x, base_out, canx
        return x
        
    # AFAIK: The most convenient method to print model.summary() in suclassed model
    def build_graph(self):
        x = keras.Input(shape=(IMG_SIZE, IMG_SIZE,3))
        return keras.Model(inputs=[x], outputs=self.call(x))

**Build Model and Plot**

In [None]:
with strategy.scope():
    model = RANZCRClassifier((IMG_SIZE, IMG_SIZE, 3))
    model.build((None, *(IMG_SIZE, IMG_SIZE, 3)))


keras.utils.plot_model(
    model.build_graph(), 
    show_shapes      = True, 
    show_layer_names = True, 
    expand_nested    = False,                      
)

# <font color = "seagreen">Learning Rate Schedule Config</font>

**Learning Rate Schedule. LRS** It's a function that takes an epoch index and current learning rate as inputs and returns a new learning rate as output. There are many types of **LRS**, [see](https://keras.io/api/optimizers/learning_rate_schedules/). For our preference, we will subclass the [schedules.LearningRateSchedule](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/LearningRateSchedule) API to build a **WarmUp Learning Rate Schedule** of the following **LRS**.

- ExponentialDecay
- CosineDecay
- Constant
- **CosineDecay-Restart** (We will use it)

In [None]:
class WarmupLearningRateSchedule(schedules.LearningRateSchedule):
    """Provides a variety of learning rate decay schedules with warm up."""
    def __init__(self,
                 initial_lr,
                 steps_per_epoch=None,
                 lr_decay_type='exponential',
                 decay_factor=0.97,
                 decay_epochs=2.4,
                 total_steps=None,
                 warmup_epochs=5,
                 minimal_lr=0):
        super(WarmupLearningRateSchedule, self).__init__()
        self.initial_lr      = initial_lr
        self.steps_per_epoch = steps_per_epoch
        self.lr_decay_type   = lr_decay_type
        self.decay_factor    = decay_factor
        self.decay_epochs    = decay_epochs
        self.total_steps     = total_steps
        self.warmup_epochs   = warmup_epochs
        self.minimal_lr      = minimal_lr

    def __call__(self, step):
        if self.lr_decay_type == 'exponential':
            assert self.steps_per_epoch is not None
            decay_steps = self.steps_per_epoch * self.decay_epochs
            lr = schedules.ExponentialDecay(self.initial_lr, decay_steps,  self.decay_factor, staircase=True)(step)
        elif self.lr_decay_type == 'cosine':
            assert self.total_steps is not None
            lr = 0.5 * self.initial_lr * (1 + tf.cos(np.pi * tf.cast(step,  tf.float32) / self.total_steps))
            
        elif self.lr_decay_type == 'linear':
            assert self.total_steps is not None
            lr = (1.0 - tf.cast(step, tf.float32) / self.total_steps) * self.initial_lr
            
        elif self.lr_decay_type == 'constant':
            lr = self.initial_lr
        
        elif self.lr_decay_type == 'cosine_restart':
            decay_steps = self.steps_per_epoch * self.decay_epochs
            lr = tf.keras.experimental.CosineDecayRestarts(self.initial_lr, decay_steps)(step)
        else:
            assert False, 'Unknown lr_decay_type : %s' % self.lr_decay_type

        if self.minimal_lr:
            lr = tf.math.maximum(lr, self.minimal_lr)

        if self.warmup_epochs:
            warmup_steps = int(self.warmup_epochs * self.steps_per_epoch)
            warmup_lr = (
              self.initial_lr * tf.cast(step, tf.float32) /
              tf.cast(warmup_steps, tf.float32))
            lr = tf.cond(step < warmup_steps, lambda: warmup_lr, lambda: lr)

        return lr

    def get_config(self):
        return {
            'initial_lr'     : self.initial_lr,
            'steps_per_epoch': self.steps_per_epoch,
            'lr_decay_type'  : self.lr_decay_type,
            'decay_factor'   : self.decay_factor,
            'decay_epochs'   : self.decay_epochs,
            'total_steps'    : self.total_steps,
            'warmup_epochs'  : self.warmup_epochs,
            'minimal_lr'     : self.minimal_lr,
        }


lr_sched = 'cosine_restart'
lr_base  = 0.016
lr_min   = 0
lr_decay_epoch  = 2.4
lr_warmup_epoch = 5
lr_decay_factor = 0.97

scaled_lr     = lr_base * (BATCH_SIZE / 256.0)
scaled_lr_min = lr_min * (BATCH_SIZE / 256.0)
total_steps   = TRAIN_STEPS_PER_EPOCH * EPOCHS

learning_rate = WarmupLearningRateSchedule(
    scaled_lr,
    steps_per_epoch=TRAIN_STEPS_PER_EPOCH,
    decay_epochs=lr_decay_epoch,
    warmup_epochs=lr_warmup_epoch,
    decay_factor=lr_decay_factor,
    lr_decay_type=lr_sched,
    total_steps=total_steps,
    minimal_lr=scaled_lr_min)

rng = [i for i in range(total_steps)]
lr_y = [learning_rate(x) for x in rng]
plt.figure(figsize=(10, 4))
plt.plot(rng, lr_y)
plt.xlabel('Iteration',size=14)
plt.ylabel('Learning Rate',size=14)

# <font color = "seagreen">Callbacks</font>


**Callback.** A callback is an object that can perform actions at various stages of training. Here we will use some built-in callback object such as `ModelCheckpoint`, `CSVLogger` etc. Along with these built-in callback, we will aslo build a custom callback that will perform computation of the **GradCAM** on some random samples picked from validation sets. We will set an interval to show the **GradCAM** afther the end of the epoch. We'll refer this callback as **GradCAMCallback**. 

See more details about [callback](https://keras.io/api/callbacks/). 

In [None]:
class GradCAMCallback(keras.callbacks.Callback):
    def __init__(self, epoch_interval=None):
        self.epoch_interval = epoch_interval

    # ref: https://keras.io/examples/vision/grad_cam/
    def make_gradcam_heatmap(self, img_array, grad_model, pred_index=None):
        with tf.GradientTape(persistent=True) as tape:
            preds, base_top, swin_top = grad_model(img_array)
            if pred_index is None:
                pred_index = tf.argmax(preds[0])
            class_channel = preds[:, pred_index]
            
        grads = tape.gradient(class_channel, base_top)
        pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
        base_top  = base_top[0]
        heatmap_a = base_top @ pooled_grads[..., tf.newaxis]
        heatmap_a = tf.squeeze(heatmap_a)
        heatmap_a = tf.maximum(heatmap_a, 0) / tf.math.reduce_max(heatmap_a)
        heatmap_a = heatmap_a.numpy()
        
        grads = tape.gradient(class_channel, swin_top)
        pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
        swin_top = swin_top[0]
        heatmap_b = swin_top @ pooled_grads[..., tf.newaxis]
        heatmap_b = tf.squeeze(heatmap_b)
        heatmap_b = tf.maximum(heatmap_b, 0) / tf.math.reduce_max(heatmap_b)
        heatmap_b = heatmap_b.numpy()
        return heatmap_a, heatmap_b

    def on_epoch_end(self, epoch, logs=None):
        if epoch and epoch % self.epoch_interval == 0:
            # pick some samples (i.e here 5) to compute grad-cam and plot
            image_batch, label_batch = next(iter(valid_datasets))
            image_batch, label_batch = image_batch[:5], label_batch[:5]

            for sample, label in zip(image_batch, label_batch):
                # make grad-cam 
                img_array = sample[tf.newaxis, ...] 
                heatmap_a, heatmap_b = self.make_gradcam_heatmap(img_array, model)

                # overaly heatmap and input sample 
                overaly_a = self.save_and_display_gradcam(sample, heatmap_a)
                overlay_b = self.save_and_display_gradcam(sample, heatmap_b)

                # passing three 2D samples to plot
                self.plot_stuff(img_array, overaly_a, overlay_b)


    # ref: https://keras.io/examples/vision/grad_cam/
    def save_and_display_gradcam(self, 
                                 img, 
                                 heatmap, 
                                 target=None, 
                                 pred=None,
                                 cam_path="cam.jpg", 
                                 alpha=0.6, 
                                 plot=None):
        # Rescale heatmap to a range 0-255
        heatmap = np.uint8(255 * heatmap)

        # Use jet colormap to colorize heatmap
        jet = cm.get_cmap("jet") 

        # Use RGB values of the colormap
        jet_colors  = jet(np.arange(256))[:, :3]
        jet_heatmap = jet_colors[heatmap]

        # Create an image with RGB colorized heatmap
        # jet_heatmap = keras.utils.array_to_img(jet_heatmap)
        jet_heatmap = keras.preprocessing.image.array_to_img(jet_heatmap)
        
        # resize to input image
        jet_heatmap = jet_heatmap.resize((img.shape[0], img.shape[1]))
    
        # jet_heatmap = keras.utils.img_to_array(jet_heatmap)
        jet_heatmap = keras.preprocessing.image.img_to_array(jet_heatmap)

        # Superimpose the heatmap on original image
        superimposed_img = img + jet_heatmap * alpha
        
        # superimposed_img = keras.utils.array_to_img(superimposed_img)
        superimposed_img = keras.preprocessing.image.array_to_img(superimposed_img)
        return superimposed_img
        
    def plot_stuff(self, inputs, features_a, features_b):
        plt.figure(figsize=(25, 25))
        
        plt.subplot(1, 3, 1)
        plt.axis('off')
        plt.imshow(tf.squeeze(inputs/255, axis=0))
        plt.title('Input')
        
        plt.subplot(1, 3, 2)
        plt.axis('off')
        plt.imshow(features_a)
        plt.title('BaseModule')
        
        plt.subplot(1, 3, 3)
        plt.axis('off')
        plt.imshow(features_b)
        plt.title('HeadModule (CAN)')
        plt.show() 

In [None]:
def get_callbacks():
    # save model checkpoint based on monitored metrics 
    checkpoint = keras.callbacks.ModelCheckpoint('model.h5', 
                                                 save_best_only    = True, 
                                                 save_weights_only = True,
                                                 monitor  = 'val_auc', 
                                                 mode     = 'max')
    # stop training safely if nan loss occurs
    stop_if_nan = keras.callbacks.TerminateOnNaN()
    
    # save training logs for post processing or post eda 
    save_log = keras.callbacks.CSVLogger('history.csv')
    
    return [checkpoint, stop_if_nan, save_log, GradCAMCallback(epoch_interval=4)]

# <font color = "seagreen">Compile & Training</font>

In compile, we will pass `learning_rate` achieved from `WarmupLearningRateSchedule`. Also we will use `label smoothing` technique. And lastly, as it's a mult-label classification task and competition metric is `AUC`, we will set those accordingly. 

In [None]:
model.compile(
    optimizer = keras.optimizers.Adam(learning_rate),
    loss = keras.losses.BinaryCrossentropy(label_smoothing=LABEL_SMOOTH), 
    metrics = [keras.metrics.AUC(multi_label=True)],
    steps_per_execution=REPLICAS
)

if DEVICE == "TPU":
    history = model.fit(
        train_datasets, 
        epochs    = EPOCHS,
        verbose   = VERBOSITY,
        callbacks = get_callbacks(),
        steps_per_epoch  = TRAIN_STEPS_PER_EPOCH,
        validation_data  = valid_datasets,
        validation_steps = VALID_STEPS_PER_EPOCH
    )
else:
    print('Please use TPU for training.\nReloading from saved weights for Inference.')
    model.load_weights('../input/multiattentioncheckwg/model_new.h5')

# <font color = "seagreen">Inference</font>

In [None]:
if FILE_TYPE != 'TF_RECORD':
    test_decoder = build_decoder(with_labels=False, target_size=IMG_SIZE)
    test_dataset = jpeg_loader(
        TEST_IMG_PATH, 
        bsize     = BATCH_SIZE, 
        repeat    = False, 
        shuffle   = False, 
        augment   = False, 
        cache     = False, 
        decode_fn = test_decoder
    )
    num_files = len(TEST_IMG_PATH)
    pred_step = int(np.ceil(num_files / float(BATCH_SIZE)))

else:
    tfrecords_schema = {
        "StudyInstanceUID" : tf.io.FixedLenFeature([], tf.string),
        "image"            : tf.io.FixedLenFeature([], tf.string)
    }
    test_dataset = tfrecords_loader(
            TEST_IMG_PATH, 
            tfschemes   = tfrecords_schema,
            shuffle     = False,
            repeat      = False,
            augment     = False,
            ignore_order= False,
            inference   = True
        )
    num_files = sum(1 for record in tf.data.TFRecordDataset(TEST_IMG_PATH))
    pred_step = int(np.ceil(num_files / float(BATCH_SIZE / REPLICAS)))
    
    
image_batch = next(iter(test_dataset))
print(image_batch.shape)
show_batch(image_batch.numpy(), title='Test Sets')

In [None]:
if DEVICE != 'TPU':
    '''It's a code competition, to submit, you must disable internet, 
    means, either use GPU or CPU for model inference w/o internet.
    '''
    SUBMIT[SUBMIT.columns[1:]] = model.predict(test_dataset, steps=pred_step, verbose=1)
    SUBMIT.to_csv('submission11.csv', index=False)
    display(SUBMIT.head())

# Final Note

1. The best score the **multi-branch soft attention net** achieved is `0.95656` on the test dataset which is without the presence of any sort of ensembling and test time augmentation. With extensive experiment and careful integreation, this can achieve much further better result and able to produce strong visual interpretation.
2. How to use it on my own dataset?
    - First, understand the competition task and its data format. And try to relate with yours.
    - Second, run this notebook successfully on the competition data.
    - Lastly, replace the dataset with yours.
3. The competition data also provides some segmentation annotation for potential segmentation modeling. Most of the top solution used it. Check out the following series of notebook, [tt195361](https://www.kaggle.com/tt195361) who reproduces the 1st place solution in `TensorFlow.Keras`. 
    - [Make Mask](https://www.kaggle.com/tt195361/ranzcr-1st-place-solution-by-tf-1-make-masks) - [Segmentation Model](https://www.kaggle.com/tt195361/ranzcr-1st-place-solution-by-tf-2-seg-model)
    - [Gen Mask](https://www.kaggle.com/tt195361/ranzcr-1st-place-solution-by-tf-3-gen-masks) - [Cls Model](https://www.kaggle.com/tt195361/ranzcr-1st-place-solution-by-tf-4-cls-model)
    - [Inference](https://www.kaggle.com/tt195361/ranzcr-1st-place-solution-by-tf-5-inference)