In [1]:
import gym
import gym_Aircraft
from gym_Aircraft.envs.Aircraft_env import (
    NOT_DONE,            # 0
    CRASHED,             # 1
    AVOIDED_TIMEOUT,     # 2
    AVOIDED_IN_ADVANCE,  # 3
)

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.utils as torch_utils

import numpy as np
import matplotlib.pyplot as plt

from collections import Counter
import random

In [2]:
env = gym.make("acav-v0")
# env = gym.make("CartPole-v1")
# env.action_space.seed(960501) 



In [3]:
dim_obs = env.observation_space.shape[0]
dim_act = env.action_space.n
max_episode = 5000
max_replay = 50000
batch_size = 256 
gamma = 1.  # ㅋㅋㅋ .9에서 1로 바꾸니까 갑자기 잘됨
eps = 1.  # 얘도 1로 시작해야 되네
eps_decay = .99
eps_decay_step = 10
learning_rate = 1e-6  # 학습하다가 발산하면 얘부터 만져봐야 됨
max_grad_norm = 10

In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

class DQN(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(dim_obs, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, dim_act),
            nn.LeakyReLU(.5),  # 와 이거 혁신이네
        )
        
    def forward(self, *input):
        return self.model(*input).squeeze()

In [None]:
net = DQN().to(device)
criterion = torch.nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=learning_rate)

memory = []
record = {'score': [], 'average loss': [], 'actions': []}

global_step = 0
train_step = 0
for episode in range(max_episode):
    s0 = env.reset()
    done = False
    net_reward = 0
    episode_loss = 0
    episode_norm = 0
    episode_step = 0
    episode_actions = []
    
    while not done:
        if train_step % eps_decay_step == eps_decay_step - 1:
            eps *= eps_decay
            eps = max(.01, eps)
            
        if random.random() > eps:
            a = np.argmax(net(torch.from_numpy(s0).float().to(device)).detach().cpu())
            a = int(a) 
        else:
            a = env.action_space.sample()
        
        # 의도적인 탐험
#         if episode < 50: a = 1
#         elif episode < 100: a = 2
#         elif episode < 150:
#             a = 1 if episode_step < 50 else 0
#         elif episode < 200:
#             a = 2 if episode_step < 70 else 0
            
        s1, reward, done, _ = env.step(a)
        s0 = s1
        net_reward += reward
        if not done: 
            episode_actions.append(a)
        
        memory.append([s0, a, reward, s1, done])
        if len(memory) > max_replay:
            del memory[0]
        
        if global_step == max_replay - 1:
            print('Training now begins')
        
# https://stackoverflow.com/questions/52770780/why-is-my-deep-q-net-and-double-deep-q-net-unstable
        if len(memory) == max_replay and global_step % 2 == 0:  # soft update
            y_batch = np.zeros(batch_size)
            s0_batch = np.zeros((batch_size, dim_obs))
            for j, sample in enumerate(random.sample(memory, batch_size)):
                mem_s0, mem_a, mem_reward, mem_s1, is_terminal = sample
                s0_batch[j] = mem_s0
                if is_terminal is NOT_DONE:
                    # y_j = r_j + gamma * max_a' Q(s_j+1, a')
                    y_batch[j] = mem_reward + gamma * torch.max(net(torch.from_numpy(mem_s1).float().to(device)).detach()).cpu().numpy()
                else:
                    y_batch[j] = mem_reward
            # optimize Q w.r.t ||max_a Q(s, a) - y||
            loss = criterion(
                torch.max(net(torch.from_numpy(s0_batch).float().to(device)), 1)[0],
                torch.from_numpy(y_batch).float().to(device)
            )
            loss.backward()
            episode_norm += torch_utils.clip_grad_norm_(net.parameters(), max_grad_norm)
            optimizer.step()
            train_step += 1
            episode_loss += loss
        episode_step += 1
        global_step += 1
    
    # episode log
    # score가 아니라 얼마나 잘 깨는지 봐야 됨
    # score는 단순히 사람 편하자고 계산하는 수임
    print(f'Epi. {episode+1:4d} score: {net_reward:4.0f}', end=' ')
    if len(memory) == max_replay:
        average_episode_loss = episode_loss/episode_step
        record['score'].append(net_reward)
        record['average loss'].append(average_episode_loss)
        record['actions'].append(episode_actions)
        c = Counter(episode_actions)
        print(
            f'E[L]: {average_episode_loss:6.1f} '
            f'e: {eps:.4f} norm: {episode_norm/episode_step:.0f} '
            f'actions: {c[0]:3d} {c[1]:3d} {c[2]:3d} {done}'
        )
    else:
        print()

Epi.    1 score: -356 
Epi.    2 score: -348 
Epi.    3 score: -388 
Epi.    4 score: -384 
Epi.    5 score: -396 
Epi.    6 score: -360 
Epi.    7 score: -356 
Epi.    8 score: -352 
Epi.    9 score: -356 
Epi.   10 score: -384 
Epi.   11 score: -396 
Epi.   12 score: -380 
Epi.   13 score: -376 
Epi.   14 score: -352 
Epi.   15 score: -368 
Epi.   16 score: -368 
Epi.   17 score: -324 
Epi.   18 score: -364 
Epi.   19 score: -360 
Epi.   20 score: -360 
Epi.   21 score: -344 
Epi.   22 score: -352 
Epi.   23 score: -360 
Epi.   24 score: -356 
Epi.   25 score: -356 
Epi.   26 score: -372 
Epi.   27 score: -352 
Epi.   28 score: -368 
Epi.   29 score: -356 
Epi.   30 score: -368 
Epi.   31 score: -364 
Epi.   32 score: -364 
Epi.   33 score: -376 
Epi.   34 score: -372 
Epi.   35 score: -340 
Epi.   36 score: -364 
Epi.   37 score: -348 
Epi.   38 score: -352 
Epi.   39 score: -384 
Epi.   40 score: -396 
Epi.   41 score: -344 
Epi.   42 score: -376 
Epi.   43 score: -348 
Epi.   44 s

Epi.  372 score: -372 
Epi.  373 score: -328 
Epi.  374 score: -376 
Epi.  375 score: -376 
Epi.  376 score: -356 
Epi.  377 score: -380 
Epi.  378 score: -364 
Epi.  379 score: -364 
Epi.  380 score: -388 
Epi.  381 score: -328 
Epi.  382 score: -356 
Epi.  383 score: -388 
Epi.  384 score: -376 
Epi.  385 score: -340 
Epi.  386 score: -340 
Epi.  387 score: -396 
Epi.  388 score: -368 
Epi.  389 score: -344 
Epi.  390 score: -344 
Epi.  391 score: -388 
Epi.  392 score: -356 
Epi.  393 score: -392 
Epi.  394 score: -356 
Epi.  395 score: -360 
Epi.  396 score: -332 
Epi.  397 score: -360 
Epi.  398 score: -372 
Epi.  399 score: -328 
Epi.  400 score: -372 
Epi.  401 score: -364 
Epi.  402 score: -348 
Epi.  403 score: -384 
Epi.  404 score: -332 
Epi.  405 score: -360 
Epi.  406 score: -324 
Epi.  407 score: -356 
Epi.  408 score: -384 
Epi.  409 score: -372 
Epi.  410 score: -340 
Epi.  411 score: -372 
Epi.  412 score: -380 
Epi.  413 score: -332 
Epi.  414 score: -352 
Epi.  415 s

Epi.  570 score: -324 E[L]:  155.2 e: 0.0100 norm: 3349 actions:  43  44  12 1
Epi.  571 score: -312 E[L]:  113.9 e: 0.0100 norm: 3090 actions:  46  47   6 1
Epi.  572 score: -336 E[L]:  160.3 e: 0.0100 norm: 3386 actions:  40  46  13 1
Epi.  573 score: -344 E[L]:  139.5 e: 0.0100 norm: 3203 actions:  38  46  15 1
Epi.  574 score: -316 E[L]:   84.5 e: 0.0100 norm: 2829 actions:  44  39  15 1
Epi.  575 score: -324 E[L]:   80.1 e: 0.0100 norm: 2784 actions:  42  40  16 1
Epi.  576 score: -320 E[L]:  212.1 e: 0.0100 norm: 3760 actions:  44  45  10 1
Epi.  577 score: -316 E[L]:  186.6 e: 0.0100 norm: 3627 actions:  45  42  12 1
Epi.  578 score: -328 E[L]:  123.4 e: 0.0100 norm: 3188 actions:  41  43  14 1
Epi.  579 score: -324 E[L]:  108.7 e: 0.0100 norm: 2983 actions:  42  40  16 1
Epi.  580 score: -320 E[L]:   82.5 e: 0.0100 norm: 2865 actions:  43  37  18 1
Epi.  581 score: -320 E[L]:   34.6 e: 0.0100 norm: 2495 actions:  43  39  16 1
Epi.  582 score: -304 E[L]:  220.4 e: 0.0100 norm: 4

Epi.  674 score: -348 E[L]:  258.1 e: 0.0100 norm: 4867 actions:  36  44  18 1
Epi.  675 score: -360 E[L]:  210.3 e: 0.0100 norm: 4844 actions:  34  50  15 1
Epi.  676 score: -376 E[L]:  445.3 e: 0.0100 norm: 6169 actions:  30  53  16 1
Epi.  677 score: -360 E[L]:  352.2 e: 0.0100 norm: 5746 actions:  34  53  12 1
Epi.  678 score: -356 E[L]:  260.6 e: 0.0100 norm: 5004 actions:  35  47  17 1
Epi.  679 score: -364 E[L]:  452.5 e: 0.0100 norm: 6266 actions:  33  52  14 1
Epi.  680 score: -360 E[L]:  309.8 e: 0.0100 norm: 5379 actions:  34  49  16 1
Epi.  681 score: -348 E[L]:  362.4 e: 0.0100 norm: 5857 actions:  36  43  19 1
Epi.  682 score: -348 E[L]:  311.9 e: 0.0100 norm: 5461 actions:  37  47  15 1
Epi.  683 score: -356 E[L]:  314.7 e: 0.0100 norm: 5536 actions:  35  51  13 1
Epi.  684 score: -324 E[L]:  319.2 e: 0.0100 norm: 5540 actions:  42  43  13 1
Epi.  685 score: -352 E[L]:  220.0 e: 0.0100 norm: 4931 actions:  37  50  13 1
Epi.  686 score: -352 E[L]:  270.5 e: 0.0100 norm: 5

Epi.  778 score: -392 E[L]: 1224.0 e: 0.0100 norm: 11996 actions:  27  51  22 1
Epi.  779 score: -388 E[L]:  865.5 e: 0.0100 norm: 10151 actions:  27  44  28 1
Epi.  780 score: -344 E[L]:  761.6 e: 0.0100 norm: 9248 actions:  42 163  23 3
Epi.  781 score: -380 E[L]:  815.4 e: 0.0100 norm: 9653 actions:  32 168  27 3
Epi.  782 score: -320 E[L]:  919.6 e: 0.0100 norm: 10420 actions:  45 153  27 3
Epi.  783 score: -368 E[L]:  897.9 e: 0.0100 norm: 10523 actions:  32  42  25 1
Epi.  784 score: -324 E[L]:  762.3 e: 0.0100 norm: 9467 actions:  45 148  33 3
Epi.  785 score: -340 E[L]: 1100.3 e: 0.0100 norm: 11820 actions:  45 151  34 3
Epi.  786 score: -320 E[L]:  864.3 e: 0.0100 norm: 10182 actions:  46 145  35 3
Epi.  787 score: -376 E[L]: 1162.4 e: 0.0100 norm: 12110 actions:  37 161  33 3
Epi.  788 score: -348 E[L]: 1148.2 e: 0.0100 norm: 12210 actions:  37  43  19 1
Epi.  789 score: -356 E[L]: 1373.8 e: 0.0100 norm: 13411 actions:  34  39  25 1
Epi.  790 score: -376 E[L]: 1063.4 e: 0.010

Epi.  881 score: -332 E[L]: 5038.8 e: 0.0100 norm: 39032 actions:  41  48  10 1
Epi.  882 score: -368 E[L]: 6238.3 e: 0.0100 norm: 45583 actions:  33  51  16 1
Epi.  883 score: -432 E[L]: 4943.6 e: 0.0100 norm: 35914 actions:  28 207   1 3
Epi.  884 score: -320 E[L]: 2351.8 e: 0.0100 norm: 21158 actions:  44  45  10 1
Epi.  885 score: -348 E[L]: 4523.1 e: 0.0100 norm: 34515 actions:  45 178   9 3
Epi.  886 score: -392 E[L]: 4861.9 e: 0.0100 norm: 36414 actions:  37 178  20 3
Epi.  887 score: -368 E[L]: 4708.4 e: 0.0100 norm: 35243 actions:  36 178  14 3
Epi.  888 score: -380 E[L]: 5542.3 e: 0.0100 norm: 41070 actions:  32 189   6 3
Epi.  889 score: -344 E[L]: 3634.3 e: 0.0100 norm: 29355 actions:  39  53   8 1
Epi.  890 score: -344 E[L]: 4402.3 e: 0.0100 norm: 32892 actions:  43 174  12 3
Epi.  891 score: -404 E[L]: 4328.9 e: 0.0100 norm: 32400 actions:  37 190  11 3
Epi.  892 score: -412 E[L]: 4948.8 e: 0.0100 norm: 37167 actions:  28 196   7 3
Epi.  893 score: -368 E[L]: 5329.0 e: 0.

In [None]:
k = list(map(lambda x: x.grad, net.parameters()))
l = list(map(lambda x: x.data, net.parameters()))
fig = plt.figure(figsize=(15, 15))
fig.suptitle('Pretty well trained', fontsize=25)

ax = plt.subplot(211)
plt.hist([
    k[0].cpu().numpy().reshape(-1),
    k[2].cpu().numpy().reshape(-1),
], bins=20, color=['r', 'orange'], label=['Layer 1 (input)', 'Layer 2 (output)'])
ax.set_title(r'Gradients $\partial \mathcal{L} / \partial w$', fontsize=22)
ax.legend(fontsize=15)

ax = plt.subplot(212)
plt.hist([
    l[0].cpu().numpy().reshape(-1),
    l[2].cpu().numpy().reshape(-1),
], bins=20, color=['r', 'orange'], label=['Layer 1 (input)', 'Layer 2 (output)'])
ax.set_title('Weights', fontsize=22)
ax.legend(fontsize=15)

plt.show()

In [None]:
fig, host = plt.subplots()
fig.set_size_inches(15, 15)
fig.subplots_adjust(right=.75)

par1 = host.twinx()
# par2 = host.twinx()

# par2.spines['right'].set_position(('axes', 1.2))
# par2.spines['right'].set_visible(True)

p1, = host.plot(record['score'], 'r', label='score')
p2, = par1.plot(record['average loss'], 'orange', label='average loss')

host.set_xlabel('episode', fontsize=22)
host.set_ylabel('score', fontsize=22)
par1.set_ylabel('average loss', fontsize=22)

host.yaxis.label.set_color(p1.get_color())
par1.yaxis.label.set_color(p2.get_color())

tkw = dict(size=4, width=1.5)
host.tick_params(axis='x', **tkw) 
host.tick_params(axis='y', colors=p1.get_color(), **tkw)
par1.tick_params(axis='y', colors=p2.get_color(), **tkw)

lines = [p1, p2]
host.legend(lines, [l.get_label() for l in lines], fontsize=20)  # 이렇게 안 하면 child에 그려진 애들은 legend가 안 붙나봄
plt.show() 

In [None]:
# r: range
# vc: closing velocity
# los: line of sight angle
# daz: azimuthal rate (horizontal look angle, right +)
# dlos: los rate
print('Env Reset')
obs0 = env.reset() 
print('observations')
for label, val in zip(['range', 'vc', 'los', 'daz', 'dlos'], obs0):
    print(f'  {label:8}:\t{val}')

In [None]:
print('Env Get Started')
obs, reward, done, info = env.step(a)
a =  np.argmax(net(torch.from_numpy(obs).float().to(device)).detach().cpu())
print('action:\t', a.item()) 
print('reward:\t', reward)
print('done:\t', done) 
print('observations')
for label, val in zip(['range', 'vc', 'los', 'daz', 'dlos'], obs):
    print(f'  {label:10}:\t{val}')

print('\ninfos')
for label, val in zip(['hdot_cmd', 'range', 'elev', 'azim', 'Pm_NED', 'Pt_NED', 'h'], info):
    print(f'  {label:10}:\t{val}')