In [1]:
import pybullet as p
import time
import numpy as np
import pybullet_data 
import matplotlib.pyplot as plt
from collections import Counter
import random
from Env import Pickup_Bot_Env
# from agent import Agent
import os
import math

pybullet build time: Dec 17 2023 23:51:54


# Temporal Difference 

## SARSA Implementation

### Components

1. **Action-to-Index Mapping:**
   - Defines actions and their corresponding indices for reference.
   - Facilitates action selection and indexing in the Q-table.

2. **Q-Table (Q):**
   - Stores Q-values representing expected cumulative rewards for state-action pairs.
   - Enables the agent to estimate the value of taking specific actions in different states.
   - Key component of the SARSA learning algorithm.

### Training Loop

1. **Episode Iteration:**
   - The training loop iterates over a series of episodes, each representing an attempt by the agent to learn and improve its behavior.

2. **Initialization and Environment Setup:**
   - At the start of each episode, the environment is initialized.
   - The current state `S_t` is obtained from the environment using `get_current_state()`.

3. **Action Selection and Execution:**
   - Based on the current state `S_t` and the epsilon-greedy policy, an action `A_t` is chosen for the agent to perform.
   - The action `A_t` is taken in the environment using `step()`, and the agent receives a reward.

4. **State Transition and Next Action Selection:**
   - The environment transitions to the next state `S_t1`.
   - The next action `A_t1` is selected for the new state `S_t1` using the epsilon-greedy policy.

5. **Q-Value Update (SARSA):**
   - The Q-value for the current state-action pair `(S_t, A_t)` is updated using the SARSA (State-Action-Reward-State-Action) algorithm.
   - Q-values are updated based on the received reward, the next state `S_t1`, and the next action `A_t1`.
   - The update step involves:
     
     Q[S<sub>t</sub>][A<sub>t</sub>] += α * (R<sub>t+1</sub> + Q[S<sub>t+1</sub>][A<sub>t+1</sub>] - Q[S<sub>t</sub>][A<sub>t</sub>])

     

6. **Epsilon Decay and Saving:**
   - Periodically, the epsilon value is decayed to shift the agent's focus from exploration to exploitation.
   - Q-values are saved periodically to track the agent's learning progress and to enable the continuation of training from saved checkpoints.



In [2]:
iterations = 1000
gamma = 0.9
epsilon = 0.1
alpha = 0.1

In [1]:
action_to_index = {
    'move_down' : 0,
    'move_up' : 1 ,
    'move_left' : 2 ,
    'move_right' :3,
    'close_gripper' : 4,
    'open_gripper' :5
}
Q = {
    (0, 0.53, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (0, 0.53, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
    (3.14, 0.53, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (3.14, 0.53, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
    (-1.57, 0.53, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (-1.57, 0.53, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
    (1.57, 0.53, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (1.57, 0.53, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
    (0, 0.11, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (0, 0.11, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
    (3.14, 0.11, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (3.14, 0.11, 0.04, 0.04): [0, 0, 0, 0, 0, 0],
    (-1.57, 0.11, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (-1.57, 0.11, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
    (1.57, 0.11, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (1.57, 0.11, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
    (0, 0.32, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (0, 0.32, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
    (3.14, 0.32, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (3.14, 0.32, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
    (-1.57, 0.32, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (-1.57, 0.32, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
    (1.57, 0.32, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
    (1.57, 0.32, 0.04, 0.04): [50, 50, 50, 50, 50, 50]
}


In [4]:
for episode in range(iterations):

    
    print('#################################################')
    print("Current Episode : ", episode)
    path_to_bot = './bot/robot.urdf'
    
    env = Pickup_Bot_Env(path_to_bot, False)
    S_t = env.get_current_state()
    
    action = env.choose_action(Q[S_t] , epsilon)
   
    A_t = action_to_index[action]
    while env.rounded_position!= env.terminal_state:
        
        env.step(action)
        reward = env.get_reward()
        
        S_t1 = env.rounded_position
        action = env.choose_action(Q[S_t1] , epsilon)
      

        A_t1 = action_to_index[action]
        
        Q[S_t][A_t] += alpha*(reward + Q[S_t1][A_t1] - Q[S_t][A_t])
        print("State | Action | Reward" , S_t , A_t , reward)
        print("State_next | Action_next" , S_t1 ,A_t1 )
        print("\n")
        S_t = S_t1
        A_t = A_t1
        
    if episode % 10 == 0:
        
        os.mkdir('./save_SARSA/{}'.format(episode))
        np.save('./save_SARSA/{}/Q.npy'.format(episode), Q)
        
    env.reset_env()
    

#################################################
Current Episode :  0
False
State | Action | Reward (0, 0.32, -0.84, -0.84) 0 -1
State_next | Action_next (0, 0.11, -0.84, -0.84) 5


State | Action | Reward (0, 0.11, -0.84, -0.84) 5 -1
State_next | Action_next (0, 0.11, -0.84, -0.84) 0


State | Action | Reward (0, 0.11, -0.84, -0.84) 0 -1
State_next | Action_next (0, 0.11, -0.84, -0.84) 0


State | Action | Reward (0, 0.11, -0.84, -0.84) 0 -1
State_next | Action_next (0, 0.11, -0.84, -0.84) 1


State | Action | Reward (0, 0.11, -0.84, -0.84) 1 -1
State_next | Action_next (0, 0.32, -0.84, -0.84) 1


State | Action | Reward (0, 0.32, -0.84, -0.84) 1 -1
State_next | Action_next (0, 0.53, -0.84, -0.84) 0


State | Action | Reward (0, 0.53, -0.84, -0.84) 0 -1
State_next | Action_next (0, 0.32, -0.84, -0.84) 2


State | Action | Reward (0, 0.32, -0.84, -0.84) 2 -1
State_next | Action_next (1.57, 0.32, -0.84, -0.84) 4


State | Action | Reward (1.57, 0.32, -0.84, -0.84) 4 -10
State_next | Ac

## <div align="center">Final Q values</div>
<br>

| States  | Go Down | Go Up | Go Left | Go Right | Close Gripper | Open Gripper |
|-----------------------|---------|-------|---------|----------|---------------|--------------|
| (0, 0.53, -0.84, -0.84) | 76.4583112773043 | 49.529 | 49.53629 | 49.93356115882323 | 48.09 | 49.49847809999999 |
| (0, 0.53, 0.04, 0.04) | 49.0 | 48.1 | 49.0 | 49.0 | 49.0 | 51.55609465134686 |
| (3.14, 0.53, -0.84, -0.84) | 83.78199100326314 | 55.800024503779795 | 49.6119 | 53.63515727454566 | 49.0 | 49.8 |
| (3.14, 0.53, 0.04, 0.04) | 49.0 | 49.43936972615094 | 49.0 | 48.81 | 48.1 | 71.04869151133106 |
| (-1.57, 0.53, -0.84, -0.84) | 49.78971418885 | 49.619 | 49.791 | 56.53534142174903 | 48.09 | 49.8 |
| (-1.57, 0.53, 0.04, 0.04) | 49.0 | 48.1 | 49.0 | 49.0 | 49.0 | 49.980006748666135 |
| (1.57, 0.53, -0.84, -0.84) | 49.699999999999996 | 49.619 | 76.4772562748358 | 51.36040720490163 | 48.01 | 50.24922201255391 |
| (1.57, 0.53, 0.04, 0.04) | 49.0 | 47.29 | 48.1 | 49.0 | 48.1 | 51.95033961443649 |
| (0, 0.11, -0.84, -0.84) | 49.620000000000005 | 79.13156407764221 | 51.68237760154969 | 49.8692700832317 | 48.1 | 49.6091 |
| (0, 0.11, 0.04, 0.04) | 47.29 | 49.0 | 49.5 | 49.5 | 48.09 | 52.652299021496084 |
| (3.14, 0.11, -0.84, -0.84) | 53.168997555 | 55.85480842588169 | 49.8 | 49.89 | 94.52905054342438 | 50 |
| (3.14, 0.11, 0.04, 0.04) | 0 | 0 | 0 | 0 | 0 | 0 |
| (-1.57, 0.11, -0.84, -0.84) | 49.81 | 49.8 | 54.10817749738163 | 49.791 | 48.1 | 49.8 |
| (-1.57, 0.11, 0.04, 0.04) | 48.1 | 49.0 | 48.81 | 59.5 | 50 | 50 |
| (1.57, 0.11, -0.84, -0.84) | 50.65970588221661 | 78.07806265844451 | 54.17736754771449 | 49.79 | 49.0 | 49.8 |
| (1.57, 0.11, 0.04, 0.04) | 48.1 | 49.0 | 63.55 | 50 | 50 | 50 |
| (0, 0.32, -0.84, -0.84) | 68.89884080281884 | 66.40722072183331 | 81.830189684113 | 77.25040812168756 | 54.78553732861665 | 74.67936656825093 |
| (0, 0.32, 0.04, 0.04) | 48.1 | 48.0772 | 49.0 | 49.0 | 50.60874192585897 | 75.97252178811729 |
| (3.14, 0.32, -0.84, -0.84) | 78.42795175724228 | 72.05053072530097 | 80.18410447891276 | 76.5414456513163 | 88.12687341730425 | 80.88294250347288 |
| (3.14, 0.32, 0.04, 0.04) | 99.99999999999994 | 57.57028932541432 | 49.528352975934645 | 57.38954144643438 | 81.18718876806155 | 76.05437091177834 |
| (-1.57, 0.32, -0.84, -0.84) | 49.71894080562687 | 49.81 | 49.7739 | 84.77262971814878 | 47.94978222461721 | 49.799 |
| (-1.57, 0.32, 0.04, 0.04) | 49.0 | 49.0 | 49.0 | 49.95 | 47.71386100312606 | 68.84878067371608 |
| (1.57, 0.32, -0.84, -0.84) | 64.59191270799717 | 63.510722075581064 | 84.59804205776545 | 69.68589580979621 | 52.42231368978855 | 77.27758350658016 |
| (1.57, 0.32, 0.04, 0.04) | 49.05 | 49.0 | 48.1 | 49.593177122415376 | 49.96354113568732 | 78.25979302891598 |


values are for representation, the code cell above ran again later and changed values

In [2]:
def convert_Q_to_policy(Q):
    policy = {}
    for key in Q:
        max_val= max(Q[key])
        policy_arr = [0,0,0,0,0,0]
        for j in range(len(Q[key])):
            if Q[key][j]==max_val:
                policy_arr[j]=1
    
        policy[key] = policy_arr

    return policy

In [3]:
import pybullet as p
import time
import numpy as np
import pybullet_data 
import matplotlib.pyplot as plt
from Env import Pickup_Bot_Env

import math
#results 
Q = np.load('./save_SARSA/990/Q.npy' , allow_pickle=True)
path_to_bot = './bot/robot.urdf'

policy = convert_Q_to_policy(Q.item())

    # Create environment instance
position = (0, 0.54, -0.84, -0.84) # Starting state

env = Pickup_Bot_Env(path_to_bot,position, True , False)


S_t = env.get_current_state()
    
action = env.choose_action(policy[S_t] , epsilon = 0)
   
A_t = action_to_index[action] # as we have saved a .npy the new file needs to be converted back to a dictionary
while env.rounded_position != env.terminal_state:
    env.step(action)
    reward = env.get_reward()
    
    S_t1 = env.rounded_position
    action = env.choose_action(policy[S_t1] , epsilon = 0)
  

    A_t1 = action_to_index[action]
    
    
    print("State | Action | Reward" , S_t , A_t , reward)
    print("State_next | Action_next" , S_t1 ,A_t1 )
    print("\n")
    S_t = S_t1
    A_t = A_t1


env.reset_env()
print("Reached")

pybullet build time: Dec 17 2023 23:51:54


(0, 0.54, -0.84, -0.84)
State | Action | Reward (0, 0.53, -0.84, -0.84) 0 -1
State_next | Action_next (0, 0.11, -0.84, -0.84) 3


State | Action | Reward (0, 0.11, -0.84, -0.84) 3 -1
State_next | Action_next (-1.57, 0.11, -0.84, -0.84) 3


State | Action | Reward (-1.57, 0.11, -0.84, -0.84) 3 -1
State_next | Action_next (3.14, 0.11, -0.84, -0.84) 4


State | Action | Reward (3.14, 0.11, -0.84, -0.84) 4 100
State_next | Action_next (3.14, 0.11, 0.04, 0.04) 0


Reached
