In [1]:
!pip install efficientnet -q



In [2]:
import os

import efficientnet.tfkeras as efn
import numpy as np
import pandas as pd
from kaggle_datasets import KaggleDatasets
from sklearn.model_selection import train_test_split
import tensorflow as tf
from sklearn.model_selection import GroupKFold
import math

In [3]:
import tensorflow_hub as hub

In [4]:
def random_float(minval=0.0, maxval=1.0):
    rnd = tf.random.uniform(
        [], minval=minval, maxval=maxval, dtype=tf.float32)
    return rnd

def choice(p, image1, mask1, image2, mask2):
    rnd = random_float()
    image = tf.where(rnd <= p, image1, image2)
    mask = tf.where(rnd <= p, mask1, mask2)
    return image, mask

In [5]:
image_size = 768

In [6]:
def HorizontalFlip(p):
    def _do_horizontal_flip(image, mask):
        aug_image = tf.image.flip_left_right(image)
        aug_mask = tf.image.flip_left_right(mask)
        return choice(p, aug_image, aug_mask, image, mask)
    return _do_horizontal_flip
def RandomBrightness(max_delta, p):
    def _do_random_brightness(image, mask):
        aug_image = tf.image.random_brightness(image, max_delta)
        return choice(p, aug_image, mask, image, mask)
    return _do_random_brightness
def RandomContrast(lower, upper, p):
    def _do_random_contrast(image, mask):
        aug_image = tf.image.random_contrast(image, lower, upper)
        return choice(p, aug_image, mask, image, mask)
    return _do_random_contrast
def initUndistortRectifyMap(height, width, k, dx, dy):
    height = tf.cast(height, dtype=tf.float32)
    width = tf.cast(width, dtype=tf.float32)
    
    f_x = width
    f_y = height
    c_x = width * 0.5 + dx
    c_y = height * 0.5 + dy
    
    f_dash_x = f_x
    c_dash_x = (width - 1.0) * 0.5
    f_dash_y = f_y
    c_dash_y = (height - 1.0) * 0.5

    h_rng = tf.range(height, dtype=tf.float32)
    w_rng = tf.range(width, dtype=tf.float32)
    v, u = tf.meshgrid(h_rng, w_rng)
    
    x = (u - c_dash_x) / f_dash_x
    y = (v - c_dash_y) / f_dash_y
    x_dash = x
    y_dash = y
    
    r_2 = x_dash * x_dash + y_dash * y_dash
    r_4 = r_2 * r_2
    x_dash_dash = x_dash * (1 + k*r_2 + k*r_4)
    y_dash_dash = y_dash * (1 + k*r_2 + k*r_4)

    map_x = x_dash_dash * f_x + c_x
    map_y = y_dash_dash * f_y + c_y
    return map_x, map_y

def OpticalDistortion(distort_limit, shift_limit, p=1.0):
    def _do_optical_distortion(image, mask):
        k = random_float(-distort_limit, distort_limit)
        dx = random_float(-shift_limit, shift_limit)
        dy = random_float(-shift_limit, shift_limit)
        image_shape = tf.shape(image)
        height = image_shape[0]
        width = image_shape[1]
        map_x, map_y = initUndistortRectifyMap(
            height, width, k, dx, dy)
        aug_image = remap(
            image, height, width, map_x, map_y, mode='mirror')
        aug_mask = remap(
            mask, height, width, map_x, map_y, mode='mirror')
        return choice(p, aug_image, aug_mask, image, mask)
    return _do_optical_distortion

def make_grid_distorted_maps(height, width, num_steps, xsteps, ysteps):
    def _make_maps_before_last(size, step, steps): # size=512, step=102,
                                                   # steps.shape=[num_steps]
        step_rep = tf.repeat(step, num_steps)  # [102, 102, 102, 102, 102]
        step_rep_f = tf.cast(step_rep, dtype=tf.float32)
        step_inc = step_rep_f * steps          # [102*s_0, ..., 102*s_4]
        cur = tf.math.cumsum(step_inc)         # [si_0, si_0 + si_1, ... ]
        zero = tf.zeros([1], dtype=tf.float32)
        prev = tf.concat([ zero, cur[ :-1] ], axis=0) # [0, c_0, ..., c_3]
        prev_cur = tf.stack([prev, cur])       # [[p_0, p_1, ...], [c_0, c_1, ...]]
        ranges = tf.transpose(prev_cur)        # [[p_0, c_0], [p_1, c_1], ... ]

        def _linspace_range(rng):
            return tf.linspace(rng[0], rng[1], step)
 
        maps_stack = tf.map_fn(_linspace_range, ranges)
        maps = tf.reshape(maps_stack, [-1])    # [-1] flatten into 1-D
        return maps
    
    def _make_last_map(size, step, last_start):
        last_step = size - step * num_steps  # 512 - 102*5 = 2 
        size_f = tf.cast(size, dtype=tf.float32)
        last_map = tf.linspace(last_start, size_f-1.0, last_step)
        return last_map
    
    def _make_distorted_map(size, steps):
        step = size // num_steps               # step=102 
        maps_before_last = _make_maps_before_last(size, step, steps[ :-1 ])
        last_map = _make_last_map(size, step, maps_before_last[-1])
        distorted_map = tf.concat([maps_before_last, last_map], axis=0)
        return distorted_map

    xx = _make_distorted_map(width, xsteps)
    yy = _make_distorted_map(height, ysteps)
    map_y, map_x = tf.meshgrid(xx, yy)
    return map_x, map_y

def GridDistortion(num_steps, distort_limit, p=1.0):
    def _do_grid_distortion(image, mask):
        xsteps = tf.random.uniform(
            [num_steps + 1],
            minval=1.0 - distort_limit,
            maxval=1.0 + distort_limit)
        ysteps = tf.random.uniform(
            [num_steps + 1],
            minval=1.0 - distort_limit,
            maxval=1.0 + distort_limit)

        image_shape = tf.shape(image)
        height = image_shape[0]
        width = image_shape[1]
        map_x, map_y = make_grid_distorted_maps(
            height, width, num_steps, xsteps, ysteps)
        aug_image = remap(
            image, height, width, map_x, map_y, mode='mirror')
        aug_mask = remap(
            mask, height, width, map_x, map_y, mode='mirror')
        return choice(p, aug_image, aug_mask, image, mask)
    return _do_grid_distortion
def OneOf(trans1, trans2, p):
    def _do_one_of(image, mask):
        image1, mask1 = trans1(image, mask)
        image2, mask2 = trans2(image, mask)
        aug_image, aug_mask = choice(
            0.5, image1, mask1, image2, mask2)
        return choice(p, aug_image, aug_mask, image, mask)
    return _do_one_of
def HueSaturationValue(
        hue_shift_limit, sat_shift_limit, val_shift_limit, p):
    def _do_hue_saturation_value(image, mask):
        hsv_image = tf.image.rgb_to_hsv(image)
        hue_shift = random_float(-hue_shift_limit, hue_shift_limit)
        sat_shift = random_float(-sat_shift_limit, sat_shift_limit)
        val_shift = random_float(-val_shift_limit, val_shift_limit)

        hue_values = (hsv_image[ ... , :1 ] + hue_shift) % 1.0
        sat_values = tf.clip_by_value(
            hsv_image[ ... , 1:2 ] + sat_shift, 0.0, 1.0)
        val_values = tf.clip_by_value(
            hsv_image[ ... , 2: ] + val_shift, 0.0, 1.0)
        hsv_image = tf.concat(
            [hue_values, sat_values, val_values], axis=-1)
        aug_image = tf.image.hsv_to_rgb(hsv_image)
        return choice(p, aug_image, mask, image, mask)
    return _do_hue_saturation_value
def affine_transform(height, width, tx, ty, z, theta):
    cx = (width - 1.0) * 0.5
    cy = (height - 1.0) * 0.5
    
    center_shift_mat = tf.convert_to_tensor([
        [1.0, 0.0, -cx],
        [0.0, 1.0, -cy],
        [0.0, 0.0, 1.0]], dtype=tf.float32)
    trans_mat = center_shift_mat
    
    rot_rad = -2.0 * math.pi * theta / 360.0
    roration_mat = tf.convert_to_tensor([
        [tf.math.cos(rot_rad), tf.math.sin(rot_rad), 0.0],
        [-tf.math.sin(rot_rad), tf.math.cos(rot_rad), 0.0],
        [0.0, 0.0, 1.0]], dtype=tf.float32)
    trans_mat = tf.linalg.matmul(roration_mat, trans_mat)
    
    shift_mat = tf.convert_to_tensor([
        [1.0, 0.0, cx - tx],
        [0.0, 1.0, cy - ty],
        [0.0, 0.0, 1.0]], dtype=tf.float32)
    trans_mat = tf.linalg.matmul(shift_mat, trans_mat)

    zoom_mat = tf.convert_to_tensor([
        [1.0 / z, 0.0, 0.0],
        [0.0, 1.0 / z, 0.0],
        [0.0, 0.0, 1.0]], dtype=tf.float32)
    trans_mat = tf.linalg.matmul(zoom_mat, trans_mat)
    
    h_rng = tf.range(height, dtype=tf.float32)
    w_rng = tf.range(width, dtype=tf.float32)
    y, x = tf.meshgrid(h_rng, w_rng)
    x = tf.reshape(x, [-1])
    y = tf.reshape(y, [-1])
    ones = tf.ones_like(x)
    coord_mat = tf.stack([x, y, ones])
    
    res_mat = tf.linalg.matmul(trans_mat, coord_mat)
    map_x = res_mat[0]
    map_y = res_mat[1]
    return map_x, map_y

def ShiftScaleRotate(
        shift_limit, scale_limit, rotate_limit, p):
    def _do_shift_scale_rotate(image, mask):
        image_shape = tf.shape(image)
        height_i = image_shape[0]
        width_i = image_shape[1]
        height_f = tf.cast(height_i, dtype=tf.float32)
        width_f = tf.cast(width_i, dtype=tf.float32)
        tx = width_f * random_float(-shift_limit, shift_limit)
        ty = height_f * random_float(-shift_limit, shift_limit)
        z = random_float(1.0 - scale_limit, 1.0 + scale_limit)
        theta = random_float(-rotate_limit, rotate_limit)

        map_x, map_y = affine_transform(
            height_f, width_f, tx, ty, z, theta)
        aug_image = remap(
            image, height_i, width_i, map_x, map_y, mode='constant')
        aug_mask = remap(
            mask, height_i, width_i, map_x, map_y, mode='constant')
        return choice(p, aug_image, aug_mask, image, mask)
    return _do_shift_scale_rotate

def randints(shape, minval, maxval):
    # maxval+1 to include maxval for the result.
    # generated range is [minval, maxval) (maxval is not included)
    return tf.random.uniform(
        shape=shape, minval=minval, maxval=maxval+1, dtype=tf.int32)

def make_range_masks(size, starts, ends):
    indice = tf.range(size, dtype=tf.int32)
    start_masks = (
        starts[ : , tf.newaxis] <= indice[  tf.newaxis, : ])
    end_masks = (
        indice[ tf.newaxis, : ] <= ends[ : , tf.newaxis])
    range_masks = start_masks & end_masks
    return range_masks

def make_region_mask(tops, lefts, bottoms, rights):
    row_masks = make_range_masks(image_size, tops, bottoms)
    col_masks = make_range_masks(image_size, lefts, rights)
    region_masks = \
        row_masks[ : , : , tf.newaxis ] & \
        col_masks[ : , tf.newaxis, : ]
    region_mask = tf.math.reduce_any(region_masks, axis=0)
    region_mask = region_mask[ : , : , tf.newaxis]
    return region_mask

def Cutout(num_cuts, mask_factor, p):
    def _do_cutout(image, mask):
        image_shape = tf.shape(image)
        height_i = image_shape[0]
        width_i = image_shape[1]
        height_f = tf.cast(height_i, dtype=tf.float32)
        width_f = tf.cast(width_i, dtype=tf.float32)
        cut_h = tf.cast(height_f * mask_factor, dtype=tf.int32)
        cut_w = tf.cast(width_f * mask_factor, dtype=tf.int32)

        y_centers = randints([num_cuts], 0, image_size - 1)
        x_centers = randints([num_cuts], 0, image_size - 1)
        tops = tf.math.maximum(y_centers - cut_h//2, 0)
        lefts = tf.math.maximum(x_centers - cut_w//2, 0)
        bottoms = tf.math.minimum(tops + cut_h, height_i - 1)
        rights = tf.math.minimum(lefts + cut_w, width_i - 1)

        cut_region = make_region_mask(tops, lefts, bottoms, rights)
        mask_value = tf.constant(0.0, dtype=tf.float32)
        aug_image = tf.where(cut_region, mask_value, image)
        return choice(p, aug_image, mask, image, mask)
    return _do_cutout

In [7]:
 def mirror_boundary(v, max_v):
    # v % (max_v*2.0-2.0) ==> v % (512*2-2) ==> [0..1022]
    # [0..1022] - (max_v-1.0) ==> [0..1022] - 511 ==> [-511..511]
    # -1.0 * abs([-511..511]) ==> [-511..0]
    # [-511..0] + max_v - 1.0 ==> [-511..0] + 511 ==> [0..511]
    mirror_v = -1.0 * tf.math.abs(
        v % (max_v*2.0-2.0) - (max_v-1.0)) + max_v-1.0
    return mirror_v

def clip_boundary(v, max_v):
    clip_v = tf.clip_by_value(v, 0.0, max_v-1.0)
    return clip_v

def interpolate_bilinear(image, map_x, map_y):
    def _gather(image, map_x, map_y):
        map_stack = tf.stack([map_x, map_y]) # [ 2, height, width ]
        map_indices = tf.transpose(
            map_stack, perm=[1, 2, 0])       # [ height, width, 2 ]
        map_indices = tf.cast(map_indices, dtype=tf.int32)
        gather_image = tf.gather_nd(image, map_indices)
        return gather_image
    
    ll = _gather(image, tf.math.floor(map_x), tf.math.floor(map_y))
    lr = _gather(image, tf.math.ceil(map_x), tf.math.floor(map_y))
    ul = _gather(image, tf.math.floor(map_x), tf.math.ceil(map_y))
    ur = _gather(image, tf.math.ceil(map_x), tf.math.ceil(map_y))
    
    fraction_x = tf.expand_dims(map_x % 1.0, axis=-1) # [h, w, 1]
    int_l = (lr - ll) * fraction_x + ll
    int_u = (ur - ul) * fraction_x + ul
    
    fraction_y = tf.expand_dims(map_y % 1.0, axis=-1) # [h, w, 1]
    interpolate_image = (int_u - int_l) * fraction_y + int_l
    return interpolate_image

def remap(image, height, width, map_x, map_y, mode):
    assert \
        mode in ('mirror', 'constant'), \
        "mode is neither 'mirror' nor 'constant'"

    height_f = tf.cast(height, dtype=tf.float32)
    width_f = tf.cast(width, dtype=tf.float32)
    map_x = tf.reshape(map_x, shape=[height, width])
    map_y = tf.reshape(map_y, shape=[height, width])
    if mode == 'mirror':
        b_map_x = mirror_boundary(map_x, width_f)
        b_map_y = mirror_boundary(map_y, height_f)
    else:
        b_map_x = clip_boundary(map_x, width_f)
        b_map_y = clip_boundary(map_y, height_f)
        
    image_remap = interpolate_bilinear(image, b_map_x, b_map_y)
    
    if mode == 'constant':
        map_stack = tf.stack([map_x, map_y])
        map_indices = tf.transpose(map_stack, perm=[1, 2, 0])
        x_ge_0 = (0.0 <= map_indices[ : , : , 0])    # [h, w]
        x_lt_w = (map_indices[ : , : , 0] < width_f)
        y_ge_0 = (0.0 <= map_indices[ : , : , 1])
        y_lt_h = (map_indices[ : , : , 1] < height_f)
        inside_boundary = tf.math.reduce_all(
            tf.stack([x_ge_0, x_lt_w, y_ge_0, y_lt_h]), axis=0) # [h, w]
        inside_boundary = inside_boundary[ : , : , tf.newaxis]  # [h, w, 1]
        image_remap = tf.where(inside_boundary, image_remap, 0.0)

    return image_remap

In [8]:
horizontal_flip = HorizontalFlip(p=0.5)
random_brightness = RandomBrightness(max_delta=0.2, p=0.75)
random_contrast = RandomContrast(lower=0.2, upper=0.8, p=0.75)
optical_distortion = OpticalDistortion(
    distort_limit=1.0, shift_limit=0.05, p=0.75)
grid_distortion = GridDistortion(
    num_steps=5, distort_limit=1.0, p=0.75)
one_of_opt_grid_distortion = OneOf(
    optical_distortion, grid_distortion, p=0.75)


hue_saturation_value = HueSaturationValue(
    hue_shift_limit=0.2, sat_shift_limit=0.3,
    val_shift_limit=0.2, p=0.75)


shift_scale_rotate = ShiftScaleRotate(
    shift_limit=0.2, scale_limit=0.3, rotate_limit=30, p=0.75)


cut_out = Cutout(num_cuts=1, mask_factor=0.4, p=0.75)

In [9]:
def auto_select_accelerator():
    try:
        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
        tf.config.experimental_connect_to_cluster(tpu)
        tf.tpu.experimental.initialize_tpu_system(tpu)
        strategy = tf.distribute.experimental.TPUStrategy(tpu)
        print("Running on TPU:", tpu.master())
    except ValueError:
        strategy = tf.distribute.get_strategy()
    print(f"Running on {strategy.num_replicas_in_sync} replicas")
    
    return strategy


def build_decoder(with_labels=True, target_size=(256, 256), ext='jpg'):
    def decode(path):
        file_bytes = tf.io.read_file(path)

        if ext == 'png':
            img = tf.image.decode_png(file_bytes, channels=3)
        elif ext in ['jpg', 'jpeg']:
            img = tf.image.decode_jpeg(file_bytes, channels=3)
        else:
            raise ValueError("Image extension not supported")
        img = tf.cast(img, tf.float32) / 255.0
        img = tf.image.resize(img, target_size)

        return img
    
    def decode_with_labels(path, label):
        return decode(path), label
    
    return decode_with_labels if with_labels else decode


def build_augmenter(with_labels=True):
   

    def augment(img):
        #img = tf.image.random_flip_left_right(img)
        #img = tf.image.random_flip_up_down(img)
        mask = img
        img, _ = horizontal_flip(img, mask)
        img, _ = random_brightness(img, mask)
        img, _ = random_contrast(img, mask)
        img, _ = one_of_opt_grid_distortion(img, mask)
        img, _ = hue_saturation_value(img, mask)
        img, _ = shift_scale_rotate(img, mask)
        img, _ = cut_out(img, mask)

        return img
    
    def augment_with_labels(img, label):
        return augment(img), label
    
    return augment_with_labels if with_labels else augment


def build_dataset(paths, labels=None, bsize=128, cache=True,
                  decode_fn=None, augment_fn=None,
                  augment=True, repeat=True, shuffle=1024, 
                  cache_dir=""):
    if cache_dir != "" and cache is True:
        os.makedirs(cache_dir, exist_ok=True)
    
    if decode_fn is None:
        decode_fn = build_decoder(labels is not None)
    
    if augment_fn is None:
        augment_fn = build_augmenter(labels is not None)
    
    AUTO = tf.data.experimental.AUTOTUNE
    slices = paths if labels is None else (paths, labels)
    
    dset = tf.data.Dataset.from_tensor_slices(slices)
    dset = dset.map(decode_fn, num_parallel_calls=AUTO)
    dset = dset.cache(cache_dir) if cache else dset
    dset = dset.map(augment_fn, num_parallel_calls=AUTO) if augment else dset
    dset = dset.repeat() if repeat else dset
    dset = dset.shuffle(shuffle) if shuffle else dset
    dset = dset.batch(bsize).prefetch(AUTO)
    
    return dset

In [10]:
COMPETITION_NAME = "siimcovid19-512-img-png-600-study-png"
strategy = auto_select_accelerator()
BATCH_SIZE = strategy.num_replicas_in_sync * 16
GCS_DS_PATH = KaggleDatasets().get_gcs_path(COMPETITION_NAME)

Running on TPU: grpc://10.0.0.2:8470
Running on 8 replicas


In [11]:
load_dir = f"/kaggle/input/{COMPETITION_NAME}/"
df = pd.read_csv('../input/siim-covid19-detection/train_study_level.csv')
label_cols = df.columns[1:5]


In [12]:
gkf  = GroupKFold(n_splits = 5)
df['fold'] = -1
for fold, (train_idx, val_idx) in enumerate(gkf.split(df, groups = df.id.tolist())):
    df.loc[val_idx, 'fold'] = fold

In [13]:
from tensorflow.keras import backend as K

In [14]:
class SpatialAttentionModule(tf.keras.layers.Layer):
    def __init__(self, kernel_size=3):
        '''
        paper: https://arxiv.org/abs/1807.06521
        code: https://gist.github.com/innat/99888fa8065ecbf3ae2b297e5c10db70
        '''
        super(SpatialAttentionModule, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(64, kernel_size=kernel_size, 
                                            use_bias=False, 
                                            kernel_initializer='he_normal',
                                            strides=1, padding='same', 
                                            activation=tf.nn.relu6)
        self.conv2 = tf.keras.layers.Conv2D(32, kernel_size=kernel_size, 
                                            use_bias=False, 
                                            kernel_initializer='he_normal',
                                            strides=1, padding='same', 
                                            activation=tf.nn.relu6)
        self.conv3 = tf.keras.layers.Conv2D(16, kernel_size=kernel_size, 
                                            use_bias=False, 
                                            kernel_initializer='he_normal',
                                            strides=1, padding='same', 
                                            activation=tf.nn.relu6)
        self.conv4 = tf.keras.layers.Conv2D(1, kernel_size=kernel_size,  
                                            use_bias=False,
                                            kernel_initializer='he_normal',
                                            strides=1, padding='same', 
                                            activation=tf.math.sigmoid)

    def call(self, inputs):
        avg_out = tf.reduce_mean(inputs, axis=3)
        max_out = tf.reduce_max(inputs,  axis=3)
        x = tf.stack([avg_out, max_out], axis=3) 
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        return self.conv4(x)
    
# A custom layer
class ChannelAttentionModule(tf.keras.layers.Layer):
    def __init__(self, ratio=8):
        '''
        paper: https://arxiv.org/abs/1807.06521
        code: https://gist.github.com/innat/99888fa8065ecbf3ae2b297e5c10db70
        '''
        super(ChannelAttentionModule, self).__init__()
        self.ratio = ratio
        self.gapavg = tf.keras.layers.GlobalAveragePooling2D()
        self.gmpmax = tf.keras.layers.GlobalMaxPooling2D()
        
    def build(self, input_shape):
        self.conv1 = tf.keras.layers.Conv2D(input_shape[-1]//self.ratio, 
                                            kernel_size=1, 
                                            strides=1, padding='same',
                                            use_bias=True, activation=tf.nn.relu)
    
        self.conv2 = tf.keras.layers.Conv2D(input_shape[-1], 
                                            kernel_size=1, 
                                            strides=1, padding='same',
                                            use_bias=True, activation=tf.nn.relu)
        super(ChannelAttentionModule, self).build(input_shape)

    def call(self, inputs):
        # compute gap and gmp pooling 
        gapavg = self.gapavg(inputs)
        gmpmax = self.gmpmax(inputs)
        gapavg = tf.keras.layers.Reshape((1, 1, gapavg.shape[1]))(gapavg)   
        gmpmax = tf.keras.layers.Reshape((1, 1, gmpmax.shape[1]))(gmpmax)   
        # forward passing to the respected layers
        gapavg_out = self.conv2(self.conv1(gapavg))
        gmpmax_out = self.conv2(self.conv1(gmpmax))
        return tf.math.sigmoid(gapavg_out + gmpmax_out)
    
    def get_output_shape_for(self, input_shape):
        return self.compute_output_shape(input_shape)

    def compute_output_shape(self, input_shape):
        output_len = input_shape[3]
        return (input_shape[0], output_len)

# Original Src: https://github.com/bfelbo/DeepMoji/blob/master/deepmoji/attlayer.py
# Adoped and Modified: https://www.kaggle.com/c/human-protein-atlas-image-classification/discussion/77269#454482
class AttentionWeightedAverage2D(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        self.init = tf.keras.initializers.get('uniform')
        super(AttentionWeightedAverage2D, self).__init__(** kwargs)

    def build(self, input_shape):
        self.input_spec = [tf.keras.layers.InputSpec(ndim=4)]
        assert len(input_shape) == 4
        self.W = self.add_weight(shape=(input_shape[3], 1),
                                 name='{}_W'.format(self.name),
                                 initializer=self.init)
        self._trainable_weights = [self.W]
        super(AttentionWeightedAverage2D, self).build(input_shape)

    def call(self, x):
        # computes a probability distribution over the timesteps
        # uses 'max trick' for numerical stability
        # reshape is done to avoid issue with Tensorflow
        # and 2-dimensional weights
        logits  = K.dot(x, self.W)
        x_shape = K.shape(x)
        logits  = K.reshape(logits, (x_shape[0], x_shape[1], x_shape[2]))
        ai      = K.exp(logits - K.max(logits, axis=[1,2], keepdims=True))
        
        att_weights    = ai / (K.sum(ai, axis=[1,2], keepdims=True) + K.epsilon())
        weighted_input = x * K.expand_dims(att_weights)
        result         = K.sum(weighted_input, axis=[1,2])
        return result

    def get_output_shape_for(self, input_shape):
        return self.compute_output_shape(input_shape)

    def compute_output_shape(self, input_shape):
        output_len = input_shape[3]
        return (input_shape[0], output_len)

class RANZCRClassifier(tf.keras.Model):
    def __init__(self, dim):
        super(RANZCRClassifier, self).__init__()
        # Defining All Layers in __init__
        # Layer of Block
        self.Base  = efn.EfficientNetB5(
            input_shape=(640,
                         640, 3),
            weights=None,
            include_top=False)
        # Keras Built-in
        self.GAP1 = tf.keras.layers.GlobalAveragePooling2D()
        self.GAP2 = tf.keras.layers.GlobalAveragePooling2D()
        self.BAT  = tf.keras.layers.BatchNormalization()
        self.ADD  = tf.keras.layers.Add()
        self.AVG  = tf.keras.layers.Average()
        self.DROP = tf.keras.layers.Dropout(rate=0.5)
        # Customs
        self.CAN  = ChannelAttentionModule()
        self.SPN1 = SpatialAttentionModule()
        self.SPN2 = SpatialAttentionModule()
        self.AWG  = AttentionWeightedAverage2D()
        # Tail
        self.DENS = tf.keras.layers.Dense(512, activation=tf.nn.relu)
        self.OUT  = tf.keras.layers.Dense(4, 
                                          activation='softmax', 
                                          #kernel_regularizer=tf.keras.regularizers.l2(0.0001),
                                          dtype=tf.float32)
    
    def call(self, input_tensor, training=False):
        # Base Inputs
        x      = self.Base(input_tensor)
        # Attention Modules 1
        # Channel Attention + Spatial Attention 
        canx   = self.CAN(x)*x
        spnx   = self.SPN1(canx)*canx
        # Global Weighted Average Poolin
        gapx   = self.GAP1(spnx)
        wvgx   = self.GAP2(self.SPN2(canx))
        gapavg = self.AVG([gapx, wvgx])
        # Attention Modules 2
        # Attention Weighted Average (AWG)
        awgavg = self.AWG(x)
        # Summation of Attentions
        x = self.ADD([gapavg, awgavg])
        # Tails
        x = self.BAT(x)
        x = self.DENS(x)
        x  = self.DROP(x, training=training)
        return self.OUT(x)
    
    # AFAIK: The most convenient method to print model.summary() in suclassed model
    def build_graph(self):
        x = tf.keras.layers.Input(shape=(TrainConfig.IMG_SIZE['0'],
                                         TrainConfig.IMG_SIZE['0'],3))
        return tf.keras.Model(inputs=[x], outputs=self.call(x))
    

 model = tf.keras.Sequential([
            efn.EfficientNetB7(
                input_shape=(IMSIZE[IMS], IMSIZE[IMS], 3),
                weights='imagenet',
                include_top=False),
            tf.keras.layers.GlobalAveragePooling2D(),
            tf.keras.layers.Dense(n_labels, activation='softmax')
        ])

In [15]:
trainnow = [0,1,2]

In [16]:
for i in range(5):
    if i in trainnow:

        valid_paths = GCS_DS_PATH + '/study/' + df[df['fold'] == i]['id'] + '.png' #"/train/"
        train_paths = GCS_DS_PATH + '/study/' + df[df['fold'] != i]['id'] + '.png' #"/train/" 
        valid_labels = df[df['fold'] == i][label_cols].values
        train_labels = df[df['fold'] != i][label_cols].values

        IMSIZE = (224, 240, 260, 300, 380, 456, 528, 768)
        IMS = 7

        decoder = build_decoder(with_labels=True, target_size=(IMSIZE[IMS], IMSIZE[IMS]), ext='png')
        test_decoder = build_decoder(with_labels=False, target_size=(IMSIZE[IMS], IMSIZE[IMS]),ext='png')

        train_dataset = build_dataset(
            train_paths, train_labels, bsize=BATCH_SIZE, decode_fn=decoder
        )

        valid_dataset = build_dataset(
            valid_paths, valid_labels, bsize=BATCH_SIZE, decode_fn=decoder,
            repeat=False, shuffle=False, augment=False
        )

        try:
            n_labels = train_labels.shape[1]
        except:
            n_labels = 1

        with strategy.scope():

            model = tf.keras.Sequential([
                efn.EfficientNetB5(
                    input_shape=(IMSIZE[IMS], IMSIZE[IMS], 3),
                    weights='imagenet',
                    include_top=False),
                tf.keras.layers.GlobalAveragePooling2D(),
                tf.keras.layers.Dropout(.1),
                tf.keras.layers.Dense(n_labels, activation='softmax')
            ])


            model.compile(
                optimizer=tf.keras.optimizers.Adam(),
                loss='categorical_crossentropy',
                metrics=[tf.keras.metrics.AUC(multi_label=True)])

            model.summary()


        steps_per_epoch = train_paths.shape[0] // BATCH_SIZE
        checkpoint = tf.keras.callbacks.ModelCheckpoint(
            f'model{i}.h5', save_best_only=True, monitor='val_loss', mode='min')
        lr_reducer = tf.keras.callbacks.ReduceLROnPlateau(
            monitor="val_loss", patience=3, min_lr=1e-6, mode='min')

        history = model.fit(
            train_dataset, 
            epochs=20,
            verbose=1,
            callbacks=[checkpoint, lr_reducer],
            steps_per_epoch=steps_per_epoch,
            validation_data=valid_dataset)

        hist_df = pd.DataFrame(history.history)
        hist_df.to_csv(f'history{i}.csv')

Downloading data from https://github.com/Callidior/keras-applications/releases/download/efficientnet/efficientnet-b5_weights_tf_dim_ordering_tf_kernels_autoaugment_notop.h5
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
efficientnet-b5 (Functional) (None, 24, 24, 2048)      28513520  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense (Dense)                (None, 4)                 8196      
Total params: 28,521,716
Trainable params: 28,348,980
Non-trainable params: 172,736
_________________________________________________________________
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/