In [None]:
import tensorflow as tf

In [29]:
seed = 10000
train_images = tf.keras.preprocessing.image_dataset_from_directory(
    directory='./images/',
    labels='inferred',
    image_size=(128, 128),
    validation_split=0.3,
    subset='training',
    seed=seed,
)

val_images = tf.keras.preprocessing.image_dataset_from_directory(
    directory='./images/',
    labels='inferred',
    image_size=(128, 128),
    validation_split=0.3,
    subset='validation',
    seed=seed,
)

Found 1997 files belonging to 2 classes.
Using 1398 files for training.
Found 1997 files belonging to 2 classes.
Using 599 files for validation.


In [17]:
# preparing model for transfer learning
# using MobileNetV2 because it is also an image classification model, and it is in the tutorial I'm using (https://www.tensorflow.org/tutorials/images/transfer_learning)

IMAGE_SHAPE = (128, 128, 3)
base_classifier = tf.keras.applications.MobileNetV2(
    input_shape=IMAGE_SHAPE,
    include_top=False,
    weights='imagenet',
)
base_classifier.trainable = False

# scaling for classifier model
temp_scaling_input = tf.keras.layers.experimental.preprocessing.Rescaling(1/127.5, offset=-1)

# scaling for when this is converted to generator
# the generator outputs values in 0, 1. The model expects -1, 1
gen_scaling_input = tf.keras.layers.experimental.preprocessing.Rescaling(2, offset=-1)

pooling_layer = tf.keras.layers.GlobalAveragePooling2D()
prediction_layer = tf.keras.layers.Dense(1)

inputs = tf.keras.layers.Input(shape=IMAGE_SHAPE)
x = temp_scaling_input(inputs)
x = base_classifier(x, training=False)
x = pooling_layer(x)
outputs = prediction_layer(x)
classifier_model = tf.keras.Model(inputs, outputs)

In [20]:
LEARNING_RATE = 1e-4

classifier_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=['accuracy']
)

In [30]:
EPOCHS = 10

history = classifier_model.fit(
    train_images,
    epochs=EPOCHS,
    validation_data=val_images,
)

Epoch 1/10
Epoch 2/10
 9/44 [=====>........................] - ETA: 6s - loss: 0.0913 - accuracy: 0.9653

KeyboardInterrupt: 

In [28]:
import imghdr
import os

valid_images = {'gif', 'png', 'jpeg', 'jpg', 'bmp'}

dir = './images/not_creepy/'
for f in os.listdir(dir):
    image_type = imghdr.what(f'{dir}{f}')
    if image_type not in valid_images:
        print(f)
        os.remove(f'{dir}{f}')