In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as np
import gym
from collections import deque

class n_step_replay_buffer(object):
    def __init__(self, capacity, n_step, gamma):
        self.capacity = capacity
        self.n_step = n_step
        self.gamma = gamma
        self.memory = deque(maxlen=self.capacity)
        self.n_step_buffer = deque(maxlen=self.n_step)

    def _get_n_step_info(self):
        reward, next_observation, done = self.n_step_buffer[-1][-3:]
        for _, _, rew, next_obs, do in reversed(list(self.n_step_buffer)[: -1]):
            reward = self.gamma * reward * (1 - do) + rew
            next_observation, done = (next_obs, do) if do else (next_observation, done)
        return reward, next_observation, done

    def store(self, observation, action, reward, next_observation, done):
        print(observation)
        observation = np.expand_dims(observation, 0)
        next_observation = np.expand_dims(next_observation, 0)

        self.n_step_buffer.append([observation, action, reward, next_observation, done])
        if len(self.n_step_buffer) < self.n_step:
            return
        reward, next_observation, done = self._get_n_step_info()
        observation, action = self.n_step_buffer[0][: 2]
        self.memory.append([observation, action, reward, next_observation, done])

    def sample(self, batch_size):
        batch = random.sample(self.memory, batch_size)
        observation, action, reward, next_observation, done = zip(* batch)
        return np.concatenate(observation, 0), action, reward, np.concatenate(next_observation, 0), done

    def __len__(self):
        return len(self.memory)


class ddqn(nn.Module):
    def __init__(self, observation_dim, action_dim):
        super(ddqn, self).__init__()
        self.observation_dim = observation_dim
        self.action_dim = action_dim

        self.fc1 = nn.Linear(self.observation_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, self.action_dim)

    def forward(self, observation):
        x = F.relu(self.fc1(observation))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def act(self, observation, epsilon):
        if random.random() > epsilon:
            q_value = self.forward(observation)
            action = q_value.max(1)[1].data[0].item()
        else:
            action = random.choice(list(range(self.action_dim)))
        return action


def train(buffer, target_model, eval_model, gamma, optimizer, batch_size, loss_fn, count, soft_update_freq, n_step):
    observation, action, reward, next_observation, done = buffer.sample(batch_size)

    observation = torch.FloatTensor(observation)
    action = torch.LongTensor(action)
    reward = torch.FloatTensor(reward)
    next_observation = torch.FloatTensor(next_observation)
    done = torch.FloatTensor(done)

    q_values = eval_model.forward(observation)
    next_q_values = target_model.forward(next_observation)
    argmax_actions = eval_model.forward(next_observation).max(1)[1].detach()
    next_q_value = next_q_values.gather(1, argmax_actions.unsqueeze(1)).squeeze(1)
    q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1)
    expected_q_value = reward + (gamma ** n_step) * (1 - done) * next_q_value

    #loss = loss_fn(q_value, expected_q_value.detach())
    loss = (expected_q_value.detach() - q_value).pow(2)
    loss = loss.mean()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if count % soft_update_freq == 0:
        target_model.load_state_dict(eval_model.state_dict())


if __name__ == '__main__':
    gamma = 0.99
    learning_rate = 1e-3
    batch_size = 64
    soft_update_freq = 200
    capacity = 10000
    exploration = 100
    epsilon_init = 0.9
    epsilon_min = 0.01
    decay = 0.99
    episode = 1000000
    n_step = 4
    render = False

    env = gym.make('CartPole-v1')
    env = env.unwrapped
    observation_dim = env.observation_space.shape[0]
    action_dim = env.action_space.n
    target_net = ddqn(observation_dim, action_dim)
    eval_net = ddqn(observation_dim, action_dim)
    eval_net.load_state_dict(target_net.state_dict())
    optimizer = torch.optim.Adam(eval_net.parameters(), lr=learning_rate)
    buffer = n_step_replay_buffer(capacity, n_step, gamma)
    loss_fn = nn.MSELoss()
    epsilon = epsilon_init
    count = 0

    weight_reward = None
    for i in range(episode):
        obs = env.reset()
        if epsilon > epsilon_min:
            epsilon = epsilon * decay
        reward_total = 0
        if render:
            env.render()
        while True:
            action = eval_net.act(torch.FloatTensor(np.expand_dims(obs, 0)), epsilon)
            count += 1
            next_obs, reward, done, info = env.step(action)
            buffer.store(obs, action, reward, next_obs, done)
            reward_total += reward
            obs = next_obs

            if i > exploration:
                train(buffer, target_net, eval_net, gamma, optimizer, batch_size, loss_fn, count, soft_update_freq, n_step)

            if done:
                if not weight_reward:
                    weight_reward = reward_total
                else:
                    weight_reward = 0.99 * weight_reward + 0.01 * reward_total
                print('episode: {}  epsilon: {:.2f}  reward: {}  weight_reward: {:.3f}'.format(i+1, epsilon, reward_total, weight_reward))
                break

[ 0.04128196 -0.04920882 -0.03878732 -0.01796029]
[ 0.04029779 -0.24375367 -0.03914653  0.2622369 ]
[ 0.03542272 -0.4382956  -0.03390179  0.5423201 ]
[ 0.0266568  -0.6329251  -0.02305539  0.8241313 ]
[ 0.0139983  -0.8277242  -0.00657276  1.1094747 ]
[-0.00255618 -1.0227592   0.01561673  1.4000884 ]
[-0.02301137 -1.2180717   0.0436185   1.6976126 ]
[-0.0473728  -1.4136686   0.07757075  2.0035486 ]
[-0.07564617 -1.609509    0.11764173  2.3192089 ]
[-0.10783635 -1.8054892   0.1640259   2.6456542 ]
episode: 1  epsilon: 0.89  reward: 10.0  weight_reward: 10.000
[ 0.01138813  0.03931922 -0.02711787  0.03114255]
[ 0.01217451  0.23481935 -0.02649502 -0.2699715 ]
[ 0.0168709   0.43030918 -0.03189445 -0.570892  ]
[ 0.02547709  0.62586355 -0.04331229 -0.8734497 ]
[ 0.03799435  0.8215468  -0.06078128 -1.1794292 ]
[ 0.05442529  0.62726444 -0.08436986 -0.9064024 ]
[ 0.06697058  0.43337992 -0.10249791 -0.64138544]
[ 0.07563818  0.62977004 -0.11532561 -0.96450627]
[ 0.08823358  0.43637037 -0.13461575 

[-0.06487426 -0.21065941  0.1098417   0.41443607]
[-0.06908745 -0.40715286  0.11813043  0.73962855]
[-0.07723051 -0.6036906   0.13292299  1.0670302 ]
[-0.08930431 -0.8002966   0.1542636   1.3983021 ]
[-0.10531025 -0.60739225  0.18222965  1.1575555 ]
[-0.11745809 -0.41505155  0.20538075  0.92710114]
episode: 19  epsilon: 0.74  reward: 23.0  weight_reward: 11.588
[-0.0414802  -0.04152054  0.04590651 -0.00029695]
[-0.04231061  0.15291402  0.04590057 -0.27814972]
[-0.03925233  0.34735212  0.04033758 -0.55600965]
[-0.03230529  0.5418852   0.02921738 -0.83571595]
[-0.02146759  0.3463766   0.01250307 -0.5339895 ]
[-0.01454005  0.5413205   0.00182328 -0.82270664]
[-0.00371364  0.7364175  -0.01463086 -1.1148156 ]
[ 0.01101471  0.5414906  -0.03692717 -0.8267579 ]
[ 0.02184452  0.34689257 -0.05346233 -0.54591393]
[ 0.02878237  0.15256095 -0.06438061 -0.27054346]
[ 0.03183359 -0.04158604 -0.06979147  0.00115887]
[ 0.03100187  0.15446381 -0.06976829 -0.31270206]
[ 0.03409114  0.35050675 -0.07602233

[-0.00768108 -0.7530614   0.05112153  1.2300177 ]
[-0.0227423  -0.9488025   0.07572188  1.5382687 ]
[-0.04171835 -1.1447495   0.10648725  1.8535881 ]
[-0.06461334 -1.3408687   0.14355901  2.1773486 ]
[-0.09143072 -1.5370659   0.18710598  2.5106785 ]
episode: 31  epsilon: 0.66  reward: 9.0  weight_reward: 12.364
[ 0.04522843 -0.0122843   0.02571096  0.02086126]
[ 0.04498274 -0.20776536  0.02612819  0.3215441 ]
[ 0.04082743 -0.40324944  0.03255907  0.622351  ]
[ 0.03276245 -0.20859692  0.04500609  0.3400976 ]
[ 0.02859051 -0.01414326  0.05180804  0.06193982]
[ 0.02830764 -0.20996827  0.05304684  0.37050796]
[ 0.02410828 -0.4058022   0.060457    0.67943406]
[ 0.01599223 -0.21156983  0.07404568  0.40638137]
[ 0.01176084 -0.40765938  0.0821733   0.72145927]
[ 0.00360765 -0.60381615  0.09660249  1.0388334 ]
[-0.00846867 -0.80007976  0.11737916  1.3602132 ]
[-0.02447027 -0.9964612   0.14458342  1.6871886 ]
[-0.04439949 -1.1929294   0.1783272   2.0211756 ]
episode: 32  epsilon: 0.65  reward: 1

[-0.03835072 -0.01773185 -0.02121132  0.04094824]
[-0.03870536 -0.21254331 -0.02039236  0.32686403]
[-0.04295622 -0.40736908 -0.01385508  0.613047  ]
[-0.0511036  -0.21205628 -0.00159414  0.31603265]
[-0.05534473 -0.40715548  0.00472652  0.6082124 ]
[-0.06348784 -0.21209992  0.01689077  0.31702194]
[-0.06772984 -0.01722258  0.0232312   0.02971327]
[-0.06807429 -0.21266985  0.02382547  0.32963443]
[-0.07232769 -0.40812272  0.03041816  0.6297345 ]
[-0.08049014 -0.60365564  0.04301285  0.9318398 ]
[-0.09256326 -0.7993308   0.06164965  1.237723  ]
[-0.10854987 -0.60505265  0.08640411  0.96497285]
[-0.12065092 -0.8012223   0.10570356  1.2835008 ]
[-0.13667537 -0.6075932   0.13137358  1.0256972 ]
[-0.14882724 -0.8041963   0.15188752  1.3565735 ]
[-0.16491115 -1.0008621   0.17901899  1.6926594 ]
episode: 44  epsilon: 0.58  reward: 16.0  weight_reward: 12.899
[0.00942695 0.01364919 0.04971664 0.01781895]
[ 0.00969993 -0.1821492   0.05007302  0.3257643 ]
[ 0.00605695 -0.377947    0.0565883   0.

[-0.04036337 -0.80553174  0.03779696  1.1991605 ]
[-0.056474   -1.0011218   0.06178017  1.5034456 ]
[-0.07649644 -1.1969367   0.09184908  1.8147595 ]
[-0.10043518 -1.0029494   0.12814426  1.5519706 ]
[-0.12049416 -1.199354    0.15918368  1.8817335 ]
[-0.14448124 -1.3958118   0.19681835  2.219297  ]
episode: 55  epsilon: 0.52  reward: 10.0  weight_reward: 13.190
[ 0.03807228 -0.01872647 -0.01166362  0.02026937]
[ 0.03769775 -0.21367922 -0.01125823  0.30924958]
[ 0.03342417 -0.01839869 -0.00507324  0.0130375 ]
[ 0.03305619 -0.21344753 -0.00481249  0.30411544]
[ 0.02878724 -0.40850055  0.00126982  0.5952768 ]
[ 0.02061723 -0.60364026  0.01317535  0.88835937]
[ 0.00854443 -0.7989385   0.03094254  1.1851548 ]
[-0.00743434 -0.9944479   0.05464564  1.4873741 ]
[-0.0273233  -1.1901914   0.08439312  1.7966088 ]
[-0.05112713 -1.3861506   0.1203253   2.1142836 ]
[-0.07885014 -1.1924183   0.16261098  1.861077  ]
[-0.1026985  -1.3889077   0.19983251  2.1995199 ]
episode: 56  epsilon: 0.51  reward: 

[-0.17514497 -1.9300886   0.18806855  2.5597064 ]
episode: 66  epsilon: 0.46  reward: 23.0  weight_reward: 13.388
[ 0.04058987 -0.0177683   0.0259735  -0.00880433]
[ 0.04023451 -0.21325293  0.02579741  0.29195908]
[ 0.03596945 -0.01850812  0.03163659  0.00752265]
[ 0.03559928 -0.21406917  0.03178705  0.31001705]
[ 0.0313179  -0.01941419  0.03798738  0.02752589]
[ 0.03092962 -0.21505973  0.0385379   0.3319481 ]
[ 0.02662842 -0.41070843  0.04517686  0.6365306 ]
[ 0.01841425 -0.21624467  0.05790748  0.35841003]
[ 0.01408936 -0.41213998  0.06507568  0.66877574]
[ 0.00584656 -0.6081036   0.07845119  0.9812177 ]
[-0.00631551 -0.80418426  0.09807555  1.2974751 ]
[-0.0223992  -1.0004053   0.12402505  1.6191802 ]
[-0.0424073  -1.1967515   0.15640865  1.9478072 ]
[-0.06634233 -1.3931549   0.1953648   2.284613  ]
episode: 67  epsilon: 0.46  reward: 14.0  weight_reward: 13.394
[-0.04973927  0.04652892  0.01440836  0.00692785]
[-0.04880869 -0.14879668  0.01454692  0.30412173]
[-0.05178463 -0.344122

[ 0.07241023 -0.8199752  -0.00383308  1.1601474 ]
[ 0.05601073 -1.015047    0.01936987  1.4516261 ]
[ 0.03570979 -1.2104014   0.0484024   1.7502973 ]
[ 0.01150176 -1.4060384   0.08340834  2.0576336 ]
[-0.01661901 -1.2118616   0.12456101  1.7918746 ]
[-0.04085624 -1.4081407   0.16039851  2.1205385 ]
[-0.06901905 -1.2149397   0.20280927  1.8814124 ]
episode: 85  epsilon: 0.38  reward: 23.0  weight_reward: 13.151
[-0.0340714   0.03745471  0.02460158  0.04695948]
[-0.03332231 -0.15801121  0.02554077  0.34730178]
[-0.03648254 -0.35348696  0.03248681  0.64792794]
[-0.04355228 -0.5490461   0.04544537  0.9506613 ]
[-0.0545332  -0.3545643   0.06445859  0.67259616]
[-0.06162448 -0.55052024  0.07791051  0.9848573 ]
[-0.07263489 -0.7465944   0.09760766  1.300959  ]
[-0.08756678 -0.9428101   0.12362684  1.6225326 ]
[-0.10642298 -1.1391518   0.15607749  1.9510514 ]
[-0.12920602 -1.3355515   0.19509852  2.2877705 ]
episode: 86  epsilon: 0.38  reward: 10.0  weight_reward: 13.119
[ 0.01003914  0.000170

[ 0.04662298 -0.1739676  -0.02209235  0.31359598]
[ 0.04314363 -0.36876798 -0.01582043  0.59923065]
[ 0.03576827 -0.1734283  -0.00383582  0.30160674]
[ 0.0322997  -0.36849537  0.00219632  0.5930775 ]
[ 0.0249298  -0.563648    0.01405787  0.8864514 ]
[ 0.01365684 -0.7589579   0.03178689  1.1835203 ]
[-0.00152232 -0.95447755  0.0554573   1.4859953 ]
[-0.02061187 -1.1502298   0.08517721  1.7954683 ]
[-0.04361647 -1.346196    0.12108658  2.113364  ]
[-0.07054039 -1.1524743   0.16335385  1.8604213 ]
[-0.09358988 -1.348968    0.20056228  2.1990511 ]
episode: 102  epsilon: 0.32  reward: 12.0  weight_reward: 12.828
[ 0.02429234 -0.01080323 -0.01602724  0.01879141]
[ 0.02407627 -0.20569171 -0.01565141  0.30637476]
[ 0.01996244 -0.40058717 -0.00952392  0.5940808 ]
[ 0.0119507  -0.59557456  0.0023577   0.8837486 ]
[ 3.9204988e-05 -7.9072845e-01  2.0032670e-02  1.1771718e+00]
[-0.01577536 -0.59587234  0.04357611  0.8908355 ]
[-0.02769281 -0.79155755  0.06139281  1.196892  ]
[-0.04352396 -0.9874180

[-0.02735748  0.4164485   0.01695463 -0.5262658 ]
[-0.01902851  0.6113278   0.00642932 -0.81355834]
[-0.00680195  0.80636114 -0.00984185 -1.104212  ]
[ 0.00932527  1.0016111  -0.03192609 -1.3999664 ]
[ 0.0293575   1.1971151  -0.05992541 -1.7024574 ]
[ 0.0532998   1.3928736  -0.09397456 -2.0131757 ]
[ 0.08115727  1.5888381  -0.13423808 -2.3334134 ]
[ 0.11293403  1.7848942  -0.18090634 -2.6641996 ]
episode: 113  epsilon: 0.29  reward: 12.0  weight_reward: 12.964
[-0.04844857 -0.02509494  0.01572362  0.02642708]
[-0.04895047 -0.22043881  0.01625216  0.3240292 ]
[-0.05335924 -0.025552    0.02273275  0.03651553]
[-0.05387028  0.1692367   0.02346306 -0.24890919]


KeyboardInterrupt: 