In [1]:
import os
import warnings

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from keras.preprocessing.image_dataset import image_dataset_from_directory
from keras_preprocessing.image import ImageDataGenerator


# Reproducability
def set_seed(seed = 31415):
    np.random.seed(seed)
    tf.random.set_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['TF_DETERMINISTIC_OPS'] = '1'


set_seed()

# Set Matplotlib defaults
plt.rc('figure', autolayout = True)
plt.rc('axes', labelweight = 'bold', labelsize = 'large',
       titleweight = 'bold', titlesize = 18, titlepad = 10)
plt.rc('image', cmap = 'magma')
warnings.filterwarnings("ignore")  # to clean up output cells

# Load training and validation sets
image_generator = ImageDataGenerator(rescale = 1 / 255, validation_split = 0.2)

ds_train = image_dataset_from_directory(
    'input/car-or-truck/train',
    labels = 'inferred',
    label_mode = 'binary',
    image_size = [128, 128],
    interpolation = 'nearest',
    batch_size = 64,
    shuffle = False,
)

ds_valid = image_dataset_from_directory(
    'input/car-or-truck/valid',
    labels = 'inferred',
    label_mode = 'binary',
    image_size = [128, 128],
    interpolation = 'nearest',
    batch_size = 64,
    shuffle = False,
)


# Data Pipeline
def convert_to_float(image, label):
    image = tf.image.convert_image_dtype(image, dtype = tf.float32)
    return image, label


AUTOTUNE = tf.data.experimental.AUTOTUNE
ds_train = (
    ds_train
        .map(convert_to_float)
        .cache()
        .prefetch(buffer_size = AUTOTUNE),
)
ds_valid = (
    ds_valid
        .map(convert_to_float)
        .cache()
        .prefetch(buffer_size = AUTOTUNE)
)

Found 5117 files belonging to 2 classes.
Found 5051 files belonging to 2 classes.


In [2]:
from tensorflow import keras
from tensorflow.keras import layers

model = keras.Sequential([
    layers.Conv2D(filters = 32, kernel_size = 5, activation = "relu", padding = 'same',
                  input_shape = [128, 128, 3]),
    layers.MaxPool2D(),
    layers.Conv2D(filters = 64, kernel_size = 3, activation = "relu", padding = 'same'),
    layers.MaxPool2D(),
    layers.Conv2D(filters = 128, kernel_size = 3, activation = "relu", padding = 'same'),
    layers.MaxPool2D(),
    layers.Flatten(),
    layers.Dense(units = 6, activation = "relu"),
    layers.Dense(units = 1, activation = "sigmoid"),
])
model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 128, 128, 32)      2432      
                                                                 
 max_pooling2d (MaxPooling2D  (None, 64, 64, 32)       0         
 )                                                               
                                                                 
 conv2d_1 (Conv2D)           (None, 64, 64, 64)        18496     
                                                                 
 max_pooling2d_1 (MaxPooling  (None, 32, 32, 64)       0         
 2D)                                                             
                                                                 
 conv2d_2 (Conv2D)           (None, 32, 32, 128)       73856     
                                                                 
 max_pooling2d_2 (MaxPooling  (None, 16, 16, 128)      0

In [3]:
model.compile(
    optimizer = tf.keras.optimizers.Adam(epsilon = 0.01),
    loss = 'binary_crossentropy',
    metrics = ['binary_accuracy']
)

history = model.fit(
    ds_train,
    validation_data = ds_valid,
    epochs = 10,
)

ValueError: Failed to find data adapter that can handle input: (<class 'tuple'> containing values of types {"<class 'tensorflow.python.data.ops.dataset_ops.PrefetchDataset'>"}), <class 'NoneType'>

In [None]:
from IPython.core.display_functions import display
import pandas as pd

history_frame = pd.DataFrame(history.history)
history_frame.loc[:, ['loss', 'val_loss']].plot()
history_frame.loc[:, ['binary_accuracy', 'val_binary_accuracy']].plot();

display(history_frame)