In [1]:
from sklearn.preprocessing import KBinsDiscretizer
import numpy as np
import math, time, random

import gym

# CartPole-v1

In [2]:
env = gym.make('CartPole-v1', render_mode="human")

### Visualise Environment

In [3]:
policy = lambda obs: 0

for _ in range(3):
    obs, info = env.reset(seed=42)
    for _ in range(1000):
        actions = policy(obs)
        obs, reward, done, _, info = env.step(actions)
        env.render()  # This line allows for visualisation
        time.sleep(0.05)

        if done:
            break

env.close()

2025-04-24 13:16:28.441 python[68623:10699131] +[IMKClient subclass]: chose IMKClient_Modern
2025-04-24 13:16:28.441 python[68623:10699131] +[IMKInputSession subclass]: chose IMKInputSession_Modern


In [4]:
?env.env

# Q-Learning

First, in order to use Q-Learning, the action and state space needs to be continuous. Therefore, we will need to discretise the continuous space.

In [5]:
from typing import Tuple

n_bins = (6, 12)
lower_bounds = [env.observation_space.low[2], -math.radians(50)]
upper_bounds = [env.observation_space.high[2], math.radians(50)]


def discretizer(_, __, angle, pole_velocity) -> Tuple[int, ...]:
    """Convert continuoues state into a discrete state."""
    est = KBinsDiscretizer(n_bins=n_bins, encode='ordinal', strategy='uniform')
    est.fit([lower_bounds, upper_bounds])
    return tuple(map(int, est.transform([[angle, pole_velocity]])[0]))


Initialise the Q-value table with zeros.

In [6]:
Q_table = np.zeros(n_bins + (env.action_space.n,))
Q_table.shape

(6, 12, 2)

Create a policy function, uses the Q-table to and greedly selecting the highest Q value.

In [7]:
def policy(state: tuple, epsilon: float = 0.1):
    """Choosing action based on epsilon-greedy policy"""
    if np.random.rand() < epsilon:
        return env.action_space.sample()  # Explore
    else:
        return np.argmax(Q_table[state])  # Exploit

Update function

In [8]:
def new_Q_value(reward: float, state_new: tuple, discount_factor: int = 1) -> float:
    """Temporal difference for updating Q-Value of state-action pair"""
    future_optimal_value = np.max(Q_table[state_new])
    learned_value = reward + discount_factor * future_optimal_value
    return learned_value

Decaying Learning Rate

In [9]:
def learning_rate(n: int, min_rate=0.01) -> float:
    """Decaying learning rate"""
    return max(min_rate, min(1.0, 1.0 - math.log10((n + 1) / 25)))

Decaying Exploration Rate

In [10]:
def exploration_rate(n: int, min_rate=0.1) -> float:
    """Decaying exploration rate"""
    return max(min_rate, min(1.0, 1.0 - math.log10((n + 1) / 25)))

## Training

In [11]:
env.reset()

(array([-0.0334251 ,  0.02402424,  0.01940848, -0.02464325], dtype=float32),
 {})

In [20]:
n_episodes = 10000
for e in range(n_episodes):
    # Discretise state into buckets
    obs, info = env.reset(seed=42)
    current_state, done = discretizer(*obs), False

    if e % 3 == 0:
        print(e)

    while not done:
        # Policy Action, Epsilon Greedy
        action = policy(current_state, exploration_rate(e))

        # Increment environment
        obs, reward, done, _, info = env.step(action)
        new_state = discretizer(*obs)

        # Update the Q-Table
        lr = learning_rate(2)
        learnt_value = new_Q_value(reward, new_state)
        old_value = Q_table[current_state][action]
        Q_table[current_state][action] = (1 - lr) * old_value + lr * learnt_value

        current_state = new_state

        # Every 100 episodes we render environment
        if e % 3 == 0:
            env.render()

0
3
6
9
12
15
18
21
24
27
30
33
36
39
42
45
48
51
54
57
60
63
66
69
72
75
78
81
84
87
90
93
96
99
102
105
108
111
114
117
120
123
126
129
132
135
138
141
144
147
150
153
156
159
162
165
168
171
174
177
180
183
186
189
192
195
198
201
204
207
210
213
216
219
222
225
228
231
234
237
240
243
246
249
252
255
258
261
264
267
270
273
276
279
282
285
288
291
294
297
300
303
306
309
312
315




318
321
324
327
330
333
336
339
342
345
348
351
354
357
360
363
366
369
372
375
378
381
384
387
390
393
396
399
402
405
408
411
414
417
420
423
426
429
432
435
438
441
444
447
450
453
456
459
462
465
468
471
474
477
480
483
486
489
492
495
498
501
504
507
510
513
516
519
522
525
528
531
534
537
540
543
546
549
552
555
558
561
564
567
570
573
576
579
582
585
588
591
594
597
600
603
606
609
612
615
618
621
624
627
630
633
636
639
642
645
648
651
654
657
660
663
666
669
672
675
678
681
684
687
690
693
696
699
702
705
708
711
714
717
720
723
726
729
732
735
738
741
744
747
750
753
756
759
762
765
768
771
774
777
780
783
786
789
792
795
798
801
804
807
810
813
816
819
822
825
828
831
834
837
840
843
846
849
852
855
858
861
864
867
870
873
876
879
882
885
888
891
894
897
900
903
906
909
912
915
918
921
924
927
930
933
936
939
942
945
948
951
954
957
960
963
966
969
972
975
978
981
984
987
990
993
996
999
1002
1005
1008
1011
1014
1017
1020
1023
1026
1029
1032
1035
1038
1041
1044
1047
1050
105

KeyboardInterrupt: 

In [21]:
np.save('Q_table.npy', Q_table)

In [22]:
Q_table

array([[[0.000e+00, 0.000e+00],
        [0.000e+00, 0.000e+00],
        [0.000e+00, 0.000e+00],
        [0.000e+00, 0.000e+00],
        [0.000e+00, 0.000e+00],
        [0.000e+00, 0.000e+00],
        [0.000e+00, 0.000e+00],
        [0.000e+00, 0.000e+00],
        [0.000e+00, 0.000e+00],
        [0.000e+00, 0.000e+00],
        [0.000e+00, 0.000e+00],
        [0.000e+00, 0.000e+00]],

       [[7.640e+02, 7.750e+02],
        [4.570e+02, 7.630e+02],
        [7.530e+02, 3.990e+02],
        [7.410e+02, 7.460e+02],
        [4.440e+02, 7.510e+02],
        [3.000e+00, 5.640e+02],
        [2.000e+00, 7.060e+02],
        [4.907e+03, 0.000e+00],
        [2.000e+00, 0.000e+00],
        [1.000e+00, 1.000e+00],
        [7.460e+02, 0.000e+00],
        [0.000e+00, 5.310e+02]],

       [[7.770e+02, 7.750e+02],
        [7.470e+02, 7.790e+02],
        [5.490e+03, 7.770e+02],
        [5.079e+03, 7.750e+02],
        [5.473e+03, 7.610e+02],
        [5.491e+03, 4.953e+03],
        [5.517e+03, 5.076e+03],
    