# **Project Name**    -



##### **Project Type**    - Classification
##### **Contribution**    - Individual
##### **Team Member 1 - Aswin K J**

# **Project Summary -**

The project is a comprehensive exploration into building an effective deep learning model for fish image classification, with the end goal of deploying it into a user-friendly web application. The core task was to accurately identify 11 different species of fish from a provided image dataset. The project began with a crucial data preparation phase, where the images were systematically organized into training, validation, and test sets. To ensure the model's robustness and prevent overfitting a common issue where a model memorizes training data instead of learning general features—a technique called data augmentation was applied. This involved creating modified versions of the training images by randomly rotating, flipping, shifting, and zooming them, thereby artificially expanding the dataset and teaching the model to recognize fish from various angles and positions. Five distinct models were then built and rigorously tested to determine the most suitable architecture for this task. The first was a custom Convolutional Neural Network (CNN) created from scratch. The other four leveraged a powerful technique known as transfer learning, where pre-trained models that have already learned to recognize features from millions of images on the ImageNet dataset are adapted for a new task. The architectures chosen for this approach were VGG16, ResNet50, MobileNetV2, and EfficientNetB0. For these models, the pre-trained base layers were kept frozen, and a new custom classifier head was added and trained specifically on the fish images, making the training process highly efficient. After training each model for 10 epochs, a detailed evaluation was conducted on the unseen test data. The results were compiled and compared based on several key metrics, including accuracy, loss, precision, recall, and F1-score. MobileNetV2 emerged as the undisputed winner, achieving a remarkable test accuracy of 99.34%, making it exceptionally reliable for this classification problem. VGG16 and the custom CNN also showed strong results with 93.32% and 89.74% test accuracy, respectively, while ResNet50 and EfficientNetB0 were less effective in this specific setup. With MobileNetV2 identified as the top performer, the project concluded with the deployment phase. The best model was saved as a best_model.h5 file, and the corresponding class names were saved in a class_names.pkl file. Then during deployment since model couldn't be uploaded directly due to problems in the way the model was saved, hence the weights was taken and the model was loaded seperately in the app.py file and the Fish Image Classification project was successfully deployed in Streamlit Cloud.

# **GitHub Link -**

https://github.com/aswinkj2006/Fish-Image-Classification

# **Problem Statement**


This project focuses on classifying fish images into multiple categories using deep learning models. The task involves training a CNN from scratch and leveraging transfer learning with pre-trained models to enhance performance. The project also includes saving models for later use and deploying a Streamlit application to predict fish categories from user-uploaded images.

# **General Guidelines** : -  

1.   Well-structured, formatted, and commented code is required.
2.   Exception Handling, Production Grade Code & Deployment Ready Code will be a plus. Those students will be awarded some additional credits.
     
     The additional credits will have advantages over other students during Star Student selection.
       
             [ Note: - Deployment Ready Code is defined as, the whole .ipynb notebook should be executable in one go
                       without a single error logged. ]

3.   Each and every logic should have proper comments.
4. You may add as many number of charts you want. Make Sure for each and every chart the following format should be answered.
        

```
# Chart visualization code
```
            

*   Why did you pick the specific chart?
*   What is/are the insight(s) found from the chart?
* Will the gained insights help creating a positive business impact?
Are there any insights that lead to negative growth? Justify with specific reason.

5. You have to create at least 15 logical & meaningful charts having important insights.


[ Hints : - Do the Vizualization in  a structured way while following "UBM" Rule.

U - Univariate Analysis,

B - Bivariate Analysis (Numerical - Categorical, Numerical - Numerical, Categorical - Categorical)

M - Multivariate Analysis
 ]





6. You may add more ml algorithms for model creation. Make sure for each and every algorithm, the following format should be answered.


*   Explain the ML Model used and it's performance using Evaluation metric Score Chart.


*   Cross- Validation & Hyperparameter Tuning

*   Have you seen any improvement? Note down the improvement with updates Evaluation metric Score Chart.

*   Explain each evaluation metric's indication towards business and the business impact pf the ML model used.




















# ***Let's Begin !***

## ***1. Know Your Data***

### Import Libraries

In [None]:
# Import Libraries
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
import tensorflow as tf
from tensorflow.keras.applications import VGG16, ResNet50, MobileNetV2, EfficientNetB0
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
import pickle
import zipfile
import gdown
import time

### Dataset Loading

In [None]:
# Load Dataset

output_path = "data.zip"

# Create dataset directory in current folder

extract_dir = "data"

# Download dataset

gdown.download(f"https://drive.google.com/uc?id={FILE_ID}", output_path, quiet=False)
print("\n Downloaded Zip File from GDrive")

# Extract dataset
with zipfile.ZipFile(output_path, 'r') as zip_ref:
    zip_ref.extractall(extract_dir)
print("\n Extracted zip into 'data' ")

# Clean up
os.remove(output_path)

# Update base path
DATASET_BASE = os.path.join(extract_dir, "data")

TRAIN_PATH = os.path.join(DATASET_BASE, 'train')
VAL_PATH = os.path.join(DATASET_BASE, 'val')
TEST_PATH = os.path.join(DATASET_BASE, 'test')

print(f"Train Path : {TRAIN_PATH}")
print(f"Valid Path : {VAL_PATH}")
print(f"Test Path : {TEST_PATH}")

print("\n Dataset Folders Created Succesfully !!")

In [None]:
# Dataset First Look

In [None]:
# Dataset Rows & Columns count

In [None]:
# Dataset Info

#### Duplicate Values

In [None]:
# Dataset Duplicate Value Count

#### Missing Values/Null Values

In [None]:
# Missing Values/Null Values Count

In [None]:
# Visualizing the missing values

### What did you know about your dataset?

The given dataset contained folders train,test and val. Each of the folders had images of fishes of 11 categories which where:

1. Animal Fish (animal fish)
2. Animal Fish Bass (animal fish bass)
3. Fish Sea Food Black Sea Sprat (fish sea_food black_sea_sprat)
4. Fish Sea Food Gilt Head Bream (fish sea_food gilt_head_bream)
5. Fish Sea Food Hourse Mackerel (fish sea_food hourse_mackerel)
6. Fish Sea Food Red Mullet (fish sea_food red_mullet)
7. Fish Sea Food Red Sea Bream (fish sea_food red_sea_bream)
8. Fish Sea Food Sea Bass (fish sea_food sea_bass)
9. Fish Sea Food Shrimp (fish sea_food shrimp)
10. Fish Sea Food Striped Red Mullet (fish sea_food striped_red_mullet)
11. Fish Sea Food Trout (fish sea_food trout)

The data was extracted in such a way because the dataset has a size of 240+ MB hence it had to be uploaded to drive and from there it is extracted to a folder data from which the images are accessed. By this method, anyone can access the files by running this code.

In [None]:
# Dataset Columns

In [None]:
# Dataset Describe

### Variables Description

Answer Here

### Check Unique Values for each variable.

In [None]:
# Check Unique Values for each variable.

### Data Wrangling Code

In [None]:
# Write your code to make your dataset analysis ready.

### What all manipulations have you done and insights you found?

Answer Here.

#### Chart - 1

In [None]:
# Chart - 1 visualization code

##### 1. Why did you pick the specific chart?

Answer Here.

##### 2. What is/are the insight(s) found from the chart?

Answer Here

##### 3. Will the gained insights help creating a positive business impact?
Are there any insights that lead to negative growth? Justify with specific reason.

Answer Here

#### Chart - 2

In [None]:
# Chart - 2 visualization code

##### 1. Why did you pick the specific chart?

Answer Here.

##### 2. What is/are the insight(s) found from the chart?

Answer Here

##### 3. Will the gained insights help creating a positive business impact?
Are there any insights that lead to negative growth? Justify with specific reason.

Answer Here

#### Chart - 3

In [None]:
# Chart - 3 visualization code

##### 1. Why did you pick the specific chart?

Answer Here.

##### 2. What is/are the insight(s) found from the chart?

Answer Here

##### 3. Will the gained insights help creating a positive business impact?
Are there any insights that lead to negative growth? Justify with specific reason.

Answer Here

#### Chart - 4

In [None]:
# Chart - 4 visualization code

##### 1. Why did you pick the specific chart?

Answer Here.

##### 2. What is/are the insight(s) found from the chart?

Answer Here

##### 3. Will the gained insights help creating a positive business impact?
Are there any insights that lead to negative growth? Justify with specific reason.

Answer Here

#### Chart - 5

In [None]:
# Chart - 5 visualization code

##### 1. Why did you pick the specific chart?

Answer Here.

##### 2. What is/are the insight(s) found from the chart?

Answer Here

##### 3. Will the gained insights help creating a positive business impact?
Are there any insights that lead to negative growth? Justify with specific reason.

Answer Here

#### Chart - 6

In [None]:
# Chart - 6 visualization code

##### 1. Why did you pick the specific chart?

Answer Here.

##### 2. What is/are the insight(s) found from the chart?

Answer Here

##### 3. Will the gained insights help creating a positive business impact?
Are there any insights that lead to negative growth? Justify with specific reason.

Answer Here

#### Chart - 7

In [None]:
# Chart - 7 visualization code

##### 1. Why did you pick the specific chart?

Answer Here.

##### 2. What is/are the insight(s) found from the chart?

Answer Here

##### 3. Will the gained insights help creating a positive business impact?
Are there any insights that lead to negative growth? Justify with specific reason.

Answer Here

#### Chart - 8

In [None]:
# Chart - 8 visualization code

##### 1. Why did you pick the specific chart?

Answer Here.

##### 2. What is/are the insight(s) found from the chart?

Answer Here

##### 3. Will the gained insights help creating a positive business impact?
Are there any insights that lead to negative growth? Justify with specific reason.

Answer Here

#### Chart - 9

In [None]:
# Chart - 9 visualization code

##### 1. Why did you pick the specific chart?

Answer Here.

##### 2. What is/are the insight(s) found from the chart?

Answer Here

##### 3. Will the gained insights help creating a positive business impact?
Are there any insights that lead to negative growth? Justify with specific reason.

Answer Here

#### Chart - 10

In [None]:
# Chart - 10 visualization code

##### 1. Why did you pick the specific chart?

Answer Here.

##### 2. What is/are the insight(s) found from the chart?

Answer Here

##### 3. Will the gained insights help creating a positive business impact?
Are there any insights that lead to negative growth? Justify with specific reason.

Answer Here

#### Chart - 11

In [None]:
# Chart - 11 visualization code

##### 1. Why did you pick the specific chart?

Answer Here.

##### 2. What is/are the insight(s) found from the chart?

Answer Here

##### 3. Will the gained insights help creating a positive business impact?
Are there any insights that lead to negative growth? Justify with specific reason.

Answer Here

#### Chart - 12

In [None]:
# Chart - 12 visualization code

##### 1. Why did you pick the specific chart?

Answer Here.

##### 2. What is/are the insight(s) found from the chart?

Answer Here

##### 3. Will the gained insights help creating a positive business impact?
Are there any insights that lead to negative growth? Justify with specific reason.

Answer Here

#### Chart - 13

In [None]:
# Chart - 13 visualization code

##### 1. Why did you pick the specific chart?

Answer Here.

##### 2. What is/are the insight(s) found from the chart?

Answer Here

##### 3. Will the gained insights help creating a positive business impact?
Are there any insights that lead to negative growth? Justify with specific reason.

Answer Here

#### Chart - 14 - Correlation Heatmap

In [None]:
# Correlation Heatmap visualization code

##### 1. Why did you pick the specific chart?

Answer Here.

##### 2. What is/are the insight(s) found from the chart?

Answer Here

#### Chart - 15 - Pair Plot

In [None]:
# Pair Plot visualization code

##### 1. Why did you pick the specific chart?

Answer Here.

##### 2. What is/are the insight(s) found from the chart?

Answer Here

### Based on your chart experiments, define three hypothetical statements from the dataset. In the next three questions, perform hypothesis testing to obtain final conclusion about the statements through your code and statistical testing.

Answer Here.

### Hypothetical Statement - 1

#### 1. State Your research hypothesis as a null hypothesis and alternate hypothesis.

Answer Here.

#### 2. Perform an appropriate statistical test.

In [None]:
# Perform Statistical Test to obtain P-Value

##### Which statistical test have you done to obtain P-Value?

Answer Here.

##### Why did you choose the specific statistical test?

Answer Here.

### Hypothetical Statement - 2

#### 1. State Your research hypothesis as a null hypothesis and alternate hypothesis.

Answer Here.

#### 2. Perform an appropriate statistical test.

In [None]:
# Perform Statistical Test to obtain P-Value

##### Which statistical test have you done to obtain P-Value?

Answer Here.

##### Why did you choose the specific statistical test?

Answer Here.

### Hypothetical Statement - 3

#### 1. State Your research hypothesis as a null hypothesis and alternate hypothesis.

Answer Here.

#### 2. Perform an appropriate statistical test.

In [None]:
# Perform Statistical Test to obtain P-Value

##### Which statistical test have you done to obtain P-Value?

Answer Here.

##### Why did you choose the specific statistical test?

Answer Here.

## ***2. Data Augmentation & Engineering***

### Generating Image Data for Models

In [None]:
IMG_SIZE = (224, 224)
BATCH_SIZE = 32
EPOCHS = 10

# Create directories for saving models
os.makedirs('/data/models', exist_ok=True)

print("\n Generating Data for Models to be trained on")

# Data generators with augmentation
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True,
    zoom_range=0.2,
    shear_range=0.2
)

val_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)

# Create data generators
train_generator = train_datagen.flow_from_directory(
    TRAIN_PATH,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical'
)

validation_generator = val_datagen.flow_from_directory(
    VAL_PATH,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical'
)

test_generator = test_datagen.flow_from_directory(
    TEST_PATH,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=False
)

# Get class information
NUM_CLASSES = len(train_generator.class_indices)
CLASS_NAMES = list(train_generator.class_indices.keys())

# Clean up class names for better display
CLEAN_CLASS_NAMES = []
for name in CLASS_NAMES:
    clean_name = name.replace('animal_fish_', '').replace('fish_sea_food_', '').replace('_', ' ').title()
    CLEAN_CLASS_NAMES.append(clean_name)

print(f"\n📊 Dataset Info:")
print(f"Number of classes: {NUM_CLASSES}")
print(f"Training samples: {train_generator.samples}")
print(f"Validation samples: {validation_generator.samples}")
print(f"Test samples: {test_generator.samples}")

print(f"\n🏷️ Fish Classes:")
for i, (orig, clean) in enumerate(zip(CLASS_NAMES, CLEAN_CLASS_NAMES), 1):
    print(f"{i:2d}. {clean} ({orig})")

# Save class names
with open('/data/models/class_names.pkl', 'wb') as f:
    pickle.dump(CLEAN_CLASS_NAMES, f)

with open('/data/models/original_class_names.pkl', 'wb') as f:
    pickle.dump(CLASS_NAMES, f)

Answer Here.

In [None]:
# Handling Outliers & Outlier treatments

##### What all outlier treatment techniques have you used and why did you use those techniques?

Answer Here.

In [None]:
# Encode your categorical columns

Answer Here.

In [None]:
# Expand Contraction

In [None]:
# Lower Casing

In [None]:
# Remove Punctuations

In [None]:
# Remove URLs & Remove words and digits contain digits

In [None]:
# Remove Stopwords

In [None]:
# Remove White spaces

In [None]:
# Rephrase Text

In [None]:
# Tokenization

In [None]:
# Normalizing Text (i.e., Stemming, Lemmatization etc.)

##### Which text normalization technique have you used and why?

Answer Here.

In [None]:
# POS Taging

In [None]:
# Vectorizing Text

##### Which text vectorization technique have you used and why?

Answer Here.

In [None]:
# Manipulate Features to minimize feature correlation and create new features

In [None]:
# Select your features wisely to avoid overfitting

##### What all feature selection methods have you used  and why?

Answer Here.

##### Which all features you found important and why?

Answer Here.

In [None]:
# Transform Your data

In [None]:
# Scaling your data

##### Which method have you used to scale you data and why?

##### Do you think that dimensionality reduction is needed? Explain Why?

Answer Here.

In [None]:
# DImensionality Reduction (If needed)

##### Which dimensionality reduction technique have you used and why? (If dimensionality reduction done on dataset.)

Answer Here.

In [None]:
# Split your data to train and test. Choose Splitting ratio wisely.

##### What data splitting ratio have you used and why?

Answer Here.

##### Do you think the dataset is imbalanced? Explain Why.

Answer Here.

In [None]:
# Handling Imbalanced Dataset (If needed)

##### What technique did you use to handle the imbalance dataset and why? (If needed to be balanced)

Answer Here.

## ***3. ML Model Implementation***

### Creating CNN, Usage and Training of other Models

In [None]:
#Creating Models

print("\n Creating and Training and Evaluating Models")

#1. CNN MODEL

def create_cnn_from_scratch():
    model = Sequential([
        tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(224, 224, 3)),
        tf.keras.layers.MaxPooling2D(2, 2),
        tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
        tf.keras.layers.MaxPooling2D(2, 2),
        tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),
        tf.keras.layers.MaxPooling2D(2, 2),
        tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),
        tf.keras.layers.MaxPooling2D(2, 2),
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dropout(0.5),
        tf.keras.layers.Dense(512, activation='relu'),
        tf.keras.layers.Dropout(0.5),
        tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')
    ])
    return model

def create_transfer_learning_model(base_model_name, input_shape=(224, 224, 3)):
    base_models = {
        'VGG16': VGG16,
        'ResNet50': ResNet50,
        'MobileNetV2': MobileNetV2,
        'EfficientNetB0': EfficientNetB0
    }
    
    base_model = base_models[base_model_name](
        weights='imagenet',
        include_top=False,
        input_shape=input_shape
    )
    
    base_model.trainable = False
    
    model = Sequential([
        base_model,
        GlobalAveragePooling2D(),
        Dropout(0.2),
        Dense(128, activation='relu'),
        Dropout(0.2),
        Dense(NUM_CLASSES, activation='softmax')
    ])
    
    return model

#Training Models 

def train_model(model, model_name, epochs=EPOCHS):
    
    model.compile(
        optimizer=Adam(learning_rate=0.001),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    
    callbacks = [
        EarlyStopping(monitor='val_accuracy', patience=10, restore_best_weights=True),
        ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=1e-7),
        ModelCheckpoint(
            f'/data/models/{model_name}_best.h5',
            monitor='val_accuracy',
            save_best_only=True,
            mode='max'
        )
    ]
    
    print(f"\n🚀 Training {model_name}...")
    print("="*50)
    
    history = model.fit(
        train_generator,
        epochs=epochs,
        validation_data=validation_generator,
        callbacks=callbacks,
        verbose=1
    )
    
    return history

#Compare Models

def evaluate_model(model, model_name, history):
    
    # Evaluate on validation and test data
    val_loss, val_accuracy = model.evaluate(validation_generator, verbose=0)
    test_loss, test_accuracy = model.evaluate(test_generator, verbose=0)
    
    # Get predictions on test data
    test_generator.reset()
    predictions = model.predict(test_generator)
    predicted_classes = np.argmax(predictions, axis=1)
    true_classes = test_generator.classes
    
    # Classification report
    report = classification_report(
        true_classes, 
        predicted_classes, 
        target_names=CLEAN_CLASS_NAMES,
        output_dict=True
    )
    
    return {
        'model_name': model_name,
        'val_accuracy': val_accuracy,
        'val_loss': val_loss,
        'test_accuracy': test_accuracy,
        'test_loss': test_loss,
        'classification_report': report,
        'history': history.history
    }

def plot_training_history(history, model_name):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
    
    # Accuracy
    ax1.plot(history.history['accuracy'], label='Training Accuracy', color='blue')
    ax1.plot(history.history['val_accuracy'], label='Validation Accuracy', color='red')
    ax1.set_title(f'{model_name} - Accuracy')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Accuracy')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Loss
    ax2.plot(history.history['loss'], label='Training Loss', color='blue')
    ax2.plot(history.history['val_loss'], label='Validation Loss', color='red')
    ax2.set_title(f'{model_name} - Loss')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Loss')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(f'/data/models/{model_name}_training_history.png', dpi=300, bbox_inches='tight')
    plt.show()

#MODEL TRAINING LOOP
print("\n🚀 Starting Model Training...")
print("="*60)

# Models to train
models_to_train = [
    ('CNN', create_cnn_from_scratch),
    ('VGG16', lambda: create_transfer_learning_model('VGG16')),
    ('ResNet50', lambda: create_transfer_learning_model('ResNet50')),
    ('MobileNetV2', lambda: create_transfer_learning_model('MobileNetV2')),
    ('EfficientNetB0', lambda: create_transfer_learning_model('EfficientNetB0'))
]

# Store results
results = []
trained_models = {}

# Train all models
for model_name, model_func in models_to_train:
    print(f"\n{'='*60}")
    print(f"🧠 TRAINING: {model_name}")
    print(f"{'='*60}")
    
    # Create model
    model = model_func()
    print(f"📊 Model Parameters: {model.count_params():,}")
    
    # Train model
    history = train_model(model, model_name)
    
    # Evaluate model
    result = evaluate_model(model, model_name, history)
    results.append(result)
    trained_models[model_name] = model
    
    # Plot training history
    plot_training_history(history, model_name)
    
    # Print results
    print(f"\n📈 {model_name} Results:")
    print(f"   Validation Accuracy: {result['val_accuracy']:.4f}")
    print(f"   Test Accuracy: {result['test_accuracy']:.4f}")
    print(f"   Test Loss: {result['test_loss']:.4f}")

#Comparison

# Create comparison DataFrame
comparison_data = []
for result in results:
    comparison_data.append({
        'Model': result['model_name'],
        'Validation Accuracy': result['val_accuracy'],
        'Test Accuracy': result['test_accuracy'],
        'Validation Loss': result['val_loss'],
        'Test Loss': result['test_loss'],
        'Precision (Macro)': result['classification_report']['macro avg']['precision'],
        'Recall (Macro)': result['classification_report']['macro avg']['recall'],
        'F1-Score (Macro)': result['classification_report']['macro avg']['f1-score']
    })

comparison_df = pd.DataFrame(comparison_data)
comparison_df = comparison_df.sort_values('Test Accuracy', ascending=False)

print("\n" + "="*80)
print("🏆 FINAL MODEL COMPARISON RESULTS")
print("="*80)
print(comparison_df.round(4).to_string(index=False))

# Find best model
best_model_name = comparison_df.iloc[0]['Model']
best_model = trained_models[best_model_name]
best_test_accuracy = comparison_df.iloc[0]['Test Accuracy']
best_val_accuracy = comparison_df.iloc[0]['Validation Accuracy']

print(f"\n🏆 BEST MODEL: {best_model_name}")
print(f"🎯 Test Accuracy: {best_test_accuracy:.4f}")
print(f"📊 Validation Accuracy: {best_val_accuracy:.4f}")

# Save the best model
best_model.save('/data/models/best_model.h5')
print(f"\n✅ Best model saved as '/data/models/best_model.h5'")

# Save comparison results
comparison_df.to_csv('/data/models/model_comparison.csv', index=False)

# Save detailed results
with open('/data/models/detailed_results.pkl', 'wb') as f:
    pickle.dump(results, f)

# Model comparison plots
plt.figure(figsize=(18, 12))

# Validation Accuracy
plt.subplot(2, 3, 1)
bars = plt.bar(comparison_df['Model'], comparison_df['Validation Accuracy'], color='lightblue', alpha=0.8)
plt.title('Validation Accuracy Comparison', fontsize=14, fontweight='bold')
plt.ylabel('Accuracy')
plt.xticks(rotation=45)
for i, bar in enumerate(bars):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.005, 
             f'{comparison_df.iloc[i]["Validation Accuracy"]:.3f}', 
             ha='center', va='bottom', fontsize=10)

# Test Accuracy
plt.subplot(2, 3, 2)
bars = plt.bar(comparison_df['Model'], comparison_df['Test Accuracy'], color='lightgreen', alpha=0.8)
plt.title('Test Accuracy Comparison', fontsize=14, fontweight='bold')
plt.ylabel('Accuracy')
plt.xticks(rotation=45)
for i, bar in enumerate(bars):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.005, 
             f'{comparison_df.iloc[i]["Test Accuracy"]:.3f}', 
             ha='center', va='bottom', fontsize=10)

# Validation Loss
plt.subplot(2, 3, 3)
bars = plt.bar(comparison_df['Model'], comparison_df['Validation Loss'], color='lightcoral', alpha=0.8)
plt.title('Validation Loss Comparison', fontsize=14, fontweight='bold')
plt.ylabel('Loss')
plt.xticks(rotation=45)

# F1-Score
plt.subplot(2, 3, 4)
bars = plt.bar(comparison_df['Model'], comparison_df['F1-Score (Macro)'], color='lightyellow', alpha=0.8)
plt.title('F1-Score Comparison', fontsize=14, fontweight='bold')
plt.ylabel('F1-Score')
plt.xticks(rotation=45)

# Combined metrics
plt.subplot(2, 3, 5)
metrics = ['Precision (Macro)', 'Recall (Macro)', 'F1-Score (Macro)']
x = np.arange(len(comparison_df))
width = 0.25

for i, metric in enumerate(metrics):
    plt.bar(x + i*width, comparison_df[metric], width, label=metric.split(' ')[0], alpha=0.8)

plt.xlabel('Models')
plt.ylabel('Score')
plt.title('Precision, Recall, F1-Score Comparison', fontsize=14, fontweight='bold')
plt.xticks(x + width, comparison_df['Model'], rotation=45)
plt.legend()

# Accuracy comparison (Val vs Test)
plt.subplot(2, 3, 6)
x = np.arange(len(comparison_df))
width = 0.35

plt.bar(x - width/2, comparison_df['Validation Accuracy'], width, label='Validation', alpha=0.8, color='blue')
plt.bar(x + width/2, comparison_df['Test Accuracy'], width, label='Test', alpha=0.8, color='orange')

plt.xlabel('Models')
plt.ylabel('Accuracy')
plt.title('Validation vs Test Accuracy', fontsize=14, fontweight='bold')
plt.xticks(x, comparison_df['Model'], rotation=45)
plt.legend()

plt.tight_layout()
plt.savefig('/data/models/model_comparison_plots.png', dpi=300, bbox_inches='tight')
plt.show()

# Generate confusion matrix for best model

test_generator.reset()
predictions = best_model.predict(test_generator)
predicted_classes = np.argmax(predictions, axis=1)
true_classes = test_generator.classes

cm = confusion_matrix(true_classes, predicted_classes)

plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=CLEAN_CLASS_NAMES, yticklabels=CLEAN_CLASS_NAMES,
            cbar_kws={'shrink': 0.8})
plt.title(f'Confusion Matrix - {best_model_name}', fontsize=16, fontweight='bold')
plt.ylabel('True Label', fontsize=12)
plt.xlabel('Predicted Label', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig('/data/models/confusion_matrix_best_model.png', dpi=300, bbox_inches='tight')
plt.show()

#Downloading Results

print("\n" + "="*80)
print("🎉 TRAINING COMPLETE!")
print("="*80)
print(f"🏆 Best Model: {best_model_name}")
print(f"🎯 Test Accuracy: {best_test_accuracy:.4f}")
print(f"📊 Validation Accuracy: {best_val_accuracy:.4f}")

print(f"\n📁 Files saved in '/data/models/':")
print("   ✅ best_model.h5 (best performing model)")
print("   ✅ class_names.pkl (clean class names)")
print("   ✅ original_class_names.pkl (original folder names)")
print("   ✅ model_comparison.csv (comparison results)")
print("   ✅ detailed_results.pkl (detailed results)")
print("   ✅ Training history plots")
print("   ✅ Model comparison plots")
print("   ✅ Confusion matrix")

print(f"\n🚀 Next Steps:")
print("1. Download 'best_model.h5' and 'class_names.pkl'")
print("2. Use these files in your Streamlit app")
print("3. Deploy your fish classification app!")

# Zip all results for easy download
print(f"\n📦 Creating download package...")
import shutil
shutil.make_archive('/data/fish_classification_results', 'zip', '/data/models')
print("✅ Results packaged in: /data/fish_classification_results.zip")

# Show final class mapping
print(f"\n🏷️ FINAL CLASS MAPPING:")
for i, (orig, clean) in enumerate(zip(CLASS_NAMES, CLEAN_CLASS_NAMES)):
    print(f"{i+1:2d}. '{orig}' → '{clean}'")


print(f"\n🎊 Fish Classification Project Complete! 🐟")

In [None]:
# Visualizing evaluation Metric Score chart

##### Which hyperparameter optimization technique have you used and why?

Answer Here.

##### Have you seen any improvement? Note down the improvement with updates Evaluation metric Score Chart.

Answer Here.

#### 1. Explain the ML Model used and it's performance using Evaluation metric Score Chart.

In [None]:
# Visualizing evaluation Metric Score chart

#### 2. Cross- Validation & Hyperparameter Tuning

In [None]:
# ML Model - 1 Implementation with hyperparameter optimization techniques (i.e., GridSearch CV, RandomSearch CV, Bayesian Optimization etc.)

# Fit the Algorithm

# Predict on the model

##### Which hyperparameter optimization technique have you used and why?

Answer Here.

##### Have you seen any improvement? Note down the improvement with updates Evaluation metric Score Chart.

Answer Here.

#### 3. Explain each evaluation metric's indication towards business and the business impact pf the ML model used.

Answer Here.

In [None]:
# ML Model - 3 Implementation

# Fit the Algorithm

# Predict on the model

#### 1. Explain the ML Model used and it's performance using Evaluation metric Score Chart.

In [None]:
# Visualizing evaluation Metric Score chart

#### 2. Cross- Validation & Hyperparameter Tuning

In [None]:
# ML Model - 3 Implementation with hyperparameter optimization techniques (i.e., GridSearch CV, RandomSearch CV, Bayesian Optimization etc.)

# Fit the Algorithm

# Predict on the model

##### Which hyperparameter optimization technique have you used and why?

Answer Here.

##### Have you seen any improvement? Note down the improvement with updates Evaluation metric Score Chart.

Answer Here.

Answer Here.

### 2. Which ML model did you choose from the above created models as your final prediction model and why?

I have chosen MobileNetV2 as its the best perfoming out of all the ones used. The training was done in just 10 epochs due to many constraints but still MobileNetV2 was able to achieve highest levels of accuracy combined with lowest levels of loss out of all models. The main reason behind this might be due to the Mobility of this model to process all the data on such short epoch range.

Answer Here.

## ***8.*** ***Future Work***

### 1. Streamlit App Deployment Code


App Link - https://fish-image-classification.streamlit.app/


In [None]:
'''
import streamlit as st
import tensorflow as tf
import numpy as np
import pickle
from pathlib import Path
from PIL import Image

model_json_path = Path("data/models/mobilenetv2_model.json")
model_weights_path = Path("data/models/mobilenetv2.weights.h5")
class_names_path = Path("data/models/class_names.pkl")

if not model_json_path.exists() or not model_weights_path.exists() or not class_names_path.exists():
    raise FileNotFoundError("Model JSON, weights file, or class names file not found in data/models/ folder.")

with open(model_json_path, "r") as json_file:
    loaded_model_json = json_file.read()

model = tf.keras.models.model_from_json(loaded_model_json)
model.load_weights(model_weights_path)

with open(class_names_path, "rb") as f:
    class_names = pickle.load(f)

st.set_page_config(page_title="Fish Species Classifier", layout="centered")

st.markdown("<h1 style='text-align: center; color: #1f77b4;'>🐟 Fish Species Classification</h1>", unsafe_allow_html=True)
st.markdown("<p style='text-align: center; font-size:18px;'>Upload an image of a fish to identify its species.</p>", unsafe_allow_html=True)

st.markdown("""
### 🐠 Current Supported Fish Types:

1. **Animal Fish** (`animal fish`)  
2. **Animal Fish Bass** (`animal fish bass`)  
3. **Fish Sea Food Black Sea Sprat** (`fish sea_food black_sea_sprat`)  
4. **Fish Sea Food Gilt Head Bream** (`fish sea_food gilt_head_bream`)  
5. **Fish Sea Food Hourse Mackerel** (`fish sea_food hourse_mackerel`)  
6. **Fish Sea Food Red Mullet** (`fish sea_food red_mullet`)  
7. **Fish Sea Food Red Sea Bream** (`fish sea_food red_sea_bream`)  
8. **Fish Sea Food Sea Bass** (`fish sea_food sea_bass`)  
9. **Fish Sea Food Shrimp** (`fish sea_food shrimp`)  
10. **Fish Sea Food Striped Red Mullet** (`fish sea_food striped_red_mullet`)  
11. **Fish Sea Food Trout** (`fish sea_food trout`)
""")

st.markdown("### 📂 Upload Your Image")
uploaded_file = st.file_uploader("", type=["jpg", "jpeg", "png"])

st.markdown("### 📸 Or Try a Sample Image")
SAMPLES_FOLDER = Path("samples")
sample_images = {
    "Head Bream Fish": SAMPLES_FOLDER / "hb.jpg",
    "Horse Mackarel Fish": SAMPLES_FOLDER / "hm.jpg",
    "Sea Sprat Fish": SAMPLES_FOLDER / "ss.jpg",
    "Striped Red Mullet": SAMPLES_FOLDER / "srm.jpg"
}

sample_choice = st.selectbox("Or pick a sample image:", list(sample_images.keys()))
if st.button("Use Sample Image"):
    uploaded_file = open(sample_images[sample_choice], "rb")

if uploaded_file:
    image = Image.open(uploaded_file).convert("RGB")
    img_resized = image.resize((224, 224))
    img_array = np.array(img_resized) / 255.0
    img_array = np.expand_dims(img_array, axis=0)  # shape (1, 224, 224, 3)

    predictions = model.predict(img_array)[0]
    st.image(image, caption="Uploaded Image", use_column_width=True)

    top_indices = predictions.argsort()[-4:][::-1]  # top 4

    st.markdown("## 🏆 Prediction Result")
    st.markdown(
        f"<h3 style='color: green;'>✅ {class_names[top_indices[0]]} ({predictions[top_indices[0]]*100:.2f}%)</h3>",
        unsafe_allow_html=True
    )

    st.markdown("### 🔍 Other Possible Species")
    for idx in top_indices[1:]:
        st.markdown(f"- **{class_names[idx]}** ({predictions[idx]*100:.2f}%)")

st.markdown("<p style='text-align: center; font-size:14px;'>A Project made by Aswin K J</p>", unsafe_allow_html=True)
'''

### ***Congrats! Your model is successfully created and ready for deployment on a live server for a real user interaction !!!***

# **Conclusion**

Based on the comprehensive analysis documented in the notebook, the project successfully identified the optimal deep learning architecture for classifying the given dataset of 11 fish species. Through a comparative study of five different models, the **MobileNetV2** architecture demonstrated overwhelmingly superior performance, achieving a near-perfect **test accuracy of 99.34%**. The success of this model can be attributed to the effective use of transfer learning combined with data augmentation techniques, which prevented overfitting and allowed the model to generalize well. The project concludes with a production-ready model saved as `best_model.h5` along with a `class_names.pkl` containing all the class names. A seperate `mobilenetv2.weights.h5` weights file had to be created due to model loading constraints and the model was loaded live in the app.py and the app was successfully deployed in the Streamlit Cloud and the project was successfully completed !!


### ***Hurrah! You have successfully completed your Machine Learning Capstone Project !!!***