In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense, Concatenate, Lambda
from tensorflow.keras.models import Model
_dataset_directory = "drive/MyDrive/DL_PROJECT_DATASET_V2"
_test_directory = "drive/MyDrive/DL_PROJECT_TEST_DATASET_V2"

import cv2
import numpy as np
from sklearn.cluster import KMeans

def preprocess_image(img, n_colors=5):
    # Convert from float32 to uint8 and from RGB to BGR
    image = cv2.cvtColor((img * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)

    # Convert to grayscale for edge detection
    gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

    # Edge detection
    edges = cv2.Canny(gray_image, 100, 200)
    resized_edges = cv2.resize(edges, (224, 224))

    # Top n colors
    pixels = image.reshape(-1, 3)
    kmeans = KMeans(n_clusters=n_colors, n_init=10, random_state=0).fit(pixels)
    prominent_colors = kmeans.cluster_centers_.astype(int)

    resized_image = cv2.resize(image, (224, 224))
    return resized_image, resized_edges, prominent_colors

def custom_generator(image_data_generator, steps_per_epoch):
    batch_count = 0
    while True:
        for batch_x, batch_y in image_data_generator:
            batch_x_original = np.zeros((batch_x.shape[0], 224, 224, 3))
            batch_x_edges = np.zeros((batch_x.shape[0], 224, 224, 1))
            batch_x_colors = np.zeros((batch_x.shape[0], 5, 3))

            for i, img in enumerate(batch_x):
                original, edges, colors = preprocess_image(img)
                batch_x_original[i] = cv2.cvtColor(original, cv2.COLOR_BGR2RGB) / 255.0
                batch_x_edges[i] = np.expand_dims(edges, axis=-1) / 255.0
                batch_x_colors[i] = colors / 255.0

            yield [batch_x_original, batch_x_edges, batch_x_colors], batch_y

            batch_count += 1
            if batch_count >= steps_per_epoch:
                batch_count = 0
                break

# Define CNN structure
def create_cnn(input_shape):
    input_layer = Input(shape=input_shape)
    x = Conv2D(32, (3, 3), activation='relu')(input_layer)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    x = Conv2D(64, (2, 2), activation='relu')(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    x = Conv2D(8, (2, 2), activation='relu')(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    x = Flatten()(x)
    return Model(inputs=input_layer, outputs=x)
def create_cnn_for_colors(input_shape):
    input_layer = Input(shape=input_shape)
    x = Conv2D(32, (1, 1), activation='relu')(input_layer)  # Using smaller kernel size
    x = Flatten()(x)
    return Model(inputs=input_layer, outputs=x)

# Create inputs for each image type
input_original = Input(shape=(224, 224, 3))
input_edges = Input(shape=(224, 224, 1))
input_colors = Input(shape=(5, 3))  # Assuming 5 prominent colors

# Create CNNs
cnn_original = create_cnn((224, 224, 3))
cnn_edges = create_cnn((224, 224, 1))
cnn_colors = create_cnn_for_colors((5, 3, 1))  # Note: This architecture might need adjustment

# Get outputs from CNNs
output_original = cnn_original(input_original)
output_edges = cnn_edges(input_edges)
output_colors = cnn_colors(Lambda(lambda x: tf.expand_dims(x, axis=-1))(input_colors))

# Concatenate outputs
concatenated = Concatenate()([output_original, output_edges, output_colors])

# Dense layers for classification
dense = Dense(64, activation='relu')(concatenated)
dense = Dense(16, activation='relu')(dense)
output_layer = Dense(5, activation='softmax')(dense)  # Assuming 6 classes

# Complete model
model = Model(inputs=[input_original, input_edges, input_colors], outputs=output_layer)

# Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

from tensorflow.keras.preprocessing.image import ImageDataGenerator

train_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2,
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True,
    vertical_flip=True,
    zoom_range=0.1
)
validation_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2,
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True,
    vertical_flip=True,
    zoom_range=0.1
)
test_datagen = ImageDataGenerator(rescale=1./255)


total_train_samples = 343
total_val_samples = 83
total_test_samples = 54
batch_size = 32

train_steps = total_train_samples // batch_size + (1 if total_train_samples % batch_size else 0)
val_steps = total_val_samples // batch_size + (1 if total_val_samples % batch_size else 0)
test_steps = total_test_samples // batch_size + (1 if total_test_samples % batch_size else 0)

print('train')
train_data_gen = train_datagen.flow_from_directory(
    _dataset_directory,
    target_size=(224, 224),
    batch_size=batch_size,
    class_mode='categorical',
    subset='training'
)
class_names = list(train_data_gen.class_indices.keys())
print(class_names)
# Get the class indices mapping
class_indices = train_data_gen.class_indices

# Create an empty dictionary to store the count for each class
class_count = {class_name: 0 for class_name in class_indices.keys()}

# Crawl through the directory
for root, _, files in os.walk(_dataset_directory):
    for file in files:
        # Get the class name from the file path
        class_name = os.path.basename(root)
        # Increment the count for the respective class
        class_count[class_name] += 1

# Print the number of elements in each class
for class_name, count in class_count.items():
    print(f"Class '{class_name}': {count} elements")


print('validation')
val_data_gen = validation_datagen.flow_from_directory(
    _dataset_directory,
    target_size=(224, 224),
    batch_size=batch_size,
    class_mode='categorical',
    subset='validation'
)
class_names = list(val_data_gen.class_indices.keys())
print(class_names)

print('test')
test_data_gen = test_datagen.flow_from_directory(
    _test_directory,
    target_size=(224, 224),
    batch_size=batch_size,
    class_mode='categorical'
)
class_names = list(test_data_gen.class_indices.keys())
print(class_names)

train_generator = custom_generator(train_data_gen, train_steps)
validation_generator = custom_generator(val_data_gen, val_steps)

from keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
version = "5.2"
checkpoint = ModelCheckpoint("./model_V"+version+".h5", monitor='val_acc', verbose=1, save_best_only=True, mode='max')

# ModelCheckpoint to save the model after every epoch
checkpoint = ModelCheckpoint(
    'art_style_model_best_v'+version+'.h5',
    monitor='val_accuracy',
    verbose=1,
    save_best_only=True,
    mode='max'
)

# EarlyStopping to stop training when the validation loss has not improved after 5 epochs
early_stopping = EarlyStopping(
    monitor='val_accuracy',
    patience=5,
    verbose=1,
    mode='max',
    restore_best_weights=True
)

# ReduceLROnPlateau to reduce the learning rate when the validation loss plateaus
reduce_lr = ReduceLROnPlateau(
    monitor='val_accuracy',
    factor=0.2,
    patience=5,
    verbose=1,
    mode='max',
    min_lr=0.00001
)

history = model.fit(
    train_generator,
    steps_per_epoch=train_steps,
    epochs=10,
    validation_data=validation_generator,
    validation_steps=val_steps,
    callbacks=[checkpoint, early_stopping, reduce_lr]
)

# Save the trained model
model.save('art_style_model_v'+version+'.h5')





{'loss': [1.6337611675262451, 1.5539116859436035, 1.4851839542388916, 1.3919838666915894, 1.4624838829040527, 1.3689720630645752, 1.3602548837661743, 1.2213523387908936, 1.1766377687454224, 1.082789659500122], 'accuracy': [0.24147726595401764, 0.3323170840740204, 0.3323170840740204, 0.3841463327407837, 0.2926829159259796, 0.37804877758026123, 0.3628048896789551, 0.5182926654815674, 0.46341463923454285, 0.5304877758026123], 'val_loss': [1.5759292840957642, 1.5368410348892212, 1.3589630126953125, 1.3686589002609253, 1.4192320108413696, 1.3711737394332886, 1.2673320770263672, 1.1651921272277832, 1.19804048538208, 1.03669011592865], 'val_accuracy': [0.35555556416511536, 0.2888889014720917, 0.42222222685813904, 0.35555556416511536, 0.4000000059604645, 0.3777777850627899, 0.41111111640930176, 0.4888888895511627, 0.3888888955116272, 0.6000000238418579], 'lr': [0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001]}
