In [1]:
import gym
from gym import spaces
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.style as mplstyle
from matplotlib import animation
from mpl_toolkits import mplot3d
from collections import namedtuple
from itertools import count
from typing import Dict, List, Set, Tuple

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import copy

mplstyle.use('fast')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
class UAV3DGrid(gym.Env):
    
    metadata = {'render.modes': ['console', 'rgb_array']}
    
    def __init__(self, start=(0, 0, 0), end=(29, 29, 15)):
        super(UAV3DGrid, self).__init__()
        
        self.state, _, _ = self._create_world(start=start, end=end)  # Observation
        self.cur_pos = start  # Track the current position of the drone
        self.end_pos = end
        
        self.trace = [self.cur_pos]
        self.fig_on = False
        self.reset_trig = True
        
        self.num_dir = 26
        self.num_speed = 16 # 0 ~ 15 m/s
        
        self.observation_space = spaces.Box(low=-1, high=20, shape=(30, 30, 30), dtype=np.float32)
        self.action_space = spaces.Discrete(self.num_dir * self.num_speed)
        
    def step(self, action):
        # action: 0 ~ 415, transform the action index to direction, speed vector
        a = np.zeros(4, dtype=int) # (dN, dE, dU, speed)
        speed = int(action / self.num_dir) + 1
        direction = action % self.num_dir + 1
        a[3] = speed
        for i in range(2, -1, -1):
            quo = int(direction / (3**i))
            direction = direction % (3**i)
            a[i] = -1 if quo == 2 else quo
        
        # Calculate the next position
        next_pos = tuple(np.array(self.cur_pos) + a[:3])
        
        # Check whether the drone is out of bound first, 
        # then check the drone is on the position of the target or the obstacles
        if any(pos < 0 or pos > 29 for pos in next_pos):
            done = True
            reward = -100
        else:
            if self.state[next_pos] == 20:
                done = True
                reward = 100
            elif self.state[next_pos] == -1:
                done = True
                reward = -100
            else:
                done = False
                if np.linalg.norm(np.array(next_pos) - np.array(self.end_pos)) < np.linalg.norm(np.array(self.cur_pos) - np.array(self.end_pos)):
                    reward = a[3] / 50
                else:
                    reward = -1 / a[3]

            # Update the state
            self.state[self.cur_pos] = 0
            self.cur_pos = next_pos
            self.state[self.cur_pos] = a[3]
            self.trace.append(self.cur_pos)

        obs = self.state
    
        return obs, reward, done, {}
    
    def reset(self, num_obs=9, dist=2):
        self.state, self.cur_pos, self.end_pos = self._create_world(num_obs=num_obs, dist=dist, mode='rand')
        self.trace = [self.cur_pos]
        self.reset_trig = True
        
        return self.state
            
    def render(self, rwd, qval, mode='rgb_array'):
        
        if self.reset_trig:
            self._render_evn(rwd, qval)
        else:
            trace = np.array(self.trace)
            self.ax3d.plot3D(trace[:, 0], trace[:, 1], trace[:, 2], 'r-')
            plt.pause(1e-9)
    
    def _render_evn(self, rwd, qval):
        if self.fig_on == False:
            self.fig = plt.figure(figsize=(8,12))
            self.ax3d = plt.subplot2grid((4, 1), (0, 0), rowspan=2, projection='3d')
            self.ax_rwd = plt.subplot2grid((4, 1), (2, 0))
            self.ax_qval = plt.subplot2grid((4, 1), (3, 0))
            
            self.fig_on = True
        
        self.reset_trig = False
        self.ax3d.clear()
        self.ax_rwd.clear()
        self.ax_qval.clear()
        
        self.ax_rwd.set_xlabel('episode')
        self.ax_rwd.set_ylabel('episode_reward')
        self.ax_rwd.grid()
        self.ax_qval.set_xlabel('episode')
        self.ax_qval.set_ylabel('episode_avg_q_val')
        self.ax_qval.grid()
        
#         plt.show(block=False)
        
        obs_x, obs_y, obs_z = (self.state==-1).nonzero()
        t_x, t_y, t_z = (self.state==20).nonzero()
        
        if rwd == []:
            pass
        else:
            rwd, qval = np.array(rwd), np.array(qval)
            self.ax_rwd.plot(rwd[:, 0], rwd[:, 1], 'b-')
            self.ax_qval.plot(qval[:, 0], qval[:, 1], 'g-')

        self.ax3d.scatter3D(obs_x, obs_y, obs_z, zdir='z', c='black')
        self.ax3d.scatter3D(t_x, t_y, t_z, zdir='z', c='blue')
        
    def _create_world(self, start=(0, 0, 0), end=(29, 29, 15), num_obs=9, dist=2, mode='det'):
        """
        state is a 3D ndarray whose components are followings:
        current position: 1 ~ 16 (1 -> 0 m/s, 16 -> 15 m/s)
        end position: 20
        obstacle position: -1
        the others: 0
        """
        # Set the shape of the grid world
        state = np.zeros((30, 30, 30), dtype=np.float32)
        
        if mode == 'det':
            # Locate the obstacles
            state[3:7, 24:28, 0:17], state[13:17, 24:28, 0:15], state[22:26, 24:28, 0:21] = -1, -1, -1
            state[3:7, 11:20, 0:12], state[13:17, 11:20, 0:19], state[22:26, 11:20, 0:15] = -1, -1, -1
            state[3:7, 3:7  , 0:21], state[13:17, 3:7  , 0:12], state[22:26, 3:7  , 0:19] = -1, -1, -1

            # Check whether Start or End is out of bound or on the obstacle position
            assert ((state[start] != -1) & (state[end] != -1)), 'Start or End is on the obstacle position'
            
            # Assign the drone's initial position as 1 and the target position as 20
            state[start] = 1
            state[end] = 20
        
        else:
            
            obs = {'n': [random.randint(0, 29) for _ in range(30)],
                   'e': [random.randint(0, 29) for _ in range(30)],
                   'u': [random.randint(0, 29) for _ in range(30)]}
        
            for i in range(num_obs):
                state[obs['n'][i]-2:obs['n'][i]+2,
                      obs['e'][i]-2:obs['e'][i]+2,
                      obs['u'][i]-2:obs['u'][i]+2] = -1
            
            a = np.array((np.where(state==0))).transpose()
            si = random.sample(range(len(a)), 1)[0]
            ei_cand = list(filter(lambda x: (np.linalg.norm(a[x] - a[si]) <= dist) & (x!=si), list(range(len(a)))))
            ei = random.sample(ei_cand, 1)[0]
            
            start = (a[si][0], a[si][1], a[si][2])
            end = (a[ei][0], a[ei][1], a[ei][2])
            
            state[start] = 1
            state[end] = 20

        return state, start, end
        
        
        

In [3]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))

class SumTree(object):
    
    def __init__(self, capacity):
        # capacity = # of leaf nodes
        self.capacity = capacity 
        self.tree = np.zeros(2 * capacity - 1)
        self.pt = 0
        
    def push(self, priority):
        idx = self.pt + self.capacity - 1
        self.update(idx, priority)
        
        self.pt = (self.pt + 1) % self.capacity
        
    def update(self, idx, priority):
        diff = priority - self.tree[idx]
        self.tree[idx] = priority
        self._propagate(idx, diff)
        
    def _propagate(self, idx, diff):
        p_idx = (idx - 1) // 2
        self.tree[p_idx] += diff
        
        if p_idx != 0:
            self._propagate(p_idx, diff)
            
    def get_leaf(self, p):
        
        p_idx = self._retrieve(0, p)
        
        return p_idx, self.tree[p_idx]
        
    def _retrieve(self, idx, val):
        left = 2 * idx + 1
        right = left + 1
        
        if left >= len(self.tree):
            return idx
        
        if self.tree[left] >= val:
            return self._retrieve(left, val)
        else:
            return self._retrieve(right, val - self.tree[left])

        
class PrioritizedReplayMemory(object):
    eps = 0.01
    alpha = 0.6
    beta = 0.4
    beta_inc = 0.001
    abs_error_upper = 1.

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.mem_pt = 0
        
        self.sum_tree = SumTree(self.capacity)

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        
        max_priority = np.max(self.sum_tree.tree[-self.sum_tree.capacity:])
        
        if max_priority == 0:
            max_priority = self.abs_error_upper
        
        # push an experience and a priority of it
        self.sum_tree.push(max_priority)
        self.memory[self.mem_pt] = Transition(*args)
        self.mem_pt = (self.mem_pt + 1) % self.capacity

    def sample(self, batch_size):
        
        p_tot = self.sum_tree.tree[0]
        segment = p_tot / batch_size
        
        tis = []
        ps = []
        exps = []
        
        self.beta = np.min([1., self.beta + self.beta_inc])
        
        for i in range(batch_size):
            a = i * segment
            b = (i+1) * segment
            
            rand_p = random.uniform(a, b)
            ti, p = self.sum_tree.get_leaf(rand_p)
            
            tis.append(ti)
            ps.append(p)
            exps.append(self.memory[ti - self.capacity + 1])
            
        probs = np.array(ps, dtype=np.float32) / p_tot
        is_weights = np.power(probs * len(self), -self.beta)
        is_weights /= is_weights.max() # normalize importance weights between 0 ~ 1 for learning stability
            
        return tis, exps, is_weights
    
    def mem_batch_update(self, tis, errors):
        errors += self.eps
        clipped_errors = np.minimum(errors, self.abs_error_upper)
        ps = np.power(clipped_errors, self.alpha)
        
        for ti, p in zip(tis, ps):
            self.sum_tree.update(ti, p)
    

    def __len__(self):
        return len(self.memory)

In [4]:
class DuelingDQN(nn.Module):

    def __init__(self):
        super(DuelingDQN, self).__init__()
        self.conv1 = nn.Conv3d(1, 16, kernel_size=6, stride=2)
        self.conv2 = nn.Conv3d(16, 32, kernel_size=3, stride=2)
        self.conv3 = nn.Conv3d(32, 64, kernel_size=2, stride=2)
        self.bn1 = nn.BatchNorm3d(64)
        self.advantage = nn.Sequential(
            nn.Linear(1728, 864),
            nn.ReLU(),
            nn.Linear(864, 416))
        self.value = nn.Sequential(
            nn.Linear(1728, 864),
            nn.ReLU(),
            nn.Linear(864, 1))
        
#         for layer in [self.conv1, self.conv2, self.conv3]:
#             torch.nn.init.kaiming_uniform_(layer.weight)
        
#         for module in [self.advantage, self.value]:
#             for layer in module:
#                 if type(layer) == nn.Linear:
#                     torch.nn.init.kaiming_uniform_(layer.weight)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.bn1(self.conv3(x)))
        x = x.view(x.size(0), -1)
        advantage = self.advantage(x)
        value = self.value(x)
        out = value + advantage - advantage.mean(dim=-1, keepdim=True)
        
        return out

In [5]:
BATCH_SIZE = 32
GAMMA = 0.999
EPS_START = 0.9
EPS_END = 0.1
EPS_DECAY = 4000
TARGET_UPDATE = 100
start = (0, 0, 0)
target = (29, 29, 15)
    
env = UAV3DGrid(start, target)

policy_net = DuelingDQN().to(device)
target_net = DuelingDQN().to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.Adam(policy_net.parameters())
memory = PrioritizedReplayMemory(2**18)

n_actions = env.action_space.n
steps_done = 0
episode_durations = []


In [6]:
def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            q_a = policy_net(state).max(1)
            qval = q_a[0].item()
            action = q_a[1].view(1,1).to(device)
            return qval, action
    else:
        with torch.no_grad():
            rand_ind = random.randrange(n_actions)
            rand_qval = policy_net(state)[0, rand_ind].item()
            rand_action = torch.tensor([[rand_ind]], dtype=torch.long).to(device)
        
        return rand_qval, rand_action
    

def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    tree_idxes, transitions, is_w = memory.sample(BATCH_SIZE)
    
    batch = Transition(*zip(*transitions)) # zip(*transition) --> transpose, *zip: unpacking
    
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), dtype=torch.bool).to(device)
    
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)
    
    # policy_net(state_batch): BATCH_SIZE x n_actions, action_batch: BATCH_SIZE x 1
    # by gathering, q_sa has values for each state in state_batch with taken actions in action_batch
    q_sa = policy_net(state_batch).gather(1, action_batch)
    
    # terminal states have zero state values, implement double Q-learning
    next_state_values = torch.zeros(BATCH_SIZE).to(device)
    argmax_q_sa = policy_net(non_final_next_states).max(1)[1].detach()
    next_state_values[non_final_mask] = target_net(non_final_next_states)[list(range(len(argmax_q_sa))), argmax_q_sa]
#     next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach() # max Q(s+1,a) = V(s+1)
    expected_q_sa = (next_state_values * GAMMA) + reward_batch
    
    td_errors = torch.abs(expected_q_sa - q_sa).view(-1).detach().cpu().numpy()
    memory.mem_batch_update(tree_idxes, td_errors)
    
    # Huber loss with importance sampling weight
    is_weights = torch.tensor(is_w).view(-1, 1).to(device)
    loss = (is_weights * F.smooth_l1_loss(q_sa, expected_q_sa.unsqueeze(1), reduction='none')).mean().to(device)
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-2, 2)
    optimizer.step()
    
    
    
        

In [None]:
%matplotlib tk

rwds = []
qvals = []
num_episodes = 20000
for i_episode in range(num_episodes):
    epi_rwd = 0
    epi_avg_q = 0
    env.reset(num_obs=(1+i_episode//500), dist=(2+i_episode//1000))
    state = torch.from_numpy(env.state).unsqueeze(0).unsqueeze(0)  # shape:  1 (# batch) x 1 (# channels) x 30 x 30 x 30
    norm_state = ((state - (-1)) / (20 - (-1))).to(device)
    for t in count():
        q_val, action = select_action(norm_state)
        next_state, reward, done, _ = env.step(action.item())
        epi_rwd = reward + GAMMA * epi_rwd
        epi_avg_q += q_val
        
        reward = torch.tensor([reward], dtype=torch.float32).to(device)
        
        if not done:
            next_state = torch.from_numpy(next_state).unsqueeze(0).unsqueeze(0).to(device)
            norm_next_state = (next_state - (-1)) / (20 - (-1))
        else:
            next_state = None
            norm_next_state = None
        
        memory.push(norm_state, action, norm_next_state, reward)
        
        norm_state = norm_next_state
        env.render(rwd=rwds, qval=qvals)
        optimize_model()
        if done:
            episode_durations.append(t + 1)
            rwds.append([i_episode, epi_rwd])
            qvals.append([i_episode, epi_avg_q / (t + 1)])
            break
        
    if i_episode % TARGET_UPDATE == 0:
        target_net.load_state_dict(policy_net.state_dict())
                
print('Complete')
env.close()

        