# HiRISENet-Tiny

A tiny neural network classifier for Mars HiRISE images

## Introduction

The Mars HiRISE (High Resolution Imaging Science Experiment) [[1]](#1) is a camera on board the Mars Reconnaissance
Orbiter which has been orbiting and studying Mars since 2006. A product of this payload's years of service has been the
curation of the Mars orbital image (HiRISE) labeled data set [[2]](#2) by NASA. This dataset provides a curated set of
labelled images of the Martian terrain from an orbital perspective. **Figure 1** shows one of the first images taken
with the HiRISE camera.

<!--- Public Domain, https://commons.wikimedia.org/w/index.php?curid=656514 >
--->
![First HiRISE Image](../assets/images/mro_first_image.jpg "Figure 1. The first orbital image captured by HiRISE")

**Figure 1. Crop of one of the first images of Mars from the HiRISE camera.**

In 2018, Wagstaff et. al. [[3]](#3) set out to train a deep learning model to enable scientists & researchers to conduct
advanced queries of the HiRISE data in NASA's planetary imagery database. With the first edition of the dataset, the
authors fine-tuned the AlexNet convolutional neural network on the data. This initial dataset contained 3,820 
greyscale images and consisted of the labels crater, dark dune, bright dune, dark slope streak, other and edge. With
this dataset, the model achieved 99.1%, 88.1%, and 90.6% accuracy across their training, validation, and test sets
respectively.

Following this effort, Wagstaff et. al. published a follow up paper [[4]](#4) which expanded the initial dataset to
a version 3.2. This expanded dataset consists of a total of 64,947 landmark images. These images have been preprocessed
and cropped to a 227x227 size similar to the first edition. In v3.2, a subset of these images, 9,022, were augmented 
using a variety of image augmentation techniques applied individually to the subset of images. This artificially
expanded the available data. The authors also introduced the classes impact ejecta, spiders, and swiss cheese while
removing edge.

![HiRISE v3.2 Class Imagse](../assets/images/hirise_v3_classes.png)

The authors also fine-tuned the AlexNet model on the new dataset. For this iteration, the authors analyzed the model's
confidence and utilized calibration techniques to make the model more reliable. With this approach and improved dataset,
the authors managed to improve the model's classification accuracy to
{"train": 99.6%, "val": 88.6%, "test": 92.8%}.

![Dataset Class Distribution](../assets/images/hirise_dataset_class_distribution.png)

An important thing to note is that the class distribution is pretty unbalanced. Images of "Other" significantly
dominate the dataset while "Impact Ejecta" constitutes a small portion of the dataset.

### About this Project

In this project I challenged myself to develop a pipeline for training and inference of a neural network to run on the
NVIDIA Jetson Nano 2GB that classifies HiRISE images.

Given the resource constraints of space-based computing platforms, I set out to train a neural network that would
hypothetically be deployed on-orbit to classify images. My goal is to train a model that is more resource efficient
than the author's AlexNet and be almost (🤞) as good or, hopefully, better.

Most importantly, this project is meant to be fun filled with learning opportunities on training a model, making sure
the model is reliable, and exploring approaches on how to make a model more efficient. I am, after all, a spacecraft
flight software engineer.

For this project I want to follow Andrej Karpathy's
"[A Recipe for Training Neural Networks](https://karpathy.github.io/2019/04/25/recipe/)" and follow some best practices
described by a leader in AI.

------------------------------

## Setup

Lets get some basic project infrastructure set up first.

In [None]:
import sys

sys.path.append("data")

In [None]:
import torch
from utils.gpu_management import GPUManager

from data.dataset import HiRISE
from data.split_type import SplitType
from torchvision.transforms import v2

torch.manual_seed(42)
torch.use_deterministic_algorithms(True)

device = GPUManager.enable_gpu_if_available()

if device.type == "cuda":
    GPUManager.cleanup_gpu_cache()

## 1. Become one with the data

Per Karpathy, the most important rule is that we become intimately familiar
with the data we are working with. That is because our trained neural network
is a direct reflection of the data it was trained on. That is, the model _is_
the data!

Fortunately, for this dataset, the authors have done a lot of heavy lifting in
preparing and curating a high quality dataset.

What would be of great benefit to us is to visualize and understand the
different types of features we are training the model to classify.

Lets take a look.

First, lets create dataset objects to be able to easily load and view our data.

In [None]:
# Lets start with some basic ones. We don't want to augment just yet.
img_transforms: v2.Compose = v2.Compose(
    [v2.Resize((227, 227)), v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]
)

In [None]:
train_dataset = HiRISE(
    root_dir="/tmp/.dataset",
    split_type=SplitType.TRAIN,
    transform=img_transforms,
    download=True,
)

val_dataset = HiRISE(
    root_dir="/tmp/.dataset",
    split_type=SplitType.VAL,
    transform=img_transforms,
    download=False,
)

test_dataset = HiRISE(
    root_dir="/tmp/.dataset",
    split_type=SplitType.TEST,
    transform=img_transforms,
    download=False,
)

Now that we have our datasets, lets see what these images actually look like.
Lets plot one of each class.

In [None]:
train_dataset.show_image_per_class()

From the images we show, it's clear that some of these terrains are very
similar to each other. Two particular classes that stand out to me are the
crater and impact ejecta. They both look like craters with the distinguishing
feature being the brighter material surrounding the impact crater. 

It also seems like a crater in this dataset is an impact zone that is larger
than an impact ejecta image. So then, what if the impact ejecta has been covered
or eroded in the image? I wonder if the model will have some difficulties with
that.

As mentioned by the authors of the dataset, the "other" class is a catch-all
class to capture anything that doesn't quite fit into the other classes. They
also mention that this "other" class makes up the overwhelming majority of the
data.

Lets take a look at that now.

In [None]:
train_dataset.show_class_distribution()

From the distribution plot we can see that the "Other" landmarks significantly
dominate the class distribution. This clear imbalance of data will need to be
accounted for in our model training.

## 2. Set up the end-to-end training/evaluation skeleton + get dumb baselines

Next up, we'll want to setup our training & evaluation pipeline. We'll also
want to get some stupid simple and dumb baseline by using a stupid and simple
model.

For this phase we want to simply establish our baselines. To accomplish that,
we need to follow a good ol' software engineering principle: KISS.

That is, keep it simple stupid!



In [None]:
# Create DataLoader for each dataset
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=32, shuffle=True, num_workers=8, pin_memory=True
)
val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=32, shuffle=True, num_workers=8, pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=32, shuffle=True, num_workers=8, pin_memory=True
)

In [None]:
LEARNING_RATE: float = 1e-4
BATCH_SIZE: int = 32

In [None]:
import torchvision

model = torchvision.models.convnext_tiny(
    weights=torchvision.models.ConvNeXt_Tiny_Weights.IMAGENET1K_V1
)

model.features[0][0] = torch.nn.Conv2d(
    1, 96, kernel_size=(4, 4), stride=(4, 4), device=device
)

model.to(device)
model.train()

In [None]:
from utils import model_insights

model_insights.calc_model_size(model, show=True);

-------------------------------------

## References

<a id="1">[1]</a> 
Wikipedia contributors. (2024, November 19). HiRISE. Wikipedia. https://en.wikipedia.org/wiki/HiRISE

<a id="2">[2]</a> 
Gary Doran, Emily Dunkel, Steven Lu, & Kiri Wagstaff. (2020). Mars orbital image (HiRISE) labeled data set version 3.2 (3.2.0) [Data set]. Zenodo. https://doi.org/10.5281/zenodo.4002935

<a id="3">[3]</a>
Wagstaff, K., Lu, Y., Stanboli, A., Grimes, K., Gowda, T., & Padams, J. (2018). Deep Mars: CNN Classification of Mars Imagery for the PDS Imaging Atlas. Proceedings of the AAAI Conference on Artificial Intelligence, 32(1). https://doi.org/10.1609/aaai.v32i1.11404

<a id="4">[4]</a>
Wagstaff, Kiri, et al. Mars Image Content Classification: Three Years of NASA Deployment and Recent Advances. arXiv:2102.05011, arXiv, 9 Feb. 2021. arXiv.org, https://doi.org/10.48550/arXiv.2102.05011.