# This is a sample Jupyter Notebook

Below is an example of a code cell. 
Put your cursor into the cell and press Shift+Enter to execute it and select the next one, or click 'Run Cell' button.

Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings.

To learn more about Jupyter Notebooks in PyCharm, see [help](https://www.jetbrains.com/help/pycharm/ipython-notebook-support.html).
For an overview of PyCharm, go to Help -> Learn IDE features or refer to [our documentation](https://www.jetbrains.com/help/pycharm/getting-started.html).

In [1]:
# --- Part 1: Training and Saving Models (train_and_save_models.py) -----------
# ==============================================================================
# This script handles data loading, model building, training, and saving.
# It is designed to be run once to generate your trained model files.

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers, models
from tensorflow.keras.applications import MobileNetV2
import os

In [2]:
# --- Configuration ---
# IMPORTANT: Adjust these variables according to your dataset.
NUM_CLASSES = 11  # Set this to the number of subdirectories in your data folders.
IMG_SIZE = (224, 224)
BATCH_SIZE = 32
EPOCHS_CNN = 8
EPOCHS_TRANSFER = 6

# Define your data directories. An empty string for base_dir
# assumes the train/val/test folders are in the same directory as this script.
base_dir = ''
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'val')
test_dir = os.path.join(base_dir, 'test')

# --- Data Loading and Augmentation ---
print("Setting up data generators...")
try:
    # Data augmentation for the training set
    train_datagen = ImageDataGenerator(
        rescale=1./255,
        rotation_range=40,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest'
    )

    # Simple rescaling for validation and test sets
    validation_datagen = ImageDataGenerator(rescale=1./255)
    test_datagen = ImageDataGenerator(rescale=1./255)

    train_generator = train_datagen.flow_from_directory(
        train_dir,
        target_size=IMG_SIZE,
        batch_size=BATCH_SIZE,
        class_mode='categorical'
    )

    validation_generator = validation_datagen.flow_from_directory(
        validation_dir,
        target_size=IMG_SIZE,
        batch_size=BATCH_SIZE,
        class_mode='categorical'
    )

    test_generator = test_datagen.flow_from_directory(
        test_dir,
        target_size=IMG_SIZE,
        batch_size=BATCH_SIZE,
        class_mode='categorical'
    )

    class_names = list(train_generator.class_indices.keys())
    print("\nData generators created successfully!")
    print(f"Detected classes: {class_names}")

except Exception as e:
    print(f"\nError loading data: {e}")
    print("Please check your file paths and ensure your dataset is correctly structured.")
    exit()

Setting up data generators...
Found 6225 images belonging to 11 classes.
Found 1092 images belonging to 11 classes.
Found 3187 images belonging to 11 classes.

Data generators created successfully!
Detected classes: ['animal fish', 'animal fish bass', 'fish sea_food black_sea_sprat', 'fish sea_food gilt_head_bream', 'fish sea_food hourse_mackerel', 'fish sea_food red_mullet', 'fish sea_food red_sea_bream', 'fish sea_food sea_bass', 'fish sea_food shrimp', 'fish sea_food striped_red_mullet', 'fish sea_food trout']


In [4]:
# ==============================================================================
# --- Model Definition and Training ---------------------------------------
# The following functions define and train your models.
# ==============================================================================
def create_cnn_model():
    """Builds a custom CNN model from scratch."""
    model = models.Sequential([
        layers.Conv2D(32, (3, 3), activation='relu', input_shape=(IMG_SIZE[0], IMG_SIZE[1], 3)),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(128, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Flatten(),
        layers.Dense(512, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(NUM_CLASSES, activation='softmax')
    ])
    model.compile(optimizer='adam',
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    return model

def create_transfer_learning_model():
    """Builds a transfer learning model using a pre-trained MobileNetV2."""
    base_model = MobileNetV2(
        input_shape=(IMG_SIZE[0], IMG_SIZE[1], 3),
        include_top=False,
        weights='imagenet'
    )
    base_model.trainable = False
    
    model = models.Sequential([
        base_model,
        layers.GlobalAveragePooling2D(),
        layers.Dense(128, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(NUM_CLASSES, activation='softmax')
    ])
    
    model.compile(optimizer='adam',
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    return model
# Train the models and save them
print("\nTraining Custom CNN Model...")
cnn_model = create_cnn_model()
cnn_model.summary()

history_cnn = cnn_model.fit(
    train_generator,
    epochs=EPOCHS_CNN,
    validation_data=validation_generator
)
cnn_model.save('fish_classifier_cnn.h5')


print("\nTraining Transfer Learning Model...")
transfer_model = create_transfer_learning_model()
transfer_model.summary()

history_transfer = transfer_model.fit(
    train_generator,
    epochs=EPOCHS_TRANSFER,
    validation_data=validation_generator
)
transfer_model.save('fish_classifier_transfer.h5')

print("\nModels saved successfully: 'fish_classifier_cnn.h5' and 'fish_classifier_transfer.h5'")





Training Custom CNN Model...


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


  self._warn_if_super_not_called()


Epoch 1/8
[1m195/195[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m95s[0m 480ms/step - accuracy: 0.2586 - loss: 2.2476 - val_accuracy: 0.5522 - val_loss: 1.2402
Epoch 2/8
[1m195/195[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m93s[0m 477ms/step - accuracy: 0.5520 - loss: 1.2574 - val_accuracy: 0.6520 - val_loss: 0.9703
Epoch 3/8
[1m195/195[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m93s[0m 478ms/step - accuracy: 0.6913 - loss: 0.8847 - val_accuracy: 0.8086 - val_loss: 0.5339
Epoch 4/8
[1m195/195[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m93s[0m 477ms/step - accuracy: 0.7416 - loss: 0.7072 - val_accuracy: 0.8370 - val_loss: 0.5179
Epoch 5/8
[1m195/195[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m92s[0m 473ms/step - accuracy: 0.7842 - loss: 0.6195 - val_accuracy: 0.7280 - val_loss: 0.7060
Epoch 6/8
[1m195/195[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m92s[0m 470ms/step - accuracy: 0.7807 - loss: 0.6052 - val_accuracy: 0.8626 - val_loss: 0.4240
Epoch 7/8
[1m19




Training Transfer Learning Model...


Epoch 1/6
[1m195/195[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m69s[0m 341ms/step - accuracy: 0.5916 - loss: 1.2520 - val_accuracy: 0.9533 - val_loss: 0.2093
Epoch 2/6
[1m195/195[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m65s[0m 332ms/step - accuracy: 0.8989 - loss: 0.3156 - val_accuracy: 0.9634 - val_loss: 0.1126
Epoch 3/6
[1m195/195[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m65s[0m 333ms/step - accuracy: 0.9318 - loss: 0.2339 - val_accuracy: 0.9725 - val_loss: 0.0816
Epoch 4/6
[1m195/195[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m71s[0m 365ms/step - accuracy: 0.9472 - loss: 0.1760 - val_accuracy: 0.9734 - val_loss: 0.0785
Epoch 5/6
[1m195/195[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m69s[0m 352ms/step - accuracy: 0.9461 - loss: 0.1598 - val_accuracy: 0.9762 - val_loss: 0.0576
Epoch 6/6
[1m195/195[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 395ms/step - accuracy: 0.9535 - loss: 0.1367 - val_accuracy: 0.9734 - val_loss: 0.0761





Models saved successfully: 'fish_classifier_cnn.h5' and 'fish_classifier_transfer.h5'
