# HiRISENet-Tiny

A tiny neural network classifier for Mars HiRISE images

## Introduction

The Mars HiRISE (High Resolution Imaging Science Experiment) [[1]](#1) is a camera on board the Mars Reconnaissance
Orbiter which has been orbiting and studying Mars since 2006. A product of this payload's years of service has been the
curation of the Mars orbital image (HiRISE) labeled data set [[2]](#2) by NASA. This dataset provides a curated set of
labelled images of the Martian terrain from an orbital perspective. **Figure 1** shows one of the first images taken
with the HiRISE camera.

<!--- Public Domain, https://commons.wikimedia.org/w/index.php?curid=656514 >
--->
![First HiRISE Image](../assets/images/mro_first_image.jpg "Figure 1. The first orbital image captured by HiRISE")

**Figure 1. Crop of one of the first images of Mars from the HiRISE camera.**

In 2018, Wagstaff et. al. [[3]](#3) set out to train a deep learning model to enable scientists & researchers to conduct
advanced queries of the HiRISE data in NASA's planetary imagery database. With the first edition of the dataset, the
authors fine-tuned the AlexNet convolutional neural network on the data. This initial dataset contained 3,820 
greyscale images and consisted of the labels crater, dark dune, bright dune, dark slope streak, other and edge. With
this dataset, the model achieved 99.1%, 88.1%, and 90.6% accuracy across their training, validation, and test sets
respectively.

Following this effort, Wagstaff et. al. published a follow up paper [[4]](#4) which expanded the initial dataset to
a version 3.2. This expanded dataset consists of a total of 64,947 landmark images. These images have been preprocessed
and cropped to a 227x227 size similar to the first edition. In v3.2, a subset of these images, 9,022, were augmented 
using a variety of image augmentation techniques applied individually to the subset of images. This artificially
expanded the available data. The authors also introduced the classes impact ejecta, spiders, and swiss cheese while
removing edge.

![HiRISE v3.2 Class Imagse](../assets/images/hirise_v3_classes.png)

The authors also fine-tuned the AlexNet model on the new dataset. For this iteration, the authors analyzed the model's
confidence and utilized calibration techniques to make the model more reliable. With this approach and improved dataset,
the authors managed to improve the model's classification accuracy to
{"train": 99.6%, "val": 88.6%, "test": 92.8%}.

![Dataset Class Distribution](../assets/images/hirise_dataset_class_distribution.png)

An important thing to note is that the class distribution is pretty unbalanced. Images of "Other" significantly
dominate the dataset while "Impact Ejecta" constitutes a small portion of the dataset.

### About this Project

In this project I challenged myself to develop a pipeline for training and
deployment of a neural network to run on the NVIDIA Jetson Nano 2GB. This neural
network shall be able to classify the greyscale HiRISE images.

My goal is to train a modern convolutional neural network that is more resource
efficient than AlexNet and, hopefully, as good as the author's trained AlexNet.

As a flight software engineer, I am focused on creating a trusted model that is
also resource efficient. As part of this project I also want to analyze and
evaluate the resource utilization and performance of the model.

Previous work by Dunkel et. al. [[5]](#5) provides some clues as to metrics
machine learning practitioners should be cognisant of for a space-based
environment. Based on the information the authors provided, I have derived some
requirements that outline metrics I will be collecting.

Key metrics I would like to collect include:

* Inference time of the model on a single input image
* Peak RAM utilization of the model
* Disk utilization of the model
* Energy consumption of the model

#### Requirements

| Requirement | Description                                                                                    | Description                                           | Verification Method |
|-------------|------------------------------------------------------------------------------------------------|-------------------------------------------------------|---------------------|
| HIRISE-001  | A neural network model shall be trained on the HiRISE v3.2 dataset.                            | The goal of the project is to classify HiRISE images. | Test Set Evaluation |
| HIRISE-002  | A neural network model shall achieve a minimum of 80% accuracy on the test dataset.            | The model needs to be nearly as good as AlexNet.      | Test Set Evaluation |
| HIRISE-003  | A neural network model shall execute on the NVIDIA Jetson Nano 2GB.                            | The flight software board is the NVIDIA Jetson Nano.  | Inspection          |
| HIRISE-004  | A neural network model inference time shall not exceed 1,069 ms on the NVIDIA Jetson Nano CPU. | The existing benchmark for HiRISENet is 1,069 ms.     | Profiling Test      |
| HIRISE-005  | A neural network model inference time shall not exceed 234 ms on the NVIDIA Jetson Nano GPU.   | The existing benchmark for HiRISENet is 234 ms.       | Profiling Test      |
| HIRISE-006  | A neural network energy consumption shall not exceed 10.7 J on the NVIDIA Jetson Nano CPU.     | The existing benchmark for HiRISENet is 10.7 J.       | Profiling Test      |
| HIRISE-007  | A neural network energy consumption shall not exceed 2.3 J on the NVIDIA Jetson Nano CPU.      | The existing benchmark for HiRISENet is 2.3 J.        | Profiling Test      |
| HIRISE-008  | A neural network model parameter size shall not exceed 233 MB.                                 | AlexNet in PyTorch 2.5 is ~233 MB in size.            | Profiling Test      |

------------------------------

## Setup

Lets get some basic project infrastructure set up first.

In [1]:
import sys

sys.path.append("data")

In [5]:
import torch
from utils.gpu_management import GPUManager

from data.dataset import HiRISE
from data.split_type import SplitType
import torchvision
from torchvision.transforms import v2
from tempfile import TemporaryDirectory
from pathlib import Path

torch.manual_seed(42)
# ? How does this work?
# torch.use_deterministic_algorithms(True)

device = GPUManager.enable_gpu_if_available()

if device.type == "cuda":
    GPUManager.cleanup_gpu_cache()

__CUDA VERSION: 90100
__Number CUDA Devices: 1
__CUDA Device Name: NVIDIA RTX A2000 8GB Laptop GPU
__CUDA Device Total Memory [GB]: 8.58947584


In [6]:
temp_dir = TemporaryDirectory()
temp_dir.name
temp_dir.name = "/tmp/tmpr2oojxfd"

### Exploratory Data Analysis

The first thing we should do is gain a deep understanding of our dataset. As
mentioned in the introduction, the model we train is a direct reflection of our
data. That is, the model _is_ the data.

The HiRISE v3.2 [[2]](#2) dataset has the following characteristics:

* **Number of Classes:** The dataset comprises of 8 different classes, each representing different Martian terrain features: bright dune, crater, dark dune, impact ejecta, other, slope streak, spider, swiss cheese.

* **Images:** The dataset contains 64,947 227x227 images. The images have been pre-processed and a subset of the original images were augmented with the following transformations: 90° clockwise rotation, 180° clockwise rotation, 270° clockwise rotation, horizontal flip, vertical flip, random brightness adjustment.

* **File Format:** The images are stored in JPG format.

* **Directory Structure:** The data has been split by the authors into a training, validation, and test split.

Before we can feed our dataset into a model, we need to transform it into a
format that the model can understand. For this, we use the `transforms` module
from `torchvision`.

In [7]:
data_transforms = {
    "train": v2.Compose(
        [
            v2.Resize((227, 227)),
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
            # v2.Normalize(mean=[0.5], std=[0.5], inplace=True)
        ]
    ),
    "val": torchvision.transforms.v2.Compose(
        [
            v2.Resize((227, 227)),
            v2.ToImage(),
            v2.ToDtype(torch.float32, scale=True),
            # v2.Normalize(mean=[0.5], std=[0.5], inplace=True)
        ]
    ),
    "test": torchvision.transforms.v2.Compose(
        [v2.Resize((227, 227)), v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]
    ),
}

At this point, the transforms we use are pretty standard for image data.
We first resize each image to 227x227. The images should already be this size
but you can never be too careful. The `v2.ToImage()` and `v2.ToDtype(...)`
are the new PyTorch v2 recommended way of transforming an image to a tensor.
The old `ToTensor()` is deprecates as of this writing.

Lastly, the `v2.Normalize()` transformation normalizes the tensor image. The
values are simply the midpoint of [0.0, 1.0].
Normalization helps improve the convergence speed and overall performance of
the model.

In [4]:
data_folder: Path = Path("/tmp/.hirise")

train_dataset = HiRISE(
    root_dir=data_folder,
    split_type=SplitType.TRAIN,
    transform=data_transforms["train"],
    download=True,
)

val_dataset = HiRISE(
    root_dir=data_folder,
    split_type=SplitType.VAL,
    transform=data_transforms["val"],
    download=True,
)

test_dataset = HiRISE(
    root_dir=data_folder,
    split_type=SplitType.TEST,
    transform=data_transforms["test"],
    download=True,
)

SyntaxError: expected argument value expression (3108753032.py, line 2)

Now that we have our datasets, lets see what these images actually look like.
Lets plot one of each class.

In [None]:
train_dataset.show_image_per_class()

From the images we show, it's clear that some of these terrains are very
similar to each other. Two particular classes that stand out to me are the
crater and impact ejecta. They both look like craters with the distinguishing
feature being the brighter material surrounding the impact crater. 

It also seems like a crater in this dataset is an impact zone that is larger
than an impact ejecta image. So then, what if the impact ejecta has been covered
or eroded in the image? I wonder if the model will have some difficulties with
that.

As mentioned by the authors of the dataset, the "other" class is a catch-all
class to capture anything that doesn't quite fit into the other classes. They
also mention that this "other" class makes up the overwhelming majority of the
data.

Lets take a look at that now.

In [None]:
train_dataset.show_class_distribution()

From the distribution plot we can see that the "Other" landmarks significantly
dominate the class distribution. This clear imbalance of data will need to be
accounted for in our model training.

### Setting Up the Model

To train our model we first need to create `DataLoader` for each of our data
splits. The `DataLoader` is a PyTorch class that facilitates iterating on our
data as we train.

In [None]:
dataloaders = {
    name: torch.utils.data.DataLoader(
        dataset, batch_size=32, shuffle=True, num_workers=8, pin_memory=True
    )
    for name, dataset in zip(
        ["train", "val", "test"], [train_dataset, val_dataset, test_dataset]
    )
}

The authors of HiRISENet used a pre-trained AlexNet model and fine-tuned it to
the HiRISE dataset. They achieved incredible results with that model and employed
techniques like calibration to guarantee a fair and balanced model.

In this project I was first enticed by the idea of using a pre-trained vision 
transformer. However, given the limited size of the dataset I opted for Liu et al. 
"ConvNeXt" [[6]](#6) architecture. Convolutional neural networks are a more
mature and well understood architecture. As a flight software engineer I like
ol' reliable tried and true over flashy state-of-the-art. Additionally, the
ConvNeXt architecture performs just as well as vision transformers on some
tasks "while maintaining the simplicity and efficiency of standard ConvNets."

For the ConvNext architecture, the authors opted for layer normalization instead
of batch normalization. The former being a bit more computationally efficient.
Additionally, the authors implemented depthwise separable convolutions for
even more computational efficiency. Other parts of this architecture "borrowed"
or were inspired by mechanisms in vision transformers like self-attention.

For this project, we will use the "tiny" version of the ConvNeXt model which is
pre-trained on ImageNet data. We will then fine-tune the model i.e. transfer
learning to get it to learn our data.

In [None]:
model = torchvision.models.convnext_tiny(
    weights=torchvision.models.ConvNeXt_Tiny_Weights.IMAGENET1K_V1
)

The existing ConvNeXt model expects 3 input channels while the HiRISE image
data is single channel. This means we have to modify the input layer of the model
to accomodate for this discrepancy.

In [None]:
# Only change the input channel and keep everything else the same
model.features[0][0] = torch.nn.Conv2d(
    1, 96, kernel_size=(4, 4), stride=(4, 4), device=device
)

# Move the model to the GPU if available
model.to(device);

In this project, I simply utilized a pre-trained ConvNeXt model that's available
from PyTorch utilizing the the smaller IMAGENET 1K weights. This makes the model
size ~109 MB which fullfills our HIRISE-008 requirement.

In [None]:
from utils import model_insights

model_insights.calc_model_size(model, show=True);

### Initial Model Evaluation

To follow some machine learning best practices, we should first establish a
baseline. That means no crazy augmentations and even a stupid simple model. The
objective with the initial training round is to just make sure everything is
set up correctly and that our stupid model can train on the data. For this
iteration we want to reduce the number of knobs and variables to the absolute
minimum.

Since the authors of the HiRISE dataset have done a lot of the the heavy
lifting for us, we will jump right into training our model. For this round we
will skip any regularization techniques and just train the model directly on
the available data. We will modify our setup like adding augmentations based on
the initial results.

First, we need to set up some hyperparameters.

In [None]:
LEARNING_RATE: float = 1e-3
BATCH_SIZE: int = 32
EPOCHS: int = 10

For the loss function we will use `CrossEntropyLoss`. This is a fairly standard
loss function and is commonly used in machine learning. Similarly, the `Adam`
optimizer is standard and pretty foolproof.

In [None]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=LEARNING_RATE)
# ? Do we need a scheduler?
# exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

In [None]:
dataset_sizes = {
    name: len(dataset)
    for name, dataset in zip(
        ["train", "val", "test"], [train_dataset, val_dataset, test_dataset]
    )
}

In [None]:
import time
import os


def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()  # Set model to training mode
    running_loss = 0.0
    running_corrects = 0

    # Iterate over data.
    for inputs, labels in dataloader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, labels)

        # backward + optimize
        loss.backward()
        optimizer.step()

        # statistics
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    return running_loss, running_corrects


def validate_epoch(model, dataloader, criterion, device):
    model.eval()  # Set model to evaluate mode
    running_loss = 0.0
    running_corrects = 0

    # Iterate over data.
    for inputs, labels in dataloader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        # forward
        with torch.no_grad():
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)

        # statistics
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    return running_loss, running_corrects


def train_model(model, criterion, optimizer, scheduler, num_epochs=3):
    since = time.time()

    # Create a temporary directory to save training checkpoints
    with TemporaryDirectory() as tempdir:
        best_model_params_path = os.path.join(tempdir, "best_model_params.pt")
        torch.save(model.state_dict(), best_model_params_path)
        best_acc = 0.0
        history = []

        for epoch in range(num_epochs):
            # Each epoch has a training and validation phase
            train_loss, train_corrects = train_epoch(
                model, dataloaders["train"], criterion, optimizer, device
            )
            # scheduler.step()
            val_loss, val_corrects = validate_epoch(
                model, dataloaders["val"], criterion, device
            )

            train_loss /= dataset_sizes["train"]
            train_acc = train_corrects.double() / dataset_sizes["train"]
            val_loss /= dataset_sizes["val"]
            val_acc = val_corrects.double() / dataset_sizes["val"]

            history.append([train_acc, val_acc, train_loss, val_loss])
            print(
                f"Epoch {epoch}/{num_epochs - 1}: "
                f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f}, "
                f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}"
            )

            # deep copy the model
            if val_acc > best_acc:
                best_acc = val_acc
                torch.save(model.state_dict(), best_model_params_path)

        time_elapsed = time.time() - since
        print(
            f"\nTraining complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s"
        )
        print(f"Best val Acc: {best_acc:.4f}")

        # load best model weights
        model.load_state_dict(torch.load(best_model_params_path))

    return model, history

In [None]:
model_ft, model_ft_history = train_model(
    model, criterion, optimizer, None, num_epochs=1
)

-------------------------------------

## References

<a id="1">[1]</a> 
Wikipedia contributors. (2024, November 19). HiRISE. Wikipedia. https://en.wikipedia.org/wiki/HiRISE

<a id="2">[2]</a> 
Gary Doran, Emily Dunkel, Steven Lu, & Kiri Wagstaff. (2020). Mars orbital image (HiRISE) labeled data set version 3.2 (3.2.0) [Data set]. Zenodo. https://doi.org/10.5281/zenodo.4002935

<a id="3">[3]</a>
Wagstaff, K., Lu, Y., Stanboli, A., Grimes, K., Gowda, T., & Padams, J. (2018). Deep Mars: CNN Classification of Mars Imagery for the PDS Imaging Atlas. Proceedings of the AAAI Conference on Artificial Intelligence, 32(1). https://doi.org/10.1609/aaai.v32i1.11404

<a id="4">[4]</a>
Wagstaff, Kiri, et al. Mars Image Content Classification: Three Years of NASA Deployment and Recent Advances. arXiv:2102.05011, arXiv, 9 Feb. 2021. arXiv.org, https://doi.org/10.48550/arXiv.2102.05011.

<a id="5">[5]</a>
Dunkel, Emily R., et al. “Benchmarking Deep Learning Models on Myriad and Snapdragon Processors for Space Applications.” Journal of Aerospace Information Systems, vol. 20, no. 10, Oct. 2023, pp. 660–74. DOI.org (Crossref), https://doi.org/10.2514/1.I011216.

<a id="6">[6]</a>
Liu, Zhuang, et al. A ConvNet for the 2020s. arXiv:2201.03545, arXiv, 2 Mar. 2022. arXiv.org, https://doi.org/10.48550/arXiv.2201.03545.
