Skip to content

Commit

Permalink
Prefix OnlineSupervisedMStep with Fast
Browse files Browse the repository at this point in the history
  • Loading branch information
angeloskath committed Sep 9, 2016
1 parent 034d42f commit b116ca1
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 49 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Expand Up @@ -42,11 +42,11 @@ set(SOURCES
src/ldaplusplus/em/AbstractEStep.cpp
src/ldaplusplus/em/CorrespondenceSupervisedEStep.cpp
src/ldaplusplus/em/CorrespondenceSupervisedMStep.cpp
src/ldaplusplus/em/FastOnlineSupervisedMStep.cpp
src/ldaplusplus/em/FastSupervisedEStep.cpp
src/ldaplusplus/em/FastSupervisedMStep.cpp
src/ldaplusplus/em/MultinomialSupervisedEStep.cpp
src/ldaplusplus/em/MultinomialSupervisedMStep.cpp
src/ldaplusplus/em/OnlineSupervisedMStep.cpp
src/ldaplusplus/em/SemiSupervisedEStep.cpp
src/ldaplusplus/em/SemiSupervisedMStep.cpp
src/ldaplusplus/em/SupervisedEStep.cpp
Expand Down
32 changes: 16 additions & 16 deletions include/ldaplusplus/LDABuilder.hpp
Expand Up @@ -432,10 +432,10 @@ class LDABuilder : public LDABuilderInterface<Scalar>
}

/**
* Create an OnlineSupervisedMStep without specifying class weights.
* Create an FastOnlineSupervisedMStep without specifying class weights.
*
* You can also see a description of the parameters at
* OnlineSupervisedMStep::OnlineSupervisedMStep.
* FastOnlineSupervisedMStep::FastOnlineSupervisedMStep.
*
* @param num_classes The number of classes
* @param regularization_penalty The L2 penalty for the logistic
Expand All @@ -447,25 +447,25 @@ class LDABuilder : public LDABuilderInterface<Scalar>
* @param eta_learning_rate The learning rate for the SGD
* update of \f$\eta\f$
* @param beta_weight The weight for the online update
* of \f$\beta\f$
* of \f$\beta\f$
*/
std::shared_ptr<em::MStepInterface<Scalar> > get_supervised_online_m_step(
std::shared_ptr<em::MStepInterface<Scalar> > get_fast_supervised_online_m_step(
size_t num_classes,
Scalar regularization_penalty = 1e-2,
size_t minibatch_size = 128,
Scalar eta_momentum = 0.9,
Scalar eta_learning_rate = 0.01,
Scalar beta_weight = 0.9
);
LDABuilder & set_supervised_online_m_step(
LDABuilder & set_fast_supervised_online_m_step(
size_t num_classes,
Scalar regularization_penalty = 1e-2,
size_t minibatch_size = 128,
Scalar eta_momentum = 0.9,
Scalar eta_learning_rate = 0.01,
Scalar beta_weight = 0.9
) {
return set_m(get_supervised_online_m_step(
return set_m(get_fast_supervised_online_m_step(
num_classes,
regularization_penalty,
minibatch_size,
Expand All @@ -476,10 +476,10 @@ class LDABuilder : public LDABuilderInterface<Scalar>
}

/**
* Create an OnlineSupervisedMStep.
* Create an FastOnlineSupervisedMStep.
*
* You can also see a description of the parameters at
* OnlineSupervisedMStep::OnlineSupervisedMStep.
* FastOnlineSupervisedMStep::FastOnlineSupervisedMStep.
*
* @param class_weights Weights to account for class
* imbalance
Expand All @@ -494,23 +494,23 @@ class LDABuilder : public LDABuilderInterface<Scalar>
* @param beta_weight The weight for the online update
* of \f$\beta\f$
*/
std::shared_ptr<em::MStepInterface<Scalar> > get_supervised_online_m_step(
std::shared_ptr<em::MStepInterface<Scalar> > get_fast_supervised_online_m_step(
std::vector<Scalar> class_weights,
Scalar regularization_penalty = 1e-2,
size_t minibatch_size = 128,
Scalar eta_momentum = 0.9,
Scalar eta_learning_rate = 0.01,
Scalar beta_weight = 0.9
);
LDABuilder & set_supervised_online_m_step(
LDABuilder & set_fast_supervised_online_m_step(
std::vector<Scalar> class_weights,
Scalar regularization_penalty = 1e-2,
size_t minibatch_size = 128,
Scalar eta_momentum = 0.9,
Scalar eta_learning_rate = 0.01,
Scalar beta_weight = 0.9
) {
return set_m(get_supervised_online_m_step(
return set_m(get_fast_supervised_online_m_step(
class_weights,
regularization_penalty,
minibatch_size,
Expand All @@ -521,10 +521,10 @@ class LDABuilder : public LDABuilderInterface<Scalar>
}

/**
* Create an OnlineSupervisedMStep.
* Create an FastOnlineSupervisedMStep.
*
* You can also see a description of the parameters at
* OnlineSupervisedMStep::OnlineSupervisedMStep.
* FastOnlineSupervisedMStep::FastOnlineSupervisedMStep.
*
* @param class_weights Weights to account for class
* imbalance
Expand All @@ -539,23 +539,23 @@ class LDABuilder : public LDABuilderInterface<Scalar>
* @param beta_weight The weight for the online update
* of \f$\beta\f$
*/
std::shared_ptr<em::MStepInterface<Scalar> > get_supervised_online_m_step(
std::shared_ptr<em::MStepInterface<Scalar> > get_fast_supervised_online_m_step(
Eigen::Matrix<Scalar, Eigen::Dynamic, 1> class_weights,
Scalar regularization_penalty = 1e-2,
size_t minibatch_size = 128,
Scalar eta_momentum = 0.9,
Scalar eta_learning_rate = 0.01,
Scalar beta_weight = 0.9
);
LDABuilder & set_supervised_online_m_step(
LDABuilder & set_fast_supervised_online_m_step(
Eigen::Matrix<Scalar, Eigen::Dynamic, 1> class_weights,
Scalar regularization_penalty = 1e-2,
size_t minibatch_size = 128,
Scalar eta_momentum = 0.9,
Scalar eta_learning_rate = 0.01,
Scalar beta_weight = 0.9
) {
return set_m(get_supervised_online_m_step(
return set_m(get_fast_supervised_online_m_step(
class_weights,
regularization_penalty,
minibatch_size,
Expand Down
@@ -1,5 +1,5 @@
#ifndef _ONLINE_SUPERVISED_M_STEP_HPP_
#define _ONLINE_SUPERVISED_M_STEP_HPP_
#ifndef _LDAPLUSPLUS_EM_FASTONLINESUPERVISEDMSTEP_HPP_
#define _LDAPLUSPLUS_EM_FASTONLINESUPERVISEDMSTEP_HPP_

#include "ldaplusplus/em/MStepInterface.hpp"

Expand All @@ -8,7 +8,7 @@ namespace em {


/**
* OnlineSupervisedMStep is an online implementation of the classical
* FastOnlineSupervisedMStep is an online implementation of the classical
* categorical supervised LDA.
*
* m_step() is called by doc_m_step() according to the minibatch_size
Expand All @@ -24,14 +24,14 @@ namespace em {
* SupervisedMStep.
*/
template <typename Scalar>
class OnlineSupervisedMStep : public MStepInterface<Scalar>
class FastOnlineSupervisedMStep : public MStepInterface<Scalar>
{
typedef Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic> MatrixX;
typedef Eigen::Matrix<Scalar, Eigen::Dynamic, 1> VectorX;

public:
/**
* Create an OnlineSupervisedMStep that accounts for class imbalance by
* Create an FastOnlineSupervisedMStep that accounts for class imbalance by
* weighting the classes.
*
* @param class_weights Weights to account for class
Expand All @@ -47,7 +47,7 @@ class OnlineSupervisedMStep : public MStepInterface<Scalar>
* @param beta_weight The weight for the online update
* of \f$\beta\f$
*/
OnlineSupervisedMStep(
FastOnlineSupervisedMStep(
VectorX class_weights,
Scalar regularization_penalty = 1e-2,
size_t minibatch_size = 128,
Expand All @@ -56,7 +56,7 @@ class OnlineSupervisedMStep : public MStepInterface<Scalar>
Scalar beta_weight = 0.9
);
/**
* Create an OnlineSupervisedMStep that uses uniform weights for the
* Create an FastOnlineSupervisedMStep that uses uniform weights for the
* classes.
*
* @param num_classes The number of classes
Expand All @@ -71,7 +71,7 @@ class OnlineSupervisedMStep : public MStepInterface<Scalar>
* @param beta_weight The weight for the online update
* of \f$\beta\f$
*/
OnlineSupervisedMStep(
FastOnlineSupervisedMStep(
size_t num_classes,
Scalar regularization_penalty = 1e-2,
size_t minibatch_size = 128,
Expand Down Expand Up @@ -134,4 +134,4 @@ class OnlineSupervisedMStep : public MStepInterface<Scalar>
} // namespace em
} // namespace ldaplusplus

#endif // _ONLINE_SUPERVISED_M_STEP_HPP_
#endif // _LDAPLUSPLUS_EM_FASTONLINESUPERVISEDMSTEP_HPP_
2 changes: 1 addition & 1 deletion slda.cpp
Expand Up @@ -346,7 +346,7 @@ LDA<double> create_lda_for_training(
std::stof(args["--regularization_penalty"].asString())
);
} else if (args["--online_m_step"].asBool()) {
builder.set_supervised_online_m_step(
builder.set_fast_supervised_online_m_step(
create_class_weights(y),
std::stof(args["--regularization_penalty"].asString()),
args["--batch_size"].asLong(),
Expand Down
14 changes: 7 additions & 7 deletions src/ldaplusplus/LDABuilder.cpp
@@ -1,10 +1,10 @@
#include "ldaplusplus/LDABuilder.hpp"
#include "ldaplusplus/em/CorrespondenceSupervisedEStep.hpp"
#include "ldaplusplus/em/CorrespondenceSupervisedMStep.hpp"
#include "ldaplusplus/em/FastOnlineSupervisedMStep.hpp"
#include "ldaplusplus/em/FastSupervisedMStep.hpp"
#include "ldaplusplus/em/MultinomialSupervisedEStep.hpp"
#include "ldaplusplus/em/MultinomialSupervisedMStep.hpp"
#include "ldaplusplus/em/OnlineSupervisedMStep.hpp"
#include "ldaplusplus/em/SemiSupervisedEStep.hpp"
#include "ldaplusplus/em/SemiSupervisedMStep.hpp"
#include "ldaplusplus/em/SupervisedEStep.hpp"
Expand Down Expand Up @@ -175,15 +175,15 @@ std::shared_ptr<em::MStepInterface<Scalar> > LDABuilder<Scalar>::get_supervised_
}

template <typename Scalar>
std::shared_ptr<em::MStepInterface<Scalar> > LDABuilder<Scalar>::get_supervised_online_m_step(
std::shared_ptr<em::MStepInterface<Scalar> > LDABuilder<Scalar>::get_fast_supervised_online_m_step(
size_t num_classes,
Scalar regularization_penalty,
size_t minibatch_size,
Scalar eta_momentum,
Scalar eta_learning_rate,
Scalar beta_weight
) {
return std::make_shared<em::OnlineSupervisedMStep<Scalar> >(
return std::make_shared<em::FastOnlineSupervisedMStep<Scalar> >(
num_classes,
regularization_penalty,
minibatch_size,
Expand All @@ -194,7 +194,7 @@ std::shared_ptr<em::MStepInterface<Scalar> > LDABuilder<Scalar>::get_supervised_
}

template <typename Scalar>
std::shared_ptr<em::MStepInterface<Scalar> > LDABuilder<Scalar>::get_supervised_online_m_step(
std::shared_ptr<em::MStepInterface<Scalar> > LDABuilder<Scalar>::get_fast_supervised_online_m_step(
std::vector<Scalar> class_weights,
Scalar regularization_penalty,
size_t minibatch_size,
Expand All @@ -208,7 +208,7 @@ std::shared_ptr<em::MStepInterface<Scalar> > LDABuilder<Scalar>::get_supervised_
weights[i] = class_weights[i];
}

return get_supervised_online_m_step(
return get_fast_supervised_online_m_step(
weights,
regularization_penalty,
minibatch_size,
Expand All @@ -219,15 +219,15 @@ std::shared_ptr<em::MStepInterface<Scalar> > LDABuilder<Scalar>::get_supervised_
}

template <typename Scalar>
std::shared_ptr<em::MStepInterface<Scalar> > LDABuilder<Scalar>::get_supervised_online_m_step(
std::shared_ptr<em::MStepInterface<Scalar> > LDABuilder<Scalar>::get_fast_supervised_online_m_step(
Eigen::Matrix<Scalar, Eigen::Dynamic, 1> class_weights,
Scalar regularization_penalty,
size_t minibatch_size,
Scalar eta_momentum,
Scalar eta_learning_rate,
Scalar beta_weight
) {
return std::make_shared<em::OnlineSupervisedMStep<Scalar> >(
return std::make_shared<em::FastOnlineSupervisedMStep<Scalar> >(
class_weights,
regularization_penalty,
minibatch_size,
Expand Down
@@ -1,16 +1,15 @@
#include <utility>

#include "ldaplusplus/em/OnlineSupervisedMStep.hpp"
#include "ldaplusplus/em/FastOnlineSupervisedMStep.hpp"
#include "ldaplusplus/optimization/MultinomialLogisticRegression.hpp"
#include "ldaplusplus/events/ProgressEvents.hpp"

namespace ldaplusplus {

using em::OnlineSupervisedMStep;
namespace em {


template <typename Scalar>
OnlineSupervisedMStep<Scalar>::OnlineSupervisedMStep(
FastOnlineSupervisedMStep<Scalar>::FastOnlineSupervisedMStep(
VectorX class_weights,
Scalar regularization_penalty,
size_t minibatch_size,
Expand All @@ -28,14 +27,14 @@ OnlineSupervisedMStep<Scalar>::OnlineSupervisedMStep(
{}

template <typename Scalar>
OnlineSupervisedMStep<Scalar>::OnlineSupervisedMStep(
FastOnlineSupervisedMStep<Scalar>::FastOnlineSupervisedMStep(
size_t num_classes,
Scalar regularization_penalty,
size_t minibatch_size,
Scalar eta_momentum,
Scalar eta_learning_rate,
Scalar beta_weight
) : OnlineSupervisedMStep(
) : FastOnlineSupervisedMStep(
VectorX::Constant(num_classes, 1),
regularization_penalty,
minibatch_size,
Expand All @@ -46,7 +45,7 @@ OnlineSupervisedMStep<Scalar>::OnlineSupervisedMStep(
{}

template <typename Scalar>
void OnlineSupervisedMStep<Scalar>::doc_m_step(
void FastOnlineSupervisedMStep<Scalar>::doc_m_step(
const std::shared_ptr<corpus::Document> doc,
const std::shared_ptr<parameters::Parameters> v_parameters,
std::shared_ptr<parameters::Parameters> m_parameters
Expand Down Expand Up @@ -87,7 +86,7 @@ void OnlineSupervisedMStep<Scalar>::doc_m_step(
}

template <typename Scalar>
void OnlineSupervisedMStep<Scalar>::m_step(
void FastOnlineSupervisedMStep<Scalar>::m_step(
std::shared_ptr<parameters::Parameters> parameters
) {
// Check whether we should actually perform the m_step
Expand Down Expand Up @@ -126,7 +125,9 @@ void OnlineSupervisedMStep<Scalar>::m_step(


// Instantiations
template class OnlineSupervisedMStep<float>;
template class OnlineSupervisedMStep<double>;
template class FastOnlineSupervisedMStep<float>;
template class FastOnlineSupervisedMStep<double>;

}

} // namespace em
} // namespace ldaplusplus
8 changes: 4 additions & 4 deletions test/test_online_maximization_step.cpp
Expand Up @@ -10,8 +10,8 @@

#include "ldaplusplus/Parameters.hpp"
#include "ldaplusplus/events/ProgressEvents.hpp"
#include "ldaplusplus/em/SupervisedEStep.hpp"
#include "ldaplusplus/em/OnlineSupervisedMStep.hpp"
#include "ldaplusplus/em/FastOnlineSupervisedMStep.hpp"
#include "ldaplusplus/em/FastSupervisedEStep.hpp"

using namespace Eigen;
using namespace ldaplusplus;
Expand Down Expand Up @@ -50,8 +50,8 @@ TYPED_TEST(TestOnlineMaximizationStep, Maximization) {
MatrixX<TypeParam>::Zero(10, 6)
);

em::SupervisedEStep<TypeParam> e_step(10, 1e-2, 10);
em::OnlineSupervisedMStep<TypeParam> m_step(
em::FastSupervisedEStep<TypeParam> e_step(10, 1e-2, 10);
em::FastOnlineSupervisedMStep<TypeParam> m_step(
6,
1e-2,
25
Expand Down

0 comments on commit b116ca1

Please sign in to comment.