In [1]:
import pybullet as p
import time
import numpy as np
import pybullet_data 
import matplotlib.pyplot as plt
from collections import Counter
import random
from Env import Pickup_Bot_Env
# from agent import Agent
import os
import math

pybullet build time: Dec 17 2023 23:51:54


the main idea here is to only update the state-action pairs after some specifed(n) time steps, considering all the rewards followed untill n time steps 

# n-Step SARSA
### Components:

1. **Action to Index Mapping:**
   - Defines the mapping of actions to their corresponding index values.
   - Each action is associated with a unique index for representation and lookup.

2. **Q-Values Initialization:**
   - Initializes the Q-values table.
   - Each state-action pair is associated with an initial Q-value.

3. **Function to Calculate G:**
   - Computes the return G for a given set of states, actions, and rewards.
   - Considers the discounted sum of rewards from a specific time step to the horizon.

### Training Loop:

- **Outermost Loop over Episodes:**
  - Iterates through a 3 n values `[ 2 , 3 , 5 ]`.
  - Each `n` represents the timestep lookahead before updaring the Q value.

- **Outer Loop over Episodes:**
  - Iterates through a predefined number of training episodes.
  - Each episode represents a complete trial or interaction of the agent with the environment.

- **Initialization:**
  - Sets up the environment and initializes the current state (S_t).

- **Action Selection:**
  - Chooses an action based on the current state and the Q-values using an epsilon-greedy policy.

- **Temporal Difference Learning:**
  - Implements n-step SARSA algorithm for temporal difference learning.
  - Updates Q-values based on the observed rewards and transitions.
  - Handles episodes until a terminal state or maximum horizon is reached.

- **Saving Q-values:**
  - Periodically saves the updated Q-values during training for analysis or future use.





In [2]:
iterations = 1000
gamma = 0.9
epsilon = 0.1
alpha = 0.1
n_all=[2,3,5]

In [3]:
def init():
    action_to_index = {
        'move_down' : 0,
        'move_up' : 1 ,
        'move_left' : 2 ,
        'move_right' :3,
        'close_gripper' : 4,
        'open_gripper' :5
    }
    Q = {
        (0, 0.53, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
        (0, 0.53, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
        (3.14, 0.53, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
        (3.14, 0.53, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
        (-1.57, 0.53, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
        (-1.57, 0.53, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
        (1.57, 0.53, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
        (1.57, 0.53, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
        (0, 0.11, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
        (0, 0.11, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
        (3.14, 0.11, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
        (3.14, 0.11, 0.04, 0.04): [0, 0, 0, 0, 0, 0],
        (-1.57, 0.11, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
        (-1.57, 0.11, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
        (1.57, 0.11, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
        (1.57, 0.11, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
        (0, 0.32, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
        (0, 0.32, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
        (3.14, 0.32, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
        (3.14, 0.32, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
        (-1.57, 0.32, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
        (-1.57, 0.32, 0.04, 0.04): [50, 50, 50, 50, 50, 50],
        (1.57, 0.32, -0.84, -0.84): [50, 50, 50, 50, 50, 50],
        (1.57, 0.32, 0.04, 0.04): [50, 50, 50, 50, 50, 50]
    }
    return Q , action_to_index

In [4]:
def calculate_G(l , h , info , gamma , update_state_time):
    G = 0
    # print("\n\n\n" , info)
    for i in range(l , h+1):
        
        power = i-update_state_time - 1
        G += (gamma**power)*info[i][0]
    return G

In [6]:
for n in n_all: 
    print(f"--------------------------------------------->n = {n}<----------------------------------------------")
    os.mkdir("./save_nStepSARSA/{}".format(n))
    Q , action_to_index = init() #we do this because for diffrenct n we need to start with the same initial Q values
    
    for episode in range(iterations):
        print('#################################################')
        print("Current Episode : ", episode)
        path_to_bot = './bot/robot.urdf'
        
        env = Pickup_Bot_Env(path_to_bot, False)
        S_t = env.get_current_state() #beign state
        
        action = env.choose_action(Q[S_t] , epsilon) #chose an action
       
        A_t = action_to_index[action] #get index of action
    
        T = float('inf')
        t = 0
    
        state_reward_collect = [(None , env.rounded_position , A_t) , ] #for t=0 we have no reward and this way of storing ensures that the reward follows time step fasion i.e. t=1 has reward and is stored at first index
        while True:
            if t < T:
                env.step(action)
                reward = env.get_reward()
                S_t1 = env.rounded_position
                
    
                
                if env.rounded_position== env.terminal_state:
                    T = t + 1
    
                else:
                    action = env.choose_action(Q[S_t1] , epsilon)
                    A_t1 = action_to_index[action]
                    state_reward_collect.append((reward , S_t1 , A_t1))
                    # print('hjiu')
    
                        
            update_state_time = t - n +1
            if update_state_time >= 0:
                # print(t)
                # print(T)
                # print(update_state_time + 1 , min(update_state_time + n , T))
                
                G = calculate_G(update_state_time + 1 , min(update_state_time + n , T-1) , state_reward_collect , gamma , update_state_time )

                # print(update_state_time , T)
                # print(state_reward_collect)
                update_state = state_reward_collect[update_state_time][1]
                action_taken_update_state = state_reward_collect[update_state_time][2]
                
                if update_state_time + n < T :
                    Q_update_state = state_reward_collect[update_state_time + n][1]
                    Q_update_action = state_reward_collect[update_state_time + n][2]
                    G += ((gamma)**n)*(Q[Q_update_state][Q_update_action])
                    
                Q[update_state][action_taken_update_state] += alpha*(G - Q[update_state][action_taken_update_state])
                
            if T==float('inf'):
                print("State_next | Action_next" , S_t1 ,A_t1 )
            t+=1
            if  update_state_time == T-1:
                break
          
            
        if episode % 10 == 0:
            
            os.mkdir('./save_nStepSARSA/{}/{}'.format(n , episode))
            np.save('./save_nStepSARSA/{}/{}/Q.npy'.format(n , episode), Q)
            
        env.reset_env()


--------------------------------------------->n = 2<----------------------------------------------
#################################################
Current Episode :  0
False
State_next | Action_next (0, 0.11, -0.84, -0.84) 0
State_next | Action_next (0, 0.11, -0.84, -0.84) 0
State_next | Action_next (0, 0.11, -0.84, -0.84) 0
State_next | Action_next (0, 0.11, -0.84, -0.84) 1
State_next | Action_next (0, 0.32, -0.84, -0.84) 4
State_next | Action_next (0, 0.32, 0.04, 0.04) 0
State_next | Action_next (0, 0.11, 0.04, 0.04) 0
State_next | Action_next (0, 0.11, 0.04, 0.04) 1
State_next | Action_next (0, 0.32, 0.04, 0.04) 1
State_next | Action_next (0, 0.53, 0.04, 0.04) 0
State_next | Action_next (0, 0.32, 0.04, 0.04) 1
State_next | Action_next (0, 0.53, 0.04, 0.04) 2
State_next | Action_next (1.57, 0.53, 0.04, 0.04) 2
State_next | Action_next (3.14, 0.53, 0.04, 0.04) 0
State_next | Action_next (3.14, 0.32, 0.04, 0.04) 0
#################################################
Current Episode :  1

In [11]:
Q_n_2 = np.load('./save_nStepSARSA/2/990/Q.npy' , allow_pickle=True)
Q_n_3 = np.load('./save_nStepSARSA/3/990/Q.npy' , allow_pickle=True)
Q_n_5 = np.load('./save_nStepSARSA/5/990/Q.npy', allow_pickle=True)

## <div align="center">Final Q values for n = 2</div>

<br>

| States (Index Values) | Go Down | Go Up | Go Left | Go Right | Close Gripper | Open Gripper |
|-----------------------|---------|-------|---------|----------|---------------|--------------|
| (0, 0.53, -0.84, -0.84) | -3.5925325682301286 | -3.1145713626675517 | -3.107209431667953 | -3.301061609984307 | -5.618291925942252 | -3.078247887495238 |
| (0, 0.53, 0.04, 0.04) | 2.6833372653486025 | -2.305991977997291 | 2.1074042719881825 | 1.3712309524767863 | 0.7353678586856836 | 3.139915948943419 |
| (3.14, 0.53, -0.84, -0.84) | -1.9234536332210603 | -2.1902075232362064 | -1.9619264998910344 | -2.0958686268890614 | -3.6995050655693458 | -2.091139230103842 |
| (3.14, 0.53, 0.04, 0.04) | 1.7917082155470694 | 5.350823609086342 | 5.194826997239729 | 2.081485928912122 | 4.529652917463553 | 5.669360616644171 |
| (-1.57, 0.53, -0.84, -0.84) | -2.528030398904096 | -2.556462017250635 | -2.7987262890636897 | -2.5821366213343606 | -4.102417215370537 | -2.588723810638343 |
| (-1.57, 0.53, 0.04, 0.04) | 2.7658327003884056 | 0.4222048847427122 | 3.167597291001637 | 4.250368114421859 | 2.369821795371786 | 3.5988629398802425 |
| (1.57, 0.53, -0.84, -0.84) | -2.6028249362536346 | -2.8115157072871386 | -2.617617693745577 | -2.6435463974137505 | -3.7414980792065267 | -2.773807222635598 |
| (1.57, 0.53, 0.04, 0.04) | 3.1723265405754915 | 1.6380781980869181 | 0.7200890407492608 | 3.5686480804010126 | 3.758745626726372 | 3.344573810264393 |
| (0, 0.11, -0.84, -0.84) | -2.8322397520062537 | -2.643992843824886 | -2.718504344571313 | -2.4810040591671374 | -6.932734719235112 | -2.8964989650480617 |
| (0, 0.11, 0.04, 0.04) | 1.108302360481522 | 0.039538560819329494 | 2.0910049038974066 | -0.2867160481218394 | 1.1412246429669333 | 2.1854223456976283 |
| (3.14, 0.11, -0.84, -0.84) | -1.0451072001589736 | -2.004945560961031 | -1.9713019352762857 | -1.8408647995521357 | 8.23599985129841e-42 | -1.0231177779434242 |
| (3.14, 0.11, 0.04, 0.04) | 0 | 0 | 0 | 0 | 0 | 0 |
| (-1.57, 0.11, -0.84, -0.84) | -1.6357429731849733 | -2.294638790296471 | -2.149526838428787 | -1.056121525604486 | -8.39969585665409 | -1.9051117616003534 |
| (-1.57, 0.11, 0.04, 0.04) | 6.140126189524825 | 5.831828847938978 | 5.063610211218542 | 6.0788327295284645 | 1.885665362047947 | 6.532288584953676 |
| (1.57, 0.11, -0.84, -0.84) | -1.7975859820560822 | -2.475862818747576 | -1.2450436018005795 | -2.8668921638615767 | -7.08094806106085 | -1.7845585915973314 |
| (1.57, 0.11, 0.04, 0.04) | 7.413798379103538 | 8.28084503419776 | 7.504731764849956 | 6.643398234774019 | 8.141457200479081 | 7.862394191043257 |
| (0, 0.32, -0.84, -0.84) | -3.970022409367797 | -4.528627730135709 | -3.169767834959741 | -4.026157746600714 | -13.725757060934896 | -4.273213978841413 |
| (0, 0.32, 0.04, 0.04) | -1.6056259754864164 | -1.9726552774939115 | -2.0273320732722038 | -1.9566067099921107 | -1.411404278719365 | -1.2018987756281307 |
| (3.14, 0.32, -0.84, -0.84) | -1.222856818622871 | -2.1070390009359135 | -2.3087733212305803 | -2.028501247040621 | -6.871660852120583 | -1.9549236903857534 |
| (3.14, 0.32, 0.04, 0.04) | 7.504731764849956 | 7.319168881086348 | 5.833517013639717 | 6.813953135114872 | 5.3806627261141715 | 7.2214014881743935 |
| (-1.57, 0.32, -0.84, -0.84) | -2.0784952578162175 | -2.539322562515712 | -2.602761648249586 | -2.3769263148686353 | -8.304341761888848 | -2.4641359904839972 |
| (-1.57, 0.32, 0.04, 0.04) | 1.192354253195841 | 1.3967383778930804 | 0.9176186523340293 | 1.4883133379637077 | 1.5838948724798425 | 1.25251393016976 |
| (1.57, 0.32, -0.84, -0.84) | -2.338828832368719 | -2.785650169536498 | -2.2624553378412986 | -3.1425382102382273 | -9.662547319187091 | -2.7707566479589776 |
| (1.57, 0.32, 0.04, 0.04) | 1.3320893978665271 | -1.4354494843122674 | 1.9712375713696701 | 0.7112851857669051 | 2.322626494886744 | 1.9712692741305342 |

## <div align="center">Final Q values for n = 3</div>

<br>



| States (Index Values) | Go Down | Go Up | Go Left | Go Right | Close Gripper | Open Gripper |
|-----------------------|---------|-------|---------|----------|---------------|--------------|
| (0, 0.53, -0.84, -0.84) | -4.1430570668246265 | -3.776698950661918 | -4.606174006486317 | -3.698285119936012 | -4.704941316081023 | -3.7320413757735222 |
| (0, 0.53, 0.04, 0.04) | 1.8136958767729952 | -0.022917934100008974 | 5.363546093713068 | 5.587672714805393 | -2.416219686766589 | 4.720703828639915 |
| (3.14, 0.53, -0.84, -0.84) | -2.1265931912041407 | -2.9818805399819848 | -2.56521092271839 | -3.3453476704966465 | -3.4468791319405714 | -2.658911733927117 |
| (3.14, 0.53, 0.04, 0.04) | 5.24133855847087 | -0.37692223857301554 | 7.328457195407166 | 2.039698010921861 | 0.6497022752269008 | 7.734365053249565 |
| (-1.57, 0.53, -0.84, -0.84) | -3.0235818954080305 | -2.939228210513532 | -2.976600107912237 | -2.8517685999044913 | -3.7348681963523815 | -2.953401721539118 |
| (-1.57, 0.53, 0.04, 0.04) | 6.137142792253556 | -2.742758013756367 | 5.689130415355066 | 6.2843179887656895 | 7.365915978222444 | 6.862451180633507 |
| (1.57, 0.53, -0.84, -0.84) | -3.142555239153461 | -4.172097955900401 | -3.089632301619883 | -3.23988128007295 | -6.186347927167031 | -3.4178101782571515 |
| (1.57, 0.53, 0.04, 0.04) | 3.8380564718799786 | 0.5897325564500755 | 5.116138510198946 | 3.881127359830318 | -1.056472738881571 | 4.659695162707793 |
| (0, 0.11, -0.84, -0.84) | -5.069520551702623 | -3.1898660112339647 | -2.1192447626978437 | -3.057093371980705 | -10.730143442913155 | -2.957000169322285 |
| (0, 0.11, 0.04, 0.04) | -1.4646870820981674 | 3.976279603404503 | 3.5926846912225203 | 0.956818347361823 | -0.3422752839168153 | 3.501797727296797 |
| (3.14, 0.11, -0.84, -0.84) | -1.0021878419513102 | -1.6257657314476104 | -1.4005061066379434 | -2.15943039045361 | 5.403639502436887e-42 | -1.2556491826267313 |
| (3.14, 0.11, 0.04, 0.04) | 0 | 0 | 0 | 0 | 0 | 0 |
| (-1.57, 0.11, -0.84, -0.84) | -2.762621618799643 | -2.096033054943975 | -2.7102531297788723 | -1.167343774665017 | -10.867040785255927 | -1.7737444159130045 |
| (-1.57, 0.11, 0.04, 0.04) | 3.1195148853589787 | 2.6567413103341013 | 5.950835746616246 | 6.0788327295284645 | -0.5545278957369262 | 6.151092003179529 |
| (1.57, 0.11, -0.84, -0.84) | -2.2111765188306878 | -1.6580453538026505 | -1.1166165420044036 | -1.9092921194542667 | -4.041693745527086 | -2.5148101658062227 |
| (1.57, 0.11, 0.04, 0.04) | 7.325955335095893 | 10.929014660116941 | 11.4383962274805 | 10.192975545486698 | 10.173524354710864 | 11.793426582752538 |
| (0, 0.32, -0.84, -0.84) | -3.802191324971843 | -4.646992519573955 | -4.313275932162313 | -3.111179112491673 | -13.268431185464015 | -3.7936977834133825 |
| (0, 0.32, 0.04, 0.04) | -1.1210714795784746 | 0.6540252161546396 | -0.35971537395970987 | 0.2275444425944091 | -0.38860408480215414 | 0.3025231874521088 |
| (3.14, 0.32, -0.84, -0.84) | -1.0960746764917202 | -2.397017756959974 | -2.677944439618263 | -1.7313270525905622 | -12.515819871344487 | -1.7849388645304463 |
| (3.14, 0.32, 0.04, 0.04) | 7.504731764849956 | 1.6847757106213037 | 8.146830439406699 | 6.232487941177503 | 5.854324221920751 | 7.134552837667934 |
| (-1.57, 0.32, -0.84, -0.84) | -3.8145852069120036 | -3.7626243778222843 | -3.9546165415402923 | -2.4046888974100806 | -10.566531468814853 | -3.7333285938827485 |
| (-1.57, 0.32, 0.04, 0.04) | 1.5949410678257177 | 5.147894097191051 | 4.170605343828882 | 4.644983856880533 | 2.6869591793734764 | 5.001620368071007 |
| (1.57, 0.32, -0.84, -0.84) | -3.1875410799120396 | -3.1999450410395234 | -2.6253877753487553 | -3.640772841868832 | -9.537616883207164 | -3.5462116234677445 |
| (1.57, 0.32, 0.04, 0.04) | 4.1816380610991715 | 4.128594371967086 | 3.6404582866086947 | 4.892403501646526 | -3.4458790422579058 | 3.269345907788935 |

<br>

## <div align="center">Final Q values for n = 5</div>

<br>


In [1]:
def convert_Q_to_policy(Q):
    policy = {}
    for key in Q:
        max_val= max(Q[key])
        policy_arr = [0,0,0,0,0,0]
        for j in range(len(Q[key])):
            if Q[key][j]==max_val:
                policy_arr[j]=1
    
        policy[key] = policy_arr

    return policy

In [4]:
import pybullet as p
import time
import numpy as np
import pybullet_data 
import matplotlib.pyplot as plt
from Env import Pickup_Bot_Env

import math
#results 
Q = np.load('./save_nStepSARSA/3/990/Q.npy' , allow_pickle=True)
path_to_bot = './bot/robot.urdf'

policy = convert_Q_to_policy(Q.item())

Q_before_run , action_to_index = init()
    # Create environment instance
position = (0, 0.54, -0.84, -0.84) # Starting state

env = Pickup_Bot_Env(path_to_bot,position, True , False)


S_t = env.get_current_state()
    
action = env.choose_action(policy[S_t] , epsilon = 0)
   
A_t = action_to_index[action] # as we have saved a .npy the new file needs to be converted back to a dictionary
while env.rounded_position != env.terminal_state:
    env.step(action)
    reward = env.get_reward()
    
    S_t1 = env.rounded_position
    action = env.choose_action(policy[S_t1] , epsilon = 0)
  

    A_t1 = action_to_index[action]
    
    
    print("State | Action | Reward" , S_t , A_t , reward)
    print("State_next | Action_next" , S_t1 ,A_t1 )
    print("\n")
    S_t = S_t1
    A_t = A_t1


env.reset_env()
print("Reached")

(0, 0.54, -0.84, -0.84)
State | Action | Reward (0, 0.53, -0.84, -0.84) 3 -1
State_next | Action_next (-1.57, 0.53, -0.84, -0.84) 3


State | Action | Reward (-1.57, 0.53, -0.84, -0.84) 3 -1
State_next | Action_next (3.14, 0.53, -0.84, -0.84) 0


State | Action | Reward (3.14, 0.53, -0.84, -0.84) 0 -1
State_next | Action_next (3.14, 0.11, -0.84, -0.84) 4


State | Action | Reward (3.14, 0.11, -0.84, -0.84) 4 100
State_next | Action_next (3.14, 0.11, 0.04, 0.04) 0


Reached
