Skip to content

Commit

Permalink
Implemented version check in the MATLAB and R wrappers; fixed bug in …
Browse files Browse the repository at this point in the history
…MATLAB wrapper that was preventing it from running, introduced in #63
  • Loading branch information
linqiaozhi committed Feb 8, 2019
1 parent 01e181e commit 0f7e5c5
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 12 deletions.
3 changes: 2 additions & 1 deletion fast_tsne.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ fftRtsne <- function(X,
load_affinities=NULL,
fast_tsne_path=NULL, nthreads=0, perplexity_list = NULL,
get_costs = FALSE, df = 1.0,... ) {
version_number = '1.1.0'

if (is.null(fast_tsne_path)) {
if(.Platform$OS.type == "unix") {
Expand Down Expand Up @@ -131,7 +132,7 @@ fftRtsne <- function(X,
print(df)
close(f)

flag= system2(command=fast_tsne_path, args=c(data_path, result_path, nthreads));
flag= system2(command=fast_tsne_path, args=c(version_number,data_path, result_path, nthreads));
if (flag != 0) {
stop('tsne call failed');
}
Expand Down
17 changes: 13 additions & 4 deletions fast_tsne.m
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@
% CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
% IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY
% OF SUCH DAMAGE.

version_number = '1.1.0';
if (nargin == 1)
opts.perplexity = 30;
end
Expand Down Expand Up @@ -275,6 +277,12 @@
else
nthreads = opts.nthreads;
end

if (~isfield(opts, 'df'))
df = 1;
else
df = opts.df;
end

X = double(X);

Expand All @@ -283,7 +291,7 @@

% Compile t-SNE C code
if(~exist(fullfile(tsne_path,'./fast_tsne'),'file') && isunix)
system(sprintf('g++ -std=c++11 -O3 src/sptree.cpp src/tsne.cpp src/nbodyfft.cpp -o bin/fast_tsne -pthread -lfftw3 -lm')); end
system(sprintf('g++ -std=c++11 -O3 src/sptree.cpp src/tsne.cpp src/nbodyfft.cpp -o bin/fast_tsne -pthread -lfftw3 -lm'));
end

% Compile t-SNE C code on Windows
Expand All @@ -296,12 +304,12 @@
stop_lying_iter, K, sigma, nbody_algo, no_momentum_during_exag, knn_algo,...
early_exag_coeff, n_trees, search_k, start_late_exag_iter, late_exag_coeff, rand_seed,...
nterms, intervals_per_integer, min_num_intervals, initialization, load_affinities, ...
perplexity_list, mom_switch_iter, momentum, final_momentum, learning_rate);
perplexity_list, mom_switch_iter, momentum, final_momentum, learning_rate,df);

disp('Data written');
tic
%[flag, cmdout] = system(fullfile(tsne_path,'/fast_tsne'), '-echo');
cmd = sprintf('%s data.dat result.dat %d',fullfile(tsne_path,'/fast_tsne'), nthreads);
cmd = sprintf('%s %s data.dat result.dat %d',fullfile(tsne_path,'/fast_tsne'), version_number, nthreads);
[flag, cmdout] = system(cmd, '-echo');
if(flag~=0)
error(cmdout);
Expand All @@ -318,7 +326,7 @@ function write_data(filename, X, no_dims, theta, perplexity, max_iter,...
stop_lying_iter, K, sigma, nbody_algo, no_momentum_during_exag, knn_algo,...
early_exag_coeff, n_trees, search_k, start_late_exag_iter, late_exag_coeff, rand_seed,...
nterms, intervals_per_integer, min_num_intervals, initialization, load_affinities, ...
perplexity_list, mom_switch_iter, momentum, final_momentum, learning_rate)
perplexity_list, mom_switch_iter, momentum, final_momentum, learning_rate,df)

[n, d] = size(X);

Expand Down Expand Up @@ -353,6 +361,7 @@ function write_data(filename, X, no_dims, theta, perplexity, max_iter,...
fwrite(h, min_num_intervals, 'int');
fwrite(h, X', 'double');
fwrite(h, rand_seed, 'integer*4');
fwrite(h, df, 'double');
fwrite(h, load_affinities, 'integer*4');
if ~isnan(initialization)
fwrite(h, initialization', 'double');
Expand Down
25 changes: 18 additions & 7 deletions src/tsne.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1871,13 +1871,14 @@ void TSNE::save_data(const char *result_path, double* data, double* costs, int n


int main(int argc, char *argv[]) {
printf("=============== t-SNE v1.0.1 ===============\n");
const char version_number[] = "1.1.0";
printf("=============== t-SNE v%s ===============\n", version_number);

// Define some variables
int N, D, no_dims, max_iter, stop_lying_iter;
int K, nbody_algo, knn_algo, no_momentum_during_exag;
int mom_switch_iter;
double momentum, final_momentum, learning_rate;
int mom_switch_iter;
double momentum, final_momentum, learning_rate;
int n_trees, search_k, start_late_exag_iter;
double sigma, early_exag_coeff, late_exag_coeff;
double perplexity, theta, *data, *initial_data;
Expand All @@ -1899,13 +1900,23 @@ int main(int argc, char *argv[]) {
data_path = "data.dat";
result_path = "result.dat";
nthreads = 0;
if(argc >= 2) {
data_path = argv[1];
}
if (argc >=2 ) {
if ( strcmp(argv[1],version_number)) {
std::cout<<"Wrapper passed wrong version number: "<< argv[1] <<std::endl;
exit(-1);
}
}else{
std::cout<<"Please pass version number as first argument." <<std::endl;
exit(-1);

}
if(argc >= 3) {
result_path = argv[2];
data_path = argv[2];
}
if(argc >= 4) {
result_path = argv[3];
}
if(argc >= 5) {
nthreads = (unsigned int)strtoul(argv[3], (char **)NULL, 10);
}
if (nthreads == 0) {
Expand Down

0 comments on commit 0f7e5c5

Please sign in to comment.