In [None]:
import re
import os
import pprint
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt

In [None]:
pp = pprint.PrettyPrinter(indent=4) # Set Pretty Print Indentation
print(tf.__version__) # Check the version of tensorflow used

%matplotlib inline

In [None]:
from kaggle_datasets import KaggleDatasets

GCS_DS_PATH = KaggleDatasets().get_gcs_path('tpu-getting-started')
print(GCS_DS_PATH) # what do gcs paths look like?
GCS_PATH = GCS_DS_PATH + '/tfrecords-jpeg-224x224'

train_dir = GCS_PATH + '/train/*.tfrec'
val_dir = GCS_PATH + '/val/*.tfrec'
test_dir = GCS_PATH + '/test/*.tfrec'

In [None]:
# Detect TPU, return appropriate distribution strategy
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver() 
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy() 

print("REPLICAS: ", strategy.num_replicas_in_sync)

In [None]:
BATCH_SIZE = 16 * strategy.num_replicas_in_sync
IMAGE_SIZE = (224, 224)
CLASSES = ['pink primrose',    'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea',     'wild geranium',     'tiger lily',           'moon orchid',              'bird of paradise', 'monkshood',        'globe thistle',         # 00 - 09
           'snapdragon',       "colt's foot",               'king protea',      'spear thistle', 'yellow iris',       'globe-flower',         'purple coneflower',        'peruvian lily',    'balloon flower',   'giant white arum lily', # 10 - 19
           'fire lily',        'pincushion flower',         'fritillary',       'red ginger',    'grape hyacinth',    'corn poppy',           'prince of wales feathers', 'stemless gentian', 'artichoke',        'sweet william',         # 20 - 29
           'carnation',        'garden phlox',              'love in the mist', 'cosmos',        'alpine sea holly',  'ruby-lipped cattleya', 'cape flower',              'great masterwort', 'siam tulip',       'lenten rose',           # 30 - 39
           'barberton daisy',  'daffodil',                  'sword lily',       'poinsettia',    'bolero deep blue',  'wallflower',           'marigold',                 'buttercup',        'daisy',            'common dandelion',      # 40 - 49
           'petunia',          'wild pansy',                'primula',          'sunflower',     'lilac hibiscus',    'bishop of llandaff',   'gaura',                    'geranium',         'orange dahlia',    'pink-yellow dahlia',    # 50 - 59
           'cautleya spicata', 'japanese anemone',          'black-eyed susan', 'silverbush',    'californian poppy', 'osteospermum',         'spring crocus',            'iris',             'windflower',       'tree poppy',            # 60 - 69
           'gazania',          'azalea',                    'water lily',       'rose',          'thorn apple',       'morning glory',        'passion flower',           'lotus',            'toad lily',        'anthurium',             # 70 - 79
           'frangipani',       'clematis',                  'hibiscus',         'columbine',     'desert-rose',       'tree mallow',          'magnolia',                 'cyclamen ',        'watercress',       'canna lily',            # 80 - 89
           'hippeastrum ',     'bee balm',                  'pink quill',       'foxglove',      'bougainvillea',     'camellia',             'mallow',                   'mexican petunia',  'bromelia',         'blanket flower',        # 90 - 99
           'trumpet creeper',  'blackberry lily',           'common tulip',     'wild rose']                                                                                                                                               # 100 - 102


def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float32) / 255.0  # convert image to floats in [0, 1] range
    image = tf.reshape(image, [*IMAGE_SIZE, 3]) # explicit size needed for TPU
    return image

def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "class": tf.io.FixedLenFeature([], tf.int64),  # shape [] means single element
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    label = tf.cast(example['class'], tf.int32)
    return image, label # returns a dataset of (image, label) pairs

def read_unlabeled_tfrecord(example):
    UNLABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "id": tf.io.FixedLenFeature([], tf.string),  # shape [] means single element
        # class is missing, this competitions's challenge is to predict flower classes for the test dataset
    }
    example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    idnum = example['id']
    return image, idnum # returns a dataset of image(s)

def augment(image_label, seed):
    image, label = image_label
    random_flip = np.random.choice(4, 1)[0]
    if random_flip == 1:
        image = tf.image.stateless_random_flip_left_right(image, seed=seed)
    elif random_flip == 2:
        image = tf.image.stateless_random_flip_up_down(image, seed=seed)
    elif random_flip == 3:
        image = tf.image.stateless_random_hue(image, 0.2, seed=seed)
        
    image = tf.image.resize_with_crop_or_pad(image, IMAGE_SIZE[0] + 6, IMAGE_SIZE[1] + 6)
    # Make a new seed.
    new_seed = tf.random.experimental.stateless_split(seed, num=1)[0, :]
    # Random crop back to the original size.
    image = tf.image.stateless_random_crop(
      image, size=[IMAGE_SIZE[0], IMAGE_SIZE[0], 3], seed=seed)
    # Random brightness.
    image = tf.image.stateless_random_brightness(
      image, max_delta=0.5, seed=new_seed)
    image = tf.clip_by_value(image, 0, 1)
    return image, label

In [None]:
AUTOTUNE = tf.data.AUTOTUNE
train_files = tf.io.gfile.glob(train_dir)
train_ds = tf.data.TFRecordDataset(train_files, num_parallel_reads=AUTOTUNE).map(read_labeled_tfrecord)
counter = tf.data.experimental.Counter()
train_ds = tf.data.Dataset.zip((train_ds, (counter, counter)))
train_ds = (train_ds
    .map(augment, num_parallel_calls=AUTOTUNE)
    .shuffle(buffer_size=2048)
    .batch(batch_size=BATCH_SIZE)
    .prefetch(buffer_size=AUTOTUNE)
)

In [None]:
val_files = tf.io.gfile.glob(val_dir)
val_ds = tf.data.TFRecordDataset(val_files, num_parallel_reads=AUTOTUNE).map(read_labeled_tfrecord)
val_ds = val_ds.batch(batch_size=BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)

In [None]:
test_files = tf.io.gfile.glob(test_dir)
test_ds = tf.data.TFRecordDataset(test_files, num_parallel_reads=AUTOTUNE).map(read_unlabeled_tfrecord)
test_ds = test_ds.batch(batch_size=BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)

In [None]:
def show_batch(images, labels, predictions=None):
    plt.figure(figsize=(20, 20))
    min = images.numpy().min()
    max = images.numpy().max()
    delta = max - min

    for i in range(12):
        plt.subplot(6, 6, i + 1)
        plt.imshow((images[i]-min) / delta)
        if predictions is None:
            plt.title(CLASSES[labels[i]])
        else:
            if labels[i] == predictions[i]:
                color = 'g'
            else:
                color = 'r'
            plt.title(CLASSES[predictions[i]], color=color)
        plt.axis("off")
    plt.show()

In [None]:
for images, labels in train_ds.take(1):
    show_batch(images, labels)

In [None]:
def count_data_items(filenames):
    # the number of data items is written in the name of the .tfrec
    # files, i.e. flowers00-230.tfrec = 230 data items
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

NUM_TRAINING_IMAGES = count_data_items(train_files)
NUM_VALIDATION_IMAGES = count_data_items(val_files)
NUM_TEST_IMAGES = count_data_items(test_files)
print('Dataset: {} training images, {} validation images, {} unlabeled test images'.format(NUM_TRAINING_IMAGES, NUM_VALIDATION_IMAGES, NUM_TEST_IMAGES))

In [None]:
EPOCHS = 25

with strategy.scope():
    preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input

    # Create the base model from the pre-trained model MobileNet V2
    IMG_SHAPE = IMAGE_SIZE + (3,)
    pretrained_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                                   include_top=False,
                                                   weights='imagenet')
    pretrained_model.trainable = False
    global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
    prediction_layer = tf.keras.layers.Dense(len(CLASSES), activation='softmax')
    
    model = tf.keras.Sequential([
        # To a base pretrained on ImageNet to extract features from images...
        pretrained_model,
        # ... attach a new head to act as a classifier.
        global_average_layer,
        prediction_layer
    ])

In [None]:
base_learning_rate = 0.0001
model.compile(loss='sparse_categorical_crossentropy',
                optimizer=tf.keras.optimizers.Adam(learning_rate=base_learning_rate),
                metrics=['sparse_categorical_accuracy'])

In [None]:
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS
)

In [None]:
def display_training_curves(training, validation, title, subplot):
    if subplot%10==1: # set up the subplots on the first call
        plt.subplots(figsize=(10,10), facecolor='#F0F0F0')
        plt.tight_layout()
    ax = plt.subplot(subplot)
    ax.set_facecolor('#F8F8F8')
    ax.plot(training)
    ax.plot(validation)
    ax.set_title('model '+ title)
    ax.set_ylabel(title)
    #ax.set_ylim(0.28,1.05)
    ax.set_xlabel('epoch')
    ax.legend(['train', 'valid.'])

In [None]:
display_training_curves(
    history.history['loss'],
    history.history['val_loss'],
    'loss',
    211,
)
display_training_curves(
    history.history['sparse_categorical_accuracy'],
    history.history['val_sparse_categorical_accuracy'],
    'accuracy',
    212,
)

In [None]:
for images, ids in val_ds.take(1):
    predictions = model.predict(images)
    predictions = tf.argmax(predictions, axis=-1)
show_batch(images, labels, tf.cast(predictions, tf.int32))

In [None]:
pretrained_model.trainable = True

# Let's take a look to see how many layers are in the base model
print("Number of layers in the base model: ", len(pretrained_model.layers))

# Fine-tune from this layer onwards
fine_tune_at = 100

# Freeze all the layers before the `fine_tune_at` layer
for layer in pretrained_model.layers[:fine_tune_at]:
    layer.trainable = False

In [None]:
model.compile(loss='sparse_categorical_crossentropy',
              optimizer = tf.keras.optimizers.RMSprop(learning_rate=base_learning_rate/10),
              metrics=['sparse_categorical_accuracy'])

In [None]:
len(model.trainable_variables)

In [None]:
fine_tune_epochs = 20
total_epochs =  EPOCHS + fine_tune_epochs

history_fine = model.fit(train_ds,
                         epochs=total_epochs,
                         initial_epoch=history.epoch[-1],
                         validation_data=val_ds)

In [None]:
pred_list = []
for images, ids in test_ds:
    predictions = model.predict(images)
    predictions = tf.argmax(predictions, axis=-1)
    pred_list.extend(list(zip(ids.numpy(), predictions.numpy())))
pred_list[:5]

In [None]:
df = pd.DataFrame(pred_list, columns=['id', 'label'])
df['id'] = df['id'].apply(lambda x: str(x)[2:-1])
df.head()

In [None]:
df.to_csv('submission.csv', index=False)