In [45]:
import oracle

In [100]:
import numpy as np

class SGD:
    def __init__(self, oracle: object, N: int, C: float):
        self.oracle = oracle
        self.N = N
        self.C = C
        self.p = oracle.initialize_prices()
        self.S = oracle.s
        self.D = oracle.d

    def update(self):
        p_mean = self.p
        for t in range(self.N):
            index = np.random.randint(0, self.S + self.D)
            self.p = (self.p - (self.C / np.sqrt(t + 1)) * self.oracle.compute_gradient(
                index
            )).clip(min=0)
            self.oracle.prices = self.p
            print("prices: ", self.p)
            print("cost function: ", self.oracle.compute_cost_func())
            print("gradient: ", self.oracle.compute_gradient(index))
            p_mean += self.p
        return p_mean / self.N

    def get_prices(self):
        return self.p

In [124]:
class ADAM:
    """Adaptive Moment Estimation."""

    def __init__(
        self, oracle: object, N: int, C: float, beta1: float = 0.9, beta2: float = 0.999
    ):
        """
        Parameters:
        beta1 (float): Decay rate of the first moment.
        beta2 (float): Decay rate of the second moment.
        """
        self.oracle = oracle
        self.N = N
        self.C = C
        self.p = oracle.initialize_prices()
        self.S = oracle.s
        self.D = oracle.d
        self.beta1 = beta1
        self.beta2 = beta2

    def update(self):
        """
        Updates prices of products using ADAM.
        Returns:
            p_mean (float): Mean of prices.
        """
        p_mean = self.p
        m = np.zeros(self.oracle.n)
        v = np.zeros(self.oracle.n)
        for t in range(self.N):
            index = np.random.randint(0, self.S + self.D)
            g = self.oracle.compute_gradient(index)
            m = self.beta1 * m + (1 - self.beta1) * g
            v = self.beta2 * v + (1 - self.beta2) * g**2
            m_hat = m / (1 - self.beta1 ** (t + 1))
            v_hat = v / (1 - self.beta2 ** (t + 1))
            self.p = (self.p - self.C / np.sqrt(v_hat + 1e-7) * m_hat).clip(min=0)
            self.oracle.prices = self.p
            print("prices: ", self.p)
            print("cost function: ", self.oracle.compute_cost_func())
            print("gradient: ", self.oracle.compute_gradient(index))
            p_mean += self.p
        return p_mean / self.N


In [117]:
from oracle import Oracle

In [125]:
oracle_obj = Oracle(5, 10, 20, 5, 0.0001)

In [126]:
adam = ADAM(oracle_obj, 1000, 0.1)

In [127]:
oracle_obj.initialize_customers()

array([[0.13798993, 0.08952909, 0.17231658, 0.07251769, 0.0104049 ,
        0.0670355 , 0.10572867, 0.06417051, 0.10974685, 0.11449248],
       [0.26128551, 0.28854099, 0.20448758, 0.3385809 , 0.28906166,
        0.33049831, 0.2287294 , 0.23024066, 0.26998817, 0.3550602 ],
       [0.48635915, 0.55786426, 0.52610126, 0.43058143, 0.57506642,
        0.45227407, 0.53549323, 0.59245327, 0.47504337, 0.47110624],
       [0.70183908, 0.69298766, 0.69434601, 0.73568796, 0.72334001,
        0.6429992 , 0.79667302, 0.63362658, 0.72567955, 0.7345591 ],
       [0.96434467, 0.82996055, 0.87481122, 0.85112639, 0.99651081,
        0.89704294, 0.80865744, 0.91624216, 0.96057641, 0.85519759],
       [0.13807956, 0.05767393, 0.01894484, 0.11762043, 0.09944595,
        0.17082593, 0.01416183, 0.16823869, 0.08912638, 0.02401797],
       [0.27544066, 0.34690817, 0.35408594, 0.27148649, 0.28596183,
        0.30271124, 0.28869092, 0.29524273, 0.2946287 , 0.35949027],
       [0.54603656, 0.42688819, 0.5329769

In [121]:
oracle_obj.compute_y()

array([[0.27606217, 0.00953143, 0.00496254, 0.16607723, 0.12882174,
        0.14618139, 0.01322161, 0.13823425, 0.01672136, 0.34431311,
        0.50001131, 0.24392959, 0.02259077, 0.03676252, 0.05968089,
        0.0171732 , 0.09285296, 0.43065033, 0.14516466, 0.09088258],
       [0.27605795, 0.00952345, 0.00495179, 0.16617294, 0.12880388,
        0.14613463, 0.01321031, 0.13834932, 0.01682432, 0.3444801 ,
        0.50007364, 0.24384263, 0.02261094, 0.03691975, 0.05959728,
        0.01711529, 0.09293609, 0.43063353, 0.14517038, 0.09079754],
       [0.2759642 , 0.00949726, 0.00505483, 0.16606798, 0.12886512,
        0.14626031, 0.01317207, 0.13842101, 0.01671548, 0.34432671,
        0.50009604, 0.24395611, 0.02271314, 0.03682443, 0.05959613,
        0.01715978, 0.09297482, 0.43060636, 0.14522971, 0.09097803],
       [0.27605614, 0.00953007, 0.00495896, 0.16623781, 0.12888425,
        0.14613006, 0.01320231, 0.13823805, 0.01676084, 0.34437675,
        0.50006595, 0.24383182, 0.02274539, 0

In [128]:
oracle_obj.compute_cost_func()

57.33150541998077

In [129]:
adam.update()

prices:  [0.85559559 0.92868767 0.40187243 0.06244636 0.58208936 0.15847973
 0.96085013 0.36313476 0.90486961 0.69102325 1.         0.24231389
 0.07729738 0.48559745 0.78756436 0.24111326 0.7194817  0.14273874
 0.53183233 0.79933669]
cost function:  49.745187347231145
gradient:  [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
prices:  [0.85559559 0.92868767 0.40187243 0.06244636 0.58208936 0.15847973
 0.96085013 0.36313476 0.90486961 0.75802907 1.         0.24231389
 0.07729738 0.55260327 0.78756436 0.24111326 0.7194817  0.14273874
 0.53183233 0.79933669]
cost function:  53.72474355744318
gradient:  [ 0  0  0  0 -1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
prices:  [0.79171428 0.86480636 0.33799131 0.         0.51820811 0.09459989
 0.89696881 0.29925369 0.84098829 0.78378751 0.93611868 0.17843318
 0.01342236 0.58548777 0.72368306 0.17723255 0.65560041 0.07885926
 0.46795111 0.73545539]
cost function:  53.10830464919114
gradient:  [3.96008772e-01 4.32513014e-01 1.69042077e-01 1.17161566e-

array([0.16511813, 0.18890905, 0.23664974, 0.21309426, 0.17745992,
       0.19371208, 0.18706431, 0.20879815, 0.20010569, 0.20838667,
       0.22122122, 0.19795006, 0.24606802, 0.20413529, 0.25991878,
       0.25913794, 0.19963317, 0.1589622 , 0.17402248, 0.23892624])