In [1]:
import oracle

In [2]:
import numpy as np

class SGD:
    def __init__(self, oracle: object, N: int, C: float):
        self.oracle = oracle
        self.N = N
        self.C = C
        self.p = oracle.initialize_prices()
        self.S = oracle.s
        self.D = oracle.d

    def update(self):
        p_mean = self.p
        for t in range(self.N):
            index = np.random.randint(0, self.S + self.D)
            self.p = (self.p - (self.C / np.sqrt(t + 1)) * self.oracle.compute_gradient(
                index
            )).clip(min=0)
            self.oracle.prices = self.p
            print("prices: ", self.p)
            print("cost function: ", self.oracle.compute_cost_func(self.oracle.prices))
            print("gradient: ", self.oracle.compute_gradient(index))
            p_mean += self.p
        return p_mean / self.N

    def get_prices(self):
        return self.p

In [3]:
from oracle import Oracle

In [4]:
oracle_obj = Oracle(5, 10, 20, 5, 0.0001)

In [5]:
adam = SGD(oracle_obj, 1000, 0.05)

In [6]:
oracle_obj.compute_y()

array([[0.06976796, 0.37630749, 0.34730139, 0.01556807, 0.38425379,
        0.37179087, 0.18834236, 0.39379938, 0.06308415, 0.50003229,
        0.05103246, 0.07179406, 0.03105245, 0.39527851, 0.38058433,
        0.47566184, 0.27288124, 0.4545388 , 0.27873382, 0.43672126],
       [0.0698548 , 0.3763345 , 0.34730829, 0.01565054, 0.38427892,
        0.37174268, 0.18836413, 0.39382495, 0.06322969, 0.50011792,
        0.05110462, 0.07187863, 0.03104107, 0.39523254, 0.38069652,
        0.47560558, 0.27281703, 0.45464828, 0.27876643, 0.43662124],
       [0.06980949, 0.3763188 , 0.34731264, 0.01550948, 0.38420686,
        0.37183081, 0.18827958, 0.39375097, 0.0632093 , 0.50001631,
        0.05118223, 0.07196139, 0.03111094, 0.39524411, 0.380668  ,
        0.47577966, 0.27290935, 0.45467453, 0.27876926, 0.43673653],
       [0.06980458, 0.37629828, 0.34742266, 0.0154974 , 0.38436315,
        0.37170412, 0.18835893, 0.39379974, 0.06316081, 0.50001078,
        0.05116887, 0.07187696, 0.03118515, 0

In [7]:
oracle_obj.compute_cost_func(oracle_obj.prices)

40.545846323262936

In [8]:
adam.update()

prices:  [0.13945417 0.7524763  0.6946163  0.03098352 0.76843028 0.74342319
 0.3763931  0.7874291  0.12613922 1.         0.10203469 0.14357843
 0.06208071 0.79051129 0.76121843 0.95127091 0.54561866 0.90904686
 0.5574423  0.87325355]
cost function:  40.545846323262936
gradient:  [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
prices:  [0.13945417 0.7524763  0.6946163  0.03098352 0.76843028 0.74342319
 0.3763931  0.7874291  0.12613922 1.         0.10203469 0.14357843
 0.06208071 0.79051129 0.76121843 0.95127091 0.54561866 0.90904686
 0.5574423  0.87325355]
cost function:  40.545846323262936
gradient:  [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
prices:  [0.13945417 0.7524763  0.72348381 0.03098352 0.76843028 0.74342319
 0.40526061 0.7874291  0.12613922 1.         0.10203469 0.14357843
 0.06208071 0.8193788  0.76121843 0.95127091 0.54561866 0.90904686
 0.58630981 0.87325355]
cost function:  40.67081850763704
gradient:  [ 0  0  0  0  0  0 -1  0  0  0  0  0  0  0  0  0  0  0  0  0]
prices:  [0.137

array([0.17478522, 0.63890869, 0.62260687, 0.08220661, 0.61322398,
       0.62131367, 0.36984048, 0.6491614 , 0.17750803, 0.79754599,
       0.11833109, 0.19955145, 0.09025483, 0.67627458, 0.62149338,
       0.7793643 , 0.47327679, 0.75759792, 0.48406389, 0.69436169])

In [9]:
oracle_obj.compute_cost_func(oracle_obj.final_price)

45.85247641737952