In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense, Concatenate, Lambda
from tensorflow.keras.models import Model
_dataset_directory = "drive/MyDrive/DL_PROJECT_DATASET_V1"
_test_directory = "drive/MyDrive/DL_PROJECT_TEST_DATASET_V1"

import cv2
import numpy as np
from sklearn.cluster import KMeans

def preprocess_image(img, n_colors=5):
    # Convert from float32 to uint8 and from RGB to BGR
    image = cv2.cvtColor((img * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)

    # Convert to grayscale for edge detection
    gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

    # Edge detection
    edges = cv2.Canny(gray_image, 100, 200)
    resized_edges = cv2.resize(edges, (224, 224))

    # Top n colors
    pixels = image.reshape(-1, 3)
    kmeans = KMeans(n_clusters=n_colors, n_init=10, random_state=0).fit(pixels)
    prominent_colors = kmeans.cluster_centers_.astype(int)

    resized_image = cv2.resize(image, (224, 224))
    return resized_image, resized_edges, prominent_colors

def custom_generator(image_data_generator, steps_per_epoch):
    batch_count = 0
    while True:
        for batch_x, batch_y in image_data_generator:
            batch_x_original = np.zeros((batch_x.shape[0], 224, 224, 3))
            batch_x_edges = np.zeros((batch_x.shape[0], 224, 224, 1))
            batch_x_colors = np.zeros((batch_x.shape[0], 5, 3))

            for i, img in enumerate(batch_x):
                original, edges, colors = preprocess_image(img)
                batch_x_original[i] = cv2.cvtColor(original, cv2.COLOR_BGR2RGB) / 255.0
                batch_x_edges[i] = np.expand_dims(edges, axis=-1) / 255.0
                batch_x_colors[i] = colors / 255.0

            yield [batch_x_original, batch_x_edges, batch_x_colors], batch_y

            batch_count += 1
            if batch_count >= steps_per_epoch:
                batch_count = 0
                break

# Define CNN structure
def create_cnn(input_shape):
    input_layer = Input(shape=input_shape)
    x = Conv2D(32, (3, 3), activation='relu')(input_layer)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    x = Conv2D(16, (2, 2), activation='relu')(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    x = Flatten()(x)
    return Model(inputs=input_layer, outputs=x)
def create_cnn_for_colors(input_shape):
    input_layer = Input(shape=input_shape)
    x = Conv2D(32, (1, 1), activation='relu')(input_layer)  # Using smaller kernel size
    x = Flatten()(x)
    return Model(inputs=input_layer, outputs=x)

# Create inputs for each image type
input_original = Input(shape=(224, 224, 3))
input_edges = Input(shape=(224, 224, 1))
input_colors = Input(shape=(5, 3))  # Assuming 5 prominent colors

# Create CNNs
cnn_original = create_cnn((224, 224, 3))
cnn_edges = create_cnn((224, 224, 1))
cnn_colors = create_cnn_for_colors((5, 3, 1))  # Note: This architecture might need adjustment

# Get outputs from CNNs
output_original = cnn_original(input_original)
output_edges = cnn_edges(input_edges)
output_colors = cnn_colors(Lambda(lambda x: tf.expand_dims(x, axis=-1))(input_colors))

# Concatenate outputs
concatenated = Concatenate()([output_original, output_edges, output_colors])

# Dense layers for classification
dense = Dense(128, activation='relu')(concatenated)
output_layer = Dense(6, activation='softmax')(dense)  # Assuming 6 classes

# Complete model
model = Model(inputs=[input_original, input_edges, input_colors], outputs=output_layer)

# Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

from tensorflow.keras.preprocessing.image import ImageDataGenerator

train_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2,
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True,
    vertical_flip=True,
    zoom_range=0.1
)
validation_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2,
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True,
    vertical_flip=True,
    zoom_range=0.1
)
test_datagen = ImageDataGenerator(rescale=1./255)


total_train_samples = 343
total_val_samples = 83
total_test_samples = 54
batch_size = 32

train_steps = total_train_samples // batch_size + (1 if total_train_samples % batch_size else 0)
val_steps = total_val_samples // batch_size + (1 if total_val_samples % batch_size else 0)
test_steps = total_test_samples // batch_size + (1 if total_test_samples % batch_size else 0)

train_data_gen = train_datagen.flow_from_directory(
    _dataset_directory,
    target_size=(224, 224),
    batch_size=batch_size,
    class_mode='categorical',
    subset='training'
)

val_data_gen = validation_datagen.flow_from_directory(
    _dataset_directory,
    target_size=(224, 224),
    batch_size=batch_size,
    class_mode='categorical',
    subset='validation'
)

test_data_gen = test_datagen.flow_from_directory(
    _test_directory,
    target_size=(224, 224),
    batch_size=batch_size,
    class_mode='categorical'
)

train_generator = custom_generator(train_data_gen, train_steps)
validation_generator = custom_generator(val_data_gen, val_steps)

from keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau

checkpoint = ModelCheckpoint("./model_V3.3.h5", monitor='val_acc', verbose=1, save_best_only=True, mode='max')

# ModelCheckpoint to save the model after every epoch
checkpoint = ModelCheckpoint(
    'art_style_model_best_v3.3.h5',
    monitor='val_accuracy',
    verbose=1,
    save_best_only=True,
    mode='max'
)

# EarlyStopping to stop training when the validation loss has not improved after 5 epochs
early_stopping = EarlyStopping(
    monitor='val_accuracy',
    patience=5,
    verbose=1,
    mode='max',
    restore_best_weights=True
)

# ReduceLROnPlateau to reduce the learning rate when the validation loss plateaus
reduce_lr = ReduceLROnPlateau(
    monitor='val_accuracy',
    factor=0.2,
    patience=5,
    verbose=1,
    mode='max',
    min_lr=0.00001
)


history = model.fit(
    train_generator,
    steps_per_epoch=train_steps,
    epochs=30,
    validation_data=validation_generator,
    validation_steps=val_steps,
    callbacks=[checkpoint, early_stopping, reduce_lr]
)

# Save the trained model
model.save('art_style_model_v3.3.h5')



{'loss': [1.3677223920822144, 1.2917981147766113, 1.0909900665283203, 1.0387272834777832, 0.9726772308349609, 0.7953414916992188, 0.8516399264335632, 0.8836516737937927, 0.7710679769515991, 0.8679074048995972, 0.7263955473899841, 0.7562497854232788, 0.7076928615570068, 0.6437550783157349, 0.7378449440002441, 0.7285891771316528, 0.6088800430297852, 0.5626384615898132, 0.587087094783783, 0.5227916240692139, 0.5964056849479675, 0.5794382095336914, 0.5608565807342529, 0.5573264956474304, 0.5960157513618469], 'accuracy': [0.4693877696990967, 0.5102040767669678, 0.6034985184669495, 0.5830903649330139, 0.6297376155853271, 0.6851311922073364, 0.647230327129364, 0.6501457691192627, 0.6851311922073364, 0.6705539226531982, 0.7026239037513733, 0.7113702893257141, 0.7230320572853088, 0.7551020383834839, 0.7026239037513733, 0.705539345741272, 0.7842565774917603, 0.7871720194816589, 0.795918345451355, 0.819242000579834, 0.7609329223632812, 0.7900874614715576, 0.8017492890357971, 0.7900874614715576, 0.7813411355018616], 'val_loss': [1.315810203552246, 1.1910982131958008, 1.1532803773880005, 0.9834330677986145, 0.8514275550842285, 0.9633168578147888, 0.8781489133834839, 0.9170699715614319, 0.8699895739555359, 0.8146267533302307, 0.6767652630805969, 0.684617280960083, 0.6725450158119202, 0.6372936367988586, 0.7225661873817444, 0.700664222240448, 0.698101282119751, 0.7462931871414185, 0.6361097693443298, 0.5596312284469604, 0.8618825674057007, 0.6015252470970154, 0.6037704944610596, 0.80154949426651, 0.5815305709838867], 'val_accuracy': [0.4939759075641632, 0.6024096608161926, 0.5783132314682007, 0.5783132314682007, 0.6265060305595398, 0.6265060305595398, 0.6385542154312134, 0.650602400302887, 0.6265060305595398, 0.6385542154312134, 0.7228915691375732, 0.7469879388809204, 0.7469879388809204, 0.7710843086242676, 0.6746987700462341, 0.7108433842658997, 0.7831325531005859, 0.6987951993942261, 0.7469879388809204, 0.8072289228439331, 0.7108433842658997, 0.7710843086242676, 0.7951807379722595, 0.6746987700462341, 0.7469879388809204], 'lr': [0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001]}
