Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
36c1267
initial unet recon demo commit
mersad95zd Jul 31, 2022
58c9be3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 31, 2022
b04d380
fixed formatting errors
mersad95zd Jul 31, 2022
73d6f92
Merge branch '01-recon-unet-demo' of https://github.com/mersad95zd/tu…
mersad95zd Jul 31, 2022
c11f9d1
Merge branch 'main' into 01-recon-unet-demo
mersad95zd Jul 31, 2022
a5659f1
fixed formatting errors
mersad95zd Jul 31, 2022
0d961cc
Merge branch '01-recon-unet-demo' of https://github.com/mersad95zd/tu…
mersad95zd Jul 31, 2022
cc765de
fixing formatting errors
mersad95zd Aug 1, 2022
cdc350e
Merge branch 'main' into 01-recon-unet-demo
wyli Aug 1, 2022
9d395f8
more experimental details added
mersad95zd Aug 1, 2022
5f55387
Merge branch 'main' into 01-recon-unet-demo
mersad95zd Aug 1, 2022
6d3ba40
Merge branch '01-recon-unet-demo' of https://github.com/mersad95zd/tu…
mersad95zd Aug 1, 2022
68d35fa
removed checkpoint from this PR; minor fix to checkpoint directory in…
mersad95zd Aug 3, 2022
400a2ef
Merge branch 'main' into 01-recon-unet-demo
mersad95zd Aug 3, 2022
04bf083
Merge branch '01-recon-unet-demo' of https://github.com/mersad95zd/tu…
mersad95zd Aug 3, 2022
430f7f6
clarified common practice for fastMRI inference
mersad95zd Aug 3, 2022
9eaf038
fixed model checkpoint name in all files
mersad95zd Aug 3, 2022
df5df4b
added model checkpoint link to readme
mersad95zd Aug 3, 2022
8fb97c0
init varnet commit
mersad95zd Sep 7, 2022
8291f29
back to init commit
mersad95zd Sep 13, 2022
e1ca9c8
Merge branch 'main' into 02-recon-varnet-demo
mersad95zd Sep 13, 2022
670252d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 13, 2022
7947977
updated readmes
mersad95zd Sep 13, 2022
c0bdc21
Merge branch '02-recon-varnet-demo' of https://github.com/mersad95zd/…
mersad95zd Sep 13, 2022
068fd21
removed unet files
mersad95zd Sep 13, 2022
a813429
updated inference files based on the data split
mersad95zd Sep 13, 2022
6be24e1
minor update to unet inference
mersad95zd Sep 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,17 @@ This folder contains code to train and validate a U-Net for accelerated MRI reco
# Dataset

The experiments are performed on the [fastMRI](https://fastmri.org/dataset) dataset. Users should request access to the dataset
from the [owner's website](https://fastmri.org/dataset).
from the [owner's website](https://fastmri.org/dataset). Remember to use the `$PATH` where you downloaded the data in `train.py`
or `inference.ipynb` accordingly.

**Note.** Since the ground truth is not released with the test set of the fastMRI dataset, it is a common practice in the literature
to perform inference on the validation set of the fastMRI dataset. This could be in the form of testing on the whole validation
set (for example this work [https://www.ncbi.nlm.nih.gov/pmc/articles/PMC8767765/](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC8767765/)).
<br>
Another approach is to split the validation set into validation and test sets and keep the test portion for inference (for exmple this work [https://arxiv.org/pdf/2111.02549.pdf](https://arxiv.org/pdf/2111.02549.pdf)). Note that both approaches are conceptually similar
in that splitting the validation set does not change the fact that the splits belong to the same distribution.
<br>
Other workarounds to this problem include (1) skipping validation during training and saving the model checkpoint of the last epoch for inference on the validation set, and (2) submitting model results to the [fastMRI public leaderboard](https://fastmri.org/leaderboards/).
For our experiments we created a subset of the fastMRI dataset which contains a `500/179/133` split for `train/val/test`. Please download [fastmri_data_split.json](https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/fastmri_data_split.json) and put it here under `./data`.

# Model checkpoint

We have already provided a model checkpoint [unet_mri_reconstruction.pt](https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/unet_mri_reconstruction.pt) for a U-Net with `7,782,849` parameters. To obtain this checkpoint, we trained
a U-Net with the default hyper-parameters in `train.py` on the T2 subset of the brain dataset (`500` training and `180` validation volumes). The user can train their model on an arbitrary portion of the dataset.
a U-Net with the default hyper-parameters in `train.py` on the T2 subset of the brain dataset. The user can train their model on an arbitrary portion of the dataset.

Our checkpoint achieves `0.9496` SSIM on the fastMRI T2 validation subset which is comparabale to the original result reported on the
[fastMRI public leaderboard](https://fastmri.org/leaderboards/) (which is `0.9474` SSIM). The training dynamics for our checkpoint is depicted in the figure below.
The training dynamics for our checkpoint is depicted in the figure below.

<p align="center"><img src="./figures/dynamics.PNG" width="800" height="225"></p>

Expand All @@ -71,5 +64,9 @@ Running `train.py` trains a U-Net. The default setup automatically detects a GPU

# Inference

The notebook `inference.ipynb` contains an example to perform validation. Average SSIM score over the validation set is computed and then
The notebook `inference.ipynb` contains an example to perform inference. Average SSIM score over the test subset is computed and then
one sample is picked for visualization.

Our checkpoint achieves `0.9436` SSIM on our test subset which is comparable to the original result reported on the
[fastMRI public leaderboard](https://fastmri.org/leaderboards/) (which is `0.9474` SSIM). Note that the results reported
on the leaderboard are for the unreleased test set. Moreover, the leaderboard model is trained on the validation set.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"import torch\n",
"import warnings\n",
"import random\n",
"import json\n",
"from fastmri_ssim import skimage_ssim\n",
"import matplotlib.pyplot as plt\n",
"\n",
Expand Down Expand Up @@ -76,7 +77,7 @@
" self.batch_size = 1 # can be set to >1 when input sizes are not different\n",
" self.num_workers = 0\n",
" self.cache_rate = 0.0 # what fraction of the data to be cached for faster loading\n",
" self.data_path_val = '/data/fastmri/fastMRI/multicoil_val_t2/' # path to the validation set\n",
" self.data_path_val = '/data/fastmri/multicoil_val/' # path to the validation set\n",
" self.sample_rate = 0.9 # select 0.9 of the validation set for inference\n",
" self.accelerations = [4] # acceleration factors used for valdiation.\n",
" self.center_fractions = [0.08] # center_fractions used for valdiation.\n",
Expand Down Expand Up @@ -104,16 +105,29 @@
"# Create validation data loader"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(\"./data/fastmri_data_split.json\", \"r\") as fn:\n",
" data = json.load(fn)\n",
"test_files = data['test_files']\n",
"fastmri_val_set = list(Path(args.data_path_val).iterdir())\n",
"test_files = [f for f in fastmri_val_set if str(f).split('/')[-1] in test_files]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"val_files = list(Path(args.data_path_val).iterdir())\n",
"random.shuffle(val_files)\n",
"val_files = val_files[:int(args.sample_rate*len(val_files))] # select a subset of the data according to sample_rate\n",
"val_files = [dict([(\"kspace\", val_files[i])]) for i in range(len(val_files))]\n",
"random.shuffle(test_files)\n",
"test_files = test_files[:int(args.sample_rate*len(test_files))] # select a subset of the data according to sample_rate\n",
"test_files = [dict([(\"kspace\", test_files[i])]) for i in range(len(test_files))]\n",
"print(f'#test files: {len(test_files)}')\n",
"\n",
"# define mask transform type (e.g., whether it is equispaced or random)\n",
"if args.mask_type == 'random':\n",
Expand All @@ -129,7 +143,7 @@
" spatial_dims=2,\n",
" is_complex=True)\n",
"\n",
"val_transforms = Compose(\n",
"test_transforms = Compose(\n",
" [\n",
" LoadImaged(keys=[\"kspace\"], reader=FastMRIReader, dtype=np.complex64),\n",
" # user can also add other random transforms\n",
Expand All @@ -145,10 +159,10 @@
" ]\n",
")\n",
"\n",
"val_ds = CacheDataset(\n",
" data=val_files, transform=val_transforms,\n",
"test_ds = CacheDataset(\n",
" data=test_files, transform=test_transforms,\n",
" cache_rate=args.cache_rate, num_workers=args.num_workers)\n",
"val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)"
"test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)"
]
},
{
Expand Down Expand Up @@ -203,30 +217,22 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"161 volume out of 161 done. \r"
]
}
],
"outputs": [],
"source": [
"outputs = defaultdict(list)\n",
"targets = defaultdict(list)\n",
"with torch.no_grad():\n",
" val_ssim = list()\n",
" step = 1\n",
" for val_data in val_loader:\n",
" for test_data in test_loader:\n",
" input, target, mean, std, fname = (\n",
" val_data[\"kspace_masked_ifft\"],\n",
" val_data[\"reconstruction_rss\"],\n",
" val_data[\"mean\"],\n",
" val_data[\"std\"],\n",
" val_data[\"kspace_meta_dict\"][\"filename\"]\n",
" test_data[\"kspace_masked_ifft\"],\n",
" test_data[\"reconstruction_rss\"],\n",
" test_data[\"mean\"],\n",
" test_data[\"std\"],\n",
" test_data[\"kspace_meta_dict\"][\"filename\"]\n",
" )\n",
"\n",
" # iterate through all slices:\n",
Expand All @@ -247,7 +253,7 @@
" # save volume slices according to volume name given by fname\n",
" outputs[fname[0]].append(output.data.cpu().numpy()[0][0]*_std+_mean)\n",
" targets[fname[0]].append(tar.numpy()[0][0]*_std+_mean)\n",
" print(step, ' volume out of', len(val_files), 'done.', '\\r', end='')\n",
" print(step, ' volume out of', len(test_files), 'done.', '\\r', end='')\n",
" step += 1\n",
"\n",
" # compute validation ssims values for all validation samples\n",
Expand All @@ -261,14 +267,14 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"average SSIM score over the validation set: 0.9496\n"
"average SSIM score over the validation set: 0.9436\n"
]
}
],
Expand Down
69 changes: 69 additions & 0 deletions reconstruction/MRI_reconstruction/varnet_demo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Accelerated MRI reconstruction with the end-to-end variational network (e2e-VarNet)

<p align="center"><img src="./figures/workflow.PNG" width="800" height="400"></p>


This folder contains code to train and validate an e2e-VarNet ([https://arxiv.org/pdf/2004.06688.pdf](https://arxiv.org/pdf/2004.06688.pdf)) for accelerated MRI reconstruction. Accelerated MRI reconstruction is a compressed sensing task where the goal is to recover a ground-truth image from an under-sampled measurement. The under-sampled measurement is based in the frequency domain and is often called the $k$-space.

***

### List of contents

* [Questions and bugs](#Questions-and-bugs)

* [Dataset](#Dataset)

* [Model checkpoint](#Model-checkpoint)

* [Training](#Training)

* [Inference](#Inference)

***

# Questions and bugs

- For questions relating to the use of MONAI, please us our [Discussions tab](https://github.com/Project-MONAI/MONAI/discussions) on the main repository of MONAI.
- For bugs relating to MONAI functionality, please create an issue on the [main repository](https://github.com/Project-MONAI/MONAI/issues).
- For bugs relating to the running of a tutorial, please create an issue in [this repository](https://github.com/Project-MONAI/Tutorials/issues).

# Dataset

Please see [dataset description](../unet_demo/README.md#dataset) for our dataset preparation.


# Model checkpoint

We have already provided a model checkpoint [varnet_mri_reconstruction.pt](https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/varnet_mri_reconstruction.pt) for a VarNet with `30,069,558` parameters. To obtain this checkpoint, we trained
a VarNet with the default hyper-parameters in `train.py` on our T2 subset of the brain dataset. The user can train their model on an arbitrary portion of the dataset.

The training dynamics for our checkpoint is depicted in the figure below.

<p align="center"><img src="./figures/dynamics.PNG" width="800" height="250"></p>

# Training

Running `train.py` trains a VarNet. The default setup automatically detects a GPU for training; if not available, CPU will be used.

# Run this to get a full list of training arguments
python ./train.py -h

# This is an example of calling train.py
python ./train.py
--data_path_train train_dir \
--data_path_val val_dir \
--exp varnet_mri_recon \
--exp_dir ./ \
--mask_type equispaced \
--num_epochs 50 \
--num_workers 0 \
--lr 0.00001

# Inference

The notebook `inference.ipynb` contains an example to perform inference. Average SSIM score over the test subset is computed and then
one sample is picked for visualization.

Our checkpoint achieves `0.9650` SSIM on our test subset which is comparable to the original result reported on the
[fastMRI public leaderboard](https://fastmri.org/leaderboards/) (which is `0.9606` SSIM). Note that the results reported
on the leaderboard are for the unreleased test set. Moreover, the leaderboard model is trained on the validation set.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading