In [None]:
EPOCHS = 300
NUM_RETRIES = 10 
MODALITIES = ['FLAIR'] # ['T1w'] # ['T1wCE'] # ['T2w'] 
BATCH_SIZE = 64
DATA_HEIGHT = 256 
DATA_WIDTH = 256 
DATA_DEPTH = 32 # 16 
DATA_CHANNELS = 4 
INPUT_HEIGHT = 120 # 256 # 50 # 64 # 
INPUT_WIDTH = 120 # 256 # 50 # 64 # 
INPUT_CHANNELS = len(MODALITIES) 
INPUT_DEPTH = DATA_DEPTH 
LEARNING_RATE = 1e-4 
MODEL_TYPE = 'EfficientNet3D_B0S'

FEATURE_MAPPING = True 
DEBUG = True 
LR_REDUCE = True 
BEST_WEIGHTS = True

ROOT_DIR = '/kaggle/input'

WEIGHTS_DIR = f'{ROOT_DIR}/rsna-brain-tumor-classification-external-weights'

In [None]:
import os, glob, numpy as np
def get_best_weight(weights_dir, model_type):
  aucs, wts = [], []
  for f in glob.glob(os.path.join(weights_dir, f'{model_type}*')):
    if '.h5' in f:
      aucs.append(float(f.replace('.h5','').split('-')[-1]))
      wts.append(f)
  if len(aucs)>0 and len(aucs)==len(wts):
    return os.path.join(weights_dir, wts[np.argmax(aucs)])

In [None]:
if BEST_WEIGHTS:
  WEIGHTS_FILE = get_best_weight(WEIGHTS_DIR, MODEL_TYPE)
else:
  WEIGHTS_FILE = f'{WEIGHTS_DIR}/{MODEL_TYPE}.h5'

In [None]:
if os.path.exists(WEIGHTS_FILE):
  !cp {WEIGHTS_FILE} ./
  print(f'Copied best weights file to training kernel: {WEIGHTS_FILE} ...')

In [None]:
import os, re, cv2, math, glob, string, collections, numpy as np, pandas as pd, \
       matplotlib.pyplot as plt, tensorflow as tf, tensorflow_addons as tfa

from tqdm import tqdm
from six.moves import xrange
from tensorflow.keras import layers
from kaggle_datasets import KaggleDatasets

In [None]:
AUTO = tf.data.AUTOTUNE

try:
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
  print('Device:', tpu.master())
  tf.config.experimental_connect_to_cluster(tpu)
  tf.tpu.experimental.initialize_tpu_system(tpu)
  strategy = tf.distribute.experimental.TPUStrategy(tpu)
  TPU = True
except:
  strategy = tf.distribute.get_strategy()
  TPU = False
print('Number of replicas:', strategy.num_replicas_in_sync)

In [None]:
GCS_BUCKETS = ['rsna-brain-tumor-classification-raw-tfrecords',  
               'rsna-brain-tumor-classification-resample-tfrecords']    \
              if INPUT_DEPTH == 32 else                                 \
              ['rsna-brain-tumor-raw-tfrecords-x16',
               'rsna-brain-tumor-resampled-tfrecords-x16']
GCS_PATHS= [KaggleDatasets().get_gcs_path(GB) for GB in GCS_BUCKETS]
[print(f'Reading tfrecords from GCS bucket {GCS_BUCKETS[i]}: {GP} ...') \
       for i, GP in enumerate(GCS_PATHS)]

In [None]:
def deserialize_example(serialized_string):
  image_feature_description = {'image': tf.io.FixedLenFeature([], tf.string),
                               'MGMT_value': tf.io.FixedLenFeature([], tf.float32)}
  parsed_record = tf.io.parse_single_example(serialized_string, image_feature_description)
  image = tf.io.decode_raw(parsed_record['image'], tf.float32)
  image = tf.reshape(image,[DATA_HEIGHT, DATA_WIDTH, DATA_DEPTH, DATA_CHANNELS])
    
  splitted_modalities = tf.split(tf.cast(image, tf.float32), DATA_CHANNELS, axis=-1)  
  splitted_modalities = [tf.squeeze(i, axis=-1) for i in splitted_modalities] 
    
  flair_augment_img = []
  t1w_augment_img = []
  t1wce_augment_img = []
  t2w_augment_img = []
    
  for j, modality in enumerate(splitted_modalities):
    splitted_frames = tf.split(tf.cast(modality, tf.float32), modality.shape[-1], axis=-1)
    for i, img in enumerate(splitted_frames):
      img = tf.image.resize(img, [INPUT_HEIGHT, INPUT_WIDTH])  
      if j == 0:
        flair_augment_img.append(img)
      elif j == 1: 
        t1w_augment_img.append(img)
      elif j == 2:
        t1wce_augment_img.append(img)
      elif j == 3:
        t2w_augment_img.append(img)
  image = []
  if 'FLAIR' in MODALITIES:
    image.append(flair_augment_img)
  if 'T1w' in MODALITIES:
    image.append(t1w_augment_img)
  if 'T1wCE' in MODALITIES:
    image.append(t1wce_augment_img)
  if 'T2w' in MODALITIES:
    image.append(t2w_augment_img)
  image = tf.transpose(image)
  image = tf.reshape(image, [INPUT_HEIGHT, INPUT_WIDTH, INPUT_DEPTH, INPUT_CHANNELS])          
  label = parsed_record['MGMT_value']
  return image, label

In [None]:
train_gcs_files = [tf.io.gfile.glob(os.path.join(GCS_PATH,
                    'brain_tumor_classification_*train*.tfrec')) for GCS_PATH in GCS_PATHS]
val_gcs_files = [tf.io.gfile.glob(os.path.join(GCS_PATH,
                    'brain_tumor_classification_*val*.tfrec')) for GCS_PATH in GCS_PATHS]                                    
train_set = tf.data.TFRecordDataset(train_gcs_files, compression_type='GZIP', 
    num_parallel_reads=AUTO).map(deserialize_example).batch(BATCH_SIZE).prefetch(AUTO)
val_set = tf.data.TFRecordDataset(val_gcs_files, compression_type='GZIP', 
    num_parallel_reads=AUTO).map(deserialize_example).batch(BATCH_SIZE).prefetch(AUTO)

In [None]:
d = train_set.take(1)
for x, y in d:
  image, label = x, y
img_id = np.random.randint(0, BATCH_SIZE)
channel = np.random.randint(0, INPUT_CHANNELS)

plt.figure(figsize=(20,10), facecolor=(0,0,0))
cols = INPUT_DEPTH//4
rows = 4

plt.axis("off")
for layer_idx in range(INPUT_DEPTH):
  ax = plt.subplot(rows,cols,layer_idx+1)
  ax.imshow(np.squeeze(image[img_id,:,:,layer_idx,channel]), cmap='gray')
  ax.axis("off")
  ax.set_title(str(layer_idx+1),color='r',y=-0.01)
    
plt.suptitle(f"Batch image no.: {img_id}, MRI modality: {MODALITIES[channel]}, \
             Shape: {image[img_id].shape}", color="w")
plt.subplots_adjust(wspace=0, hspace=0)
plt.show()

In [None]:
plt.close()

In [None]:
def generate_legacy_interface(allowed_positional_args=None,
                              conversions=None,
                              preprocessor=None,
                              value_conversions=None):
    allowed_positional_args = allowed_positional_args or []
    conversions = conversions or []
    value_conversions = value_conversions or []

    def legacy_support(func):
        @six.wraps(func)
        def wrapper(*args, **kwargs):
            layer_name = args[0].__class__.__name__
            if preprocessor:
                args, kwargs, converted = preprocessor(args, kwargs)
            else:
                converted = []
            if len(args) > len(allowed_positional_args) + 1:
                raise TypeError('Layer `' + layer_name +
                                '` can accept only ' +
                                str(len(allowed_positional_args)) +
                                ' positional arguments (' +
                                str(allowed_positional_args) + '), but '
                                'you passed the following '
                                'positional arguments: ' +
                                str(args[1:]))
            for key in value_conversions:
                if key in kwargs:
                    old_value = kwargs[key]
                    if old_value in value_conversions[key]:
                        kwargs[key] = value_conversions[key][old_value]
            for old_name, new_name in conversions:
                if old_name in kwargs:
                    value = kwargs.pop(old_name)
                    kwargs[new_name] = value
                    converted.append((new_name, old_name))
            if converted:
                signature = '`' + layer_name + '('
                for value in args[1:]:
                    if isinstance(value, six.string_types):
                        signature += '"' + value + '"'
                    else:
                        signature += str(value)
                    signature += ', '
                for i, (name, value) in enumerate(kwargs.items()):
                    signature += name + '='
                    if isinstance(value, six.string_types):
                        signature += '"' + value + '"'
                    else:
                        signature += str(value)
                    if i < len(kwargs) - 1:
                        signature += ', '
                signature += ')`'
                warnings.warn('Update your `' + layer_name +
                              '` layer call to the Keras 2 API: ' + signature)
            return func(*args, **kwargs)
        return wrapper
    return legacy_support


def conv3d_args_preprocessor(args, kwargs):
    if len(args) > 5:
        raise TypeError('Layer can receive at most 4 positional arguments.')
    if len(args) == 5:
        if isinstance(args[2], int) and isinstance(args[3], int) and isinstance(args[4], int):
            kernel_size = (args[2], args[3], args[4])
            args = [args[0], args[1], kernel_size]
    elif len(args) == 4 and isinstance(args[3], int):
        if isinstance(args[2], int) and isinstance(args[3], int):
            new_keywords = ['padding', 'strides', 'data_format']
            for kwd in new_keywords:
                if kwd in kwargs:
                    raise ValueError(
                        'It seems that you are using the Keras 2 '
                        'and you are passing both `kernel_size` and `strides` '
                        'as integer positional arguments. For safety reasons, '
                        'this is disallowed. Pass `strides` '
                        'as a keyword argument instead.')
        if 'kernel_dim3' in kwargs:
            kernel_size = (args[2], args[3], kwargs.pop('kernel_dim3'))
            args = [args[0], args[1], kernel_size]
    elif len(args) == 3:
        if 'kernel_dim2' in kwargs and 'kernel_dim3' in kwargs:
            kernel_size = (args[2],
                           kwargs.pop('kernel_dim2'),
                           kwargs.pop('kernel_dim3'))
            args = [args[0], args[1], kernel_size]
    elif len(args) == 2:
        if 'kernel_dim1' in kwargs and 'kernel_dim2' in kwargs and 'kernel_dim3' in kwargs:
            kernel_size = (kwargs.pop('kernel_dim1'),
                           kwargs.pop('kernel_dim2'),
                           kwargs.pop('kernel_dim3'))
            args = [args[0], args[1], kernel_size]
    return args, kwargs, [('kernel_size', 'kernel_dim*')]


def _preprocess_padding(padding):
    if padding == 'same':
        padding = 'SAME'
    elif padding == 'valid':
        padding = 'VALID'
    else:
        raise ValueError('Invalid padding: ' + str(padding))
    return padding


def dtype(x):
    return x.dtype.base_dtype.name


def _has_nchw_support():
    return True


def _preprocess_conv3d_input(x, data_format):
    if (dtype(x) == 'float64' and
            StrictVersion(tf.__version__.split('-')[0]) < StrictVersion('1.8.0')):
        x = tf.cast(x, 'float32')
    tf_data_format = 'NDHWC'
    return x, tf_data_format

In [None]:
def depthwise_conv3d_args_preprocessor(args, kwargs):
    converted = []

    if 'init' in kwargs:
        init = kwargs.pop('init')
        kwargs['depthwise_initializer'] = init
        converted.append(('init', 'depthwise_initializer'))

    args, kwargs, _converted = conv3d_args_preprocessor(args, kwargs)
    return args, kwargs, converted + _converted

    legacy_depthwise_conv3d_support = generate_legacy_interface(
    allowed_positional_args=['filters', 'kernel_size'],
    conversions=[('nb_filter', 'filters'),
                 ('subsample', 'strides'),
                 ('border_mode', 'padding'),
                 ('dim_ordering', 'data_format'),
                 ('b_regularizer', 'bias_regularizer'),
                 ('b_constraint', 'bias_constraint'),
                 ('bias', 'use_bias')],
    value_conversions={'dim_ordering': {'tf': 'channels_last',
                                        'th': 'channels_first',
                                        'default': None}},
    preprocessor=depthwise_conv3d_args_preprocessor)

In [None]:
(_KERAS_BACKEND, _KERAS_LAYERS, 
 _KERAS_MODELS, _KERAS_UTILS) = (tf.keras.backend, tf.keras.layers, 
                                 tf.keras.models, tf.keras.utils)

In [None]:
def get_submodules_from_kwargs(kwargs):
  backend = kwargs.get('backend', _KERAS_BACKEND)
  layers  = kwargs.get('layers', _KERAS_LAYERS)
  models  = kwargs.get('models', _KERAS_MODELS)
  utils   = kwargs.get('utils', _KERAS_UTILS)
  for key in kwargs.keys():
    if key not in ['backend', 'layers', 'models', 'utils']:
      raise TypeError('Invalid keyword argument: %s', key)
  return backend, layers, models, utils

In [None]:
class DepthwiseConv3D(tf.keras.layers.Conv3D):
    def __init__(self,
                 kernel_size,
                 strides=(1, 1, 1),
                 padding='valid',
                 depth_multiplier=1,
                 groups=None,
                 data_format=None,
                 activation=None,
                 use_bias=True,
                 depthwise_initializer='glorot_uniform',
                 bias_initializer='zeros',
                 dilation_rate = (1, 1, 1),
                 depthwise_regularizer=None,
                 bias_regularizer=None,
                 activity_regularizer=None,
                 depthwise_constraint=None,
                 bias_constraint=None,
                 **kwargs):
        super(DepthwiseConv3D, self).__init__(
            filters=None,
            kernel_size=kernel_size,
            strides=strides,
            padding=padding,
            data_format=data_format,
            activation=activation,
            use_bias=use_bias,
            bias_regularizer=bias_regularizer,
            dilation_rate=dilation_rate,
            activity_regularizer=activity_regularizer,
            bias_constraint=bias_constraint,
            **kwargs)
        self.depth_multiplier = depth_multiplier
        self.groups = groups
        self.depthwise_initializer = tf.keras.initializers.get(depthwise_initializer)
        self.depthwise_regularizer = tf.keras.regularizers.get(depthwise_regularizer)
        self.depthwise_constraint = tf.keras.constraints.get(depthwise_constraint)
        self.bias_initializer = tf.keras.initializers.get(bias_initializer)
        self.dilation_rate = dilation_rate
        self._padding = _preprocess_padding(self.padding)
        self._strides = (1,) + self.strides + (1,)
        self._data_format = "NDHWC"
        self.input_dim = None

    def build(self, input_shape):
        if len(input_shape) < 5:
            raise ValueError('Inputs to `DepthwiseConv3D` should have rank 5. '
                             'Received input shape:', str(input_shape))
        if self.data_format == 'channels_first':
            channel_axis = 1
        else:
            channel_axis = -1
        if input_shape[channel_axis] is None:
            raise ValueError('The channel dimension of the inputs to '
                             '`DepthwiseConv3D` '
                             'should be defined. Found `None`.')
        self.input_dim = int(input_shape[channel_axis])

        if self.groups is None:
            self.groups = self.input_dim

        if self.groups > self.input_dim:
            raise ValueError('The number of groups cannot exceed the number of channels')

        if self.input_dim % self.groups != 0:
            raise ValueError('Warning! The channels dimension is not divisible by the group size chosen')

        depthwise_kernel_shape = (self.kernel_size[0],
                                  self.kernel_size[1],
                                  self.kernel_size[2],
                                  self.input_dim,
                                  self.depth_multiplier)

        self.depthwise_kernel = self.add_weight(
            shape=depthwise_kernel_shape,
            initializer=self.depthwise_initializer,
            name='depthwise_kernel',
            regularizer=self.depthwise_regularizer,
            constraint=self.depthwise_constraint)

        if self.use_bias:
            self.bias = self.add_weight(shape=(self.groups * self.depth_multiplier,),
                                        initializer=self.bias_initializer,
                                        name='bias',
                                        regularizer=self.bias_regularizer,
                                        constraint=self.bias_constraint)
        else:
            self.bias = None
        self.input_spec = tf.keras.layers.InputSpec(ndim=5, axes={channel_axis: self.input_dim})
        self.built = True

    def call(self, inputs, training=None):
        inputs = _preprocess_conv3d_input(inputs, self.data_format)

        if self.data_format == 'channels_last':
            dilation = (1,) + self.dilation_rate + (1,)
        else:
            dilation = self.dilation_rate + (1,) + (1,)

        if self._data_format == 'NCDHW':
            outputs = tf.concat(
                [tf.nn.conv3d(inputs[0][:, i:i+self.input_dim//self.groups, :, :, :], self.depthwise_kernel[:, :, :, i:i+self.input_dim//self.groups, :],
                    strides=self._strides,
                    padding=self._padding,
                    dilations=dilation,
                    data_format=self._data_format) for i in range(0, self.input_dim, self.input_dim//self.groups)], axis=1)

        else:
            outputs = tf.concat(
                [tf.nn.conv3d(inputs[0][:, :, :, :, i:i+self.input_dim//self.groups], self.depthwise_kernel[:, :, :, i:i+self.input_dim//self.groups, :],
                    strides=self._strides,
                    padding=self._padding,
                    dilations=dilation,
                    data_format=self._data_format) for i in range(0, self.input_dim, self.input_dim//self.groups)], axis=-1)

        if self.bias is not None:
            outputs = K.bias_add(
                outputs,
                self.bias,
                data_format=self.data_format)

        if self.activation is not None:
            return self.activation(outputs)

        return outputs

    def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_first':
            depth = input_shape[2]
            rows = input_shape[3]
            cols = input_shape[4]
            out_filters = self.groups * self.depth_multiplier
        elif self.data_format == 'channels_last':
            depth = input_shape[1]
            rows = input_shape[2]
            cols = input_shape[3]
            out_filters = self.groups * self.depth_multiplier

        depth = conv_utils.conv_output_length(depth, self.kernel_size[0],
                                             self.padding,
                                             self.strides[0])

        rows = conv_utils.conv_output_length(rows, self.kernel_size[1],
                                             self.padding,
                                             self.strides[1])

        cols = conv_utils.conv_output_length(cols, self.kernel_size[2],
                                             self.padding,
                                             self.strides[2])

        if self.data_format == 'channels_first':
            return (input_shape[0], out_filters, depth, rows, cols)

        elif self.data_format == 'channels_last':
            return (input_shape[0], depth, rows, cols, out_filters)

    def get_config(self):
        config = super(DepthwiseConv3D, self).get_config()
        config.pop('filters')
        config.pop('kernel_initializer')
        config.pop('kernel_regularizer')
        config.pop('kernel_constraint')
        config['depth_multiplier'] = self.depth_multiplier
        config['depthwise_initializer'] = tf.keras.initializers.serialize(self.depthwise_initializer)
        config['depthwise_regularizer'] = tf.keras.regularizers.serialize(self.depthwise_regularizer)
        config['depthwise_constraint'] = tf.keras.constraints.serialize(self.depthwise_constraint)
        return config

In [None]:
DepthwiseConvolution3D = DepthwiseConv3D

(backend, layers, models, keras_utils) = (tf.keras.backend, tf.keras.layers, 
                                          tf.keras.models, tf.keras.utils)

# Code of this model implementation is mostly written by
# Björn Barz ([@Callidior](https://github.com/Callidior))

BlockArgs = collections.namedtuple('BlockArgs', [
    'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
    'expand_ratio', 'id_skip', 'strides', 'se_ratio'
])
BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields)

DEFAULT_BLOCKS_ARGS = [
    BlockArgs(kernel_size=3, num_repeat=1, input_filters=32, output_filters=16,
              expand_ratio=1, id_skip=True, strides=[1, 1, 1], se_ratio=0.25),
    BlockArgs(kernel_size=3, num_repeat=2, input_filters=16, output_filters=24,
              expand_ratio=6, id_skip=True, strides=[2, 2, 2], se_ratio=0.25),
    BlockArgs(kernel_size=5, num_repeat=2, input_filters=24, output_filters=40,
              expand_ratio=6, id_skip=True, strides=[2, 2, 2], se_ratio=0.25),
    BlockArgs(kernel_size=3, num_repeat=3, input_filters=40, output_filters=80,
              expand_ratio=6, id_skip=True, strides=[2, 2, 2], se_ratio=0.25),
    BlockArgs(kernel_size=5, num_repeat=3, input_filters=80, output_filters=112,
              expand_ratio=6, id_skip=True, strides=[1, 1, 1], se_ratio=0.25),
    BlockArgs(kernel_size=5, num_repeat=4, input_filters=112, output_filters=192,
              expand_ratio=6, id_skip=True, strides=[2, 2, 2], se_ratio=0.25),
    BlockArgs(kernel_size=3, num_repeat=1, input_filters=192, output_filters=320,
              expand_ratio=6, id_skip=True, strides=[1, 1, 1], se_ratio=0.25)
]

CONV_KERNEL_INITIALIZER = {'class_name': 'VarianceScaling',
  'config': {'scale': 2.0, 'mode': 'fan_out',
             'distribution': 'normal'}}

DENSE_KERNEL_INITIALIZER = {'class_name': 'VarianceScaling',
             'config': {'scale': 1. / 3., 'mode': 'fan_out',
                        'distribution': 'uniform'}}


def preprocess_input(x, **kwargs):
    kwargs = {k: v for k, v in kwargs.items() if k in ['backend', 'layers', 'models', 'utils']}
    return _preprocess_input(x, mode='tensorflow', **kwargs)


def get_swish(**kwargs):
    (backend, layers, models, keras_utils) = (tf.keras.backend, tf.keras.layers, 
                                              tf.keras.models, tf.keras.utils)
    def swish(x):
        return tf.nn.swish(x)
    return swish

def get_dropout(**kwargs):
    (backend, layers, models, keras_utils) = (tf.keras.backend, tf.keras.layers, 
                                              tf.keras.models, tf.keras.utils)

    class FixedDropout(layers.Dropout):
        def _get_noise_shape(self, inputs):
            if self.noise_shape is None:
                return self.noise_shape

            symbolic_shape = backend.shape(inputs)
            noise_shape = [symbolic_shape[axis] if shape is None else shape
                           for axis, shape in enumerate(self.noise_shape)]
            return tuple(noise_shape)

    return FixedDropout


def round_filters(filters, width_coefficient, depth_divisor):
    filters *= width_coefficient
    new_filters = int(filters + depth_divisor / 2) // depth_divisor * depth_divisor
    new_filters = max(depth_divisor, new_filters)
    if new_filters < 0.9 * filters:
        new_filters += depth_divisor
    return int(new_filters)


def round_repeats(repeats, depth_coefficient):
    return int(math.ceil(depth_coefficient * repeats))


def mb_conv_block(inputs, block_args, activation, drop_rate=None, prefix='', ):
    has_se = (block_args.se_ratio is not None) and (0 < block_args.se_ratio <= 1)
    bn_axis = 4 #if backend.image_data_format() == 'channels_last' else 1

    # workaround over non working dropout with None in noise_shape in tf.keras
    Dropout = get_dropout(
        backend=backend,
        layers=layers,
        models=models,
        utils=keras_utils
    )

    # Expansion phase
    filters = block_args.input_filters * block_args.expand_ratio
    if block_args.expand_ratio != 1:
        x = tf.keras.layers.Conv3D(filters, 1,
                          padding='same',
                          use_bias=False,
                          kernel_initializer=CONV_KERNEL_INITIALIZER,
                          name=prefix + 'expand_conv')(inputs)
        x = tf.keras.layers.BatchNormalization(axis=bn_axis, name=prefix + 'expand_bn')(x)
        x = tf.keras.layers.Activation(activation, name=prefix + 'expand_activation')(x)
    else:
        x = inputs

    # Depthwise Convolution
    x = DepthwiseConv3D(block_args.kernel_size,
                               strides=block_args.strides,
                               padding='same',
                               use_bias=False,
                               depthwise_initializer=CONV_KERNEL_INITIALIZER,
                               name=prefix + 'dwconv')(x)
    x = tf.keras.layers.BatchNormalization(axis=bn_axis, name=prefix + 'bn')(x)
    x = tf.keras.layers.Activation(activation, name=prefix + 'activation')(x)

    # Squeeze and Excitation phase
    if has_se:
        num_reduced_filters = max(1, int(
            block_args.input_filters * block_args.se_ratio
        ))
        se_tensor = tf.keras.layers.GlobalAveragePooling3D(name=prefix + 'se_squeeze')(x)

        target_shape = (1, 1, 1, filters) if backend.image_data_format() == 'channels_last' else (filters, 1, 1, 1)
        se_tensor = tf.keras.layers.Reshape(target_shape, name=prefix + 'se_reshape')(se_tensor)
        se_tensor = tf.keras.layers.Conv3D(num_reduced_filters, 1,
                                  activation=activation,
                                  padding='same',
                                  use_bias=True,
                                  kernel_initializer=CONV_KERNEL_INITIALIZER,
                                  name=prefix + 'se_reduce')(se_tensor)
        se_tensor = tf.keras.layers.Conv3D(filters, 1,
                                  activation='sigmoid',
                                  padding='same',
                                  use_bias=True,
                                  kernel_initializer=CONV_KERNEL_INITIALIZER,
                                  name=prefix + 'se_expand')(se_tensor)
        if backend.backend() == 'theano':
            # For the Theano backend, we have to explicitly make
            # the excitation weights broadcastable.
            pattern = ([True, True, True, True, False] if backend.image_data_format() == 'channels_last'
                       else [True, False, True, True, True])
            se_tensor = layers.Lambda(
                lambda x: backend.pattern_broadcast(x, pattern),
                name=prefix + 'se_broadcast')(se_tensor)
        x = layers.multiply([x, se_tensor], name=prefix + 'se_excite')

    # Output phase
    x = layers.Conv3D(block_args.output_filters, 1,
                      padding='same',
                      use_bias=False,
                      kernel_initializer=CONV_KERNEL_INITIALIZER,
                      name=prefix + 'project_conv')(x)
    x = layers.BatchNormalization(axis=bn_axis, name=prefix + 'project_bn')(x)
    if block_args.id_skip and all(
            s == 1 for s in block_args.strides
    ) and block_args.input_filters == block_args.output_filters:
        if drop_rate and (drop_rate > 0):
            x = Dropout(drop_rate,
                        noise_shape=(None, 1, 1, 1, 1),
                        name=prefix + 'drop')(x)
        x = layers.add([x, inputs], name=prefix + 'add')

    return x

In [None]:
def EfficientNet3D(width_coefficient, depth_coefficient, default_resolution,
                   dropout_rate=0.2, drop_connect_rate=0.2, depth_divisor=8,
                   blocks_args=DEFAULT_BLOCKS_ARGS, include_top=False,
                   model_name='efficientnet3d', weights=None, input_tensor=None, 
                   input_shape=None, pooling=None, classes=1000, **kwargs):
    
    global backend, layers, models, keras_utils
    backend, layers, models, keras_utils = get_submodules_from_kwargs(kwargs)

    if weights is not None and not os.path.exists(weights):
        raise ValueError('The `weights` argument should be '
                         'a valid path to the weights file ...')

    if input_tensor is None:
        img_input = layers.Input(shape=input_shape)
    else:
        img_input = input_tensor

    bn_axis = 4 if backend.image_data_format() == 'channels_last' else 1
    activation = get_swish(**kwargs)

    # Build stem
    x = img_input
    x = layers.Conv3D(round_filters(32, width_coefficient, depth_divisor), 3,
                      strides=(2, 2, 2),
                      padding='same',
                      use_bias=False,
                      kernel_initializer=CONV_KERNEL_INITIALIZER,
                      name='stem_conv')(x)
    x = layers.BatchNormalization(axis=bn_axis, name='stem_bn')(x)
    x = layers.Activation(activation, name='stem_activation')(x)

    # Build blocks
    num_blocks_total = sum(block_args.num_repeat for block_args in blocks_args)
    block_num = 0
    for idx, block_args in enumerate(blocks_args):
        assert block_args.num_repeat > 0
        # Update block input and output filters based on depth multiplier.
        block_args = block_args._replace(
            input_filters=round_filters(block_args.input_filters,
                                        width_coefficient, depth_divisor),
            output_filters=round_filters(block_args.output_filters,
                                         width_coefficient, depth_divisor),
            num_repeat=round_repeats(block_args.num_repeat, depth_coefficient))

        # The first block needs to take care of stride and filter size increase.
        drop_rate = drop_connect_rate * float(block_num) / num_blocks_total
        x = mb_conv_block(x, block_args,
                          activation=activation,
                          drop_rate=drop_rate,
                          prefix='block{}a_'.format(idx + 1))
        block_num += 1
        if block_args.num_repeat > 1:
            # pylint: disable=protected-access
            block_args = block_args._replace(
                input_filters=block_args.output_filters, strides=[1, 1, 1])
            # pylint: enable=protected-access
            for bidx in xrange(block_args.num_repeat - 1):
                drop_rate = drop_connect_rate * float(block_num) / num_blocks_total
                block_prefix = 'block{}{}_'.format(
                    idx + 1,
                    string.ascii_lowercase[bidx + 1]
                )
                x = mb_conv_block(x, block_args,
                                  activation=activation,
                                  drop_rate=drop_rate,
                                  prefix=block_prefix)
                block_num += 1

    # Build top
    x = layers.Conv3D(round_filters(1280, width_coefficient, depth_divisor), 1,
                      padding='same',
                      use_bias=False,
                      kernel_initializer=CONV_KERNEL_INITIALIZER,
                      name='top_conv')(x)
    x = layers.BatchNormalization(axis=bn_axis, name='top_bn')(x)
    x = layers.Activation(activation, name='top_activation')(x)
    if include_top:
        x = layers.GlobalAveragePooling3D(name='avg_pool')(x)
        if dropout_rate and dropout_rate > 0:
            x = layers.Dropout(dropout_rate, name='top_dropout')(x)
        x = layers.Dense(classes,
                         activation='softmax',
                         kernel_initializer=DENSE_KERNEL_INITIALIZER,
                         name='probs')(x)
    else:
        if pooling == 'avg':
            x = layers.GlobalAveragePooling3D(name='avg_pool')(x)
        elif pooling == 'max':
            x = layers.GlobalMaxPooling3D(name='max_pool')(x)

    # Ensure that the model takes into account
    # any potential predecessors of `input_tensor`.
    if input_tensor is not None:
        inputs = keras_utils.get_source_inputs(input_tensor)
    else:
        inputs = img_input

    # Create model.
    model = models.Model(inputs, x, name=model_name)

    if weights is not None and os.path.exists(weights):
      model.load_weights(weights)

    return model

In [None]:
def EfficientNet3D_B0(include_top=False, input_tensor=None, input_shape=None,
                      weights=None, pooling=None, classes=1000, **kwargs):
    return EfficientNet3D(1.0, 1.0, 224, 0.2, model_name='efficientnet3d-b0',
                          include_top=include_top, weights=None,
                          input_tensor=input_tensor, input_shape=input_shape,
                          pooling=pooling, classes=classes, **kwargs)

def EfficientNet3D_B1(include_top=False, input_tensor=None, input_shape=None,
                      pooling=None, classes=1000, **kwargs):
    return EfficientNet3D(1.0, 1.1, 240, 0.2, model_name='efficientnet3d-b1',
                          include_top=include_top, weights=weights,
                          input_tensor=input_tensor, input_shape=input_shape,
                          pooling=pooling, classes=classes, **kwargs)

def EfficientNet3D_B2(include_top=False, input_tensor=None, input_shape=None,
                      pooling=None, classes=1000, **kwargs):
    return EfficientNet3D(1.1, 1.2, 260, 0.3, model_name='efficientnet3d-b2',
                          include_top=include_top, weights=weights,
                          input_tensor=input_tensor, input_shape=input_shape,
                          pooling=pooling, classes=classes, **kwargs)

def EfficientNet3D_B3(include_top=False, input_tensor=None, input_shape=None,
                      pooling=None, classes=1000, **kwargs):
    return EfficientNet3D(1.2, 1.4, 300, 0.3, model_name='efficientnet3d-b3',
                          include_top=include_top, weights=weights,
                          input_tensor=input_tensor, input_shape=input_shape,
                          pooling=pooling, classes=classes, **kwargs)

def EfficientNet3D_B4(include_top=False, input_tensor=None, input_shape=None,
                      pooling=None, classes=1000, **kwargs):
    return EfficientNet3D(1.4, 1.8, 380, 0.4, model_name='efficientnet3d-b4',
                          include_top=include_top, weights=weights,
                          input_tensor=input_tensor, input_shape=input_shape,
                          pooling=pooling, classes=classes, **kwargs)

def EfficientNet3D_B5(include_top=False, input_tensor=None, input_shape=None,
                      pooling=None, classes=1000, **kwargs):
    return EfficientNet3D(1.6, 2.2, 456, 0.4, model_name='efficientnet3d-b5',
                          include_top=include_top, weights=weights,
                          input_tensor=input_tensor, input_shape=input_shape,
                          pooling=pooling, classes=classes, **kwargs)

def EfficientNet3D_B6(include_top=False, input_tensor=None, input_shape=None,
                      pooling=None, classes=1000, **kwargs):
    return EfficientNet3D(1.8, 2.6, 528, 0.5, model_name='efficientnet3d-b6',
                          include_top=include_top, weights=weights,
                          input_tensor=input_tensor, input_shape=input_shape,
                          pooling=pooling, classes=classes, **kwargs)

def EfficientNet3D_B7(include_top=False, input_tensor=None, input_shape=None,
                      pooling=None, classes=1000, **kwargs):
    return EfficientNet3D(2.0, 3.1, 600, 0.5, model_name='efficientnet3d-b7',
                          include_top=include_top, weights=weights,
                          input_tensor=input_tensor, input_shape=input_shape,
                          pooling=pooling, classes=classes, **kwargs)

def EfficientNet3D_L2(include_top=False, input_tensor=None, input_shape=None,
                      pooling=None, classes=1000, **kwargs):
    return EfficientNet3D(4.3, 5.3, 800, 0.5, model_name='efficientnet3d-l2',
                          include_top=include_top, weights=weights,
                          input_tensor=input_tensor, input_shape=input_shape,
                          pooling=pooling, classes=classes, **kwargs)

In [None]:
def get_efn3D(input_height, input_width, input_depth, input_channels,
              model_type='EfficientNet3D_B0', feature_mapping=True):
  input_tensor = tf.keras.Input((input_height, input_width, 
                        input_depth, input_channels), name='input3D')
  if feature_mapping:
    feat3d_map = tf.keras.layers.Conv3D(3, (3,3,3), strides=(1, 1, 1), 
                     padding='same', name='feature3D', use_bias=True)(input_tensor)
  feat_ext = EfficientNet3D(0.45, 0.45, input_height, 0.2, model_name='efficientnet3d-b0',
                            input_shape=(input_height, input_width, input_depth, 3 \
                            if feature_mapping else input_channels), 
                            include_top=False, weights=None, pooling=None)
  output = feat_ext(feat3d_map if feature_mapping else input_tensor)
  output = tf.keras.layers.Dropout(0.5)(output)
  output = tf.keras.layers.GlobalAveragePooling3D()(output)
  output = tf.keras.layers.Dropout(0.5)(output)
  output = tf.keras.layers.Dense(1, activation='sigmoid')(output)
  return tf.keras.Model(input_tensor, output)

In [None]:
with strategy.scope():
  model = get_efn3D(INPUT_HEIGHT, INPUT_WIDTH, INPUT_DEPTH, INPUT_CHANNELS,
                    model_type=MODEL_TYPE, feature_mapping=FEATURE_MAPPING)

  model.compile(loss='binary_crossentropy', 
                optimizer=tf.keras.optimizers.Nadam(lr=LEARNING_RATE),
                metrics=['accuracy', 'AUC'])
    
  ckpt_cb = tf.keras.callbacks.ModelCheckpoint(
    filepath=f'{MODEL_TYPE}'+'.{epoch:02d}-{val_auc:.4f}.h5', 
    monitor='val_auc', mode='max', save_best_only=True, verbose=1)

  es_cb = tf.keras.callbacks.EarlyStopping(
    monitor='val_auc', min_delta=0, patience=5, 
    verbose=1, mode='max', restore_best_weights=True)
    
  lr_reduce_cb = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss', factor=0.1, patience=3, verbose=1,
    mode='min', min_delta=0.0001, cooldown=0, min_lr=1e-16) 

  def get_lr_callback(batch_size=8):
    lr_start   = 0.000005
    lr_max     = 0.00000125 * batch_size
    lr_min     = 0.000001
    lr_ramp_ep = 5
    lr_sus_ep  = 0
    lr_decay   = 0.8
    def lrfn(epoch):
      if epoch < lr_ramp_ep:
        lr = (lr_max - lr_start) / lr_ramp_ep * epoch + lr_start
      elif epoch < lr_ramp_ep + lr_sus_ep:
        lr = lr_max
      else:
        lr = (lr_max - lr_min) * lr_decay**(epoch - lr_ramp_ep - lr_sus_ep) + lr_min
      return lr
    return tf.keras.callbacks.LearningRateScheduler(lrfn, verbose=True)

In [None]:
model.summary(); tf.keras.backend.clear_session()
if DEBUG:
  for x, y in d:
    print(x.shape, y.shape)
  if TPU:
    model.fit(x, y)
  else:
    model.predict(x, verbose=1)

In [None]:
if WEIGHTS_FILE is not None and os.path.exists(WEIGHTS_FILE):
  model.load_weights(WEIGHTS_FILE)
  print(f'Loaded weights from: {WEIGHTS_FILE} ...')

In [None]:
for i in tqdm(range(NUM_RETRIES)):
  history = model.fit(train_set, validation_data=val_set, epochs=EPOCHS,
              callbacks=[ckpt_cb, es_cb, lr_reduce_cb  \
                if LR_REDUCE else get_lr_callback(BATCH_SIZE)])