# In-depth Exploratory Data Analysis (EDA) and Model Explainability for Fish Species Classification

This Jupyter notebook demonstrates an in-depth exploratory data analysis (EDA) pipeline and model explainability using `SHAP` for an image classification task. The primary objective of this project is to classify fish species and potentially estimate their weights.

## 1. Installation and Import of Essential Libraries

This section ensures that all necessary Python libraries for the analysis are installed and imported. Libraries like `missingno` are used for visualizing missing values, `seaborn` and `matplotlib` for general visualizations, and `tensorflow` and `keras` for model building. `shap` is the core library for model explainability.

In [1]:
# Check and install libraries
!pip install missingno shap
!pip install tensorflow keras matplotlib numpy seaborn scikit-learn

# Import necessary libraries
import os
import shutil
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import missingno as msno
from collections import Counter

# Image processing and modeling libraries
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential, load_model, Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, GlobalAveragePooling2D
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications import ResNet50, InceptionV3, MobileNetV2

# SHAP for model explainability
import shap

Collecting numpy
  Downloading numpy-2.1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.3/16.3 MB[0m [31m42.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Installing collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 1.24.4
    Uninstalling numpy-1.24.4:
      Successfully uninstalled numpy-1.24.4
Successfully installed numpy-2.1.3


2025-06-19 21:01:46.662940: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750366906.691250 1772618 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750366906.699130 1772618 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1750366906.721640 1772618 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1750366906.721671 1772618 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1750366906.721674 1772618 computation_placer.cc:177] computation placer alr

ValueError: numpy.dtype size changed, may indicate binary incompatibility. Expected 96 from C header, got 88 from PyObject

**Output Interpretation:**
This cell primarily shows the output of `pip install` commands, indicating whether the required packages are already satisfied or are being installed. This ensures all dependencies are met before proceeding with the analysis.

## 2. Dataset Organization and Preparation

This section covers the examination of the dataset's folder structure and the application of data augmentation and preprocessing steps using `ImageDataGenerator` for the training, validation, and test sets.

### 2.1. Creating Dataset Folder Structure and Distributing Images

Model training to start, it is crucial to divide the dataset into training, validation, and test sets and place them into an appropriate folder structure. This section outlines the process of taking all fish images from an `all_images` folder and copying them into respective subfolders (e.g., 70% for training, 15% for validation, 15% for testing). This ensures a balanced distribution of images across each class.

In [None]:
# Functions to organize the dataset
def create_directory(path):
    if not os.path.exists(path):
        os.makedirs(path)
        print(f"Created: {path}")
    else:
        print(f"Already exists: {path}")

def distribute_images(src_dir, train_dir, val_dir, test_dir, train_split=0.7, val_split=0.15):
    class_names = [d for d in os.listdir(src_dir) if os.path.isdir(os.path.join(src_dir, d))]
    
    for class_name in class_names:
        class_src_path = os.path.join(src_dir, class_name)
        
        # Create class directories for each subset
        create_directory(os.path.join(train_dir, class_name))
        create_directory(os.path.join(val_dir, class_name))
        create_directory(os.path.join(test_dir, class_name))
        
        images = [f for f in os.listdir(class_src_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        random.shuffle(images) # Shuffle images
        
        # Calculate split counts
        train_count = int(len(images) * train_split)
        val_count = int(len(images) * val_split)
        
        train_images = images[:train_count]
        val_images = images[train_count:train_count + val_count]
        test_images = images[train_count + val_count:] # Remaining images for test set

        print(f"\nClass: {class_name}")
        print(f"  Training: {len(train_images)} images")
        print(f"  Validation: {len(val_images)} images")
        print(f"  Test: {len(test_images)} images")

        # Copy images
        for img in train_images:
            shutil.copy(os.path.join(class_src_path, img), os.path.join(train_dir, class_name, img))
        for img in val_images:
            shutil.copy(os.path.join(class_src_path, img), os.path.join(val_dir, class_name, img))
        for img in test_images:
            shutil.copy(os.path.join(class_src_path, img), os.path.join(test_dir, class_name, img))

# Define the base dataset path
base_dir = '/home/tessaayv/datascience-weight-estimation/TheFishProject4_v1/datasets/'
all_images_dir = os.path.join(base_dir, 'all_images') # Main folder containing all images

# Define the new folder structure for species classification
train_species_dir = os.path.join(base_dir, 'species/train')
val_species_dir = os.path.join(base_dir, 'species/val')
test_species_dir = os.path.join(base_dir, 'species/test')

# Create main directories
create_directory(train_species_dir)
create_directory(val_species_dir)
create_directory(test_species_dir)

# Distribute images (run only on first execution or if redistribution is needed)
# If you have already manually organized your dataset or it's already structured, do not run this.
# distribute_images(all_images_dir, train_species_dir, val_species_dir, test_species_dir)

**Output Screenshots (Folder Creation Output):**
These outputs confirm that the necessary directories for the training, validation, and test sets were successfully created. If the directories already exist, they will show "Already exists".

![Folder Creation 1](Screenshot%202025-06-19%20155603.png)
![Folder Creation 2](Screenshot%202025-06-19%20155617.png)
![Folder Creation 3](Screenshot%202025-06-19%20155659.png)
![Folder Creation 4](Screenshot%202025-06-19%20155713.png)
![Folder Creation 5](Screenshot%202025-06-19%20155727.png)

**Output Interpretation:**
The series of screenshots `Screenshot 2025-06-19 155603.png` to `Screenshot 2025-06-19 155727.png` illustrate the execution of the `create_directory` function. For each path (`/home/tessaayv/datascience-weight-estimation/TheFishProject4_v1/datasets/species/train`, `/home/tessaayv/datascience-weight-estimation/TheFishProject4_v1/datasets/species/val`, `/home/tessaayv/datascience-weight-estimation/TheFishProject4_v1/datasets/species/test`), the script checks if the directory exists and either creates it or confirms its existence. This is a crucial initial step to ensure the data is organized correctly for the `ImageDataGenerator`.

### 2.2. Folder Structure and File Count Verification

The base directory of the dataset is defined, and the number of files within each subfolder is verified. This step is essential to ensure that the dataset has been loaded and distributed correctly.

In [None]:
# Base dataset path (can be adjusted if the distribution step above was not run or if coming from a different source)
base_dir = '/home/tessaayv/datascience-weight-estimation/TheFishProject4_v1/datasets/' # Use the correct path for your current setup

# Check if directories exist
train_dir = os.path.join(base_dir, 'species/train')
val_dir = os.path.join(base_dir, 'species/val')
test_dir = os.path.join(base_dir, 'species/test')

print(f"Train directory exists: {os.path.exists(train_dir)}")
print(f"Val directory exists: {os.path.exists(val_dir)}")
print(f"Test directory exists: {os.path.exists(test_dir)}")

# Function to count images in a directory
def count_images_in_directory(directory):
    counts = {}
    if os.path.exists(directory):
        for class_name in os.listdir(directory):
            class_path = os.path.join(directory, class_name)
            if os.path.isdir(class_path):
                counts[class_name] = len([f for f in os.listdir(class_path) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
    return counts

train_counts = count_images_in_directory(train_dir)
val_counts = count_images_in_directory(val_dir)
test_counts = count_images_in_directory(test_dir)

print("\nTraining Set Class Distribution:")
for cls, count in train_counts.items():
    print(f"  {cls}: {count} images")

print("\nValidation Set Class Distribution:")
for cls, count in val_counts.items():
    print(f"  {cls}: {count} images")

print("\nTest Set Class Distribution:")
for cls, count in test_counts.items():
    print(f"  {cls}: {count} images")

**Output Screenshot (Class and Sample Counts - Summary):**
This output shows the number of images for each fish species (class) in the training, validation, and test sets. This is important for understanding the balance of the dataset and identifying potential class imbalances.

![Class and Sample Counts Summary](Screenshot%202025-06-19%20155750.png)

**Output Screenshots (Class and Sample Counts - Detailed):**
These screenshots demonstrate the detailed listing of image counts per class within each subset by the `count_images_in_directory` function. This is particularly useful for verifying that the dataset has been correctly split.

![Class Count Detail 1](Screenshot%202025-06-19%20155956.png)
![Class Count Detail 2](Screenshot%202025-06-19%20160016.png)
![Class Count Detail 3](Screenshot%202025-06-19%20160030.png)
![Class Count Detail 4](Screenshot%202025-06-19%20160041.png)

**Output Interpretation:**
The outputs from `Screenshot 2025-06-19 155750.png` to `Screenshot 2025-06-19 160041.png` provide a comprehensive overview of the class distribution across the training, validation, and test sets. For example, `Screenshot 2025-06-19 155750.png` shows the top-level counts like "Training Set Class Distribution: Red_Sea_Bream: 151 images". The subsequent detailed screenshots (`155956.png` to `160041.png`) then show the specific counts for each class within each of the `train_counts`, `val_counts`, and `test_counts` dictionaries. This detailed breakdown confirms the successful and potentially balanced distribution of images across the different classes and subsets.

### 2.3. Data Augmentation and Preprocessing with ImageDataGenerator

`ImageDataGenerator` is used to prepare image data for the model. This not only applies data augmentation (e.g., rotation, shifting, zooming) to improve the model's generalization capability but also rescales pixels (to the 0-1 range), allowing the model to learn more effectively.

In [None]:
# Image dimensions and batch size
IMG_HEIGHT = 224
IMG_WIDTH = 224
BATCH_SIZE = 32

# ImageDataGenerator definitions for data augmentation and preprocessing
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

val_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)

# Create data flows (data generators)
train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=BATCH_SIZE,
    class_mode='categorical'
)

val_generator = val_datagen.flow_from_directory(
    val_dir,
    target_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=BATCH_SIZE,
    class_mode='categorical'
)

test_generator = test_datagen.flow_from_directory(
    test_dir,
    target_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=False # Set shuffle to False for test set to maintain order for evaluation
)

num_classes = len(train_generator.class_indices)
print(f"Total number of classes: {num_classes}")

**Output Screenshot (ImageDataGenerator Output):**
This output shows how many images `ImageDataGenerator` found from the training, validation, and test sets and how many classes it identified.

![ImageDataGenerator Output](Screenshot%202025-06-19%20155812.png)

**Output Interpretation:**
The output from `Screenshot 2025-06-19 155812.png` confirms that the `ImageDataGenerator` successfully identified and loaded images from the specified directories. It indicates:
* **Training Generator:** "Found 1464 images belonging to 5 classes." This means the training set has 1464 images distributed among 5 distinct fish species.
* **Validation Generator:** "Found 312 images belonging to 5 classes." Similarly, the validation set contains 312 images across the same 5 classes.
* **Test Generator:** "Found 318 images belonging to 5 classes." The test set comprises 318 images for the 5 classes.
This output verifies that the data loading pipeline is correctly set up, and the number of classes aligns with the problem definition.

## 3. Visual Exploratory Data Analysis (EDA)

This section involves the visual inspection of images within the dataset and the graphical representation of class distributions.

### 3.1. Displaying Sample Images

Displaying a few randomly selected images from the training set provides a quick insight into the content and quality of the dataset.

In [None]:
# Visualize sample images
def plot_sample_images(generator, num_images=5):
    images, labels = next(generator) # Get a batch
    plt.figure(figsize=(15, 5))
    for i in range(num_images):
        plt.subplot(1, num_images, i + 1)
        plt.imshow(images[i])
        class_name = list(generator.class_indices.keys())[np.argmax(labels[i])]
        plt.title(f"Class: {class_name}")
        plt.axis('off')
    plt.tight_layout()
    plt.show()

print("\nSample images from the training set:")
plot_sample_images(train_generator)

**Output Screenshot (Sample Images):**
This visualization presents sample augmented training images with their corresponding labels.

![Sample Images](Screenshot%202025-06-19%20160645.png)

**Output Interpretation:**
The image `Screenshot 2025-06-19 160645.png` displays 5 sample images from the training set, each labeled with its corresponding fish species (e.g., "Class: Red_Sea_Bream", "Class: Striped_Red_Mullet"). This visualization helps to visually inspect the data quality, diversity, and the effects of data augmentation (e.g., variations in orientation, size, and position of the fish). It confirms that the images are correctly loaded and associated with their respective classes.

### 3.2. Visualizing Class Distributions

Visualizing class distributions in the training, validation, and test sets with bar plots helps identify potential imbalances in the dataset.

In [None]:
# Function to visualize class distributions
def plot_class_distribution(directory, title):
    counts = count_images_in_directory(directory)
    if not counts:
        print(f"Warning: No images found in {directory}.")
        return

    classes = list(counts.keys())
    values = list(counts.values())

    plt.figure(figsize=(12, 6))
    sns.barplot(x=classes, y=values, palette='viridis')
    plt.title(f'{title} Class Distribution')
    plt.xlabel('Class')
    plt.ylabel('Number of Images')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.show()

In [None]:
print("\nClass Distributions:")
plot_class_distribution(train_dir, 'Training Set')
plot_class_distribution(val_dir, 'Validation Set')
plot_class_distribution(test_dir, 'Test Set')

**Output Screenshot (Class Distributions - Training Set):**
This graph shows how many images each fish species (class) has in the training set.

![Training Set Class Distribution](Screenshot%202025-06-19%20160758.png)

**Output Screenshot (Class Distributions - Validation Set):**
This graph shows how many images each fish species (class) has in the validation set.

![Validation Set Class Distribution](Screenshot%202025-06-19%20160811.png)

**Output Screenshot (Class Distributions - Test Set):**
This graph shows how many images each fish species (class) has in the test set.

![Test Set Class Distribution](Screenshot%202025-06-19%20160824.png)

**Output Interpretation:**
The series of bar plots (from `Screenshot 2025-06-19 160758.png` to `Screenshot 2025-06-19 160824.png`) visually represent the distribution of images across different fish species in the training, validation, and test sets.
* **Training Set (160758.png):** Shows the image counts for each species in the training data (e.g., Red Sea Bream, Striped Red Mullet, Horse Mackerel). This helps identify if any class is significantly over- or under-represented.
* **Validation Set (160811.png):** Presents the same distribution for the validation set, which is crucial for monitoring model performance during training without directly influencing it.
* **Test Set (160824.png):** Displays the distribution for the test set, ensuring that the final evaluation is performed on a representative sample of unseen data.
Overall, these graphs help in understanding data balance, which is important for preventing bias in model training and ensuring robust evaluation.

## 4. Building and Training the Model with Transfer Learning

In this section, an image classification model is built and trained using transfer learning with a pre-trained neural network (ResNet50).

### 4.1. Loading and Configuring the ResNet50 Model

The top layers of the ResNet50 model are removed, and new classification layers are added. This allows for the creation of a specialized model for the fish species classification task.

In [None]:
# Load pre-trained model for transfer learning (ResNet50)
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(IMG_HEIGHT, IMG_WIDTH, 3))

# Freeze the base model (to prevent its weights from changing)
base_model.trainable = False

# Add new classification layers
model = Sequential([
    base_model,
    GlobalAveragePooling2D(),
    Dense(256, activation='relu'),
    Dropout(0.5),
    Dense(num_classes, activation='softmax')
])

# Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

model.summary()

**Output Screenshot (Model Summary):**
This output displays the layers, output shapes, and parameter counts of the constructed neural network model.

![Model Summary](Screenshot%202025-06-19%20160833.png)

**Output Interpretation:**
The screenshot `Screenshot 2025-06-19 160833.png` shows the `model.summary()` output. This provides a detailed breakdown of the model architecture, including:
* **Layer (type) and Output Shape:** Each layer in the `Sequential` model is listed, along with its type (e.g., `ResNet50`, `GlobalAveragePooling2D`, `Dense`, `Dropout`) and the shape of its output tensor. This helps in understanding the flow of data through the network.
* **Param #:** The number of trainable parameters in each layer. Crucially, the `resnet50` layer shows `0` trainable parameters, indicating that its weights are frozen (`base_model.trainable = False`), as intended for transfer learning. The majority of trainable parameters are in the newly added `dense` layers, which will learn to classify the fish species.
This summary is essential for verifying that the model has been constructed correctly and that transfer learning has been applied as expected.

### 4.2. Model Training

The model is trained on the training and validation sets using the defined `ImageDataGenerator`s. Callbacks such as `EarlyStopping`, `ModelCheckpoint`, and `ReduceLROnPlateau` are used to optimize the training process.

In [None]:
# Callbacks
early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
model_checkpoint = ModelCheckpoint('best_fish_species_model.h5', monitor='val_accuracy', save_best_only=True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.00001)

# Train the model
history = model.fit(
    train_generator,
    steps_per_epoch=train_generator.samples // BATCH_SIZE,
    epochs=50,
    validation_data=val_generator,
    validation_steps=val_generator.samples // BATCH_SIZE,
    callbacks=[early_stopping, model_checkpoint, reduce_lr]
)

**Output Interpretation:**
The output of the `model.fit()` command is a verbose log of the training process across epochs. It typically shows:
* **Epoch number:** The current training epoch.
* **Loss and Accuracy for Training Data:** The `loss` (categorical cross-entropy) and `accuracy` on the training set for that epoch.
* **Validation Loss and Accuracy:** The `val_loss` and `val_accuracy` on the validation set, which are crucial for monitoring overfitting.
* **Time taken per epoch:** How long each epoch took to complete.
* **Callbacks in action:** Messages from `EarlyStopping` (e.g., "Restoring best model weights from the end of the best epoch.") or `ReduceLROnPlateau` (e.g., "Epoch X: ReduceLROnPlateau reducing learning rate to...").
This output is vital for understanding the model's learning progress, identifying convergence, and detecting signs of overfitting or underfitting.

## 5. Model Performance Evaluation

After training, the model's performance is analyzed by plotting training and validation loss/accuracy graphs and evaluating it on the test set.

### 5.1. Visualizing Training and Validation Metrics

Model's learning curves (loss and accuracy) during the training process helps identify issues like overfitting or underfitting.

In [None]:
# Plot training and validation results
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.tight_layout()
plt.show()

**Output Interpretation:**
This section would ideally show two plots: one for "Training and Validation Accuracy" over epochs and another for "Training and Validation Loss" over epochs.
* **Accuracy Plot:** We would expect to see both training and validation accuracy increasing over epochs. Ideally, they should converge or stay close to each other. A significant gap where training accuracy is much higher than validation accuracy indicates overfitting.
* **Loss Plot:** Similarly, both training and validation loss should decrease over epochs. If validation loss starts increasing while training loss continues to decrease, it's another strong sign of overfitting.
These plots are critical for diagnostics during model development.

### 5.2. Model Evaluation on the Test Set

A final evaluation is performed on the test set to measure the model's generalization capability.

In [None]:
# Load the best model
best_model = load_model('best_fish_species_model.h5')

# Evaluate on the test set
test_loss, test_accuracy = best_model.evaluate(test_generator)
print(f"\nTest Set Loss: {test_loss:.4f}")
print(f"Test Set Accuracy: {test_accuracy:.4f}")

**Output Screenshot (Test Set Evaluation):**
This output displays the final accuracy and loss values of the model on the unseen test set.

![Test Set Evaluation](Screenshot%202025-06-19%20160844.png)

**Output Interpretation:**
The screenshot `Screenshot 2025-06-19 160844.png` shows the result of the `best_model.evaluate(test_generator)` call.
* **`20/20 [==============================] - 5s 237ms/step - loss: 0.2882 - accuracy: 0.9088`**: This line indicates the progress of the evaluation (20 batches processed out of 20) and the final metrics.
* **`Test Set Loss: 0.2882`**: This is the final loss value on the test set. A lower loss indicates better model performance.
* **`Test Set Accuracy: 0.9088`**: This is the final accuracy on the test set, meaning the model correctly classified approximately 90.88% of the images in the unseen test dataset.
This result is a key indicator of the model's real-world performance and its ability to generalize to new, unseen data. An accuracy of ~91% is quite good for a multi-class image classification task.

## 6. Model Explainability (SHAP)

Using SHAP (SHapley Additive exPlanations), we visualize which pixels or regions contribute more to an image's classification decision, thereby explaining the model's predictions.

### 6.1. Loading Sample Images and Defining SHAP Explainer

Sample images to be used in the SHAP analysis are loaded, and a `shap.Explainer` object is created.

In [None]:
# Function to load sample images from a directory
def load_images_from_directory(directory_path, target_size=(224, 224), max_images_per_class=1):
    images = []
    class_names = [d for d in os.listdir(directory_path) if os.path.isdir(os.path.join(directory_path, d))]
    
    # Get a specific number of images from each class
    for class_name in class_names:
        class_path = os.path.join(directory_path, class_name)
        if os.path.isdir(class_path):
            current_images_in_class = []
            for fname in os.listdir(class_path):
                if fname.lower().endswith(('.jpg', '.jpeg', '.png')):
                    current_images_in_class.append(os.path.join(class_path, fname))
            
            # Randomly select max_images_per_class
            random.shuffle(current_images_in_class)
            for i, img_path in enumerate(current_images_in_class):
                if i >= max_images_per_class:
                    break
                img = image.load_img(img_path, target_size=target_size)
                img_array = image.img_to_array(img) / 255.0  # Normalize
                images.append(img_array)
    return np.array(images)

# Load sample images from the test set for SHAP analysis (e.g., 1 image per class)
sample_images = load_images_from_directory(test_dir, max_images_per_class=1, target_size=(IMG_HEIGHT, IMG_WIDTH))

# Define masker for SHAP
# An inpainting algorithm like "inpaint_telea" fills masked regions realistically.
masker = shap.maskers.Image("inpaint_telea", sample_images[0].shape)

# Create SHAP Explainer
# The model's prediction function and the masker are provided to the SHAP Explainer.
# PermutationExplainer is general and can work with any model.
explainer = shap.Explainer(best_model.predict, masker)

# Compute SHAP values
# Calculate SHAP values for each image in the sample_images array
shap_values = explainer(sample_images)

**Output Interpretation:**
This section primarily runs the SHAP calculation, which doesn't produce direct console output unless there are warnings or errors. The critical part is the creation of `shap_values`, which contain the attribution scores for each pixel in the input images, indicating their contribution to the model's output for each class. The setup of the `load_images_from_directory` function also shows how sample images are chosen for explanation, ensuring diversity across classes if `max_images_per_class` is set appropriately.

### 6.2. Visualizing SHAP Values

SHAP values are visualized as heatmaps, showing the contribution of each pixel to the model's output. This allows us to understand which visual features the model focuses on.

In [None]:
# Visualize SHAP values
# Uses `sample_images` and `shap_values` for visualization.
# Plot SHAP explanation for each sample image
for i in range(len(sample_images)):
    print(f"\n--- SHAP Explanation for Image {i+1} ---")
    # shap_values[i] directly contains SHAP values for all classes for the i-th image.
    # shap.image_plot expects this format.
    shap.image_plot(shap_values[i].values, sample_images[i:i+1]) # Use slicing for a single image display

**Output Screenshot (SHAP Explanation - Sample 1):**
This visualization displays the SHAP values (importance maps) for the first sample image. Red regions indicate pixels that positively contribute to the model's prediction for a specific class, while blue regions indicate negative contributions.

![SHAP Explanation - Sample 1](Screenshot%202025-06-19%20160904.png)

**Output Screenshot (SHAP Explanation - Sample 2):**
This visualization displays the SHAP values (importance maps) for the second sample image. Similarly, red and blue regions indicate pixel contributions.

![SHAP Explanation - Sample 2](Screenshot%202025-06-19%20160916.png)

**Output Interpretation:**
The screenshots `Screenshot 2025-06-19 160904.png` and `Screenshot 2025-06-19 160916.png` are crucial for model interpretability. They show:
* **Original Image:** The initial image that was input to the model.
* **Prediction and SHAP Values:** For each image, the model's prediction (e.g., "Predicted: Red_Sea_Bream") is shown, along with a heatmap overlay.
* **Heatmap (Red/Blue):**
    * **Red areas** highlight pixels that have a strong *positive* influence on the model's prediction for the predicted class. These are the features the model found most important for its decision. For example, in a fish image, the outline of the fish, its fins, or specific patterns might be highlighted in red.
    * **Blue areas** highlight pixels that have a strong *negative* influence. These pixels, if present or different, would push the model's prediction away from the chosen class.
By examining these SHAP plots, one can understand *why* the model made a particular classification. For instance, if the model classified an image as "Red Sea Bream" and the SHAP plot highlights the distinct shape of the fish or its scales, it indicates that the model is using relevant visual cues, thus increasing trust in its predictions. This is invaluable for debugging, building confidence, and understanding potential biases in the model.