# Wood Species Identification Model - Required Modules Explanation

This document outlines the modules required for building a wood species identification model, explaining their use cases.

## TensorFlow and Keras

- **TensorFlow**: An open-source machine learning framework used for both research and production. TensorFlow offers APIs for beginners and experts to develop for desktop, mobile, web, and cloud.
- **Keras**: A high-level neural networks API, written in Python and capable of running on top of TensorFlow. It enables fast experimentation with deep neural networks.

### Key Components from TensorFlow and Keras:

- **`Sequential`**: A linear stack of layers.
- **`Model`**: The base class for Keras models.
- **`Dense`, `Conv2D`, `MaxPool2D`, `Flatten`, `Dropout`**: Layers that are used to build neural networks.
- **`RandomFlip`, `RandomRotation`, `RandomZoom`, `Rescaling`**: Data augmentation layers to help the model generalize better.
- **`SparseCategoricalCrossentropy`**: A loss function used for classification tasks.
- **`SparseCategoricalAccuracy`**: A metric to compute the accuracy rates.
- **`Adam`**: An optimizer for training neural networks.
- **`EarlyStopping`, `ModelCheckpoint`, `TensorBoard`**: Callbacks for monitoring and improving training.

## Keras Tuner

- **Keras Tuner**: A library for hyperparameter tuning for Keras models. It helps to pick the optimal set of hyperparameters for your model.

### Key Components:

- **`Hyperband`**: An optimization algorithm based on the Hyperband algorithm.
- **`HyperParameters`**: A class for defining and managing hyperparameters.

## TensorFlow Data

- **`AUTOTUNE`**: A special value that can be used to indicate that the dataset should tune the number of elements to prefetch dynamically at runtime.
- **`Dataset`**: Provides a way to create and manipulate sequences of data items.

## Custom Modules

- **`src.Dataset` (WSI_Dataset)**: A custom module for loading and preprocessing the wood species identification dataset.
- **`ModelContext`, `ModelFactory`**: Custom modules for managing model lifecycle and factory patterns for creating models.


In [None]:

from keras_tuner import (
    Hyperband,
    HyperParameters,
)
from loguru import logger
from pathlib import Path
from sklearn.metrics import (
    accuracy_score,
    confusion_matrix,
    ConfusionMatrixDisplay,
    f1_score,
    precision_score,
    recall_score,
)
from tensorflow.data import (
    AUTOTUNE,
    Dataset,
)
from tensorflow import keras
from tensorflow.keras import (
    Model,
    regularizers,
)
from tensorflow.keras.callbacks import (
    EarlyStopping,
    ModelCheckpoint,
    TensorBoard,
)
from tensorflow.keras.layers import (
    Conv2D,
    Dense,
    Dropout,
    Flatten,
    MaxPool2D,
    MaxPooling2D,
    RandomFlip,
    RandomRotation,
    RandomZoom,
    Rescaling,
)
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.metrics import SparseCategoricalAccuracy
from tensorflow.keras.models import (
    clone_model,
    Sequential,
)
from tensorflow.keras.optimizers import Adam
from typing import (
    Any,
    Callable,
    List,
    Optional,
)

from src import (
    Dataset as WSI_Dataset,
    ModelContext,
    ModelFactory,
)

import itertools
import keras_tuner as kt
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
import warnings

%load_ext tensorboard

warnings.simplefilter(action="ignore", category=FutureWarning)

os.environ["TF_GPU_ALLOCATOR"] = "cuda_malloc_async"


# Wood Species Identification Model - Constants and Callbacks Explanation

## Constants

### Image Dimensions
- `IMG_WIDTH` and `IMG_HEIGHT`: These constants define the width and height of the images that the model will process. Both are set to 200 pixels, ensuring that all images fed into the model are of uniform size.

### Model Overwriting
- `OVERWRITE_MODEL`: This optional string variable determines whether an existing model should be overwritten. If set to `None`, a new model will be trained.

### Data Augmentation
- `DATA_AUGMENTATION`: A boolean flag indicating whether data augmentation should be used during training. Data augmentation can help improve model generalization.
- `DATA_AUGMENTATION_LAYERS`: A list of data augmentation layers that will be applied if `DATA_AUGMENTATION` is `True`. It includes horizontal flipping, random rotations, and random zooms.

### Training Parameters
- `BATCH_SIZE`: The number of samples that will be propagated through the network in one forward/backward pass. It is set to 4.
- `VALIDATION_SPLIT`: The fraction of the data to be used as validation data. Here, it's set to 30% of the data.
- `FINAL_LAYER_UNITS`: The number of neurons in the final layer of the model, which should match the number of classes in the dataset. For this model, it is set to 12.

## Callbacks

### `HYPERMODEL_CREATION_CALLBACK`
This function is crucial for creating a hypermodel with tunable parameters. It takes a `HyperParameters` object and an optional `Model` object as inputs and returns a compiled model. The function allows for tuning various aspects of the model architecture and compilation settings, including:
- The number of filters and kernel size for convolutional layers.
- L2 regularization strength for convolutional and dense layers.
- Dropout rates to prevent overfitting.
- The number of units in the dense layer.
- The learning rate for the Adam optimizer.

This dynamic creation of the model architecture enables efficient hyperparameter tuning to find the best model configuration.

### `FIT_CALLBACKS`
This function returns a list of callbacks to be used during model training. These callbacks include:
- `EarlyStopping`: Monitors the validation loss and stops training if it does not improve for a specified number of epochs (`patience=15`), helping to prevent overfitting.
- `ModelCheckpoint`: Saves the model after every epoch where there is an improvement in validation accuracy, ensuring that the best model is retained.
- `TensorBoard`: Enables visualization of the training process, including metrics and model architecture, facilitating debugging and optimization.

These callbacks are essential for monitoring the training process, saving the best model, and preventing overfitting.

In [None]:

IMG_WIDTH = 200
IMG_HEIGHT = 200


# The name of the model to be trained, should not include the file extension
# None if a new model is to be trained
OVERWRITE_MODEL: Optional[str] = None


# Set to True if data augmentation is to be used
DATA_AUGMENTATION: bool = False
# The `DATA_AUGMENTATION_LAYERS` will be used only if `DATA_AUGMENTATION` is True
DATA_AUGMENTATION_LAYERS = [
    RandomFlip("horizontal"),
    RandomRotation(0.1),
    RandomZoom(0.1),
]


# Variables for the model
BATCH_SIZE: int = 4
VALIDATION_SPLIT: float = 0.3


# The number of neurons in the final layer, should be 12 for this dataset
FINAL_LAYER_UNITS: int = 12


def HYPERMODEL_CREATION_CALLBACK(
    hp: HyperParameters,
    *,
    model: Optional[Model]=None,
) -> Model:
    """The function to create a hypermodel
    """
    if model is None:
        model = Sequential(
            [
                Rescaling(1. / 255),
                Conv2D(
                    filters=hp.Int("conv_1_filter", min_value=32, max_value=128, step=16),
                    kernel_size=hp.Choice("conv_1_kernel", values=[3, 5]),
                    kernel_regularizer=regularizers.l2(hp.Choice("conv_1_l2", values=[1e-3, 1e-4, 1e-5, 1e-6])),
                    activation="relu",
                    input_shape=(None, IMG_WIDTH, IMG_HEIGHT, 3),
                ),
                MaxPooling2D(
                ),
                Dropout(
                    rate=hp.Float("dropout_1_rate", min_value=0.1, max_value=0.8, step=0.1),
                ),
                Conv2D(
                    filters=hp.Int("conv_2_filter", min_value=32, max_value=128, step=16),
                    kernel_regularizer=regularizers.l2(hp.Choice("conv_2_l2", values=[1e-3, 1e-4, 1e-5, 1e-6])),
                    kernel_size=hp.Choice("conv_2_kernel", values=[3, 5]),
                    activation="relu",
                ),
                MaxPooling2D(
                ),
                Dropout(
                    rate=hp.Float("dropout_2_rate", min_value=0.1, max_value=0.8, step=0.1),
                ),
                Flatten(
                    input_shape=(None, IMG_WIDTH, IMG_HEIGHT, 3)
                ),
                Dense(
                    units=hp.Int("dense_1_units", min_value=32, max_value=512, step=32),
                    kernel_regularizer=regularizers.l2(hp.Choice("dense_1_l2", values=[1e-3, 1e-4, 1e-5, 1e-6])),
                    activation="relu",
                ),
                Dropout(
                    rate=hp.Float("dropout_3_rate", min_value=0.1, max_value=0.5, step=0.1),
                ),
                Dense(
                    units=FINAL_LAYER_UNITS,
                ),
            ]
        )

    model.compile(
        optimizer=Adam(
            learning_rate=hp.Choice("learning_rate", values=[1e-2, 1e-3, 1e-4])
        ),
        loss=SparseCategoricalCrossentropy(
            from_logits=True
        ),
        metrics=[
            "accuracy",
        ]
    )

    return model


def FIT_CALLBACKS(model_name: str) -> List[Callable]:
    """The callbacks to be called after done of each epoch
    """
    return [
        EarlyStopping(
            monitor="val_loss",
            patience=15,
        ),
        ModelCheckpoint(
            filepath=f"caches/checkpoints/{model_name}.keras",
            monitor="val_accuracy",
            mode="max",
            save_best_only=True,
        ),
        TensorBoard(
            log_dir=f"logs/fit/{model_name}",
            histogram_freq=1,
            profile_batch=0,
        ),
    ]


# Retrieving the Dataset for Wood Species Identification

This section demonstrates how to retrieve the dataset split into training, validation, and test sets using a custom dataset class `WSI_Dataset`. This class likely encapsulates the logic for downloading, preprocessing, and partitioning the dataset according to a specified validation split ratio.

The dataset is split as follows:
- **Training set**: Used to train the model.
- **Validation set**: Used to tune the hyperparameters and evaluate the model during training.
- **Test set**: Used to test the model's performance after training is complete.

The `class_names` attribute of the dataset object contains the names of the wood species, which are logged along with the number of samples in each dataset split. This information is crucial for understanding the composition of the dataset and ensuring that the model is trained on a balanced and diverse set of images.

For more details on the dataset and to access it, visit the dataset's page on Hugging Face: [Wood Species Identification Dataset](https://huggingface.co/datasets/LynBean/wood-species-identification).

In [None]:

_raw_train_ds, _raw_val_ds, _raw_test_ds = WSI_Dataset.get(validation_split=VALIDATION_SPLIT)

class_names = _raw_train_ds.class_names

logger.info(f"Raw train set with {len(_raw_train_ds)} samples and {len(_raw_train_ds.class_names)} of classes, which are {', '.join(_raw_train_ds.class_names)}")
logger.info(f"Raw validation set with {len(_raw_val_ds)} samples and {len(_raw_val_ds.class_names)} of classes, which are {', '.join(_raw_val_ds.class_names)}")
logger.info(f"Raw test set with {len(_raw_test_ds)} samples and {len(_raw_test_ds.class_names)} of classes, which are {', '.join(_raw_test_ds.class_names)}")


# Data Preprocessing Explanation

The provided code snippet from `train.ipynb` outlines the data preprocessing steps for a machine learning model focused on wood species identification. The preprocessing involves two main functions: `_process_ds` and `_augment`, along with conditional data augmentation based on the `DATA_AUGMENTATION` flag.

## `_process_ds` Function
This function takes a TensorFlow `Dataset` object, a batch size, and a shuffle flag as inputs. It performs the following operations:
1. **Batching**: Groups the dataset into batches of the specified size using `ds.batch(batch)`.
2. **Shuffling**: If `shuffle` is `True`, the dataset is shuffled with a buffer size of 500 to ensure randomness. This is particularly useful for the training dataset to prevent the model from learning the order of the data.
3. **Caching**: The dataset is cached using `ds.cache()` to improve performance by storing the dataset in memory after the first epoch, reducing read latency in subsequent epochs.
4. **Prefetching**: `ds.prefetch(buffer_size=AUTOTUNE)` allows the dataset to prefetch batches while the model is training, improving efficiency by reducing the time the model spends waiting for data.

## `_augment` Function
This function applies data augmentation to the dataset if the `DATA_AUGMENTATION` flag is set to `True`. It uses a `Sequential` model of `DATA_AUGMENTATION_LAYERS` to apply transformations such as flipping, rotation, and zooming on the images. Data augmentation is performed on-the-fly during training, effectively increasing the diversity of the training data without requiring additional storage.

## Conditional Data Augmentation
The dataset is conditionally augmented based on the `DATA_AUGMENTATION` flag. If `True`, both the raw training and validation datasets are augmented using the `_augment` function. This step enhances the model's ability to generalize by training on a more varied dataset.

## Final Dataset Preparation
Finally, the training, validation, and test datasets are processed using the `_process_ds` function with appropriate batching and shuffling settings. The training dataset is shuffled to ensure randomness, while the validation and test datasets do not need shuffling. The test dataset is batched with a size of 1, as each sample is evaluated independently.

The code concludes with logging the number of samples in the batched train, validation, and test datasets, providing insight into the dataset size and composition after preprocessing.

In [None]:

def _process_ds(ds: Dataset, batch: int, shuffle: bool) -> Dataset:
    ds = ds.batch(batch)

    if shuffle:
        ds.shuffle(buffer_size=500, reshuffle_each_iteration=True)

    ds = ds.cache()
    ds = ds.prefetch(buffer_size=AUTOTUNE)
    return ds

def _augment(ds: Dataset) -> Dataset:
    data_augmentation = Sequential(DATA_AUGMENTATION_LAYERS)

    result = ds.map(
        lambda x, y: (data_augmentation(x, training=True), y),
        num_parallel_calls=AUTOTUNE,
    )

    return result

if DATA_AUGMENTATION:
    _raw_train_ds = _augment(_raw_train_ds)
    _raw_val_ds = _augment(_raw_val_ds)
    logger.info("Dataset has been augmented")


train_ds = _process_ds(_raw_train_ds, batch=BATCH_SIZE, shuffle=True)
val_ds = _process_ds(_raw_val_ds, batch=BATCH_SIZE, shuffle=False)
test_ds = _process_ds(_raw_test_ds, batch=1, shuffle=False)

logger.info(f"Batched train set with {len(train_ds)} samples")
logger.info(f"Batched validation set with {len(val_ds)} samples")
logger.info(f"Test set with {len(test_ds)} samples")


# Visualization of First Image from Each Class

This section is designed to visualize the first image from each class label in the dataset. It iterates through the training dataset, collecting one unique image per class until it has an image for each class name. These images are then plotted in a grid, with each image labeled with its corresponding class name. This visualization helps in understanding the diversity and characteristics of the dataset's classes.

In [None]:

_taken_classes = set()

_images_list = []
_labels_list = []


for images, labels in train_ds:
    for i in range(images.shape[0]):

        if labels[i].numpy() in _taken_classes:
            continue

        _taken_classes.add(labels[i].numpy())

        _images_list.append(images[i])
        _labels_list.append(labels[i])

        if len(_taken_classes) >= len(class_names):
            break

    if len(_taken_classes) >= len(class_names):
        break


_images_tensor = tf.stack(_images_list)
_labels_tensor = tf.stack(_labels_list)

logger.debug(f"Images tensor shape: {_images_tensor.shape}")
logger.debug(f"Labels tensor shape: {_labels_tensor.shape}")


plt.figure(figsize=(10, 10))


for image, label in zip(_images_tensor, _labels_tensor):
    plt.subplot(3, 4, label.numpy() + 1)
    plt.imshow(image.numpy().astype("uint8"))
    plt.title(class_names[label])
    plt.axis("off")


In [None]:
if OVERWRITE_MODEL is not None:
    _contexts: List[ModelContext] = ModelContext.models()
    context: Optional[ModelContext] = next(
        filter(lambda x: x.name == OVERWRITE_MODEL, _contexts),
        None
    )

    if context is None:
        raise ValueError(f"Model {OVERWRITE_MODEL} not found")

    logger.info(f"Model {context.name} will be used for this training")


else:
    context = None


# Hyperparameter Tuning with Hyperband

This section demonstrates the process of hyperparameter tuning for a wood species identification model using the Hyperband algorithm. Hyperparameter tuning is crucial for optimizing the model's performance by finding the best set of parameters.

## Overview

1. **Hyperband Initialization**: The `Hyperband` tuner is initialized with a custom model creation function (`HYPERMODEL_CREATION_CALLBACK`), which dynamically constructs a model based on the current set of hyperparameters (`hp`) being evaluated. The tuner aims to maximize validation accuracy (`objective="val_accuracy"`) over a maximum of 200 epochs, adjusting the number of models trained in each iteration by a factor of 3.

2. **Tuner Search**: The `tuner.search` method starts the hyperparameter search process. It trains different configurations of the model on the `train_ds` dataset, evaluates them on `val_ds`, and uses the callbacks returned by `FIT_CALLBACKS("hyperband")` for early stopping and model checkpointing.

3. **Best Hyperparameters**: After the search completes, the best set of hyperparameters is retrieved using `tuner.get_best_hyperparameters(num_trials=1)[0]`. These parameters are logged for debugging and success reporting, showcasing the optimal values for each hyperparameter.

## Key Components

- **Hyperband Algorithm**: An efficient and effective hyperparameter optimization algorithm that uses a "successive halving" approach. It dynamically allocates resources to more promising configurations.

- **`HYPERMODEL_CREATION_CALLBACK`**: A function that takes hyperparameters as input and returns a compiled model. This allows for flexible model architecture adjustments based on the hyperparameters being tested.

- **Callbacks**: Used during the search to implement strategies like early stopping (to prevent overfitting) and model checkpointing (to save the best model).

This process is essential for fine-tuning the model to achieve the best possible accuracy on the validation dataset, leading to improved performance on unseen data.

In [None]:

%tensorboard --logdir logs/fit


initial_model: Optional[Model] = clone_model(context.model) if context is not None else None


tuner = Hyperband(
    lambda hp: HYPERMODEL_CREATION_CALLBACK(
        hp,
        model=initial_model,
    ),
    objective="val_accuracy",
    max_epochs=200,
    factor=3,
    directory="caches",
    project_name="hyperband",
)


tuner.search(
    train_ds,
    epochs=200,
    validation_data=val_ds,
    callbacks=FIT_CALLBACKS("hyperband"),
)

best_hps = tuner.get_best_hyperparameters(num_trials=1)[0]

logger.debug(f"Best HPs: {best_hps}")
logger.success(
    f"The hyperparameter search is complete. The optimal values are\n" + \
    "\n".join([f"{k.capitalize():25s}: {v}" for k, v in best_hps.values.items()])
)


# Training the Hypermodel with Optimal Hyperparameters

After completing the hyperparameter tuning process, the next step involves building and training the model using the best hyperparameters found. This section outlines this process, ensuring that the model is trained efficiently to achieve the best performance on the validation dataset.

## Building the Hypermodel

**Model Construction**: The model is constructed using the `tuner.hypermodel.build(best_hps)` method, where `best_hps` contains the optimal set of hyperparameters discovered during the tuning process. This step ensures that the model architecture is configured with the best parameters for training.

## Initial Training

1. **Model Training**: The model is trained using the `context.model.fit` method with the training dataset (`train_ds`), validation dataset (`val_ds`), and a set of callbacks returned by `FIT_CALLBACKS(context.name)`. The training is set to run for a large number of epochs (`10000`), but early stopping is expected to halt training when no improvement is observed.

2. **Evaluation**: After training, the model is evaluated on the test dataset (`test_ds`) to obtain the test loss and accuracy, which are logged for analysis.

## Identifying the Best Epoch

1. **Best Epoch Calculation**: The epoch that achieved the highest validation accuracy during training is identified. This is done by finding the maximum value in `history.history["val_accuracy"]` and adding one (since epochs are zero-indexed).

2. **Logging the Best Epoch**: A log message is generated to indicate the best epoch for training.

## Final Training

1. **Rebuilding the Model**: The model is rebuilt with the optimal hyperparameters to reset its state.

2. **Final Training Run**: The model undergoes a final training run with the previously identified best epoch number. This approach ensures that the model is not overfitted to the validation dataset.

This process of building the hypermodel with the best hyperparameters, identifying the best epoch, and conducting a final training run optimizes the model's performance, ensuring it is well-tuned for making accurate predictions on unseen data.

In [None]:

__model: Model = tuner.hypermodel.build(best_hps)


if context is not None:
    context.model = __model
    logger.info(f"Re-using the model {context.name}")

else:
    context = ModelFactory.create(__model)
    logger.info(f"Created a new model {context.name}")


context.model.summary(
    expand_nested=True,
)


In [None]:

history = context.model.fit(
    train_ds,
    callbacks=FIT_CALLBACKS(context.name),
    validation_data=val_ds,
    epochs=10000,
    verbose=1,
)

eval_result = context.model.evaluate(test_ds)
logger.info(f"Test loss: {eval_result[0]}")
logger.info(f"Test accuracy: {eval_result[1]}")


In [None]:

best_epoch = history.history["val_accuracy"].index(
    max(history.history["val_accuracy"])
) + 1


logger.debug(f"Best epoch: {best_epoch}")
logger.info(f"Re-instantiate the hypermodel and train it with the optimal number of epochs {best_epoch}.")

context.model = tuner.hypermodel.build(best_hps)

context.model.fit(
    train_ds,
    callbacks=FIT_CALLBACKS(context.name),
    validation_data=val_ds,
    epochs=best_epoch,
    verbose=1,
)

eval_result = context.model.evaluate(test_ds)
logger.info(f"Test loss: {eval_result[0]}")
logger.info(f"Test accuracy: {eval_result[1]}")


# Model Prediction and Evaluation

This section demonstrates the process of making predictions with the trained model on the test dataset, followed by evaluating the model's performance using various metrics. Here's a breakdown of the steps involved:

## Making Predictions

- **Model Prediction**: `context.model.predict(test_ds, verbose=1)` generates predictions for the test dataset.

## Preparing Actual and Predicted Labels

- **Extracting Actual Labels**: Actual labels are extracted from the test dataset and converted into a NumPy array for comparison with the predicted labels.
- **Determining Predicted Labels**: The `argmax` function is applied to the predictions to convert the model's output probabilities into class labels.

## Model Performance Metrics

- **Accuracy**: The proportion of correctly predicted observations to the total observations.
- **Precision**: The ratio of correctly predicted positive observations to the total predicted positives. Here, it's calculated with `average='micro'` to aggregate the contributions of all classes.
- **Sensitivity (Recall)**: The ratio of correctly predicted positive observations to all actual positives. Also calculated with `average='micro'`.
- **Specificity**: Calculated as the recall for the negative class, but here it's equivalent to sensitivity since `pos_label=0` and `average='micro'` are used.
- **F1 Score**: The weighted average of Precision and Recall. Using `average='micro'` means calculating metrics globally by counting the total true positives, false negatives, and false positives.

## Confusion Matrix

- **Generation**: A confusion matrix is generated using `confusion_matrix(actual, predicted)`, providing a summary of the prediction results.
- **Visualization**: The confusion matrix is visualized using `ConfusionMatrixDisplay`, with class names as labels. The plot is displayed with a blue color map to make it easier to interpret the model's performance across different classes.

This section of the code is crucial for understanding how well the model performs on unseen data, identifying areas where the model excels or struggles, and guiding further improvements.

In [None]:

predictions = context.model.predict(
    test_ds,
    verbose=1,
)

logger.debug(f"Predictions shape: {predictions.shape}")
logger.debug(f"Predictions\n{predictions}")


actual = np.array([l.numpy() for _, l in test_ds])
predicted = np.argmax(predictions, axis=-1)

logger.debug(f"Actual shape: {actual.shape}")
logger.debug(f"Actual values\n{actual}")

logger.debug(f"Predicted shape: {predicted.shape}")
logger.debug(f"Predicted values\n{predicted}")


logger.info(f"Accuracy: {accuracy_score(actual, predicted)}")
logger.info(f"Precision: {precision_score(actual, predicted, average='micro')}")
logger.info(f"Sensitivity recall: {recall_score(actual, predicted, average='micro')}")
logger.info(f"Specificity: {recall_score(actual, predicted, pos_label=0, average='micro')}")
logger.info(f"F1 score: {f1_score(actual, predicted, average='micro')}")


cm = confusion_matrix(actual, predicted)

logger.debug(f"Confusion Matrix\n{cm}")


cm_display = ConfusionMatrixDisplay(
    confusion_matrix=cm,
    display_labels=class_names,
)

cm_display.plot(
    cmap="Blues",
    ax=plt.subplots(figsize=(9, 9))[1]
)

plt.show()


# Model Summary and Saving

This section performs two key actions related to the final model:

1. **Model Summary**: `context.model.summary()` displays a summary of the model's architecture, including details about the layers, their shapes, and the total number of parameters. This overview is crucial for understanding the model's structure and complexity.

2. **Saving the Model**: `context.save()` saves the entire model into a single file. This includes the architecture, weights, and training configuration, allowing for easy deployment or further training in the future.

In [None]:
context.model.summary()
context.save()
