In [41]:
import gym
from gym import spaces
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import animation
from mpl_toolkits import mplot3d
from collections import namedtuple
from itertools import count
from typing import Dict, List, Set, Tuple

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [42]:
class UAV3DGrid(gym.Env):
    
    metadata = {'render.modes': ['console', 'rgb_array']}
    
    def __init__(self, state, start):
        super(UAV3DGrid, self).__init__()
        
        self.state = state  # Observation
        self.ini_state = state
        self.ini_pos = start  # For reset method
        self.cur_pos = start  # Track the current position of the drone
        
        self.trace = [self.ini_pos]
        self.fig_on = False
        
        self.num_dir = 26
        self.num_speed = 16 # 0 ~ 15 m/s
        
        self.observation_space = spaces.Box(low=-1, high=20, shape=(30, 30, 30), dtype=np.float32)
        self.action_space = spaces.Discrete(self.num_dir * self.num_speed)
        
    def step(self, action):
        # action: 0 ~ 415, transform the action index to direction, speed vector
        a = np.zeros(4, dtype=int) # (dN, dE, dU, speed)
        speed = int(action / self.num_dir) + 1
        direction = action % self.num_dir + 1
        a[3] = speed
        for i in range(2, -1, -1):
            quo = int(direction / (3**i))
            direction = direction % (3**i)
            a[i] = -1 if quo == 2 else quo
        
        # Calculate the next position
        next_pos = tuple(np.array(self.cur_pos) + a[:3])
        
        # Check whether the drone is out of bound first, 
        # then check the drone is on the position of the target or the obstacles
        if any(pos < 0 or pos > 30 for pos in next_pos):
            done = True
            reward = -100
        else:
            if self.state[next_pos] == 20:
                done = True
                reward = 100
            elif self.state[next_pos] == -1:
                done = True
                reward = -100
            else:
                done = False
                reward = -1 / a[3]

            # Update the state
            self.state[self.cur_pos] = 0
            self.cur_pos = next_pos
            self.state[self.cur_pos] = a[3]
            self.trace.append(self.cur_pos)

        obs = self.state
    
        return obs, reward, done, {}
    
    def reset(self):
        self.state = self.ini_state
        self.cur_pos = self.ini_pos
        self.trace = [self.cur_pos]
        self.fig_on = False
        
        return self.state
            
    def render(self, mode='rgb_array'):
        
        if not self.fig_on:
            self._render_evn()
        else:
            trace = np.array(self.trace)
            self.ax.plot3D(trace[:, 0], trace[:, 1], trace[:, 2], 'r-')
            plt.pause(1e-9)
    
    def _render_evn(self):
        self.fig_on = True
#         self.fig = plt.figure()
#         self.fig.canvas.layout.width = '50%'
#         self.fig.canvas.layout.height = '700px'
        self.ax = plt.axes(projection='3d')
        self.ax.clear()
        plt.show(block=False)
        
        obs_x, obs_y, obs_z = (self.state==-1).nonzero()
        t_x, t_y, t_z = (self.state==20).nonzero()

        self.ax.scatter3D(obs_x, obs_y, obs_z, zdir='z', c='black')
        self.ax.scatter3D(t_x, t_y, t_z, zdir='z', c='blue')
        
        
        

In [43]:
def create_world(start: Tuple[int, int, int], end: Tuple[int, int, int]):
    """
    state is a 3D ndarray whose components are followings:
    current position: 1 ~ 16 (1 -> 0 m/s, 16 -> 15 m/s)
    end position: 20
    obstacle position: -1
    the others: 0
    """
    # Set the shape of the grid world
    state = np.zeros((30, 30, 30), dtype=np.float32)
    
    # Locate the obstacles
    state[3:7, 24:28, 0:17], state[13:17, 24:28, 0:15], state[22:26, 24:28, 0:21] = -1, -1, -1
    state[3:7, 11:20, 0:12], state[13:17, 11:20, 0:19], state[22:26, 11:20, 0:15] = -1, -1, -1
    state[3:7, 3:7  , 0:21], state[13:17, 3:7  , 0:12], state[22:26, 3:7  , 0:19] = -1, -1, -1
    
    # Check whether Start or End is out of bound or on the obstacle position
    try:
        state[start], state[end]
    except:
        print('Start or End is out of bound')
        return
    assert ((state[start] != -1) & (state[end] != -1)), 'Start or End is on the obstacle position'
    
    # Assign the drone's initial position as 1 and the target position as 20
    state[start] = 1
    state[end] = 20
    
    return state
    

In [44]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [45]:
class DQN(nn.Module):

    def __init__(self):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv3d(1, 16, kernel_size=6, stride=2)
        self.conv2 = nn.Conv3d(16, 32, kernel_size=3, stride=2)
        self.conv3 = nn.Conv3d(32, 64, kernel_size=2, stride=2)
        self.fc1 = nn.Linear(1728, 864)
        self.fc2 = nn.Linear(864, 416)
        
        for w in [self.conv1, self.conv2, self.conv3, self.fc1, self.fc2]:
            torch.nn.init.kaiming_uniform_(w.weight)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.fc1(x.view(x.size(0), -1)))
        out = F.relu(self.fc2(x))
        
        return out

In [46]:
BATCH_SIZE = 128
GAMMA = 0.999
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200
TARGET_UPDATE = 10
start = (0, 0, 0)
target = (29, 29, 15)
    
state = create_world(start, target)
env = UAV3DGrid(state, start)

policy_net = DQN().to(device)
target_net = DQN().to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = optim.Adam(policy_net.parameters())
memory = ReplayMemory(10000)

n_actions = env.action_space.n
steps_done = 0
episode_durations = []


In [47]:
def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            return policy_net(state).max(1)[1].view(1, 1)
    else:
        return torch.tensor([[random.randrange(n_actions)]], device=device, dtype=torch.long)
    

def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    
    batch = Transition(*zip(*transitions)) # zip(*transition) --> transpose, *zip: unpacking
    
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool)
    
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)
    
    # policy_net(state_batch): BATCH_SIZE x n_actions, action_batch: BATCH_SIZE x 1
    # by gathering, q_sa has values for each state in state_batch with taken actions in action_batch
    q_sa = policy_net(state_batch).gather(1, action_batch)
    
    # terminal states have zero state values
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach() # max Q(s+1,a) = V(s+1)
    expected_q_sa = (next_state_values * GAMMA) + reward_batch
    
    # Huber loss
    loss = F.smooth_l1_loss(q_sa, expected_q_sa.unsqueeze(1))
    
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()
    
    
    
        

state: 30x30x30, action: 1, next_state: 30x30x30, reward: 1

batch update 방식이기 때문에 모두 Batch_size x dim의 형태로 바꿔 줘야 함
action 같은 경우는 network output 자체가 Batch_size x dim의 형태로 나옴

state이랑 next_state을 받을 때 Batch_size x 30 x 30 x 30의 형태로 나오게 코드 수정해야함

In [48]:
%matplotlib tk

num_episodes = 1000
for i_episode in range(num_episodes):
    env.reset()
    state = torch.from_numpy(env.state).unsqueeze(0).unsqueeze(0)  # shape:  1 (# batch) x 1 (# channels) x 30 x 30 x 30
    for t in count():
        action = select_action(state)
        next_state, reward, done, _ = env.step(action.item())
        reward = torch.tensor([reward], dtype=torch.float32, device=device)
        
        if not done:
            next_state = torch.from_numpy(next_state).unsqueeze(0).unsqueeze(0)
        else:
            next_state = None
            
        memory.push(state, action, next_state, reward)
        
        state = next_state
        env.render()
        optimize_model()
        if done:
            episode_durations.append(t + 1)
            break
        
    if i_episode % TARGET_UPDATE == 0:
        target_net.load_state_dict(policy_net.state_dict())
        
    print(f'Episode: {i_episode}, Episode_durations: {episode_durations[i_episode]}')
        
print('Complete')
env.close()

        

Episode: 0, Episode_durations: 1
Episode: 1, Episode_durations: 1
Episode: 2, Episode_durations: 1
Episode: 3, Episode_durations: 1
Episode: 4, Episode_durations: 1
Episode: 5, Episode_durations: 1
Episode: 6, Episode_durations: 2
Episode: 7, Episode_durations: 2
Episode: 8, Episode_durations: 1
Episode: 9, Episode_durations: 1
Episode: 10, Episode_durations: 1
Episode: 11, Episode_durations: 1
Episode: 12, Episode_durations: 1
Episode: 13, Episode_durations: 1
Episode: 14, Episode_durations: 1
Episode: 15, Episode_durations: 1
Episode: 16, Episode_durations: 1
Episode: 17, Episode_durations: 1
Episode: 18, Episode_durations: 1
Episode: 19, Episode_durations: 1
Episode: 20, Episode_durations: 1
Episode: 21, Episode_durations: 5
Episode: 22, Episode_durations: 1
Episode: 23, Episode_durations: 10
Episode: 24, Episode_durations: 1
Episode: 25, Episode_durations: 3
Episode: 26, Episode_durations: 6
Episode: 27, Episode_durations: 1
Episode: 28, Episode_durations: 4
Episode: 29, Episode_du

KeyboardInterrupt: 