# VORTEX
**Brief:** Build an interactive application for testing different models under distribution shifts.

Common acquisition-related perturbations in MRI, such as changes in signal-to-noise (SNR) and patient motion, can substantially degrade the quality of the reconstructed images.
VORTEX helps mitigate this problem by building invariance to these perturbations during model training. VORTEX also helps reduce the amount of fully-sampled data required for training.

In this demo, we will use [Meddlr](https://github.com/ad12/meddlr) and [Meerkat](https://github.com/HazyResearch/meerkat) to build interactive applications to explore how
different models perform these two distribution shifts. We will learn how to:

- Convert Meddlr datasets to Meerkat dataframes
- Use pre-built interfaces in Meerkat to visualize our data
- Integrate pre-trained or custom models into Meerkat

**Reference:**
    Desai et al. VORTEX: Physics-Driven Data Augmentations Using Consistency
    Training for Robust Accelerated MRI Reconstruction. MIDL 2022.

**Requirements:**
- `pip install meerkat-ml meddlr`
- `pip install torch torchvision`

In [1]:
import meddlr as mr
import meerkat as mk

%load_ext autoreload
%autoreload 2

## Start Meerkat Server
Starting the meerkat server will allow us to interact with the Meerkat application in the notebook.

If you do not see a view at the bottom of the notebook, change the `api_port` and `frontend_port` and restart the notebook.

**Remote Server:** If you are running this notebook on a remote machine, you will need to forward the api and frontend ports to your local machine:

```bash
# If api_port = 5000 and frontend_port = 8000
ssh -L 5000:localhost:5000 -L 8000:localhost:8000 <user>@<remote-machine>
```

In [2]:
mk.gui.start(api_port=5000, frontend_port=8000, dev=False)

(APIInfo(api=<fastapi.applications.FastAPI object at 0x2ba03fdc0>, port=5000, server=<meerkat.interactive.server.Server object at 0x2ba2242e0>, name='127.0.0.1', shared=False, process=None, _url=None),
 FrontendInfo(package_manager='npm', port=8000, name='localhost', shared=False, process=<subprocess.Popen object at 0x2ba224310>, _url=None))

## Build DataFrame
Meerkat DataFrames help manage complex data types, such as high dimensional images, kspace, etc.

Let's convert the mridata Stanford 3D knee FSE test split into a Meerkat DataFrame.
The dataset is in the ismrmrd HDF5 format, with an additional field for sensitivity maps (`maps`).

Each row in the dataframe will correspond to of axial slices of scans from the knee dataset.
Columns will include:

- `kspace`: The full-sampled kspace for the `ky x kz` slice
- `target`: The ground truth slice
- `maps`: The sensitivity maps for the slice

**Note:** If you have the dataset downloaded locally with `meddlr`, you can fetch the dataset using `DatasetCatalog`. 

In [3]:
from meddlr.data import DatasetCatalog

# mridata Stanford 3D FSE dataset.
paths = [
    "https://huggingface.co/datasets/arjundd/mridata-stanford-knee-3d-fse/resolve/main/files/ec00945c-ad90-46b7-8c38-a69e9e801074.h5",
    "https://huggingface.co/datasets/arjundd/mridata-stanford-knee-3d-fse/resolve/main/files/ee2efe48-1e9d-480e-9364-e53db01532d4.h5",
    "https://huggingface.co/datasets/arjundd/mridata-stanford-knee-3d-fse/resolve/main/files/efa383b6-9446-438a-9901-1fe951653dbd.h5",
]

# If you have the Stanford 3D FSE dataset downloaded locally, you can use this:
# dataset_dicts = DatasetCatalog.get("mridata_knee_2019_test")
# paths = [d["file_name"] for d in dataset_dicts]

In [4]:
# Convert paths to the slice dataframe.
from meddlr_viz.utils import build_slice_df

df = build_slice_df(paths, defer=True)
df["id"] = df["path"].apply(lambda x: x.split("/")[-1].split(".")[0])



## Reconstruction Models
Meddlr offers several reconstruction models in its Model Zoo and easy-to-use [APIs](https://github.com/ad12/meddlr#-model-zoo).

Let's start by using a few pre-trained models from the VORTEX paper. These models are hosted on huggingface in the Meddlr format.
Providing the urls for these models will automatically download and load them in!

In [5]:
# Pre-trained models from the VORTEX paper.
# More pre-trained models are available at https://github.com/ad12/meddlr/blob/main/projects/vortex/MODEL_ZOO.md
MODELS = {
    "Supervised": "https://huggingface.co/arjundd/noise2recon-release/resolve/main/mridata_knee_3dfse/12x/Supervised_1sub",
    "Supervised + Aug": "https://huggingface.co/arjundd/vortex-release/resolve/main/mridata_knee_3dfse/Aug_Physics",
    "SSDU": "https://huggingface.co/arjundd/vortex-release/resolve/main/mridata_knee_3dfse/SSDU",
    "VORTEX": "https://huggingface.co/arjundd/vortex-release/resolve/main/mridata_knee_3dfse/VORTEX_Physics",
}

### Aside: Adding your own models

Interested in using your own models? No problem! Just write a wrapper module for your model.

Let's make a dummy model that takes in kspace and returns the zero-filled reconstruction.

In [6]:
import torch
from torch import nn
from typing import Dict
from meddlr.forward.mri import SenseModel

class ZeroFilledModel(nn.Module):
    def forward(self, inputs: Dict[str, torch.Tensor]):
        A = SenseModel(inputs["maps"], weights=inputs["mask"])
        return A(inputs["kspace"], adjoint=True)

MODELS["Dummy Model"] = ZeroFilledModel()

## `MRIPerturbationInference` Interface

We added `MRIPerturbationInference` to Meerkat's suite of pre-build interfaces.

In this interface we can:
- Interactively control the SNR, 1D translational motion extent, and acceleration we apply to the k-space
- Toggle what models we want to test
- Change the scans that we want to visualize

This interface gives a quick way to visualize results from your models without having the overhead of writing the scans to disk.

**Note:** All reconstructions are computed dynamically. If you are using a CPU, this may take a while. If you have access to a GPU, we recommend using it for this demo.

In [7]:
from meddlr_viz.gui.perturbation import MRIPerturbationInference
%load_ext autoreload
%autoreload 2
view = MRIPerturbationInference(df, models=MODELS, acc=(12, 24, 1))
view._get_ipython_height = lambda: "600px"

view

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
