In [1]:
import gym

env = gym.make('gym_snake:snake-v0')
#pi = Policy()

for i_episode in range(30):
    observation = env.reset()
#    env.render()
    for t in range(100):

        action = env.action_space.sample()

        observation, reward, done, _ = env.step(action)
#        env.render()

        if done:
            print("Episode finished after {} timesteps".format(t+1))
            break

env.close()

Episode finished after 4 timesteps
Episode finished after 3 timesteps
Episode finished after 4 timesteps
Episode finished after 5 timesteps
Episode finished after 3 timesteps
Episode finished after 3 timesteps
Episode finished after 5 timesteps
Episode finished after 7 timesteps
Episode finished after 7 timesteps
Episode finished after 5 timesteps
Episode finished after 9 timesteps
Episode finished after 5 timesteps
Episode finished after 3 timesteps
Episode finished after 6 timesteps
Episode finished after 4 timesteps
Episode finished after 4 timesteps
Episode finished after 4 timesteps
Episode finished after 7 timesteps
Episode finished after 5 timesteps
Episode finished after 4 timesteps
Episode finished after 4 timesteps
Episode finished after 21 timesteps
Episode finished after 11 timesteps
Episode finished after 4 timesteps
Episode finished after 3 timesteps
Episode finished after 7 timesteps
Episode finished after 12 timesteps
Episode finished after 10 timesteps
Episode finished

In [2]:
from JSAnimation.IPython_display import display_animation
from matplotlib import animation
from IPython.display import display

def display_frames_as_gif(frames, epi=0):
    plt.figure(figsize=(frames[0].shape[1]/72.0, frames[0].shape[0]/72.0), dpi=72)
    patch = plt.imshow(frames[0])
    plt.axis('off')
    
    def animate(i):
        patch.set_data(frames[i])
        
    anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=100)
    anim.save('snake_DQN_%d.mp4' % epi)
    anim.event_source.stop()
    del anim
    plt.close()

In [3]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from collections import namedtuple

Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

In [4]:
ENV = 'gym_snake:snake-v0'
GAMMA = 0.99
MAX_STEPS = 5000
NUM_EPISODES = 50000

In [5]:
class ReplayMemory:
    
    def __init__(self, CAPACITY):
        self.capacity = CAPACITY
        self.memory = []
        self.index = 0
        
    def push(self, state, action, state_next, reward):
        if len(self.memory) < self.capacity:
            self.memory.append(None)
            
        self.memory[self.index] = Transition(state, action, state_next, reward)
        
        self.index = (self.index + 1) % self.capacity
        
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)

In [6]:
import random
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from IPython.display import clear_output
from tqdm import tqdm, trange

BATCH_SIZE = 32
CAPACITY = 10000

class Brain:
    def __init__(self, num_states, num_actions):
        self.num_actions = num_actions
        
        self.memory = ReplayMemory(CAPACITY)
        
        self.model = nn.Sequential()
        self.model.add_module('fc1', nn.Linear(num_states, 500))
        self.model.add_module('relu1', nn.ReLU())
        self.model.add_module('fc2', nn.Linear(500, 500))
        self.model.add_module('relu2', nn.ReLU())
#        self.model.add_module('fc3', nn.Linear(1000, 1000))
#        self.model.add_module('relu3', nn.ReLU())
        self.model.add_module('fc4', nn.Linear(500, num_actions))
        
        print(self.model)
        
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.0001)
        
    def replay(self):
        if len(self.memory) < BATCH_SIZE:
            return
        
        transitions = self.memory.sample(BATCH_SIZE)
        
        batch = Transition(*zip(*transitions))
        
        state_batch = torch.cat(batch.state)
        action_batch = torch.cat(batch.action)
        reward_batch = torch.cat(batch.reward)
        non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
        
        self.model.eval()
        
        state_action_values = self.model(state_batch).gather(1, action_batch)
        
        non_final_mask = torch.ByteTensor(tuple(map(lambda s : s is not None, batch.next_state)))
        
        next_state_values = torch.zeros(BATCH_SIZE)
        
        next_state_values[non_final_mask] = self.model(non_final_next_states).max(1)[0].detach()
        
        expected_state_action_values = reward_batch + GAMMA * next_state_values
        
        self.model.train()
        
        loss = F.smooth_l1_loss(state_action_values, expected_state_action_values.unsqueeze(1))
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
    def decide_action(self, state, episode):
        epsilon = 0.5 * (1 / (episode + 1))
        
        if epsilon < np.random.uniform(0, 1):
            self.model.eval()
            with torch.no_grad():
                action = self.model(state).max(1)[1].view(1, 1)
        else:
            action = torch.LongTensor([[random.randrange(self.num_actions)]])
            
        return action

In [7]:
class Agent:
    def __init__(self, num_states, num_actions):
        self.brain = Brain(num_states, num_actions)
    
    def update_q_function(self):
        self.brain.replay()
        
    def get_action(self, state, episode):
        action = self.brain.decide_action(state, episode)
        return action
    
    def memorize(self, state, action, state_next, reward):
        self.brain.memory.push(state, action, state_next, reward)

In [8]:
class Environment:
    def __init__(self, agent=None):
        self.env = gym.make(ENV)
        num_states = self.env.observation_space.n
        num_actions = self.env.action_space.n
        if agent:
            self.agent = agent
        else:
            self.agent = Agent(num_states, num_actions)
        
    def run(self):
        episode_10_list = np.zeros(10)
        complete_episodes = 0
        episode_final = False
        frames = []
        list_outputs = {}
        
        for episode in range(NUM_EPISODES):
            observation = self.env.reset()
            
            state = observation
            state = torch.from_numpy(state).type(torch.FloatTensor)
            #state = torch.unsqueeze(state, 0)
            sum_reward = 0
            
            for step in range(MAX_STEPS):
                if episode_final is True:
                    frames.append(self.env.render(mode='rgb_array'))
                    
                action = self.agent.get_action(state, episode)
                
                observation_next, reward, done, _ = self.env.step(action.item())
                sum_reward += reward
                
                if done or step+1==MAX_STEPS:
                    state_next = None
                    
                    episode_10_list = np.hstack((episode_10_list[1:], step + 1))
                    
                    if sum_reward < 10:
                        #reward = torch.FloatTensor([-1.0])
                        complete_episodes = 0
                        list_outputs = {}
                    else:
                        #reward = torch.FloatTensor([1.0])
                        complete_episodes += 1  
                else:
                    #reward = torch.FloatTensor([0.0])
                    state_next = observation_next
                    state_next = torch.from_numpy(state_next).type(torch.FloatTensor)
                    
                reward = torch.FloatTensor([reward])
                
                self.agent.memorize(state, action, state_next, reward)
                
                self.agent.update_q_function()
                
                state = state_next
                
                if done:
                    clear_output(wait=True)
                    list_outputs[episode] = (step, sum_reward)
                    for epi in list_outputs:
                        print("%d Episode: Finished after %3d steps(%+3d) : 최근 10 에피소드의 평균 단계 수 = %.1lf" % (epi, list_outputs[epi][0] + 1, list_outputs[epi][1], episode_10_list.mean()))
                    break
                    
            if episode_final is True:
                display_frames_as_gif(frames)
                break
                
            if complete_episodes >= 10:
                print('10 에피소드 연속 성공')
                episode_final = True
                
        return self.agent

In [9]:
snake_env = Environment()
model = snake_env.run()

19210 Episode: Finished after 117 steps( +9) : 최근 10 에피소드의 평균 단계 수 = 128.6
19211 Episode: Finished after 115 steps(+13) : 최근 10 에피소드의 평균 단계 수 = 128.6
19212 Episode: Finished after 159 steps(+10) : 최근 10 에피소드의 평균 단계 수 = 128.6
19213 Episode: Finished after 143 steps(+11) : 최근 10 에피소드의 평균 단계 수 = 128.6
19214 Episode: Finished after 123 steps(+10) : 최근 10 에피소드의 평균 단계 수 = 128.6
19215 Episode: Finished after 133 steps(+13) : 최근 10 에피소드의 평균 단계 수 = 128.6
19216 Episode: Finished after 121 steps(+11) : 최근 10 에피소드의 평균 단계 수 = 128.6
19217 Episode: Finished after 148 steps(+10) : 최근 10 에피소드의 평균 단계 수 = 128.6
19218 Episode: Finished after 105 steps(+12) : 최근 10 에피소드의 평균 단계 수 = 128.6
19219 Episode: Finished after 140 steps(+11) : 최근 10 에피소드의 평균 단계 수 = 128.6
19220 Episode: Finished after  99 steps(+12) : 최근 10 에피소드의 평균 단계 수 = 128.6
19221 Episode: Finished after 115 steps(+12) : 최근 10 에피소드의 평균 단계 수 = 128.6


In [10]:
def show_video_in_jupyter_nb(width, height, video_url):
    from IPython.display import HTML
    return HTML("""<video width="{}" height="{}" controls>
    <source src={} type="video/mp4">
    </video>""".format(width, height, video_url))
video_url = 'snake_DQN_0.mp4'
show_video_in_jupyter_nb(900, 400,video_url)

In [11]:
def send_email():
    import smtplib, os, pickle
    from email import encoders
    from email.mime.text import MIMEText
    from email.mime.multipart import MIMEMultipart
    from email.mime.base import MIMEBase

    toAddr = "01ee@naver.com"

    email = "01eecubic@gmail.com"

    smtp = smtplib.SMTP('smtp.gmail.com', 587)
    smtp.starttls()
    smtp.login(email, 'anhrtkojyekqzlcb')

    msg = MIMEText('snake 학습이 완료되었습니다.')
    msg['Subject'] = 'Snake result'

    smtp.sendmail(email, toAddr, msg.as_string())

    smtp.quit()
    
send_email()