diff --git a/doc/stages/filters.neighborclassifier.rst b/doc/stages/filters.neighborclassifier.rst new file mode 100644 index 0000000000..490b32aa57 --- /dev/null +++ b/doc/stages/filters.neighborclassifier.rst @@ -0,0 +1,78 @@ +.. _filters.neighborclassifier: + +filters.neighborclassifier +=================== + +The neighborclassifier filter allows you update the value of the classification +for specific points to a value determined by a K-nearest neighbors vote. +For each point, the k nearest neighbors are queried and if more than half of +them have the same value, the filter updates the selected point accordingly + +For example, if an automated classification procedure put/left erroneous +vegetation points near the edges of buildings which were largely classified +correctly, you could try using this filter to fix that problem. + +Similiarly, some automated classification processes result in prediction for +only a subset of the original point cloud. This filter could be used to +extrapolate those predictions to the original. + +.. embed:: + +Example 1 +--------- + +This pipeline updates the Classification of all points with classification +1 (unclassified) based on the consensus (majority) of its nearest 10 neighbors. + +.. code-block:: json + { + "pipeline":[ + "autzen_class.las", + { + "type" : "filters.neighborclassifier", + "domain" : "Classification[1:1]", + "k" : 10 + }, + { + "filename":"autzen_class_refined.las" + } + ] + } + +Example 2 +--------- + +This pipeline moves all the classifications from "pred.txt" +to src.las. Any points in src.las that are not in pred.txt will be +assigned based on the closest point in pred.txt. + +.. code-block:: json + { + "pipeline":[ + "src.las", + { + "type" : "filters.neighborclassifier", + "k" : 1, + "candidate" : "pred.txt" + }, + { + "filename":"dest.las" + } + ] + } + +Options +------- + +candidate + A filename which points to the point cloud containing the points which + will do the voting. If not specified, defaults to the input of the filter. + +domain + A :ref:`range ` which selects points to be processed by the filter. + Can be specified multiple times. Points satisfying any range will be + processed + +k + An integer which specifies the number of neighbors which vote on each + selected point. diff --git a/filters/NeighborClassifierFilter.cpp b/filters/NeighborClassifierFilter.cpp new file mode 100644 index 0000000000..b2dbdace86 --- /dev/null +++ b/filters/NeighborClassifierFilter.cpp @@ -0,0 +1,199 @@ +/****************************************************************************** +* Copyright (c) 2017, Hobu Inc., info@hobu.co +* +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following +* conditions are met: +* +* * Redistributions of source code must retain the above copyright +* notice, this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright +* notice, this list of conditions and the following disclaimer in +* the documentation and/or other materials provided +* with the distribution. +* * Neither the name of Hobu, Inc. or Flaxen Geo Consulting nor the +* names of its contributors may be used to endorse or promote +* products derived from this software without specific prior +* written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +* FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +* COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +* INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS +* OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED +* AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT +* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY +* OF SUCH DAMAGE. +****************************************************************************/ + +#include "NeighborClassifierFilter.hpp" + +#include +#include +#include + +#include "private/DimRange.hpp" + +#include +#include +namespace pdal +{ + +static PluginInfo const s_info = PluginInfo( + "filters.neighborclassifier", + "Re-assign some point attributes based KNN voting", + "http://pdal.io/stages/filters.neighborclassifier.html" ); + +CREATE_STATIC_PLUGIN(1, 0, NeighborClassifierFilter, Filter, s_info) + +NeighborClassifierFilter::NeighborClassifierFilter() : m_dim(Dimension::Id::Classification) +{} + + +NeighborClassifierFilter::~NeighborClassifierFilter() +{} + + +void NeighborClassifierFilter::addArgs(ProgramArgs& args) +{ + args.add("domain", "Selects which points will be subject to KNN-based assignmenassignment", + m_domainSpec); + args.add("k", "Number of nearest neighbors to consult", + m_k).setPositional(); + //args.add("dimension", "Dimension on to be updated", m_dimName).setPositional(); + Arg& candidate = args.add("candidate", "candidate file name", + m_candidateFile); +} + +void NeighborClassifierFilter::initialize() +{ + for (auto const& r : m_domainSpec) + { + try + { + DimRange range; + range.parse(r); + m_domain.push_back(range); + } + catch (const DimRange::error& err) + { + throwError("Invalid 'domain' option: '" + r + "': " + err.what()); + } + } + if (m_k < 1) + throwError("Invalid 'k' option: " + std::to_string(m_k) + ", must be > 0"); +} +void NeighborClassifierFilter::prepared(PointTableRef table) +{ + PointLayoutPtr layout(table.layout()); + + for (auto& r : m_domain) + { + r.m_id = layout->findDim(r.m_name); + if (r.m_id == Dimension::Id::Unknown) + throwError("Invalid dimension name in 'domain' option: '" + + r.m_name + "'."); + } + std::sort(m_domain.begin(), m_domain.end()); + //m_dim = layout->findDim(m_dimName); + + //if (m_dim == Dimension::Id::Unknown) + // throwError("Dimension '" + m_dimName + "' not found."); +} + +void NeighborClassifierFilter::doOneNoDomain(PointRef &point, PointRef &temp, KD3Index &kdi) +{ + std::vector iSrc = kdi.neighbors(point, m_k); + double thresh = iSrc.size()/2.0; + //std::cout << "iSrc.size() " << iSrc.size() << " thresh " << thresh << std::endl; + + // vote NNs + std::map counts; + for (PointId id : iSrc) + { + temp.setPointId(id); + double votefor = temp.getFieldAs(m_dim); + counts[votefor]++; + } + + // pick winner of the vote + auto pr = *std::max_element(counts.begin(), counts.end(), + [](const std::pair& p1, const std::pair& p2) { + return p1.second < p2.second; }); + + // update point + auto oldclass = point.getFieldAs(m_dim); + auto newclass = pr.first; + //std::cout << oldclass << " --> " << newclass << " count " << pr.second << std::endl; + if (pr.second > thresh && oldclass != newclass) + { + point.setField(m_dim, newclass); + } +} + +bool NeighborClassifierFilter::doOne(PointRef& point, PointRef &temp, KD3Index &kdi) +{ // update point. kdi and temp both reference the NN point cloud + + if (m_domain.empty()) // No domain, process all points + doOneNoDomain(point, temp, kdi); + + for (DimRange& r : m_domain) + { // process only points that satisfy a domain condition + if (r.valuePasses(point.getFieldAs(r.m_id))) + { + doOneNoDomain(point, temp, kdi); + break; + } + } + return true; +} + +PointViewPtr NeighborClassifierFilter::loadSet(const std::string& filename, + PointTable& table) +{ + PipelineManager mgr; + + Stage& reader = mgr.makeReader(filename, ""); + reader.prepare(table); + PointViewSet viewSet = reader.execute(table); + assert(viewSet.size() == 1); + return *viewSet.begin(); +} + +void NeighborClassifierFilter::filter(PointView& view) +{ + PointRef point_src(view, 0); + if (m_candidateFile.empty()) + { // No candidate file so NN comes from src file + KD3Index kdiSrc(view); + kdiSrc.build(); + PointRef point_nn(view, 0); + for (PointId id = 0; id < view.size(); ++id) + { + point_src.setPointId(id); + doOne(point_src, point_nn, kdiSrc); + } + } + else + { // NN comes from candidate file + PointTable candTable; + PointViewPtr candView = loadSet(m_candidateFile, candTable); + KD3Index kdiCand(*candView); + kdiCand.build(); + PointRef point_nn(*candView, 0); + for (PointId id = 0; id < view.size(); ++id) + { + point_src.setPointId(id); + doOne(point_src, point_nn, kdiCand); + } + } +} + +} // namespace pdal + diff --git a/filters/NeighborClassifierFilter.hpp b/filters/NeighborClassifierFilter.hpp new file mode 100644 index 0000000000..4ef355c92c --- /dev/null +++ b/filters/NeighborClassifierFilter.hpp @@ -0,0 +1,76 @@ +/****************************************************************************** +* Copyright (c) 2017, Hobu Inc. +* +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following +* conditions are met: +* +* * Redistributions of source code must retain the above copyright +* notice, this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright +* notice, this list of conditions and the following disclaimer in +* the documentation and/or other materials provided +* with the distribution. +* * Neither the name of Hobu, Inc. or Flaxen Geo Consulting nor the +* names of its contributors may be used to endorse or promote +* products derived from this software without specific prior +* written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +* FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +* COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +* INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS +* OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED +* AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT +* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY +* OF SUCH DAMAGE. +****************************************************************************/ + +#pragma once + +#include +#include + +extern "C" int32_t NeighborClassifierFilter_ExitFunc(); +extern "C" PF_ExitFunc NeighborClassifierFilter_InitPlugin(); + +namespace pdal +{ + +struct DimRange; + +class PDAL_DLL NeighborClassifierFilter : public Filter +{ +public: + NeighborClassifierFilter(); + ~NeighborClassifierFilter(); + + static void * create(); + static int32_t destroy(void *); + std::string getName() const { return "filters.neighborclassifier"; } + +private: + virtual void addArgs(ProgramArgs& args); + virtual void prepared(PointTableRef table); + bool doOne(PointRef& point, PointRef& temp, KD3Index &kdi); + virtual void filter(PointView& view); + virtual void initialize(); + void doOneNoDomain(PointRef &point, PointRef& temp, KD3Index &kdi); + PointViewPtr loadSet(const std::string &candFileName, PointTable &table); + NeighborClassifierFilter& operator=(const NeighborClassifierFilter&) = delete; + NeighborClassifierFilter(const NeighborClassifierFilter&) = delete; + StringList m_domainSpec; + std::vector m_domain; + int m_k; + Dimension::Id m_dim; + std::string m_dimName; + std::string m_candidateFile; +}; + +} // namespace pdal diff --git a/pdal/StageFactory.cpp b/pdal/StageFactory.cpp index 4a65e8bb85..effefe13fa 100644 --- a/pdal/StageFactory.cpp +++ b/pdal/StageFactory.cpp @@ -55,6 +55,7 @@ #include #include #include +#include #include #include #include @@ -291,6 +292,7 @@ StageFactory::StageFactory(bool no_plugins) PluginManager::initializePlugin(HeadFilter_InitPlugin); PluginManager::initializePlugin(IQRFilter_InitPlugin); PluginManager::initializePlugin(KDistanceFilter_InitPlugin); + PluginManager::initializePlugin(NeighborClassifierFilter_InitPlugin); PluginManager::initializePlugin(LocateFilter_InitPlugin); PluginManager::initializePlugin(LOFFilter_InitPlugin); PluginManager::initializePlugin(MADFilter_InitPlugin); diff --git a/test/data/las/sample_c.las b/test/data/las/sample_c.las new file mode 100644 index 0000000000..e259b13e2a Binary files /dev/null and b/test/data/las/sample_c.las differ diff --git a/test/data/las/sample_c_thin.las b/test/data/las/sample_c_thin.las new file mode 100644 index 0000000000..58befb7a99 Binary files /dev/null and b/test/data/las/sample_c_thin.las differ diff --git a/test/data/las/sample_nc.las b/test/data/las/sample_nc.las new file mode 100644 index 0000000000..85083a33b6 Binary files /dev/null and b/test/data/las/sample_nc.las differ diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 3610677794..8ad7e109d9 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -114,6 +114,7 @@ PDAL_ADD_TEST(pdal_filters_decimation_test FILES PDAL_ADD_TEST(pdal_filters_divider_test FILES filters/DividerFilterTest.cpp) PDAL_ADD_TEST(pdal_filters_ferry_test FILES filters/FerryFilterTest.cpp) PDAL_ADD_TEST(pdal_filters_groupby_test FILES filters/GroupByFilterTest.cpp) +PDAL_ADD_TEST(pdal_filters_neighborclassifier_test FILES filters/NeighborClassifierFilterTest.cpp) PDAL_ADD_TEST(pdal_filters_locate_test FILES filters/LocateFilterTest.cpp) PDAL_ADD_TEST(pdal_filters_merge_test FILES filters/MergeTest.cpp) PDAL_ADD_TEST(pdal_morton_order_test FILES filters/MortonOrderTest.cpp) diff --git a/test/unit/filters/NeighborClassifierFilterTest.cpp b/test/unit/filters/NeighborClassifierFilterTest.cpp new file mode 100644 index 0000000000..ae49dd773d --- /dev/null +++ b/test/unit/filters/NeighborClassifierFilterTest.cpp @@ -0,0 +1,216 @@ +/****************************************************************************** +* Copyright (c) 2015, Hobu Inc. (info@hobu.co) +* +* All rights reserved. +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following +* conditions are met: +* +* * Redistributions of source code must retain the above copyright +* notice, this list of conditions and the following disclaimer. +* * Redistributions in binary form must reproduce the above copyright +* notice, this list of conditions and the following disclaimer in +* the documentation and/or other materials provided +* with the distribution. +* * Neither the name of Hobu, Inc. nor the +* names of its contributors may be used to endorse or promote +* products derived from this software without specific prior +* written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +* FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +* COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +* INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS +* OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED +* AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT +* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY +* OF SUCH DAMAGE. +****************************************************************************/ + +#include + +#include +#include +#include "Support.hpp" + +#include + +using namespace pdal; +const stats::Summary::EnumMap GetClassifications(Stage &s, unsigned int *count = NULL) +{ + StatsFilter stats; + stats.setInput(s); + + Options statOpts; + statOpts.add("enumerate", "Classification"); + stats.setOptions(statOpts); + + PointTable table; + stats.prepare(table); + PointViewSet viewSet = stats.execute(table); + const stats::Summary& statsClassification = stats.getStats(Dimension::Id::Classification); + if (count) + *count = statsClassification.count(); + return statsClassification.values(); +} +TEST(NeighborClassifierFilterTest, singleRange) +{ + Options ro; + ro.add("filename", Support::datapath("las/sample_c.las")); + + StageFactory factory; + Stage& r = *(factory.createStage("readers.las")); + r.setOptions(ro); + unsigned int count = 0; + stats::Summary::EnumMap OrigClassifications = GetClassifications(r, &count); + + std::vector kvals = {1, 3}; + for (auto &k : kvals) { + + Options fo; + fo.add("domain", "Classification[14:14]"); + //fo.add("dimension", "Classification"); + fo.add("k", k); + + Stage& f = *(factory.createStage("filters.neighborclassifier")); + f.setInput(r); + + f.setOptions(fo); + + PointTable table; + f.prepare(table); + PointViewSet viewSet = f.execute(table); + PointViewPtr view = *viewSet.begin(); + + EXPECT_EQ(1u, viewSet.size()); + EXPECT_EQ(count, view->size()); + + stats::Summary::EnumMap NewClassifications = GetClassifications(f); + + for (auto& p : OrigClassifications) + { + if (k == 1) + { + EXPECT_TRUE(NewClassifications[p.first] == OrigClassifications[p.first]); + } + else + { + if (p.first == 14) + { + EXPECT_TRUE(NewClassifications[p.first] == 0); + } + else + EXPECT_TRUE(NewClassifications[p.first] >= OrigClassifications[p.first]); + } + } + } +} + +TEST(NeighborClassifierFilterTest, multipleRange) +{ + Options ro; + ro.add("filename", Support::datapath("las/sample_c.las")); + + StageFactory factory; + Stage& r = *(factory.createStage("readers.las")); + r.setOptions(ro); + unsigned int count = 0; + stats::Summary::EnumMap OrigClassifications = GetClassifications(r, &count); + + std::vector kvals = {1, 3}; + for (auto &k : kvals) { + + Options fo; + fo.add("domain", "Classification[14:14], Classification[11:11]"); + //fo.add("dimension", "Classification"); + fo.add("k", k); + + Stage& f = *(factory.createStage("filters.neighborclassifier")); + f.setInput(r); + f.setOptions(fo); + + PointTable table; + f.prepare(table); + PointViewSet viewSet = f.execute(table); + PointViewPtr view = *viewSet.begin(); + + EXPECT_EQ(1u, viewSet.size()); + EXPECT_EQ(count, view->size()); + + stats::Summary::EnumMap NewClassifications = GetClassifications(f); + + for (auto& p : OrigClassifications) + { + if (k == 1) + { + EXPECT_TRUE(NewClassifications[p.first] == OrigClassifications[p.first]); + } + else + { + if (p.first == 14 || p.first == 11) + { + EXPECT_TRUE(NewClassifications[p.first] == 0); + } + else + EXPECT_TRUE(NewClassifications[p.first] >= OrigClassifications[p.first]); + } + } + } +} + +TEST(NeighborClassifierFilterTest, candidate) +{ + StageFactory factory; + + Options rClassifications; + unsigned int count = 0; + rClassifications.add("filename", Support::datapath("las/sample_c.las")); + Stage& rC = *(factory.createStage("readers.las")); + rC.setOptions(rClassifications); + stats::Summary::EnumMap OrigClassifications = GetClassifications(rC, &count); + + Options ro; + ro.add("filename", Support::datapath("las/sample_nc.las")); + + Stage& r = *(factory.createStage("readers.las")); + r.setOptions(ro); + + std::vector kvals = {1}; + for (auto &k : kvals) { + + Options fo; + //fo.add("dimension", "Classification"); + fo.add("candidate", Support::datapath("las/sample_c_thin.las")); + fo.add("k", k); + + Stage& f = *(factory.createStage("filters.neighborclassifier")); + f.setInput(r); + f.setOptions(fo); + + PointTable table; + f.prepare(table); + PointViewSet viewSet = f.execute(table); + PointViewPtr view = *viewSet.begin(); + + EXPECT_EQ(1u, viewSet.size()); + EXPECT_EQ(count, view->size()); + + stats::Summary::EnumMap NewClassifications = GetClassifications(f); + + //std::cout << "**** K = " << k << " ****** " << std::endl; + for (auto& p : OrigClassifications) + { + if (p.first == 6) + { + EXPECT_TRUE(NewClassifications[p.first] == 12441 && OrigClassifications[p.first] == 12525); + } + //std::cout << " OrigClassifications["<< p.first << "] = " << OrigClassifications[p.first] << + //" --> " << "NewClassifications[" << p.first << "] = " << NewClassifications[p.first] << std::endl; + } + } +}