In [1]:
import gymnasium as gym
import numpy as np
import random
from Qtabularfunctions import*
from Cartpolefamily import*
import matplotlib.pyplot as plt
import random

In [None]:
# state space limits
low = np.array([-4.8, -3.0, -0.418, -3.5])
high = np.array([4.8, 3.0, 0.418, 3.5])

episodes = 20000
min_td_error = 1e-4  # Minimum TD error threshold to continue episode
consecutive_small_errors = 10  # Number of consecutive small TD errors to trigger stop

num_actions = 101
lr=0.1
gamma=0.99
epsilon=0.5
force_mag=100


In [9]:
gen = CartPoleCategoryGenerator()

#initial:
agent = TabularQLearningAgent(
    statespace=[low,high],
    num_actions=num_actions,  
    lr=lr,
    gamma=gamma,
    epsilon=epsilon,
    force_mag=force_mag
)

actionspace_dict = {state_tuple: [*range(num_actions)] for state_tuple in agent.disc.get_all_discrete_states()}
actionset_dict = {state_tuple: [] for state_tuple in agent.disc.get_all_discrete_states()}

Continuous action space discretized into 101 actions:
Discrete actions: [-100.  -98.  -96.  -94.  -92.  -90.  -88.  -86.  -84.  -82.  -80.  -78.
  -76.  -74.  -72.  -70.  -68.  -66.  -64.  -62.  -60.  -58.  -56.  -54.
  -52.  -50.  -48.  -46.  -44.  -42.  -40.  -38.  -36.  -34.  -32.  -30.
  -28.  -26.  -24.  -22.  -20.  -18.  -16.  -14.  -12.  -10.   -8.   -6.
   -4.   -2.    0.    2.    4.    6.    8.   10.   12.   14.   16.   18.
   20.   22.   24.   26.   28.   30.   32.   34.   36.   38.   40.   42.
   44.   46.   48.   50.   52.   54.   56.   58.   60.   62.   64.   66.
   68.   70.   72.   74.   76.   78.   80.   82.   84.   86.   88.   90.
   92.   94.   96.   98.  100.]


In [None]:
random_cat = random.choice(list(gen.categories.keys())[:3])
env = gen.generate_env(random_cat)
agent = TabularQLearningAgent(
    statespace=[low,high],
    num_actions=num_actions,
    actionspace=actionspace, 
    lr=lr,
    gamma=gamma,
    epsilon=epsilon,
    force_mag=force_mag
)

for episode in range(episodes):
    state = env.reset()
    total_reward = 0
    done = False
    small_error_count = 0
    
    while not done:
        # Choose action (returns index 0-9)
        a = agent.choose_action(state)
        action = agent.discrete_actions[a]
        
        # Take action in environment
        result = env.step(action)
        if len(result) == 4:
            next_state, reward, done, info = result
        else:
            next_state, reward, done, truncated, info = result
            done = done or truncated
        
        # Update Q-table and get TD error
        td_error = agent.update(state, a, reward, next_state, done)
        
        state = next_state
        total_reward += reward
        
        # Check if TD error is small enough to stop episode
        if abs(td_error) < min_td_error:
            small_error_count += 1
        else:
            small_error_count = 0
            
        # Stop episode if TD error has been small for consecutive steps
        if small_error_count >= consecutive_small_errors:
            done = True
            if episode % 1000 == 0:  # Print only occasionally
                print(f"Episode {episode} stopped early due to small TD error")
    
    # Decrease exploration over time
    if episode % 100 == 0:
        agent.decrease_epsilon()

env.close()

for tuple in agent.disc.get_all_discrete_states():
   actionset_dict[tuple]+=[np.argmax(agent.Q[tuple])]

Continuous action space discretized into 101 actions:
Discrete actions: [-100.  -98.  -96.  -94.  -92.  -90.  -88.  -86.  -84.  -82.  -80.  -78.
  -76.  -74.  -72.  -70.  -68.  -66.  -64.  -62.  -60.  -58.  -56.  -54.
  -52.  -50.  -48.  -46.  -44.  -42.  -40.  -38.  -36.  -34.  -32.  -30.
  -28.  -26.  -24.  -22.  -20.  -18.  -16.  -14.  -12.  -10.   -8.   -6.
   -4.   -2.    0.    2.    4.    6.    8.   10.   12.   14.   16.   18.
   20.   22.   24.   26.   28.   30.   32.   34.   36.   38.   40.   42.
   44.   46.   48.   50.   52.   54.   56.   58.   60.   62.   64.   66.
   68.   70.   72.   74.   76.   78.   80.   82.   84.   86.   88.   90.
   92.   94.   96.   98.  100.]

--- Q-update Debug ---
State: (0, 0, 0, 0), Action: 4, Next State: (0, 1, 2, 3)
Current Q: 0.000
Target: 1.000
TD Error: 1.000
Update amount: 0.100
New Q: 0.100
Max Q in next state: 0.000

--- Q-update Debug ---
State: (0, 1, 1, 3), Action: 98, Next State: (0, 1, 1, 2)
Current Q: 0.982
Target: 1.000
TD Error: 


--- Q-update Debug ---
State: (0, 1, 4, 4), Action: 18, Next State: (0, 0, 4, 4)
Current Q: 1.127
Target: 1.000
TD Error: -0.127
Update amount: -0.013
New Q: 1.114
Max Q in next state: 1.057

--- Q-update Debug ---
State: (0, 1, 4, 4), Action: 18, Next State: (0, 0, 4, 4)
Current Q: 1.134
Target: 1.000
TD Error: -0.134
Update amount: -0.013
New Q: 1.120
Max Q in next state: 1.141

--- Q-update Debug ---
State: (0, 1, 1, 1), Action: 0, Next State: (0, 1, 1, 2)
Current Q: 1.001
Target: 1.000
TD Error: -0.001
Update amount: -0.000
New Q: 1.001
Max Q in next state: 1.000

--- Q-update Debug ---
State: (0, 1, 4, 3), Action: 65, Next State: (0, 3, 4, 1)
Current Q: 1.235
Target: 1.000
TD Error: -0.235
Update amount: -0.023
New Q: 1.211
Max Q in next state: 1.269

--- Q-update Debug ---
State: (0, 2, 1, 2), Action: 83, Next State: (0, 3, 1, 1)
Current Q: 1.000
Target: 1.000
TD Error: 0.000
Update amount: 0.000
New Q: 1.000
Max Q in next state: 1.099

--- Q-update Debug ---
State: (0, 0, 3, 4)

In [17]:
for state_tuple, actions in actionset_dict.items():
    actionset_dict[state_tuple] = list(set(actions))

In [18]:
actionset_dict

{(0, 0, 0, 0): [np.int64(4)],
 (0, 0, 0, 1): [np.int64(15)],
 (0, 0, 0, 2): [np.int64(0)],
 (0, 0, 0, 3): [np.int64(0)],
 (0, 0, 0, 4): [np.int64(75)],
 (0, 0, 1, 0): [np.int64(3)],
 (0, 0, 1, 1): [np.int64(0)],
 (0, 0, 1, 2): [np.int64(17)],
 (0, 0, 1, 3): [np.int64(64)],
 (0, 0, 1, 4): [np.int64(81)],
 (0, 0, 2, 0): [np.int64(92)],
 (0, 0, 2, 1): [np.int64(56)],
 (0, 0, 2, 2): [np.int64(29)],
 (0, 0, 2, 3): [np.int64(0)],
 (0, 0, 2, 4): [np.int64(0)],
 (0, 0, 3, 0): [np.int64(0)],
 (0, 0, 3, 1): [np.int64(0)],
 (0, 0, 3, 2): [np.int64(22)],
 (0, 0, 3, 3): [np.int64(28)],
 (0, 0, 3, 4): [np.int64(68)],
 (0, 0, 4, 0): [np.int64(69)],
 (0, 0, 4, 1): [np.int64(0)],
 (0, 0, 4, 2): [np.int64(8)],
 (0, 0, 4, 3): [np.int64(0)],
 (0, 0, 4, 4): [np.int64(94)],
 (0, 1, 0, 0): [np.int64(50)],
 (0, 1, 0, 1): [np.int64(35)],
 (0, 1, 0, 2): [np.int64(5)],
 (0, 1, 0, 3): [np.int64(0)],
 (0, 1, 0, 4): [np.int64(29)],
 (0, 1, 1, 0): [np.int64(57)],
 (0, 1, 1, 1): [np.int64(0)],
 (0, 1, 1, 2): [np.int6