In [11]:
# Flow_from_directory
# https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image/ImageDataGenerator
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.applications.resnet50 import ResNet50


In [12]:
IMAGE_SIZE = 224

# Train data with validation split with 
train_generator = ImageDataGenerator(rescale=1./255, validation_split=0.2)
train= train_generator.flow_from_directory('data/full/calc/train/', target_size=(IMAGE_SIZE, IMAGE_SIZE), batch_size=32, class_mode='binary', subset='training')
val= train_generator.flow_from_directory('data/full/calc/train/', target_size=(IMAGE_SIZE, IMAGE_SIZE), batch_size=32, class_mode='binary', subset='validation')

# Test data
test_generator = ImageDataGenerator(rescale=1./255)
test = test_generator.flow_from_directory('data/full/calc/test/', target_size=(IMAGE_SIZE, IMAGE_SIZE), batch_size=32, class_mode='binary')

Found 988 images belonging to 2 classes.
Found 246 images belonging to 2 classes.
Found 286 images belonging to 2 classes.


In [13]:
# Train a resnet50 model
model = Sequential()
model.add(ResNet50(include_top=False, pooling='avg', weights='imagenet'))
model.add(Dense(1, activation='sigmoid'))
model.layers[0].trainable = False
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

In [14]:
model.fit(train, validation_data=val, epochs=10)

# Save model 
model.save(f'model.h5')

Epoch 1/10
Epoch 2/10
Epoch 3/10

KeyboardInterrupt: 

In [None]:
# Plot the training and validation accuracy and loss at each epoch
import matplotlib.pyplot as plt
acc = model.history.history['accuracy']
val_acc = model.history.history['val_accuracy']
loss = model.history.history['loss']
val_loss = model.history.history['val_loss']

epochs = range(len(acc))

plt.plot(epochs, acc, 'r', label='Training accuracy')
plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend(loc=0)
plt.figure()