Skip to content

Commit

Permalink
Merge pull request #254 from Luke-Pratley/fix_l1_norm_eval
Browse files Browse the repository at this point in the history
Fix l1 norm eval
  • Loading branch information
Luke-Pratley committed Nov 11, 2019
2 parents f9ab56a + 655a2d8 commit 1ce136c
Show file tree
Hide file tree
Showing 12 changed files with 62 additions and 39 deletions.
2 changes: 1 addition & 1 deletion cpp/sopt/forward_backward.h
Expand Up @@ -233,7 +233,7 @@ void ForwardBackward<SCALAR>::iteration_step(t_Vector &out, t_Vector &residual,
f_gradient(z, residual);
g_proximal(out, gamma() * beta(), out - beta() / nu() * (Phi().adjoint() * z));
p = out + lambda * (out - p);
residual = (Phi() * p) / nu() - target();
residual = (Phi() * p) - target();
}

template <class SCALAR>
Expand Down
6 changes: 4 additions & 2 deletions cpp/sopt/imaging_forward_backward.h
Expand Up @@ -355,7 +355,8 @@ bool ImagingForwardBackward<SCALAR>::objective_convergence(ScalarRelativeVariati
if (static_cast<bool>(objective_convergence())) return objective_convergence()(x, residual);
if (scalvar.relative_tolerance() <= 0e0) return true;
auto const current =
((gamma() > 0) ? sopt::l1_norm(Psi().adjoint() * x, l1_proximal_weights()) * gamma() : 0) +
((gamma() > 0) ? sopt::l1_norm((Psi().adjoint() * x).eval(), l1_proximal_weights()) * gamma()
: 0) +
std::pow(sopt::l2_norm(residual), 2) / (2 * sigma() * sigma());
return scalvar(current);
};
Expand All @@ -369,7 +370,8 @@ bool ImagingForwardBackward<SCALAR>::objective_convergence(mpi::Communicator con
if (static_cast<bool>(objective_convergence())) return objective_convergence()(x, residual);
if (scalvar.relative_tolerance() <= 0e0) return true;
auto const current = obj_comm.all_sum_all<t_real>(
((gamma() > 0) ? sopt::l1_norm(Psi().adjoint() * x, l1_proximal_weights()) * gamma() : 0) +
((gamma() > 0) ? sopt::l1_norm((Psi().adjoint() * x).eval(), l1_proximal_weights()) * gamma()
: 0) +
std::pow(sopt::l2_norm(residual), 2) / (2 * sigma() * sigma()));
return scalvar(current);
};
Expand Down
2 changes: 1 addition & 1 deletion cpp/sopt/imaging_padmm.h
Expand Up @@ -326,7 +326,7 @@ bool ImagingProximalADMM<SCALAR>::objective_convergence(ScalarRelativeVariation<
t_Vector const &residual) const {
if (static_cast<bool>(objective_convergence())) return objective_convergence()(x, residual);
if (scalvar.relative_tolerance() <= 0e0) return true;
auto const current = sopt::l1_norm(Psi().adjoint() * x, l1_proximal_weights());
auto const current = sopt::l1_norm((Psi().adjoint() * x).eval(), l1_proximal_weights());
return scalvar(current);
};

Expand Down
16 changes: 6 additions & 10 deletions cpp/sopt/imaging_primal_dual.h
Expand Up @@ -51,10 +51,10 @@ class ImagingPrimalDual {
template <class DERIVED>
ImagingPrimalDual(Eigen::MatrixBase<DERIVED> const &target)
: l1_proximal_([](t_Vector &out, const Real &gamma, const t_Vector &x) {
proximal::l1_norm(out, gamma, x);
proximal::l1_norm<t_Vector, t_Vector>(out, gamma, x);
}),
l1_proximal_weighted_([](t_Vector &out, const Vector<Real> &gamma, const t_Vector &x) {
proximal::l1_norm(out, gamma, x);
proximal::l1_norm<t_Vector, t_Vector, Vector<Real>>(out, gamma, x);
}),
l1_proximal_weights_(Vector<Real>::Ones(1)),
l2ball_proximal_(1e0),
Expand Down Expand Up @@ -290,12 +290,11 @@ class ImagingPrimalDual {
//! check that l1 and weighted l1 proximal operators are the same function (except for weights)
bool check_l1_weight_proximal(const t_Proximal<Real> &no_weights,
const t_Proximal<Vector<Real>> &with_weights) const {
Vector<SCALAR> output;
Vector<SCALAR> outputw;

const Vector<SCALAR> x = Vector<SCALAR>::Ones(this->l1_proximal_weights().size());
Vector<SCALAR> output = Vector<SCALAR>::Zero(this->l1_proximal_weights().size());
Vector<SCALAR> outputw = Vector<SCALAR>::Zero(this->l1_proximal_weights().size());
no_weights(output, 1, x);
with_weights(outputw, Vector<Real>::Ones(1), x);
with_weights(outputw, Vector<Real>::Ones(this->l1_proximal_weights().size()), x);
return output.isApprox(outputw);
};
};
Expand Down Expand Up @@ -379,10 +378,7 @@ bool ImagingPrimalDual<SCALAR>::objective_convergence(ScalarRelativeVariation<Sc
t_Vector const &residual) const {
if (static_cast<bool>(objective_convergence())) return objective_convergence()(x, residual);
if (scalvar.relative_tolerance() <= 0e0) return true;
auto const current =
(l1_proximal_weights().size() > 1)
? sopt::l1_norm(l1_proximal_weights().array() * (Psi().adjoint() * x).array())
: sopt::l1_norm(l1_proximal_weights()(0) * (Psi().adjoint() * x));
auto const current = sopt::l1_norm((Psi().adjoint() * x).eval(), l1_proximal_weights());
return scalvar(current);
};

Expand Down
6 changes: 4 additions & 2 deletions cpp/sopt/l1_proximal.h
Expand Up @@ -143,11 +143,13 @@ typename std::enable_if<is_complex<SCALAR>::value == is_complex<typename T0::Sca
L1TightFrame<SCALAR>::objective(Eigen::MatrixBase<T0> const &x, Eigen::MatrixBase<T1> const &z,
Real const &gamma) const {
#ifdef SOPT_MPI
auto const adj = gamma * sopt::mpi::l1_norm(Psi().adjoint() * z, weights(), adjoint_space_comm());
auto const adj =
gamma * sopt::mpi::l1_norm((Psi().adjoint() * z).eval(), weights(), adjoint_space_comm());
auto const dir = direct_space_comm().all_sum_all(0.5 * (x - z).squaredNorm());
return adj + dir;
#else
return 0.5 * (x - z).squaredNorm() + gamma * sopt::l1_norm(Psi().adjoint() * z, weights());
return 0.5 * (x - z).squaredNorm() +
gamma * sopt::l1_norm((Psi().adjoint() * z).eval(), weights());
#endif
}

Expand Down
4 changes: 2 additions & 2 deletions cpp/sopt/l2_primal_dual.h
Expand Up @@ -236,11 +236,11 @@ class ImagingPrimalDual {
// E.g.: `paddm.l2_proximal_itermax(100).l2ball_epsilon(1e-2).l2_proximal_tolerance(1e-4)`.
// ~~~
#define SOPT_MACRO(VAR, NAME, PROXIMAL) \
/** \brief Forwards to l2ball_proximal **/ \
/** \brief Forwards to l2ball_proximal **/ \
decltype(std::declval<proximal::PROXIMAL<Scalar> const>().VAR()) NAME##_proximal_##VAR() const { \
return NAME##_proximal().VAR(); \
} \
/** \brief Forwards to l2ball_proximal **/ \
/** \brief Forwards to l2ball_proximal **/ \
ImagingPrimalDual<Scalar> &NAME##_proximal_##VAR( \
decltype(std::declval<proximal::PROXIMAL<Scalar> const>().VAR()) VAR) { \
NAME##_proximal().VAR(VAR); \
Expand Down
2 changes: 1 addition & 1 deletion cpp/sopt/padmm.h
Expand Up @@ -195,7 +195,7 @@ class ProximalADMM {
static std::tuple<t_Vector, t_Vector> initial_guess(t_Vector const &target,
t_LinearTransform const &phi, Real nu) {
std::tuple<t_Vector, t_Vector> guess;
std::get<0>(guess) = phi.adjoint() * target / nu;
std::get<0>(guess) = (phi.adjoint() * target).eval() / nu;
std::get<1>(guess) = phi * std::get<0>(guess) - target;
return guess;
}
Expand Down
14 changes: 14 additions & 0 deletions cpp/sopt/power_method.h
Expand Up @@ -81,6 +81,20 @@ std::tuple<t_real, T, sopt::LinearTransform<T>> normalise_operator(
#ifdef SOPT_MPI
//! Performs an all sum all operation to collectively normalise different serial operators
template <class T>
std::tuple<t_real, T> all_sum_all_power_method(const sopt::mpi::Communicator &comm,
const sopt::LinearTransform<T> &op,
const t_uint &niters,
const t_real &relative_difference,
const T &initial_vector) {
const auto all_sum_all_op = sopt::LinearTransform<T>(
[&op](T &output, const T &input) { output = (op * input).eval(); }, op.sizes(),
[&op, comm](T &output, const T &input) {
output = comm.all_sum_all((op.adjoint() * input).eval());
},
op.adjoint().sizes());
return power_method(all_sum_all_op, niters, relative_difference, initial_vector.derived());
}
template <class T>
std::tuple<t_real, T, std::shared_ptr<sopt::LinearTransform<T>>> all_sum_all_normalise_operator(
const sopt::mpi::Communicator &comm, const std::shared_ptr<sopt::LinearTransform<T> const> &op,
const t_uint &niters, const t_real &relative_difference, const T &initial_vector) {
Expand Down
33 changes: 16 additions & 17 deletions cpp/sopt/primal_dual.h
Expand Up @@ -225,23 +225,23 @@ class PrimalDual {

//! \brief Computes initial guess for x and the residual using the targets
//! \details with y the vector of measurements
//! - x = Φ^T y / ν
//! - x = Φ^T y * xi * tau
//! - residuals = Φ x - y
std::tuple<t_Vector, t_Vector> initial_guess() const {
return PrimalDual<SCALAR>::initial_guess(target(), Phi(), nu());
return PrimalDual<SCALAR>::initial_guess(target(), Phi(), xi());
}

//! \brief Computes initial guess for x and the residual using the targets
//! \details with y the vector of measurements
//! - x = Φ^T y / ν
//! - x = Φ^T y * xi * tau
//! - residuals = Φ x - y
//!
//! This function simplifies creating overloads for operator() in PD wrappers.
static std::tuple<t_Vector, t_Vector> initial_guess(t_Vector const &target,
t_LinearTransform const &phi, Real nu) {
std::tuple<t_Vector, t_Vector> guess;
std::get<0>(guess) = phi.adjoint() * target / nu;
std::get<1>(guess) = phi * std::get<0>(guess) - target;
std::get<0>(guess) = (phi.adjoint() * t_Vector::Zero(target.size())).eval();
std::get<1>(guess) = target;
return guess;
}

Expand Down Expand Up @@ -289,8 +289,8 @@ void PrimalDual<SCALAR>::iteration_step(t_Vector &out, t_Vector &out_hold, t_Vec
}
// dual calculations for wavelet
if (random_wavelet_update) {
q = Psi().adjoint() * out_hold;
f_proximal(u_hold, gamma(), u + q);
q = (Psi().adjoint() * out_hold) * sigma();
f_proximal(u_hold, gamma(), (u + q));
u_hold = u + q - u_hold;
u = u + update_scale() * (u_hold - u);
u_update = Psi() * u;
Expand All @@ -301,18 +301,17 @@ void PrimalDual<SCALAR>::iteration_step(t_Vector &out, t_Vector &out_hold, t_Vec
if (v_all_sum_all_comm().size() > 0 and u_all_sum_all_comm().size() > 0)
constraint()(
out_hold,
r - tau() *
(u_all_sum_all_comm().all_sum_all(static_cast<const t_Vector>(u_update)) * sigma() +
v_all_sum_all_comm().all_sum_all(static_cast<const t_Vector>(v_update)) * xi()));
r - tau() * (u_all_sum_all_comm().all_sum_all(static_cast<const t_Vector>(u_update)) +
v_all_sum_all_comm().all_sum_all(static_cast<const t_Vector>(v_update))));
else
#endif
constraint()(out_hold, r - tau() * (u_update * sigma() + v_update * xi()));
constraint()(out_hold, r - tau() * (u_update + v_update));
out = r + update_scale() * (out_hold - r);
out_hold = 2 * out_hold - r;
random_measurement_update = random_measurement_updater_();
random_wavelet_update = random_wavelet_updater_();
// update residual
if (random_measurement_update) residual = Phi() * out_hold - target();
if (random_measurement_update) residual = ((Phi() * out_hold) * xi() - target());
}

template <class SCALAR>
Expand All @@ -325,13 +324,13 @@ typename PrimalDual<SCALAR>::Diagnostic PrimalDual<SCALAR>::operator()(
t_Vector residual = res_guess;
out = x_guess;
t_Vector out_hold = x_guess;
t_Vector r = out;
t_Vector v = t_Vector::Zero(target().size());
t_Vector v_hold = t_Vector::Zero(target().size());
t_Vector r = x_guess;
t_Vector v = residual;
t_Vector v_hold = residual;
t_Vector v_update = x_guess;
t_Vector u = Psi().adjoint() * t_Vector::Zero(x_guess.size());
t_Vector u = Psi().adjoint() * out;
t_Vector u_hold = u;
t_Vector u_update = x_guess;
t_Vector u_update = out;
t_Vector q = u;

t_uint niters(0);
Expand Down
4 changes: 2 additions & 2 deletions cpp/sopt/proximal.h
Expand Up @@ -64,13 +64,13 @@ auto euclidian_norm(typename real_type<typename T0::Scalar>::type const &t,
template <class T0, class T1>
void l1_norm(Eigen::DenseBase<T0> &out, typename real_type<typename T0::Scalar>::type gamma,
Eigen::DenseBase<T1> const &x) {
out = sopt::soft_threshhold(x, gamma);
out = sopt::soft_threshhold<T0>(x, gamma);
}
//! Proxmal of the weighted l1 norm
template <class T0, class T1, class T2>
void l1_norm(Eigen::DenseBase<T0> &out, Eigen::DenseBase<T2> const &gamma,
Eigen::DenseBase<T1> const &x) {
out = sopt::soft_threshhold(x, gamma);
out = sopt::soft_threshhold<T0, T2>(x, gamma);
}

//! \brief Proximal of the l1 norm
Expand Down
3 changes: 2 additions & 1 deletion cpp/tests/primal_dual.cc
Expand Up @@ -28,6 +28,7 @@ TEST_CASE("Primal Dual Imaging", "[primaldual]") {
auto const epsilon = target.stableNorm() / 2;

auto primaldual = algorithm::ImagingPrimalDual<Scalar>(target)
.l1_proximal_weights(t_Vector::Ones(target.size()))
.Phi(mId)
.Psi(mId)
.itermax(5000)
Expand All @@ -38,7 +39,7 @@ TEST_CASE("Primal Dual Imaging", "[primaldual]") {
.residual_convergence(epsilon);

auto const result = primaldual();
CHECK((result.x - target).stableNorm() <= Approx(epsilon));
CHECK((result.x - target).stableNorm() <= Approx(epsilon).margin(1e-12));
CHECK(result.good);
primaldual
.l1_proximal([](t_Vector &output, const t_real &gamma, const t_Vector &input) {
Expand Down
9 changes: 9 additions & 0 deletions cpp/tests/wavelets.cc
Expand Up @@ -59,6 +59,15 @@ void check_round_trip(Eigen::ArrayBase<T0> const &input_, sopt::t_uint db,
CHECK(not transform.isApprox(sopt::wavelets::direct_transform(input, nlevels - 1, dbwave), 1e-4));
}

TEST_CASE("wavelet data") {
for (sopt::t_int num = 1; num < 100; num++) {
if (num < 39)
REQUIRE(sopt::wavelets::daubechies_data(num).coefficients.size() == 2 * num);
else
REQUIRE_THROWS(sopt::wavelets::daubechies_data(num));
}
}

TEST_CASE("Wavelet transform innards with integer data", "[wavelet]") {
using namespace sopt::wavelets;

Expand Down

0 comments on commit 1ce136c

Please sign in to comment.