# Training notebook using TPU for "Cassava leaf disease classification" competition

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import warnings
import json
import re
from pathlib import Path
import numpy as np, pandas as pd, matplotlib.pyplot as plt, seaborn as sns
import matplotlib.image as mpimg
import tensorflow as tf

warnings.filterwarnings("ignore")
sns.set()

In [3]:
# BASE PATHS
BASE_DIR = Path("../input/cassava-leaf-disease-classification") #Path to data directory
MODELS_DIR = Path("../input/training") #Path to saved models
IMAGE_DIR = Path(BASE_DIR, "train_images") #Path to images directory
OUTPUT_DIR = Path("./") #Path to 'output' directory

In [4]:
with open(Path(BASE_DIR, "label_num_to_disease_map.json"), 'r') as infile:
    map_classes = json.load(infile)
map_classes = {int(k):v for k, v in map_classes.items()}
map_classes

{0: 'Cassava Bacterial Blight (CBB)',
 1: 'Cassava Brown Streak Disease (CBSD)',
 2: 'Cassava Green Mottle (CGM)',
 3: 'Cassava Mosaic Disease (CMD)',
 4: 'Healthy'}

In [5]:
data = pd.read_csv(Path(BASE_DIR, "train.csv"))
data['class_name'] = data['label'].map(map_classes)
freqs = 1 - np.unique(data["label"], return_counts=True)[1]/data.shape[0]
LOSS_WEIGHTS = {i:10*freqs[i] for i in range(5)} #Weights for loss function

In [6]:
# For TPU training
TPU = tf.distribute.cluster_resolver.TPUClusterResolver.connect() #Detect and init the TPU
print('Device:', TPU.master())
TPU_STRATEGY = tf.distribute.experimental.TPUStrategy(TPU) #Instantiate a distribution strategy
REPLICAS = TPU_STRATEGY.num_replicas_in_sync
AUTO = tf.data.experimental.AUTOTUNE

Device: grpc://10.0.0.2:8470


In [7]:
from tensorflow.keras.backend import clear_session
from tensorflow.keras.mixed_precision import experimental as mixed_precision

from tensorflow.keras import Input
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Flatten, Dense, Dropout, GlobalAveragePooling2D
from tensorflow.keras.applications import Xception, InceptionV3, EfficientNetB0, EfficientNetB4, EfficientNetB7
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy, CategoricalCrossentropy, SparseCategoricalCrossentropy
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint, LearningRateScheduler


# policy = mixed_precision.Policy('mixed_bfloat16')
# mixed_precision.set_policy(policy) #shortens training time by 2x

In [8]:
#Global variables
DEBUG = False
SEED = 117
IMAGE_SIZE = 512 #Parameter to choose wisely
TARGET_SIZE = (IMAGE_SIZE, IMAGE_SIZE)
INPUT_SHAPE = (IMAGE_SIZE, IMAGE_SIZE, 3)
N_CLASSES = 5

In [9]:
from kaggle_datasets import KaggleDatasets
from sklearn.model_selection import train_test_split

GCS_PATH = KaggleDatasets().get_gcs_path('cassava-leaf-disease-classification')
FILENAMES = tf.io.gfile.glob(GCS_PATH+'/train_tfrecords/*.tfrec')
FILENAMES, TEST_FILENAMES = train_test_split(FILENAMES, test_size=0.1, random_state=SEED)

In [10]:
# Datasets utility functions
def one_hot(image, label):
    label = tf.one_hot(label, N_CLASSES, dtype=tf.float32)
    
    return image, label

def decode_image(image_data):
    """
        1. Decode a JPEG-encoded image to a uint8 tensor.
        2. Cast tensor to float and normalizes (range between 0 and 1).
        3. Resize and reshape images to the expected size.
    """
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float64) #/255.0
    image = PREPROCESS_FUNC(image)
                      
    image = tf.image.resize(image, [IMAGE_SIZE, IMAGE_SIZE])
    image = tf.reshape(image, [IMAGE_SIZE, IMAGE_SIZE, 3])
    
    return image

def read_tfrecord(example, labeled=True):
    """
        1. Parse data based on the 'TFREC_FORMAT' map.
        2. Decode image.
        3. If 'labeled' returns (image, label) if not (image, name).
    """
    if labeled:
        TFREC_FORMAT = {
            'image': tf.io.FixedLenFeature([], tf.string), 
            'target': tf.io.FixedLenFeature([], tf.int64), 
        }
    else:
        TFREC_FORMAT = {
            'image': tf.io.FixedLenFeature([], tf.string), 
            'image_name': tf.io.FixedLenFeature([], tf.string), 
        }
    example = tf.io.parse_single_example(example, TFREC_FORMAT)
    image = decode_image(example['image'])
    if labeled:
        label = tf.cast(example['target'], tf.int32)
    else:
        label = example['image_name']
        
    return image, label


def load_dataset(filenames, labeled=True, ordered=False):
    """
        Create a Tensorflow dataset from TFRecords.
    """
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False

    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)
    dataset = dataset.with_options(ignore_order)
    dataset = dataset.map(lambda x: read_tfrecord(x, labeled=labeled), num_parallel_calls=AUTO)
    
    return dataset


def get_dataset(filenames, labeled=True, ordered=False, repeated=False, augment=True, drop_remainder=False):
    """
        Return a Tensorflow dataset ready for training or inference.
    """
    dataset = load_dataset(filenames, labeled=labeled, ordered=ordered)
    dataset = dataset.map(one_hot, num_parallel_calls=AUTO)
    if repeated:
        dataset = dataset.repeat()
    if augment:
        dataset = dataset.batch(AUG_BATCH)
        dataset = dataset.map(transform, num_parallel_calls=AUTO)
        dataset = dataset.unbatch()
    if not ordered:
        dataset = dataset.shuffle(SEED)
    dataset = dataset.batch(BATCH_SIZE, drop_remainder=drop_remainder)
    dataset = dataset.prefetch(AUTO)
    
    return dataset

In [11]:
def cutmix(image, label, prob=1.0):
    """
        image: a batch of images of size [AUG_BATCH, IMAGE_SIZE, IMAGE_SIZE, 3]
    """
    imgs = []; labs = []
    for j in range(AUG_BATCH):
        # DO CUTMIX WITH PROBABILITY DEFINED ABOVE
        P = tf.cast(tf.random.uniform([], 0, 1)<=prob, tf.int32)
        # CHOOSE RANDOM IMAGE TO CUTMIX WITH
        k = tf.cast(tf.random.uniform([], 0, AUG_BATCH), tf.int32)
        # CHOOSE RANDOM LOCATION
        x = tf.cast(tf.random.uniform([], 0, IMAGE_SIZE), tf.int32)
        y = tf.cast(tf.random.uniform([], 0, IMAGE_SIZE), tf.int32)
        b = tf.random.uniform([], 0, 1) # this is beta dist with alpha=1.0
        width = tf.cast(IMAGE_SIZE*tf.math.sqrt(1-b), tf.int32)*P
        ya = tf.math.maximum(0, y-width//2)
        yb = tf.math.minimum(IMAGE_SIZE, y+width//2)
        xa = tf.math.maximum(0, x-width//2)
        xb = tf.math.minimum(IMAGE_SIZE, x+width//2)
        # MAKE CUTMIX IMAGE
        one = image[j, ya:yb, 0:xa, :]
        two = image[k, ya:yb, xa:xb, :]
        three = image[j, ya:yb, xb:IMAGE_SIZE, :]
        middle = tf.concat([one, two, three], axis=1)
        img = tf.concat([image[j, 0:ya, :, :], middle, image[j, yb:IMAGE_SIZE, :, :]], axis=0)
        imgs.append(img)
        # MAKE CUTMIX LABEL
        a = tf.cast(width*width/IMAGE_SIZE/IMAGE_SIZE, tf.float32)
        if len(label.shape) == 1:
            lab1 = tf.one_hot(label[j], N_CLASSES)
            lab2 = tf.one_hot(label[k], N_CLASSES)
        else:
            lab1 = label[j, ]
            lab2 = label[k, ]
        labs.append((1-a)*lab1 + a*lab2)
            
    # RESHAPE HACK SO TPU COMPILER KNOWS SHAPE OF OUTPUT TENSOR (maybe use Python typing instead?)
    image2 = tf.reshape(tf.stack(imgs), (AUG_BATCH, IMAGE_SIZE, IMAGE_SIZE, 3))
    label2 = tf.reshape(tf.stack(labs), (AUG_BATCH, N_CLASSES))
    
    return image2, label2


def mixup(image, label, prob=1.0):
    """
        image: a batch of images of size [AUG_BATCH, IMAGE_SIZE, IMAGE_SIZE, 3]
    """
    imgs = []; labs = []
    for j in range(AUG_BATCH):
        # DO MIXUP WITH PROBABILITY DEFINED ABOVE
        P = tf.cast(tf.random.uniform([], 0, 1)<=prob, tf.float32)
        # CHOOSE RANDOM
        k = tf.cast(tf.random.uniform([], 0, AUG_BATCH), tf.int32)
        a = tf.random.uniform([], 0, 1)*P # this is beta dist with alpha=1.0
        # MAKE MIXUP IMAGE
        img1 = image[j, ]
        img2 = image[k, ]
        imgs.append((1-a)*img1 + a*img2)
        # MAKE CUTMIX LABEL
        if len(label.shape) == 1:
            lab1 = tf.one_hot(label[j], N_CLASSES)
            lab2 = tf.one_hot(label[k], N_CLASSES)
        else:
            lab1 = label[j, ]
            lab2 = label[k, ]
        labs.append((1-a)*lab1 + a*lab2)
            
    # RESHAPE HACK SO TPU COMPILER KNOWS SHAPE OF OUTPUT TENSOR (maybe use Python typing instead?)
    image2 = tf.reshape(tf.stack(imgs), (AUG_BATCH, IMAGE_SIZE, IMAGE_SIZE,3))
    label2 = tf.reshape(tf.stack(labs), (AUG_BATCH, N_CLASSES))
    
    return image2, label2

In [12]:
def apply_flips(image):
    img = tf.image.random_flip_left_right(image)
    img = tf.image.random_flip_up_down(img)
    
    return img

def transform(image, label):
    """
        image: a batch of images of size [AUG_BATCH, IMAGE_SIZE, IMAGE_SIZE, 3]
    """
    switch = 0.5
    cutmix_prob = 0.666
    mixup_prob = 0.666
    # FOR SWITCH PERCENT OF TIME WE DO CUTMIX AND (1-SWITCH) WE DO MIXUP
    image1 = []
    for j in range(AUG_BATCH):
#         img = transform_mat(image[j, ])
        img = image[j, ]
        img = apply_flips(img)
        image1.append(img)
        
    image1 = tf.reshape(tf.stack(image1), (AUG_BATCH, IMAGE_SIZE, IMAGE_SIZE, 3))
    image2, label2 = cutmix(image1, label, cutmix_prob)
    image3, label3 = mixup(image1, label, mixup_prob)
    
    imgs = []; labs = []
    for j in range(AUG_BATCH):
        P = tf.cast(tf.random.uniform([], 0, 1)<=switch, tf.float32)
        imgs.append(P*image2[j, ]+(1-P)*image3[j, ])
        labs.append(P*label2[j, ]+(1-P)*label3[j ,])
    # RESHAPE HACK SO TPU COMPILER KNOWS SHAPE OF OUTPUT TENSOR (maybe use Python typing instead?)
    image4 = tf.reshape(tf.stack(imgs), (AUG_BATCH, IMAGE_SIZE, IMAGE_SIZE, 3))
    label4 = tf.reshape(tf.stack(labs), (AUG_BATCH, N_CLASSES))
    
    return image4, label4

In [13]:
# Data preprocessing functions
# def augment_train(image, label):
#     p_flip_lr = np.random.uniform()
#     p_flip_ud = np.random.uniform()
#     if p_flip_lr >= 0.5:
#         image = tf.image.random_flip_left_right(image)
#     if p_flip_ud >= 0.5:
#         image = tf.image.random_flip_up_down(image)
    
#     return image, label

# def augment_test(image, label):
#     return image, label

def to_float32(image, label):
    return tf.cast(image, tf.float32), label

def count_data_items(filenames):
    n = [int(re.compile(r'-([0-9]*)\.').search(filename).group(1)) for filename in filenames]
    
    return np.sum(n)

In [14]:
def build_model(model_name, num_classes=N_CLASSES, pretrained=True, freeze=False):
    weights = None
    trainable = True
    if pretrained:
        weights = "imagenet"
        if freeze:
            trainable = False
        
    base_model = getattr(tf.keras.applications, model_name)(include_top=False, 
                                                            weights=weights, 
                                                            input_shape=INPUT_SHAPE)
    base_model.trainable = trainable
    
    inputs = Input(shape=INPUT_SHAPE)
    x = base_model(inputs) #'training=False' allows keeping batch norm layers in inference mode when unfreezing
    x = GlobalAveragePooling2D()(x)
    x = Dropout(0.2)(x, training=True)
    outputs = Dense(num_classes, activation="softmax", dtype='float32')(x)
    model = Model(inputs, outputs)
            
    return model

In [15]:
#LR scheduler (rampup) HP
LR_START = 1e-5
LR_MAX = 5e-5*REPLICAS
LR_MIN = 1e-5
LR_RAMPUP_EPOCHS = 5
LR_SUSTAIN_EPOCHS = 0
LR_EXP_DECAY = .8

def lrfn(epoch):
    if epoch < LR_RAMPUP_EPOCHS:
        lr = (LR_MAX - LR_START) / LR_RAMPUP_EPOCHS * epoch + LR_START
    elif epoch < LR_RAMPUP_EPOCHS + LR_SUSTAIN_EPOCHS:
        lr = LR_MAX
    else:
        lr = (LR_MAX - LR_MIN) * LR_EXP_DECAY**(epoch - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS) + LR_MIN
    return lr
    
lr_scheduler = LearningRateScheduler(lrfn, verbose = True)

In [16]:
from keras.callbacks import Callback
from keras.backend import set_value, get_value
from keras import backend as Kk


class CosineAnnealingScheduler(Callback):
    """
        Cosine annealing scheduler.
    """
    def __init__(self, T_max, eta_max, eta_min=0, verbose=0):
        super(CosineAnnealingScheduler, self).__init__()
        self.T_max = T_max
        self.eta_max = eta_max
        self.eta_min = eta_min
        self.verbose = verbose

    def on_epoch_begin(self, epoch, logs=None):
        if not hasattr(self.model.optimizer, 'lr'):
            raise ValueError('Optimizer must have a "lr" attribute.')
        lr = self.eta_min + (self.eta_max - self.eta_min) * (1 + np.cos(np.pi * epoch / self.T_max)) / 2
        set_value(self.model.optimizer.lr, lr)
        if self.verbose > 0:
            print('\nEpoch %05d: CosineAnnealingScheduler setting learning '
                  'rate to %s.' % (epoch + 1, lr))

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        logs['lr'] = get_value(self.model.optimizer.lr)

### Single model

In [17]:
# BATCH_SIZE = 64*REPLICAS #NUMBER IN FRONT OF REPLIAS CAN BE CHANGED (MAX. 128)
# AUG_BATCH = BATCH_SIZE//2
# STEPS_PER_EXECUTION = 16
# # LR = 5e-5*REPLICAS #SHOULDN'T BE CHANGED
# LR = 1e-3
# EPOCHS = 20
# PATIENCE = 5

# PREPROCESS_FUNC = tf.keras.applications.xception.preprocess_input
# TRAIN_FILENAMES, VAL_FILENAMES = train_test_split(FILENAMES, test_size=0.3, random_state=SEED)
# train_dataset = get_dataset(TRAIN_FILENAMES, 
#                             labeled=True, 
#                             ordered=False, 
#                             repeated=True, 
#                             augment=True,
#                             drop_remainder=False)
# val_dataset = get_dataset(VAL_FILENAMES,
#                          labeled=True, 
#                          ordered=True, 
#                          repeated=True, 
#                          augment=False, 
#                          drop_remainder=False)

# STEPS_PER_EPOCH = count_data_items(TRAIN_FILENAMES)//BATCH_SIZE
# VAL_STEPS = count_data_items(VAL_FILENAMES)//BATCH_SIZE

In [18]:
# if DEBUG: 
#     row = 6
#     col = 4
#     row = min(row, AUG_BATCH//col)
#     all_elements = get_dataset(TRAIN_FILENAMES, 
#                                labeled=True, 
#                                ordered=False,
#                                repeated=True,
#                                augment=False).unbatch()
#     aug_elements = all_elements.repeat().batch(AUG_BATCH).map(transform)

#     for (img, label) in aug_elements:
#         plt.figure(figsize=(15, int(15*row/col)))
#         for j in range(row*col):
#             plt.subplot(row, col, j+1)
#             plt.axis('off')
#             plt.imshow(img[j, ])
#         plt.show()
#         break

In [19]:
# loss = CategoricalCrossentropy(label_smoothing=0.2)
# # loss = SparseCategoricalCrossentropy(from_logits=False)
# optimizer = Adam(learning_rate=LR)


# clear_session()
# model_name = "Xception" #CHOOSE DESIRED MODEL
# with TPU_STRATEGY.scope():
#     model = build_model(model_name)
#     model.compile(
#         optimizer=optimizer,
#         loss=loss,
#         metrics=["accuracy"],
# #         steps_per_execution=STEPS_PER_EXECUTION
#     )
#     model.summary()

In [20]:
# history = model.fit(
#     train_dataset,
#     validation_data=val_dataset,
#     class_weight=LOSS_WEIGHTS,
#     epochs=EPOCHS,
#     steps_per_epoch=STEPS_PER_EPOCH,
#     validation_steps=VAL_STEPS,
#     callbacks=[
# #         CosineAnnealingScheduler(T_max=3, eta_min=1e-4, eta_max=1e-3, verbose=1),
#         ReduceLROnPlateau(monitor='val_loss', patience=1, verbose=1, factor=0.3, min_delta=0.001),
# #         lr_scheduler,
#         EarlyStopping(monitor='val_loss', patience=PATIENCE, verbose=1, min_delta=0.0001, restore_best_weights=True),
#         ModelCheckpoint(filepath=Path(OUTPUT_DIR, model_name+"_512.h5"), monitor='val_loss', save_best_only=True)
#     ]
# )

In [21]:
# fig, ax = plt.subplots(1, 2, figsize=(16, 5))

# ax[0].plot(history.history['accuracy'])
# ax[0].plot(history.history['val_accuracy'])
# ax[0].set(xlabel="epoch", ylabel="accuracy", title="Model accuracy")
# ax[0].legend(['train', 'test'])

# ax[1].plot(history.history['loss'])
# ax[1].plot(history.history['val_loss'])
# ax[1].set(xlabel="epoch", ylabel="loss", title="Model loss")
# ax[1].legend(['train', 'test'], loc='upper left')

In [22]:
# test_dataset = get_dataset(TEST_FILENAMES,
#                            labeled=True,
#                            ordered=True,
#                            repeated=False,
#                            augment=False,
#                            drop_remainder=False)
# test_images = test_dataset.map(lambda image, label: image)

# STEPS = count_data_items(TEST_FILENAMES)//BATCH_SIZE
# preds = model.predict(
#     test_images,
#     steps=STEPS,
#     verbose=1
# )
# preds = np.argmax(preds, axis=1)
# labels = np.argmax([target.numpy() for img, target in iter(test_dataset.unbatch())], axis=1)

In [23]:
# from sklearn.metrics import accuracy_score

# accuracy_score(labels, preds)

In [24]:
# from toolbox import plot_confusion_matrix

# plot_confusion_matrix(preds, labels, normalize=True)

### K-folds

In [25]:
from sklearn.model_selection import KFold

def kfold_training(model_name, filenames, K=3):
    i = 1
    kf = KFold(n_splits=K)
    filenames = np.array(filenames)
    for train_index, val_index in kf.split(filenames):
        print("Fold #", i)
        savepath = Path(OUTPUT_DIR, "{0:s}_512_fold{1:d}.h5".format(model_name, i))
        
        TRAIN_FILENAMES = filenames[train_index]
        VAL_FILENAMES = filenames[val_index]
        train_dataset = get_dataset(TRAIN_FILENAMES, 
                                    labeled=True, 
                                    ordered=False, 
                                    repeated=True, 
                                    augment=True,
                                    drop_remainder=False)
        val_dataset = get_dataset(VAL_FILENAMES,
                                  labeled=True, 
                                  ordered=True, 
                                  repeated=True,
                                  augment=False,
                                  drop_remainder=False)
        steps_per_epoch = count_data_items(TRAIN_FILENAMES)//BATCH_SIZE
        val_steps = count_data_items(VAL_FILENAMES)//BATCH_SIZE

        loss = CategoricalCrossentropy(label_smoothing=0.2)
        optimizer = Adam(learning_rate=LR)
        with TPU_STRATEGY.scope():
            model = build_model(
                model_name=model_name,
                num_classes=5
            )
            model.compile(
                optimizer=optimizer,
                loss=loss,
                metrics=["accuracy"],
#                 steps_per_execution=STEPS_PER_EXECUTION
            )
        
        history = model.fit(
            train_dataset,
            validation_data=val_dataset,
            steps_per_epoch=steps_per_epoch,
            validation_steps=val_steps,
            class_weight=LOSS_WEIGHTS,
            epochs=EPOCHS,
            callbacks=[
                ReduceLROnPlateau(monitor='val_loss', patience=1, verbose=1, factor=0.3, min_delta=0.001),
#                 lr_scheduler,
                EarlyStopping(monitor='val_loss', patience=PATIENCE, verbose=1, min_delta=0.0001, restore_best_weights=True),
                ModelCheckpoint(filepath=savepath, monitor='val_loss', save_best_only=True)
            ]
        )
        
        i += 1
        
    return model

In [26]:
BATCH_SIZE = 64*REPLICAS #NUMBER IN FRONT OF REPLIAS CAN BE CHANGED (MAX. 128)
AUG_BATCH = BATCH_SIZE//2
STEPS_PER_EXECUTION = 16
# LR = 5e-5*REPLICAS #SHOULDN'T BE CHANGED
LR = 1e-3
EPOCHS = 15
PATIENCE = 5

PREPROCESS_FUNC = tf.keras.applications.xception.preprocess_input
model = kfold_training("Xception", FILENAMES, K=3)

Fold # 1
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5
Epoch 1/15
Epoch 2/15
Epoch 3/15

Epoch 00003: ReduceLROnPlateau reducing learning rate to 0.0003000000142492354.
Epoch 4/15
Epoch 5/15
Epoch 6/15

Epoch 00006: ReduceLROnPlateau reducing learning rate to 9.000000427477062e-05.
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15
Epoch 11/15
Epoch 12/15

Epoch 00012: ReduceLROnPlateau reducing learning rate to 2.700000040931627e-05.
Epoch 13/15
Epoch 14/15
Epoch 15/15
Fold # 2
Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 4/15

Epoch 00004: ReduceLROnPlateau reducing learning rate to 0.0003000000142492354.
Epoch 5/15
Epoch 6/15
Epoch 7/15
Epoch 8/15
Epoch 9/15
Epoch 10/15

Epoch 00010: ReduceLROnPlateau reducing learning rate to 9.000000427477062e-05.
Epoch 11/15
Epoch 12/15

Epoch 00012: ReduceLROnPlateau reducing learning rate to 2.700000040931627e-05.
Epoch 13/15
Epoch 14/15
Epoch 15/15

Epoch 00