# Model Architecture

There's a lot of ways to tackle this problem, but I narrowed it down to a few ways that might work that we're familiar with just from the topics from the class already.

## Considerations

We have two inputs: IMU data, and images. Each comes in streams of different frequencies. Assuming the IMU data comes in at a frequency of X Hz, and the image data comes in at a lower frequency of Y Hz, we need to run our model at Y Hz using a single image X/Y number of samples from the IMU associated with that image.

We have an anomaly detection problem, as we need to make sense of our inputs and determine if something bad is happening. Thus, our model needs to learn what is considered normal SLAM and flight behavior for our drone. Since our system is evolving over time, we have two options. 

1. Make the model output some binary label that corresponds to SLAM failure or nominal operations. Alternatively, this can also be converted to a confidence using softmax, if we want. Then the model will predict if there is a SLAM failure within a given window, which we can compare to the GT trajectory of the drone. If the two diverge by a certain amount, label that point as a SLAM failure. Then we can run the model on the test set and we will manually label each window of the test as failure or nominal, based on whatever SLAM-VIO algorithm both us and the drone will use. This model will use binary cross entropy as the loss function, which will maximize the classification likelihood.
2. Make the model predict a path with initial condition being its current estimated point. Its output would be the path divergence over time, say the average within the next cycle, and we could conceive of a metric that determines if this is a SLAM failure or not. But the model will be training on the actual raw position data, instead of failure labels. This model will use L2 loss per physical direction (x,y,z), which will penalize large path estimation residuals. This will maximize the likelihood the model predicts the true path at each timestep.
    1. A prediction-based RNN can learn to forecast future pose readings. If the actual SLAM readings deviate heavily from the predicted readings, we can have the model flag that as a SLAM failure, as the SLAM algorithm is not functioning as expected which is a sign of failure. Models following this paradigm are outlined in [Darba et al. (2023)](https://arxiv.org/html/2211.05244v3#:~:text=learning%20architectures%20,based%20models%20are%20more%20effective).

I prefer case 2 as it is more grounded in its physical interpretations and can give us more things to plot and show for our presentation. Something to note is that anomalies are not labeled in our training data, so it is hard to make the model learn anomalies. Rather, what anomaly detection entails usually is learning **normal** behavior, and detecting abnormal behavior during testing. 

## RNN-based SLAM Anomaly Detector

If we go with model 2, we can make a RNN scheme that follows this data pipeline:

1. IMU branch (200 Hz): Processes raw IMU packets with a small RNN to produce one embedding per camera frame.

1. Vision branch (20 Hz): Encodes each image with a CNN backbone + optional per‐frame RNN to capture appearance dynamics.

1. Fusion RNN (20 Hz): At each camera frame, takes the concatenated IMU & vision embeddings and updates a shared hidden state.

1. Prediction head: From the fusion RNN’s hidden state, outputs a prediction of the next-cycle pose sequence (Δx,Δy,Δz over the next N frames) or directly the expected average drift.

1. Anomaly flag: During inference, compute the L₂ error between true SLAM output and your model’s prediction. If error > threshold → SLAM failure.


### IMU Branch

For this branch, the IMU samples are much faster than the camera, so collect them into a vector between the previous and current camera frame. This can be done with either a small RNN or a 1D CNN.

`[6×10] → Conv1D(filters=32,kernel=3,stride=1) → ReLU → Conv1D(64,3,1) → ReLU
       → GlobalAvgPool → Dense(D) → IMU_embed_t`

`for i in 1…10: h_i = GRU_cell(u_{t,i}, h_{i-1})
IMU_embed_t = Dense(h_{10})`

Input: 10 IMU samples (200 vs 20 Hz sampling), {u<sub>1</sub>, ..., u<sub>10</sub>}, where u is a 6 dim vector of IMU data.

Output: One learned embedding of IMU data per camera frame of D dimension. 

### Vision Branch

This branch is just a regular vision branch but can also include a RNN to capture temporal changes. That would look like:

1. CNN -> GAP (global average pooling) -> linear projection
    1. Image(t) ---> e<sub>t</sub> in R<sup>d</sup>
1. RNN update (here I use LSTM with hidden state h and cell c)
    1. (h<sub>t-1</sub>, c<sub>t-1</sub>), e<sub>t</sub> -> (h<sub>t</sub>, c<sub>t</sub>),
1. Hidden state -> vision embedding
    1. h<sub>t</sub> ---> Vision(t) in R<sup>d</sup>

Since this is the big CNN for image processing, it will take up the bulk of the computation. We should consider using existing models like MobileNet which are optimized to be more lightweight, especially since our project is more robotics oriented, and the extra weight savings will contribute a lot to real world applicability.

### Fusion RNN

This will fuse the two branches using a RNN to capture temporal relationships between IMU data and camera data. 

Concatenate vision and IMU embeeddings into a 2*D vector and pass into another RNN: `fus_h_t, fus_c_t = LSTM_cell(x_t, (fus_h_{t-1}, fus_c_{t-1}))`

### Prediction Head

The final part is predicting whatever we need to predict. We can do this one of two ways. 

First, we can get the expected future frames K for some small integer K, calculating the change in displacement and comparing that to actual change in displacement. At inference, we can find the mean and standard deviation of loss and if it is above some threshold, declare a failure.

Alternatively, we can direcly predict L2 drift in the next cycle. If that is above some threshold we set, declare failure. The training should minimize this drift if we set the drift to be the loss function itself, and whenever the model encounters high loss, it will see whatever is happening is very different from what it's used to, which is good flying data, and declare failure.

```mermaid
graph TD
    IMU[IMU 200Hz] --> RNN1[RNN or 1D CNN]
    RNN1 --> I[IMU Feature Iₜ]

    CAM[Camera 20Hz] --> ResNet[ResNet CNN]
    ResNet --> SAP
    SAP --> Linear
    Linear --> V[Cam Feature Vₜ]

    I --> RNN2
    V --> RNN2

    RNN2 --> Head[MLP or RNN]
    Head --> Drift[Drift Lₐ]
    Head --> Displacement[Next k Δ displacement]
```

In [None]:
# Example of nn.Module using lightning (this is old pytorch lightning, new one is just lightning as L)
import pytorch_lightning as pl
import torch


class DriftPredictor(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = MyCustomNet()  # Define as nn.Module
        self.criterion = nn.MSELoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        loss = self.criterion(preds, y)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

In [None]:
# Example lightning code for RNN
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import lightning.pytorch as pl


# ---- Dummy Dataset ----
class DummyIMUDataset(Dataset):
    def __init__(self, seq_len=10, num_samples=1000):
        self.seq_len = seq_len
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        imu_seq = torch.randn(self.seq_len, 6)  # e.g. 3-axis accel + 3-axis gyro
        delta_p = imu_seq.sum(dim=0)[-3:] * 0.1  # fake "position delta"
        return imu_seq, delta_p


# ---- Lightning Module ----
class DriftPredictorRNN(pl.LightningModule):
    def __init__(self, input_dim=6, hidden_dim=64, output_dim=3, num_layers=1):
        super().__init__()
        self.rnn = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        _, (hn, _) = self.rnn(x)  # hn: (num_layers, batch, hidden_dim)
        return self.fc(hn[-1])  # Take final hidden state

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.functional.mse_loss(y_hat, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.functional.mse_loss(y_hat, y)
        self.log("val_loss", loss)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)


# ---- Data ----
train_ds = DummyIMUDataset()
val_ds = DummyIMUDataset()
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=32)


# ---- Train ----
model = DriftPredictorRNN()
trainer = pl.Trainer(max_epochs=10, accelerator="cpu")
trainer.fit(model, train_loader, val_loader)