Skip to content

Commit

Permalink
Stein's Unbiased Risk Estimator (SURE) loss and Conjugate Gradient (P…
Browse files Browse the repository at this point in the history
…roject-MONAI#7308)

### Description

Based on the discussion topic
[here](Project-MONAI#7161 (comment)),
we implemented the Conjugate-Gradient algorithm for linear operator
inversion, and Stein's Unbiased Risk Estimator (SURE) [1] loss for
ground-truth-date free diffusion process guidance that is proposed in
[2] and illustrated in the algorithm below:

<img width="650" alt="Screenshot 2023-12-10 at 10 19 25 PM"
src="https://github.com/Project-MONAI/MONAI/assets/8581162/97069466-cbaf-44e0-b7a7-ae9deb8fd7f2">

The Conjugate-Gradient (CG) algorithm is used to solve for the inversion
of the linear operator in Line-4 in the algorithm above, where the
linear operator is too large to store explicitly as a matrix (such as
FFT/IFFT of an image) and invert directly. Instead, we can solve for the
linear inversion iteratively as in CG.

The SURE loss is applied for Line-6 above. This is a differentiable loss
function that can be used to train/giude an operator (e.g. neural
network), where the pseudo ground truth is available but the reference
ground truth is not. For example, in the MRI reconstruction, the pseudo
ground truth is the zero-filled reconstruction and the reference ground
truth is the fully sampled reconstruction. The reference ground truth is
not available due to the lack of fully sampled.

**Reference**
[1] Stein, C.M.: Estimation of the mean of a multivariate normal
distribution. Annals of Statistics 1981 [[paper
link](https://projecteuclid.org/journals/annals-of-statistics/volume-9/issue-6/Estimation-of-the-Mean-of-a-Multivariate-Normal-Distribution/10.1214/aos/1176345632.full)]
[2] B. Ozturkler et al. SMRD: SURE-based Robust MRI Reconstruction with
Diffusion Models. MICCAI 2023
[[paper link](https://arxiv.org/pdf/2310.01799.pdf)]

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: chaoliu <chaoliu@nvidia.com>
Signed-off-by: cxlcl <chaoliucxl@gmail.com>
Signed-off-by: chaoliu <chaoliucxl@gmail.com>
Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
Signed-off-by: Yu0610 <612410030@alum.ccu.edu.tw>
  • Loading branch information
4 people authored and Yu0610 committed Apr 11, 2024
1 parent ed799b9 commit 1b723c8
Show file tree
Hide file tree
Showing 8 changed files with 450 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/source/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ Reconstruction Losses
.. autoclass:: JukeboxLoss
:members:

`SURELoss`
~~~~~~~~~~
.. autoclass:: SURELoss
:members:


Loss Wrappers
-------------
Expand Down
5 changes: 5 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,11 @@ Layers
.. autoclass:: LLTM
:members:

`ConjugateGradient`
~~~~~~~~~~~~~~~~~~~
.. autoclass:: ConjugateGradient
:members:

`Utilities`
~~~~~~~~~~~
.. automodule:: monai.networks.layers.convutils
Expand Down
1 change: 1 addition & 0 deletions monai/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,6 @@
from .spatial_mask import MaskedLoss
from .spectral_loss import JukeboxLoss
from .ssim_loss import SSIMLoss
from .sure_loss import SURELoss
from .tversky import TverskyLoss
from .unified_focal_loss import AsymmetricUnifiedFocalLoss
200 changes: 200 additions & 0 deletions monai/losses/sure_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Callable, Optional

import torch
import torch.nn as nn
from torch.nn.modules.loss import _Loss


def complex_diff_abs_loss(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
First compute the difference in the complex domain,
then get the absolute value and take the mse
Args:
x, y - B, 2, H, W real valued tensors representing complex numbers
or B,1,H,W complex valued tensors
Returns:
l2_loss - scalar
"""
if not x.is_complex():
x = torch.view_as_complex(x.permute(0, 2, 3, 1).contiguous())
if not y.is_complex():
y = torch.view_as_complex(y.permute(0, 2, 3, 1).contiguous())

diff = torch.abs(x - y)
return nn.functional.mse_loss(diff, torch.zeros_like(diff), reduction="mean")


def sure_loss_function(
operator: Callable,
x: torch.Tensor,
y_pseudo_gt: torch.Tensor,
y_ref: Optional[torch.Tensor] = None,
eps: Optional[float] = -1.0,
perturb_noise: Optional[torch.Tensor] = None,
complex_input: Optional[bool] = False,
) -> torch.Tensor:
"""
Args:
operator (function): The operator function that takes in an input
tensor x and returns an output tensor y. We will use this to compute
the divergence. More specifically, we will perturb the input x by a
small amount and compute the divergence between the perturbed output
and the reference output
x (torch.Tensor): The input tensor of shape (B, C, H, W) to the
operator. For complex input, the shape is (B, 2, H, W) aka C=2 real.
For real input, the shape is (B, 1, H, W) real.
y_pseudo_gt (torch.Tensor): The pseudo ground truth tensor of shape
(B, C, H, W) used to compute the L2 loss. For complex input, the shape is
(B, 2, H, W) aka C=2 real. For real input, the shape is (B, 1, H, W)
real.
y_ref (torch.Tensor, optional): The reference output tensor of shape
(B, C, H, W) used to compute the divergence. Defaults to None. For
complex input, the shape is (B, 2, H, W) aka C=2 real. For real input,
the shape is (B, 1, H, W) real.
eps (float, optional): The perturbation scalar. Set to -1 to set it
automatically estimated based on y_pseudo_gtk
perturb_noise (torch.Tensor, optional): The noise vector of shape (B, C, H, W).
Defaults to None. For complex input, the shape is (B, 2, H, W) aka C=2 real.
For real input, the shape is (B, 1, H, W) real.
complex_input(bool, optional): Whether the input is complex or not.
Defaults to False.
Returns:
sure_loss (torch.Tensor): The SURE loss scalar.
"""
# perturb input
if perturb_noise is None:
perturb_noise = torch.randn_like(x)
if eps == -1.0:
eps = float(torch.abs(y_pseudo_gt.max())) / 1000
# get y_ref if not provided
if y_ref is None:
y_ref = operator(x)

# get perturbed output
x_perturbed = x + eps * perturb_noise
y_perturbed = operator(x_perturbed)
# divergence
divergence = torch.sum(1.0 / eps * torch.matmul(perturb_noise.permute(0, 1, 3, 2), y_perturbed - y_ref)) # type: ignore
# l2 loss between y_ref, y_pseudo_gt
if complex_input:
l2_loss = complex_diff_abs_loss(y_ref, y_pseudo_gt)
else:
# real input
l2_loss = nn.functional.mse_loss(y_ref, y_pseudo_gt, reduction="mean")

# sure loss
sure_loss = l2_loss * divergence / (x.shape[0] * x.shape[2] * x.shape[3])
return sure_loss


class SURELoss(_Loss):
"""
Calculate the Stein's Unbiased Risk Estimator (SURE) loss for a given operator.
This is a differentiable loss function that can be used to train/guide an
operator (e.g. neural network), where the pseudo ground truth is available
but the reference ground truth is not. For example, in the MRI
reconstruction, the pseudo ground truth is the zero-filled reconstruction
and the reference ground truth is the fully sampled reconstruction. Often,
the reference ground truth is not available due to the lack of fully sampled
data.
The original SURE loss is proposed in [1]. The SURE loss used for guiding
the diffusion model based MRI reconstruction is proposed in [2].
Reference
[1] Stein, C.M.: Estimation of the mean of a multivariate normal distribution. Annals of Statistics
[2] B. Ozturkler et al. SMRD: SURE-based Robust MRI Reconstruction with Diffusion Models.
(https://arxiv.org/pdf/2310.01799.pdf)
"""

def __init__(self, perturb_noise: Optional[torch.Tensor] = None, eps: Optional[float] = None) -> None:
"""
Args:
perturb_noise (torch.Tensor, optional): The noise vector of shape
(B, C, H, W). Defaults to None. For complex input, the shape is (B, 2, H, W) aka C=2 real.
For real input, the shape is (B, 1, H, W) real.
eps (float, optional): The perturbation scalar. Defaults to None.
"""
super().__init__()
self.perturb_noise = perturb_noise
self.eps = eps

def forward(
self,
operator: Callable,
x: torch.Tensor,
y_pseudo_gt: torch.Tensor,
y_ref: Optional[torch.Tensor] = None,
complex_input: Optional[bool] = False,
) -> torch.Tensor:
"""
Args:
operator (function): The operator function that takes in an input
tensor x and returns an output tensor y. We will use this to compute
the divergence. More specifically, we will perturb the input x by a
small amount and compute the divergence between the perturbed output
and the reference output
x (torch.Tensor): The input tensor of shape (B, C, H, W) to the
operator. C=1 or 2: For complex input, the shape is (B, 2, H, W) aka
C=2 real. For real input, the shape is (B, 1, H, W) real.
y_pseudo_gt (torch.Tensor): The pseudo ground truth tensor of shape
(B, C, H, W) used to compute the L2 loss. C=1 or 2: For complex
input, the shape is (B, 2, H, W) aka C=2 real. For real input, the
shape is (B, 1, H, W) real.
y_ref (torch.Tensor, optional): The reference output tensor of the
same shape as y_pseudo_gt
Returns:
sure_loss (torch.Tensor): The SURE loss scalar.
"""

# check inputs shapes
if x.dim() != 4:
raise ValueError(f"Input tensor x should be 4D, got {x.dim()}.")
if y_pseudo_gt.dim() != 4:
raise ValueError(f"Input tensor y_pseudo_gt should be 4D, but got {y_pseudo_gt.dim()}.")
if y_ref is not None and y_ref.dim() != 4:
raise ValueError(f"Input tensor y_ref should be 4D, but got {y_ref.dim()}.")
if x.shape != y_pseudo_gt.shape:
raise ValueError(
f"Input tensor x and y_pseudo_gt should have the same shape, but got x shape {x.shape}, "
f"y_pseudo_gt shape {y_pseudo_gt.shape}."
)
if y_ref is not None and y_pseudo_gt.shape != y_ref.shape:
raise ValueError(
f"Input tensor y_pseudo_gt and y_ref should have the same shape, but got y_pseudo_gt shape {y_pseudo_gt.shape}, "
f"y_ref shape {y_ref.shape}."
)

# compute loss
loss = sure_loss_function(operator, x, y_pseudo_gt, y_ref, self.eps, self.perturb_noise, complex_input)

return loss
1 change: 1 addition & 0 deletions monai/networks/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from __future__ import annotations

from .conjugate_gradient import ConjugateGradient
from .convutils import calculate_out_shape, gaussian_1d, polyval, same_padding, stride_minus_kernel_padding
from .drop_path import DropPath
from .factories import Act, Conv, Dropout, LayerFactory, Norm, Pad, Pool, split_args
Expand Down
112 changes: 112 additions & 0 deletions monai/networks/layers/conjugate_gradient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Callable

import torch
from torch import nn


def _zdot(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
"""
Complex dot product between tensors x1 and x2: sum(x1.*x2)
"""
if torch.is_complex(x1):
assert torch.is_complex(x2), "x1 and x2 must both be complex"
return torch.sum(x1.conj() * x2)
else:
return torch.sum(x1 * x2)


def _zdot_single(x: torch.Tensor) -> torch.Tensor:
"""
Complex dot product between tensor x and itself
"""
res = _zdot(x, x)
if torch.is_complex(res):
return res.real
else:
return res


class ConjugateGradient(nn.Module):
"""
Congugate Gradient (CG) solver for linear systems Ax = y.
For linear_op that is positive definite and self-adjoint, CG is
guaranteed to converge CG is often used to solve linear systems of the form
Ax = y, where A is too large to store explicitly, but can be computed via a
linear operator.
As a result, here we won't set A explicitly as a matrix, but rather as a
linear operator. For example, A could be a FFT/IFFT operation
"""

def __init__(self, linear_op: Callable, num_iter: int):
"""
Args:
linear_op: Linear operator
num_iter: Number of iterations to run CG
"""
super().__init__()

self.linear_op = linear_op
self.num_iter = num_iter

def update(
self, x: torch.Tensor, p: torch.Tensor, r: torch.Tensor, rsold: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
perform one iteration of the CG method. It takes the current solution x,
the current search direction p, the current residual r, and the old
residual norm rsold as inputs. Then it computes the new solution, search
direction, residual, and residual norm, and returns them.
"""

dy = self.linear_op(p)
p_dot_dy = _zdot(p, dy)
alpha = rsold / p_dot_dy
x = x + alpha * p
r = r - alpha * dy
rsnew = _zdot_single(r)
beta = rsnew / rsold
rsold = rsnew
p = beta * p + r
return x, p, r, rsold

def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""
run conjugate gradient for num_iter iterations to solve Ax = y
Args:
x: tensor (real or complex); Initial guess for linear system Ax = y.
The size of x should be applicable to the linear operator. For
example, if the linear operator is FFT, then x is HCHW; if the
linear operator is a matrix multiplication, then x is a vector
y: tensor (real or complex); Measurement. Same size as x
Returns:
x: Solution to Ax = y
"""
# Compute residual
r = y - self.linear_op(x)
rsold = _zdot_single(r)
p = r

# Update
for _i in range(self.num_iter):
x, p, r, rsold = self.update(x, p, r, rsold)
if rsold < 1e-10:
break
return x
Loading

0 comments on commit 1b723c8

Please sign in to comment.