-
Notifications
You must be signed in to change notification settings - Fork 38
/
bads.m
1528 lines (1309 loc) · 65.6 KB
/
bads.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
function [x,fval,exitflag,output,optimState,gpstruct] = bads(fun,x0,LB,UB,PLB,PUB,nonbcon,options,varargin)
%BADS Constrained optimization using Bayesian Adaptive Direct Search (v1.1.2)
% BADS attempts to solve problems of the form:
% min F(X) subject to: LB <= X <= UB
% X C(X) <= 0 (optional)
%
% X = BADS(FUN,X0,LB,UB) starts at X0 and finds a local minimum X to the
% target function FUN. FUN accepts input X and returns a scalar function
% value evaluated at X. X0 may be a scalar, vector or empty matrix. If X0
% is empty, the starting point is chosen from the initial mesh only.
% LB and UB define a set of lower and upper bounds on the design
% variables, X, so that a solution is found in the range LB <= X <= UB.
% LB and UB can be scalars or vectors. If scalars, the bound is
% replicated in each dimension. Use empty matrices for LB and UB if no
% bounds exist. Set LB(i) = -Inf if X(i) is unbounded below; set
% UB(i) = Inf if X(i) is unbounded above. Note that:
% - unbounded variables are currently supported but deprecated,
% support may be removed in future version of BADS;
% - if LB and/or UB contain unbounded variables, the respective
% values of PLB and/or PUB need to be specified (see below);
% - if X0 is empty, LB and UB need to be specified as vectors.
%
% X = BADS(FUN,X0,LB,UB,PLB,PUB) specifies a set of plausible lower and
% upper bounds such that LB <= PLB < PUB <= UB. Both PLB and PUB
% need to be finite. PLB and PUB are used to design the initial mesh of
% the direct search, and represent a plausible range for the
% optimization variables. As a rule of thumb, set PLB and PUB such that
% there is > 90% probability that the minimum is found within the box
% (where in doubt, just set PLB=LB and PUB=UB).
% As an exception to the strict bound ordering provided above, BADS accepts
% 'fixed' variables such that X0(i),LB(i),UB(i),PLB(i),PUB(i) are all
% equal, in which case the fixed variables take constant values, and
% BADS runs on a problem with reduced dimensionality.
%
% X = BADS(FUN,X0,LB,UB,PLB,PUB,NONBCON) subjects the minimization to the
% non-bound constraints defined in NONBCON. The function NONBCON accepts
% a N-by-D matrix XI where N is any number of points to evaluate and D is
% the number of dimensions, and returns a N-by-1 vector C, representing
% the degree of violation of non-bound inequalities for each point in XI.
% BADS minimizes FUN such that C(X)<=0.
%
% X = BADS(FUN,X0,LB,UB,PLB,PUB,NONBCON,OPTIONS) minimizes with the default
% optimization parameters replaced by values in the structure OPTIONS.
% BADS('defaults') returns the default OPTIONS struct.
%
% [X,FVAL] = BADS(...) returns FVAL, the value of the objective function
% FUN at the solution X. If the target function is stochastic, FVAL is
% the expected mean of the function value at X.
%
% [X,FVAL,EXITFLAG] = BADS(...) returns EXITFLAG which describes the exit
% condition of BADS. Possible values of EXITFLAG and the corresponding
% exit conditions are:
%
% 0 Maximum number of function evaluations or iterations reached.
% 1 Magnitude of mesh size is less than specified tolerance.
% 2 Change in estimated function value less than the specified tolerance
% for OPTIONS.TolStallIters iterations.
%
% [X,FVAL,EXITFLAG,OUTPUT] = BADS(...) returns a structure OUTPUT with the
% following information:
% function: <Objective function name>
% problemtype: <Type of problem> (unconstrained or bound constrained)
% targettype: <Type of target function> (deterministic or stochastic)
% iterations: <Total iterations>
% funccount: <Total function evaluations>
% meshsize: <Mesh size at X>
% overhead: <Fractional overhead (total runtime / total fcn time - 1)>
% rngstate: <Status of random number generator>
% algorithm: <Bayesian adaptive direct search>
% message: <BADS termination message>
% fval: <Expected mean of function value at X>
% fsd: <Estimated standard deviation of function value at X>
%
% [X,FVAL,EXITFLAG,OUTPUT,OPTIMSTATE] = BADS(...) returns a detailed
% optimization structure OPTIMSTATE, mostly for debugging purposes.
%
% [X,FVAL,EXITFLAG,OUTPUT,OPTIMSTATE,GPSTRUCT] = BADS(...) returns the
% Gaussian process (GP) structure GPSTRUCT.
%
% OPTIONS = BADS('defaults') returns a basic default OPTIONS structure.
%
% EXITFLAG = BADS('test') runs a battery of tests. Here EXITFLAG is 0 if
% everything works correctly.
%
% Examples:
% FUN can be a function handle (using @)
% X = bads(@rosenbrocks, ...)
% In this case, F = rosenbrocks(X) returns the scalar function value F of
% the function evaluated at X.
%
% An example with no hard bounds, only plausible bounds
% plb = [-2 -2]; pub = [2 2];
% [X,FVAL,EXITFLAG] = bads(@rosenbrocks,[0 0],[],[],plb,pub);
%
% FUN can also be an anonymous function:
% X = bads(@(x) 3*sin(x(1))+exp(x(2)),[1 1],[0 0],[pi/2 5])
% returns X = [0 0].
%
% To run BADS on a noisy (stochastic) objective function, set
% OPTIONS.UncertaintyHandling = true
% OPTIONS.NoiseSize = SIGMA
% where SIGMA is an estimate of the SD of the noise in your problem in
% a good region of the parameter space. (If not specified, default
% SIGMA = 1). To help BADS work better, it is recommended that you
% provide to BADS an estimate of the noise at each location (see below).
%
% Set OPTIONS.UncertaintyHandling = false for a deterministic function.
% If OPTIONS.UncertaintyHandling is not specified, BADS will determine at
% runtime if the objective function is noisy.
%
% If you can estimate the SD of the noise at each input location, return
% the estimate as the *second output* of FUN. That is [FVAL,SD] = FUN(X)
% where FVAL is the observed (noisy) function value at X, and SD is the
% estimated standard deviation of the observation at X. Then set
% OPTIONS.SpecifyTargetNoise = true
%
% See BADS_EXAMPLES for more examples. The most recent version of the
% algorithm and additional documentation can be found here:
% https://github.com/acerbilab/bads
% Also, check out the FAQ: https://github.com/acerbilab/bads/wiki
%
% Reference: Acerbi, L. & Ma, W. J. (2017). "Practical Bayesian
% Optimization for Model Fitting with Bayesian Adaptive Direct Search".
% In Advances in Neural Information Processing Systems 30, pages 1834-1844.
% (arXiv preprint: https://arxiv.org/abs/1705.04405).
%
% See also BADS_EXAMPLES, @.
%--------------------------------------------------------------------------
% BADS: Bayesian Adaptive Direct Search for nonlinear function minimization.
% To be used under the terms of the GNU General Public License
% (http://www.gnu.org/copyleft/gpl.html).
%
% Author (copyright): Luigi Acerbi, 2017-2022
% e-mail: luigi.acerbi@helsinki.fi
% URL: http://luigiacerbi.com
% Version: 1.1.2
% Release date: Nov 14, 2022
% Code repository: https://github.com/acerbilab/bads
%--------------------------------------------------------------------------
%% Start timer
t0 = tic;
bads_version = '1.1.2';
%% Basic default options
defopts.Display = 'iter % Level of display ("iter", "notify", "final", or "off")';
defopts.MaxIter = '200*nvars % Max number of iterations';
defopts.MaxFunEvals = '500*nvars % Max number of objective fcn evaluations';
defopts.PeriodicVars = '[] % Array with indices of periodic variables';
defopts.NonlinearScaling = 'on % Automatic nonlinear rescaling of variables';
defopts.CompletePoll = 'off % Complete polling around the current iterate';
defopts.AccelerateMesh = 'on % Accelerate mesh contraction';
defopts.OutputFcn = '[] % Output function';
defopts.UncertaintyHandling = '[] % Explicit noise handling (if empty, determine at runtime)';
defopts.NoiseSize = '[] % Base observation noise magnitude (SD)';
defopts.SpecifyTargetNoise = 'no % Target function returns noise estimate (SD) as second output';
defopts.NoiseFinalSamples = '10 % Samples to estimate FVAL at the end (for noisy objectives)';
defopts.OptimToolbox = '[] % Use Optimization Toolbox (if empty, determine at runtime)';
%% If called with no arguments or with 'defaults', return default options
if nargout <= 1 && (nargin == 0 || (nargin == 1 && ischar(fun) && strcmpi(fun,'defaults')))
if nargin < 1
fprintf('Basic default options returned (type "help bads" for help).\n');
end
x = defopts;
return;
end
%% If called with one argument which is 'test', run test
if nargout <= 1 && nargin == 1 && ischar(fun) && strcmpi(fun,'test')
x = runtest();
return;
end
%% If called with one argument which is 'version', return version
if nargout <= 1 && nargin == 1 && ischar(fun) && strcmpi(fun,'version')
x = bads_version;
return;
end
%% Advanced options (do not modify unless you *know* what you are doing)
% Running mode
defopts.Plot = 'off % Show optimization plots ("profile", "scatter", or "off")';
defopts.Debug = 'off % Debug mode, plot additional info';
defopts.TrueMinX = '[] % Location of the global minimum (for debug only)';
% Tolerance and termination conditions
defopts.TolMesh = '1e-6 % Tolerance on mesh size';
defopts.TolFun = '1e-3 % Min significant change of objective fcn';
defopts.TolStallIters = '4 + floor(nvars/2) % Max iterations with no significant change (doubled under uncertainty)';
defopts.TolNoise = 'sqrt(eps)*options.TolFun % Min variability for a fcn to be considered noisy';
% Initialization
defopts.Ninit = 'nvars % Number of initial objective fcn evaluations';
defopts.InitFcn = '@initSobol % Initialization function';
% defoptions.InitFcn = '@initLHS';
defopts.Restarts = '0 % Number of restart attempts';
defopts.CacheSize = '1e4 % Size of cache for storing function evaluations';
defopts.FunValues = '[] % Struct with pregress fcn evaluations (X and Y fields)';
% Poll options
defopts.PollMethod = '@pollMADS2N % Poll function';
defopts.Nbasis = '200*nvars';
defopts.PollMeshMultiplier = '2 % Mesh multiplicative factor between iterations';
defopts.ForcePollMesh = 'no % Force poll vectors to be on mesh';
defopts.AlternativeIncumbent = 'off % Use alternative incumbent offset';
defopts.AdaptiveIncumbentShift = 'off % Adaptive multiplier to incumbent uncertainty';
defopts.gpRescalePoll = '1 % GP-based geometric scaling factor of poll vectors';
defopts.TolPoI = '1e-6/nvars % Threshold probability of improvement (PoI); set to 0 to always complete polling';
defopts.SkipPoll = 'yes % Skip polling if PoI below threshold, even with no success';
defopts.ConsecutiveSkipping = 'yes % Allow consecutive incomplete polls';
defopts.SkipPollAfterSearch = 'yes % Skip polling after successful search';
defopts.MinFailedPollSteps = 'Inf % Number of failed fcn evaluations before skipping is allowed';
defopts.AccelerateMeshSteps = '3 % Accelerate mesh after this number of stalled iterations';
defopts.SloppyImprovement = 'yes % Move incumbent even after insufficient improvement';
defopts.MeshOverflowsWarning = '2 + nvars/2 % Threshold # mesh overflows for warning';
% Improvement parameters
defopts.TolImprovement = '1 % Minimum significant improvement at unit mesh size';
defopts.ForcingExponent = '3/2 % Exponent of forcing function';
defopts.IncumbentSigmaMultiplier = '0.1 % Multiplier to incumbent uncertainty for acquisition functions';
defopts.ImprovementQuantile = '0.5 % Quantile when computing improvement (<0.5 for conservative improvement)';
defopts.FinalQuantile = '1e-3 % Top quantile when choosing final iteration';
% Search properties
defopts.Nsearch = '2^12 % Number of candidate search points';
defopts.Nsearchiter = '2 % Number of optimization iterations for search';
defopts.ESbeta = '1 % Multiplier in ES';
defopts.ESstart = '0.25 % Starting scale value in ES';
defopts.SearchImproveFrac = '0 % Fraction of candidate search points with (slower) improved estimate';
defopts.SearchScaleSuccess = 'sqrt(2) % Search radius expansion factor for successful search';
defopts.SearchScaleIncremental = '2 % Search radius expansion factor for incremental search';
defopts.SearchScaleFailure = 'sqrt(0.5) % Search radius contraction factor for failed search';
defopts.SearchFactorMin = '0.5';
defopts.SearchMethod = '{@searchHedge,{{@searchES,1,1},{@searchES,2,1}}} % Search function(s)';
defopts.SearchGridNumber = '10 % iteration scale factor between poll and search';
defopts.MaxPollGridNumber = '0 % Maximum poll integer';
defopts.SearchGridMultiplier = '2 % multiplier integer scale factor between poll and search';
defopts.SearchSizeLocked = 'on % Relative search scale factor locked to poll scale factor';
% defopts.SearchNtry = 'max(2*nvars,5+nvars) % Number of searches per iteration';
defopts.SearchNtry = 'max(nvars,floor(3+nvars/2)) % Number of searches per iteration';
defopts.SearchMeshExpand = '0 % Search-triggered mesh expansion after this number of successful search rounds';
defopts.SearchMeshIncrement = '1 % Mesh size increment after search-triggered mesh expansion';
defopts.SearchOptimize = 'no % Further optimize acquisition function';
% Gaussian process properties
defopts.Ndata = '50 + 10*nvars % Number of training data (minimum 200 under uncertainty)';
defopts.MinNdata = '50 % Minimum number of training data (doubled under uncertainty)';
defopts.BufferNdata = '100 % Max number of training data removed if too far from current point';
defopts.gpSamples = '0 % Hyperparameters samples (0 = optimize)';
defopts.MinRefitTime = '2*nvars % Minimum fcn evals before refitting the GP';
defopts.PollTraining = 'yes % Train GP also during poll stage';
defopts.DoubleRefit = 'off % Always try a second GP fit';
defopts.gpMeanPercentile = '90 % Percentile of empirical GP mean';
defopts.gpMeanRangeFun = '@(ym,y) (ym - median(y))/5*2 % Empirical range of hyperprior over the mean';
defopts.gpdefFcn = '{@gpdefBads,''rq'',[1,1]} % GP definition fcn';
defopts.gpMethod = 'nearest % GP training set selection method';
defopts.gpCluster = 'no % Cluster additional points during training';
defopts.RotateGP = 'no % Rotate GP basis';
defopts.gpRadius = '3 % Radius of training set';
defopts.UseEffectiveRadius = 'yes %';
defopts.gpCovPrior = 'iso % GP hyper-prior over covariance';
defopts.gpFixedMean = 'no';
defopts.FitLik = 'yes % Fit the likelihood term';
defopts.PollAcqFcn = '{@acqLCB,[]} % Acquisition fcn for poll stage';
defopts.SearchAcqFcn = '{@acqLCB,[]} % Acquisition fcn for search stage';
defopts.AcqHedge = 'off % Hedge acquisition function';
defopts.CholAttempts = '0 % Attempts at performing the Cholesky decomposition';
defopts.NoiseNudge = '[1 0] % Increase nudge to noise in case of Cholesky failure';
defopts.RemovePointsAfterTries = '1 % Start removing training points after this number of failures';
defopts.gpSVGDiters = '200 % SVGD iterations for GP training';
defopts.gpWarnings = 'off % Issue warning if GP hyperparameters fit fails';
defopts.NormAlphaLevel = '1e-6 % Alpha level for normality test of gp predictions';
% GP warping parameters (unsupported)
defopts.FitnessShaping = 'off % Nonlinear rescaling of objective fcn';
defopts.WarpFunc = '0 % GP warping function type';
% Noise parameters
defopts.UncertainIncumbent = 'yes % Treat incumbent as if uncertain regardless of uncertainty handling';
defopts.MeshNoiseMultiplier = '0.5 % Contribution to log noise magnitude from log mesh size (0 for noisy functions)';
% Hedge heuristic parameters (currently used during the search stage)
defopts.HedgeGamma = '0.125';
defopts.HedgeBeta = '1e-3/options.TolFun';
defopts.HedgeDecay = '0.1^(1/(2*nvars))';
%% If called with 'all', return all default options
if strcmpi(fun,'all')
x = defopts;
return;
end
%% Check that all BADS subfolders are on the MATLAB path
add2path();
%% Input arguments
if nargin < 3 || isempty(LB); LB = -Inf; end
if nargin < 4 || isempty(UB); UB = Inf; end
if nargin < 5; PLB = []; end
if nargin < 6; PUB = []; end
if nargin < 7; nonbcon = []; end
if nargin < 8; options = []; end
%% Initialize display printing options
if ~isfield(options,'Display') || isempty(options.Display)
options.Display = defopts.Display;
end
switch lower(options.Display(1:3))
case {'not','notify','notify-detailed'}
prnt = 1;
case {'non','none','off'}
prnt = 0;
case {'ite','all','iter','iter-detailed'}
prnt = 3;
case {'fin','final','final-detailed'}
prnt = 2;
otherwise
prnt = 1;
end
%% Initialize variables and algorithm structures
if isempty(x0)
if prnt > 2
fprintf('X0 not specified. Taking the number of dimensions from PLB and PUB...');
end
if isempty(PLB) || isempty(PUB)
error('If no starting point is provided, PLB and PUB need to be specified.');
end
x0 = NaN(size(PLB));
if prnt > 2
fprintf(' NVARS = %d.\n', numel(x0));
end
end
nvars = numel(x0);
optimState = [];
% Check boundaries and if there are fixed variables
[LB,UB,PLB,PUB,fixidx] = boundscheck(x0,LB,UB,PLB,PUB);
% If there are fixed variables, rerun BADS with lowered dimensionality
if any(fixidx)
fixedvars = LB(fixidx);
if isempty(varargin)
fun_fix = @(x) fun(expandvars(x,fixidx,fixedvars));
else
fun_fix = @(x) fun(expandvars(x,fixidx,fixedvars),varargin{:});
end
if ~isempty(nonbcon)
nonbcon_fix = @(x) nonbcon(expandvars(x,fixidx,fixedvars));
else
nonbcon_fix = [];
end
if isfield(options,'OutputFcn') && ~isempty(options.OutputFcn) && ...
(isa(options.OutputFcn,'function_handle') || ~isempty(eval(options.OutputFcn)))
if ischar(options.OutputFcn)
outputfun_tmp = eval(options.OutputFcn);
elseif isa(options.OutputFcn,'function_handle')
outputfun_tmp = options.OutputFcn;
else
error('OPTIONS.OutputFcn should be a function handle to an output function.');
end
outputfun_fix = @(x,optimState,state) outputfun_tmp(expandvars(x,fixidx,fixedvars),optimState,state);
else
outputfun_fix = [];
end
options.OutputFcn = outputfun_fix; % Assign Output function
% Run of BADS with lowered dimensionality
[x,fval,exitflag,output,optimState,gpstruct] = fixedbads(fun_fix,x0,LB,UB,PLB,PUB,nonbcon_fix,options,fixidx,nargout);
return;
end
% Convert from char to function handles
if ischar(fun); fun = str2func(fun); end
if ischar(nonbcon); nonbcon = str2func(nonbcon); end
% Setup algorithm options
options = setupoptions(nvars,defopts,options);
% Output function
outputfun = options.OutputFcn;
% Setup and transform variables
[u0,LB,UB,PLB,PUB,MeshSizeInteger,optimState] = ...
setupvars(x0,LB,UB,PLB,PUB,optimState,nonbcon,options,prnt);
optimState = updateSearchBounds(optimState);
% Store objective function
optimState.fun = fun;
if isempty(varargin)
funwrapper = fun; % No additional function arguments passed
else
funwrapper = @(u_) fun(u_,varargin{:});
end
% Store constraints function
optimState.nonbcon = nonbcon;
% Initialize function logger
[~,optimState] = funlogger([],u0,optimState,'init',options.CacheSize,options.SpecifyTargetNoise);
%% Initial function evaluations
iter = 0;
optimState.iter = iter;
% Evaluate starting point and initial mesh, determine if function is noisy
[u,yval,fval,isFinished_flag,optimState,displayFormat] = ...
evalinitmesh(u0,funwrapper,optimState,options,prnt);
if ~isfinite(fval); error('Cannot find valid starting point.'); end
exitflag = 0;
msg = 'Optimization terminated: reached maximum number of function evaluations after initialization.';
if ~isempty(outputfun)
isFinished_flag = outputfun(origunits(u,optimState),optimState,'init');
end
% Change options for uncertainty handling
if optimState.UncertaintyHandling
options.TolStallIters = 2*options.TolStallIters;
options.Ndata = max(200,options.Ndata);
options.MinNdata = 2*options.MinNdata;
options.MeshOverflowsWarning = 2*options.MeshOverflowsWarning;
%options.gpMeanPercentile = 50;
options.MinFailedPollSteps = Inf;
options.MeshNoiseMultiplier = 0;
if isempty(options.NoiseSize); options.NoiseSize = 1; end
% Keep some function evaluations for the final resampling
options.NoiseFinalSamples = min(options.NoiseFinalSamples, options.MaxFunEvals - optimState.funccount);
options.MaxFunEvals = options.MaxFunEvals - options.NoiseFinalSamples;
else
if isempty(options.NoiseSize); options.NoiseSize = sqrt(options.TolFun); end
end
if optimState.UncertaintyHandling % Current uncertainty in estimate
if options.SpecifyTargetNoise
[~,idx] = min(optimState.Y); % Assume the min has been picked
fsd = optimState.S(idx);
else
fsd = options.NoiseSize(1);
end
else
fsd = 0;
end
optimState.fsd = fsd;
ubest = u; % Current best minumum location
optimState.usuccess = ubest; % Store sequence of successful x, y, and f values
optimState.ysuccess = yval;
optimState.fsuccess = fval;
optimState.u = u;
% Initialize Gaussian Process (GP) structure
if options.FitLik; gplik = []; else gplik = log(options.TolFun); end
gpstruct = feval(options.gpdefFcn{:},nvars,gplik,optimState,options,[]);
gpstruct.fun = funwrapper;
fhyp = gpstruct.hyp;
% Initialize struct with GP prediction statistics
Nsamples = max(1,options.gpSamples);
optimState.gpstats = savegpstats([],[],[],[],ones(1,Nsamples)/Nsamples);
lastskipped = 0; % Last skipped iteration
SearchSuccesses = 0;
SearchSpree = 0;
Restarts = options.Restarts;
%% Optimization loop
iter = 1;
while ~isFinished_flag
optimState.iter = iter;
refitted_flag = false; % GP refitted this iteration
gpexitflag = Inf; % Exit flag from GP training
action = []; % Action performed this iteration (for printing purposes)
% Compute mesh size and search mesh size
MeshSize = options.PollMeshMultiplier^MeshSizeInteger;
if options.SearchSizeLocked
optimState.SearchSizeInteger = min(0,MeshSizeInteger*options.SearchGridMultiplier - options.SearchGridNumber);
end
optimState.meshsize = MeshSize;
optimState.searchmeshsize = options.PollMeshMultiplier.^optimState.SearchSizeInteger;
% Update bounds to grid for search mesh
optimState = updateSearchBounds(optimState);
% Minimum improvement for a poll/search to be considered successful
SufficientImprovement = options.TolImprovement*(MeshSize^options.ForcingExponent);
if options.SloppyImprovement
SufficientImprovement = max(SufficientImprovement, options.TolFun);
end
optimState.SearchSufficientImprovement = SufficientImprovement;
% Multiple successful searches raise the bar for improvement
% optimState.SearchSufficientImprovement = SufficientImprovement*(2^SearchSuccesses);
%----------------------------------------------------------------------
%% Search stage
% Perform search if there are still available attempts, and if there
% are more than NVARS stored points
DoSearchStep_flag = optimState.searchcount < options.SearchNtry ...
&& size(gpstruct.y,1) > nvars;
if DoSearchStep_flag
% Check whether it is time to refit the GP
[refitgp_flag,~,optimState] = IsRefitTime(optimState,options);
if refitgp_flag || optimState.searchcount == 0; gpstruct.post = []; end
if isempty(gpstruct.post)
% Local GP approximation on current point
[gpstruct,gptempflag] = gpupdate(gpstruct, ...
options.gpMethod, ...
u, ...
[], ... % [upoll; gridunits(x,optimState)], ...
optimState, ...
options, ...
refitgp_flag);
if refitgp_flag; refitted_flag = true; end
gpexitflag = min(gptempflag,gpexitflag);
end
% Update optimization target (based on GP prediction at incumbent)
optimState = UpdateTarget(ubest,fhyp,optimState,gpstruct,options);
% Generate search set (normalized coordinates)
optimState.searchcount = optimState.searchcount + 1;
[usearchset,optimState] = feval(options.SearchMethod{:}, ...
u, ...
gpstruct, ...
LB, ...
UB, ...
optimState, ...
options);
% Enforce periodicity
usearchset = periodCheck(usearchset,LB,UB,optimState);
% Force candidate points on search grid
usearchset = force2grid(usearchset, optimState);
% Remove already evaluated or unfeasible points from search set
usearchset = uCheck(usearchset,optimState.TolMesh,optimState,1);
if ~isempty(usearchset) % Non-empty search set
ns = size(usearchset, 1);
ymu = zeros(numel(gpstruct.hyp),ns);
ys = zeros(numel(gpstruct.hyp),ns);
% Evaluate acquisition function on search set
try
%----------------------------------------------------------
if options.AcqHedge
[optimState.hedge,acqIndex,ymu,ys] = ...
acqPortfolio('acq',optimState.hedge,usearchset,optimState.ftarget,0,gpstruct,optimState,options,SufficientImprovement);
index = acqIndex(optimState.hedge.chosen);
z = 1;
%----------------------------------------------------------
else
% Batch evaluation of acquisition function on search set
[z,~,ymu,ys] = ...
feval(options.SearchAcqFcn{:},usearchset,optimState.ftarget,gpstruct,optimState,0);
% Evaluate best candidate point in original coordinates
[~,index] = min(z);
end
catch
% Failed evaluation of the acquisition function
index = []; z = [];
end
% Randomly choose index if something went wrong
if isempty(index) || ~isfinite(index); index = randi(size(usearchset,1)); end
acqu = [];
%--------------------------------------------------------------
% Local optimization of the acquisition function
% (generally it does not improve results)
if options.SearchOptimize
acqoptoptions = optimset('Display','off','GradObj','off','DerivativeCheck','off',...
'TolX',optimState.TolMesh,'TolFun',options.TolFun);
try
acqu = usearchset(index,:);
acqoptoptions.MaxFunEval = options.Nsearch;
acqoptf = @(u_) feval(options.SearchAcqFcn{:},u_,optimState.ftarget,gpstruct,optimState,0);
% Limit seach within search box
% NEEDS TO BE ADJUSTED FOR PERIODIC VARIABLES
acqlb = max(LB, min(acqu,u - optimState.searchfactor*MeshSize));
acqub = min(UB, max(acqu,u + optimState.searchfactor*MeshSize));
[acqu,facq,~,output] = fmincon(acqoptf, ...
acqu,[],[],[],[],acqlb,acqub,[],acqoptoptions);
acqu = force2grid(acqu, optimState);
acqu = periodCheck(acqu,LB,UB,optimState,1);
acqu = uCheck(acqu,optimState.TolMesh,optimState,1);
catch
acqu = [];
end
end
%--------------------------------------------------------------
if isempty(acqu); acqu = usearchset(index,:); end
usearch = acqu;
% Evaluate function on search point
[ysearch,optimState,ysearch_sd] = funlogger(funwrapper,usearch,optimState,'iter');
if ~isempty(z)
% Save statistics of gp prediction
optimState.gpstats = ...
savegpstats(optimState.gpstats,ysearch,ymu(:,index),ys(:,index),gpstruct.hypweight);
end
% Add search point to training set
if ~isempty(usearch) && optimState.searchcount < options.SearchNtry
gpstruct = gpupdate(gpstruct, ...
'add', ...
usearch, ...
[ysearch,ysearch_sd], ...
optimState, ...
options, ...
0);
end
if optimState.UncertaintyHandling
gpstructnew = gpupdate(gpstruct, ...
options.gpMethod, ...
usearch, ...
[], ...
optimState, ...
options, ...
0);
% Compute estimated function value at point
[~,~,fsearch,fs2] = gppred(usearch,gpstructnew);
if numel(gpstructnew.hyp) > 1
fsearch = weightedsum(gpstructnew.hypweight,fsearch,1);
fs2 = weightedsum(gpstructnew.hypweight,fs2,1);
end
fsearchsd = sqrt(fs2);
else
fsearch = ysearch;
fsearchsd = 0;
end
% Compute distance of search point from current point
searchdist = sqrt(udist(ubest,usearch,gpstruct.lenscale,optimState));
else % Empty search set
ysearch = yval;
fsearch = fval;
fsearchsd = 0;
searchdist = 0;
end
% Evaluate search
SearchImprovement = EvalImprovement(fval,fsearch,fsd,fsearchsd,options.ImprovementQuantile);
fvalold = fval;
% Declare if search was success or failure
if (SearchImprovement > 0 && options.SloppyImprovement) ...
|| SearchImprovement > optimState.SearchSufficientImprovement
% Search did not fail
%--------------------------------------------------------------
if options.AcqHedge
method = optimState.hedge.str{optimState.hedge.chosen};
else
%--------------------------------------------------------------
method = feval(options.SearchMethod{:},[],[],[],[],optimState);
end
if SearchImprovement > optimState.SearchSufficientImprovement
SearchSuccesses = SearchSuccesses + 1;
searchstring = ['Successful search (' method ')'];
optimState.usuccess = [optimState.usuccess; usearch];
optimState.ysuccess = [optimState.ysuccess; ysearch];
optimState.fsuccess = [optimState.fsuccess; fsearch];
searchstatus = 'success';
else
searchstring = ['Incremental search (' method ')'];
% searchstring = [];
searchstatus = 'incremental';
end
% Update incumbent point
[ubest,yval,fval,fsd,optimState,gpstruct] = UpdateIncumbent(ubest,yval,fval,fsd,usearch,ysearch,fsearch,fsearchsd,optimState,gpstruct,options);
if optimState.UncertaintyHandling; gpstruct = gpstructnew; end
gpstruct.post = []; % Reset posterior
else
% Search failed
searchstatus = 'failure';
searchstring = [];
end
%------------------------------------------------------------------
% Update portfolio acquisition function
if options.AcqHedge && ~isempty(usearchset)
optimState.hedge = ...
acqPortfolio('update',optimState.hedge,usearchset(acqIndex,:),fsearch,fsearchsd,gpstruct,optimState,options,SufficientImprovement,fvalold,MeshSize);
end
% Update search portfolio (needs improvement)
if ~options.AcqHedge && ~isempty(optimState.hedge) && ~isempty(usearch)
optimState.hedge = ...
acqPortfolio('update',optimState.hedge,usearch,fsearch,fsearchsd,gpstruct,optimState,options,SufficientImprovement,fvalold,MeshSize);
end
%------------------------------------------------------------------
% Update search statistics and search scale factor
optimState = UpdateSearch(optimState,searchstatus,searchdist,options);
% Print search results
if prnt > 2 && ~isempty(searchstring)
if optimState.UncertaintyHandling
fprintf(displayFormat,iter,optimState.funccount,fval,fsd,MeshSize,searchstring,'');
else
fprintf(displayFormat,iter,optimState.funccount,fval,MeshSize,searchstring,'');
end
end
end % Search stage
% Decide whether to perform the poll stage
switch optimState.searchcount
case {0, options.SearchNtry} % Skipped or just finished search
optimState.searchcount = 0;
if SearchSuccesses > 0 && options.SkipPollAfterSearch
DoPollStep_flag = false;
SearchSpree = SearchSpree + 1;
if options.SearchMeshExpand > 0 && ...
mod(SearchSpree,options.SearchMeshExpand) == 0 && ...
options.SearchMeshIncrement > 0
% Check if mesh size is already maximal
optimState = meshOverflowCheck(MeshSizeInteger,optimState,options);
MeshSizeInteger = min(MeshSizeInteger + options.SearchMeshIncrement, options.MaxPollGridNumber);
end
else
DoPollStep_flag = true;
SearchSpree = 0;
end
SearchSuccesses = 0;
% optimState.searchfactor = 1;
otherwise % In-between searches, no poll
DoPollStep_flag = false;
end
%----------------------------------------------------------------------
%% Poll stage
u = ubest;
if DoPollStep_flag
PollBestImprovement = 0; % Best improvement so far
upollbest = u; % Best poll point
ypollbest = yval; % Observed func value at best point
fpollbest = fval; % Estimated func value at best point
fpollhyp = fhyp; % gp hyper-parameters at best point
fpollbestsd = fsd; % Uncertainty of objective func
optimState.pollcount = 0; % Poll iterations
goodpoll_flag = false; % Found a good poll
B = []; % Poll basis
upoll = []; % Poll vectors
unew = [];
% Poll loop
while (~isempty(upoll) || isempty(B)) ...
&& optimState.funccount < options.MaxFunEvals ...
&& optimState.pollcount <= nvars*2
% Fill in basis vectors
Bnew = feval(options.PollMethod{:}, ...
B, ...
u, ...
gpstruct, ...
LB, ...
UB, ...
optimState, ...
options);
% Create new poll vectors
if ~isempty(Bnew)
% GP-based vector scaling
vv = bsxfun(@times,Bnew*MeshSize,gpstruct.pollscale);
% Add vector to current point, fix to grid
upollnew = bsxfun(@plus,u,vv);
upollnew = periodCheck(upollnew,LB,UB,optimState);
if options.ForcePollMesh
upollnew = force2grid(upollnew, optimState);
end
upollnew = uCheck(upollnew,optimState.TolMesh,optimState,0);
% Add new poll points to polling set
upoll = [upoll; upollnew];
B = [B; Bnew];
end
% Cannot refill poll vector set, stop polling
if isempty(upoll); break; end
% Check whether it is time to refit the GP
[refitgp_flag,unrelgp_flag,optimState] = IsRefitTime(optimState,options);
if ~options.PollTraining && iter > 1; refitgp_flag = false; end
% Rebuild GP local approximation if refitting or if at beginning of polling stage
if refitgp_flag || optimState.pollcount == 0; gpstruct.post = []; end
% Local GP approximation around polled points
if isempty(gpstruct.post)
[gpstruct,gptempflag] = gpupdate(gpstruct, ...
options.gpMethod, ...
u, ...
upoll, ...
optimState, ...
options, ...
refitgp_flag);
if refitgp_flag; refitted_flag = true; end
gpexitflag = min(gptempflag,gpexitflag);
end
optimState = UpdateTarget(upollbest,fpollhyp,optimState,gpstruct,options);
% Evaluate acquisition function on poll vectors
%--------------------------------------------------------------
if options.AcqHedge
[optimState.hedge,acqIndex,ymu,ys,fm,fs] = ...
acqPortfolio('acq',optimState.hedge,upoll,optimState.ftarget,0,gpstruct,optimState,options,SufficientImprovement);
index = acqIndex(optimState.hedge.chosen);
else
%--------------------------------------------------------------
% Batch evaluation of acquisition function on search set
[z,~,ymu,ys,fm,fs] = feval(options.PollAcqFcn{:},upoll,optimState.ftarget,gpstruct,optimState,0);
[~,index] = min(z);
end
% Something went wrong, random vector
if isempty(index) || isnan(index); index = randi(size(upoll,1)); end
% Compute probability that improvement at any location is
% less than SufficientImprovement (assumes independence --
% conservative estimate towards continuing polling)
gammaz = (optimState.ftarget - SufficientImprovement - fm)./fs;
if numel(gpstruct.hyp) > 1
gammaz = weightedsum(gpstruct.hypweight,gammaz,1);
end
if all(isfinite(gammaz)) && isreal(gammaz)
fpi = 0.5*erfc(-gammaz/sqrt(2));
fpi = sort(fpi,'descend');
pless = prod(1-fpi(1:min(nvars,end)));
else
pless = 0;
unrelgp_flag = 1;
end
% Consider whether to stop polling
if ~options.CompletePoll
% Stop polling if last poll was good
if goodpoll_flag
if unrelgp_flag
break; % GP is unreliable, just stop polling
elseif pless > 1-options.TolPoI
break; % Use GP prediction whether to stop polling
end
else
% No good poll so far -- if GP is reliable, stop polling
% if probability of improvement at any location is too low
if ~unrelgp_flag && ...
(options.ConsecutiveSkipping || lastskipped < iter-1) ...
&& optimState.pollcount >= options.MinFailedPollSteps ...
&& pless > 1-options.TolPoI
lastskipped = iter;
break;
end
end
end
% Evaluate function and store value
unew = upoll(index,:);
[ypoll,optimState,ypoll_sd] = funlogger(funwrapper,unew,optimState,'iter');
% Remove polled vector from set
upoll(index,:) = [];
% Save statistics of gp prediction
optimState.gpstats = ...
savegpstats(optimState.gpstats,ypoll,ymu(:,index),ys(:,index),gpstruct.hypweight);
if optimState.UncertaintyHandling
% Add just polled point to training set
gpstruct = gpupdate(gpstruct, ...
'add', ...
unew, ...
[ypoll,ypoll_sd], ...
optimState, ...
options, ...
0);
% Compute estimated function value at point
[~,~,fpoll,fs2] = gppred(unew,gpstruct);
if numel(gpstruct.hyp) > 1
fpoll = weightedsum(gpstruct.hypweight,fpoll,1);
fs2 = weightedsum(gpstruct.hypweight,fs2,1);
end
fpollsd = sqrt(fs2);
else
fpoll = ypoll;
fpollsd = 0;
end
% Compute estimated improvement over incumbent
PollImprovement = EvalImprovement(fval,fpoll,fsd,fpollsd,options.ImprovementQuantile);
% Check if current point improves over best polled point so far
if PollImprovement > PollBestImprovement
upollbest = unew;
ypollbest = ypoll;
fpollbest = fpoll;
fpollhyp = gpstruct.hyp;
fpollbestsd = fpollsd;
PollBestImprovement = PollImprovement;
if PollBestImprovement > SufficientImprovement
goodpoll_flag = true;
end
end
% Increase poll counter
optimState.pollcount = optimState.pollcount + 1;
end % Poll loop
% Evaluate poll
if (PollBestImprovement > 0 && options.SloppyImprovement) || ...
PollBestImprovement > SufficientImprovement
polldirection = find(abs(upollbest - ubest) > 1e-12,1); % The sign can be wrong for periodic variables (unused anyhow)
[ubest,yval,fval,fsd,optimState,gpstruct] = UpdateIncumbent(ubest,yval,fval,fsd,upollbest,ypollbest,fpollbest,fpollbestsd,optimState,gpstruct,options);
u = ubest;
pollmoved_flag = true;
else
pollmoved_flag = false;
end
if PollBestImprovement > SufficientImprovement
% Check if mesh size is already maximal
optimState = meshOverflowCheck(MeshSizeInteger,optimState,options);
% Successful poll, increase mesh size
MeshSizeInteger = min(MeshSizeInteger + 1, options.MaxPollGridNumber);
SuccessPoll_flag = true;
optimState.usuccess = [optimState.usuccess; ubest];
optimState.ysuccess = [optimState.ysuccess; yval];
optimState.fsuccess = [optimState.fsuccess; fval];
else
% Failed poll, decrease mesh size
MeshSizeInteger = MeshSizeInteger - 1;
% Accelerated mesh reduction if stalling
if options.AccelerateMesh && iter > options.AccelerateMeshSteps
% Evaluate improvement in the last iterations
HistoricImprovement = ...
EvalImprovement(optimState.iterList.fval(iter-options.AccelerateMeshSteps),fval,optimState.iterList.fsd(iter-options.AccelerateMeshSteps),fsd,options.ImprovementQuantile);
if HistoricImprovement < options.TolFun
MeshSizeInteger = MeshSizeInteger - 1;
end
end
optimState.SearchSizeInteger = min(optimState.SearchSizeInteger, MeshSizeInteger*options.SearchGridMultiplier - options.SearchGridNumber);
SuccessPoll_flag = false;
% Profile plot of iteration
if strcmpi(options.Plot,'profile') && ~isempty(gpstruct.x)
% figure(iter);
gpstruct.ftarget = optimState.ftarget;
hold off;
landscapeplot(@(u_) funwrapper(origunits(u_,optimState)), ...
u, ...
LB, ...
UB, ...
MeshSize, ...
gpstruct, ...
[], ...
31);