In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import math, re, os, gc
import tensorflow as tf
import tensorflow.keras.backend as K
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from kaggle_datasets import KaggleDatasets
from tensorflow import keras
from functools import partial
from sklearn.model_selection import train_test_split
import tensorflow_addons as tfa

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
user_credential = user_secrets.get_gcloud_credential()
user_secrets.set_tensorflow_credential(user_credential)

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print('Number of replicas:', strategy.num_replicas_in_sync)

In [None]:
# Initialize Variables

AUTOTUNE = tf.data.experimental.AUTOTUNE
GCS_PATH = KaggleDatasets().get_gcs_path('cassava-leaf-disease-classification')
BATCH_SIZE = 16 * strategy.num_replicas_in_sync
AUG_BATCH = BATCH_SIZE
IMAGE_SIZE = [512, 512]
CLASSES = ['0', '1', '2', '3', '4']


FOLDS = 5
SEED = 1
EPOCHS = 25

In [None]:
# Augmentation Definitions
# https://www.kaggle.com/cdeotte/rotation-augmentation-gpu-tpu-0-96
# https://www.kaggle.com/cdeotte/cutmix-and-mixup-on-gpu-tpu

def get_mat(rotation, shear, height_zoom, width_zoom, height_shift, width_shift):
    # returns 3x3 transformmatrix which transforms indicies
        
    # CONVERT DEGREES TO RADIANS
    rotation = math.pi * rotation / 180.
    shear = math.pi * shear / 180.
    
    # ROTATION MATRIX
    c1 = tf.math.cos(rotation)
    s1 = tf.math.sin(rotation)
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    rotation_matrix = tf.reshape( tf.concat([c1,s1,zero, -s1,c1,zero, zero,zero,one],axis=0),[3,3] )
        
    # SHEAR MATRIX
    c2 = tf.math.cos(shear)
    s2 = tf.math.sin(shear)
    shear_matrix = tf.reshape( tf.concat([one,s2,zero, zero,c2,zero, zero,zero,one],axis=0),[3,3] )    
    
    # ZOOM MATRIX
    zoom_matrix = tf.reshape( tf.concat([one/height_zoom,zero,zero, zero,one/width_zoom,zero, zero,zero,one],axis=0),[3,3] )
    
    # SHIFT MATRIX
    shift_matrix = tf.reshape( tf.concat([one,zero,height_shift, zero,one,width_shift, zero,zero,one],axis=0),[3,3] )
    
    return K.dot(K.dot(rotation_matrix, shear_matrix), K.dot(zoom_matrix, shift_matrix))

def transform(image,label):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly rotated, sheared, zoomed, and shifted
    DIM = IMAGE_SIZE[0]
    XDIM = DIM%2 #fix for size 331
    
    rot = 15. * tf.random.normal([1],dtype='float32')
    shr = 5. * tf.random.normal([1],dtype='float32') 
    h_zoom = 1.0 + tf.random.normal([1],dtype='float32')/10.
    w_zoom = 1.0 + tf.random.normal([1],dtype='float32')/10.
    h_shift = 16. * tf.random.normal([1],dtype='float32') 
    w_shift = 16. * tf.random.normal([1],dtype='float32') 
  
    # GET TRANSFORMATION MATRIX
    m = get_mat(rot,shr,h_zoom,w_zoom,h_shift,w_shift) 

    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(m,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    
    # FIND ORIGIN PIXEL VALUES           
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image,tf.transpose(idx3))
        
    return tf.reshape(d,[DIM,DIM,3]),label

def cutmix(image, label, PROBABILITY = 1.0):
    # input image - is a batch of images of size [n,dim,dim,3] not a single image of [dim,dim,3]
    # output - a batch of images with cutmix applied
    DIM = IMAGE_SIZE[0]
    
    imgs = []; labs = []
    for j in range(AUG_BATCH):
        # DO CUTMIX WITH PROBABILITY DEFINED ABOVE
        P = tf.cast( tf.random.uniform([],0,1)<=PROBABILITY, tf.int32)
        # CHOOSE RANDOM IMAGE TO CUTMIX WITH
        k = tf.cast( tf.random.uniform([],0,AUG_BATCH),tf.int32)
        # CHOOSE RANDOM LOCATION
        x = tf.cast( tf.random.uniform([],0,DIM),tf.int32)
        y = tf.cast( tf.random.uniform([],0,DIM),tf.int32)
        b = tf.random.uniform([],0,1) # this is beta dist with alpha=1.0
        WIDTH = tf.cast( DIM * tf.math.sqrt(1-b),tf.int32) * P
        ya = tf.math.maximum(0,y-WIDTH//2)
        yb = tf.math.minimum(DIM,y+WIDTH//2)
        xa = tf.math.maximum(0,x-WIDTH//2)
        xb = tf.math.minimum(DIM,x+WIDTH//2)
        # MAKE CUTMIX IMAGE
        one = image[j,ya:yb,0:xa,:]
        two = image[k,ya:yb,xa:xb,:]
        three = image[j,ya:yb,xb:DIM,:]
        middle = tf.concat([one,two,three],axis=1)
        img = tf.concat([image[j,0:ya,:,:],middle,image[j,yb:DIM,:,:]],axis=0)
        imgs.append(img)
        # MAKE CUTMIX LABEL
        a = tf.cast(WIDTH*WIDTH/DIM/DIM,tf.float32)
        if len(label.shape)==1:
            lab1 = tf.one_hot(label[j],len(CLASSES))
            lab2 = tf.one_hot(label[k],len(CLASSES))
        else:
            lab1 = label[j,]
            lab2 = label[k,]
        labs.append((1-a)*lab1 + a*lab2)
            
    # RESHAPE HACK SO TPU COMPILER KNOWS SHAPE OF OUTPUT TENSOR (maybe use Python typing instead?)
    image2 = tf.reshape(tf.stack(imgs),(AUG_BATCH,DIM,DIM,3))        
    label2 = tf.reshape(tf.stack(labs),(AUG_BATCH,len(CLASSES)))
    return image2,label2

In [None]:
# https://towardsai.net/p/machine-learning/building-complex-image-augmentation-pipelines-with-tensorflow-bed1914278d2

def data_augment(image, label):
    # Thanks to the dataset.prefetch(AUTO) statement in the following function this happens essentially for free on TPU. 
    # Data pipeline code is executed on the "CPU" part of the TPU while the TPU itself is computing gradients.
#     image = tf.image.random_flip_left_right(image)
    
#     p_rotation = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_spatial = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_rotate = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_pixel_1 = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_pixel_2 = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_pixel_3 = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
#     p_shear = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_crop = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    
    # Shear
#     if p_shear > .2:
#         if p_shear > .6:
#             image = transform_shear(image, HEIGHT, shear=20.)
#         else:
#             image = transform_shear(image, HEIGHT, shear=-20.)
            
    # Rotation
#     if p_rotation > .2:
#         if p_rotation > .6:
#             image = transform_rotation(image, HEIGHT, rotation=45.)
#         else:
#             image = transform_rotation(image, HEIGHT, rotation=-45.)
            
    # Flips
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    if p_spatial > .75:
        image = tf.image.transpose(image)
        
    # Rotates
    if p_rotate > .75:
        image = tf.image.rot90(image, k=3) # rotate 270º
    elif p_rotate > .5:
        image = tf.image.rot90(image, k=2) # rotate 180º
    elif p_rotate > .25:
        image = tf.image.rot90(image, k=1) # rotate 90º
        
    # Pixel-level transforms
    if p_pixel_1 >= .4:
        image = tf.image.random_saturation(image, lower=.7, upper=1.3)
    if p_pixel_2 >= .4:
        image = tf.image.random_contrast(image, lower=.8, upper=1.2)
    if p_pixel_3 >= .4:
        image = tf.image.random_brightness(image, max_delta=.1)
        
    # Crops
    if p_crop > .7:
        if p_crop > .9:
            image = tf.image.central_crop(image, central_fraction=.6)
        elif p_crop > .8:
            image = tf.image.central_crop(image, central_fraction=.7)
        else:
            image = tf.image.central_crop(image, central_fraction=.8)
    elif p_crop > .4:
        crop_size = tf.random.uniform([], int(512*.6), 512, dtype=tf.int32)
        image = tf.image.random_crop(image, size=[crop_size, crop_size, 3])
    
    image = tf.image.resize(image, size=[512, 512])
    return image, label

In [None]:
# Data Processing Functions

def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

# Test image does not have a label so if condition is required
def read_tfrecord(example, labeled):
    tfrecord_format = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.int64),
        "image_name" : tf.io.FixedLenFeature([], tf.string)
    } if labeled else {
        "image": tf.io.FixedLenFeature([], tf.string),
        "image_name": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    image_name = example['image_name']
    if labeled:
        label = tf.cast(example['target'], tf.int32)
        label = tf.one_hot(label, depth=len(CLASSES))
        return image, label
    idnum = example['image_name']
    return image, idnum

def load_dataset(filenames, labeled=True, ordered=False):
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False # disable order, increase speed
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE) # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.map(partial(read_tfrecord, labeled=labeled), num_parallel_calls=AUTOTUNE)
    return dataset

def get_training_dataset(dataset, do_init_aug=True, do_aug=True, shuffleBuffer=2048):
    
    if do_init_aug:
        dataset = dataset.map(data_augment, num_parallel_calls=AUTOTUNE)
        dataset = dataset.map(transform, num_parallel_calls=AUTOTUNE)
    
    dataset = dataset.repeat()
    if do_aug: 
        dataset = dataset.batch(AUG_BATCH)
        dataset = dataset.map(cutmix, num_parallel_calls=AUTOTUNE)
        dataset = dataset.unbatch()
    dataset = dataset.shuffle(shuffleBuffer)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset

def get_validation_dataset(dataset, ordered=False, do_aug=False):
    if do_aug:
        dataset = dataset.map(data_augment, num_parallel_calls=AUTOTUNE)
        dataset = dataset.map(transform, num_parallel_calls=AUTOTUNE)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.cache()
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset

def get_test_dataset(ordered=False):
    dataset = load_dataset(TEST_FILENAMES, labeled=False, ordered=ordered)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset

def count_data_items(filenames):
    # the number of data items is written in the name of the .tfrec files, i.e. flowers00-230.tfrec = 230 data items
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

def countItems(dataset):
    count = 0
    for item in dataset:
        count += 1
    return count

In [None]:
MY_DATA_PATH=KaggleDatasets().get_gcs_path('cassava-leaf-disease-classification')
MY_DATA_PATH = MY_DATA_PATH + "/train_tfrecords/*.tfrec"

In [None]:
# Step 1 - Read filenames for tfrecords

train_tfrecords_names = []
for dirname, _, filenames in os.walk('/kaggle/input/cassava-leaf-disease-classification/train_tfrecords/'):
    for filename in filenames:
        train_tfrecords_names.append(os.path.join(dirname, filename))

TRAINING_FILENAMES = tf.io.gfile.glob(MY_DATA_PATH)

# TEST_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/test_tfrecords/ld_test*.tfrec')

In [None]:
NUM_TRAINING_IMAGES = int(count_data_items(TRAINING_FILENAMES) * (FOLDS-1.)/FOLDS) # + count_data_items(TRAINING_EXTERNAL_FILENAMES)
NUM_VALIDATION_IMAGES = int(count_data_items(TRAINING_FILENAMES) * (1./FOLDS))
# NUM_TEST_IMAGES = count_data_items(TEST_FILENAMES)
STEPS_PER_EPOCH = NUM_TRAINING_IMAGES // BATCH_SIZE

print('Dataset: {} training images, {} validation images, unlabeled test images'.format(NUM_TRAINING_IMAGES, NUM_VALIDATION_IMAGES))

# **Building the Model**

In [None]:
!pip install -U efficientnet
from keras import applications
import efficientnet.keras as efn
from efficientnet.keras import EfficientNetB3
from keras import callbacks
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout
from keras.layers import BatchNormalization
from keras.optimizers import Adam

In [None]:
os.system('pip install ../input/tf2cvmodels/TensorFlow-ResNets-master/ -q --no-deps')

In [None]:
from tf2_resnets import models

In [None]:
class CategoricalFocalLossLabelSmoothing(tf.keras.losses.Loss):
    def __init__(self, gamma=2.0, alpha=0.25, ls=0.1, classes=5.0):
        super(CategoricalFocalLossLabelSmoothing, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.ls = ls
        self.classes = classes
        
    def focal_loss(self, y_true, y_pred, gamma, alpha, ls, classes):
        # Define epsilon so that the backpropagation will not result in NaN
        # for 0 divisor case
        epsilon = K.epsilon()
        # Add the epsilon to prediction value
        #y_pred = y_pred + epsilon
        #label smoothing
        y_pred_ls = (1 - ls) * y_pred + ls / classes
        # Clip the prediction value
        y_pred_ls = K.clip(y_pred_ls, epsilon, 1.0-epsilon)
        # Calculate cross entropy
        cross_entropy = -y_true*K.log(y_pred_ls)
        # Calculate weight that consists of  modulating factor and weighting factor
        weight = alpha * y_true * K.pow((1-y_pred_ls), gamma)
        # Calculate focal loss
        loss = weight * cross_entropy
        # Sum the losses in mini_batch
        loss = K.sum(loss, axis=1)
        return loss
        
    def call(self, y_true, y_pred):
        return self.focal_loss(y_true, y_pred, gamma=self.gamma, alpha=self.alpha, ls=self.ls, classes=self.classes)

class TaylorCrossEntropyLoss(tf.keras.losses.Loss):
    def __init__(self, n=3, label_smoothing=0.0):
        super(TaylorCrossEntropyLoss, self).__init__()
        self.n = n
        self.label_smoothing = label_smoothing
        
    def taylor_cross_entropy_loss(self, y_pred, y_true, n=3, label_smoothing=0.0):
        """Taylor Cross Entropy Loss.
        Args:
        y_pred: A multi-dimensional probability tensor with last dimension `num_classes`.
        y_true: A tensor with shape and dtype as y_pred.
        n: An order of taylor expansion.
        label_smoothing: A float in [0, 1] for label smoothing.
        Returns:
        A loss tensor.
        """
        y_pred = tf.cast(y_pred, tf.float32)
        y_true = tf.cast(y_true, tf.float32)

        if label_smoothing > 0.0:
            num_classes = tf.cast(tf.shape(y_true)[-1], tf.float32)
            y_true = (1 - num_classes /(num_classes - 1) * label_smoothing) * y_true + label_smoothing / (num_classes - 1)

        y_pred_n_order = tf.math.maximum(tf.stack([1 - y_pred] * n), 1e-7) # avoide being too small value
        numerator = tf.math.maximum(tf.math.cumprod(y_pred_n_order, axis=0), 1e-7) # avoide being too small value
        denominator = tf.expand_dims(tf.expand_dims(tf.range(1, n+1, dtype="float32"), axis=1), axis=1)
        y_pred_taylor = tf.math.maximum(tf.math.reduce_sum(tf.math.divide(numerator, denominator), axis=0), 1e-7) # avoide being too small value
        loss_values = tf.math.reduce_sum(y_true * y_pred_taylor, axis=1, keepdims=True)
        return tf.math.reduce_sum(loss_values, -1)

    def call(self, y_true, y_pred):
        return self.taylor_cross_entropy_loss(y_pred, y_true, n=self.n, label_smoothing=self.label_smoothing)

In [None]:
# Build Model

def get_model():
    with strategy.scope():

        efficient_net = EfficientNetB3(
        weights='noisy-student',
        input_shape=(512,512,3),
        include_top=False
    )
        model = Sequential()
        model.add(efficient_net)
        model.add(tf.keras.layers.GlobalAveragePooling2D())
        model.add(Dropout(0.5))
        model.add(Dense(units = len(CLASSES), activation='softmax'))
        
        model.summary()
        
        model.compile(
        optimizer=tf.keras.optimizers.Adam(lr = 1e-3),
        loss = CategoricalFocalLossLabelSmoothing(gamma=2.0, alpha=0.25, ls=0.3, classes=5.0),
        metrics=['categorical_accuracy'])
    
    return model

In [None]:
# Build Model

def get_model2():
    with strategy.scope():

        resnet = models.ResNeSt101(
        weights='imagenet',
        input_shape=(512,512,3),
        include_top=False
    )
        model = Sequential()
        model.add(resnet)
        model.add(tf.keras.layers.GlobalAveragePooling2D())
        model.add(Dropout(0.5))
        model.add(Dense(units = len(CLASSES), activation='softmax'))
        
        model.summary()

        model.compile(
        optimizer=tf.keras.optimizers.Adam(lr = 1e-3),
        loss='categorical_crossentropy',
        metrics=['categorical_accuracy'])
    
    return model

In [None]:
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor = 'val_categorical_accuracy', factor = 0.3, 
                              patience = 2, min_delta = 0.001, 
                              mode = 'auto', verbose = 1)

In [None]:
from sklearn.model_selection import KFold

histories = []
modelslist = []
kfold = KFold(FOLDS, shuffle = True, random_state = SEED)
FILENAMES = pd.DataFrame({'TRAINING_FILENAMES': TRAINING_FILENAMES})

for fold, (trn_ind, val_ind) in enumerate(kfold.split(TRAINING_FILENAMES)):
    
    model_save = tf.keras.callbacks.ModelCheckpoint('./model%i.h5'%fold,
                             save_best_only = True, 
                             monitor = 'val_categorical_accuracy', 
                             mode = 'auto', verbose = 1)
    
    print(); print('#'*25)
    print('### FOLD', fold+1)
    print('#'*25)
    
    TRAIN_NAMES = list(FILENAMES.loc[trn_ind]['TRAINING_FILENAMES'])
    train_dataset = load_dataset(TRAIN_NAMES, labeled = True)
    val_dataset = load_dataset(FILENAMES.loc[val_ind]['TRAINING_FILENAMES'], labeled = True, ordered = True)
    
    model = get_model()
    history = model.fit(
        get_training_dataset(train_dataset, do_init_aug=True, do_aug=False, shuffleBuffer=NUM_TRAINING_IMAGES),
        steps_per_epoch = STEPS_PER_EPOCH,
        epochs = 20,
        callbacks=[model_save,
                   reduce_lr,
                  ],
        validation_data = get_validation_dataset(val_dataset),
        verbose=2
    )
    modelslist.append(model)
    histories.append(history)
    plt.figure(figsize=(15,5))
    plt.plot(np.arange(len(model.history.history['categorical_accuracy'])),model.history.history['categorical_accuracy'],'-o',label='Train categorical_accuracy',color='#ff7f0e')
    plt.plot(np.arange(len(model.history.history['categorical_accuracy'])),model.history.history['val_categorical_accuracy'],'-o',label='Val categorical_accuracy',color='#1f77b4')
    x = np.argmax(model.history.history['val_categorical_accuracy'] ); y = np.max( model.history.history['val_categorical_accuracy'] )
    xdist = plt.xlim()[1] - plt.xlim()[0]; ydist = plt.ylim()[1] - plt.ylim()[0]
    plt.scatter(x,y,s=200,color='#1f77b4'); plt.text(x-0.03*xdist,y-0.13*ydist,'max categorical_accuracy\n%.2f'%y,size=14)
    plt.ylabel('categorical_accuracy',size=14); plt.xlabel('Epoch',size=14)
    plt.legend(loc=2)
    plt2 = plt.gca().twinx()
    plt2.plot(np.arange(len(model.history.history['loss'])),model.history.history['loss'],'-o',label='Train Loss',color='#2ca02c')
    plt2.plot(np.arange(len(model.history.history['loss'])),model.history.history['val_loss'],'-o',label='Val Loss',color='#d62728')
    x = np.argmin( model.history.history['val_loss'] ); y = np.min( model.history.history['val_loss'] )
    ydist = plt.ylim()[1] - plt.ylim()[0]
    plt.scatter(x,y,s=200,color='#d62728'); plt.text(x-0.03*xdist,y+0.05*ydist,'min loss',size=14)
    plt.ylabel('Loss',size=14)
    plt.title('Fold: %i | Image Size: %i | model: EfficientNetB%i |  Batch_size: %i'%(fold+1, 512, 7, 128))
    plt.legend(loc=3)
    plt.show()
    del model; gc.collect()
    tf.tpu.experimental.initialize_tpu_system(tpu)

In [None]:
from sklearn.model_selection import KFold

histories = []
modelslist = []
kfold = KFold(FOLDS, shuffle = True, random_state = SEED)
FILENAMES = pd.DataFrame({'TRAINING_FILENAMES': TRAINING_FILENAMES})

for fold, (trn_ind, val_ind) in enumerate(kfold.split(TRAINING_FILENAMES)):
    
    model_save = tf.keras.callbacks.ModelCheckpoint('./resnetmodel%i.h5'%fold,
                             save_best_only = True, 
                             monitor = 'val_categorical_accuracy', 
                             mode = 'auto', verbose = 1)
    
    print(); print('#'*25)
    print('### FOLD', fold+1)
    print('#'*25)
    
    TRAIN_NAMES = list(FILENAMES.loc[trn_ind]['TRAINING_FILENAMES'])
    train_dataset = load_dataset(TRAIN_NAMES, labeled = True)
    val_dataset = load_dataset(FILENAMES.loc[val_ind]['TRAINING_FILENAMES'], labeled = True, ordered = True)
    
    model = get_model2()
    history = model.fit(
        get_training_dataset(train_dataset, do_init_aug=True, do_aug=True, shuffleBuffer=NUM_TRAINING_IMAGES),
        steps_per_epoch = STEPS_PER_EPOCH,
        epochs = EPOCHS,
        callbacks=[model_save,
                   lr_callback
                  ],
        validation_data = get_validation_dataset(val_dataset),
        verbose=2
    )
    modelslist.append(model)
    histories.append(history)
    plt.figure(figsize=(15,5))
    plt.plot(np.arange(len(model.history.history['categorical_accuracy'])),model.history.history['categorical_accuracy'],'-o',label='Train categorical_accuracy',color='#ff7f0e')
    plt.plot(np.arange(len(model.history.history['categorical_accuracy'])),model.history.history['val_categorical_accuracy'],'-o',label='Val categorical_accuracy',color='#1f77b4')
    x = np.argmax(model.history.history['val_categorical_accuracy'] ); y = np.max( model.history.history['val_categorical_accuracy'] )
    xdist = plt.xlim()[1] - plt.xlim()[0]; ydist = plt.ylim()[1] - plt.ylim()[0]
    plt.scatter(x,y,s=200,color='#1f77b4'); plt.text(x-0.03*xdist,y-0.13*ydist,'max categorical_accuracy\n%.2f'%y,size=14)
    plt.ylabel('categorical_accuracy',size=14); plt.xlabel('Epoch',size=14)
    plt.legend(loc=2)
    plt2 = plt.gca().twinx()
    plt2.plot(np.arange(len(model.history.history['loss'])),model.history.history['loss'],'-o',label='Train Loss',color='#2ca02c')
    plt2.plot(np.arange(len(model.history.history['loss'])),model.history.history['val_loss'],'-o',label='Val Loss',color='#d62728')
    x = np.argmin( model.history.history['val_loss'] ); y = np.min( model.history.history['val_loss'] )
    ydist = plt.ylim()[1] - plt.ylim()[0]
    plt.scatter(x,y,s=200,color='#d62728'); plt.text(x-0.03*xdist,y+0.05*ydist,'min loss',size=14)
    plt.ylabel('Loss',size=14)
    plt.title('Fold: %i | Image Size: %i | model: EfficientNetB%i |  Batch_size: %i'%(fold+1, 512, 7, 128))
    plt.legend(loc=3)
    plt.show()
    del model; gc.collect()
    tf.tpu.experimental.initialize_tpu_system(tpu)