In [1]:
from snake_gym import SnakeGym

In [2]:
import sys
import pylab
import random
import numpy as np
from collections import deque

import torch
from torch import nn, optim
import torch.nn.functional as F

In [3]:
class DQNAgent:
    def __init__(self, state_size, action_size):
        self.render = False
        self.load_model = False

        # 상태와 행동의 크기 정의
        self.state_size = state_size
        self.action_size = action_size

        # DQN 하이퍼파라미터
        self.discount_factor = 0.99
        self.learning_rate = 0.001
        self.epsilon = 1.0
        self.epsilon_decay = 0.9999
        self.epsilon_min = 0.05
        self.batch_size = 64
        self.train_start = 2000

        # 리플레이 메모리
        self.memory = deque(maxlen=50000)

        # 모델과 타깃 모델 생성
        self.model = self.build_model()
        self.target_model = self.build_model()
        self.optimizer = optim.Adam(
            self.model.parameters(), lr=self.learning_rate)

        # 타깃 모델 초기화
        self.update_target_model()

        if self.load_model:
            self.model.load_state_dict(torch.load(
                './snake_dqn_trained.bin'))
    
    # 상태가 입력, 큐함수가 출력인 인공신경망 생성
    def build_model(self):
        model = nn.Sequential(
            nn.Linear(self.state_size, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, self.action_size),
        )
        return model

    # 타깃 모델을 모델의 가중치로 업데이트
    def update_target_model(self):
        self.target_model.load_state_dict(self.model.state_dict())

    # 입실론 탐욕 정책으로 행동 선택
    def get_action(self, state):
        if np.random.rand() <= self.epsilon:
            # 무작위 행동 반환
            return torch.LongTensor([[random.randrange(4)]])
        else:
            # 모델로부터 행동 산출
            return self.model(state).data.max(1)[1].view(1, 1)

    # 샘플 <s, a, r, s'>을 리플레이 메모리에 저장
    def append_sample(self, state, action, reward, next_state, done):
        # reward = torch.FloatTensor([reward])
        # next_state = torch.FloatTensor([next_state])
        # done = torch.FloatTensor([done])
        reward = torch.FloatTensor(np.array([reward]))
        next_state = torch.FloatTensor(np.array([next_state]))
        done = torch.FloatTensor(np.array([done]))

        self.memory.append((state, action, reward, next_state, done))

    # 리플레이 메모리에서 무작위로 추출한 배치로 모델 학습
    def train_model(self):
        if self.epsilon > self.epsilon_min and len(agent.memory) > agent.train_start:
            self.epsilon *= self.epsilon_decay

        # 메모리에서 배치 크기만큼 무작위로 샘플 추출
        batch = random.sample(self.memory, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.cat(states)
        actions = torch.cat(actions)
        rewards = torch.cat(rewards)
        next_states = torch.cat(next_states)
        dones = torch.cat(dones)

        # 현재 상태에 대한 모델의 큐함수
        # 다음 상태에 대한 타깃 모델의 큐함수
        current_q = self.model(states).gather(1, actions)
        max_next_q = self.target_model(next_states).detach().max(1)[0]
        expected_q = rewards + (self.discount_factor * max_next_q)

        # 벨만 최적 방정식을 이용한 업데이트 타깃
        self.optimizer.zero_grad()

        loss = F.mse_loss(current_q.squeeze(), expected_q)
        loss.backward()

        self.optimizer.step()


In [4]:
# batch = random.sample(agent.memory, agent.batch_size) # 64
# states, actions, rewards, next_states, dones = zip(*batch)

# # print(states)
# print(len(states), states[0].shape)
# states = torch.cat(states)
# actions = torch.cat(actions)
# rewards = torch.cat(rewards)
# next_states = torch.cat(next_states)
# dones = torch.cat(dones)

# print(states)
# print(states.shape) # 57600 == 900 * 64
# print(agent.model)

# agent.model(states) #ERROR!

# # current_q = agent.model(states).gather(1, actions)
# # max_next_q = agent.target_model(next_states).detach().max(1)[0]
# # expected_q = rewards + (agent.discount_factor * max_next_q)

In [5]:
env = SnakeGym()
state_size = env.state_size
action_size = env.action_size

print(state_size, action_size)

400 4


In [6]:
EPISODES = 500000

agent = DQNAgent(state_size, action_size)
scores, episodes, episode_lengths, epsilons, eat_cnts = [], [], [], [], []

In [7]:
# 학습 시작 시간 체크용
import datetime
print(datetime.datetime.now())

2022-06-30 16:41:10.531423


In [8]:
for e in range(EPISODES):
    done = False
    score = 0
    state = env.reset()
    
    while not done:
        # if e % 100 == 0:
            # env.render()
        
        state = torch.FloatTensor(np.array([state]))
        action = agent.get_action(state)
        
        next_state, reward, done, info = env.step(action)
        
        agent.append_sample(state, action, reward, next_state, done)
        if len(agent.memory) > agent.train_start:
            agent.train_model()
        
        score += reward
        state = next_state
        
        if done:
            agent.update_target_model()
            scores.append(score)
            episodes.append(e)
            episode_lengths.append(info['episode_length'])
            epsilons.append(agent.epsilon)
            eat_cnts.append(info['eat_cnt'])
            print(f"episode: {e:3}, score: {score:4}, memory length: {len(agent.memory):5}, epsilon: {agent.epsilon:.4f}, episode length: {info['episode_length']:3}, eat: {info['eat_cnt']}")
            
            if e % 200 == 0:
                torch.save(agent.model.state_dict(), f"./snake_dqn_trained.bin")
                
                pylab.subplot(4, 1, 1)
                pylab.xlabel('episodes')
                pylab.ylabel('scores')
                pylab.plot(episodes, scores, 'b')
                
                pylab.subplot(4, 1, 2)
                pylab.xlabel('episodes')
                pylab.ylabel('episode length')
                pylab.plot(episodes, episode_lengths, 'g')
                
                pylab.subplot(4, 1, 3)
                pylab.xlabel('episodes')
                pylab.ylabel('epsilon')
                pylab.plot(episodes, epsilons, 'r')
                
                pylab.subplot(4, 1, 4)
                pylab.xlabel('episodes')
                pylab.ylabel('eat cnt')
                pylab.plot(episodes, eat_cnts, 'y')
                
                pylab.savefig('./snake_dqn.png')

episode:   0, score:  -91, memory length:     8, epsilon: 1.0000, episode length:   8, eat: 1
episode:   1, score:  -76, memory length:    32, epsilon: 1.0000, episode length:  24, eat: 0
episode:   2, score:  -96, memory length:    36, epsilon: 1.0000, episode length:   4, eat: 0
episode:   3, score:  -86, memory length:    49, epsilon: 1.0000, episode length:  13, eat: 1
episode:   4, score:  -96, memory length:    53, epsilon: 1.0000, episode length:   4, eat: 0
episode:   5, score:  -84, memory length:    69, epsilon: 1.0000, episode length:  16, eat: 0
episode:   6, score:  -99, memory length:    70, epsilon: 1.0000, episode length:   1, eat: 0
episode:   7, score:  -99, memory length:    71, epsilon: 1.0000, episode length:   1, eat: 0
episode:   8, score:  -99, memory length:    72, epsilon: 1.0000, episode length:   1, eat: 0
episode:   9, score:  -89, memory length:    83, epsilon: 1.0000, episode length:  11, eat: 0
episode:  10, score:  -79, memory length:   104, epsilon: 1.

In [None]:
# 눈으로 보는 테스트 시작

agent.load_model = True

for e in range(1):
    done = False
    score = 0
    state = env.reset()
    
    while not done:
        env.render()
        
        state = torch.FloatTensor(np.array([state]))
        action = agent.get_action(state)
        
        next_state, reward, done, info = env.step(action)
        
        score += reward
        state = next_state
        
        if done:
            print(f"episode: {e:3}, score: {score:4}, memory length: {len(agent.memory):5}, epsilon: {agent.epsilon:.4f}, episode length: {info['episode_length']:3}")