In [None]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import os

In [2]:
def preprocessing_func():

    train_transform = ImageDataGenerator(rescale=1./255,
                                        shear_range=0.2,
                                        zoom_range=0.2,
                                        horizontal_flip=True)
    training_set = train_transform.flow_from_directory('dataset/train',
                                                      target_size=(800,800),
                                                      color_mode="grayscale",
                                                      batch_size=32,
                                                      class_mode='categorical')

    
    validation_transform = ImageDataGenerator(rescale=1./255)
    validation_set = validation_transform.flow_from_directory('dataset/valid',
                                                      target_size=(800,800),
                                                      color_mode="grayscale",
                                                      batch_size=32,
                                                      class_mode='categorical')


    test_transform = ImageDataGenerator(rescale=1./255)
    test_set = test_transform.flow_from_directory('dataset/test',
                                                      target_size=(800,800),
                                                      color_mode="grayscale",
                                                      batch_size=32,
                                                      class_mode='categorical')

    return training_set, validation_set, test_set

In [3]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout

def build_cnn():

    cnn = Sequential()

    cnn.add(Conv2D(16, (3,3), activation='relu', input_shape=[800, 800, 1]))
    cnn.add(MaxPooling2D((2,2)))
    
    cnn.add(Conv2D(32, (3,3), activation='relu'))
    cnn.add(MaxPooling2D((2,2)))

    cnn.add(Conv2D(64, (3,3), activation='relu'))
    cnn.add(MaxPooling2D((2,2)))

    cnn.add(Flatten())

    cnn.add(Dense(64, activation='relu'))
    cnn.add(Dense(units=6, activation='softmax'))

    return cnn

In [4]:
def cnn_train_model(cnn, training_set, validation_set):

    cnn.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    cnn.fit(x=training_set, validation_data=validation_set, epochs=5)
    if os.path.isdir('models'):
        print(f"Folder 'models' already exists.")
        cnn.save('models/exit_model.keras')
    else:
        os.mkdir('models')
        cnn.save('models/exit_model.keras')

In [5]:
def check_for_hid_files():   # for removing "." files

    print("======== Checking on hidden '.' files started. ========")

    datasets_for_check = ["dataset", "dataset_predict"]

    for dataset in datasets_for_check:
        if not os.path.isdir(dataset):
            print(f"Directory '{dataset}' doesn't exist.")
            continue

        for folder in os.listdir(dataset):
            dir_1 = dataset + "/" + folder
            if folder[0] == ".":
                os.remove(dir_1)

            for folder_el in os.listdir(dir_1):
                dir_2 = dir_1 + "/" + folder_el
                if folder_el[0] == ".":
                    os.remove(dir_2)
                if dataset == "dataset":
                    for element in os.listdir(dir_2):
                        if element[0] == ".":
                            dir_rem = dir_2 + "/" + element
                            os.remove(dir_rem)

        if os.path.isdir(dataset):
            print(f"======== Dataset '{dataset}' has been checked for hidden files. ========")

# EXIT CODE

## CREATING MODEL

In [None]:
check_for_hid_files()
training_set, validation_set, test_set = preprocessing_func()
cnn = build_cnn()
cnn_train_model(cnn, training_set, validation_set)

In [None]:
loss_value, acc_value = cnn.evaluate(test_set)

print(f"Loss value: {loss_value}")
print(f"Accuracy value: {acc_value}")