Skip to content

Commit

Permalink
ENH: Made weights of weighted least squares optional in conjugate gra…
Browse files Browse the repository at this point in the history
…dient
  • Loading branch information
Simon Rit authored and SimonRit committed Dec 4, 2023
1 parent deb0f0c commit f17a8ad
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 34 deletions.
24 changes: 7 additions & 17 deletions applications/rtkconjugategradient/rtkconjugategradient.cxx
Expand Up @@ -77,25 +77,15 @@ main(int argc, char * argv[])
inputFilter = constantImageSource;
}

// Read weights if given, otherwise default to weights all equal to one
itk::ImageSource<OutputImageType>::Pointer weightsSource;
// Read weights if given
OutputImageType::Pointer inputWeights;
if (args_info.weights_given)
{
using WeightsReaderType = itk::ImageFileReader<OutputImageType>;
WeightsReaderType::Pointer weightsReader = WeightsReaderType::New();
weightsReader->SetFileName(args_info.weights_arg);
weightsSource = weightsReader;
}
else
{
using ConstantWeightsSourceType = rtk::ConstantImageSource<OutputImageType>;
ConstantWeightsSourceType::Pointer constantWeightsSource = ConstantWeightsSourceType::New();

// Set the weights to be like the projections
TRY_AND_EXIT_ON_ITK_EXCEPTION(reader->UpdateOutputInformation())
constantWeightsSource->SetInformationFromImage(reader->GetOutput());
constantWeightsSource->SetConstant(1.0);
weightsSource = constantWeightsSource;
inputWeights = weightsReader->GetOutput();
inputWeights->Update();
}

// Read Support Mask if given
Expand All @@ -110,9 +100,9 @@ main(int argc, char * argv[])
ConjugateGradientFilterType::Pointer conjugategradient = ConjugateGradientFilterType::New();
SetForwardProjectionFromGgo(args_info, conjugategradient.GetPointer());
SetBackProjectionFromGgo(args_info, conjugategradient.GetPointer());
conjugategradient->SetInput(inputFilter->GetOutput());
conjugategradient->SetInput(1, reader->GetOutput());
conjugategradient->SetInput(2, weightsSource->GetOutput());
conjugategradient->SetInputVolume(inputFilter->GetOutput());
conjugategradient->SetInputProjectionStack(reader->GetOutput());
conjugategradient->SetInputWeights(inputWeights);
conjugategradient->SetCudaConjugateGradient(!args_info.nocudacg_flag);
if (args_info.mask_given)
{
Expand Down
5 changes: 5 additions & 0 deletions include/rtkConjugateGradientConeBeamReconstructionFilter.h
Expand Up @@ -166,9 +166,14 @@ class ITK_TEMPLATE_EXPORT ConjugateGradientConeBeamReconstructionFilter
std::is_same<TSingleComponentImage, TOutputImage>::value,
CudaConstantVolumeSource,
ConstantImageSource<TOutputImage>>::type ConstantImageSourceType;
typedef typename std::conditional<!std::is_same<TOutputImage, CPUOutputImageType>::value &&
std::is_same<TSingleComponentImage, TOutputImage>::value,
CudaConstantVolumeSource,
ConstantImageSource<TWeightsImage>>::type ConstantWeightSourceType;
#else
using DisplacedDetectorFilterType = DisplacedDetectorImageFilter<TWeightsImage>;
using ConstantImageSourceType = ConstantImageSource<TOutputImage>;
using ConstantWeightSourceType = ConstantImageSource<TWeightsImage>;
#endif

/** Set the support mask, if any, for support constraint in reconstruction */
Expand Down
34 changes: 22 additions & 12 deletions include/rtkConjugateGradientConeBeamReconstructionFilter.hxx
Expand Up @@ -20,6 +20,7 @@
#define rtkConjugateGradientConeBeamReconstructionFilter_hxx

#include <itkProgressAccumulator.h>
#include <itkPixelTraits.h>

namespace rtk
{
Expand All @@ -29,7 +30,7 @@ ConjugateGradientConeBeamReconstructionFilter<TOutputImage, TSingleComponentImag
ConjugateGradientConeBeamReconstructionFilter()
: m_IterationReporter(this, 0, 1) // report every iteration
{
this->SetNumberOfRequiredInputs(3);
this->SetNumberOfRequiredInputs(2);

// Set the default values of member parameters
m_NumberOfIterations = 3;
Expand Down Expand Up @@ -74,7 +75,7 @@ void
ConjugateGradientConeBeamReconstructionFilter<TOutputImage, TSingleComponentImage, TWeightsImage>::SetInputWeights(
const TWeightsImage * weights)
{
this->SetNthInput(2, const_cast<TWeightsImage *>(weights));
this->SetInput("InputWeights", const_cast<TWeightsImage *>(weights));
}

template <typename TOutputImage, typename TSingleComponentImage, typename TWeightsImage>
Expand Down Expand Up @@ -104,7 +105,7 @@ template <typename TOutputImage, typename TSingleComponentImage, typename TWeigh
typename TWeightsImage::ConstPointer
ConjugateGradientConeBeamReconstructionFilter<TOutputImage, TSingleComponentImage, TWeightsImage>::GetInputWeights()
{
return static_cast<const TWeightsImage *>(this->itk::ProcessObject::GetInput(2));
return static_cast<const TWeightsImage *>(this->itk::ProcessObject::GetInput("InputWeights"));
}

template <typename TOutputImage, typename TSingleComponentImage, typename TWeightsImage>
Expand Down Expand Up @@ -142,11 +143,14 @@ ConjugateGradientConeBeamReconstructionFilter<TOutputImage, TSingleComponentImag
return;
inputPtr1->SetRequestedRegion(inputPtr1->GetLargestPossibleRegion());

// Input 2 is the weights map on projections, either user-defined or filled with ones (default)
typename TWeightsImage::Pointer inputPtr2 = const_cast<TWeightsImage *>(this->GetInputWeights().GetPointer());
if (!inputPtr2)
return;
inputPtr2->SetRequestedRegion(inputPtr2->GetLargestPossibleRegion());
// Input "InputWeights" is the weights map on projections, either user-defined or filled with ones (default)
if (this->GetInputWeights().IsNotNull())
{
typename TWeightsImage::Pointer inputWeights = const_cast<TWeightsImage *>(this->GetInputWeights().GetPointer());
if (!inputWeights)
return;
inputWeights->SetRequestedRegion(inputWeights->GetLargestPossibleRegion());
}

// Input "SupportMask" is the support constraint mask on volume, if any
if (this->GetSupportMask().IsNotNull())
Expand Down Expand Up @@ -190,6 +194,16 @@ ConjugateGradientConeBeamReconstructionFilter<TOutputImage, TSingleComponentImag
m_CGOperator->SetSupportMask(this->GetSupportMask());
m_ConjugateGradientFilter->SetX(this->GetInputVolume());
m_DisplacedDetectorFilter->SetDisable(m_DisableDisplacedDetectorFilter);
if (this->GetInputWeights().IsNull())
{
using PixelType = typename TWeightsImage::PixelType;
using ComponentType = typename itk::PixelTraits<PixelType>::ValueType;
typename ConstantWeightSourceType::Pointer ones = ConstantWeightSourceType::New();
ones->SetInformationFromImage(this->GetInputProjectionStack());
ones->SetConstant(PixelType(itk::NumericTraits<ComponentType>::One));
ones->Update();
this->SetInputWeights(ones->GetOutput());
}
m_DisplacedDetectorFilter->SetInput(this->GetInputWeights());

// Links with the m_BackProjectionFilter should be set here and not
Expand Down Expand Up @@ -256,10 +270,6 @@ ConjugateGradientConeBeamReconstructionFilter<TOutputImage, TSingleComponentImag
if (this->GetSupportMask())
{
m_MultiplyOutputFilter->Update();
}

if (this->GetSupportMask())
{
this->GraftOutput(m_MultiplyOutputFilter->GetOutput());
}
else
Expand Down
6 changes: 3 additions & 3 deletions include/rtkReconstructionConjugateGradientOperator.hxx
Expand Up @@ -161,10 +161,10 @@ ReconstructionConjugateGradientOperator<TOutputImage, TSingleComponentImage, TWe
inputPtr1->SetRequestedRegion(inputPtr1->GetLargestPossibleRegion());

// Input 2 is the weights map on projections, if any
typename TWeightsImage::Pointer inputPtr2 = const_cast<TWeightsImage *>(this->GetInputWeights().GetPointer());
if (!inputPtr2)
typename TWeightsImage::Pointer inputWeights = const_cast<TWeightsImage *>(this->GetInputWeights().GetPointer());
if (!inputWeights)
return;
inputPtr2->SetRequestedRegion(inputPtr2->GetLargestPossibleRegion());
inputWeights->SetRequestedRegion(inputWeights->GetLargestPossibleRegion());

// Input "SupportMask" is the support constraint mask on volume, if any
if (this->GetSupportMask().IsNotNull())
Expand Down
2 changes: 1 addition & 1 deletion test/rtkconjugategradientreconstructiontest.cxx
Expand Up @@ -142,7 +142,7 @@ main(int, char **)
std::cout << "\n\n****** Case 1: Voxel-Based Backprojector ******" << std::endl;

conjugategradient->SetBackProjectionFilter(ConjugateGradientType::BP_VOXELBASED);
conjugategradient->SetInput(2, uniformWeightsSource->GetOutput());
conjugategradient->SetInputWeights(uniformWeightsSource->GetOutput());
TRY_AND_EXIT_ON_ITK_EXCEPTION(conjugategradient->Update());

CheckImageQuality<OutputImageType>(conjugategradient->GetOutput(), dsl->GetOutput(), 0.08, 23, 2.0);
Expand Down
2 changes: 1 addition & 1 deletion test/rtkcylindricaldetectorreconstructiontest.cxx
Expand Up @@ -136,7 +136,7 @@ main(int, char **)
ConjugateGradientType::Pointer conjugategradient = ConjugateGradientType::New();
conjugategradient->SetInput(tomographySource->GetOutput());
conjugategradient->SetInput(1, rei->GetOutput());
conjugategradient->SetInput(2, uniformWeightsSource->GetOutput());
conjugategradient->SetInputWeights(uniformWeightsSource->GetOutput());
conjugategradient->SetGeometry(geometry);
conjugategradient->SetNumberOfIterations(5);
conjugategradient->SetDisableDisplacedDetectorFilter(true);
Expand Down

0 comments on commit f17a8ad

Please sign in to comment.