Skip to content

Commit

Permalink
Add normal refinement via minimum spanning tree propagation
Browse files Browse the repository at this point in the history
  • Loading branch information
chambbj committed Feb 25, 2020
1 parent 6ec91b8 commit dcabe0f
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 43 deletions.
226 changes: 184 additions & 42 deletions filters/NormalFilter.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/******************************************************************************
* Copyright (c) 2016-2017, Bradley J Chambers (brad.chambers@gmail.com)
* Copyright (c) 2016, 2017, 2019, 2020 Bradley J Chambers (brad.chambers@gmail.com)
*
* All rights reserved.
*
Expand Down Expand Up @@ -47,29 +47,46 @@
namespace pdal
{

static StaticPluginInfo const s_info
struct Edge
{
"filters.normal",
"Normal Filter",
"http://pdal.io/stages/filters.normal.html"
PointId m_v0;
PointId m_v1;
float m_weight;

Edge(PointId i, PointId j, float weight)
: m_v0(i), m_v1(j), m_weight(weight)
{
}
};

struct CompareEdgeWeight
{
bool operator()(Edge const& lhs, Edge const& rhs)
{
return lhs.m_weight > rhs.m_weight;
}
};

using namespace Eigen;
using namespace Dimension;

static StaticPluginInfo const s_info{
"filters.normal", "Normal Filter",
"http://pdal.io/stages/filters.normal.html"};

CREATE_STATIC_STAGE(NormalFilter, s_info)

struct NormalArgs
{
int m_knn;
filter::Point m_viewpoint;
bool m_up;
bool m_refine;
};

NormalFilter::NormalFilter() : m_args(new NormalArgs)
{}


NormalFilter::~NormalFilter()
{}
NormalFilter::NormalFilter() : m_args(new NormalArgs) {}

NormalFilter::~NormalFilter() {}

std::string NormalFilter::getName() const
{
Expand All @@ -79,16 +96,17 @@ std::string NormalFilter::getName() const
void NormalFilter::addArgs(ProgramArgs& args)
{
args.add("knn", "k-Nearest Neighbors", m_args->m_knn, 8);
m_viewpointArg = &args.add("viewpoint",
"Viewpoint as WKT or GeoJSON", m_args->m_viewpoint);
m_viewpointArg = &args.add("viewpoint", "Viewpoint as WKT or GeoJSON",
m_args->m_viewpoint);
args.add("always_up", "Normals always oriented with positive Z?",
m_args->m_up, true);
m_args->m_up, true);
args.add("refine",
"Refine normals using minimum spanning tree propagation?",
m_args->m_refine, true);
}

void NormalFilter::addDimensions(PointLayoutPtr layout)
{
using namespace Dimension;

layout->registerDims(
{Id::NormalX, Id::NormalY, Id::NormalZ, Id::Curvature});
}
Expand All @@ -112,53 +130,177 @@ void NormalFilter::prepared(PointTableRef table)
<< "Viewpoint provided. Ignoring always_up = TRUE." << std::endl;
m_args->m_up = false;
}

// The query point is returned as a neighbor of itself, so we must increase
// k by one to get the desired number of neighbors.
++m_args->m_knn;
}

void NormalFilter::filter(PointView& view)
void NormalFilter::compute(PointView& view, KD3Index& kdi)
{
KD3Index& kdi = view.build3dIndex();

for (PointId i = 0; i < view.size(); ++i)
log()->get(LogLevel::Debug) << "Computing normal vectors\n";
for (PointId idx = 0; idx < view.size(); ++idx)
{
// find the k-nearest neighbors
auto ids = kdi.neighbors(i, m_args->m_knn);
PointRef p = view.point(idx);

// compute covariance of the neighborhood
auto B = computeCovariance(view, ids);

// perform the eigen decomposition
Eigen::SelfAdjointEigenSolver<Eigen::Matrix3d> solver(B);
if (solver.info() != Eigen::Success)
// Perform eigen decomposition of covariance matrix computed from
// neighborhood composed of k-nearest neighbors.
PointIdList neighbors = kdi.neighbors(idx, m_args->m_knn);
auto B = computeCovariance(view, neighbors);
SelfAdjointEigenSolver<Matrix3d> solver(B);
if (solver.info() != Success)
throwError("Cannot perform eigen decomposition.");

// The curvature is computed as the ratio of the first (smallest)
// eigenvalue to the sum of all eigenvalues.
auto eval = solver.eigenvalues();
Eigen::Vector3d normal = solver.eigenvectors().col(0);
double sum = eval[0] + eval[1] + eval[2];
double curvature = sum ? std::fabs(eval[0] / sum) : 0;

// The normal is defined by the eigenvector corresponding to the
// smallest eigenvalue.
Vector3d normal = solver.eigenvectors().col(0);

if (m_viewpointArg->set())
{
using namespace Dimension;

PointRef p = view.point(i);
Eigen::Vector3d vp(
(float)(m_args->m_viewpoint.x() - p.getFieldAs<double>(Id::X)),
(float)(m_args->m_viewpoint.y() - p.getFieldAs<double>(Id::Y)),
(float)(m_args->m_viewpoint.z() - p.getFieldAs<double>(Id::Z)));
// If a viewpoint has been specified, orient the normals to face the
// viewpoint by taking the dot product of the vector connecting the
// point with the viewpoint and the normal. Flip the normal, where
// the dot product is negative.
float dx = static_cast<float>(m_args->m_viewpoint.x() -
p.getFieldAs<double>(Id::X));
float dy = static_cast<float>(m_args->m_viewpoint.y() -
p.getFieldAs<double>(Id::Y));
float dz = static_cast<float>(m_args->m_viewpoint.z() -
p.getFieldAs<double>(Id::Z));
Vector3d vp(dx, dy, dz);
if (vp.dot(normal) < 0)
normal *= -1.0;
}
else if (m_args->m_up)
{
// If normals are expected to be upward facing, invert them when the
// Z component is negative.
if (normal[2] < 0)
normal *= -1.0;
}

view.setField(Dimension::Id::NormalX, i, normal[0]);
view.setField(Dimension::Id::NormalY, i, normal[1]);
view.setField(Dimension::Id::NormalZ, i, normal[2]);
// Set the computed normal and curvature dimensions.
p.setField(Id::NormalX, normal[0]);
p.setField(Id::NormalY, normal[1]);
p.setField(Id::NormalZ, normal[2]);
p.setField(Id::Curvature, curvature);
}
}

double sum = eval[0] + eval[1] + eval[2];
view.setField(Dimension::Id::Curvature, i,
sum ? std::fabs(eval[0] / sum) : 0);
void NormalFilter::refine(PointView& view, KD3Index& kdi)
{
log()->get(LogLevel::Debug)
<< "Refining normals using minimum spanning tree\n";

typedef std::vector<Edge> EdgeList;
std::priority_queue<Edge, EdgeList, CompareEdgeWeight> edge_queue;
std::vector<bool> inMST(view.size(), false);
PointId nextIdx(0);
point_count_t count(0);
while (count < view.size())
{
// Find the PointId of the next point not currently part of the minimum
// spanning tree.
while (inMST[nextIdx])
++nextIdx;

// Lambda to add the current PointId to the minimum spanning tree and
// update the edge queue.
auto update = [&](PointId updateIdx) {
// Add the current PointId to the minimum spanning tree.
inMST[updateIdx] = true;
++count;

// Consider neighbors of the newly added PointId, adding them to
// the edge queue if they are not already part of the minimum
// spanning tree.
PointIdList neighbors = kdi.neighbors(updateIdx, m_args->m_knn);
PointRef p = view.point(updateIdx);
Vector3d N1(p.getFieldAs<double>(Id::NormalX),
p.getFieldAs<double>(Id::NormalY),
p.getFieldAs<double>(Id::NormalZ));
for (PointId const& neighborIdx : neighbors)
{
if (updateIdx != neighborIdx && !inMST[neighborIdx])
{
PointRef q = view.point(neighborIdx);
Vector3d N2(q.getFieldAs<double>(Id::NormalX),
q.getFieldAs<double>(Id::NormalY),
q.getFieldAs<double>(Id::NormalZ));
float weight = std::max(
0.0, 1.0 - static_cast<float>(std::fabs(N1.dot(N2))));
edge_queue.emplace(updateIdx, neighborIdx, weight);
}
}
};

update(nextIdx);

// Iterate on the edge queue until empty (or all points have been added
// to the minimum spanning tree).
while (!edge_queue.empty() && (count < view.size()))
{
// Retrieve the edge with the smallest weight.
Edge edge(edge_queue.top());
edge_queue.pop();

// Record the PointId and normal of the PointId (if one exists)
// that is not already in the minimum spanning tree.
PointId newIdx(0);
Vector3d normal;
PointRef p = view.point(edge.m_v0);
Vector3d N1(p.getFieldAs<double>(Id::NormalX),
p.getFieldAs<double>(Id::NormalY),
p.getFieldAs<double>(Id::NormalZ));
PointRef q = view.point(edge.m_v1);
Vector3d N2(q.getFieldAs<double>(Id::NormalX),
q.getFieldAs<double>(Id::NormalY),
q.getFieldAs<double>(Id::NormalZ));
if (!inMST[edge.m_v0])
{
newIdx = edge.m_v0;
normal = N1;
}
else if (!inMST[edge.m_v1])
{
newIdx = edge.m_v1;
normal = N2;
}
else
continue;

// Where the dot product of the normals is less than 0, invert the
// normal of the selected PointId.
if (N1.dot(N2) < 0)
{
normal *= -1;
view.setField(Id::NormalX, newIdx, normal(0));
view.setField(Id::NormalY, newIdx, normal(1));
view.setField(Id::NormalZ, newIdx, normal(2));
}

update(newIdx);
}
}
}

void NormalFilter::filter(PointView& view)
{
KD3Index& kdi = view.build3dIndex();

// Compute the normal/curvature and optionally orient toward viewpoint or
// positive Z.
compute(view, kdi);

// If requested, refine normals through minimum spanning tree propagation.
if (m_args->m_refine)
refine(view, kdi);
}

} // namespace pdal
5 changes: 4 additions & 1 deletion filters/NormalFilter.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/******************************************************************************
* Copyright (c) 2016-2017, Bradley J Chambers (brad.chambers@gmail.com)
* Copyright (c) 2016, 2017, 2020 Bradley J Chambers (brad.chambers@gmail.com)
*
* All rights reserved.
*
Expand Down Expand Up @@ -65,6 +65,9 @@ class PDAL_DLL NormalFilter : public Filter
std::unique_ptr<NormalArgs> m_args;
Arg* m_viewpointArg;

void compute(PointView& view, KD3Index& kdi);
void refine(PointView& view, KD3Index& kdi);

virtual void addArgs(ProgramArgs& args);
virtual void addDimensions(PointLayoutPtr layout);
virtual void prepared(PointTableRef table);
Expand Down

0 comments on commit dcabe0f

Please sign in to comment.