In [1]:
import tensorflow as tf
%matplotlib inline
import matplotlib.pylab as plt
import numpy as np

## numbers for traininng

In [3]:
estimated_image_count = 600 #TODO count it
nr_samples = estimated_image_count
nr_epochs = 1 #TODO way more or early stop (see at training above)
batch_size = 32
steps_per_epoch = nr_samples / batch_size

# load data 

In [4]:
train_image_generator = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale=1./255,
    rotation_range=25,
    width_shift_range=0.1,
    height_shift_range=0.1,
    brightness_range=(0.9, 1.1),
    horizontal_flip=True,
    vertical_flip=True
)

train_dir = "./images/cleaned"
IMG_HEIGHT, IMG_WIDTH = 224, 224

train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,
                                                           directory=train_dir,
                                                           shuffle=True,
                                                           target_size=(IMG_HEIGHT, IMG_WIDTH))



print(train_data_gen.class_indices)

Found 543 images belonging to 4 classes.
{'Bier': 0, 'Cocktail': 1, 'Wasser': 2, 'Wein': 3}


# display images

In [5]:
sample_training_images, labels = next(train_data_gen)
class_label = {v:k for k,v in train_data_gen.class_indices.items()}
fig, axes = plt.subplots(8, 4, figsize=(20,20))
axes = axes.flatten()
for img, label, ax in zip(sample_training_images, labels, axes):
    ax.set_title(class_label[np.argmax(label)])
    ax.imshow(img)
    ax.axis('off')
plt.tight_layout()



OSError: cannot identify image file './images/cleaned\\Wein\\12.0641295_PE700405_S5.JPG'

# model

In [6]:
base_model = tf.keras.applications.resnet_v2.ResNet50V2(
        include_top=False,
        pooling='max',
        input_shape=(IMG_HEIGHT, IMG_WIDTH, 3),
        weights='imagenet')

model = tf.keras.Sequential([
    base_model,
    tf.keras.layers.Dense(512, 'relu'),
    tf.keras.layers.Dense(256, 'relu'),
    tf.keras.layers.Dense(train_data_gen.num_classes, 'softmax')
])
    
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
resnet50v2 (Model)           (None, 2048)              23564800  
_________________________________________________________________
dense (Dense)                (None, 512)               1049088   
_________________________________________________________________
dense_1 (Dense)              (None, 256)               131328    
_________________________________________________________________
dense_2 (Dense)              (None, 4)                 1028      
Total params: 24,746,244
Trainable params: 24,700,804
Non-trainable params: 45,440
_________________________________________________________________


### utility function that will skip batches if there is a broken image in it

In [7]:
def skip_broken_images(gen):
    while True:
        try:
            data, labels = next(gen)
            yield data, labels
        except GeneratorExit:
            break
        except:
            pass

# train fully connected layers

In [None]:
for layer in base_model.layers:
    layer.trainable = False
    
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['categorical_accuracy'])

history = model.fit_generator(
    skip_broken_images(train_data_gen),
    steps_per_epoch=steps_per_epoch, #use way more steps here: number of samples / batch size
    epochs=nr_epochs,# or EarlyStopping callback
    callbacks=[]
)

plt.plot(history.history['loss'])



# train complete network

In [None]:
for layer in base_model.layers:
    layer.trainable = True
        
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['categorical_accuracy'])    

history = model.fit_generator(
    skip_broken_images(train_data_gen),
    steps_per_epoch=steps_per_epoch, #use way more steps here: number of samples / batch size
    epochs=nr_epochs, # or EarlyStopping callback
    callbacks=[]
)

plt.plot(history.history['loss'])

# save model

In [None]:
model.save('drink-detection.h5')