学習用notebook → https://www.kaggle.com/hutch1221/training-cassava-classification

## 変更仕様(from baseline)
* augmentationのクラス化、private dataset に退避 https://github.com/Taichicchi1221/tf-image-classification
* random_rotation, FMixの実装
* bi-tempered lossを追加(LOSS_TYPE="BTL"で指定) https://github.com/Diulhio/bitemperedloss-tf
* External datasetを追加（2019年cassavaコンペのもの）https://www.kaggle.com/c/cassava-disease
* TTAの実装(oofで評価)
* Image Normalization実装（フラグ"IMAGE_NORMALIZATION"で管理）
* ViTモデルの指定（バッチサイズ小さめにしないとダメ）

In [1]:
# 上記notebook(学習用)の結果が入ったdirectoryを指定
INPUT_DIR = "../input/result-exp40"

In [2]:
TEST_BATCH_SIZE = 32

In [3]:
!pip install -q /kaggle/input/keras-efficientnet-whl/Keras_Applications-1.0.8-py3-none-any.whl
!pip install -q /kaggle/input/keras-efficientnet-whl/efficientnet-1.1.1-py3-none-any.whl
!pip install -q /kaggle/input/keras-pretrained-imagenet-weights/image_classifiers-1.0.0-py3-none-any.whl
!pip install -q /kaggle/input/vit-keras/validators-0.18.2-py3-none-any.whl
!pip install -q /kaggle/input/vit-keras/vit_keras-0.0.10-py3-none-any.whl

In [4]:
import os
import sys
import re
import math
import time
import gc
import random
import json
from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
import tensorflow_addons as tfa
from kaggle_datasets import KaggleDatasets
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix, accuracy_score
from sklearn.model_selection import KFold, StratifiedKFold

In [5]:
sys.path.append("/kaggle/input/tf-bi-tempered-loss")
from tf_bi_tempered_loss import BiTemperedLogisticLoss
import efficientnet.tfkeras as efn
from classification_models.tfkeras import Classifiers
from vit_keras import vit, utils

In [6]:
# codes from private dataset
sys.path.append("/kaggle/input/tf-augmentation-class/")
from augmentation import SingleImageAugmentator, MixImageAugmentator

# Create Test TFRecords

In [7]:
import cv2


# Create TF Records
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def serialize_example(image, target, image_name):
    feature = {
        'image': _bytes_feature(image),
        'target': _int64_feature(target),
        'image_name': _bytes_feature(image_name),
    }
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

IMG_PATH = "/kaggle/input/cassava-leaf-disease-classification/test_images/"
test = pd.read_csv("../input/cassava-leaf-disease-classification/sample_submission.csv", dtype={"image_id": "object", "label": "uint8"})

IMGS = os.listdir(IMG_PATH)
N_FILES = 15 # split images into 10 files
IMG_QUALITY = 100
os.makedirs("/kaggle/working/tfrecs", exist_ok = True)

print(f'Image samples: {len(IMGS)}')


test["file"] = test.index.values%N_FILES
test_filenames = []

for tfrec_num in range(N_FILES):
    samples = test[test['file'] == tfrec_num]
    n_samples = len(samples)
    if n_samples == 0: break
    fname = '/kaggle/working/tfrecs/Id_test%.2i-%i.tfrec'%(tfrec_num, n_samples)
    print('\nWriting TFRecord %i of %i...'%(tfrec_num + 1, N_FILES))
    print(f"filename: {fname}")
    print(f'{n_samples} samples')
    test_filenames.append(fname)
    with tf.io.TFRecordWriter(fname) as writer:
        for row in samples.itertuples():
            label = row.label
            image_name = row.image_id
            img_path = f'{IMG_PATH}{image_name}'
            
            img = cv2.imread(img_path)
            img = img[:, 100:700] # center cropping
            img = cv2.resize(img, (512, 512))
            img = cv2.imencode('.jpg', img, (cv2.IMWRITE_JPEG_QUALITY, IMG_QUALITY))[1].tobytes()
            
            example = serialize_example(img, label, str.encode(image_name))
            writer.write(example)
            
print("Complete")

Image samples: 1

Writing TFRecord 1 of 15...
filename: /kaggle/working/tfrecs/Id_test00-1.tfrec
1 samples
Complete


# Configurations

In [8]:
AUTO = tf.data.experimental.AUTOTUNE

# Detect hardware, return appropriate distribution strategy
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy() # default distribution strategy in Tensorflow. Works on CPU and single GPU.
    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        # Restrict TensorFlow to only use the first GPU
        try:
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
            tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
            logical_gpus = tf.config.experimental.list_logical_devices('GPU')
            print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPU")
        except RuntimeError as e:
            # Visible devices must be set before GPUs have been initialized
            print(e)
    
REPLICAS = strategy.num_replicas_in_sync
print("REPLICAS: ", REPLICAS)

Physical devices cannot be modified after being initialized
REPLICAS:  1


In [9]:
def seed_everything(seed=13):
    np.random.seed(seed)
    tf.random.set_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['TF_DETERMINISTIC_OPS'] = '1'
    os.environ['TF_KERAS'] = '1'
    random.seed(seed)

In [10]:
with open(os.path.join(INPUT_DIR, "configuration.json")) as f:
    CFG = json.load(f)

with open(os.path.join(INPUT_DIR, "result.json")) as f:
    RESULT = json.load(f)
    
OUTPUT_DIR = f"/kaggle/working/{CFG['EXPERIMENT_TAG']}"

INPUT_SHAPE = (*CFG["IMAGE_SIZE"], 3)    

print(CFG)

{'EXPERIMENT_TAG': 'exp40', 'DATA_TYPE': 'CROP-RESIZED', 'USE_EXTERNAL_DATA': True, 'SEED': 40, 'NUM_CLASSES': 5, 'CHANNELS': 3, 'IMAGE_SIZE': [512, 512], 'IMAGE_NORMALIZATION': True, 'NORMALIZATION_MEAN': [0.42984136, 0.49624753, 0.3129598], 'NORMALIZATION_STD': [0.21417203, 0.21910103, 0.19542212], 'BATCH_SIZE': 256, 'AUG_BATCH': 64, 'DO_AUG': True, 'RANDOM_FLIP_LEFT_RIGHT': True, 'RANDOM_FLIP_UP_DOWN': True, 'RANDOM_ROTATION': True, 'RANDOM_ROTATION_RANGE': 45, 'RANDOM_ROTATION_FILL_MODE': 'reflect', 'RANDOM_BRIGHTNESS': True, 'RANDOM_BRIGHTNESS_MAX_DELTA': 0.2, 'RANDOM_CONTRAST': False, 'RANDOM_CONTRAST_LOWER': 0.6, 'RANDOM_CONTRAST_UPPER': 1.4, 'RANDOM_HUE': False, 'RANDOM_HUE_MAX_DELTA': 0.07, 'RANDOM_SATURATION': False, 'RANDOM_SATURATION_LOWER': 0.5, 'RANDOM_SATURATION_UPPER': 1.5, 'DO_MIX_AUG': True, 'MIXUP_PROB': 0.0, 'MIXUP_ALPHA': 1.0, 'CUTMIX_PROB': 1.0, 'CUTMIX_ALPHA': 1.0, 'FMIX_PROB': 1.0, 'FMIX_ALPHA': 1.0, 'FMIX_DECAY': 3.0, 'FOLDS': 5, 'FOLDS_SEED': 40, 'EPOCHS': 25,

## Base Model Tags
* ResNet18
* ResNet50
* ResNet50V2
* ResNet101
* ResNet101V2
* ResNet152
* ResNet152V2
* EfficientNetB0
* EfficientNetB1
* EfficientNetB2
* EfficientNetB3
* EfficientNetB4
* EfficientNetB5
* EfficientNetB6
* EfficientNetB7
* ResNeXt50
* ResNeXt101
* SeResNet50
* SeResNet101
* SeResNeXt50
* SeResNeXt101
* vit_b16
* vit_b32
* vit_l16
* vit_l32

In [11]:
class BaseModel:
    def __init__(self, model_tag, weights, freeze_bn, input_shape):
        self.model_tag = model_tag
        self.weights = weights
        self.freeze_bn = freeze_bn
        self.INPUT_SHAPE = input_shape
        self.ERR_MSG_WEIGHT = "weightの指定が不適です！"
        self.ERR_MSG_MODEL_TAG = "Model Tagの指定が不適です！"
    def __call__(self):
        if self.model_tag == "ResNet18":
            if self.weights is None:
                pass
            elif self.weights == "imagenet":
                self.weights = "/kaggle/input/keras-pretrained-imagenet-weights/resnet18_imagenet_1000_no_top.h5"
            else:
                raise NotImplementedError(self.ERR_MSG_WEIGHT)
            base_model = tf.keras.Sequential(
                    [
                        Classifiers.get('resnet18')[0](input_shape = self.INPUT_SHAPE, weights = self.weights, include_top = False),
                        tf.keras.layers.GlobalAveragePooling2D()
                    ]
                )
        elif self.model_tag == "ResNet50":
            if self.weights in [None, "imagenet"]:
                base_model = tf.keras.applications.ResNet50(input_shape = self.INPUT_SHAPE, weights = self.weights, include_top = False, pooling = "avg")
            else:
                raise NotImplementedError(self.ERR_MSG_WEIGHT)
        elif self.model_tag == "ResNet50V2":
            if self.weights in [None, "imagenet"]:
                base_model = tf.keras.applications.ResNet50V2(input_shape = self.INPUT_SHAPE, weights = self.weights, include_top = False, pooling = "avg")
            else:
                raise NotImplementedError(self.ERR_MSG_WEIGHT)
        elif self.model_tag == "ResNet101":
            if self.weights in [None, "imagenet"]:
                base_model = tf.keras.applications.ResNet101(input_shape = self.INPUT_SHAPE, weights = self.weights, include_top = False, pooling = "avg")
            else:
                raise NotImplementedError(self.ERR_MSG_WEIGHT)
        elif self.model_tag == "ResNet101":
            if self.weights in [None, "imagenet"]:
                base_model = tf.keras.applications.ResNet101V2(input_shape = self.INPUT_SHAPE, weights = self.weights, include_top = False, pooling = "avg")
            else:
                raise NotImplementedError(self.ERR_MSG_WEIGHT)
        elif self.model_tag == "EfficientNetB0":
            if self.weights in [None, "imagenet", "noisy-student"]:
                base_model = efn.EfficientNetB0(input_shape = self.INPUT_SHAPE, weights = self.weights, include_top = False, pooling = "avg")
            else:
                raise NotImplementedError(self.ERR_MSG_WEIGHT)
        elif self.model_tag == "EfficientNetB1":
            if self.weights in [None, "imagenet", "noisy-student"]:
                base_model = efn.EfficientNetB1(input_shape = self.INPUT_SHAPE, weights = self.weights, include_top = False, pooling = "avg")
            else:
                raise NotImplementedError(self.ERR_MSG_WEIGHT)
        elif self.model_tag == "EfficientNetB2":
            if self.weights in [None, "imagenet", "noisy-student"]:
                base_model = efn.EfficientNetB2(input_shape = self.INPUT_SHAPE, weights = self.weights, include_top = False, pooling = "avg")
            else:
                raise NotImplementedError(self.ERR_MSG_WEIGHT)
        elif self.model_tag == "EfficientNetB3":
            if self.weights in [None, "imagenet", "noisy-student"]:
                base_model = efn.EfficientNetB3(input_shape = self.INPUT_SHAPE, weights = self.weights, include_top = False, pooling = "avg")
            else:
                raise NotImplementedError(self.ERR_MSG_WEIGHT)
        elif self.model_tag == "EfficientNetB4":
            if self.weights in [None, "imagenet", "noisy-student"]:
                base_model = efn.EfficientNetB4(input_shape = self.INPUT_SHAPE, weights = self.weights, include_top = False, pooling = "avg")
            else:
                raise NotImplementedError(self.ERR_MSG_WEIGHT)
        elif self.model_tag == "EfficientNetB5":
            if self.weights in [None, "imagenet", "noisy-student"]:
                base_model = efn.EfficientNetB5(input_shape = self.INPUT_SHAPE, weights = self.weights, include_top = False, pooling = "avg")
            else:
                raise NotImplementedError(self.ERR_MSG_WEIGHT)
        elif self.model_tag == "EfficientNetB6":
            if self.weights in [None, "imagenet", "noisy-student"]:
                base_model = efn.EfficientNetB6(input_shape = self.INPUT_SHAPE, weights = self.weights, include_top = False, pooling = "avg")
            else:
                raise NotImplementedError(self.ERR_MSG_WEIGHT)
        elif self.model_tag == "EfficientNetB7":
            if self.weights in [None, "imagenet", "noisy-student"]:
                base_model = efn.EfficientNetB7(input_shape = self.INPUT_SHAPE, weights = self.weights, include_top = False, pooling = "avg")
            else:
                raise NotImplementedError(self.ERR_MSG_WEIGHT)
        elif self.model_tag == "ResNeXt50":
            if self.weights is None:
                pass
            elif self.weights == "imagenet":
                self.weights = "/kaggle/input/keras-pretrained-imagenet-weights/resnext50_imagenet_1000_no_top.h5"
            else:
                raise NotImplementedError(self.ERR_MSG_WEIGHT)
            base_model = tf.keras.Sequential(
                    [
                        Classifiers.get('resnext50')[0](input_shape = self.INPUT_SHAPE, weights = self.weights, include_top = False),
                        tf.keras.layers.GlobalAveragePooling2D()
                    ]
                )
        elif self.model_tag == "ResNeXt101":
            if self.weights is None:
                pass
            elif self.weights == "imagenet":
                self.weights = "/kaggle/input/keras-pretrained-imagenet-weights/resnext101_imagenet_1000_no_top.h5"
            else:
                raise NotImplementedError(self.ERR_MSG_WEIGHT)
            base_model = tf.keras.Sequential(
                    [
                        Classifiers.get('resnext101')[0](input_shape = self.INPUT_SHAPE, weights = self.weights, include_top = False),
                        tf.keras.layers.GlobalAveragePooling2D()
                    ]
                )
        elif self.model_tag == "SeResNet50":
            if self.weights is None:
                pass
            elif self.weights == "imagenet":
                self.weights = "/kaggle/input/keras-pretrained-imagenet-weights/seresnet50_imagenet_1000_no_top.h5"
            else:
                raise NotImplementedError(self.ERR_MSG_WEIGHT)
            base_model = tf.keras.Sequential(
                    [
                        Classifiers.get('seresnet50')[0](input_shape = self.INPUT_SHAPE, weights = self.weights, include_top = False),
                        tf.keras.layers.GlobalAveragePooling2D()
                    ]
                )
        elif self.model_tag == "SeResNet101":
            if self.weights is None:
                pass
            elif self.weights == "imagenet":
                self.weights = "/kaggle/input/keras-pretrained-imagenet-weights/seresnet101_imagenet_1000_no_top.h5"
            else:
                raise NotImplementedError(self.ERR_MSG_WEIGHT)
            base_model = tf.keras.Sequential(
                    [
                        Classifiers.get('seresnet101')[0](input_shape = self.INPUT_SHAPE, weights = self.weights, include_top = False),
                        tf.keras.layers.GlobalAveragePooling2D()
                    ]
                )
        elif self.model_tag == "SeResNeXt50":
            if self.weights is None:
                pass
            elif self.weights == "imagenet":
                self.weights = "/kaggle/input/keras-pretrained-imagenet-weights/seresnext50_imagenet_1000_no_top.h5"
            else:
                raise NotImplementedError(self.ERR_MSG_WEIGHT)
            base_model = tf.keras.Sequential(
                    [
                        Classifiers.get('seresnext50')[0](input_shape = self.INPUT_SHAPE, weights = self.weights, include_top = False),
                        tf.keras.layers.GlobalAveragePooling2D()
                    ]
                )
        elif self.model_tag == "SeResNeXt101":
            if self.weights is None:
                pass
            elif self.weights == "imagenet":
                self.weights = "/kaggle/input/keras-pretrained-imagenet-weights/seresnext101_imagenet_1000_no_top.h5"
            else:
                raise NotImplementedError(self.ERR_MSG_WEIGHT)
            base_model = tf.keras.Sequential(
                    [
                        Classifiers.get('seresnext101')[0](input_shape = self.INPUT_SHAPE, weights = self.weights, include_top = False),
                        tf.keras.layers.GlobalAveragePooling2D()
                    ]
                )
        elif self.model_tag == "vit_b16":
            base_model = vit.vit_b16(
                image_size=self.INPUT_SHAPE[0],
                activation='softmax',
                pretrained=(self.weights == "imagenet"),
                include_top=False,
                pretrained_top=False,
                weights = 'imagenet21k'
            )
        elif self.model_tag == "vit_b32":
            base_model = vit.vit_b32(
                image_size=self.INPUT_SHAPE[0],
                activation='softmax',
                pretrained=(self.weights == "imagenet"),
                include_top=False,
                pretrained_top=False,
                weights = 'imagenet21k'
            )
        elif self.model_tag == "vit_l16":
            base_model = vit.vit_l16(
                image_size=self.INPUT_SHAPE[0],
                activation='softmax',
                pretrained=(self.weights == "imagenet"),
                include_top=False,
                pretrained_top=False,
                weights = 'imagenet21k'
            )
        elif self.model_tag == "vit_l32":
            base_model = vit.vit_l32(
                image_size=self.INPUT_SHAPE[0],
                activation='softmax',
                pretrained=(self.weights == "imagenet"),
                include_top=False,
                pretrained_top=False,
                weights = 'imagenet21k'
            )
        else:
            raise NotImplementedError(self.ERR_MSG_MODEL_TAG)
        
        if self.freeze_bn:
            for l in base_model.layers:
                if type(l) is tf.keras.layers.BatchNormalization:
                    l.trainable = False
        
        return base_model

In [12]:
train_df = pd.read_csv("../input/cassava-leaf-disease-classification/train.csv", dtype={"image_id": "object", "label": "uint8"})
test_df = pd.read_csv("../input/cassava-leaf-disease-classification/sample_submission.csv", dtype={"image_id": "object", "label": "uint8"})
sample_submission = pd.read_csv("../input/cassava-leaf-disease-classification/sample_submission.csv")

# Datasets Functions

In [13]:
def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)
    
def display_items(ds, labeled = True, row = 6, col = 4):
    for (img,label) in ds:
        plt.figure(figsize=(15,int(15*row/col)))
        for j in range(row*col):
            _label = label[j, ].numpy()
            if labeled:
                _label = np.round(_label, 2)
            else:
                _label = _label.decode()
            plt.subplot(row,col,j+1)
            plt.axis('off')
            plt.title(str(_label))
            if CFG["IMAGE_NORMALIZATION"]: _img = inverse_preprocess_normalization(img[j, ])
            else: _img = img[j, ]
            plt.imshow(_img.numpy())
        plt.show()
        break

In [14]:
# Augmentation用のクラス（single, mixそれぞれ）
single_augment = SingleImageAugmentator(
    seed = CFG["SEED"],
    RANDOM_FLIP_LEFT_RIGHT = CFG["RANDOM_FLIP_LEFT_RIGHT"],
    RANDOM_FLIP_UP_DOWN = CFG["RANDOM_FLIP_UP_DOWN"],
    RANDOM_ROTATION = CFG["RANDOM_ROTATION"],
    RANDOM_ROTATION_RANGE = CFG["RANDOM_ROTATION_RANGE"],
    RANDOM_ROTATION_FILL_MODE = CFG["RANDOM_ROTATION_FILL_MODE"],
    RANDOM_BRIGHTNESS = CFG["RANDOM_BRIGHTNESS"],
    RANDOM_BRIGHTNESS_MAX_DELTA = CFG["RANDOM_BRIGHTNESS_MAX_DELTA"],
    RANDOM_CONTRAST = CFG["RANDOM_CONTRAST"],
    RANDOM_CONTRAST_LOWER = CFG["RANDOM_CONTRAST_LOWER"],
    RANDOM_CONTRAST_UPPER = CFG["RANDOM_CONTRAST_UPPER"],
    RANDOM_HUE = CFG["RANDOM_HUE"],
    RANDOM_HUE_MAX_DELTA = CFG["RANDOM_HUE_MAX_DELTA"],
    RANDOM_SATURATION = CFG["RANDOM_SATURATION"],
    RANDOM_SATURATION_LOWER = CFG["RANDOM_SATURATION_LOWER"],
    RANDOM_SATURATION_UPPER = CFG["RANDOM_SATURATION_UPPER"],
)
mix_augment = MixImageAugmentator(
    seed = CFG["SEED"],
    AUG_BATCH = CFG["AUG_BATCH"],
    IMAGE_SIZE_0 = CFG["IMAGE_SIZE"][0],
    IMAGE_SIZE_1 = CFG["IMAGE_SIZE"][1],
    CHANNELS = CFG["CHANNELS"],
    CLASSES = CFG["NUM_CLASSES"],
    MIXUP_PROB = CFG["MIXUP_PROB"],
    MIXUP_ALPHA = CFG["MIXUP_ALPHA"],
    CUTMIX_PROB = CFG["CUTMIX_PROB"],
    CUTMIX_ALPHA = CFG["CUTMIX_ALPHA"],
    FMIX_PROB = CFG["FMIX_PROB"],
    FMIX_ALPHA = CFG["FMIX_ALPHA"],
    FMIX_DECAY = CFG["FMIX_DECAY"],
)

In [15]:
def preprocess_normalization(image):
    
    mean = tf.convert_to_tensor(CFG["NORMALIZATION_MEAN"], dtype=tf.float32)
    std = tf.convert_to_tensor(CFG["NORMALIZATION_STD"], dtype=tf.float32)
    
    image = (image - mean)/std
    
    return image

def inverse_preprocess_normalization(image):
    
    mean = tf.convert_to_tensor(CFG["NORMALIZATION_MEAN"], dtype=tf.float32)
    std = tf.convert_to_tensor(CFG["NORMALIZATION_STD"], dtype=tf.float32)
    
    image = image * std + mean
    
    return image

In [16]:
def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, CFG["IMAGE_SIZE"])
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.reshape(image, [*CFG["IMAGE_SIZE"], CFG["CHANNELS"]])
    if CFG["IMAGE_NORMALIZATION"]:
        image = preprocess_normalization(image)
    return image

def onehot(image,label):
    return image,tf.one_hot(label,CFG["NUM_CLASSES"])

def read_tfrecord(example, labeled):
    tfrecord_format = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.int64)
    } if labeled else {
        "image": tf.io.FixedLenFeature([], tf.string),
        "image_name": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    if labeled:
        label = example["target"]
        return image, label
    image_id = example['image_name']
    return image, image_id

def load_dataset(filenames, labeled=True, ordered=False):
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False # disable order, increase speed
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.map(lambda x:read_tfrecord(x, labeled=labeled), num_parallel_calls=AUTO)
    return dataset

def get_training_dataset(filenames, do_aug = True, do_mix_aug = True):
    dataset = load_dataset(filenames, labeled = True, ordered = False)
    dataset = dataset.repeat() # the training dataset must repeat for several epochs
    dataset = dataset.shuffle(512, seed = CFG["SEED"])
    dataset = dataset.map(onehot, num_parallel_calls=AUTO)
    if do_aug:
        dataset = dataset.map(single_augment, num_parallel_calls=AUTO)
    if do_mix_aug:
        dataset = dataset.batch(CFG["AUG_BATCH"])
        dataset = dataset.map(mix_augment, num_parallel_calls=AUTO) # note we put AFTER batching
        dataset = dataset.unbatch()
    dataset = dataset.batch(CFG["BATCH_SIZE"])
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def get_validation_dataset(filenames):
    dataset = load_dataset(filenames, labeled = True, ordered = True)
    dataset = dataset.map(onehot, num_parallel_calls = AUTO)
    dataset = dataset.batch(CFG["BATCH_SIZE"])
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def get_test_dataset(filenames):
    dataset = load_dataset(filenames, labeled = False, ordered = True)
    dataset = dataset.batch(CFG["BATCH_SIZE"])
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

# Model

In [17]:
def create_model():
    inputs = tf.keras.Input(INPUT_SHAPE)
    base_model = BaseModel(
        model_tag = CFG['MODEL_TAG'],
        weights = None,
        freeze_bn = CFG["FREEZE_BN"],
        input_shape = INPUT_SHAPE,
    )()
    
    base = base_model(inputs)
    
    outputs = tf.keras.layers.Dense(CFG["NUM_CLASSES"])(base)
    model = tf.keras.models.Model(
        inputs = inputs,
        outputs = outputs,
    )
        
    return model

# Inference

In [18]:
def read_tfrecord_ids(example):
    tfrecord_format = {
        "image_name": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image_id = example['image_name']
    return image_id

def get_ds_for_id(filenames):
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) # automatically interleaves reads from multiple files
    dataset = dataset.map(lambda x:read_tfrecord_ids(x), num_parallel_calls=AUTO)
    dataset = dataset.batch(TEST_BATCH_SIZE)
    return dataset
    
def get_ds_for_prediction(filenames):
    do_aug = (CFG["DO_AUG"] and (CFG["TTA"] >= 2))
    dataset = load_dataset(filenames, labeled = True, ordered = True)
    dataset = dataset.repeat(CFG["TTA"])
    if do_aug: dataset = dataset.map(single_augment, num_parallel_calls=AUTO)
    dataset = dataset.batch(TEST_BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def make_prediction(filenames, model):
    result_df = pd.DataFrame()
    
    id_ds = get_ds_for_id(filenames)
    image_ids = np.concatenate([image_id.astype("str") for image_id in id_ds.as_numpy_iterator()] * CFG["TTA"])

    dataset = get_ds_for_prediction(filenames)
    preds = model.predict(dataset, verbose = 1)

    result_df["image_id"] = image_ids
    for i in range(CFG["NUM_CLASSES"]):
        result_df[f"label_{i}"] = preds[:, i]
    
    result_df = result_df.groupby("image_id").mean()
    
    return result_df

In [19]:
%%time
sub_prob = pd.DataFrame()

model = create_model()


for fold in range(CFG["FOLDS"]):
    print("#" * 30, f"fold {fold + 1}/{CFG['FOLDS']}", "#" * 30)
    tf.keras.backend.clear_session()
    weights_file_path = f"weights_{CFG['MODEL_TAG']}_fold{fold}.h5"

    model.load_weights(os.path.join(INPUT_DIR, weights_file_path))

    # make sub prediction
    _preds = make_prediction(test_filenames, model)
    sub_prob = pd.concat([sub_prob, _preds])

############################## fold 1/5 ##############################
############################## fold 2/5 ##############################
############################## fold 3/5 ##############################
############################## fold 4/5 ##############################
############################## fold 5/5 ##############################
CPU times: user 16.2 s, sys: 1.31 s, total: 17.5 s
Wall time: 26.5 s


In [20]:
sub_prob

Unnamed: 0_level_0,label_0,label_1,label_2,label_3,label_4
image_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
2216849948.jpg,-0.815305,-0.370056,1.159971,-1.831684,1.691907
2216849948.jpg,-1.27922,-0.633256,2.01014,-1.904642,1.533733
2216849948.jpg,-1.062366,-0.645648,1.112429,-1.065102,1.69814
2216849948.jpg,-1.107527,-0.105813,1.738636,-1.896744,1.971731
2216849948.jpg,-0.704563,-0.697859,1.349013,-1.480617,1.700505


In [21]:
sub_prob = sub_prob.groupby("image_id").mean()
sub_prob.to_csv("submission_probabilities.csv")
sub = sub_prob.copy()
sub["label"] = np.argmax(sub_prob.values, axis = 1)
sub = sub[["label"]]
sub.to_csv("submission.csv")

In [22]:
sub

Unnamed: 0_level_0,label
image_id,Unnamed: 1_level_1
2216849948.jpg,4


In [23]:
os.mkdir(OUTPUT_DIR)

with open(os.path.join(OUTPUT_DIR, "configuration.json"), "w") as f:
    json.dump(CFG, f)
    
with open(os.path.join(OUTPUT_DIR, "result.json"), "w") as f:
    json.dump(RESULT, f)