In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense, Concatenate, Lambda, BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Parameters
n = 224  # Final image size
m = 4    # Grid size
grid_size = n // m
num_cnn_layers = 3  # Number of layers in each small CNN
num_classes = 6     # Number of art style categories

_dataset_directory = "drive/MyDrive/DL_PROJECT_DATASET_V1"

# Directory containing the dataset
dataset_dir = _dataset_directory

# Data Augmentation for training data
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest',
    validation_split=0.2  # Split data into training (80%) and validation (20%)
)

# Data generator for validation data (No augmentation, only rescaling)
validation_datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=0.2
)

# Data generator for test data (No augmentation, only rescaling)
test_datagen = ImageDataGenerator(rescale=1./255)

# Load images from directories
train_generator = train_datagen.flow_from_directory(
    dataset_dir,
    target_size=(n, n),
    batch_size=32,
    class_mode='categorical',
    subset='training'
)

validation_generator = validation_datagen.flow_from_directory(
    dataset_dir,
    target_size=(n, n),
    batch_size=32,
    class_mode='categorical',
    subset='validation'
)

# CNN for each grid cell
def create_cnn(input_shape):
    input_layer = Input(shape=input_shape)
    x = input_layer
    for _ in range(num_cnn_layers):
        x = Conv2D(32, (3, 3), activation='relu')(x)
        x = Conv2D(124, (2, 2), activation='relu')(x)
        #x = BatchNormalization()(x)
        x = MaxPooling2D((2, 2))(x)
    return Model(inputs=input_layer, outputs=x)

# Input layer for the whole image
input_layer = Input(shape=(n, n, 3))

# Create CNNs for each grid cell and store their outputs
cnn_outputs = []
for i in range(m):
    for j in range(m):
        grid_input = Lambda(
            lambda z: z[:, i*grid_size:(i+1)*grid_size, j*grid_size:(j+1)*grid_size, :]
        )(input_layer)
        grid_cnn = create_cnn((grid_size, grid_size, 3))
        cnn_outputs.append(grid_cnn(grid_input))

# Merge CNN outputs
merged = Concatenate()(cnn_outputs)
flattened = Flatten()(merged)

# Final dense layers
dense = Dense(128, activation='relu')(flattened)
output_layer = Dense(num_classes, activation='softmax')(dense)  # Change based on your task

# Complete model
model = Model(inputs=input_layer, outputs=output_layer)

# Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])


from keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau

checkpoint = ModelCheckpoint("./model_V2.1.h5", monitor='val_accuracy', verbose=1, save_best_only=True, mode='max')

# ModelCheckpoint to save the model after every epoch
checkpoint = ModelCheckpoint(
    'art_style_model_best_v2.1.h5',
    monitor='val_accuracy',
    verbose=1,
    save_best_only=True,
    mode='max'
)

# EarlyStopping to stop training when the validation loss has not improved after 5 epochs
early_stopping = EarlyStopping(
    monitor='val_accuracy',
    patience=5,
    verbose=1,
    mode='max',
    restore_best_weights=True
)

# ReduceLROnPlateau to reduce the learning rate when the validation loss plateaus
reduce_lr = ReduceLROnPlateau(
    monitor='val_accuracy',
    factor=0.2,
    patience=3,
    verbose=1,
    mode='max',
    min_lr=0.00001
)



# Train the model
history = model.fit(
    train_generator,
    epochs=30,
    validation_data=validation_generator,
    callbacks=[checkpoint, early_stopping, reduce_lr]
)

# Save the trained model
model.save('art_style_model_2.1.h5')

# Evaluate the model (Optional: If you have a separate test set)
# test_generator = test_datagen.flow_from_directory(
#     test_dataset_dir,
#     target_size=(n, n),
#     batch_size=32,
#     class_mode='categorical'
# )
# model.evaluate(test_generator)


{'loss': [1.7418172359466553, 1.4547181129455566, 1.4345157146453857, 1.402978539466858, 1.3294391632080078, 1.2639915943145752, 1.2911114692687988, 1.2897870540618896, 1.2472788095474243, 1.2193827629089355, 1.1779048442840576], 'accuracy': [0.27988338470458984, 0.44897958636283875, 0.4548105001449585, 0.44606414437294006, 0.5072886347770691, 0.5218659043312073, 0.5131195187568665, 0.5102040767669678, 0.524781346321106, 0.533527672290802, 0.5160349607467651], 'val_loss': [1.5180459022521973, 1.511977195739746, 1.296093463897705, 1.3138688802719116, 1.2082133293151855, 1.2678929567337036, 1.2448573112487793, 1.2557536363601685, 1.1840312480926514, 1.1906942129135132, 1.1829643249511719], 'val_accuracy': [0.4457831382751465, 0.4337349534034729, 0.45783132314682007, 0.4819277226924896, 0.5180723071098328, 0.5421686768531799, 0.46987950801849365, 0.4939759075641632, 0.5301204919815063, 0.4939759075641632, 0.5060241222381592], 'lr': [0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.00020000001, 0.00020000001]}
