In [1]:
#  purpose of this nb is to figure out an efficienct way to calculate discounted returns
import numpy as np
from energy_py import Utils

In [2]:
discount_factor = 1.0
#  rewards are from memory
rewards = Utils.load_pickle('disc_returns_data/ep_1_rewards.pickle')
#  actual rewards are kept within episode (ie appending during experiment)
actual_rewards = Utils.load_pickle('disc_returns_data/ep_1_rewards_list.pickle')
#  actual rewards are kept within episode (ie appending during experiment)
memory_returns = Utils.load_pickle('disc_returns_data/ep_1_rtns.pickle')

for r, r_ in zip(rewards, actual_rewards):
    assert r == r_

In [3]:
#  first method taken from dennybritz
def dennybritz(rewards):
    #  soln from dennybritz
    rtns = []
    for t, transition in enumerate(rewards):
        if t == None:
            total_return = 0
        else:    
            total_return = sum(discount_factor**i * r for i, r in enumerate(rewards[t:]))
        rtns.append(total_return)
    return rtns

In [4]:
#  second method is a more efficient way to do it
def method_two(rewards):
    rewards = list(rewards)
    rewards.reverse()
    rtns_ = np.zeros(len(rewards))
    for i, r in enumerate(rewards):
        if i == 0:
            total_return = r

        else:
            total_return = r + discount_factor * rtns_[i-1]

        rtns_[i] = total_return    

    rtns_ = list(rtns_)
    rtns_.reverse()
    return rtns_

In [5]:
def method_three(rewards):
    #  now we can calculate the Monte Carlo discounted return               
    #  R = the return from s'                                               
    R, returns = 0, []                                                      
    #  note that we reverse the list here                                   
    for r in rewards[::-1]:                                                                                             
        R = r + discount_factor * R  # the Bellman equation                   
        returns.insert(0, R)   
    return returns

In [6]:
def keon(rewards):
    """
    https://github.com/keon/policy-gradient/blob/master/pg.py
    """
    discounted_rewards = np.zeros_like(rewards)
    running_add = 0
    for t in reversed(range(0, rewards.size)):
        if rewards[t] != 0:
            running_add = 0
        running_add = running_add * discount_factor + rewards[t]
        discounted_rewards[t] = running_add
    return discounted_rewards

In [7]:
%%time
dennybritz(rewards)

CPU times: user 56 ms, sys: 0 ns, total: 56 ms
Wall time: 59.1 ms


[array([-1378.68225098], dtype=float32),
 array([-1370.85998535], dtype=float32),
 array([-1362.97375488], dtype=float32),
 array([-1354.80822754], dtype=float32),
 array([-1346.15759277], dtype=float32),
 array([-1336.83581543], dtype=float32),
 array([-1326.68566895], dtype=float32),
 array([-1316.76037598], dtype=float32),
 array([-1307.53015137], dtype=float32),
 array([-1299.0423584], dtype=float32),
 array([-1291.31921387], dtype=float32),
 array([-1284.35754395], dtype=float32),
 array([-1278.12634277], dtype=float32),
 array([-1272.57043457], dtype=float32),
 array([-1267.61242676], dtype=float32),
 array([-1263.1607666], dtype=float32),
 array([-1259.1126709], dtype=float32),
 array([-1255.36047363], dtype=float32),
 array([-1251.79309082], dtype=float32),
 array([-1248.2989502], dtype=float32),
 array([-1244.76525879], dtype=float32),
 array([-1241.07824707], dtype=float32),
 array([-1237.12414551], dtype=float32),
 array([-1232.78796387], dtype=float32),
 array([-1227.958618

In [8]:
%%time
method_two(rewards)

CPU times: user 4 ms, sys: 0 ns, total: 4 ms
Wall time: 894 µs


[-1378.6820068359375,
 -1370.8597412109375,
 -1362.9735107421875,
 -1354.8079833984375,
 -1346.1573486328125,
 -1336.83544921875,
 -1326.685546875,
 -1316.7601318359375,
 -1307.530029296875,
 -1299.0419921875,
 -1291.31884765625,
 -1284.3570556640625,
 -1278.1259765625,
 -1272.56982421875,
 -1267.6119384765625,
 -1263.16015625,
 -1259.112060546875,
 -1255.3597412109375,
 -1251.79248046875,
 -1248.29833984375,
 -1244.7646484375,
 -1241.0777587890625,
 -1237.1234130859375,
 -1232.7874755859375,
 -1227.9581298828125,
 -1222.5303955078125,
 -1216.4119873046875,
 -1209.5306396484375,
 -1201.841064453125,
 -1193.33056640625,
 -1184.0213623046875,
 -1173.9688720703125,
 -1163.6746826171875,
 -1154.1988525390625,
 -1145.408203125,
 -1137.1378173828125,
 -1129.197265625,
 -1121.379150390625,
 -1113.4674072265625,
 -1105.24755859375,
 -1096.515869140625,
 -1087.089599609375,
 -1076.816650390625,
 -1066.9825439453125,
 -1057.851318359375,
 -1049.466796875,
 -1041.8482666015625,
 -1034.98852539062

In [9]:
%%time
method_three(rewards)

CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 629 µs


[array([-1378.68200684], dtype=float32),
 array([-1370.85974121], dtype=float32),
 array([-1362.97351074], dtype=float32),
 array([-1354.8079834], dtype=float32),
 array([-1346.15734863], dtype=float32),
 array([-1336.83544922], dtype=float32),
 array([-1326.68554688], dtype=float32),
 array([-1316.76013184], dtype=float32),
 array([-1307.5300293], dtype=float32),
 array([-1299.04199219], dtype=float32),
 array([-1291.31884766], dtype=float32),
 array([-1284.35705566], dtype=float32),
 array([-1278.12597656], dtype=float32),
 array([-1272.56982422], dtype=float32),
 array([-1267.61193848], dtype=float32),
 array([-1263.16015625], dtype=float32),
 array([-1259.11206055], dtype=float32),
 array([-1255.35974121], dtype=float32),
 array([-1251.79248047], dtype=float32),
 array([-1248.29833984], dtype=float32),
 array([-1244.76464844], dtype=float32),
 array([-1241.07775879], dtype=float32),
 array([-1237.12341309], dtype=float32),
 array([-1232.78747559], dtype=float32),
 array([-1227.9581

In [10]:
memory_returns

array([[-1378.68200684],
       [-1370.85974121],
       [-1362.97351074],
       [-1354.8079834 ],
       [-1346.15734863],
       [-1336.83544922],
       [-1326.68554688],
       [-1316.76013184],
       [-1307.5300293 ],
       [-1299.04199219],
       [-1291.31884766],
       [-1284.35705566],
       [-1278.12597656],
       [-1272.56982422],
       [-1267.61193848],
       [-1263.16015625],
       [-1259.11206055],
       [-1255.35974121],
       [-1251.79248047],
       [-1248.29833984],
       [-1244.76464844],
       [-1241.07775879],
       [-1237.12341309],
       [-1232.78747559],
       [-1227.95812988],
       [-1222.53039551],
       [-1216.4119873 ],
       [-1209.53063965],
       [-1201.84106445],
       [-1193.33056641],
       [-1184.0213623 ],
       [-1173.96887207],
       [-1163.67468262],
       [-1154.19885254],
       [-1145.40820312],
       [-1137.13781738],
       [-1129.19726562],
       [-1121.37915039],
       [-1113.46740723],
       [-1105.24755859],
