<center><img src="https://raw.githubusercontent.com/dimitreOliveira/MachineLearning/master/Kaggle/Cassava%20Leaf%20Disease%20Classification/banner.png" width="1000"></center>
<br>
<center><h1>Cassava Leaf Disease - TPU v2 Pods - Inference</h1></center>
<br>

- This is the inference part of the work, the training notebook can be found here [Cassava Leaf Disease - Training with TPU v2 Pods](https://www.kaggle.com/dimitreoliveira/cassava-leaf-disease-training-with-tpu-v2-pods)
- keras-applications GitHub repository can be found [here](https://www.kaggle.com/dimitreoliveira/kerasapplications)
- efficientnet GitHub repository can be found [here](https://www.kaggle.com/dimitreoliveira/efficientnet-git)
- Dataset source `center cropped` [512x512](https://www.kaggle.com/dimitreoliveira/cassava-leaf-disease-50-tfrecords-center-512x512)
- Dataset source `external data` `center cropped` [512x512](https://www.kaggle.com/dimitreoliveira/cassava-leaf-disease-50-tfrecords-external-512x512)
- Dataset source [discussion thread](https://www.kaggle.com/c/cassava-leaf-disease-classification/discussion/198744)
- Dataset [creation source](https://www.kaggle.com/dimitreoliveira/cassava-leaf-disease-stratified-tfrecords-256x256)
- Model experiments [discussion thread](https://www.kaggle.com/c/cassava-leaf-disease-classification/discussion/203594)

## Dependencies

In [None]:
#Установка и импорт всех необходимых элементов
!pip install --quiet /kaggle/input/kerasapplications
!pip install --quiet /kaggle/input/efficientnet-git

In [None]:
import math, os, re, warnings, random, glob
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow.keras.layers as L
import tensorflow.keras.backend as K
from tensorflow.keras import Sequential, Model
import efficientnet.tfkeras as efn

def seed_everything(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['TF_DETERMINISTIC_OPS'] = '1'

seed = 1234
seed_everything(seed)
warnings.filterwarnings('ignore')

In [None]:
# TPU or GPU detection
# Detect hardware, return appropriate distribution strategy
#Определение TPU или GPU. В данном случае - GPU
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print(f'Running on TPU {tpu.master()}')
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy()

AUTO = tf.data.experimental.AUTOTUNE
REPLICAS = strategy.num_replicas_in_sync
print(f'REPLICAS: {REPLICAS}')

In [None]:
#Параметры модели
BATCH_SIZE = 32 * REPLICAS
HEIGHT = 512
WIDTH = 512 
CHANNELS = 3
N_CLASSES = 5
"""
batchsize的大小
长宽高
通道数
分类数
"""
TTA_STEPS = 10 # Do TTA if > 0 

In [None]:
#Аугментация (Некоторые из первого ноутбука)
def data_augment(image, label):
    p_spatial = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_rotate = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_pixel_1 = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_pixel_2 = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_pixel_3 = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_crop = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
            
    # Flips
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    if p_spatial > .75:
        image = tf.image.transpose(image)
        
    # 图形反转
    if p_rotate > .75:
        image = tf.image.rot90(image, k=3) # rotate 270º
    elif p_rotate > .5:
        image = tf.image.rot90(image, k=2) # rotate 180º
    elif p_rotate > .25:
        image = tf.image.rot90(image, k=1) # rotate 90º
        
    # Pixel-level transforms
    if p_pixel_1 >= .4:
        image = tf.image.random_saturation(image, lower=.7, upper=1.3)
    if p_pixel_2 >= .4:
        image = tf.image.random_contrast(image, lower=.8, upper=1.2)
    if p_pixel_3 >= .4:
        image = tf.image.random_brightness(image, max_delta=.1)
        
    # Crops
    if p_crop > .7:
        if p_crop > .9:
            image = tf.image.central_crop(image, central_fraction=.7)
        elif p_crop > .8:
            image = tf.image.central_crop(image, central_fraction=.8)
        else:
            image = tf.image.central_crop(image, central_fraction=.9)
    elif p_crop > .4:
        crop_size = tf.random.uniform([], int(HEIGHT*.8), HEIGHT, dtype=tf.int32)
        image = tf.image.random_crop(image, size=[crop_size, crop_size, CHANNELS])

    return image, label

In [None]:
#Имя
def get_name(file_path):
    parts = tf.strings.split(file_path, os.path.sep)
    name = parts[-1]
    return name

#Декодирование
def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)#三通道的采样
    image = tf.cast(image, tf.float32) / 255.0##正则化
    
#     image = center_crop(image)
    return image


def center_crop(image):
    image = tf.reshape(image, [600, 800, CHANNELS]) # 原始尺寸
    
    h, w = image.shape[0], image.shape[1]
    if h > w:
        image = tf.image.crop_to_bounding_box(image, (h - w) // 2, 0, w, w)#按最小的边进行放缩
    else:
        image = tf.image.crop_to_bounding_box(image, 0, (w - h) // 2, h, h)
        
    image = tf.image.resize(image, [HEIGHT, WIDTH]) # Expected shape
    return image

def resize_image(image, label):
    image = tf.image.resize(image, [HEIGHT, WIDTH])
    image = tf.reshape(image, [HEIGHT, WIDTH, CHANNELS])
    return image, label


def process_path(file_path):
    name = get_name(file_path)
    img = tf.io.read_file(file_path)
    img = decode_image(img)
    return img, name


def get_dataset(files_path, shuffled=False, tta=False, extension='jpg'):
    dataset = tf.data.Dataset.list_files(f'{files_path}*{extension}', shuffle=shuffled)
    dataset = dataset.map(process_path, num_parallel_calls=AUTO)
#    if tta:
    dataset = dataset.map(data_augment, num_parallel_calls=AUTO)
    dataset = dataset.map(resize_image, num_parallel_calls=AUTO)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO)
    return dataset


def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

In [None]:
#Загрузка датасетов и моделей
database_base_path = '/kaggle/input/cassava-leaf-disease-classification/'
submission = pd.read_csv(f'{database_base_path}sample_submission.csv')
display(submission.head())

TEST_FILENAMES = tf.io.gfile.glob(f'{database_base_path}test_tfrecords/ld_test*.tfrec')
NUM_TEST_IMAGES = count_data_items(TEST_FILENAMES)
print(f'GCS: test: {NUM_TEST_IMAGES}')

In [None]:
model_path_list = glob.glob('/kaggle/input/cassava-leaf-disease-training-with-tpu-v2-pods/*.h5')
model_path_list.sort()

print('Models to predict:')
print(*model_path_list, sep='\n')

In [None]:

def model_fn(input_shape, N_CLASSES):
    inputs = L.Input(shape=input_shape, name='input_image')
    base_model = efn.EfficientNetB4(input_tensor=inputs, 
                                    include_top=False, 
                                    weights=None, 
                                    pooling='avg')##使用efficientnet

    x = L.Dropout(0.4)(base_model.output)
    output = L.Dense(N_CLASSES, activation='tanh', name='output')(x)
    model = Model(inputs=inputs, outputs=output)

    return model

with strategy.scope():
    model = model_fn((None, None, CHANNELS), N_CLASSES)
    
model.summary()

In [None]:
#Прогнозирование тестовых наборов, ансамбль моделей
files_path = f'{database_base_path}test_images/'
test_size = len(os.listdir(files_path))
test_preds = np.zeros((test_size, N_CLASSES))##弄好测试集


for model_path in model_path_list:
    print(model_path)
    K.clear_session()
    model.load_weights(model_path)

    if TTA_STEPS > 0:
        test_ds = get_dataset(files_path, tta=True).repeat()
        ct_steps = TTA_STEPS * ((test_size/BATCH_SIZE) + 1)
        preds = model.predict(test_ds, steps=ct_steps, verbose=1)[:(test_size * TTA_STEPS)]
        preds = np.mean(preds.reshape(test_size, TTA_STEPS, N_CLASSES, order='F'), axis=1)
        test_preds += preds / len(model_path_list)
    else:
        test_ds = get_dataset(files_path, tta=False)
        x_test = test_ds.map(lambda image, image_name: image)
        test_preds += model.predict(x_test) / len(model_path_list)
    
test_preds = np.argmax(test_preds, axis=-1)
test_names_ds = get_dataset(files_path)
image_names = [img_name.numpy().decode('utf-8') for img, img_name in iter(test_names_ds.unbatch())]

In [None]:
#Формирование CSV
submission = pd.DataFrame({'image_id': image_names, 'label': test_preds})
submission.to_csv('submission.csv', index=False)
display(submission.head())