/
VPCT.py
98 lines (78 loc) · 2.39 KB
/
VPCT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# -*- coding: utf-8 -*-
"""Implementation of VPCT
"""
# Author: Wenjie Li <li3549@purdue.edu>
# License: MIT
import numpy as np
from PyXAB.algos.Algo import Algorithm
from PyXAB.algos.VHCT import VHCT
from PyXAB.algos.GPO import GPO
from PyXAB.partition.BinaryPartition import BinaryPartition
class VPCT(Algorithm):
"""
Implementation of Variance-reduced Parallel Confidence Tree algorithm (VHCT + GPO)
"""
def __init__(
self, numax=1, rhomax=0.9, rounds=1000, domain=None, partition=BinaryPartition
):
"""
Initialization of the VPCT algorithm
Parameters
----------
numax: float
parameter nu_max in the algorithm
rhomax: float
parameter rho_max in the algorithm, the maximum rho used
rounds: int
the number of rounds/budget
domain: list(list)
the domain of the objective function
partition:
the partition used in the optimization process
"""
super(VPCT, self).__init__()
if domain is None:
raise ValueError("Parameter space is not given.")
if partition is None:
raise ValueError("Partition of the parameter space is not given.")
self.algorithm = GPO(
numax=numax,
rhomax=rhomax,
rounds=rounds,
domain=domain,
partition=partition,
algo=VHCT,
)
def pull(self, time):
"""
The pull function of VPCT that returns a point to be evaluated
Parameters
----------
time: int
The time step of the online process.
Returns
-------
point: list
The point chosen by the VPCT algorithm
"""
return self.algorithm.pull(time)
def receive_reward(self, time, reward):
"""
The receive_reward function of VPCT to receive the reward for the chosen point
Parameters
----------
time: int
The time step of the online process.
reward: float
The (Stochastic) reward of the pulled point
Returns
-------
"""
self.algorithm.receive_reward(time, reward)
def get_last_point(self):
"""
The function to get the last point of VPCT.
Returns
-------
"""
return self.algorithm.get_last_point()