<a href="https://colab.research.google.com/github/LanYuCL/RLstartup/blob/master/DQN_GPU.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [2]:
pip install pygame



In [3]:
!apt-get install -y xvfb python-opengl > /dev/null 2>&1

In [4]:
!pip install gym pyvirtualdisplay > /dev/null 2>&1

In [5]:
import pygame
import torch
import torch.nn as nn
import pickle
import random
import tqdm
import sys
import math
import numpy as np
from abc import ABC
from PIL import Image
from collections import deque, OrderedDict
from pygame.locals import *
from pyvirtualdisplay import Display
thedisplay = Display(visible=0, size=(800, 1200))
thedisplay.start()

pygame 2.0.0 (SDL 2.0.12, python 3.6.9)
Hello from the pygame community. https://www.pygame.org/contribute.html


<pyvirtualdisplay.display.Display at 0x7f9db9cbbcf8>

In [6]:
# ------------------------------------------------------------------------------------------------------------------------------------------------------
# Project: Reinforcement Learning on Snake
# Author: Cai Ruikai
# Date: 2020.10.10
# ------------------------------------------------------------------------------------------------------------------------------------------------------

default_config = {
    'SCREEN_WIDTH': 200,  # 屏幕宽度
    'SCREEN_HEIGHT': 200,  # 屏幕高度
    'BLOCK_SIZE': 20,  # 方格大小

    'CROSS_BOUNDARY': True,  # 是否允许穿过边界

    'BACKGROUND_COLOR': (60, 60, 60),  # 背景颜色
    'SNAKE_COLOR': (0, 160, 100),  # 蛇的颜色
    'FOOD_COLOR': (240, 240, 240),  # 水果的颜色
    'OBSTACLE_COLOR': (255, 0, 0),  # 障碍物颜色

    'OBSTACLE': False,
    'OBSTACLE_NUM': 2,  # 障碍物数量
    'OBSTACLE_FRESH': False,  # 是否刷新障碍物位置
    'OBSTACLE_FRESH_RATE': 30  # 障碍物刷新频率
}


class Snake_Env:
    def __init__(self, config=None):
        # use default config
        if not config:
            config = default_config

        # game environment setting
        self._SCREEN_WIDTH = config['SCREEN_WIDTH']
        self._SCREEN_HEIGHT = config['SCREEN_HEIGHT']
        self._BLOCK_SIZE = config['BLOCK_SIZE']
        self._X_AREA = (0, self._SCREEN_WIDTH // self._BLOCK_SIZE - 1)
        self._Y_AREA = (0, self._SCREEN_HEIGHT // self._BLOCK_SIZE - 1)

        self._CROSS_BOUNDARY = config['CROSS_BOUNDARY']

        self._OBSTACLE_NUM = config['OBSTACLE_NUM']
        self._OBSTACLE_FRESH = config['OBSTACLE_FRESH']
        self._OBSTACLE_FRESH_RATE = config['OBSTACLE_FRESH_RATE']

        self._BACKGROUND_COLOR = config['BACKGROUND_COLOR']
        self._SNAKE_COLOR = config['SNAKE_COLOR']
        self._OBSTACLE_COLOR = config['OBSTACLE_COLOR']
        self._FOOD_COLOR = config['FOOD_COLOR']
        self._OBSTACLE=config['OBSTACLE']

        # game status init
        self._step = 0
        self._score = 0
        self._step_reward = 0
        self._game_over = 0

        # pygame init
        pygame.init()
        self.screen = pygame.display.set_mode((self._SCREEN_WIDTH, self._SCREEN_HEIGHT))
        pygame.display.set_caption('Snake     Score:0')

        # snake,fruits,obstacles init
        self._snake, self._death_pos, self._foods, self._obstacles = deque(), (-1, 0), list(), list()
        self._init_spfo()
        self.render()
        pygame.display.update()

        # obs_dim, act_dim
        self._obs_dim = (self._SCREEN_WIDTH, self._SCREEN_HEIGHT, 3)
        self._act_dim = 4

    def _init_spfo(self):
        # snake
        for i in range(3):
            self._snake.append((3 - i, 0))
        # pos
        self._death_pos = (-1, 0)
        # foods
        self._foods.append(self._generate_xy())
        # obstacles
        if self._OBSTACLE:
            for i in range(self._OBSTACLE_NUM):
                self._obstacles.append(self._generate_xy())

    def _generate_xy(self):
        x = random.randint(self._X_AREA[0], self._X_AREA[1])
        y = random.randint(self._Y_AREA[0], self._Y_AREA[1])
        while (x, y) in set(list(self._snake) + self._obstacles + self._foods):
            x = random.randint(self._X_AREA[0], self._X_AREA[1])
            y = random.randint(self._Y_AREA[0], self._Y_AREA[1])
        return x, y

    def _refresh_food(self, eaten_food):
        if eaten_food == 0:
            self._foods[0] = self._generate_xy()
        # if eaten_food == 1:
        #     self._foods.pop()
        # if self._step % self._X_AREA[1]*2 == 0:
        #     self._foods = self._foods[:1]
        #     self._foods.append(self._generate_xy())
        # if self._step % self._X_AREA[1] > self._X_AREA[1] *0.9:
        #     self._foods = self._foods[:1]

    def _refresh_obstacle(self):
        if not self._OBSTACLE: return
        if self._step % self._OBSTACLE_FRESH_RATE == 0:
            self._obstacles = list()
            for i in range(self._OBSTACLE_NUM):
                self._obstacles.append(self._generate_xy())

    def _print_over(self):
        font = pygame.font.Font(None, 50)
        fwidth, fheight = font.size('GAME OVER')
        imgtext = font.render('GAME OVER', True, (255, 0, 0))
        self.screen.blit(imgtext, (self._SCREEN_WIDTH // 2 - fwidth // 2, self._SCREEN_HEIGHT // 2 - fheight // 2))

    def _move_snake(self, action):
        env_action = ((-1, 0), (1, 0), (0, -1), (0, 1))
        reward, eaten_food = -0.15, -1
        # judge whether action is legal
        contrary = {(0, 1): (0, -1), (0, -1): (0, 1), (1, 0): (-1, 0), (-1, 0): (1, 0)}
        action = env_action[action]
        if action == self._death_pos:
            self._game_over = 1
            # print('\nDeath: action is legal')
        else:
            next_head = (self._snake[0][0] + action[0], self._snake[0][1] + action[1])
            # obstacles
            if next_head in self._obstacles:
                self._game_over = 1
                # print('\nDeath: obstacles')
            # boundary
            if self._CROSS_BOUNDARY:
                if next_head[0] < 0:
                    next_head = (self._X_AREA[1], next_head[1])
                elif next_head[0] > self._X_AREA[1]:
                    next_head = (0, next_head[1])
                elif next_head[1] < 0:
                    next_head = (next_head[0], self._Y_AREA[1])
                elif next_head[1] > self._Y_AREA[1]:
                    next_head = (next_head[0], 0)
            else:
                if next_head[0] < 0 or next_head[0] > self._X_AREA[1] or next_head[1] < 0 or next_head[1] > \
                        self._Y_AREA[
                            1]:
                    self._game_over = 1
                    # print('\nDeath: boundary')
            # body:
            if next_head in self._snake:
                self._game_over = 1
                # print('\nDeath: body')

            # fruit
            if next_head in self._foods:
                eaten_food = self._foods.index(next_head)
                # reward += 10 if eaten_food == 0 else 50
                reward += 50

            else:
                self._snake.pop()

            self._snake.appendleft(next_head)
        if not self._game_over:
            self._death_pos = contrary[action]
            dis = (math.sqrt(pow((self._foods[0][0] - next_head[0]), 2) + pow((self._foods[0][1] - next_head[1]), 2)))
            dis_reward = (1 / max(1.0, dis)) * 1
            reward += dis_reward
        else:
            reward = -10
        self._score += eaten_food+1 if eaten_food<1 else 5
        return reward, eaten_food

    def render(self):
        # draw background
        self.screen.fill(self._BACKGROUND_COLOR)
        # draw gird and x-axis
        for x in range(self._BLOCK_SIZE, self._SCREEN_WIDTH, self._BLOCK_SIZE):
            pygame.draw.line(self.screen, (0, 0, 0), (x, 0), (x, self._SCREEN_HEIGHT), 1)
        # draw gird and y-axis
        for y in range(self._BLOCK_SIZE, self._SCREEN_HEIGHT, self._BLOCK_SIZE):
            pygame.draw.line(self.screen, (0, 0, 0), (0, y), (self._SCREEN_WIDTH, y), 1)
        # draw food
        for index, food in enumerate(self._foods):
            food_color = self._FOOD_COLOR if index == 0 else (255, 215, 0)
            pygame.draw.circle(self.screen, food_color,
                               (food[0] * self._BLOCK_SIZE + self._BLOCK_SIZE // 2,
                                food[1] * self._BLOCK_SIZE + self._BLOCK_SIZE // 2),
                               self._BLOCK_SIZE // 2, 0)
        # draw obstacle
        for obs in self._obstacles:
            pygame.draw.rect(self.screen, self._OBSTACLE_COLOR,
                             ((obs[0] * self._BLOCK_SIZE, obs[1] * self._BLOCK_SIZE),
                              (self._BLOCK_SIZE, self._BLOCK_SIZE)),
                             0)
        # draw snake
        for index, node in enumerate(self._snake):
            if index == 0:
                pygame.draw.circle(self.screen, self._SNAKE_COLOR,
                                   (node[0] * self._BLOCK_SIZE + self._BLOCK_SIZE // 2,
                                    node[1] * self._BLOCK_SIZE + self._BLOCK_SIZE // 2),
                                   self._BLOCK_SIZE // 2, 0)
            else:
                pygame.draw.rect(self.screen, self._SNAKE_COLOR,
                                 ((node[0] * self._BLOCK_SIZE, node[1] * self._BLOCK_SIZE),
                                  (self._BLOCK_SIZE, self._BLOCK_SIZE)),
                                 0)
        if self._game_over:
            self._print_over()

        pygame.display.set_caption('Score:{:.3f}'.format(self._score))
        pygame.display.update()

    def init(self):
        self.__init__()
        obs = pygame.surfarray.array3d(pygame.display.get_surface()).transpose((1, 0, 2))
        return obs

    def obs_dim(self):
        return self._obs_dim

    def act_dim(self):
        return self._act_dim

    def reset(self):
        return self.init()

    def frame_step(self, action):
        self._step += 1

        self._step_reward, eaten_food = self._move_snake(action)
        self._refresh_food(eaten_food)
        if self._OBSTACLE_FRESH:
            self._refresh_obstacle()

        self.render()
        obs = pygame.surfarray.array3d(pygame.display.get_surface()).transpose((1, 0, 2))

        return obs, self._step_reward, self._score, self._game_over, self.get_game_info()

    @staticmethod
    def get_human_action():
        action = None
        while not action:
            for event in pygame.event.get():
                if event.type == QUIT:
                    sys.exit()
                elif event.type == pygame.KEYUP:
                    if event.key == pygame.K_LEFT:
                        return 0
                    if event.key == pygame.K_RIGHT:
                        return 1
                    if event.key == pygame.K_UP:
                        return 2
                    if event.key == pygame.K_DOWN:
                        return 3
                else:
                    pass

    def get_game_info(self):
        info = {'step': self._step,
                'snake': list(self._snake),
                'obstacles': self._obstacles,
                'foods': self._foods}
        return info

    def print_game_info(self):
        info = self.get_game_info()
        print('\ncurrent step:{} reward :{} game score : {}'.format(info['step'], self._step_reward, self._score))
        for k, v in info.items():
            if k == 'count':
                continue
            print(k, v)


def demo():
    env = Snake_Env()
    env.init()
    env.render()
    game_over = False
    while not game_over:
        human_action = env.get_human_action()
        obs, step_reward, game_score, game_over, info = env.frame_step(human_action)
        env.print_game_info()
        if game_over:
            print(game_score)
            game_over = False
            env.reset()
            env.render()



In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def to_tensor(x):
    """A helper function to transform a numpy array to a Pytorch Tensor"""
    if isinstance(x, np.ndarray):
        x = torch.from_numpy(x).type(torch.float32).to(device)
    assert isinstance(x, torch.Tensor), type(x)
    if x.dim() == 3 or x.dim() == 1:
        x = x.unsqueeze(0)
    assert x.dim() == 2 or x.dim() == 4, x.shape
    return x


def process_state(state):
    # covert env's RGB format [H,W,C] to [N,C,H,W]
    # also covert it to Tensor
    state = torch.from_numpy(state).type(torch.float32).to(device)
    if state.dim() == 3:
        state = state.permute(2, 0, 1).unsqueeze(0)
    elif state.dim() == 4:
        state = state.permute(0, 3, 1, 2)
    assert state.dim() == 4, state.shape

    return state


class ReplayMemory:
    def __init__(self, capacity=50000, learn_start=500):
        self.capacity = capacity
        self.learn_start = learn_start

        self.memory = deque(maxlen=capacity)

    def push(self, transition):
        self.memory.append(transition)
        if len(self.memory) == self.learn_start:
            print("Current memory contains {} transitions,start learning!".format(self.learn_start))

    def get_batch(self, batch_size):
        batch = random.sample(self.memory, batch_size)

        states = np.stack([transition[0] for transition in batch])
        next_states = np.stack([transition[3] for transition in batch])

        actions = to_tensor(np.stack([transition[1] for transition in batch])).long().view(-1, 1)

        rewards = to_tensor(np.stack([transition[2] for transition in batch])).squeeze()
        not_done_mask = to_tensor(np.stack([1 - transition[4] for transition in batch])).squeeze()

        return states, actions, rewards, next_states, not_done_mask

    def load(self, path):
        self.memory = pickle.load(open(path, 'rb'))

    def save(self, path):
        pickle.dump(self.memory, open(path, 'wb'))

    def __len__(self):
        return len(self.memory)



In [8]:


# Double DQN method
DQN_Default_Conf = {
    'memory_size': 100000,
    'learn_start': 500,
    'batch_size': 64,
    'learn_freq': 2,
    'target_update_freq': 30,
    'clip_norm': 5,
    'learning_rate': 0.001,
    'eps': 0.3,
    'max_train_iteration': 5000000,
    'reward_threshold': 1000,
    'max_episode_length': 500,
    'gamma': 0.9,
    'evaluate_int': 1,
}


class DQN_Network(nn.Module, ABC):
    def __init__(self, obs_dim=(520, 520, 3), act_dim=4):
        super(DQN_Network, self).__init__()

        self.input_shape = obs_dim
        self.num_actions = act_dim
        self.model = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(3, 16, 3, 2)),
            ('relu1', nn.ReLU()),
            ('conv2', nn.Conv2d(16, 32, 3)),
            ('relu2', nn.ReLU()),
            ('conv3', nn.Conv2d(32, 32, 3, 2)),
            ('relu3', nn.ReLU()),
            ('conv4', nn.Conv2d(32, 16, 3)),
            ('relu4', nn.ReLU()),
            ('conv5', nn.Conv2d(16, 1, 3, 2)),
            ('relu5', nn.ReLU()),
            # ('conv6', nn.Conv2d(16, 1, 3)),
            # ('relu6', nn.ReLU()),
            ('flatten', nn.Flatten()),
            ('linear1', nn.Linear(484, 128)),
            ('relu7', nn.ReLU()),
            ('linear2', nn.Linear(128, act_dim))
            # ('linear2', nn.Linear(512, self.num_actions))
        ]))
        # print('Network Build Done:\n', self.model)

    def forward(self, observation):
        observation = process_state(observation)
        return self.model(observation)


class DQN_Agent(DQN_Network, ABC):
    def __init__(self):
        super(DQN_Agent, self).__init__()

    def load_weights(self, weights=None):
        if weights:
            self.mode.load_state_dict(weights)
        return self.model


class DQN_Method:
    def __init__(self, config=None):
        if not config:
            config = DQN_Default_Conf
        # parameters
        self.best_scores = 0
        self.learn_freq = config["learn_freq"]
        self.learn_start = config["learn_start"]
        self.learning_rate = config["learning_rate"]
        self.target_update_freq = config["target_update_freq"]
        self.memory = ReplayMemory(capacity=config["memory_size"], learn_start=config["learn_start"])

        self.batch_size = config["batch_size"]
        self.max_train_iteration = config["max_train_iteration"]
        self.max_episode_length = config["max_episode_length"]

        self.eps = config["eps"]
        self.gamma = config["gamma"]
        self.clip_norm = config["clip_norm"]
        self.evaluate_int = config["evaluate_int"]
        self.reward_threshold = config["reward_threshold"]

        self.total_step = 0
        self.step_since_update = 0
        self.step_since_evaluate = 0

        # create environment
        self.env = Snake_Env()
        self.obs_dim = self.env.obs_dim()
        self.act_dim = self.env.act_dim()

        # create double DQN network
        self.network = DQN_Network(self.obs_dim, self.act_dim).to(device)
        self.network.eval()

        self.target_network = DQN_Network(self.obs_dim, self.act_dim).to(device)
        self.target_network.load_state_dict(self.network.state_dict())
        self.network.eval()

        # set optimizer and loss
        self.optimizer = torch.optim.Adam(self.network.parameters(), lr=self.learning_rate)
        self.loss = nn.MSELoss()

    def compute_action(self, observation, eps=None):
        values = self.target_network(observation)
        values = values.cpu().detach().numpy()
        if not eps:
            eps = self.eps
        action = np.argmax(values) if np.random.random() > eps else np.random.choice(self.act_dim)

        return action

    def prepare_memory(self):
        #pbar = tqdm.tqdm(total=self.learn_start, desc="preparing replay memory")
        while len(self.memory) < self.learn_start:
            env = Snake_Env()
            current_state = env.reset()
            act = self.compute_action(current_state)
            for t in range(self.max_episode_length):

                next_state, reward, score, game_over, other_info = env.frame_step(act)
                transition = (current_state, act, reward, next_state, game_over)
                # self.test(transition)

                self.memory.push(transition)
                #pbar.update()
                current_state = next_state
                act = self.compute_action(current_state)
                if game_over:
                    break
        #pbar.close()

    def test(self, transition):
        current_state, act, reward, next_state, game_over = transition
        current_state = Image.fromarray(np.uint8(current_state))
        current_state.show(title="current_state")
        next_state = Image.fromarray(np.uint8(next_state))
        next_state.show(title="next_state")
        print(act, reward, game_over)
        input()

    def train(self,  pre_trained=None):
        if pre_trained:
            self.network=torch.load(pre_trained)
            self.target_network=torch.load(pre_trained)
            self.best_scores=float(pre_trained.split('_')[1])
            self.total_step=int(pre_trained.split('_')[2][:-3])

        self.prepare_memory()
        print('Start Training')
        # train network in max_train_iteration
        for train_iteration in range(self.max_train_iteration):
            current_state = self.env.reset()
            act = self.compute_action(current_state)
            stat = {"loss": []}

            # each train iteration has max episode length
            for t in range(self.max_episode_length):

                next_state, reward, score, game_over, other_info = self.env.frame_step(act)
                transition = (current_state, act, reward, next_state, game_over)
                # self.test(transition)

                self.memory.push(transition)

                self.total_step += 1
                self.step_since_update += 1

                if game_over:
                    break

                current_state = next_state
                act = self.compute_action(current_state)

                if t % self.learn_freq != 0: continue

                states, actions, rewards, next_states, not_done_mask = self.memory.get_batch(self.batch_size)
                # image1=Image.fromarray(states[0])
                # image1.show('1')
                # image2 = Image.fromarray(next_states[0])
                # image2.show('2')
                # print(actions[0],rewards[0],not_done_mask[0])
                # input()

                with torch.no_grad():
                    Q_t_plus_one_max = self.target_network(next_states).max(1)[0]
                    Q_t_plus_one = Q_t_plus_one_max * not_done_mask
                    Q_target = rewards + self.gamma * Q_t_plus_one

                self.network.train()
                Q_t = self.network(states)
                Q_t = Q_t.gather(1, actions).squeeze()

                assert Q_t.shape == Q_target.shape, print(Q_t.shape, Q_target.shape)

                # Update the network
                self.optimizer.zero_grad()
                loss = self.loss(input=Q_t, target=Q_target)
                loss_value = loss.item()
                stat['loss'].append(loss_value)
                loss.backward()
                nn.utils.clip_grad_norm_(self.network.parameters(), self.clip_norm)
                self.optimizer.step()
                self.network.eval()

            # update target network
            if self.step_since_update > self.target_update_freq:
                if train_iteration % 10000 == 0:
                  print('\nCurrent train iteration:{} Current memory:{}'.format(train_iteration,len(self.memory)))
                  print('Current step:{},{} steps has passed since last update,Now update behavior policy'.format(self.total_step, self.step_since_update))

                self.step_since_update = 0
                self.target_network.load_state_dict(self.network.state_dict())
                self.target_network.eval()

                # evaluate and save network
                self.step_since_evaluate += 1
                if self.step_since_evaluate >= self.evaluate_int:
                    self.step_since_evaluate = 0
                    eva_score, eva_length, eva_reward = self.evaluate()
                    if train_iteration % 10000 == 0:
                      print("best score:{},loss:{:.2f},episode length:{},evaluate score:{},evaluate reward:{:.2f}"
                            .format(self.best_scores,np.mean(stat["loss"]),eva_length ,eva_score,eva_reward))

                    # save best network
                    if eva_score > self.best_scores and eva_score > 1:
                        print('save model of performance:', eva_score)
                        self.best_scores = eva_score
                        torch.save(self.target_network, 'best_{}_{}.pt'.format(eva_score, self.total_step))

    def evaluate(self, weights=None, num_episodes=30, episodes_len=100):
        #pbar = tqdm.tqdm(total=num_episodes, desc="evaluating")
        env = Snake_Env()
        policy = self.target_network
        if weights:
            policy.load_state_dict(weights)
        rewards = []
        epo_len = []
        scores = []
        for i in range(num_episodes):
            obs = env.reset()
            with torch.no_grad():
                act = np.argmax(policy(obs).cpu().detach().numpy())
            epo = 0
            score = 0
            ep_reward=0
            for t in range(episodes_len):
                next_state, reward, score, game_over, other_info = env.frame_step(act)
                act = np.argmax(policy(next_state).cpu().detach().numpy())
                if game_over:
                    break
                epo += 1
                ep_reward+=reward
            epo_len.append(epo)
            rewards.append(ep_reward)
            scores.append(score)
            #pbar.update()
        #pbar.close()
        return np.mean(scores), np.mean(epo_len),np.mean(rewards)


def demo():
    pass

DQN = DQN_Method()
#DQN.train(pre_trained='best_3.1_80096.pt')
DQN.train()


Current memory contains 500 transitions,start learning!
Start Training

Current train iteration:0 Current memory:539
Current step:34,34 steps has passed since last update,Now update behavior policy
best score:0,loss:36.28,episode length:100.0,evaluate score:0.06666666666666667,evaluate reward:9.05


KeyboardInterrupt: ignored