In [17]:
import numpy as np

n_states = 16
n_actions = 4
goal_state = 15

Q_table = np.zeros((n_states,n_actions))

In [18]:
print(Q_table)

[[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]


In [15]:
learning_rate = 0.8
# Low discount factor means that I only care about immediate rewards whereas high discount factor means that I care about future rewards 
# A simple example would be if I pick the coin in front of me or I take a risky path that could lead to future rewards
discount_factor = 0.95
exploration_prob = 0.2
# An epoch means that the model has passed through an entire phase of learning with all the available training data. 
# If there are 100 training data samples, then one epoch is when your model trains from all 100 once.
epochs = 1000

In [19]:
for epoch in range(epochs):
    current_state = np.random.randint(0,n_states)
    while current_state != goal_state:
        # Here the reason we have this set up in this way is because we try to balance exploration vs exploitation
        # Here we arent manually defining wether to explore or exploit, rather we are deciding it based on a random value 
        # "np.random.rand() < exploration_prob"
        # So if it is less than exploration proabability then explore else exploit.
        if np.random.rand() < exploration_prob:
            action = np.random.randint(0, n_actions)
        else:
            action = np.argmax(Q_table[current_state])
        # move to the next state, or wrap around -> used in circular queue
        next_state = ( current_state + 1 ) % n_states
        reward = 1 if next_state == goal_state else 0
        Q_table[current_state,action] += learning_rate * (reward + discount_factor * np.max(Q_table[next_state]) - Q_table[current_state, action])
        
        current_state = next_state

In [20]:
print(Q_table)

[[0.47956544 0.4868947  0.39013998 0.48767498]
 [0.51334178 0.51334208 0.51334077 0.51333551]
 [0.54035978 0.5403587  0.54036009 0.54018717]
 [0.56880009 0.56880009 0.56879864 0.56880009]
 [0.59873694 0.59873694 0.59873694 0.59873694]
 [0.63024941 0.63024941 0.63024941 0.63024941]
 [0.66342043 0.66342043 0.66342043 0.66342043]
 [0.6983373  0.6983373  0.6983373  0.6983373 ]
 [0.73509189 0.73509189 0.73509189 0.73509189]
 [0.77378094 0.77378094 0.77378094 0.77378094]
 [0.81450625 0.81450625 0.81450625 0.81450625]
 [0.857375   0.857375   0.857375   0.857375  ]
 [0.9025     0.9025     0.9025     0.9025    ]
 [0.95       0.95       0.95       0.95      ]
 [1.         1.         1.         1.        ]
 [0.         0.         0.         0.        ]]
