In [None]:
# Multi CNN + Edge + Segmentation
# Monitored for validation loss

import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense, Concatenate, Lambda
from tensorflow.keras.models import Model
_dataset_directory = "drive/MyDrive/DL_PROJECT_DATASET_V1"

# Define CNN structure
def create_cnn(input_shape):
    input_layer = Input(shape=input_shape)
    x = Conv2D(32, (3, 3), activation='relu')(input_layer)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    x = Conv2D(16, (2, 2), activation='relu')(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    x = Flatten()(x)
    return Model(inputs=input_layer, outputs=x)
def create_cnn_for_colors(input_shape):
    input_layer = Input(shape=input_shape)
    x = Conv2D(32, (1, 1), activation='relu')(input_layer)  # Using smaller kernel size
    x = Flatten()(x)
    return Model(inputs=input_layer, outputs=x)

# Create inputs for each image type
input_original = Input(shape=(224, 224, 3))
input_edges = Input(shape=(224, 224, 1))
input_colors = Input(shape=(5, 3))  # Assuming 5 prominent colors

# Create CNNs
cnn_original = create_cnn((224, 224, 3))
cnn_edges = create_cnn((224, 224, 1))
cnn_colors = create_cnn_for_colors((5, 3, 1))  # Note: This architecture might need adjustment

# Get outputs from CNNs
output_original = cnn_original(input_original)
output_edges = cnn_edges(input_edges)
output_colors = cnn_colors(Lambda(lambda x: tf.expand_dims(x, axis=-1))(input_colors))

# Concatenate outputs
concatenated = Concatenate()([output_original, output_edges, output_colors])

# Dense layers for classification
dense = Dense(128, activation='relu')(concatenated)
output_layer = Dense(6, activation='softmax')(dense)  # Assuming 6 classes

# Complete model
model = Model(inputs=[input_original, input_edges, input_colors], outputs=output_layer)

# Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

import cv2
import numpy as np
from sklearn.cluster import KMeans

def preprocess_image(img, n_colors=5):
    # Convert from float32 to uint8 and from RGB to BGR
    image = cv2.cvtColor((img * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)
    
    # Convert to grayscale for edge detection
    gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

    # Edge detection
    edges = cv2.Canny(gray_image, 100, 200)
    resized_edges = cv2.resize(edges, (224, 224))

    # Top n colors
    pixels = image.reshape(-1, 3)
    kmeans = KMeans(n_clusters=n_colors, n_init=10, random_state=0).fit(pixels)
    prominent_colors = kmeans.cluster_centers_.astype(int)

    resized_image = cv2.resize(image, (224, 224))
    return resized_image, resized_edges, prominent_colors

def custom_generator(image_data_generator, steps_per_epoch):
    batch_count = 0
    while True:
        for batch_x, batch_y in image_data_generator:
            batch_x_original = np.zeros((batch_x.shape[0], 224, 224, 3))
            batch_x_edges = np.zeros((batch_x.shape[0], 224, 224, 1))
            batch_x_colors = np.zeros((batch_x.shape[0], 5, 3))

            for i, img in enumerate(batch_x):
                original, edges, colors = preprocess_image(img)
                batch_x_original[i] = cv2.cvtColor(original, cv2.COLOR_BGR2RGB) / 255.0
                batch_x_edges[i] = np.expand_dims(edges, axis=-1) / 255.0
                batch_x_colors[i] = colors / 255.0

            yield [batch_x_original, batch_x_edges, batch_x_colors], batch_y

            batch_count += 1
            if batch_count >= steps_per_epoch:
                batch_count = 0
                break

from tensorflow.keras.preprocessing.image import ImageDataGenerator

train_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2, 
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True,
    vertical_flip=True,
    zoom_range=0.1
)
validation_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2,
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True,
    vertical_flip=True,
    zoom_range=0.1
)

total_train_samples = 343
total_val_samples = 83
batch_size = 32

train_steps = total_train_samples // batch_size + (1 if total_train_samples % batch_size else 0)
val_steps = total_val_samples // batch_size + (1 if total_val_samples % batch_size else 0)

train_data_gen = train_datagen.flow_from_directory(
    _dataset_directory, 
    target_size=(224, 224), 
    batch_size=batch_size, 
    class_mode='categorical', 
    subset='training'
)

val_data_gen = validation_datagen.flow_from_directory(
    _dataset_directory, 
    target_size=(224, 224), 
    batch_size=batch_size, 
    class_mode='categorical', 
    subset='validation'
)

train_generator = custom_generator(train_data_gen, train_steps)
validation_generator = custom_generator(val_data_gen, val_steps)


from keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau

checkpoint = ModelCheckpoint("./model_V3.1.h5", monitor='val_acc', verbose=1, save_best_only=True, mode='max')

# ModelCheckpoint to save the model after every epoch
checkpoint = ModelCheckpoint(
    'art_style_model_best_v3.1.h5', 
    monitor='val_loss', 
    verbose=1, 
    save_best_only=True, 
    mode='min'
)

# EarlyStopping to stop training when the validation loss has not improved after 5 epochs
early_stopping = EarlyStopping(
    monitor='val_loss', 
    patience=5, 
    verbose=1, 
    mode='min',
    restore_best_weights=True
)

# ReduceLROnPlateau to reduce the learning rate when the validation loss plateaus
reduce_lr = ReduceLROnPlateau(
    monitor='val_loss', 
    factor=0.2, 
    patience=3, 
    verbose=1, 
    mode='min',
    min_lr=0.00001
)


history = model.fit(
    train_generator,
    steps_per_epoch=train_steps,
    epochs=30,
    validation_data=validation_generator,
    validation_steps=val_steps,
    callbacks=[checkpoint, early_stopping, reduce_lr]
)

# Save the trained model
model.save('art_style_model_v3.1.h5')

history = {'loss': [4.457160472869873, 1.716261625289917, 1.4640203714370728, 1.260480284690857, 1.1083595752716064, 1.063092589378357, 1.0915950536727905, 0.9701434969902039, 0.9114781022071838, 0.871853768825531, 0.9130118489265442, 0.8544056415557861, 0.8068971037864685, 0.7558664083480835, 0.751266598701477, 0.7298506498336792, 0.7162065505981445, 0.7210341691970825, 0.7144117951393127, 0.7021090984344482, 0.6917153596878052, 0.6845459938049316, 0.6965336203575134, 0.6972512602806091], 'accuracy': [0.22448979318141937, 0.33236151933670044, 0.4810495674610138, 0.5160349607467651, 0.5743440389633179, 0.5830903649330139, 0.588921308517456, 0.6355684995651245, 0.6763848662376404, 0.6501457691192627, 0.6647230386734009, 0.6705539226531982, 0.7026239037513733, 0.7201166152954102, 0.7142857313156128, 0.7259474992752075, 0.7259474992752075, 0.7317784428596497, 0.7346938848495483, 0.7405247688293457, 0.7492711544036865, 0.7434402108192444, 0.7463557124137878, 0.7492711544036865], 'val_loss': [2.0124619007110596, 1.5428327322006226, 1.4092859029769897, 1.2252542972564697, 1.053996205329895, 0.9147937893867493, 0.9492437243461609, 0.9353174567222595, 0.8333977460861206, 0.8799031376838684, 0.8576726317405701, 0.8942350745201111, 0.8115752935409546, 0.6981508731842041, 0.7377256155014038, 0.7041245698928833, 0.7449336051940918, 0.6957244873046875, 0.6725016832351685, 0.701926052570343, 0.7784252762794495, 0.7231735587120056, 0.7501319050788879, 0.8031225800514221], 'val_accuracy': [0.20481927692890167, 0.5662650465965271, 0.4337349534034729, 0.5180723071098328, 0.6144578456878662, 0.6626505851745605, 0.6024096608161926, 0.6987951993942261, 0.6867470145225525, 0.6746987700462341, 0.5903614163398743, 0.6385542154312134, 0.6746987700462341, 0.7349397540092468, 0.7108433842658997, 0.759036123752594, 0.6867470145225525, 0.7469879388809204, 0.759036123752594, 0.7349397540092468, 0.7349397540092468, 0.7228915691375732, 0.6746987700462341, 0.6867470145225525], 'lr': [0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.00020000001, 0.00020000001, 0.00020000001, 0.00020000001, 0.00020000001, 4.0000003e-05, 4.0000003e-05, 4.0000003e-05, 4.0000003e-05, 4.0000003e-05, 1e-05, 1e-05]}