#### imports

In [None]:
import tensorflow as tf
%matplotlib inline
import matplotlib.pylab as plt
import numpy as np
import pathlib
import PIL
import PIL.Image

#### parameters

In [None]:
# data
data_dir = pathlib.Path('./data')
data = list(data_dir.glob('*/*.*'))
img_height = 227
img_width = 227

# partitioning
test_set_ratio = 0.1
validation_set_ratio = 0.1

# random
random_seed = 42

# keras
batch_size = 32

#### load all data

In [None]:
ds_all = tf.keras.utils.image_dataset_from_directory(
    batch_size = batch_size,
    directory = data_dir,
    image_size = (img_height, img_width), # resized image dimensions
    seed = random_seed) # ensure known random_seed is used (for consistency)

#### partition data

In [None]:
ds_all_size = ds_all.__len__().numpy()
print("ds_all_size        => " + str(ds_all_size))

# use first batches for test and validation as these are full batches (last batch might be partial batch)
ds_test_size       = max(1, int(ds_all_size * test_set_ratio      ))
ds_validation_size = max(1, int(ds_all_size * validation_set_ratio))
print("ds_test_size       => " + str(ds_test_size))
print("ds_validation_size => " + str(ds_validation_size))

# training set becomes rest
ds_training_size = ds_all_size - ds_validation_size - ds_test_size
print("ds_training_size   => " + str(ds_training_size))

In [None]:
# check sizes
ds_test       = ds_all                                        .take(ds_test_size      )
ds_validation = ds_all.skip(ds_test_size                     ).take(ds_validation_size)
ds_training   = ds_all.skip(ds_test_size + ds_validation_size).take(ds_training_size  )
print("ds_test      .__len__() => " + str(ds_test      .__len__().numpy()))
print("ds_validation.__len__() => " + str(ds_validation.__len__().numpy()))
print("ds_training  .__len__() => " + str(ds_training  .__len__().numpy()))

In [None]:
print("ds_all.__len__() => " + str(ds_all.__len__().numpy()))

In [None]:
print("tf.data.experimental.cardinality(ds_all) => " + str(tf.data.experimental.cardinality(ds_all)))

In [None]:
print("dataset count: " + str(sum(1 for _ in ds_all.__iter__())))

In [None]:
class_names = ds_all.class_names

plt.figure(figsize=(img_height, img_width))
for images, labels in ds_all.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")

In [None]:
print(class_names)

In [None]:
print(train_ds)

In [None]:
base_model = tf.keras.applications.resnet_v2.ResNet50V2(
        include_top=False,
        pooling='max',
        input_shape=(img_height, img_width, 3),
        weights='imagenet')

model = tf.keras.Sequential([
    base_model,
    tf.keras.layers.Dense(512, 'relu'),
    tf.keras.layers.Dense(256, 'relu'),
    tf.keras.layers.Dense(2, 'softmax')
])
    
model.summary()

In [None]:
snapshot_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath="./snapshots",
    verbose=1)

In [None]:
for layer in base_model.layers:
    layer.trainable = False
    
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['sparse_categorical_accuracy'])

history = model.fit(
    train_ds,
    steps_per_epoch=2, #use way more steps here: number of samples / batch size
    epochs=30 #use way more or use EarlyStopping callback
)