In [2]:
from tqdm import trange
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [3]:
#定义实体
class Cube:
    def __init__(self,SIZE):
        # 实体位置
        self.SIZE = SIZE
        self.x = np.random.randint(0,self.SIZE)
        self.y = np.random.randint(0,self.SIZE) 
        
    # 打印实体位置
    def __str__(self):
        return f'{self.x},{self.y}' 
    
    def __sub__(self,other):
        return (self.x-other.x,self.y-other.y)
    
    # 两个对象相等的情况
    def __eq__(self,other):
        return self.x == other.x and self.y == other.y
    
    def action(self,choise):
        if choise == 0:
            self.move(x=1, y=1)
        elif choise == 1:
            self.move(x=-1, y=1)
        elif choise == 2:
            self.move(x=1, y=-1)
        elif choise == 3:
            self.move(x=-1, y=-1)
        elif choise == 4:
            self.move(x=0, y=1)
        elif choise == 5:
            self.move(x=0, y=-1)
        elif choise == 6:
            self.move(x=1, y=0)
        elif choise == 7:
            self.move(x=-1, y=0)
        elif choise == 8:
            self.move(x=0, y=0)
    
    def move(self,x=False,y=False):
        if not x:
            self.x += np.random.randint(-1,2)
        else:
            self.x += x
        
        if not y:
            self.y += np.random.randint(-1,2)
        else:
            self.y += y
            
        if self.x < 0:
            self.x = 0
        elif self.x >= self.SIZE:
            self.x = self.SIZE - 1
        
        if self.y < 0:
            self.y = 0
        elif self.y >= self.SIZE:
            self.y = self.SIZE - 1

In [4]:
# 设定环境
class envCube:
    SIZE = 10
    #OBSERVATION_SPACE_VALUES = (SIZE,SIZE,3) # 整个环境的观测空间
    OBSERVATION_SPACE_VALUES = (4,)
    ACTION_SPACE_VALUES = 9 # 九个动作
    RETURN_IMAGE = False # 是否返回图像
    
    FOOD_REWARD = 1000 # 食物吃掉得25分
    ENEMY_PENALITY = -1000 # 被吃掉扣300分
    MOVE_PENALITY = -1 # 移动一次扣一分
    
    # 颜色赋值
    d = {1:(255,0,0), # blue
     2:(0,255,0), # green
     3:(0,0,255)} # red
    PLAYER_N = 1
    FOOD_N = 2
    ENEMY_N = 3
    
    def reset(self):
        # 实例化对象
        self.player = Cube(self.SIZE)
        self.food = Cube(self.SIZE)
        while self.food == self.player:
            self.food = Cube(self.SIZE)
        
        self.enemy = Cube(self.SIZE)
        while self.enemy == self.player or self.enemy == self.food:
            self.enemy = Cube(self.SIZE)
        
        if self.RETURN_IMAGE:
            observation = np.array(self.get_image())
        else:
            observation = (self.player - self.food) + (self.player - self.enemy)
        
        self.episode_step = 0
        
        return observation
    
    def step(self,action):
        
        self.episode_step += 1 # 当前步数加一
        self.player.action(action) # 玩家采取行动
        self.food.move() # 食物移动
        self.enemy.move() # 敌人移动
        
        new_observation = (self.player - self.food) + (self.player - self.enemy)
            
        if self.player == self.food:
            reward = self.FOOD_REWARD
        elif self.player == self.enemy:
            reward = self.ENEMY_PENALITY
        else:
            reward = self.MOVE_PENALITY
        
        done = False
        if self.player == self.food or self.player == self.enemy or self.episode_step >= 200:
            done = True
        
        return new_observation,reward,done
    
    # 刷新图形
    def render(self):
        img = self.get_image()    
        img = img.resize((800,800))
        cv2.imshow('Predator',np.array(img))
        cv2.waitKey(1)
        
    def get_image(self):
        env = np.zeros((self.SIZE,self.SIZE,3),dtype = np.uint8)
        env[self.food.x][self.food.y] = self.d[self.FOOD_N] # 把食物放进去（坐标颜色改变）
        env[self.player.x][self.player.y] = self.d[self.PLAYER_N] # 把食物放进去（坐标颜色改变）
        env[self.enemy.x][self.enemy.y] = self.d[self.ENEMY_N] # 把食物放进去（坐标颜色改变）
        img = Image.fromarray(env,'RGB')
        return img

In [7]:
# Hyper Parameters
BATCH_SIZE = 32
LR = 0.01                   # learning rate
GAMMA = 0.9                 # reward discount
TARGET_REPLACE_ITER = 100   # target update frequency
MEMORY_CAPACITY = 500
env = envCube()
N_ACTIONS = env.ACTION_SPACE_VALUES
N_STATES = env.OBSERVATION_SPACE_VALUES[0]
ENV_A_SHAPE = 0 #if isinstance(env.action_space.sample(), int) else env.action_space.sample().shape     # to confirm the shape

class Net(nn.Module):
    def __init__(self, ):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(N_STATES, 32)
        self.fc1.weight.data.normal_(0, 1)   # initialization
        self.fc2 = nn.Linear(32, 32)
        self.fc2.weight.data.normal_(0, 1)   # initialization
        self.out = nn.Linear(32, N_ACTIONS)
        self.out.weight.data.normal_(0, 1)   # initialization

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        actions_value = self.out(x)
        return actions_value

class DQN(object):
    def __init__(self):
        self.eval_net, self.target_net = Net(), Net()

        self.learn_step_counter = 0                                     # for target updating
        self.memory_counter = 0                                         # for storing memory
        self.memory = np.zeros((MEMORY_CAPACITY, N_STATES * 2 + 2))     # initialize memory
        self.optimizer = torch.optim.Adam(self.eval_net.parameters(), lr=LR)
        self.loss_func = nn.MSELoss()
    
    def save_model(self):
        torch.save(self.eval_net.state_dict(),f'model_{time.time()}')
    
    def choose_action(self, x, epsilon):
        x = torch.unsqueeze(torch.FloatTensor(x), 0)
        # input only one sample
        if np.random.uniform() > epsilon:   # greedy
            actions_value = self.eval_net.forward(x)
            action = torch.max(actions_value, 1)[1].data.numpy()   
            #print(action)
        else:   # random
            action = np.random.randint(0, N_ACTIONS)
            print(action)
        return action

    def store_transition(self, s, a, r, s_):
        transition = np.hstack((s, [a, r], s_))
        # replace the old memory with new memory
        index = self.memory_counter % MEMORY_CAPACITY
        self.memory[index, :] = transition
        self.memory_counter += 1

    def learn(self):
        # target parameter update
        if self.learn_step_counter % TARGET_REPLACE_ITER == 0:
            self.target_net.load_state_dict(self.eval_net.state_dict())
        self.learn_step_counter += 1

        # sample batch transitions
        sample_index = np.random.choice(MEMORY_CAPACITY, BATCH_SIZE)
        b_memory = self.memory[sample_index, :]
        b_s = torch.FloatTensor(b_memory[:, :N_STATES])
        b_a = torch.LongTensor(b_memory[:, N_STATES:N_STATES+1].astype(int))
        b_r = torch.FloatTensor(b_memory[:, N_STATES+1:N_STATES+2])
        b_s_ = torch.FloatTensor(b_memory[:, -N_STATES:])

        # q_eval w.r.t the action in experience
        q_eval = self.eval_net(b_s).gather(1, b_a)  # shape (batch, 1)
        q_next = self.target_net(b_s_).detach()     # detach from graph, don't backpropagate
        q_target = b_r + GAMMA * q_next.max(1)[0].view(BATCH_SIZE, 1)   # shape (batch, 1)
        loss = self.loss_func(q_eval, q_target)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
    def fit(self,env,EPISODES=10000,SHOW_EVERY=1000,epsilon=0.9,EPS_DECAY=0.99998,
            DISCOUNT=0.95,LEARNING_RATE=0.01,img_show=False):
        
        episode_rewards = []
        loss = -1
        print(f'Training for {EPISODES} steps ...')
        
        for episode in range(0,EPISODES,SHOW_EVERY):
            
            for i in trange(SHOW_EVERY,position=0, leave=True,ascii=True):
                
                obs = env.reset()
                done = False
                episode_reward = 0
                
                while not done:

                    action = self.choose_action(obs,epsilon) # 根据观测值选择行为

                    new_obs,reward,done = env.step(action)

                    self.store_transition(obs, action, reward, new_obs)
                    
                    if self.memory_counter > MEMORY_CAPACITY and self.memory_counter%5 == 0:
                        self.learn()

                    episode_reward += reward
                    
                    obs = new_obs

                episode_rewards.append(episode_reward)
                epsilon *= EPS_DECAY
            print(f'mean reward: {np.mean(episode_rewards[-SHOW_EVERY:])}, episode #{episode}, epsilon:{epsilon}')
            
        if img_show:
            #画曲线
            moving_avg = np.convolve(episode_rewards,np.ones((SHOW_EVERY,))/SHOW_EVERY,mode='valid')
            plt.plot([i for i in range(len(moving_avg))], moving_avg)
            plt.xlabel('episode #')
            plt.ylabel(f'mean {SHOW_EVERY} reward')
            plt.show()
            
    def test(self, model, episodes = 100):
        
        self.eval_net.load_state_dict(torch.load('model_1643877864.1740255'))
        episode_rewards = []

        print(f'Testing for {episodes} steps ...')
        
        for episode in range(0,episodes):
                
            obs = env.reset()
            done = False
            episode_reward = 0

            while not done:

                action = self.choose_action(obs,0) # 根据观测值选择行为

                new_obs,reward,done = env.step(action)

                episode_reward += reward

                obs = new_obs

            episode_rewards.append(episode_reward)
            print(f'Episode {episode}:, reward {episode_reward}')
        

In [8]:
dqn = DQN()
#optimizer = optim.RMSprop(model.parameters())  # 设置优化器
#memory = ReplayMemory(10000)
dqn.fit(env,EPISODES=1000000,SHOW_EVERY=10000,img_show=True)

Training for 1000000 steps ...


  0%|                                                                                        | 0/10000 [00:00<?, ?it/s]

1
0
6
3
8
2
6
7
5
2
0
3
4
1
4
2
0
0
1
4
8
3
4
5
6
4
6
1
4
1
2
6
0
2
0
0
0
5
7
0
7
5
3
0
0
3
5
2
4
4
6
7
2
7
0
3
4
3
1
6
4
5
0
6
2
8
6
6
2
5
7
6
7
3
0
2
5
2
5
8
4
7
2
6
1
0
4
0
5
6
7
3
0
8
5
7
0
4
3
0
5
6
1
2
2
3
7
3
8
5
3
1
4
0
6
8
8
0
5
3
0
2


  0%|                                                                                | 1/10000 [00:00<35:03,  4.75it/s]

6
1
2
1
8
5
0
0
7
4
0
2
8
2
8
5
1
6
0
2
0
5
6
4
6
4
5
1
6
8
7
5
4
0
7
2
4
0
4
4
2
7
2
8
8
6
2
6
0
5
3
8
1
3
1
1
3
8
7
3
3
5
1
1
2
3
1
4
0
5
4
5
0
0
1
2
4
1
0
1
3
5
1
6


  0%|                                                                                | 3/10000 [00:00<27:44,  6.01it/s]

2
8
2
1
2
3
7
8
2
0
0
5
1
8
5
4
5
7
2
6
5
1
5
6
1
0
5
7
6
2
4
0
0
3
7
1
3
0
4
1
6
7
0
5
0
8
8
8
7
2
0
4
1
0
4
7
4
1
7
7
3
3
8
5
4
7
2
0
5
7
7
3
7
3
8
6
2
3
6
8
3
8
5
7
3
2
6
3
3
5
8
7
6
7
3
2
1
0
7
8
0
4
7
1
6
0
6


  0%|                                                                                | 4/10000 [00:00<27:52,  5.98it/s]

1
2
6
6
5
0
6
7
2
2
1
7
6
0
2
5
0
1
6
6
0
5
3
1
1
8
2
0
2
8
4
6
3
4
3
0
3
5
3
3
2
4
1
2
5
1
1
4
3
6
5
0
1
6
7
6
8
5
7
3
7
0
2
8
7
1
3
3
6
4
7
8
0
2
1
7
5
8
0
6
2
4
0
8
4
1
0
8
6
5
5
0
6
1
6
5
2
6
7
4
3
6


  0%|                                                                                | 5/10000 [00:00<27:19,  6.10it/s]

2
2
5
5
3
3
3
6
5
6
3
5
5
6
8
3
3
4
6
7
2
6
5
4
7
4
5
2
6
5
0
5
0
3
1
6
8
4
5
7
2
1
7
7
6
6
8
0
5
3
7
0
3
0
3
1
4
6
3
5
4
1
0
6
8
6
7
0
0
4
2
4
2
7
8
6
7
3
1
2
1
6
4
4
2
7
6
4
1
4
4
4
2
8
7
8
8
6
0
2
1
3
3
7
1
7
5
3
7
3
3
3
7
3
4
2
2
6
1
3
2
6
4
0
7
0
2
8
6
7
2
1
2
4
0
5
2
7
2
0
8
8
7
8
6
4
6
3
3
7
2
1
1
6
7
8
4
1


  0%|                                                                                | 7/10000 [00:00<26:14,  6.35it/s]

8
3
8
5
7
5
4
1
3
7
0
8
5
2
4
5
4
5
0
7
7
3
8
8
2
4
3
5
2
1
7
3
6
7
0
1
6
8
0
6
5
1
1
3
7
3
5
7
5
0
5


  0%|                                                                                | 8/10000 [00:01<23:43,  7.02it/s]

7
8
2
1
6
7
5
7
6
6
3
1
8
4
5
8
8
1
1
2
0
0
8
2
2
3
0
0
5
2
4
0
1
6
2
2
8
6
1
1
8
3
1
2
1
1
6
7
1
8
3
8
0
5
6


  0%|                                                                                | 9/10000 [00:01<21:50,  7.62it/s]

4
6
4
8
6
5
8
4
5
6
1
8
7
3
6
2
4
3
5
4
3
7
1
7
4
2
7
2
4
4
3
1
1
8
2
4
6
4
4
1
5
5
0
3
6
1
3
0
6
8
1
8
4
1
4
5
3
8
6
0
6
0
5
8
4
8
1
4
1
4
4
6
4
0


  0%|                                                                               | 11/10000 [00:01<18:53,  8.81it/s]

1
6
1
7
4
7
4
5
0
1
2
3
1
2
5
1
3
0
6
1
4
1
0
6
1
0
5
4
4
1
7
1
8
8
3
4
1
7
7
8
1
0
0
4
1
2
4
0
3
1
6
4
3
1
7
5
8
2
3
4
2
8
8
6
1
4
2
3
0
7
8
4
3
8


  0%|1                                                                              | 14/10000 [00:01<15:29, 10.74it/s]

6
0
4
3
6
8
7
4
6
5
3
3
1
8
7
5
5
0
2
4
2
2
4
4
6
1
1
3
0
1
7
8
3
3
7
8
8
5
0
1
3
8
2
3
5
8
8
1
3
3
0
8
0
8
3
3
1
8
3
6
3
6
8
6
4
0
0
1
5
0
8
1
1
0
8
1
5
4
0
1
1
6
7
1
8
1
5
4
0
1
3
6
2
2
8
6
5
1
8
8
2


  0%|1                                                                              | 16/10000 [00:01<15:51, 10.50it/s]

1
6
6
8
0
6
0
0
0
0
3
4
0
2
4
0
6
6
7
8
0
0
0
7
3
4
3
5
7
3
5
1
3
0
5
1
7
1
3
4
4
0
5
8
2
6
0
0
5
7
8
7
7
2
1
8
7
7
0
5
6
0
8
7
2
1
8
7
0
2
5
3
0
8
8
5
5
0
5
6
7
8
6
4
7
6
4
1
1
7
2
2
3
3
0
4
5


  0%|1                                                                              | 18/10000 [00:01<15:40, 10.61it/s]

3
6
8
3
2
5
8
7
5
5
3
4
8
0
0
6
8
7
2
7
7
2
0
0
3
4
4
0
7
6
7
4
2

  0%|1                                                                              | 18/10000 [00:01<17:26,  9.54it/s]


KeyboardInterrupt: 

In [11]:
a = {1:'a',2:'b',3:'c'}

In [16]:
for key,value in a.items():
    print(key,value)

1 a
2 b
3 c
