In [1]:
import gym
import numpy as np

# Create CartPole environment
env = gym.make('CartPole-v1')

# Define parameters
num_episodes = 1000
max_steps_per_episode = 500
learning_rate = 0.1
discount_rate = 0.99
epsilon = 0.1

# Discretize the observation space
num_bins = (6, 12, 6, 12)  # (cart position, cart velocity, pole angle, pole velocity)
num_actions = env.action_space.n

# Initialize Q-table with zeros
q_table = {}

def discretize(observation):
    print("Observation:", observation)
    # Rest of the function code...

# Helper function to select an action using epsilon-greedy policy
def choose_action(state):
    if np.random.uniform(0, 1) < epsilon:
        return np.random.choice(num_actions)
    else:
        return np.argmax(q_table[state])
    

# Q-learning algorithm
for episode in range(num_episodes):
    observation = env.reset()
    state = discretize(observation)
    done = False
    total_reward = 0

    for step in range(max_steps_per_episode):
        if state not in q_table:
            q_table[state] = np.zeros(num_actions)

        action = choose_action(state)
        next_observation, reward, done, _, _ = env.step(action)
        next_state = discretize(next_observation)

        if next_state not in q_table:
            q_table[next_state] = np.zeros(num_actions)

        # Update Q-value
        q_table[state][action] += learning_rate * (
            reward + discount_rate * np.max(q_table[next_state]) - q_table[state][action])

        total_reward += reward
        state = next_state

        if done:
            break

    if episode % 100 == 0:
        print(f"Episode: {episode}, Total Reward: {total_reward}")

print("Training finished.")


Observation: (array([-0.00768522, -0.00622621,  0.00232681,  0.04219859], dtype=float32), {})
Observation: [-0.00780974 -0.20138144  0.00317079  0.33561474]
Observation: [-0.01183737 -0.3965484   0.00988308  0.6292959 ]
Observation: [-0.01976834 -0.5918068   0.022469    0.9250749 ]
Observation: [-0.03160448 -0.78722495  0.0409705   1.2247334 ]
Observation: [-0.04734898 -0.9828499   0.06546516  1.5299665 ]
Observation: [-0.06700597 -1.1786972   0.09606449  1.8423413 ]
Observation: [-0.09057992 -1.3747389   0.13291131  2.1632473 ]
Observation: [-0.11807469 -1.5708876   0.17617626  2.4938364 ]
Observation: [-0.14949244 -1.7669799   0.226053    2.8349504 ]
Episode: 0, Total Reward: 9.0
Observation: (array([-0.00922083, -0.04188472,  0.02916505, -0.04310755], dtype=float32), {})
Observation: [-0.01005853 -0.23741248  0.0283029   0.25863266]
Observation: [-0.01480678 -0.4329268   0.03347555  0.56010664]
Observation: [-0.02346531 -0.62850225  0.04467768  0.8631454 ]
Observation: [-0.03603536 

  if not isinstance(terminated, (bool, np.bool8)):


Observation: (array([ 0.00153783,  0.03523894, -0.01818465,  0.0305198 ], dtype=float32), {})
Observation: [ 0.00224261 -0.15961757 -0.01757425  0.31741026]
Observation: [-0.00094974 -0.35448486 -0.01122604  0.6044995 ]
Observation: [-8.0394391e-03 -5.4944801e-01  8.6394453e-04  8.9362544e-01]
Observation: [-0.0190284  -0.7445817   0.01873645  1.1865798 ]
Observation: [-0.03392003 -0.9399415   0.04246805  1.4850763 ]
Observation: [-0.05271886 -1.1355547   0.07216958  1.7907133 ]
Observation: [-0.07542996 -1.3314079   0.10798384  2.104928  ]
Observation: [-0.10205812 -1.5274341   0.1500824   2.42894   ]
Observation: [-0.1326068  -1.7234949   0.19866121  2.763684  ]
Observation: [-0.16707669 -1.9193627   0.2539349   3.1097302 ]
Observation: (array([-0.02261535,  0.04988339, -0.03631738,  0.03124148], dtype=float32), {})
Observation: [-0.02161768 -0.14469944 -0.03569255  0.3122483 ]
Observation: [-0.02451167 -0.3392952  -0.02944759  0.5934647 ]
Observation: [-0.03129757 -0.5339928  -0.017