Skip to content

Commit

Permalink
Merge pull request #1406 from MRtrix3/amp2response_speedup
Browse files Browse the repository at this point in the history
amp2response speedup
  • Loading branch information
Lestropie committed Jul 31, 2018
2 parents 6e35388 + cbc888b commit 102618f
Showing 1 changed file with 21 additions and 15 deletions.
36 changes: 21 additions & 15 deletions cmd/amp2response.cpp
Expand Up @@ -194,11 +194,21 @@ void run ()
auto image = header.get_image<float>();
auto mask = Image<bool>::open (argument[1]);
check_dimensions (image, mask, 0, 3);
if (!(mask.ndim() == 3 || (mask.ndim() == 4 && mask.size(3) == 1)))
throw Exception ("input mask must be a 3D image");
auto dir_image = Image<float>::open (argument[2]);
if (dir_image.ndim() < 4 || dir_image.size(3) < 3)
throw Exception ("input direction image \"" + std::string (argument[2]) + "\" does not have expected dimensions");
check_dimensions (image, dir_image, 0, 3);

size_t num_voxels = 0;
for (auto l = Loop (mask, 0, 3) (mask); l; ++l) {
if (mask.value())
++num_voxels;
}
if (!num_voxels)
throw Exception ("input mask does not contain any voxels");

Eigen::MatrixXd responses (dirs_azel.size(), Math::ZSH::NforL (max_lmax));

for (size_t shell_index = 0; shell_index != dirs_azel.size(); ++shell_index) {
Expand All @@ -208,17 +218,17 @@ void run ()
Eigen::MatrixXd dirs_cartesian = Math::Sphere::spherical2cartesian (dirs_azel[shell_index]);

// All directions from all SF voxels get concatenated into a single large matrix
Eigen::MatrixXd cat_transforms;
Eigen::VectorXd cat_data;
Eigen::MatrixXd cat_transforms (num_voxels * dirs_azel[shell_index].rows(), Math::ZSH::NforL (lmax[shell_index]));
Eigen::VectorXd cat_data (num_voxels * dirs_azel[shell_index].rows());

#ifdef AMP2RESPONSE_DEBUG
// To make sure we've got our data rotated correctly, let's generate a scatterplot of
// elevation vs. amplitude
Eigen::MatrixXd scatter;
#endif

size_t sf_counter = 0;
for (auto l = Loop (mask) (image, mask, dir_image); l; ++l) {
size_t voxel_counter = 0;
for (auto l = Loop (mask, 0, 3) (image, mask, dir_image); l; ++l) {
if (mask.value()) {

// Grab the image data
Expand Down Expand Up @@ -291,19 +301,16 @@ void run ()
Eigen::MatrixXd transform = Math::ZSH::init_amp_transform<default_type> (rotated_dirs_azel.col(1), lmax[shell_index]);

// Concatenate these data to the ICLS matrices
const size_t old_rows = cat_transforms.rows();
cat_transforms.conservativeResize (old_rows + transform.rows(), transform.cols());
cat_transforms.block (old_rows, 0, transform.rows(), transform.cols()) = transform;
cat_data.conservativeResize (old_rows + data.size());
cat_data.tail (data.size()) = data;
cat_transforms.block (voxel_counter * data.size(), 0, transform.rows(), transform.cols()) = transform;
cat_data.segment (voxel_counter * data.size(), data.size()) = data;

#ifdef AMP2RESPONSE_DEBUG
scatter.conservativeResize (cat_data.size(), 2);
scatter.block (old_rows, 0, data.size(), 1) = rotated_dirs_azel.col(1);
scatter.block (old_rows, 1, data.size(), 1) = data;
#endif

++sf_counter;
++voxel_counter;

}
}
Expand All @@ -323,13 +330,12 @@ void run ()
Eigen::HouseholderQR<Eigen::MatrixXd> solver (cat_transforms);
rf = solver.solve (cat_data);

CONSOLE (shell_desc + "Response function [" + str(rf.transpose().cast<float>()) + "] solved via ordinary least-squares from " + str(sf_counter) + " voxels");
CONSOLE (shell_desc + "Response function [" + str(rf.transpose().cast<float>()) + "] solved via ordinary least-squares from " + str(voxel_counter) + " voxels");

} else {

// Generate the constraint matrix
// We are going to both constrain the amplitudes to be non-negative, and constrain the derivatives to be non-negative
Eigen::MatrixXd constraints;
const size_t num_angles_constraint = 90;
Eigen::VectorXd els;
els.resize (num_angles_constraint+1);
Expand All @@ -338,7 +344,7 @@ void run ()
Eigen::MatrixXd amp_transform = Math::ZSH::init_amp_transform <default_type> (els, lmax[shell_index]);
Eigen::MatrixXd deriv_transform = Math::ZSH::init_deriv_transform<default_type> (els, lmax[shell_index]);

constraints.resize (amp_transform.rows() + deriv_transform.rows(), amp_transform.cols());
Eigen::MatrixXd constraints (amp_transform.rows() + deriv_transform.rows(), amp_transform.cols());
constraints.block (0, 0, amp_transform.rows(), amp_transform.cols()) = amp_transform;
constraints.block (amp_transform.rows(), 0, deriv_transform.rows(), deriv_transform.cols()) = deriv_transform;

Expand All @@ -349,7 +355,7 @@ void run ()
// Estimate the solution
const size_t niter = solver (rf, cat_data);

CONSOLE (shell_desc + "Response function [" + str(rf.transpose().cast<float>()) + " ] solved after " + str(niter) + " constraint iterations from " + str(sf_counter) + " voxels");
CONSOLE (shell_desc + "Response function [" + str(rf.transpose().cast<float>()) + " ] solved after " + str(niter) + " constraint iterations from " + str(voxel_counter) + " voxels");

}

Expand All @@ -359,7 +365,7 @@ void run ()
rf.resize(1);
rf[0] = cat_data.mean() * std::sqrt(4*Math::pi);

CONSOLE (shell_desc + "Response function [ " + str(float(rf[0])) + " ] from average of " + str(sf_counter) + " voxels");
CONSOLE (shell_desc + "Response function [ " + str(float(rf[0])) + " ] from average of " + str(voxel_counter) + " voxels");

}

Expand Down

0 comments on commit 102618f

Please sign in to comment.