## **Analyze a Classification Model Using Labelchecker Data**
Here, we walk look at the classification performance of a trained model using the small dataset provided.

Here’s an overview of what we’ll cover:

1. **Data Download**: Obtain the example data.
2. **Data Preparation**: Detail the necessary processing steps before analysis.
3. **Model download**: Download the model.
4. **Model loading**: Load the model for analyzes.
5. **Data Loading**: Set up data loaders for model prediction.
6. **Model Evaluation**: Assess its performance.

Feel free to replace our example data and model with your own 😎. Let’s get started!

### 0 **import libraries**

In [None]:
# import libraries
import cv2
import json
from pathlib import Path
import requests
import zipfile
from rich import print
from tqdm import tqdm
from typing import Tuple
import pandas as pd
import numpy as np

from plotly import express as px
from plotly import graph_objects as go
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

import tensorflow as tf

from src.schemas.ModelConfig import ModelConfig

### 1. **Data Download**
Let's download the example data and start exploring it!


In [None]:
# set the path to the data
data_path = Path("data")

# set dataset name
dataset_name = "example"
data_path = data_path.joinpath(dataset_name)
# make sure the data directory and subdirectories exists
data_path.mkdir(exist_ok=True, parents=True)

In [None]:
# download the data zip file
data_url = "https://zenodo.org/records/14755172/files/data.zip"
data_file = data_path.joinpath("data.zip")

if not data_file.exists():
    print(f"Downloading data from {data_url}")
    r = requests.get(data_url, stream=True)
    if r.status_code == 200:
        # get total file size
        total_size = int(r.headers.get('content-length', 0))
        
        # save the data to a file with progress bar
        with open(data_file, 'wb') as f, tqdm(
            desc="Downloading",
            total=total_size,
            unit='iB',
            unit_scale=True,
            unit_divisor=1024,
        ) as pbar:
            for input_data_samples in r.iter_content(chunk_size=1024):
                size = f.write(input_data_samples)
                pbar.update(size)
        
        # extract the data with progress bar
        with zipfile.ZipFile(data_file, 'r') as zip_ref:
            members = zip_ref.namelist()
            for member in tqdm(members, desc="Extracting"):
                zip_ref.extract(member, data_path)
        print(f"Data extracted to {data_path}")
    else:
        print(f"Failed to download data: {r.status_code}")
else:
    print(f"Data file {data_file} already exists")


In [None]:
# fetch all Labelchecker data files from the data directory
data_files = list(data_path.glob(f"**/LabelChecker_*.csv"))
print(f"Found {len(data_files)} data files")

## 2. **Data Preparation**
To clean and prepare the data we:
1. Subset data that has a `LabelTrue value`
2. drop columns with only `missing values`
3. drop columns with `default values`
4. set `image paths`
5. drop columns with `object` data
6. remove labels with less than N examples

> for details see about the data preprocessing steps see [train_model.ipynb](./train_model.ipynb)

In [None]:
# check for default values function
def is_default(series: pd.Series) -> bool:
    return len(series.unique()) == 1


# drop all object columns except for LabelTrue function
def is_object(
    series: pd.Series,
    columns_to_keep: list[str] = ["LabelTrue", "ImageFilename", "CollageFile"],
) -> bool:
    if series.name in columns_to_keep:
        return False
    return series.dtype == "object"

# drop labels with less than N examples
def drop_labels_with_less_than_examples(data: pd.DataFrame, min_examples: int) -> pd.DataFrame:
    return data.groupby("LabelTrue").filter(lambda x: len(x) >= min_examples)

# build image paths
def build_image_path(df: pd.DataFrame, directory: Path) -> Tuple[bool, list[str]]:
    """
    Builds a list of image paths based on the given DataFrame and directory.

    Args:
        df (pd.DataFrame): The DataFrame containing the image filenames and names.
        directory (Path): The directory where the images are located.

    Returns:
        Tuple[bool, list[str]]: A tuple containing a boolean value indicating whether the image paths are for collage files,
        and a list of image paths.

    Raises:
        FileNotFoundError: If any of the image files are missing.
    """
    is_collage = True
    image_paths = []
    if "ImageFilename" in df.columns:
        if not df["ImageFilename"].isnull().all() and not df["Name"].isnull().all():
            is_collage = False
            for name, filename in zip(df["Name"], df["ImageFilename"]):
                image_path = Path.joinpath(directory, name, filename)
                if not image_path.exists():
                    raise FileNotFoundError(f"file {filename} not found")
                image_paths.append(image_path.as_posix())
    if "CollageFile" in df.columns:
        if not df["CollageFile"].isnull().all():
            is_collage = True
            for collage_file in df["CollageFile"]:
                image_path = Path.joinpath(directory, collage_file)
                if not image_path.exists():
                    raise FileNotFoundError(f"file {collage_file} not found")
                image_paths.append(image_path.as_posix())
    return is_collage, image_paths


def load_and_process_data(
    data_files: list[Path], 
) -> Tuple[pd.DataFrame, LabelEncoder]:
    """
    Load data from the data files and preprocess the data.

    Args:
        data_files (list[Path]): A list of file paths to the data files.
        min_examples (int, optional): The minimum number of examples required for each label. Defaults to 5.

    Returns:
        pd.DataFrame: A tuple containing the preprocessed data as a DataFrame
    """
    data = []
    for data_file in data_files:
        if not data_file.exists():
            raise FileNotFoundError(f"File {data_file} not found")
        df = pd.read_csv(data_file)

        # Build the image paths
        is_collage, image_paths = build_image_path(df, data_file.parent)
        if image_paths:
            if is_collage:
                df["CollageFile"] = image_paths
            else:
                df["ImageFilename"] = image_paths
        data.append(df)
    data = pd.concat(data)

    # Drop rows with missing LabelTrue values
    data = data.loc[data["LabelTrue"].str.len() > 0]
    data = data.dropna(subset=["LabelTrue"])

    # Drop columns with all missing values
    data = data.dropna(axis=1, how="all")

    # Drop columns with default values
    data = data.loc[:, ~data.apply(is_default)]

    # Drop all object columns except for LabelTrue function
    data = data.loc[:, ~data.apply(is_object)]
    return data

In [None]:
# load the data
input_data_samples = load_and_process_data(data_files)
print(f"the data contains {input_data_samples.shape[0]} samples")
print(
    f"the data contains the following columns: {[column_name for column_name in input_data_samples.columns]}"
)

## 3. **Model downloading**

In [None]:
# set the path to service
path_to_service = Path().joinpath("src", "services", "classification")
service_name = "ObjectClassification"

# set the path to the models directory
model_dir = Path().joinpath(path_to_service, service_name, "models")
print(f"Model directory: {model_dir}")

# create the model directory
model_dir.mkdir(parents=True, exist_ok=True)

In [None]:
# download the model zip file
model_url = "https://zenodo.org/records/14755172/files/freshwater_phytoplankton_model.zip"
model_file = model_dir.with_suffix(".zip")

if not model_file.exists():
    print(f"Downloading model from {model_url}")
    r = requests.get(model_url, stream=True)
    if r.status_code == 200:
        # get total file size
        total_size = int(r.headers.get('content-length', 0))
        
        # save the model to a file with progress bar
        with open(model_file, 'wb') as f, tqdm(
            desc="Downloading",
            total=total_size,
            unit='iB',
            unit_scale=True,
            unit_divisor=1024,
        ) as pbar:
            for data in r.iter_content(chunk_size=1024):
                size = f.write(data)
                pbar.update(size)
        
        # extract the model with progress bar
        with zipfile.ZipFile(model_file, 'r') as zip_ref:
            members = zip_ref.namelist()
            for member in tqdm(members, desc="Extracting"):
                zip_ref.extract(member, model_dir)
        print(f"Model extracted to {model_dir}")

        # remove the zip file
        model_file.unlink()
    else:
        print(f"Failed to download model: {r.status_code}")


## 4. **Model loading**

In [None]:
def load_model_configuration(config_file_path: Path) -> ModelConfig:
        try:
            with open(config_file_path, "r") as file:
                model_config = json.load(file)
                return ModelConfig(model_config=model_config)
        except ValueError as e:
            raise ValueError(f"Model configuration file is missing required values: {e}")

In [None]:
# load the model configuration
model_config_file = list(model_dir.glob("**/*.json"))[0]
model_config = load_model_configuration(model_config_file)
print(f"Model configuration: {model_config}")

In [None]:
# We need a label encoder to encode that encodes all labels, including those that are not present in the data and model prediction classes
encoder = LabelEncoder()
target_names = set(input_data_samples["LabelTrue"].unique()).union(set(model_config.Class_names))
encoder.fit(list(target_names))

In [None]:
# we now use this encoder to transform the labels in the data
input_data_samples["LabelTrue"] = encoder.transform(input_data_samples["LabelTrue"])

In [None]:
# helper function to print label counts
def print_label_counts(data: pd.DataFrame, class_names: list[str]):
    label_counts = data["LabelTrue"].value_counts()
    value_counts = {}
    for label, count in label_counts.items():
        value_counts[class_names[label]] = count

    # sort the labels by count
    sorted_value_counts = sorted(
        value_counts.items(), key=lambda x: x[1], reverse=False
    )
    sorted_labels = [label for label, count in sorted_value_counts]
    sorted_counts = [count for label, count in sorted_value_counts]

    # plot the label counts
    px.bar(
        x=sorted_counts,
        y=sorted_labels,
        title="Label Counts",
        orientation="h",
        labels={"x": "Count", "y": "Label"},
        width=800,
        height=1200,
    ).show()

In [None]:
print(
    f"the data contains these labels: {encoder.classes_}; \na total of {len(encoder.classes_)} labels"
)
print_label_counts(input_data_samples, encoder.classes_)

In [None]:
# We need a model encoder to encode the model labels because not all labels present in the data are present in the model configuration
model_encoder = LabelEncoder()
model_encoder.fit(model_config.Class_names)

In [None]:
# load the model
model_file_path = list(model_dir.glob("**/*.keras"))[0]
model = tf.keras.models.load_model(model_file_path)
print(f"Model loaded from {model_file_path}")
model.summary()

## 5. **Data loader**

In [None]:
#  read the image file
def decode_image(row: pd.Series, image_size: tuple[int, int, int]) -> tf.Tensor:
    if "ImageFilename" in row:
        image_string = tf.io.read_file(row["ImageFilename"])
        image = tf.io.decode_png(image_string, channels=image_size[-1])  # png images
        return image
    else:
        image_path = tf.strings.as_string(row["CollageFile"])
        image = tf.numpy_function(read_tiff, [image_path], tf.uint8)
        image.set_shape([None, None, 3])
        image = remove_alpha_channel(
            image, image_size=image_size
        )  # RGBA (4 channels) to RGB (3 channels)
        image = crop_image(row, image)  # crop out the object image
        return image


# remove the alpha channel
def remove_alpha_channel(image, image_size: tuple[int, int, int]) -> tf.Tensor:
    return tf.convert_to_tensor(image[:, :, : image_size[-1]])  # remove alpha channel

# read TIFF images
def read_tiff(path_tensor: tf.Tensor):
    # path_tensor is already bytes, just decode it
    path = path_tensor.decode("utf-8")
    img = cv2.imread(path, cv2.IMREAD_COLOR)
    if img is None:
        raise ValueError(f"Image not found at path: {path}")
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img.astype(np.uint8)

# crop out the object image from the collage
def crop_image(row: pd.Series, image):
    image_x = tf.squeeze(row["ImageX"])
    image_y = tf.squeeze(row["ImageY"])
    image_width = tf.squeeze(row["ImageW"])
    image_height = tf.squeeze(row["ImageH"])
    return image[
        int(image_y) : int(image_y) + int(image_height),
        int(image_x) : int(image_x) + int(image_width),
    ]


def resize_image(image, image_size: tuple[int, int, int]) -> tf.Tensor:
    image = tf.image.resize(image, [image_size[0], image_size[1]])  # H, W only
    return image


# combining all the image processing functions
def get_image(row: pd.Series, image_size: tuple[int, int, int]) -> tf.Tensor:
    image = decode_image(row, image_size=image_size)
    return resize_image(image, image_size=image_size)

In [None]:
# object features
def get_features(row: pd.Series, feature_names: list[str]) -> tf.Tensor:
    return tf.convert_to_tensor(
        [float(row[feature]) for feature in feature_names], dtype=tf.float64
    )

In [None]:
# labels
def get_label(row: pd.Series):
    return row.pop("LabelTrue")

In [None]:
def get_data(
    row: pd.Series,
    image_size: Tuple[int, int, int],
    feature_names: list[str],
):
    image = get_image(row, image_size=image_size)
    features = get_features(row, feature_names=feature_names)
    label = get_label(row)
    return (
        features,
        image,
    ), label  # Note: the order of the features and image is important for the model input

In [None]:
batch_size = 22

# create the training datasets
ds = tf.data.Dataset.from_tensor_slices(dict(input_data_samples))
ds = ds.map(
    lambda x: get_data(x, image_size=model_config.Input_shape, feature_names=model_config.Features),
    num_parallel_calls=tf.data.AUTOTUNE,
)
ds = (
    ds.batch(batch_size=batch_size)
)

## 6. **Model evaluation**

In [None]:
# make predictions
predictions = model.predict(ds, verbose=1)
predicted_labels = np.argmax(predictions, axis=1)

In [None]:
# re-encode the predicted labels to ensure that class mapping is consistent
inverse_predicted_labels = model_encoder.inverse_transform(predicted_labels)
remapped_predicted_labels = encoder.transform(inverse_predicted_labels)

In [None]:
# set the target names
target_names = [encoder.classes_[i] for i in set(input_data_samples["LabelTrue"].unique()).union(set(remapped_predicted_labels))]

In [None]:
# print the classification accuracy
print(f"Overall classification accuracy is: {accuracy_score(input_data_samples['LabelTrue'], remapped_predicted_labels)}")

In [None]:
# plot confusion matrix
def plot_confusion_matrix(
    true_labels: np.ndarray,
    predicted_labels: np.ndarray,
    class_names: list[str],
    text_size: int = 10,
    normalize: bool = True,
    width: int = 1000,
    height: int = 1000,
):
    cm = confusion_matrix(
        y_true=true_labels,
        y_pred=predicted_labels,
        normalize="true" if normalize else None,
    )
    # normalize the confusion matrix

    fig = go.Figure(
        data=go.Heatmap(
            z=cm,
            x=class_names,
            y=class_names,
            colorscale="Viridis",
            showscale=False,
            text=cm,
            texttemplate="%{text:.2f}",
            textfont={"size": text_size},
        )
    )

    fig.update_layout(
        title="Confusion Matrix",
        title_x=0.5,
        xaxis_title="Predicted",
        yaxis_title="True",
        autosize=False,
        width=width,
        height=height,
    )

    fig.show()

In [None]:
plot_confusion_matrix(
    predicted_labels=remapped_predicted_labels,
    true_labels=input_data_samples["LabelTrue"],
    class_names=target_names,
    text_size=10,
    normalize=True,
    width = 1200,
    height = 1200,
)

In [None]:
# print classification report
print(
    classification_report(
        input_data_samples["LabelTrue"], 
        remapped_predicted_labels,
        target_names=target_names,
        zero_division=0
    )
)
