In [1]:
import gym
import ptan
import numpy as np
from torch.utils.tensorboard import SummaryWriter

In [2]:
import torch
import torch.nn as nn 
import torch.optim as optim

In [3]:
GAMMA = 0.99
LEARNING_RATE = 0.01
BATCH_SIZE = 8

EPSILON_START = 1.0
EPSILON_STOP = 0.02
EPLSION_STEP = 5000

REPLAY_BUFFER = 50000

In [4]:
class DQN(nn.Module):
    def __init__(self, input_size, n_actions) -> None:
        super(DQN, self).__init__()
        
        self.net = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions)
        )
        
    def forward(self, x):
        return self.net(x)

In [5]:
def calc_target(net, local_reward, next_state):
    if next_state is None:
        return local_reward
    state_v = torch.tensor([next_state], dtype=torch.float32)
    next_q_v = net(state_v)
    best_q = next_q_v.max(dim=1)[0].item()
    return local_reward + GAMMA * best_q

In [6]:
env = gym.make("CartPole-v1")
writer = SummaryWriter(comment="-cartpole=dqn")

net = DQN(env.observation_space.shape[0], env.action_space.n)
print(net)

DQN(
  (net): Sequential(
    (0): Linear(in_features=4, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=2, bias=True)
  )
)


In [7]:
selector = ptan.actions.EpsilonGreedyActionSelector(epsilon=EPSILON_START)
agent = ptan.agent.DQNAgent(net, selector, preprocessor=ptan.agent.float32_preprocessor)
exp_source = ptan.experience.ExperienceSourceFirstLast(env, agent, gamma=GAMMA)
replay_buffer = ptan.experience.ExperienceReplayBuffer(exp_source, REPLAY_BUFFER)

optimizer = optim.Adam(net.parameters(),lr=LEARNING_RATE)
mse_loss = nn.MSELoss()

In [8]:
total_rewards = []
step_idx = 0
done_episodes = 0

In [9]:
while True:
    step_idx += 1
    selector.epsilon = max(EPSILON_STOP, EPSILON_START - step_idx / EPLSION_STEP)
    replay_buffer.populate(1)
    
    if len(replay_buffer) < BATCH_SIZE:
        continue
    
    # sample batch
    batch = replay_buffer.sample(BATCH_SIZE)
    batch_states = [exp.state for exp in batch]
    batch_actions = [exp.action for exp in batch]
    batch_targets = [calc_target(net, exp.reward, exp.last_state) for exp in batch]
    
    # train
    optimizer.zero_grad()
    states_v = torch.FloatTensor(batch_states)
    net_q_v = net(states_v)
    target_q = net_q_v.data.numpy().copy()
    target_q[range(BATCH_SIZE), batch_actions] = batch_targets
    target_q_v = torch.tensor(target_q)
    loss_v = mse_loss(net_q_v, target_q_v)
    loss_v.backward()
    optimizer.step()
    
    # handle new rewards
    new_rewards = exp_source.pop_total_rewards()
    if new_rewards:
        done_episodes += 1
        reward = new_rewards[0]
        total_rewards.append(reward)
        mean_rewards = float(np.mean(total_rewards[-100:]))
        print("%d: reward: %6.2f, mean_100: %6.2f, epsilon: %.2f, episodes: %d"
              % (step_idx, reward, mean_rewards, selector.epsilon, done_episodes))
        writer.add_scalar("reward", reward, step_idx)
        writer.add_scalar("reward_100", reward, step_idx)
        writer.add_scalar("epsilon", selector.epsilon, step_idx)
        writer.add_scalar("episodes", done_episodes, step_idx)
        if mean_rewards > 195:
            print("Solved in %d steps and %d episodes!"
                  % (step_idx, done_episodes))
            break
writer.close()

16: reward:  15.00, mean_100:  15.00, epsilon: 1.00, episodes: 1
39: reward:  23.00, mean_100:  19.00, epsilon: 0.99, episodes: 2
52: reward:  13.00, mean_100:  17.00, epsilon: 0.99, episodes: 3
65: reward:  13.00, mean_100:  16.00, epsilon: 0.99, episodes: 4
77: reward:  12.00, mean_100:  15.20, epsilon: 0.98, episodes: 5
89: reward:  12.00, mean_100:  14.67, epsilon: 0.98, episodes: 6
105: reward:  16.00, mean_100:  14.86, epsilon: 0.98, episodes: 7
122: reward:  17.00, mean_100:  15.12, epsilon: 0.98, episodes: 8


  state_v = torch.tensor([next_state], dtype=torch.float32)


156: reward:  34.00, mean_100:  17.22, epsilon: 0.97, episodes: 9
170: reward:  14.00, mean_100:  16.90, epsilon: 0.97, episodes: 10
197: reward:  27.00, mean_100:  17.82, epsilon: 0.96, episodes: 11
218: reward:  21.00, mean_100:  18.08, epsilon: 0.96, episodes: 12
249: reward:  31.00, mean_100:  19.08, epsilon: 0.95, episodes: 13
275: reward:  26.00, mean_100:  19.57, epsilon: 0.94, episodes: 14
301: reward:  26.00, mean_100:  20.00, epsilon: 0.94, episodes: 15
317: reward:  16.00, mean_100:  19.75, epsilon: 0.94, episodes: 16
333: reward:  16.00, mean_100:  19.53, epsilon: 0.93, episodes: 17
367: reward:  34.00, mean_100:  20.33, epsilon: 0.93, episodes: 18
390: reward:  23.00, mean_100:  20.47, epsilon: 0.92, episodes: 19
406: reward:  16.00, mean_100:  20.25, epsilon: 0.92, episodes: 20
419: reward:  13.00, mean_100:  19.90, epsilon: 0.92, episodes: 21
433: reward:  14.00, mean_100:  19.64, epsilon: 0.91, episodes: 22
469: reward:  36.00, mean_100:  20.35, epsilon: 0.91, episodes: