In [1]:
import gym
import numpy as np
import ptan

import torch
import torch.nn as nn
import torch.optim as optim

from tensorboardX import SummaryWriter

from lib import common

import collections

import warnings
warnings.filterwarnings('ignore')

## Priority Replay Buffer

In [2]:
class PrioReplayBuffer:
    def __init__(self, exp_source, buf_size, prob_alpha=0.6):
        self.exp_source_iter = iter(exp_source)
        self.prob_alpha = prob_alpha
        self.capacity = buf_size
        self.pos = 0
        self.buffer = []
        # each experience in the buffer will be assigned a "priority"
        # initialize each priority to 0
        # the priorities will be the loss for each experience, this
        # will be used to set the probablity of selecting that 
        # experience in other batches, where the more wrong the loss
        # the more likely it is to be selected, so it can train on it better
        self.priorities = np.zeros((buf_size,), dtype=np.float32)
    
    def __len__(self):
        return len(self.buffer)
    
    def populate(self, count):
        # what is the maximum priority in the buffer if it is not empty?
        max_prio = self.priorities.max() if self.buffer else 1.0
        
        # sample "count," experiences, and keep track of the oldest experience in the buffer
        # by tracking it's position within the buffer
        for _ in range(count):
            sample = next(self.exp_source_iter) # generate/get the next experience (s,a,s',r)
            
            if len(self.buffer) < self.capacity:
                self.buffer.append(sample) # append to the buffer if the buffer isn't full
            else:
                self.buffer[self.pos] = sample # replace the oldest experience in the buffer with the new one
                
            self.priorities[self.pos] = max_prio # set the priority of the current experience to the max so,
                                                 # that its very likely to get sampled, since it hasn't
                                                 # been seen before in any batch
            
            self.pos = (self.pos + 1) % self.capacity # cyclic the oldest xp index to the next one
    
    def sample(self, batch_size, beta=0.4):
        
        if len(self.buffer) == self.capacity:
            prios = self.priorities
        else:
            prios = self.priorities[:self.pos]
        
        # convert priority numbers to probabilities via
        # P(i) = priority_i**alpha/(sum(priority_i**alpha))
        probs = prios ** self.prob_alpha
        probs /= probs.sum()
        
        # randomly select from the experience buffer, a batch size sample, with probabilities
        # of experience selection = probs
        indices = np.random.choice(len(self.buffer), batch_size, p=probs, replace=True)
        samples = [self.buffer[idx] for idx in indices]
        
        # because we pull from from a non-uniform sample
        # we need compute the importance sampling weight for 
        # SGD to work, we kind of cheat here by adding a beta
        # in pure importance sampling beta = 1
        total = len(self.buffer)
        weights = (total*probs[indices])**(-beta)
        weights /= weights.max()
        
        # return the samples, the weights, and indices, where the indices
        # are going to be used to update their priorities after the loss
        # is computed on them
        return(samples, indices, np.array(weights, dtype=np.float32))
    
    def update_priorities(self, batch_indices, batch_priorities):
        for idx, prio in zip(batch_indices, batch_priorities):
            self.priorities[idx] = prio

## Experience Buffer

In [3]:
# each experience is a tuple (s,a,r,d,s')
Experience = collections.namedtuple('Experience', 
                                    field_names=['state','action','reward','done','new_state'])

# make a buffer to hold "capacity" number of experiences in a queue
class ExperienceBuffer:
    def __init__(self, capacity):
        self.buffer = collections.deque(maxlen=capacity)
    
    def __len__(self):
        return len(self.buffer)
    
    def append(self,experience):
        self.buffer.append(experience)
    
    def sample(self, batch_size):
        # randomly pick from the experience buffer
        indices = np.random.choice(len(self.buffer), batch_size, replace=False)
        
        states, actions, rewards, dones, next_states = \
        zip(*[self.buffer[idx] for idx in indices])
        
        return np.array(states),np.array(actions),np.array(rewards,dtype=np.float32),\
               np.array(dones),np.array(next_states)

## Agent

In [4]:
# an agent has an env, an experience buffer, a Q(s,a) network
class Agent:
    def __init__(self, env, exp_buffer, net):
        self.env = env
        self.exp_buffer = exp_buffer
        self.net = net
        self._reset()
    
    def _reset(self):
        self.state = env.reset()
        self.total_reward = 0.0
    
    def play_step(self, epsilon=0.0, device="cpu"):
        done_reward = None
        
        # epsilon greedy play
        if np.random.random() < epsilon:
            action = env.action_space.sample()
        else:
            state_a = np.array([self.state], copy=False)
            state_v = torch.tensor(state_a).to(device)
            q_vals_v = self.net(state_v)
            _,act_v = torch.max(q_vals_v,dim=1) # argmax_a Q(s,a)
            action = int(act_v.item())
        
        # take the step
        new_state, reward, is_done, _ = self.env.step(action)
        self.total_reward += reward
        
        # create the experience object -(s,a,r,d,s')
        exp = Experience(self.state, action, reward, is_done, new_state)
        # append the experience to the buffer
        self.exp_buffer.append(exp)
        
        # setup for the next iteration unless done
        self.state = new_state
        if is_done:
            done_reward = self.total_reward
            self._reset()
            return done_reward

## Mean Square Error Loss Function

In [5]:
#MSE loss on priority replay
def calc_loss(batch, net, tgt_net, gamma, device="cpu"):
    # grab a collection of random experiences
    # E - {(s1,a1,r1,d1,s'1),(s2,a2,r2,d2,s'2),...}
    states, actions, rewards, dones, next_states = batch
    
    # convert to pytorch vars
    states_v = torch.tensor(states).to(device)
    next_states_v = torch.tensor(next_states).to(device)
    actions_v = torch.tensor(actions).to(device)
    rewards_v = torch.tensor(rewards).to(device)
    done_mask = torch.ByteTensor(dones).to(device)
    
    # Q(s,a) for each (s,a) pair in the batch 
    Q_sa = net(states_v).gather(1, actions_v.unsqueeze(-1)).squeeze(-1)
    
    # max_a' Q_tgt(s', a') for each s' in the batch
    maxQ_sa = tgt_net(next_states_v).max(1)[0] # max in pytorch gives (values, indices) so you want part 0
    
    # terminal states dont have a max_a' Q_tgt(s',a'), the target for Q_sa is just r(s,a)
    maxQ_sa[done_mask] = 0.0
    
    # build the Q(s,a) target
    # y(s,a) = r(s,a) + gamma * max_a' Q_tgt(s',a')
    y_sa = rewards_v + gamma*maxQ_sa.detach() # i use detach because I dont want the parameters of Q_tgt to change 
    
    # sampling MSE loss,
    # l_i = (Q(s_i,a_i) - y(s_i,a_i))**2
    loss_v = (Q_sa - y_sa)**2
    
    #  L = 1/N sum_i(l_i), l_i + delta -- passed back for updating priorities
    return(loss_v.mean(), loss_v + 1e-5) # avoid divide by zero issue with 1e-5

## Build Environment

In [6]:
# Environment setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

env = gym.make('LunarLander-v2')


params = {'run_name': 'lunarlander',
         'epsilon_start': 1.0,
         'epsilon_final': 0.02,
         'epsilon_frames':   10**5,
         'replay_size': 100000,
         'replay_initial': 10000,
         'batch_size': 32,
         'gamma': 0.99,
         'target_net_sync':1000,
         }

GAMMA = 0.95 # reward discount
LEARNING_RATE = 0.001
STOP_REWARD = 199 # 200 points to "solve" it

## Q(s,a) approximating function

In [7]:
class DQN(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(DQN, self).__init__()

        self.fc = nn.Sequential(
            nn.Linear(input_shape, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )

    def forward(self, x):
        return self.fc(x)

In [17]:
EPSILON_FINAL = 0.01
EPSILON_START = 1.0
EPSILON_DECAY_LAST_FRAME = 100000
REPLAY_START_SIZE = 1000

# intial network
net = DQN(env.observation_space.shape[0], env.action_space.n).to(device)
# target network
tgt_net = DQN(env.observation_space.shape[0], env.action_space.n).to(device)

# experience buffer to hold recent experiences
exp_buffer = ExperienceBuffer(params['replay_size'])

# agent to play the game
agent = Agent(env, exp_buffer, net)


optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)

frame_idx = 0

total_rewards = []
best_mean_reward = None
MEAN_REWARD_BOUND = 200
SYNC_TARGET_FRAMES = 1000
BATCH_SIZE = 256

while True:
    frame_idx += 1
    
    epsilon = max(EPSILON_FINAL, EPSILON_START - frame_idx / EPSILON_DECAY_LAST_FRAME)
    reward = agent.play_step(epsilon, device=device)
    
    if reward is not None:
        total_rewards.append(reward)
        mean_reward = np.mean(total_rewards[-100:])
        
        if best_mean_reward is None or best_mean_reward < mean_reward:
            if best_mean_reward is not None:
                print("best mean reward updated %.3f -> %.3f"%(best_mean_reward,mean_reward))
                best_mean_reward = mean_reward
        
        if mean_reward > MEAN_REWARD_BOUND:
            print("Solved in %d frames"%frame_idx)
            break
    
    # dont start training until we have a full replay queue
    if len(exp_buffer) < REPLAY_START_SIZE:
        continue
    
    # match the target network to the current one
    if frame_idx % SYNC_TARGET_FRAMES == 0:
        tgt_net.load_state_dict(net.state_dict())
    
    optimizer.zero_grad()
    batch = exp_buffer.sample(BATCH_SIZE)
    loss_t = calc_loss(batch, net, tgt_net, device=device)
    loss_t.backward()
    optimizer.step()

TypeError: calc_loss() missing 1 required positional argument: 'gamma'

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.imshow(frames[5])