Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test new prox #16

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ jobs:
fail-fast: false
matrix:
config:
- {os: windows-latest, r: '3.6'}
- {os: macOS-latest, r: '3.6'}
- {os: windows-latest, r: '4.0'}
- {os: macOS-latest, r: '4.0'}
- {os: macOS-latest, r: 'devel'}
- {os: ubuntu-16.04, r: '3.6', rspm: "https://demo.rstudiopm.com/all/__linux__/xenial/latest"}
- {os: ubuntu-20.04, r: '4.0', rspm: "https://packagemanager.rstudio.com/cran/__linux__/focal/latest"}

env:
R_REMOTES_NO_ERRORS_FROM_WARNINGS: true
Expand Down
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Description: Fit Generalized Linear Models ("GLMs") using the "Exclusive Lasso"
License: GPL (>= 2)
Imports: Rcpp (>= 0.12.12), Matrix, RColorBrewer, foreach
LinkingTo: Rcpp, RcppArmadillo (>= 0.8.300.1.0)
Suggests: testthat, knitr, microbenchmark, glmnet, ncvreg, grpreg, covr, nnls
Suggests: testthat, knitr, microbenchmark, glmnet, ncvreg, grpreg, covr, nnls, rmarkdown
VignetteBuilder: knitr
BugReports: https://github.com/DataSlingers/ExclusiveLasso/issues
URL: https://github.com/DataSlingers/ExclusiveLasso
63 changes: 61 additions & 2 deletions src/ExclusiveLasso.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,72 @@ double exclusive_lasso_penalty(const arma::vec& x, const arma::ivec& groups){
return ans / 2.0;
}

// [[Rcpp::export]]
arma::vec exclusive_lasso_new_prox_inner(const arma::vec& z,
const arma::ivec& groups,
double lambda){

arma::uword p = z.n_elem;
arma::vec beta(p);

// TODO -- parallelize?
// Loop over groups
for(arma::sword g = arma::min(groups); g <= arma::max(groups); g++){
// Identify elements in group
arma::uvec g_ix = arma::find(g == groups);
int g_n_elem = g_ix.n_elem;

arma::vec z_g = z(g_ix);
arma::uvec z_g_sort_ix = arma::sort_index(z_g, "descend");
arma::vec z_g_sorted = arma::sort(z_g, "descend");

arma::vec s = arma::cumsum(z_g_sorted);
arma::vec L = arma::regspace(1, g_n_elem);
// Remove the factor of 2 here and in beta_g to match our other scaling
arma::vec alpha = s / (1 + lambda * L);
double alpha_bar = arma::max(alpha);

arma::vec beta_g = (z_g - lambda * alpha_bar);
beta_g = beta_g % (beta_g >= 0);

beta(g_ix) = beta_g;
}
return beta;
}


//[[Rcpp::export]]
arma::vec exclusive_lasso_prox(const arma::vec& z,
const arma::ivec& groups,
const arma::ivec groups,
double lambda,
const arma::vec& lower_bound,
const arma::vec& upper_bound,
double thresh=1e-7){
bool apply_box_constraints = arma::any(lower_bound != -EXLASSO_INF) || arma::all(upper_bound != EXLASSO_INF);

if(apply_box_constraints){
return exclusive_lasso_prox_old(z, groups, lambda, lower_bound, upper_bound, thresh);
}

arma::vec result = arma::sign(z) % exclusive_lasso_new_prox_inner(arma::abs(z), groups, lambda);

for(int i=0; i<z.n_elem; i++){
// Impose box constraints
if(apply_box_constraints){
result(i) = std::fmax(result(i), lower_bound(i));
result(i) = std::fmin(result(i), upper_bound(i));
}
}

return result;
}

// [[Rcpp::export]]
arma::vec exclusive_lasso_prox_old(const arma::vec& z,
const arma::ivec& groups,
double lambda,
const arma::vec& lower_bound,
const arma::vec& upper_bound,
double thresh=1e-7){

bool apply_box_constraints = arma::any(lower_bound != -EXLASSO_INF) || arma::all(upper_bound != EXLASSO_INF);

Expand Down