In [5]:
import numpy as np
import tensorflow_datasets as tfds
import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt

(train_data, val_data), meta = tfds.load('tf_flowers', split=['train[:70%]', 'train[70%:]'], 
                                        as_supervised=True, with_info=True)
num_train=meta.splits['train'].num_examples

EPOCHS=5
RES=299
def define_model():
    URL = "https://tfhub.dev/google/tf2-preview/inception_v3/feature_vector/4"

    model=tf.keras.Sequential([hub.KerasLayer(URL,
                        input_shape=(RES,RES,3), trainable=False),
                        tf.keras.layers.Dense(5, activation='softmax')])
    model.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=['accuracy'],
        optimizer=tf.keras.optimizers.Adam()
    )
    return model
model=define_model()

BATCH=50
classes=meta.features['label'].names

print('\n{}\n{}'.format(classes, num_train))

normalize= lambda data,labels: (tf.cast(data/255, tf.float32), labels)
train_data, val_data = train_data.shuffle(buffer_size=num_train).map(normalize), val_data.map(normalize)

train_batch, val_batch=train_data.batch(BATCH).prefetch(1), val_data.batch(BATCH).prefetch(1)

hist=model.fit(
    train_batch,
    verbose=2,
    validation_data=val_batch,
    epochs=5
)

def summarize_model_diagnostics(history):
    loss, acc = history.history['loss'], history.history['accuracy']
    val_loss, val_acc = history.history['val_loss'], history.history['val_accuracy']
    fig, ax=plt.subplots(1,2, figsize=(20,10))
    ax[1].plot(range(EPOCHS), loss, label='loss',color='blue')
    ax[1].plot(range(EPOCHS), val_loss, label='val loss', color='r')
    ax[1].legend(loc='upper right')
    ax[1].set_title('Loss')

    ax[0].plot(range(EPOCHS), acc, label='accuracy',color='blue')
    ax[0].plot(range(EPOCHS), val_acc, label='val accuracy', color='r')
    ax[0].legend(loc='upper right')
    ax[0].set_title('Accuracy')
    plt.title('Training & Validation Performance')
    plt.tight_layout()
    plt.show()

summarize_model_diagnostics(hist)


['dandelion', 'daisy', 'tulips', 'sunflowers', 'roses']
3670
