this notebook is almost exactly the same as the contents of this tutorial:
https://www.tensorflow.org/tutorials/images/classification

In [None]:
import tensorflow as tf

train_ds = tf.keras.utils.image_dataset_from_directory(
    "combined",
    validation_split=0.1,
    subset="training",
    seed=12345,
    image_size=(40, 40),
    batch_size=32,
    color_mode="grayscale",
)
val_ds = tf.keras.utils.image_dataset_from_directory(
    "combined",
    validation_split=0.1,
    subset="validation",
    seed=12345,
    image_size=(40, 40),
    batch_size=32,
    color_mode="grayscale"
)

In [None]:
import matplotlib.pyplot as plt

class_names = train_ds.class_names
plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")

In [None]:
# im not sure this is nessisary for such a small dataset but its in the tutorial so whatever
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

In [None]:
for images, labels in train_ds.take(1):
    print(images[0].shape)

In [None]:
import tensorflow.keras.layers as layers
import tensorflow.keras as keras
num_classes = len(class_names)

img_side = 40

model = keras.Sequential([
  layers.Rescaling(1./255, input_shape=(img_side, img_side, 1)),
  layers.Flatten(),
  layers.Dense(64, activation='relu'),
  layers.Dense(num_classes),
  layers.Softmax(),
])

In [None]:
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
model.summary()

In [None]:
epochs=20
history = model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=epochs
)

In [None]:
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

In [None]:
model.save("my_model")