## Importing the libraries

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
import re
import math

### Connect to TPU

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
except ValueError:
    tpu = None
    print("Failed to connect to TPU.")
    
if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy() 

print("REPLICAS: ", strategy.num_replicas_in_sync)

## Defining helper functions for tfdata files preprocessing

In [None]:
IMAGE_SIZE = (512, 512)
AUTO = tf.data.AUTOTUNE
INPUT_PATH = "/kaggle/input/tpu-getting-started/tfrecords-jpeg-512x512"
CLASSES = ['pink primrose',    'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea',     'wild geranium',     'tiger lily',           'moon orchid',              'bird of paradise', 'monkshood',        'globe thistle',         # 00 - 09
           'snapdragon',       "colt's foot",               'king protea',      'spear thistle', 'yellow iris',       'globe-flower',         'purple coneflower',        'peruvian lily',    'balloon flower',   'giant white arum lily', # 10 - 19
           'fire lily',        'pincushion flower',         'fritillary',       'red ginger',    'grape hyacinth',    'corn poppy',           'prince of wales feathers', 'stemless gentian', 'artichoke',        'sweet william',         # 20 - 29
           'carnation',        'garden phlox',              'love in the mist', 'cosmos',        'alpine sea holly',  'ruby-lipped cattleya', 'cape flower',              'great masterwort', 'siam tulip',       'lenten rose',           # 30 - 39
           'barberton daisy',  'daffodil',                  'sword lily',       'poinsettia',    'bolero deep blue',  'wallflower',           'marigold',                 'buttercup',        'daisy',            'common dandelion',      # 40 - 49
           'petunia',          'wild pansy',                'primula',          'sunflower',     'lilac hibiscus',    'bishop of llandaff',   'gaura',                    'geranium',         'orange dahlia',    'pink-yellow dahlia',    # 50 - 59
           'cautleya spicata', 'japanese anemone',          'black-eyed susan', 'silverbush',    'californian poppy', 'osteospermum',         'spring crocus',            'iris',             'windflower',       'tree poppy',            # 60 - 69
           'gazania',          'azalea',                    'water lily',       'rose',          'thorn apple',       'morning glory',        'passion flower',           'lotus',            'toad lily',        'anthurium',             # 70 - 79
           'frangipani',       'clematis',                  'hibiscus',         'columbine',     'desert-rose',       'tree mallow',          'magnolia',                 'cyclamen ',        'watercress',       'canna lily',            # 80 - 89
           'hippeastrum ',     'bee balm',                  'pink quill',       'foxglove',      'bougainvillea',     'camellia',             'mallow',                   'mexican petunia',  'bromelia',         'blanket flower',        # 90 - 99
           'trumpet creeper',  'blackberry lily',           'common tulip',     'wild rose']                                                                                                                                               # 100 - 102

def decode_image(data):
    image = tf.io.decode_jpeg(data, channels=3)
#     image = tf.cast(image, tf.float32) / 255.0 # using preprocessing layers at the moment, will uncomment if needed
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

def read_labelled_tfrecord(record):
    LABELED_RECORD_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "class": tf.io.FixedLenFeature([], tf.int64)
    }
    record = tf.io.parse_single_example(record, LABELED_RECORD_FORMAT)
    image = decode_image(record["image"])
    label = tf.cast(record["class"], tf.int32)
    return image, label

def read_unlabelled_tfrecord(record):
    UNLABELED_RECORD_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "id": tf.io.FixedLenFeature([], tf.string)
    }
    record = tf.io.parse_single_example(record, UNLABELED_RECORD_FORMAT)
    image = decode_image(record["image"])
    ID = record["id"]
    return image, ID

def load_dataset(folder, labeled=True, ordered=False):
    options = tf.data.Options()
    if not ordered:
        options.experimental_deterministic = False # ignore order
    dataset = tf.data.TFRecordDataset(tf.io.gfile.glob(INPUT_PATH + '/' + folder + "/*.tfrec"), num_parallel_reads=AUTO)
    dataset = dataset.with_options(options)
    dataset = dataset.map(read_labelled_tfrecord if labeled else read_unlabelled_tfrecord, num_parallel_calls=AUTO)
    return dataset


### Loading the datasets

In [None]:
BATCH_SIZE = 16 * strategy.num_replicas_in_sync

def augment_image(image, label):
    image = tf.image.random_flip_left_right(image)
    return image, label

def load_train_data():
    dataset = load_dataset("train")
    dataset = dataset.map(augment_image, num_parallel_calls=AUTO)
    dataset = dataset.shuffle(1024)
    dataset = dataset.repeat()
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO)
    return dataset
    
def load_val_data():
    dataset = load_dataset("val")
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.cache()
    dataset = dataset.prefetch(AUTO)
    return dataset

def load_test_data():
    dataset = load_dataset("test", labeled=False, ordered=True)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO)
    return dataset

def count_entries(folder):
    paths = tf.io.gfile.glob(INPUT_PATH + '/' + folder + "/*.tfrec")
    count = [int(re.search(f'(\d+)(?=\.)', path)[0]) for path in paths]
    return sum(count)
    
    
train_dataset = load_train_data()
TRAIN_ENTRIES = count_entries("train")
val_dataset = load_val_data()
VAL_ENTRIES = count_entries("val")
test_dataset = load_test_data()
TEST_ENTRIES = count_entries("test")
print("Found", TRAIN_ENTRIES, "train entries,", VAL_ENTRIES, "val_entries,", TEST_ENTRIES, "test_entries.")

### Functions for exploring the images

In [None]:
import math

def tfdata_to_numpy(data):
    images, labels = data
    numpy_images = images.numpy()
    numpy_labels = labels.numpy()
    if numpy_labels.dtype == object:
        numpy_labels = [None for i in range(len(numpy_labels))]
    return numpy_images, numpy_labels

# simpler version just to display images without comparing labels
def display_images(batch):
    images, labels = tfdata_to_numpy(batch)
    rows = int(math.sqrt(len(images)))
    cols = len(images) // rows
    plt.figure(figsize=(16,16))
    plt.tight_layout()
    for i in range(rows * cols):
        plt.subplot(rows, cols, i+1)
        if labels[i] is not None:
            plt.title(CLASSES[labels[i]], fontsize=10)
        plt.axis("off")
        plt.imshow(images[i])
    plt.show()

def title_from_label_and_target(label, correct_label):
    if correct_label is None:
        return CLASSES[label], True
    correct = (label == correct_label)
    return "{} [{}{}{}]".format(CLASSES[label], 'OK' if correct else 'NO', u"\u2192" if not correct else '',
                                CLASSES[correct_label] if not correct else ''), correct

def display_one_flower(image, title, subplot, red=False, titlesize=16):
    plt.subplot(*subplot)
    plt.axis('off')
    plt.imshow(image)
    if len(title) > 0:
        plt.title(title, fontsize=int(titlesize) if not red else int(titlesize/1.2), color='red' if red else 'black', fontdict={'verticalalignment':'center'}, pad=int(titlesize/1.5))
    return (subplot[0], subplot[1], subplot[2]+1)
    
def display_batch_of_images(databatch, predictions=None):
    """This will work with:
    display_batch_of_images(images)
    display_batch_of_images(images, predictions)
    display_batch_of_images((images, labels))
    display_batch_of_images((images, labels), predictions)
    """
    # data
    images, labels = tfdata_to_numpy(databatch)
    if labels is None:
        labels = [None for _ in enumerate(images)]
        
    # auto-squaring: this will drop data that does not fit into square
    # or square-ish rectangle
    rows = int(math.sqrt(len(images)))
    cols = len(images)//rows
        
    # size and spacing
    FIGSIZE = 13.0
    SPACING = 0.1
    subplot=(rows,cols,1)
    if rows < cols:
        plt.figure(figsize=(FIGSIZE,FIGSIZE/cols*rows))
    else:
        plt.figure(figsize=(FIGSIZE/rows*cols,FIGSIZE))
    
    # display
    for i, (image, label) in enumerate(zip(images[:rows*cols], labels[:rows*cols])):
        title = '' if label is None else CLASSES[label]
        correct = True
        if predictions is not None:
            title, correct = title_from_label_and_target(predictions[i], label)
        dynamic_titlesize = FIGSIZE*SPACING/max(rows,cols)*40+3 # magic formula tested to work from 1x1 to 10x10 images
        subplot = display_one_flower(image, title, subplot, not correct, titlesize=dynamic_titlesize)
    
ds_iter = iter(train_dataset.unbatch().batch(20))
one_batch = next(ds_iter)
display_batch_of_images(one_batch)

## Defining the model

In [None]:
MAX_LR = 0.00005 * strategy.num_replicas_in_sync

def exponential_lr(epoch,
                   start_lr = 0.00001, min_lr = 0.00001, max_lr = MAX_LR,
                   rampup_epochs = 5, sustain_epochs = 2,
                   exp_decay = 0.8):

    def lr(epoch, start_lr, min_lr, max_lr, rampup_epochs, sustain_epochs, exp_decay):
        # linear increase from start to rampup_epochs
        if epoch < rampup_epochs:
            lr = ((max_lr - start_lr) /
                  rampup_epochs * epoch + start_lr)
        # constant max_lr during sustain_epochs
        elif epoch < rampup_epochs + sustain_epochs:
            lr = max_lr
        # exponential decay towards min_lr
        else:
            lr = ((max_lr - min_lr) *
                  exp_decay**(epoch - rampup_epochs - sustain_epochs) +
                  min_lr)
        return lr
    return lr(epoch,
              start_lr,
              min_lr,
              max_lr,
              rampup_epochs,
              sustain_epochs,
              exp_decay)

lr_callback = tf.keras.callbacks.LearningRateScheduler(exponential_lr, verbose=True)

### Plotting the learning rate for 30 epochs

In [None]:
epochs = np.arange(1, 30, 1)
plt.plot(epochs, [exponential_lr(x) for x in epochs])
plt.ylabel("Learning rate")
plt.xlabel("Epoch")
plt.show()

In [None]:
def compile_model(preprocess_layer, base_model):
    with strategy.scope():
        img_preprocess = tf.keras.layers.Lambda(getattr(tf.keras.applications, preprocess_layer).preprocess_input, input_shape=([*IMAGE_SIZE, 3]))
        base_model = getattr(tf.keras.applications, base_model)(weights="imagenet", include_top=False)
        base_model.trainable = False
        
        inputs = tf.keras.layers.Input(shape=([*IMAGE_SIZE, 3]))
        x = img_preprocess(inputs)
        x = base_model(x, training=False)
        x = tf.keras.layers.GlobalAveragePooling2D()(x)
        outputs = tf.keras.layers.Dense(len(CLASSES), activation="softmax")(x)

        model = tf.keras.Model(inputs, outputs)

        model.compile(optimizer="adam",
                      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                      metrics=["sparse_categorical_accuracy"])
        print("done")
        return model

### Transfer learning the models

In [None]:
STEPS_PER_EPOCH = TRAIN_ENTRIES // BATCH_SIZE
VAL_STEPS = VAL_ENTRIES // BATCH_SIZE
model_layers = [("densenet", "DenseNet201"),
                ("resnet50", "ResNet50"),
                ("xception", "Xception")]

models = []

for layers in model_layers:
    models.append(compile_model(*layers))

print(len(models))


In [None]:
models[2].layers[2].trainable

In [None]:
histories = []

for model in models:
    history = model.fit(train_dataset,
                    steps_per_epoch=STEPS_PER_EPOCH,
                    validation_data=val_dataset,
                    validation_steps=VAL_STEPS,
                    epochs=15,
                    callbacks=[lr_callback])
    histories.append(history)

### Plotting the learning processes of models

In [None]:
def plot_graphs(axs, history, metric):
  axs.plot(history.history[metric])
  axs.plot(history.history['val_'+metric], '')
  axs.set_xlabel("Epochs")
  axs.set_ylabel(metric)
  axs.legend([metric, 'val_'+metric])

In [None]:
def plot_histories(histories):
    _, axs = plt.subplots(len(histories), 2, figsize=(12,12))
    model_names = [model_tuple[1] for model_tuple in model_layers]
    for i in range(len(histories)):
        axs[i, 0].set_title(model_names[i] + " " + 'sparse_categorical_accuracy')
        axs[i, 1].set_title(model_names[i] + " " + 'loss')
        plot_graphs(axs[i, 0], histories[i], 'sparse_categorical_accuracy')
        plot_graphs(axs[i, 1], histories[i], 'loss')
    plt.tight_layout()
    plt.show()
plot_histories(histories)

According to the graphs I am going to fine-tune Xception and DenseNet201

### Fine tuning of models

In [None]:
def tune_model(model):
    model.layers[2].trainable = True
    model.compile(optimizer=tf.keras.optimizers.Adam(1e-5),
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                  metrics="sparse_categorical_accuracy")

    model.fit(train_dataset, steps_per_epoch=STEPS_PER_EPOCH, validation_data=val_dataset, validation_steps=VAL_STEPS, epochs=5)

In [None]:
# Tuning DenseNet201
tune_model(models[0])
# Tuning Xception
tune_model(models[2])

## Predicting the classes of images in the validation set to see how each model performs individually

In [None]:
val_sample = val_dataset.unbatch().shuffle(350).batch(30)
images, labels = next(iter(val_sample))
for model in models:
    preds = model.predict(images)
    preds = np.argmax(preds, axis=-1)
    display_batch_of_images((images, labels), preds)

## Making predictions on a test set

In [None]:
test_images_ds = test_dataset.map(lambda image, idnum: image)
predictions = []
for model in models:
    probabilities = model.predict(test_images_ds)
    predictions.append(np.argmax(probabilities, axis=-1))


### Function for declaring the best prediction among ensembled predictions

In [None]:
def vote_preds(predictions):
    voters_amt = len(predictions)
    preds_len = len(predictions[0])
    best = []
    for i in range(preds_len):
        one_pred = []
        for j in range(voters_amt):
            one_pred.append(predictions[j][i])
        best.append(np.bincount(np.array(one_pred)).argmax())
    return best

### Creating a submission.csv file

In [None]:
ids = test_dataset.map(lambda image, idnum: idnum).unbatch()
ids = next(iter(ids.batch(TEST_ENTRIES))).numpy().astype("U")
predictions = vote_preds(predictions)
subm_df = pd.DataFrame({"id": ids, "label": predictions})
subm_df.head()

In [None]:
subm_df.to_csv("/kaggle/working/submission.csv", index=False)
print("done!")