In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.optimizers import Adam
from sklearn.utils.class_weight import compute_class_weight
import matplotlib.pyplot as plt
import numpy as np
import os

# Set random seed for reproducibility
tf.random.set_seed(42)

# -------------------- 1. Paths & Parameters --------------------
data_dir = 'new_dataset'  # Adjust path as needed
batch_size = 32
img_height = 224
img_width = 224
epochs = 60

# -------------------- 2. Data Generators --------------------
# Training data with augmentation
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    validation_split=0.2
)

# Validation data without augmentation
val_datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=0.2
)

# Training generator
train_generator = train_datagen.flow_from_directory(
    data_dir,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical',
    subset='training',
    shuffle=True
)

# Validation generator
val_generator = val_datagen.flow_from_directory(
    data_dir,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical',
    subset='validation',
    shuffle=False
)

num_classes = len(train_generator.class_indices)
print(f"Number of classes: {num_classes}")

# -------------------- 3. Compute Class Weights --------------------
labels = train_generator.classes
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(labels),
    y=labels
)
class_weights_dict = dict(enumerate(class_weights))

# -------------------- 4. Build CNN Model (Option A: Custom CNN) --------------------
model = models.Sequential([
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(img_height, img_width, 3)),
    layers.BatchNormalization(),
    layers.MaxPooling2D(2, 2),

    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.BatchNormalization(),
    layers.MaxPooling2D(2, 2),

    layers.Conv2D(128, (3, 3), activation='relu'),
    layers.BatchNormalization(),
    layers.MaxPooling2D(2, 2),

    layers.Conv2D(256, (3, 3), activation='relu'),
    layers.BatchNormalization(),
    layers.MaxPooling2D(2, 2),

    layers.Flatten(),
    layers.Dense(512, activation='relu'),
    layers.Dropout(0.5),
    layers.Dense(num_classes, activation='softmax')
])

# -------------------- Optional: Use Transfer Learning (Option B) --------------------
# from tensorflow.keras.applications import MobileNetV2
# from tensorflow.keras.models import Model
#
# base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(img_height, img_width, 3))
# base_model.trainable = False  # Freeze layers initially
#
# x = layers.GlobalAveragePooling2D()(base_model.output)
# x = layers.Dense(128, activation='relu')(x)
# x = layers.Dropout(0.4)(x)
# output = layers.Dense(num_classes, activation='softmax')(x)
#
# model = Model(inputs=base_model.input, outputs=output)

# -------------------- 5. Compile Model --------------------
model.compile(
    optimizer=Adam(learning_rate=1e-4),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

model.summary()

# -------------------- 6. Callbacks --------------------
callbacks = [
    EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3),
    ModelCheckpoint('best_model.h5', monitor='val_accuracy', save_best_only=True)
]

# -------------------- 7. Train Model --------------------
history = model.fit(
    train_generator,
    steps_per_epoch=train_generator.samples // batch_size,
    validation_data=val_generator,
    validation_steps=val_generator.samples // batch_size,
    epochs=epochs,
    callbacks=callbacks,
    class_weight=class_weights_dict
)

# -------------------- 8. Save Final Model --------------------
model.save('image_classifier_final.h5')

# -------------------- 9. Plot Training History --------------------
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid()

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid()

plt.tight_layout()
plt.savefig('training_history.png')
plt.show()


Found 8381 images belonging to 30 classes.
Found 2082 images belonging to 30 classes.
Number of classes: 30


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


  self._warn_if_super_not_called()


Epoch 1/60
[1m159/261[0m [32m━━━━━━━━━━━━[0m[37m━━━━━━━━[0m [1m4:44[0m 3s/step - accuracy: 0.1956 - loss: 4.1062



[1m237/261[0m [32m━━━━━━━━━━━━━━━━━━[0m[37m━━[0m [1m1:07[0m 3s/step - accuracy: 0.2139 - loss: 3.7776