In [None]:
import tensorflow as tf

from keras import models
from keras import layers
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Flatten
import pickle
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.metrics import classification_report, confusion_matrix


from sklearn.model_selection import train_test_split


In [None]:
class CustomGenerator:
    def __init__(self, generator):
        self.generator = generator

    def generate_batches(self):
        while True:
            data = next(self.generator)
            yield data

In [18]:
# the path to dataset
dataset_dir = 'D:\\workspace\\CNN-cassifier\\Data\\3IDL_DataSet'

# Set up a data generator with image augmentation for training
datagen = ImageDataGenerator(
    rescale=1./255,      # Normalize pixel values to be between 0 and 1
    shear_range=0.2,     # Shear transformations
    zoom_range=0.2,      # Zoom transformations
    horizontal_flip=False, # Randomly flip images horizontally
    validation_split=0.2,

)

batch_size = 64

# Flow training images in batches using the datagen generator
train_generator = datagen.flow_from_directory(
    dataset_dir,
    batch_size=batch_size,
    class_mode='categorical',
    shuffle=True,
    subset='training'  
)
test_generator = datagen.flow_from_directory(
    dataset_dir,
    batch_size=batch_size,
    class_mode='categorical',
    shuffle=False,
    subset='validation'
    )

# Save the relevant information to recreate the generator later
test_generator_info = {
    'directory': dataset_dir,
    'batch_size': batch_size,
    'class_mode': 'categorical',
    'shuffle': False,
    'subset': 'validation'
}

Found 4183 images belonging to 3 classes.
Found 1045 images belonging to 3 classes.


In [20]:
# Recreate the pickled file with the highest protocol version
with open('test_generator_info.pkl', 'wb') as file:
    pickle.dump(test_generator_info, file, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
print("Number of training samples:", len(train_generator))
print("Number of batches per epoch:", len(train_generator))
print("Number of classes:", len(train_generator))
print("Class labels:", train_generator)

In [None]:
#Filtring phase
model = Sequential()
model.add(Conv2D(32, (3, 3), activation='relu',kernel_initializer='he_uniform',input_shape=(256, 256, 3)))
model.add(MaxPooling2D((2, 2)))

model.add(Conv2D(64, (5, 5), activation='relu', kernel_initializer='he_uniform'))
model.add(MaxPooling2D((2, 2)))

In [None]:
model.add(Flatten())

In [None]:
#Fully connected layers
model.add(Dense(256, activation='relu', kernel_initializer='he_uniform'))
model.add(Dense(128, activation='relu', kernel_initializer='he_uniform'))
model.add(Dense(3, activation='softmax'))

In [None]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.0001),
    loss=tf.keras.losses.CategoricalCrossentropy(),
    metrics=['categorical_accuracy']
)


In [None]:
model.fit(
    train_generator,
    steps_per_epoch=train_generator.samples // batch_size,
    epochs=15,
)

In [None]:
#Saving the model for the backend later
model.save('model.h5')

### ----------- Insights -------------

In [None]:
# Evaluate the model on the validation set
validation_loss, validation_accuracy = model.evaluate(test_generator)
print("Validation Accuracy:", validation_accuracy)

# Get predictions on the validation set
y_pred = model.predict(test_generator)
y_true = test_generator.classes

# Generate a confusion matrix
conf_matrix = confusion_matrix(y_true, y_pred.argmax(axis=1))

# Print classification report
print("Classification Report: \n" , classification_report(y_true, y_pred.argmax(axis=1)))

# Print confusion matrix
print("Confusion Matrix: \n ", conf_matrix)


In [None]:
# Evaluate the model on the test data
test_loss, test_accuracy = model.evaluate(
    test_generator,
    steps=test_generator.samples // batch_size
)

print("Test Loss:", test_loss)
print("Test Accuracy:", test_accuracy)