In [1]:
import pybullet as p
import time
import numpy as np
import pybullet_data 
import matplotlib.pyplot as plt
from collections import Counter
import random
from Env import Pickup_Bot_Env
# from agent import Agent
import os
import math

pybullet build time: Dec 17 2023 23:51:54


# Double Q-Learning Algorithm for Reinforcement Learning

The provided code snippet implements a Double Q-Learning algorithm for reinforcement learning. Here's a breakdown of the key components:

- **`action_to_index`**: Maps actions to corresponding indices for Q-value storage.
- **`Q1` and `Q2`**: Dictionaries storing Q-values for state-action pairs, each representing an independent Q-table.
- **`pick_Q()`**: Randomly selects either Q1 or Q2 with equal probability.
- **`average_Q1_Q2(Q1, Q2)`**: Computes the average Q-values between Q1 and Q2 for each state-action pair.
- **`for` loop**: Iterates over a defined number of episodes.
- **Episode Execution**:
  - Initializes the environment and the current state.
  - Computes the final Q-values by averaging Q1 and Q2.
  - Selects an action based on the current state and exploration rate (epsilon).
  - Executes actions in the environment until reaching a terminal state.
  - Updates Q-values based on rewards and transitions between states using Double Q-Learning logic.
  - Adjusts epsilon every 50 episodes and saves the final Q-values to disk.
  - Resets the environment for the next episode.



In [2]:
iterations = 1000
gamma = 0.9
epsilon = 0.3 #had to tinker a bit so that exploration is appreciated 
alpha = 0.1

In [3]:
action_to_index = {
    'move_down' : 0,
    'move_up' : 1 ,
    'move_left' : 2 ,
    'move_right' :3,
    'close_gripper' : 4,
    'open_gripper' :5
}
Q1 = {
    (0, 0.53, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (0, 0.53, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
    (3.14, 0.53, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (3.14, 0.53, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
    (-1.57, 0.53, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (-1.57, 0.53, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
    (1.57, 0.53, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (1.57, 0.53, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
    (0, 0.11, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (0, 0.11, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
    (3.14, 0.11, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (3.14, 0.11, 0.04, 0.04): [0, 0, 0, 0, 0, 0],
    (-1.57, 0.11, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (-1.57, 0.11, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
    (1.57, 0.11, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (1.57, 0.11, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
    (0, 0.32, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (0, 0.32, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
    (3.14, 0.32, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (3.14, 0.32, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
    (-1.57, 0.32, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (-1.57, 0.32, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
    (1.57, 0.32, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (1.57, 0.32, 0.04, 0.04): [50, 50, 50, 50, 50, 50]
}

Q2 = {
    (0, 0.53, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (0, 0.53, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
    (3.14, 0.53, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (3.14, 0.53, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
    (-1.57, 0.53, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (-1.57, 0.53, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
    (1.57, 0.53, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (1.57, 0.53, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
    (0, 0.11, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (0, 0.11, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
    (3.14, 0.11, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (3.14, 0.11, 0.04, 0.04): [0, 0, 0, 0, 0, 0],
    (-1.57, 0.11, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (-1.57, 0.11, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
    (1.57, 0.11, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (1.57, 0.11, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
    (0, 0.32, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (0, 0.32, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
    (3.14, 0.32, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (3.14, 0.32, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
    (-1.57, 0.32, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (-1.57, 0.32, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
    (1.57, 0.32, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (1.57, 0.32, 0.04, 0.04): [50, 50, 50, 50, 50, 50]
}


In [4]:
def pick_Q():
    choose = random.choices(['Q1' , 'Q2'] , weights=[0.5 , 0.5])
    return choose[0]

def average_Q1_Q2(Q1,Q2):
    Q = {}
    for state in Q1:
        Q[state] = [(g + h) / 2 for g, h in zip(Q1[state], Q2[state])]
    return Q

In [5]:
for episode in range(iterations):
    print('#################################################')
    print("Current Episode : ", episode)
    path_to_bot = './bot/robot.urdf'
    
    env = Pickup_Bot_Env(path_to_bot, False)
    S_t = env.get_current_state()

    final_Q = average_Q1_Q2(Q1,Q2)
    action = env.choose_action(final_Q[S_t] , epsilon)
   
    A_t = action_to_index[action]
    while env.rounded_position!= env.terminal_state:
        
        env.step(action)
        reward = env.get_reward()
        
        S_t1 = env.rounded_position
        action = env.choose_action(final_Q[S_t1] , epsilon)
      

        A_t1 = action_to_index[action]

        
        if pick_Q() == 'Q1': #which value to update
            
            max_idx = np.argmax(Q1[S_t1])
        
            Q1[S_t][A_t] += alpha*(reward + Q2[S_t1][max_idx] - Q1[S_t][A_t]) #The update will now happen with taking the max of all actions that are possible

        elif pick_Q() == 'Q2' :
            max_idx = np.argmax(Q2[S_t1])
        
            Q2[S_t][A_t] += alpha*(reward + Q1[S_t1][max_idx] - Q2[S_t][A_t])

        final_Q = average_Q1_Q2(Q1,Q2)
        print("State | Action | Reward" , S_t , A_t , reward)
        print("State_next | Action_next" , S_t1 ,A_t1 )
        print("\n")
        S_t = S_t1
        A_t = A_t1
        
    if episode % 50 == 0:
        epsilon = epsilon * (0.95 ** (episode / 10))
        os.mkdir('./save_DoubleQLearning/{}'.format(episode))
        np.save('./save_DoubleQLearning/{}/final_Q.npy'.format(episode), final_Q)
        
    env.reset_env()
    

#################################################
Current Episode :  0
State | Action | Reward (0, 0.32, -0.84, -0.84) 0 -1
State_next | Action_next (0, 0.11, -0.84, -0.84) 1


State | Action | Reward (0, 0.11, -0.84, -0.84) 1 -1
State_next | Action_next (0, 0.32, -0.84, -0.84) 1


State | Action | Reward (0, 0.32, -0.84, -0.84) 1 -1
State_next | Action_next (0, 0.53, -0.84, -0.84) 2


State | Action | Reward (0, 0.53, -0.84, -0.84) 2 -1
State_next | Action_next (1.57, 0.53, -0.84, -0.84) 4


State | Action | Reward (1.57, 0.53, -0.84, -0.84) 4 -10
State_next | Action_next (1.57, 0.53, 0.04, 0.04) 5


State | Action | Reward (1.57, 0.53, 0.04, 0.04) 5 -1
State_next | Action_next (1.57, 0.53, -0.84, -0.84) 0


State | Action | Reward (1.57, 0.53, -0.84, -0.84) 0 -1
State_next | Action_next (1.57, 0.32, -0.84, -0.84) 0


State | Action | Reward (1.57, 0.32, -0.84, -0.84) 0 -1
State_next | Action_next (1.57, 0.11, -0.84, -0.84) 4


State | Action | Reward (1.57, 0.11, -0.84, -0.84) 4 -10


## <div align="center">Final Averaged Q values</div>

| States (Index Values) | Go Down | Go Up | Go Left | Go Right | Close Gripper | Open Gripper |
|-----------------------|---------|-------|---------|----------|---------------|--------------|
| (0, 0.53, -0.84, -0.84) | 51.61411900969908 | 49.527541238626 | 49.499135499999994 | 49.52937849999999 | 49.05 | 49.549533997836 |
| (0, 0.53, 0.04, 0.04) | 48.945499999999996 | 49.5 | 48.945499999999996 | 49.405 | 49.394549999999995 | 49.9386 |
| (3.14, 0.53, -0.84, -0.84) | 49.85875 | 49.83955 | 49.8 | 49.652685500000004 | 48.55 | 49.7581 |
| (3.14, 0.53, 0.04, 0.04) | 49.75 | 48.9225 | 48.955 | 48.995000000000005 | 49.05 | 49.9 |
| (-1.57, 0.53, -0.84, -0.84) | 50.284244755610814 | 49.705 | 49.742173967322366 | 49.74369 | 49.5 | 49.72795 |
| (-1.57, 0.53, 0.04, 0.04) | 48.945 | 49.5 | 49.05 | 49.5 | 49.405 | 50.0 |
| (1.57, 0.53, -0.84, -0.84) | 49.7321915 | 49.65 | 49.7595 | 49.77527409315 | 48.166855 | 49.577495 |
| (1.57, 0.53, 0.04, 0.04) | 48.95 | 48.4505 | 49.5 | 48.53195 | 49.05 | 49.75405 |
| (0, 0.11, -0.84, -0.84) | 53.55948456223213 | 52.04981883629688 | 50.8214029702739 | 97.99999999579254 | 44.760594345 | 51.18351565672661 |
| (0, 0.11, 0.04, 0.04) | 48.3945 | 48.9495 | 49.5 | 49.05 | 48.34545 | 52.016327701139645 |
| (3.14, 0.11, -0.84, -0.84) | 61.82652363592791 | 51.314180899272856 | 57.76632261796594 | 50.36203227744541 | 99.99999999982893 | 65.09452662817178 |
| (3.14, 0.11, 0.04, 0.04) | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
| (-1.57, 0.11, -0.84, -0.84) | 58.83111484913981 | 51.03267086694383 | 52.034578203169545 | 98.99999999917901 | 48.99866105 | 55.94593147152604 |
| (-1.57, 0.11, 0.04, 0.04) | 48.455 | 48.995000000000005 | 49.445 | 72.08948777500001 | 49.5 | 50.828593440079324 |
| (1.57, 0.11, -0.84, -0.84) | 49.8505 | 49.84363999999999 | 66.19997984531524 | 49.8595 | 48.95 | 49.834205 |
| (1.57, 0.11, 0.04, 0.04) | 48.95 | 49.5 | 56.775 | 50.0 | 50.0 | 49.95 |
| (0, 0.32, -0.84, -0.84) | 96.99999996115821 | 49.290769870779975 | 49.4418542073642 | 52.97477379033016 | 45.20595316226293 | 52.72766138186209 |
| (0, 0.32, 0.04, 0.04) | 48.394949999999994 | 48.894549999999995 | 49.45 | 49.445 | 49.5 | 50.47039887526957 |
| (3.14, 0.32, -0.84, -0.84) | 64.9855417791258 | 49.81715 | 50.340733672592016 | 49.81274795 | 49.7275 | 49.85095 |
| (3.14, 0.32, 0.04, 0.04) | 61.0975 | 49.5 | 50.0 | 50.0 | 50.0 | 49.945 |
| (-1.57, 0.32, -0.84, -0.84) | 64.52902927096216 | 49.708600000000004 | 49.7419115315 | 49.802005 | 48.445 | 50.820117661528286 |
| (-1.57, 0.32, 0.04, 0.04) | 49.5 | 49.45 | 49.45 | 49.0 | 48.945499999999996 | 50.26123001535448 |
| (1.57, 0.32, -0.84, -0.84) | 49.766624500000006 | 49.7795695 | 51.4727563688042 | 49.62817040635 | 49.5 | 49.629204276772505 |
| (1.57, 0.32, 0.04, 0.04) | 49.5 | 49.5 | 49.5 | 49.445 | 
