Skip to content

Commit

Permalink
Added threshold for matrix and sigma argument for legacy.
Browse files Browse the repository at this point in the history
  • Loading branch information
AnthonyEbert committed Mar 15, 2018
1 parent b31073b commit 12b5a50
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 11 deletions.
4 changes: 2 additions & 2 deletions R/RcppExports.R
@@ -1,8 +1,8 @@
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393

kernelMatrix_sum_multi <- function(x, y, Sinv) {
.Call('_EasyMMD_kernelMatrix_sum_multi', PACKAGE = 'EasyMMD', x, y, Sinv)
kernelMatrix_sum_multi <- function(x, y, Sinv, threshold) {
.Call('_EasyMMD_kernelMatrix_sum_multi', PACKAGE = 'EasyMMD', x, y, Sinv, threshold)
}

kernelMatrix_sum <- function(x, y, sigma, approx_exp) {
Expand Down
6 changes: 4 additions & 2 deletions R/mmd.R
Expand Up @@ -34,7 +34,9 @@
#' # Different var
#'
#' MMD_4 <- MMD(y, x, var = 0.25)
MMD <- function(y, x, y_kmmd = NULL, var = 1, bias = FALSE, threshold = Inf, approx_exp = 0){
MMD <- function(y, x, y_kmmd = NULL, var = 1, bias = FALSE, threshold = Inf, approx_exp = 0, sigma = NULL){

if(!is.null(sigma)){var = sigma^2}

stopifnot(class(x) == class(y))

Expand Down Expand Up @@ -94,7 +96,7 @@ kernelMatrix_sum_wrap <- function(y, x, var = 1, threshold = Inf, approx_exp = 0

Sinv <- solve(var)

return(kernelMatrix_sum_multi(y, x, Sinv = Sinv))
return(kernelMatrix_sum_multi(y, x, Sinv = Sinv, threshold = threshold))
}
}

Expand Down
2 changes: 1 addition & 1 deletion man/MMD.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 5 additions & 2 deletions src/MMD_full_loop.cpp
Expand Up @@ -27,19 +27,22 @@ inline
}

// [[Rcpp::export]]
double kernelMatrix_sum_multi(const arma::mat& x, const arma::mat& y, const arma::mat Sinv) {
double kernelMatrix_sum_multi(const arma::mat& x, const arma::mat& y, const arma::mat Sinv, const double threshold) {

int n_x = x.n_rows;
int n_y = y.n_rows;

double b;
double c = threshold * threshold;

double output_2 = 0;
//
for(int i = 0; i < n_x; ++i){
for(int j = 0; j < n_y; ++j){
b = maha(y.row(j), x.row(i), Sinv);
output_2 += std::exp(- 0.5 * b);
if(b < c){
output_2 += std::exp(- 0.5 * b);
}

if(j % 2048 == 0)
{
Expand Down
9 changes: 5 additions & 4 deletions src/RcppExports.cpp
Expand Up @@ -7,15 +7,16 @@
using namespace Rcpp;

// kernelMatrix_sum_multi
double kernelMatrix_sum_multi(const arma::mat& x, const arma::mat& y, const arma::mat Sinv);
RcppExport SEXP _EasyMMD_kernelMatrix_sum_multi(SEXP xSEXP, SEXP ySEXP, SEXP SinvSEXP) {
double kernelMatrix_sum_multi(const arma::mat& x, const arma::mat& y, const arma::mat Sinv, const double threshold);
RcppExport SEXP _EasyMMD_kernelMatrix_sum_multi(SEXP xSEXP, SEXP ySEXP, SEXP SinvSEXP, SEXP thresholdSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< const arma::mat& >::type x(xSEXP);
Rcpp::traits::input_parameter< const arma::mat& >::type y(ySEXP);
Rcpp::traits::input_parameter< const arma::mat >::type Sinv(SinvSEXP);
rcpp_result_gen = Rcpp::wrap(kernelMatrix_sum_multi(x, y, Sinv));
Rcpp::traits::input_parameter< const double >::type threshold(thresholdSEXP);
rcpp_result_gen = Rcpp::wrap(kernelMatrix_sum_multi(x, y, Sinv, threshold));
return rcpp_result_gen;
END_RCPP
}
Expand Down Expand Up @@ -50,7 +51,7 @@ END_RCPP
}

static const R_CallMethodDef CallEntries[] = {
{"_EasyMMD_kernelMatrix_sum_multi", (DL_FUNC) &_EasyMMD_kernelMatrix_sum_multi, 3},
{"_EasyMMD_kernelMatrix_sum_multi", (DL_FUNC) &_EasyMMD_kernelMatrix_sum_multi, 4},
{"_EasyMMD_kernelMatrix_sum", (DL_FUNC) &_EasyMMD_kernelMatrix_sum, 4},
{"_EasyMMD_kernelMatrix_threshold_sum", (DL_FUNC) &_EasyMMD_kernelMatrix_threshold_sum, 5},
{NULL, NULL, 0}
Expand Down
4 changes: 4 additions & 0 deletions tests/testthat/test_main.R
Expand Up @@ -19,6 +19,7 @@ print(system.time(MMD(x,y, var = 1/2, threshold = 6)))
set.seed(1)
x <- rnorm(1e4)
y <- rnorm(1e4, 5)
mmd_0 <- MMD(y, x, sigma = 1/sqrt(2))
mmd_1 <- MMD(y, x, var = 1/2)
mmd_2 <- MMD(y, x, var = 1/2, threshold = 6)
mmd_3 <- MMD(y, x, var = 1/2, approx_exp = 1)
Expand All @@ -28,11 +29,14 @@ mmd_6 <- MMD_l(y, x, var = 1/2)

set.seed(2)
mmd_7 <- MMD_l_multi(y, x, var = 1/2, k = 10)
mmd_8 <- MMD(matrix(y, ncol = 1), matrix(x, ncol = 1), var = matrix(1/2), threshold = 6)

testthat::expect_equal(mmd_0, mmd_1)
testthat::expect_equal(mmd_1, 0.88824062832437)
testthat::expect_equal(mmd_1, mmd_2)
testthat::expect_equal(mmd_3, 0.88804234756543)
testthat::expect_equal(mmd_4, 0.88835124947997)
testthat::expect_equal(mmd_4, mmd_5)
testthat::expect_equal(mmd_6, 0.87759954311812)
testthat::expect_equal(mmd_7, 0.88606371828789)
testthat::expect_equal(mmd_8, mmd_1)

0 comments on commit 12b5a50

Please sign in to comment.