Skip to content

Commit

Permalink
CalcNormConst refactoring. Todo auto-detect if sumProduct needs to be…
Browse files Browse the repository at this point in the history
… re-run.
  • Loading branch information
TobiasMadsen committed May 13, 2015
1 parent 27d3dd2 commit d636c40
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 23 deletions.
22 changes: 5 additions & 17 deletions src/DiscreteFactorGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,19 +226,6 @@ namespace phy {
}
}


number_t DFG::calcNormConst2(stateMaskVec_t const & stateMasks, vector<vector<message_t const *> > & inMessages) const
{
return calcNormConst(0, stateMasks[0], inMessages[0]);
}


number_t DFG::calcNormConst2(stateMaskVec_t const & stateMasks) const
{
return calcNormConst(0, stateMasks[0], inMessages_[0]);
}


void DFG::runSumProduct(stateMaskVec_t const & stateMasks)
{
if (inMessages_.size() == 0)
Expand Down Expand Up @@ -465,14 +452,15 @@ namespace phy {
for(int i = 0; i < roots.size(); ++i){
unsigned const root = roots.at(i);
runSumProductInwardsRec(root, root, stateMasks, inMessages, outMessages);
res *= calcNormConst(root, stateMasks[root], inMessages[root]);
res *= calcNormConstComponent(root, stateMasks[root], inMessages[root]);
}
return res;
}


number_t DFG::calcNormConst(unsigned varId, stateMask_t const * stateMask, vector<message_t const *> const & inMes) const
number_t DFG::calcNormConstComponent(unsigned varId, stateMask_t const * stateMask, vector<message_t const *> const & inMes) const
{
// Calculates Normalizing Constant for the component that contains varId
assert( varId < variables.size() );
unsigned dim = nodes[ variables[varId] ].dimension;
message_t v(vector_t(dim), 0);
Expand Down Expand Up @@ -939,7 +927,7 @@ namespace phy {
stateMask_t const * stateMask = stateMasks[ convNodeToVar(root) ];

// calculate the likelihood of current component
number_t lik_com = calcNormConst(root, stateMask, inMu_[root]);
number_t lik_com = calcNormConstComponent(root, stateMask, inMu_[root]);
res_lik *= lik_com;

if(!nodes[root].isFactor){
Expand Down Expand Up @@ -1223,7 +1211,7 @@ namespace phy {
dfg.calcFactorMarginals(tmpFacMar);
for (unsigned j = 0; j < dfg.factors.size(); j++)
accFacMar[j] += tmpFacMar[j];
normConst = dfg.calcNormConst2(stateMaskVec);
normConst = dfg.calcNormConst(stateMaskVec);
}


Expand Down
8 changes: 2 additions & 6 deletions src/DiscreteFactorGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,6 @@ using namespace std;
void runSumProduct(stateMaskVec_t const & stateMasks);
void runSumProduct(stateMaskVec_t const & stateMasks, vector<vector<message_t const *> > & inMessages, vector<vector<message_t> > & outMessages) const;

/** Precondition: runSumProduct has been called (setting all out/inMessages). Calc the normalization constant (Z). */
number_t calcNormConst2(stateMaskVec_t const & stateMasks) const;
number_t calcNormConst2(stateMaskVec_t const & stateMasks, vector<vector<message_t const *> > & inMessages) const;

/** Precondition: runSumProcut has been called (setting all out/inMessages). If variableMarginals is an empty vector ::initVariableMarginals will be called*/
void calcVariableMarginals(stateMaskVec_t const & stateMasks);
void calcVariableMarginals(vector<vector_t> & variableMarginals, stateMaskVec_t const & stateMasks);
Expand Down Expand Up @@ -188,7 +184,7 @@ using namespace std;
void calcSumProductMessageVariable(unsigned current, unsigned receiver, stateMaskVec_t const & stateMasks, vector<vector<message_t const *> > & inMessages, vector<vector<message_t> > & outMessages) const;
void calcSumProductMessageVariable(unsigned current, unsigned receiver, stateMask_t const * stateMask, vector<message_t const *> const & inMes, message_t & outMes) const;
void calcSumProductMessage(unsigned current, unsigned receiver, stateMaskVec_t const & stateMasks, vector<vector<message_t const *> > & inMessages, vector<vector<message_t> > & outMessages) const;
number_t calcNormConst(unsigned varId, stateMask_t const * stateMask, vector<message_t const *> const & inMes) const;
number_t calcNormConstComponent(unsigned varId, stateMask_t const * stateMask, vector<message_t const *> const & inMes) const;

// helper functions for maxSum();
unsigned maxNeighborDimension(vector<unsigned> const & nbs) const;
Expand Down Expand Up @@ -290,7 +286,7 @@ using namespace std;

/** helper functions. Interface provided above*/
template <class T>
void initGenericVariableMarginals(vector<T> & variableMarginals, DFG const & dfg)
void initoGenericVariableMarginals(vector<T> & variableMarginals, DFG const & dfg)
{
variableMarginals.clear();
variableMarginals.resize( dfg.variables.size() );
Expand Down
7 changes: 7 additions & 0 deletions tests/boost/test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#define BOOST_TEST_DYN_LINK

#define BOOST_TEST_MODULE UnitTests

// boost test
#include <boost/test/unit_test.hpp>
#include <boost/test/floating_point_comparison.hpp>

0 comments on commit d636c40

Please sign in to comment.