# Fully unsupervised dynamic MRI reconstruction via geometrotemporal equivariance

Paper | [Repo](https://github.com/Andrewwango/ddei) | [Website](https://andrewwango.github.io/ddei)

![](img/demo_results.gif)

**Aim**: reconstruct dynamic MRI videos from accelerated undersampled measurements.

**Applications**: real-time cardiac imaging, free-breathing motion, vocal tract speech...

**Goals**:

- Capture true motion: aperiodicities, irregularities: real-time MRI
- Capture higher spatiotemporal resolution with fewer measurements (leading to faster, cheaper, portable MRI)

**Why is it hard?** ground truth is impossible to truly obtain! There is no such thing as true fully-sampled dynamic MRI data at the same frame rate. Hence all supervised methods are fundamentally flawed - a [_data crime_](https://www.pnas.org/doi/full/10.1073/pnas.2117203119). The best pseudo-ground-truth, e.g. retrospective gating/cine imaging, must assume periodicity and all methods that use this cannot capture true motion - ultimately what we want to image. Therefore we need unsupervised methods.

**Our method** we posit that the unknown set of MRI videos is $G$-invariant. 

[Equivariant Imaging](https://openaccess.thecvf.com/content/ICCV2021/papers/Chen_Equivariant_Imaging_Learning_Beyond_the_Range_Space_ICCV_2021_paper.pdf)

etc.

You can easily implement our method using the [`deepinv`](https://deepinv.github.io) library. See [train.py](train.py) for a full training demo including training and evaluating competitors. For example:

In [3]:
import torch
from torch.utils.data import DataLoader

import deepinv as dinv

from utils import Trainer, ArtifactRemovalCRNN, CRNN, DeepinvSliceDataset, CineNetDataTransform

### Define dynamic MRI physics:

Define accelerated dynamic MRI. We set the (4x, 8x, 16x) undersampling mask on-the-fly as it varies per subject.

In [4]:
physics = dinv.physics.DynamicMRI(img_size=(1, 2, 12, 512, 256), device="cpu")

### Define the DDEI loss function:

See [train.py](train.py) for full demo of how to train with competitors' losses using [deepinv](https://deepinv.github.io).

In [5]:
transform = dinv.transform.ShiftTime() | (dinv.transform.CPABDiffeomorphism() | dinv.transform.Rotate())
loss = [dinv.loss.MCLoss(), dinv.loss.EILoss(transform=transform)]

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

### Define the data:

In [11]:
dataset = DeepinvSliceDataset(
    root="data/CMRxRecon",
    transform=CineNetDataTransform(time_window=12, apply_mask=True, normalize=True), 
    set_name="TrainingSet",
    acc_folders=["FullSample"],
    mask_folder="TimeVaryingGaussianMask08",
    dataset_cache_file="dataset_cache_new.pkl"
)

100%|██████████| 4/4 [00:02<00:00,  1.74it/s]

Saving dataset cache file





### Define neural network:

For $f_\theta$ we use a very small [CRNN](https://ieeexplore.ieee.org/document/8425639), a lightweight unrolled network with 2 unrolled iterations and 1154 parameters. Our framework is **NN-agnostic** and any state-of-the-art NN can be used as the backbone.

In [12]:
model = ArtifactRemovalCRNN(CRNN(num_cascades=2)).to("cpu")

### Train the network!

We train the network using a modified [`deepinv.Trainer`](https://deepinv.github.io/deepinv/stubs/deepinv.Trainer.html). For full training demo, see [train.py](train.py).

In [None]:
trainer = Trainer(
    model = model,
    physics = physics,
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3),
    train_dataloader = DataLoader(dataset=dataset),
    losses = loss,
    metrics = dinv.metric.PSNR(complex_abs=True, max_pixel=None)
)

trainer.train()

### Full results

Test set example cardiac long axis views (above 2 rows) and short axis slice (below) reconstruction results:

![](img/results_fig_1.gif)