In [None]:
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from keras.utils import plot_model
import tensorflow as tf
import keras
import matplotlib.pyplot as plt
import pandas as pd
import os
import shutil

In [None]:
dataframe = pd.read_csv('../data/labels-map-proj-v3.txt', delimiter=' ', header=None)
dataframe.columns = ['file_names','labels']
display(dataframe)

In [None]:
file_path = './sorted_images/'
try:
    target_dir = os.path.dirname(file_path)
    if not os.path.exists(target_dir):
            os.makedirs(target_dir)
    for i in dataframe['labels'].unique():
        sub_dir = os.path.dirname(target_dir+ '/' + str(i) + '/')
        if not os.path.exists(sub_dir):
            os.makedirs(sub_dir)
except IOError as e:
        print(f"Error while creating directories: {e}")

In [None]:
image_num = dataframe.shape[0]
for i in range(image_num):
    filname = dataframe['file_names'].iloc[i]
    label = dataframe['labels'].iloc[i]
    target_location = file_path + str(label) + '/' + filname
    if not os.path.exists(target_location):
        shutil.copy2('../data/map-proj-v3/' + filname, target_location)

In [None]:
train_datagen = ImageDataGenerator(validation_split=0.2)

batch_size = 32
img_width, img_height = 256, 256

train_generator = train_datagen.flow_from_directory(
    file_path,
    target_size=(img_width, img_height),
    batch_size=batch_size,
    class_mode='categorical',
    subset='training')

validation_generator = train_datagen.flow_from_directory(
    file_path,
    target_size=(img_width, img_height),
    batch_size=batch_size,
    class_mode='categorical',  # Hier geändert zu 'categorical' für Mehrklassenklassifizierung
    subset='validation')

In [None]:
def plot_history(history):
    # Plot training & validation accuracy and loss values
    plt.plot(history.history['accuracy'], label='Train Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.plot(history.history['loss'], label='Train Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title('Model Metrics')
    plt.xlabel('Epoch')
    plt.ylabel('Metrics')
    plt.legend()
    plt.show()


In [None]:
# model from https://keras.io/examples/vision/image_classification_from_scratch/

def make_model(input_shape, num_classes):
    inputs = keras.Input(shape=input_shape)

    # Entry block
    x = layers.Rescaling(1.0 / 255)(inputs)
    x = layers.Conv2D(128, 3, strides=2, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    previous_block_activation = x  # Set aside residual

    for size in [256, 512, 728]:
        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(size, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(size, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.MaxPooling2D(3, strides=2, padding="same")(x)

        # Project residual
        residual = layers.Conv2D(size, 1, strides=2, padding="same")(
            previous_block_activation
        )
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    x = layers.SeparableConv2D(1024, 3, padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    x = layers.GlobalAveragePooling2D()(x)
    if num_classes == 2:
        units = 1
    else:
        units = num_classes

    x = layers.Dropout(0.25)(x)
    # We specify activation=None so as to return logits
    outputs = layers.Dense(units, activation=None)(x)
    return keras.Model(inputs, outputs)


In [None]:
model = make_model(input_shape=(img_width, img_height, 3), num_classes=train_generator.num_classes)
epochs = 25
callbacks = [
    keras.callbacks.ModelCheckpoint("save_at_{epoch}.keras"),
]
model.compile(
    optimizer=keras.optimizers.Adam(3e-4),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy(name="acc")],
)

display(model.summary())
plot_model(model, to_file='model3_plot.png', show_shapes=True, show_layer_names=True, dpi=60)

In [None]:
history = model.fit(
    train_generator,
    steps_per_epoch=train_generator.samples // batch_size,
    epochs=10,
    callbacks=callbacks,
    validation_data=validation_generator,
    validation_steps=800 // batch_size)

In [None]:
plot_history(history)
# TODO: convert model to coral compatible model

In [None]:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()