Esta notebook contiene bloques de código útiles para realizar Q-learning en el entorno "Pendulum"

In [28]:
import numpy as np
from pendulum_env_extended import PendulumEnvExtended
import random 

In [29]:
alpha = 0.1
gamma = 0.99
epsilon = 0.1
n_episodes = 100
epsilon_variability = 0.8
cant_Buckets = 10

In [30]:
env = PendulumEnvExtended(render_mode='rgb_array')

Discretización de los estados

In [31]:
x_space = np.linspace(-1, 1, cant_Buckets)
y_space = np.linspace(-1, 1, cant_Buckets)
vel_space = np.linspace(-8, 8, 100)
x_space

array([-1.        , -0.77777778, -0.55555556, -0.33333333, -0.11111111,
        0.11111111,  0.33333333,  0.55555556,  0.77777778,  1.        ])

Obtener el estado a partir de la observación

In [32]:
def get_state(obs):
    x, y, vel = obs
    x_bin = np.digitize(x, x_space)
    y_bin = np.digitize(y, y_space)
    vel_bin = np.digitize(vel, vel_space)
    return x_bin, y_bin, vel_bin

In [33]:
state = get_state(np.array([-0.4, 0.2, 0.3])) #mapeo de estado. Nos dice el bin en el que estamos
state

(3, 6, 52)

Discretización de las acciones

In [34]:
actions = list(np.linspace(-2, 2, cant_Buckets)) #minimo, maximo y cuantos bins
actionBuckets=np.linspace(-2, 2, cant_Buckets)
actions

[-2.0,
 -1.5555555555555556,
 -1.1111111111111112,
 -0.6666666666666667,
 -0.22222222222222232,
 0.22222222222222232,
 0.6666666666666665,
 1.1111111111111107,
 1.5555555555555554,
 2.0]

In [35]:
def getActions(action):
    return np.digitize(action,actionBuckets)

In [36]:
def get_sample_action():
    return random.choice(actions)

Inicilización de la tabla Q

In [37]:
Q = np.zeros((len(x_space) + 1, len(y_space) + 1, len(vel_space) + 1, len(actions)))
Q

array([[[[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
    

Obtención de la acción a partir de la tabla Q

In [38]:
def optimal_policy(state, Q):
    action = actions[np.argmax(Q[state])]
    return action

Epsilon-Greedy Policy

In [39]:
def epsilon_greedy_policy(state, Q, epsilon=0.1):
    explore = np.random.binomial(1, epsilon)
    # explore
    if explore:
        action = get_sample_action()
        print('explore')
    # exploit
    else:
        action = optimal_policy(state, Q)
        print('exploit')
        
    return action

Ejemplo de episodio 

In [40]:
%pip install pygame

Collecting pygameNote: you may need to restart the kernel to use updated packages.

  Downloading pygame-2.5.2-cp39-cp39-win_amd64.whl.metadata (13 kB)
Downloading pygame-2.5.2-cp39-cp39-win_amd64.whl (10.8 MB)
   ---------------------------------------- 0.0/10.8 MB ? eta -:--:--
    --------------------------------------- 0.2/10.8 MB 6.6 MB/s eta 0:00:02
   - -------------------------------------- 0.5/10.8 MB 6.0 MB/s eta 0:00:02
   -- ------------------------------------- 0.7/10.8 MB 6.1 MB/s eta 0:00:02
   --- ------------------------------------ 0.9/10.8 MB 5.4 MB/s eta 0:00:02
   ---- ----------------------------------- 1.1/10.8 MB 5.5 MB/s eta 0:00:02
   ----- ---------------------------------- 1.5/10.8 MB 6.1 MB/s eta 0:00:02
   ------ --------------------------------- 1.7/10.8 MB 5.9 MB/s eta 0:00:02
   ------- -------------------------------- 2.0/10.8 MB 6.1 MB/s eta 0:00:02
   -------- ------------------------------- 2.3/10.8 MB 6.4 MB/s eta 0:00:02
   --------- -------------

In [41]:
obs,_ = env.reset()
print(obs)
done = False
total_reward = 0
while not done:
    state = get_state(obs)
    print('state', state)
    action = epsilon_greedy_policy(state, Q, 0.5)
    action_idx = actions.index(action)
    # Acción del ambiente
    real_action = np.array([action])

    obs, reward, done, _, _ = env.step(real_action)
    next_state = get_state(obs)
    
    total_reward += reward
    print('->', state, action, reward, obs, done)
    env.render()
print('total_reward', total_reward)

[ 0.4278568   0.90384656 -0.07612711]
state (7, 9, 50)
explore
-> (7, 9, 50) 1.5555555555555554 -1.2769089873511925 [0.38975513 0.9209185  0.8350911 ] False


state (7, 9, 55)
exploit
-> (7, 9, 55) -2.0 -1.443645598755916 [0.33261648 0.9430622  1.22578   ] False
state (6, 9, 58)
explore
-> (6, 9, 58) -0.22222222222222232 -1.667436318430144 [0.24167293 0.9703578  1.8997433 ] False
state (6, 9, 62)
explore
-> (6, 9, 62) -1.1111111111111112 -2.1222879818152647 [0.12075188 0.9926827  2.460845  ] False
state (6, 9, 65)
exploit
-> (6, 9, 65) -2.0 -2.7113481173384524 [-0.02421821  0.9997067   2.9053571 ] False
state (5, 9, 68)
explore
-> (5, 9, 68) 1.1111111111111107 -3.389423430704769 [-0.21365103  0.97691005  3.8218038 ] False
state (4, 9, 74)
exploit
-> (4, 9, 74) -2.0 -4.654796888412884 [-0.41508394  0.9097831   4.254486  ] False
state (3, 9, 76)
exploit
-> (3, 9, 76) -2.0 -5.809392284247679 [-0.6130191  0.7900681  4.6368237] False
state (2, 9, 79)
explore
-> (2, 9, 79) -0.6666666666666667 -7.126357886803334 [-0.79338205  0.60872406  5.1293745 ] False
state (1, 8, 82)
exploit
-> (1, 8, 82) -2.0 -8.820920330467734 [-0.9248499  0.3803322  5.28591

In [42]:
%pip install sklearn

Collecting sklearnNote: you may need to restart the kernel to use updated packages.

  Using cached sklearn-0.0.post12.tar.gz (2.6 kB)
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'error'


  error: subprocess-exited-with-error
  
  × python setup.py egg_info did not run successfully.
  │ exit code: 1
  ╰─> [15 lines of output]
      The 'sklearn' PyPI package is deprecated, use 'scikit-learn'
      rather than 'sklearn' for pip commands.
      
      Here is how to fix this error in the main use cases:
      - use 'pip install scikit-learn' rather than 'pip install sklearn'
      - replace 'sklearn' by 'scikit-learn' in your pip requirements files
        (requirements.txt, setup.py, setup.cfg, Pipfile, etc ...)
      - if the 'sklearn' package is used by one of your dependencies,
        it would be great if you take some time to track which package uses
        'sklearn' instead of 'scikit-learn' and report it to their issue tracker
      - as a last resort, set the environment variable
        SKLEARN_ALLOW_DEPRECATED_SKLEARN_PACKAGE_INSTALL=True to avoid this error
      
      More information is available at
      https://github.com/scikit-learn/sklearn-pypi-packag

In [43]:
import wandb
from wandb.sklearn import plot_precision_recall, plot_feature_importances
from wandb.sklearn import plot_class_proportions, plot_learning_curve, plot_roc

def train_policy():
        try:
            with wandb.init() as run:
                config = run.config
                alpha = config.alpha
                gamma = config.gamma
                epsilon = config.epsilon
                epsilon_variability = config.epsilonVariability
                episodes = config.episodes
                total_rewards = []
                total_reward_promedio = []

            for episode in range(episodes):
                obs, _ = env.reset()
                done = False
                total_reward = 0
                step_count = 0

                while not done:
                    state = obs
                    epsilon = max(epsilon - epsilon_variability * epsilon, 0)  # Decay epsilon
                    action = epsilon_greedy_policy(get_state(state), Q, epsilon)
                    obs, reward, done, _, _ = env.step([action])
                    total_reward_promedio.append(reward)
                    newState = get_state(state)
                    binAction = getActions(action)
                    Q[newState[0],newState[1],newState[2],binAction] += alpha * (reward + gamma * np.max(Q[get_state(obs)]) - Q[newState[0],newState[1],newState[2],binAction])
                    total_reward += reward
                    step_count += 1

                total_rewards.append(total_reward)
                last_ten_rewards = total_reward_promedio[-10:]
                last_rewards_mean = np.mean(last_ten_rewards)
                wandb.log(last_rewards_mean)
                wandb.log({'train_avg_reward': last_rewards_mean})
        except Exception as e:
                print(f"An error occurred: {e}")
        finally:
                wandb.finish()
                wandb.log({"total_reward": total_reward, "steps": step_count})

In [44]:
import wandb

# Configuración del sweep
sweep_configuration = {
    "method": "random",
    "metric": {"goal": "maximize", "name": 'train_avg_reward'},
    "parameters": {
        "episodes": {"max": 10000, "min": 9999},
        "alpha": {"max": 0.9, "min": 0.1},
        "gamma": {"max": 0.9, "min": 0.1},
        "epsilon": {"max": 0.8, "min": 0.1},
        "epsilonVariability": {"max": 0.5, "min": 0.35}
    },
}

sweep_id = wandb.sweep(sweep=sweep_configuration, project="pendulum-sweep")
wandb.agent(sweep_id, function=train_policy)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

In [None]:
train_policy()

exploit
-2.0
[-0.99688566  0.07886058 -0.5943711 ]
-9.419153350359384
-2.0
exploit
-2.0
[-0.9927241   0.12041095 -0.83522564]
-9.199529458934128
-2.0
exploit
-2.0
[-0.98508143  0.17208879 -1.0449175 ]
-8.926026212108594
-2.0
exploit
-2.0
[-0.97280645  0.23161961 -1.2158508 ]
-8.607423329858042
-2.0
exploit
-2.0
[-0.9550852   0.29633123 -1.3421361 ]
-8.253957161248922
-2.0
exploit
-2.0
[-0.93165916  0.3633335  -1.4198877 ]
-7.8771188479309675
-2.0
exploit
-2.0
[-0.9029492   0.42974722 -1.4473876 ]
-7.489351088238139
-2.0
exploit
-2.0
[-0.87006277  0.49294093 -1.4250772 ]
-7.103622174868753
-2.0
exploit
-2.0
[-0.83468527  0.55072725 -1.3553715 ]
-6.732895825007395
-2.0
exploit
-2.0
[-0.7988884  0.6014793 -1.242326 ]
-6.38955707974638
-2.0
exploit
-2.0
[-0.76489866  0.6441507  -1.0912166 ]
-6.084877962004481
-2.0
exploit
-2.0
[-0.7348726  0.6782052 -0.9081035]
-5.828599590981752
-2.0
exploit
-2.0
[-0.7107095   0.70348555 -0.69944966]
-5.628673064194676
-2.0
exploit
-2.0
[-0.69391686  0.72

exploit
-2.0
[-0.8892458   0.45742968 -1.4438069 ]
-7.3226270467856756
-2.0
exploit
-2.0
[-0.85505503  0.5185372  -1.4007347 ]
-6.9417756108692945
-2.0
exploit
-2.0
[-0.81922907  0.5734664  -1.3118317 ]
-6.5813513804282096
-2.0
exploit
-2.0
[-0.78393495  0.62084293 -1.1817319 ]
-6.253321387217598
-2.0
exploit
-2.0
[-0.7513951   0.65985256 -1.0160997 ]
-5.9683423661527994
-2.0
exploit
-2.0
[-0.72367555  0.69014037 -0.8212103 ]
-5.7354085447038035
-2.0
exploit
-2.0
[-0.70252055  0.7116635  -0.60360503]
-5.561662118987371
-2.0
exploit
-2.0
[-0.68924046  0.72453266 -0.3698574 ]
-5.452343173055207
-2.0
exploit
-2.0
[-0.68464553  0.7288762  -0.12645788]
-5.410822537035135
-2.0
exploit
-2.0
[-0.68901366  0.7247483   0.12019923]
-5.438649565735088
-2.0
exploit
-2.0
[-0.7020807   0.71209735  0.36376047]
-5.535557541584369
-2.0
exploit
-2.0
[-0.7230497  0.690796   0.5978335]
-5.699397134451752
-2.0
exploit
-2.0
[-0.75062233  0.6607315   0.8159305 ]
-5.926005392019152
-2.0
exploit
-2.0
[-0.783064