## üêü Fish Image Species Classification

Given *images of fish*, let's try to predict the **species** of fish present in a given image.

We will use a Tensorflow/Keras pretrained CNN to make our predictions.

Data source: https://www.kaggle.com/datasets/crowww/a-large-scale-fish-dataset

### Importing Libraries

In [1]:
import numpy as np
import pandas as pd

from pathlib import Path
import os.path

from sklearn.model_selection import train_test_split

import tensorflow as tf

In [2]:
image_dir = Path('archive/Fish_Dataset/Fish_Dataset')
image_dir

PosixPath('archive/Fish_Dataset/Fish_Dataset')

### Creating File DataFrame

In [5]:
# Get filepaths and labels
filepaths = list(image_dir.glob(r'**/*.png'))
filepaths

[PosixPath('archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea Bass/00455.png'),
 PosixPath('archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea Bass/00158.png'),
 PosixPath('archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea Bass/00727.png'),
 PosixPath('archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea Bass/00927.png'),
 PosixPath('archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea Bass/00187.png'),
 PosixPath('archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea Bass/00464.png'),
 PosixPath('archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea Bass/00383.png'),
 PosixPath('archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea Bass/00458.png'),
 PosixPath('archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea Bass/00431.png'),
 PosixPath('archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea Bass/00166.png'),
 PosixPath('archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea Bass/00941.png'),
 PosixPath('archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea Bass/00946.png'),
 PosixPath('archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea Bass/00958.png'),

In [13]:
labels = list(map(lambda x: os.path.split(os.path.split(x)[0])[1], filepaths))
labels

['Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',
 'Sea Bass',

In [14]:
filepaths = pd.Series(filepaths, name='Filepath').astype(str)
labels = pd.Series(labels, name='Label')

# Concatenate filepaths and labels
image_df = pd.concat([filepaths, labels], axis=1)

In [15]:
image_df

Unnamed: 0,Filepath,Label
0,archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea...,Sea Bass
1,archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea...,Sea Bass
2,archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea...,Sea Bass
3,archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea...,Sea Bass
4,archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea...,Sea Bass
...,...,...
17995,archive/Fish_Dataset/Fish_Dataset/Trout/Trout/...,Trout
17996,archive/Fish_Dataset/Fish_Dataset/Trout/Trout/...,Trout
17997,archive/Fish_Dataset/Fish_Dataset/Trout/Trout/...,Trout
17998,archive/Fish_Dataset/Fish_Dataset/Trout/Trout/...,Trout


In [17]:
# Drop GT images
image_df['Label'] = image_df['Label'].apply(lambda x: np.nan if x[-2:] == 'GT' else x)

In [18]:
image_df = image_df.dropna(axis=0)
image_df

Unnamed: 0,Filepath,Label
0,archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea...,Sea Bass
1,archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea...,Sea Bass
2,archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea...,Sea Bass
3,archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea...,Sea Bass
4,archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea...,Sea Bass
...,...,...
17995,archive/Fish_Dataset/Fish_Dataset/Trout/Trout/...,Trout
17996,archive/Fish_Dataset/Fish_Dataset/Trout/Trout/...,Trout
17997,archive/Fish_Dataset/Fish_Dataset/Trout/Trout/...,Trout
17998,archive/Fish_Dataset/Fish_Dataset/Trout/Trout/...,Trout


In [20]:
image_df.sample(200*9)['Label'].value_counts()

Label
Red Sea Bream         224
Black Sea Sprat       214
Hourse Mackerel       204
Gilt-Head Bream       200
Sea Bass              199
Striped Red Mullet    197
Shrimp                192
Red Mullet            187
Trout                 183
Name: count, dtype: int64

In [21]:
image_df['Label'].value_counts()

Label
Sea Bass              1000
Striped Red Mullet    1000
Gilt-Head Bream       1000
Red Mullet            1000
Hourse Mackerel       1000
Shrimp                1000
Black Sea Sprat       1000
Red Sea Bream         1000
Trout                 1000
Name: count, dtype: int64

In [23]:
# Sample 200 images from each class
samples = []

for category in image_df['Label'].unique():
    category_slice = image_df.query("Label == @category")
    samples.append(category_slice.sample(200, random_state=1))

In [24]:
samples

[                                              Filepath     Label
 507  archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea...  Sea Bass
 818  archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea...  Sea Bass
 452  archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea...  Sea Bass
 368  archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea...  Sea Bass
 242  archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea...  Sea Bass
 ..                                                 ...       ...
 430  archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea...  Sea Bass
 874  archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea...  Sea Bass
 550  archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea...  Sea Bass
 608  archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea...  Sea Bass
 207  archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea...  Sea Bass
 
 [200 rows x 2 columns],
                                                Filepath               Label
 2507  archive/Fish_Dataset/Fish_Dataset/Striped Red ...  Striped Red Mullet
 2818  archive/Fish_Dataset

In [25]:
image_df = pd.concat(samples, axis=0).sample(frac=1.0, random_state=1).reset_index(drop=True)

In [26]:
image_df

Unnamed: 0,Filepath,Label
0,archive/Fish_Dataset/Fish_Dataset/Red Sea Brea...,Red Sea Bream
1,archive/Fish_Dataset/Fish_Dataset/Gilt-Head Br...,Gilt-Head Bream
2,archive/Fish_Dataset/Fish_Dataset/Red Mullet/R...,Red Mullet
3,archive/Fish_Dataset/Fish_Dataset/Black Sea Sp...,Black Sea Sprat
4,archive/Fish_Dataset/Fish_Dataset/Hourse Macke...,Hourse Mackerel
...,...,...
1795,archive/Fish_Dataset/Fish_Dataset/Hourse Macke...,Hourse Mackerel
1796,archive/Fish_Dataset/Fish_Dataset/Trout/Trout/...,Trout
1797,archive/Fish_Dataset/Fish_Dataset/Shrimp/Shrim...,Shrimp
1798,archive/Fish_Dataset/Fish_Dataset/Striped Red ...,Striped Red Mullet


In [28]:
image_df['Label'].value_counts()

Label
Red Sea Bream         200
Gilt-Head Bream       200
Red Mullet            200
Black Sea Sprat       200
Hourse Mackerel       200
Sea Bass              200
Trout                 200
Shrimp                200
Striped Red Mullet    200
Name: count, dtype: int64

In [30]:
train_df, test_df = train_test_split(image_df, train_size=0.7, shuffle=True, random_state=1)

In [31]:
train_df

Unnamed: 0,Filepath,Label
1145,archive/Fish_Dataset/Fish_Dataset/Sea Bass/Sea...,Sea Bass
927,archive/Fish_Dataset/Fish_Dataset/Striped Red ...,Striped Red Mullet
1189,archive/Fish_Dataset/Fish_Dataset/Hourse Macke...,Hourse Mackerel
1065,archive/Fish_Dataset/Fish_Dataset/Shrimp/Shrim...,Shrimp
671,archive/Fish_Dataset/Fish_Dataset/Gilt-Head Br...,Gilt-Head Bream
...,...,...
905,archive/Fish_Dataset/Fish_Dataset/Trout/Trout/...,Trout
1791,archive/Fish_Dataset/Fish_Dataset/Hourse Macke...,Hourse Mackerel
1096,archive/Fish_Dataset/Fish_Dataset/Trout/Trout/...,Trout
235,archive/Fish_Dataset/Fish_Dataset/Striped Red ...,Striped Red Mullet


### Loading the Images

In [34]:
train_generator = tf.keras.preprocessing.image.ImageDataGenerator(
    preprocessing_function = tf.keras.applications.mobilenet_v2.preprocess_input,
    validation_split = 0.2
)

test_generator = tf.keras.preprocessing.image.ImageDataGenerator(
    preprocessing_function = tf.keras.applications.mobilenet_v2.preprocess_input
)

In [35]:
train_images = train_generator.flow_from_dataframe(
    dataframe = train_df,
    x_col = 'Filepath',
    y_col = 'Label',
    target_size = (224, 224),
    color_mode = 'rgb',
    class_mode = 'categorical',
    batch_size = 32,
    shuffle = True,
    seed = 42,
    subset = 'training'
)

val_images = train_generator.flow_from_dataframe(
    dataframe = train_df,
    x_col = 'Filepath',
    y_col = 'Label',
    target_size = (224, 224),
    color_mode = 'rgb',
    class_mode = 'categorical',
    batch_size = 32,
    shuffle = True,
    seed = 42,
    subset = 'validation'
)

test_images = test_generator.flow_from_dataframe(
    dataframe = test_df,
    x_col = 'Filepath',
    y_col = 'Label',
    target_size = (224, 224),
    color_mode = 'rgb',
    class_mode = 'categorical',
    batch_size = 32,
    shuffle = False
)

Found 1008 validated image filenames belonging to 9 classes.
Found 252 validated image filenames belonging to 9 classes.
Found 540 validated image filenames belonging to 9 classes.


### Load Pretrained Model

In [33]:
pretrained_model = tf.keras.applications.MobileNetV2(
    input_shape = (224, 224, 3),
    include_top = False,
    weights = 'imagenet',
    pooling = 'avg'
)

pretrained_model.trainable = False

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top.h5
[1m9406464/9406464[0m [32m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[37m[0m [1m7s[0m 1us/step


### Training

In [36]:
pretrained_model.summary()

In [39]:
next(train_images)[1]

array([[0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0.

In [40]:
inputs = pretrained_model.input

x = tf.keras.layers.Dense(128, activation='relu')(pretrained_model.output)
x = tf.keras.layers.Dense(128, activation='relu')(x)

outputs = tf.keras.layers.Dense(9, activation='softmax')(x)

model = tf.keras.Model(inputs=inputs, outputs=outputs)

model.compile(
    optimizer = 'adam',
    loss = 'categorical_crossentropy',
    metrics = ['accuracy']
)

history = model.fit(
    train_images,
    validation_data = val_images,
    epochs = 100,
    callbacks = [
        tf.keras.callbacks.EarlyStopping(
            monitor = 'val_loss',
            patience = 3,
            restore_best_weights = True
        )
    ]
)

  self._warn_if_super_not_called()


Epoch 1/100


2025-10-20 11:22:05.272225: W external/local_xla/xla/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 51380224 exceeds 10% of free system memory.
2025-10-20 11:22:05.449210: W external/local_xla/xla/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 51380224 exceeds 10% of free system memory.
2025-10-20 11:22:05.567331: W external/local_xla/xla/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 154140672 exceeds 10% of free system memory.
2025-10-20 11:22:05.753452: W external/local_xla/xla/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 156905472 exceeds 10% of free system memory.
2025-10-20 11:22:05.869133: W external/local_xla/xla/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 38535168 exceeds 10% of free system memory.


[1m32/32[0m [32m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[37m[0m [1m51s[0m 1s/step - accuracy: 0.7569 - loss: 0.8446 - val_accuracy: 0.9683 - val_loss: 0.1363
Epoch 2/100
[1m32/32[0m [32m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[37m[0m [1m41s[0m 1s/step - accuracy: 0.9812 - loss: 0.0884 - val_accuracy: 0.9921 - val_loss: 0.0481
Epoch 3/100
[1m32/32[0m [32m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[37m[0m [1m41s[0m 1s/step - accuracy: 0.9940 - loss: 0.0338 - val_accuracy: 0.9960 - val_loss: 0.0334
Epoch 4/100
[1m32/32[0m [32m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[37m[0m [1m84s[0m 1s/step - accuracy: 0.9940 - loss: 0.0267 - val_accuracy: 0.9960 - val_loss: 0.0259
Epoch 5/100
[1m32/32[0m [32m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[37m[0m [1m43s[0m 1s/step - accuracy: 1.0000 - loss: 0.0081 - val_accuracy: 1.0000 - val_loss: 0.0155
Epoch 6

### Results

In [41]:
results = model.evaluate(test_images, verbose=0)

print("Test Loss: {:.5f}".format(results[0]))
print("Test Loss: {:.2f}%".format(results[1]*100))

Test Loss: 0.01968
Test Loss: 99.81%
