Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with sampling from Wishart distribution #247

Closed
donaldRwilliams opened this Issue Mar 6, 2019 · 17 comments

Comments

Projects
None yet
4 participants
@donaldRwilliams
Copy link

donaldRwilliams commented Mar 6, 2019

Hi:
I was wondering if there were changes made to the functions, as it does not seem to work for me as expected. More specifically, if taking the average of the diagonals everything appears OK, but each diagonal should be 20. This is note the case for the Rcpp implementation, and the values can be very far from the expected.

R solution

samples = 10000
S = toeplitz((10:1)/10)

R_wish <- rWishart(samples, 20, S)
R_list <- list()
for(i in 1:samples){
  R_list[[i]] <- diag(R_wish[,,i])
}

colMeans(do.call(rbind, R_list) )
20.05181 19.96682 19.93232 19.91842 19.89735 19.86378 19.87150 19.90565 19.91630 19.92229

# now take the average of the diagonals as
mean(colMeans(do.call(rbind, R_list) ))

19.92462

Rcpp solution

Rcpp::cppFunction("arma::mat test_rwishart(arma::mat S, int df) {
  return arma::wishrnd(S, df);
                  }", plugins = "cpp11", depends = "RcppArmadillo")


Rcpp_wish = replicate(samples, test_rwishart(S, 20))


Rcpp_list <- list()
for(i in 1:samples){
 Rcpp_list[[i]] <- diag(Rcpp_wish[,,i])
  
}
colMeans( do.call(rbind, Rcpp_list))

# these should all be 20
77.111685 21.990086 20.196240 18.345458 16.408652 14.297822 11.994180  9.526644  6.773458  3.668038

# now take the average of the diagonals as

mean(colMeans( do.call(rbind, Rcpp_list)))

20.03123

Finally, applying the stop if not all equal shows they are not equal as well.

Also note I posted this on a closed issue as well, but figured opening a new issue would be better.

Thanks !

@stephensrmmartin

This comment has been minimized.

Copy link

stephensrmmartin commented Mar 6, 2019

I can confirm this unexpected behavior on linux.

This appears to get worse as more correlation is present.

E.g., using a diagonal matrix, the expected values of the diagonals are equal to the df.

Using a matrix with a diagonal of 1, and covariances/correlations of .4, the expected values are way off, and in descending order. (e.g., 84, 27, 16, 12, 11, 10, 9, 9, 9, 8; when all should be 20).

Is this an RcppArmadillo problem, or an Armadillo problem?

@eddelbuettel

This comment has been minimized.

Copy link
Member

eddelbuettel commented Mar 6, 2019

Is this an RcppArmadillo problem, or an Armadillo problem?

We don't know at his point. arma::wishrnd() is Armadillo but for R user's convenience we allow use / default to use of the R RNGs so the first thing I would check is

  • where is the seeding coming from?
  • what happens when you decouple? we have a working arma function, wrap a main() around and sample N points from a standalone binary not involving R in any form?

Can you guys look into this, please?

@donaldRwilliams

This comment has been minimized.

Copy link
Author

donaldRwilliams commented Mar 6, 2019

We are looking into it, but also note that the unexpected behavior is not seen when using an identity matrix, or when the correlations are very small. If it was R I am thinking the problem would not depend on the kind of scale matrix.

@eddelbuettel

This comment has been minimized.

Copy link
Member

eddelbuettel commented Mar 6, 2019

Good point -- agreed! Was just trying to be careful...

@stephensrmmartin

This comment has been minimized.

Copy link

stephensrmmartin commented Mar 6, 2019

#include <iostream>
#include <armadillo>
using namespace std;
using namespace arma;

mat test_wishart(mat S, int df){
 return(wishrnd(S, df)); 
}

int main(){
  mat S(10, 10, fill::ones);
  for(int i = 0; i < 10; i++){
    for(int j = i + 1; j < 10; j++){
      S(i,j) = .4;
      S(j,i) = .4;
    }
  }
  
  //S.save("S_matrix.txt",csv_ascii);
  
  int df = 20;
  int samples = 1000;
  
  cube matrices(10,10,samples);
  mat matrix_diags(samples,10);
  
  for(int k = 0; k < samples; k++){
    matrices.slice(k) = test_wishart(S, df);
    matrix_diags.row(k) = matrices.slice(k).diag().t();
  }
  mat mean_diag(10,1); 
  mean_diag.col(0) = mean(matrix_diags).t();
  
  mean_diag.save("simOut.txt",csv_ascii);
}

This is a standalone program. It is indeed an Armadillo problem.

We're using .4s here, but you can change it, compile/run it, and see in simOut.txt that the expected values of the diagonals are NOT the df, and are in descending order. They should be all approximately equal to the df.

@coatless

This comment has been minimized.

Copy link
Contributor

coatless commented Mar 6, 2019

@eddelbuettel + @conradsnicta I can confirm a regression:

From the issue ticket that incorporate the feature (#201), it seems like the diagonal isn't stable anymore.

# Build the C++ function
Rcpp::cppFunction("arma::mat test_rwishart(arma::mat S, int df) {
  return arma::wishrnd(S, df);
}", plugins = "cpp11", depends = "RcppArmadillo")

# Run through tests:

## Artificial
S = toeplitz((10:1)/10)

set.seed(11)
R = replicate(1000, test_rwishart(S, 20))

dim(R) 
#  10 10  1000

mR = apply(R, 1:2, mean)  
# ~= E[ Wish(S, 20) ] = 20 * S

stopifnot(all.equal(mR, 20*S, tolerance = .009))
# Error: mR and 20 * S are not equal:
#  Mean relative difference: 0.5437535

## See Details, the variance is
Va = 20*(S^2 + tcrossprod(diag(S)))
vR = apply(R, 1:2, var)

stopifnot(all.equal(vR, Va, tolerance = 1/16))
# Error: vR and Va are not equal:
#   Mean relative difference: 0.8660115
@eddelbuettel

This comment has been minimized.

Copy link
Member

eddelbuettel commented Mar 6, 2019

And thanks for carrying it over to GitLab where @conradsnicta is more likely to see it!

@donaldRwilliams

This comment has been minimized.

Copy link
Author

donaldRwilliams commented Mar 7, 2019

Update. @conradsnicta is going to look into it when he has the time. My lab mate told me that it was not so easy to jump into the armadillo code and see readily all that is going on, and in particular, what might be causing this.
But the unexpected behavior is not negligible (eg a difference of 10 times either as small or large as the expected)--e..g., even if only the diagonals are off, then this implies when standardizing to obtain correlations those will be off quite a bit as well.

Hopefully there solution soon.

@stephensrmmartin

This comment has been minimized.

Copy link

stephensrmmartin commented Mar 7, 2019

Looks like the latest push to Armadillo from @conradsnicta fixes the problem.

@eddelbuettel

This comment has been minimized.

Copy link
Member

eddelbuettel commented Mar 7, 2019

Great -- and I just took that (well confined) commit and applied to the branch I just pushed (as I need to update the RNG settings for one test with sample() given an upcoming change in R 3.6.0).

Could you by chance test that? You could just drop in this file (and/or of course clone the repo / fetch the commit and build a local 0.9.200.7.1). Heck, I just put my tarball here . Feedback would be appreciated -- passes all tests so far.

@donaldRwilliams

This comment has been minimized.

Copy link
Author

donaldRwilliams commented Mar 7, 2019

Glad to here. Is there a dev version of RcppArmadillo that I can install ? (could not locate this)

@eddelbuettel

This comment has been minimized.

Copy link
Member

eddelbuettel commented Mar 7, 2019

  1. I committed to a branch of this repo.
  2. To make it easier I also provided you with a link to the altered file. Not marked up the link is https://github.com/RcppCore/RcppArmadillo/blob/bugfix/sample_test/inst/include/armadillo_bits/op_wishrnd_meat.hpp and it still resolves for me.
  3. Further, I placed a complete tarball -- the 'dev version' you ask about -- on my server (where I disabled index.html file listing). Not marked up the URL is http://dirk.eddelbuettel.com/tmp/RcppArmadillo_0.9.200.7.1.tar.gz and it resolved for me from work:
~/tmp$ wget http://dirk.eddelbuettel.com/tmp/RcppArmadillo_0.9.200.7.1.tar.gz
--2019-03-07 14:29:36--  http://dirk.eddelbuettel.com/tmp/RcppArmadillo_0.9.200.7.1.tar.gz
Resolving dirk.eddelbuettel.com (dirk.eddelbuettel.com)... 73.120.251.142
Connecting to dirk.eddelbuettel.com (dirk.eddelbuettel.com)|73.120.251.142|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1253060 (1.2M) [application/x-gzip]
Saving to: ‘RcppArmadillo_0.9.200.7.1.tar.gz’

100%[======================================================>] 1,253,060   --.-K/s   in 0.01s   

2019-03-07 14:29:38 (110 MB/s) - ‘RcppArmadillo_0.9.200.7.1.tar.gz’ saved [1253060/1253060]

~/tmp$ 

Can please try again one of the redundant methods I provided, or kindly explain in more than ten words what you tried and what failed?

@donaldRwilliams

This comment has been minimized.

Copy link
Author

donaldRwilliams commented Mar 7, 2019

I will give it a try, and report back to you.

@stephensrmmartin

This comment has been minimized.

Copy link

stephensrmmartin commented Mar 7, 2019

We used the bugfix/sample_test branch and confirmed it works. Thank you both for the update(s).

@eddelbuettel

This comment has been minimized.

Copy link
Member

eddelbuettel commented Mar 7, 2019

Super, thanks so much for confirming. Should be on CRAN tomorrow, or, at the latest, by the weekend.

@donaldRwilliams

This comment has been minimized.

Copy link
Author

donaldRwilliams commented Mar 7, 2019

Great ! Thank you for the excellent support.

@eddelbuettel

This comment has been minimized.

Copy link
Member

eddelbuettel commented Mar 8, 2019

On CRAN now. Simulate away :)

Thanks for the very helpful and detailed bug report, and of course to @conradsnicta for the quick fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.
You signed in with another tab or window. Reload to refresh your session. You signed out in another tab or window. Reload to refresh your session.