Skip to content

Commit

Permalink
ENH: Add quadratic penalization to OSEM algorithm
Browse files Browse the repository at this point in the history
Implement the quadratic penalization describes in [De Pierro, IEEE TMI,
1995].
  • Loading branch information
arobert01 committed Apr 18, 2020
1 parent df4a9c8 commit b806c7e
Show file tree
Hide file tree
Showing 6 changed files with 354 additions and 30 deletions.
2 changes: 2 additions & 0 deletions applications/rtkosem/rtkosem.cxx
Expand Up @@ -102,6 +102,8 @@ main(int argc, char * argv[])
osem->SetSigmaZero(args_info.sigmazero_arg);
if (args_info.alphapsf_given)
osem->SetAlpha(args_info.alphapsf_arg);
if (args_info.betaregularization_given)
osem->SetBetaRegularization(args_info.betaregularization_arg);
osem->SetGeometry(geometryReader->GetOutputObject());

osem->SetNumberOfIterations(args_info.niterations_arg);
Expand Down
1 change: 1 addition & 0 deletions applications/rtkosem/rtkosem.ggo
Expand Up @@ -8,3 +8,4 @@ option "output" o "Output file name" s
option "niterations" n "Number of iterations" int no default="5"
option "input" i "Input volume" string no
option "nprojpersubset" - "Number of projections processed between each update of the reconstructed volume (several for OSEM, all for MLEM)" int no default="1"
option "betaregularization" - "Hyperparameter for the regularization" float no default="0.01"
155 changes: 155 additions & 0 deletions include/rtkDePierroRegularizationImageFilter.h
@@ -0,0 +1,155 @@
/*=========================================================================
*
* Copyright RTK Consortium
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0.txt
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*=========================================================================*/

#ifndef rtkDePierroRegularizationImageFilter_h
#define rtkDePierroRegularizationImageFilter_h

#include <itkMultiplyImageFilter.h>
#include <itkSubtractImageFilter.h>
#include <itkDivideImageFilter.h>
#include <itkImageKernelOperator.h>
#include <itkNeighborhoodOperatorImageFilter.h>
#include <itkConstantBoundaryCondition.h>
#include "rtkConstantImageSource.h"

namespace rtk
{

/** \class DePierroRegularizationImageFilter
* \brief Implements a regularization for MLEM/OSEM reconstruction.
*
* Perform the quadratic penalization describe in [De Pierro, IEEE TMI, 1995] for
* MLEM/OSEM reconstruction.
*
* This filter takes the k and k+1 updates of the classic MLEM/OSEM algorithm as
* inputs and return the regularization factor.
*
* \dot
* digraph DePierroRegularizationImageFilter {
*
* Input0 [ label="Input 0 (Update k of MLEM)"];
* Input0 [shape=Mdiamond];
* Input1 [label="Input 1 (Update k+1 of MLEM)"];
* Input1 [shape=Mdiamond];
* Input2 [label="Input 2 (Backprojection of one)"];
* Input2 [shape=Mdiamond];
* Output [label="Output (Regularization factor)"];
* Output [shape=Mdiamond];
* KernelImage [ label="KernelImage"];
* KernelImage [shape=Mdiamond];
*
* node [shape=box];
* ImageKernelOperator [ label="itk::ImageKernelOperator" URL="\ref itk::ImageKernelOperator"];
* NOIF[ label="itk::NeighborhoodOperatorImageFilter" URL="\ref itk::NeighborhoodOperatorImageFilter"];
* Subtract [ label="itk::SubtractImageFilter" URL="\ref itk::SubtractImageFilter"];
* Multiply1 [ label="itk::MultiplyImageFilter (by constant)" URL="\ref itk::MultiplyImageFilter"];
* Multiply2 [ label="itk::MultiplyImageFilter (by constant)" URL="\ref itk::MultiplyImageFilter"];
* CustomBinary [ label="itk::BinaryGeneratorImageFilter ((A + sqrt(A*A + B))/2)" URL="\ref
* itk::BinaryGeneratorImageFilter"]; Input2 -> Subtract; KernelImage -> ImageKernelOperator; ImageKernelOperator ->
* NOIF; Input0 -> NOIF; NOIF -> Multiply1; Multiply1 -> Subtract; Subtract -> CustomBinary; Input1 -> Multiply2;
* Multiply2 -> CustomBinary;
* CustomBinary -> Output;
* }
* \enddot
*
* \author Antoine Robert
*
* \ingroup RTK ReconstructionAlgorithm
*/
template <class TInputImage, class TOutputImage = TInputImage>
class ITK_EXPORT DePierroRegularizationImageFilter : public itk::ImageToImageFilter<TInputImage, TOutputImage>
{
public:
ITK_DISALLOW_COPY_AND_ASSIGN(DePierroRegularizationImageFilter);

/** Standard class type alias. */
using Self = DePierroRegularizationImageFilter;
using Superclass = itk::ImageToImageFilter<TOutputImage, TOutputImage>;
using Pointer = itk::SmartPointer<Self>;
using ConstPointer = itk::SmartPointer<const Self>;

/** Some convenient type alias. */
using InputImageType = TInputImage;
using InputImagePointerType = typename TInputImage::Pointer;
using OutputImageType = TOutputImage;
using InputPixelType = typename TInputImage::PixelType;

/** ImageDimension constants */
static constexpr unsigned int InputImageDimension = TInputImage::ImageDimension;

/** Typedefs of each subfilter of this composite filter */
using MultiplyImageFilterType = itk::MultiplyImageFilter<InputImageType, InputImageType>;
using MultpiplyImageFilterPointerType = typename MultiplyImageFilterType::Pointer;
using ConstantVolumeSourceType = rtk::ConstantImageSource<InputImageType>;
using ConstantVolumeSourcePointerType = typename ConstantVolumeSourceType::Pointer;
using SubtractImageFilterType = itk::SubtractImageFilter<InputImageType, InputImageType>;
using SubtractImageFilterPointerType = typename SubtractImageFilterType::Pointer;
using ImageKernelOperatorType = itk::ImageKernelOperator<InputPixelType, InputImageDimension>;
using NOIFType = itk::NeighborhoodOperatorImageFilter<InputImageType, InputImageType>;
using NOIFPointerType = typename NOIFType::Pointer;
using CustomBinaryFilterType = itk::BinaryGeneratorImageFilter<InputImageType, InputImageType, OutputImageType>;
using CustomBinaryFilterPointerType = typename CustomBinaryFilterType::Pointer;

/** Typedef for the boundary condition */
using BoundaryCondition = itk::ConstantBoundaryCondition<InputImageType>;

/** Method for creation through the object factory. */
itkNewMacro(Self);

/** Run-time type information (and related methods). */
itkTypeMacro(DePierroRegularizationFilter, itk::ImageToImageFilter);

/** Get / Set the hyper parameter for the regularization */
itkGetMacro(Beta, float);
itkSetMacro(Beta, float);

protected:
DePierroRegularizationImageFilter();
~DePierroRegularizationImageFilter() override = default;

void
GenerateInputRequestedRegion() override;

void
GenerateOutputInformation() ITK_OVERRIDE;

void
GenerateData() ITK_OVERRIDE;

MultpiplyImageFilterPointerType m_MultiplyConstant1ImageFilter;
MultpiplyImageFilterPointerType m_MultiplyConstant2ImageFilter;
ConstantVolumeSourcePointerType m_KernelImage;
ConstantVolumeSourcePointerType m_DefaultNormalizationVolume;
SubtractImageFilterPointerType m_SubtractImageFilter;
BoundaryCondition m_BoundsCondition;
ImageKernelOperatorType m_KernelOperator;
NOIFPointerType m_ConvolutionFilter;
CustomBinaryFilterPointerType m_CustomBinaryFilter;

private:
float m_Beta;

}; // end of class

} // end namespace rtk

#ifndef ITK_MANUAL_INSTANTIATION
# include "rtkDePierroRegularizationImageFilter.hxx"
#endif

#endif
140 changes: 140 additions & 0 deletions include/rtkDePierroRegularizationImageFilter.hxx
@@ -0,0 +1,140 @@
/*=========================================================================
*
* Copyright RTK Consortium
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0.txt
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*=========================================================================*/

#ifndef rtkDePierroRegularizationImageFilter_hxx
#define rtkDePierroRegularizationImageFilter_hxx

#include "rtkDePierroRegularizationImageFilter.h"

namespace rtk
{
template <class TInputImage, class TOutputImage>
DePierroRegularizationImageFilter<TInputImage, TOutputImage>::DePierroRegularizationImageFilter()
{
this->SetNumberOfRequiredInputs(2);
// Set default parameters
m_Beta = 0.01;

// Create each filter of the composite filter
m_MultiplyConstant1ImageFilter = MultiplyImageFilterType::New();
m_MultiplyConstant2ImageFilter = MultiplyImageFilterType::New();
m_KernelImage = ConstantVolumeSourceType::New();
m_DefaultNormalizationVolume = ConstantVolumeSourceType::New();
m_SubtractImageFilter = SubtractImageFilterType::New();
m_ConvolutionFilter = NOIFType::New();
m_CustomBinaryFilter = CustomBinaryFilterType::New();

// Set Lambda function
auto customLambda = [](const typename InputImageType::PixelType & input1,
const typename InputImageType::PixelType & input2) -> typename OutputImageType::PixelType
{
return static_cast<typename OutputImageType::PixelType>((input1 + std::sqrt(input1 * input1 + input2)) / 2);
};
m_CustomBinaryFilter->SetFunctor(customLambda);

// Permanent internal connections
m_SubtractImageFilter->SetInput2(m_MultiplyConstant1ImageFilter->GetOutput());
m_CustomBinaryFilter->SetInput1(m_SubtractImageFilter->GetOutput());
m_CustomBinaryFilter->SetInput2(m_MultiplyConstant2ImageFilter->GetOutput());

// Set the kernel image
typename ConstantVolumeSourceType::PointType origin;
typename ConstantVolumeSourceType::SizeType size;
typename ConstantVolumeSourceType::SpacingType spacing;
origin.Fill(-1);
size.Fill(3);
spacing.Fill(1);
m_KernelImage->SetOrigin(origin);
m_KernelImage->SetSpacing(spacing);
m_KernelImage->SetSize(size);
m_KernelImage->SetConstant(1.);
}

template <class TInputImage, class TOutputImage>
void
DePierroRegularizationImageFilter<TInputImage, TOutputImage>::GenerateInputRequestedRegion()
{
// Input 0 is the k uptade of classic MLEM/OSEM algorithm
typename TInputImage::Pointer inputPtr0 = const_cast<TInputImage *>(this->GetInput(0));
if (!inputPtr0)
return;
inputPtr0->SetRequestedRegion(this->GetOutput()->GetRequestedRegion());

// Input 0 is the k+1 uptade of classic MLEM/OSEM algorithm
typename TInputImage::Pointer inputPtr1 = const_cast<TInputImage *>(this->GetInput(1));
if (!inputPtr1)
return;
inputPtr1->SetRequestedRegion(inputPtr1->GetLargestPossibleRegion());

// Input 3 is the normalization volume (optional)
typename TInputImage::Pointer inputPtr3 = const_cast<TInputImage *>(this->GetInput(2));
if (inputPtr3)
inputPtr3->SetRequestedRegion(inputPtr0->GetRequestedRegion());
}

template <class TInputImage, class TOutputImage>
void
DePierroRegularizationImageFilter<TInputImage, TOutputImage>::GenerateOutputInformation()
{
m_MultiplyConstant1ImageFilter->SetInput1(this->GetInput(0));
m_MultiplyConstant2ImageFilter->SetInput1(this->GetInput(1));
if (this->GetInput(2) != nullptr)
{
m_SubtractImageFilter->SetInput1(this->GetInput(2));
}
else
{
m_DefaultNormalizationVolume->SetInformationFromImage(const_cast<TInputImage *>(this->GetInput(0)));
m_DefaultNormalizationVolume->SetConstant(1);
m_SubtractImageFilter->SetInput1(m_DefaultNormalizationVolume->GetOutput());
}
m_MultiplyConstant1ImageFilter->SetConstant2(m_Beta);
m_MultiplyConstant2ImageFilter->SetConstant2(4 * 2 * m_Beta * (pow(3, InputImageDimension) - 1));

m_CustomBinaryFilter->UpdateOutputInformation();
this->GetOutput()->SetOrigin(m_CustomBinaryFilter->GetOutput()->GetOrigin());
this->GetOutput()->SetSpacing(m_CustomBinaryFilter->GetOutput()->GetSpacing());
this->GetOutput()->SetDirection(m_CustomBinaryFilter->GetOutput()->GetDirection());
this->GetOutput()->SetLargestPossibleRegion(m_CustomBinaryFilter->GetOutput()->GetLargestPossibleRegion());
}

template <class TInputImage, class TOutputImage>
void
DePierroRegularizationImageFilter<TInputImage, TOutputImage>::GenerateData()
{
m_KernelImage->Update();
typename TInputImage::IndexType pixelIndex;
pixelIndex.Fill(1);
m_KernelImage->GetOutput()->SetPixel(pixelIndex, pow(3, InputImageDimension) - 1);
m_KernelOperator.SetImageKernel(m_KernelImage->GetOutput());
// The radius of the kernel must be the radius of the patch, NOT the size of the patch
itk::Size<InputImageDimension> radius;
radius.Fill(1);
m_KernelOperator.CreateToRadius(radius);
m_ConvolutionFilter->OverrideBoundaryCondition(&m_BoundsCondition);
m_ConvolutionFilter->SetOperator(m_KernelOperator);
m_ConvolutionFilter->SetInput(this->GetInput(0));
m_MultiplyConstant1ImageFilter->SetInput1(m_ConvolutionFilter->GetOutput());

m_CustomBinaryFilter->Update();
this->GraftOutput(m_CustomBinaryFilter->GetOutput());
}

} // end namespace rtk

#endif // rtkDePierroRegularizationImageFilter_hxx

0 comments on commit b806c7e

Please sign in to comment.