In [1]:
from pyqlearning.qlearning.greedy_q_learning import GreedyQLearning
from numpy.random import normal

In [2]:
import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)

In [13]:
class PriceReinforcementLearning(GreedyQLearning):
    def __init__(self):
        super().__init__()
        self.set_alpha_value(1.0)
        self.set_epsilon_greedy_rate(0.8)

    def extract_possible_actions(self, state_key):
        return list(filter(lambda x: 1 <= x <= 5, (state_key - 1, state_key, state_key + 1)))

    def observe_reward_value(self, state_key, action_key):
        return 5 - action_key

    def learn(self, a, limit=1000):
        super().learn(a, limit)
        print(self.q_df.sort_values("q_value", ascending=False))

In [14]:
ql = PriceReinforcementLearning()
for i in range(1, 6):
    for j in range(max(1, i - 1), min(6, i + 2)):
        ql.save_q_df(i, j, 1.0)

In [15]:
ql.learn(1, limit=1000)

   state_key  action_key   q_value
0        1.0         1.0  8.000000
0        2.0         1.0  8.000000
0        1.0         2.0  7.000000
0        3.0         2.0  7.000000
0        2.0         2.0  7.000000
0        4.0         3.0  5.500000
0        2.0         3.0  5.500000
0        3.0         4.0  3.750000
0        4.0         4.0  2.000000
0        5.0         4.0  2.000000
0        3.0         3.0  1.000000
0        4.0         5.0  0.999992
0        5.0         5.0  0.500000


In [42]:
class InventoryReinforcementLearning(GreedyQLearning):
    def __init__(self):
        super().__init__()
        self.set_alpha_value(1.0)

    def learn(self, state_key, limit=1000):
        '''
        Learning and searching the optimal solution.
        
        Args:
            state_key:      Initial state.
            limit:          The maximum number of iterative updates based on value iteration algorithms.
        '''
        self.t = 1
        while self.t <= limit:
            next_action_list = self.extract_possible_actions(state_key)
            if len(next_action_list):
                action_key = self.select_action(
                    state_key=state_key,
                    next_action_list=next_action_list
                )
                reward_value, next_state_key = self.observe_reward_value(state_key, action_key)

            if len(next_action_list):
                # Max-Q-Value in next action time.

                next_next_action_list = self.extract_possible_actions(next_state_key)
                next_action_key = self.predict_next_action(next_state_key, next_next_action_list)
                next_max_q = self.extract_q_df(next_state_key, next_action_key)

                # Update Q-Value.
                self.update_q(
                    state_key=state_key,
                    action_key=action_key,
                    reward_value=reward_value,
                    next_max_q=next_max_q
                )
                # Update State.
                state_key = next_state_key

            # Normalize.
            self.normalize_q_value()
            self.normalize_r_value()

            # Vis.
            self.visualize_learning_result(state_key)
            # Check.
            if self.check_the_end_flag(state_key) is True:
                break

            # Epsode.
            self.t += 1

        print(self.q_df[self.q_df.state_key <= 30].sort_values(["state_key", "action_key"]))

    def extract_possible_actions(self, state_key):
        return [0, 1]

    def observe_reward_value(self, state_key, action_key):
        if action_key == 1:
            state_key += 10
        purchased = int(normal(5, 1))
        state_key -= purchased
        reward = 0
        if action_key == 1:
            if state_key > 0:
                reward = state_key * -0.1
        else:
            if state_key < 0:
                reward = -5
        return  reward, state_key if state_key > 0 else 10

In [43]:
ql = InventoryReinforcementLearning()
ql.learn(10)

   state_key  action_key   q_value
0          1         0.0 -5.329688
0          1         1.0 -1.150400
0          2         0.0 -5.183588
0          2         1.0 -0.699319
0          3         0.0 -5.119149
0          3         1.0 -1.112402
0          4         0.0 -0.195129
0          4         1.0 -1.450672
0          5         0.0 -0.087415
0          5         1.0 -0.914321
0          6         0.0 -0.575200
0          6         1.0 -1.525800
0          7         0.0 -2.931616
0          7         1.0 -1.412415
0          8         0.0 -0.931518
0          8         1.0 -1.414290
0          9         0.0 -2.957908
0          9         1.0 -1.601600
0         10         0.0 -0.811765
0         10         1.0 -2.010746
0         11         0.0 -0.064131
0         11         1.0 -1.879381
0         12         0.0 -0.851318
0         12         1.0 -2.104130
0         13         0.0 -0.203200
0         13         1.0 -1.962244
0         14         0.0 -0.789690
0         14        