# Train

In [8]:
from tensorflow.keras.utils import image_dataset_from_directory
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Rescaling, Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.metrics import SparseCategoricalAccuracy
from tensorflow.data import AUTOTUNE
import os

## Set paramaters

In [9]:
BATCH_SIZE = 16
IMG_HEIGHT = 64
IMG_WIDTH = 64
DATA_PATH = 'Train_dataset'
NUM_CLASSES = len(os.listdir(DATA_PATH))
MODEL_NAME = 'model.hdf5'

## Read dataset

In [10]:
train_ds = image_dataset_from_directory(
  DATA_PATH,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(IMG_HEIGHT, IMG_WIDTH),
  batch_size=BATCH_SIZE)

val_ds = image_dataset_from_directory(
  DATA_PATH,
  validation_split=0.2,
  subset="validation",
  seed=123,
  image_size=(IMG_HEIGHT, IMG_WIDTH),
  batch_size=BATCH_SIZE)

Found 4151 files belonging to 6 classes.
Using 3321 files for training.
Found 4151 files belonging to 6 classes.
Using 830 files for validation.


In [11]:
train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

## Create neural network architecture

In [12]:
model = Sequential([
  Rescaling(1./255),
  Conv2D(64, 3, activation='relu'),
  MaxPooling2D(),
  Conv2D(64, 3, activation='relu'),
  MaxPooling2D(),
  Conv2D(64, 3, activation='relu'),
  MaxPooling2D(),
  Flatten(),
  Dense(256, activation='relu'),
  Dense(NUM_CLASSES)
])

## Compile and train neural network

In [13]:
early_stop = EarlyStopping(monitor='val_loss', patience=10)
model_checkpoint =  ModelCheckpoint(MODEL_NAME, save_best_only=True, monitor='val_loss', mode='min')
model.compile(
  optimizer='adam',
  loss=SparseCategoricalCrossentropy(from_logits=True),
  metrics=[SparseCategoricalAccuracy()])

model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=100, callbacks=[early_stop, model_checkpoint]
)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100


<keras.callbacks.History at 0x275bf45d040>