In [9]:
# Imports
import curses
import time
import sys
import math
import os
import random
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib.image as img
import csv
import numpy as np
from collections import deque
from random import randint
from PIL import Image

# Tensorflow imports
import tensorflow as tf
from tensorflow.keras.models import Sequential,load_model
from tensorflow.keras.layers import Dense, Dropout, Conv2D,Conv1D, MaxPooling2D,MaxPooling1D, Activation, Flatten
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.losses import Huber
import tensorflow.python.keras.backend as backend



In [10]:
# RL settings
DISCOUNT = 0.89
REPLAY_MEMORY_SIZE = 10000  
MIN_REPLAY_MEMORY_SIZE = 1000 
MINIBATCH_SIZE = 128  
UPDATE_TARGET_EVERY = 5  
MODEL_NAME = 'snake'
LEARNING_RATE = 0.005

# Environment settings
EPISODES = 2000
ENV_SIZE = 10

# Exploration settings
EPSILON_DECAY = 0.95
MIN_EPSILON = 0.05

# Stats settings
AGGREGATE_STATS_EVERY = 25  
SHOW_PREVIEW = False

In [11]:
class Field:
    #Initialize game field
    def __init__(self,size):
        self.size = size
        # Used for terminal print
        self.icons = {
            0: ' . ',
            1: ' * ',
            2: ' # ',
            3: ' & ',
            4 : ' 0 '
        }
        # Used for RGB image
        self.d = {
            0: (255, 255, 255),
            1: (216, 173, 230),
            2: (0, 0, 255),
            3: (0,255,0),
            4: (0,0,0)
            }

        self.snake_coords = [[5, 3], [5, 4], [5, 5]]
        self._generate_field()
        self.add_entity()
        if self.get_entity_pos()==[-1,-1]:
            self.add_entity()

    def reset(self):
        # Reset field after episode
        self._generate_field()
        self.add_entity()
        self.snake_coords = [[5, 3], [5, 4], [5, 5]]
        if self.get_entity_pos()==[-1,-1]:
            self.add_entity()
        
    def add_entity(self):  
       # Add food to field 
        while(True):
            i = randint(2, self.size-2)
            j = randint(2, self.size-2)
            entity = [i, j]
            
            if entity not in self.snake_coords:
                self.field[i][j] = 3
                break
 
    def _generate_field(self):
        self.field = np.zeros((self.size, self.size), dtype=np.uint8) 
    
    def _clear_field(self):        
        self.field = [[j if j!= 1 and j!= 2 else 0 for j in i] for i in self.field]

    def render(self, show):
        # Render field enumeration 
        size = self.size
        self._clear_field()

        # Render snake on the field
        for i, j in self.snake_coords:
            self.field[i][j] = 1
        if self.get_entity_pos()==[-1,-1]:
            self.add_entity()

        # Mark head
        head = self.snake_coords[len(self.snake_coords)-1]
        self.field[head[0]][head[1]] = 2

        for i in range(self.size):
                self.field[i][0] = 4
                self.field[0][i] = 4
                self.field[self.size-1][i] = 4
                self.field[i][self.size-1] = 4

        # Print icons to terminal
        if show:
            for i in range(self.size):
                row = ''
                for j in range(self.size):
                    row += self.icons[ self.field[i][j] ]
                print(row,'\n')   
        return np.array(self.field)


    def get_image(self,game_field):
        # Convert enumerated field to RGB image
        env = np.zeros((self.size, self.size, 3), dtype=np.uint8) 
        
        for i in range(self.size):
            for j in range(self.size):
                if game_field[i][j]==0:
                    env[i][j] = self.d[0]
            if game_field[i][j]==1:
                    env[i][j] = self.d[1]
            if game_field[i][j]==2:
                    env[i][j] = self.d[2]
            if game_field[i][j]==3:
                    env[i][j] = self.d[3]
        img = Image.fromarray(env, 'RGB') 
        return img



    def get_entity_pos(self):
        for i in range(self.size):
            for j in range(self.size):
                if self.field[i][j] == 3:
                    return [i, j]
        return [-1, -1]


    def is_snake_eat_entity(self):
        entity = self.get_entity_pos()
        head = self.snake_coords[-1]
        return entity == head


In [12]:
class Snake:
    def __init__(self):
        self.direction = 0
        self.dist =8
        # Init coords
        self.coords = [[5, 3], [5, 4], [5, 5]]

    def reset(self):
        self.direction = 0

        # Init coords
        self.coords = [[5, 3], [5, 4], [5, 5]]
        
    def set_direction_normal(self, ch):

        # Check if wrong direction
        if ch == 1 and self.direction == 0:
            return
        if ch == 0 and self.direction == 1:
            return
        if ch == 3 and self.direction == 2:
            return
        if ch == 2 and self.direction == 3:
            return 

        self.direction = ch

    def set_direction_random(self):
        rand_dir = randint(0, 4)
        # Check if wrong direction
        if rand_dir == 0 and  self.direction != 0:
            self.direction = 1
        if rand_dir == 1 and  self.direction != 1:
            self.direction = 0
        if rand_dir == 2 and self.direction != 2:
            self.direction = 2
        if rand_dir == 3 and self.direction != 3:
            self.direction = 3



    def level_up(self):
        # get last point direction
        a = self.coords[0]
        b = self.coords[1]

        tail = a[:]

        if a[0] < b[0]:
            tail[0]-=1
        elif a[1] < b[1]:
            tail[1]-=1
        elif a[0] > b[0]:
            tail[0]+=1
        elif a[1] > b[1]:
            tail[1]+=1

        tail = self._check_limit(tail)
        self.coords.insert(0, tail)

    def is_dead(self):
        head = self.coords[-1]
        snake_body = self.coords[:-1]

        return head in snake_body

    def hit_wall(self):
        head = self.coords[-1]
        self.walls = []
        for i in range(self.field.size):
            self.walls.append([i,0])
            self.walls.append([0,i])
            self.walls.append([self.field.size-1,i])
            self.walls.append([i,self.field.size-1])

        return head in self.walls


    def _check_limit(self, point):
        # Check field limit
        if point[0] > self.field.size-1:
            point[0] = 0
        elif point[0] < 0:
            point[0] = self.field.size-1
        elif point[1] < 0:
            point[1] = self.field.size-1
        elif point[1] > self.field.size-1:
            point[1] = 0

        return point

    def move(self):
        # Determine head coords
        head = self.coords[-1][:]
        food =  self.field.get_entity_pos()
        
        done=False
        # Calc new head coords
        if self.direction == 3:
            head[0]-=1
        elif self.direction == 2:
            head[0]+=1
        elif self.direction == 0:
            head[1]+=1
        elif self.direction == 1:
            head[1]-=1
        dist=math.sqrt((head[0]-food[0])**2+(head[1]-food[1])**2)
        if(dist<self.dist):
            reward = -0.005
            self.dist = dist
        else:
            reward = -0.01
            self.dist = dist

        # Check field limit
        head = self._check_limit(head)

        del(self.coords[0])
        self.coords.append(head)
        self.field.snake_coords = self.coords

        if  self.is_dead():
            reward = -5
            done = True

        if self.hit_wall():
            reward = -5
            done =True
        # check if snake eat an entity
        if self.field.is_snake_eat_entity():
            #curses.beep()
            self.level_up()
            self.field.add_entity()
            reward = 1
        return done,reward


    def set_field(self, field):
        self.field = field

    def get_entity_pos(self):
            for i in range(ENV_SIZE):
                for j in range(ENV_SIZE):
                    if self.field[i][j] == 3:
                        return [i, j]

            return [-1, -1]

In [13]:
class Agent:
    def __init__(self):

        # Main model
        self.model = self.create_model()

        # Target network
        self.target_model = self.create_model()
        self.target_model.set_weights(self.model.get_weights())

        # An array with last n steps for training
        self.replay_memory = deque(maxlen=REPLAY_MEMORY_SIZE)
        
        self.target_update_counter = 0

    def create_model(self):

            model = Sequential()

            model.add(Conv1D(32,7,strides=2,activation='relu',padding='same', input_shape=(ENV_SIZE, ENV_SIZE)))  
            model.add(MaxPooling1D(pool_size=(3),padding='same'))
            
            model.add(Conv1D(64,3,activation='relu',padding='same'))  
            model.add(MaxPooling1D(pool_size=(2),padding='same'))

            model.add(Flatten()) 
            model.add(Dense(128,activation='relu'))
            model.add(Dense(128,activation='relu'))
            model.add(Dense(4, activation='linear')) 
            model.compile(loss='log_cosh', optimizer=Adam(learning_rate=LEARNING_RATE), metrics=['accuracy'])
            return model


    # Adds step's data to a memory(current observation, action, reward, new observation space, done)
    def update_replay_memory(self, transition):
        self.replay_memory.append(transition)

    # Trains main network every step during episode
    def train(self, terminal_state, step):

        # Start training only if certain number of samples is already saved
        if len(self.replay_memory) < MIN_REPLAY_MEMORY_SIZE:
            return

        # Get a minibatch of random samples 
        minibatch = random.sample(self.replay_memory, MINIBATCH_SIZE)

        # Get current states from minibatch, and Q values from NN
        current_states = np.array([transition[0] for transition in minibatch])/4
        current_qs_list = self.model.predict(current_states)

        # Get future states from minibatch and Q values from NN
        new_current_states = np.array([transition[3] for transition in minibatch])/4
        future_qs_list = self.target_model.predict(new_current_states)
    
        X = []
        y = []

        for index, (current_state, action, reward, new_current_state, done) in enumerate(minibatch):

            # If not a terminal state calc new Q-value
            if not done:
                max_future_q = np.max(future_qs_list[index])
                new_q = reward + DISCOUNT * max_future_q
            else:
                new_q = reward

            # Update Q value for given state
            current_qs = current_qs_list[index]
            current_qs[action] = new_q

            # And append to our training data
            X.append(current_state)
            y.append(current_qs)

        # Fit on all samples as one batch, log only on terminal state
        self.model.fit(np.array(X)/4, np.array(y), batch_size=MINIBATCH_SIZE, verbose=0, shuffle=False)

        # Update target counter
        if terminal_state:
            self.target_update_counter += 1

        # If counter reaches set value, update target network with weights of main network
        if self.target_update_counter > UPDATE_TARGET_EVERY:
            self.target_model.set_weights(self.model.get_weights())
            self.target_update_counter = 0

    # Queries main network for Q values given current observation space 
    def get_qs(self, state):
        return self.model.predict(np.array(state).reshape(-1, *state.shape)/4)[0]

In [None]:
def main():
    ep_rewards = []
    ep_score = []
    avgs = []
    max_rewards = []
    epsilon = 0.01
    random.seed(1)
    np.random.seed(1)
    tf.random.set_seed(1)

    if not os.path.isdir('models'):
        os.makedirs('models')
    if not os.path.isdir('images'):
        os.makedirs('images')
        
    field = Field(ENV_SIZE)
    snake = Snake()
    snake.set_field(field)
    agent = Agent()

    for episode in tqdm(range(1, EPISODES + 1), ascii=True, unit='episodes'):
        

        # Restarting episode - reset episode reward and step number
        episode_reward = 0
        score = 0
        step = 1
        new_step = 0
        average_reward = -11
        old_avg = -11

        # Reset environment and get initial state
        field.reset()
        snake.reset()
        current_field = field.render(False)
        current_state = field.get_image(current_field)

        # Reset flag and start iterating until episode ends
        done = False
        action = 0
        while not done:

            if np.random.random() > epsilon:
                # Get action from Q table
                action = np.argmax(agent.get_qs(current_field))
                
            else:
                # Get random action
                action = np.random.randint(0, 4)
          
            snake.set_direction_normal(action)
            done, reward= snake.move()
            if reward == 1:
                score += 1
                new_step = step

            new_field = field.render(False)
            new_state = field.get_image(new_field)

            if (step-new_step) >= 100*(round(episode/10000))+100:
                    new_step = step
                    done = True
                    reward = -5

            episode_reward += reward

            if SHOW_PREVIEW and not episode % AGGREGATE_STATS_EVERY:
                img = current_state.resize((400,400))
                img.save(f'images/training-ep_{episode}step__{step}.png')

            # Every step we update replay memory and train network
            agent.update_replay_memory((current_field, action, reward, new_field, done))
            agent.train(done, step)
            
            current_field = new_field
            current_state = new_state
            step += 1

        # Append episode reward to a list 
        ep_rewards.append(episode_reward)
        ep_score.append(score)
        if not episode % AGGREGATE_STATS_EVERY or episode == 1:
            old_avg = average_reward
            average_reward = sum(ep_rewards[-AGGREGATE_STATS_EVERY:])/len(ep_rewards[-AGGREGATE_STATS_EVERY:])
            min_reward = min(ep_rewards[-AGGREGATE_STATS_EVERY:])
            max_reward = max(ep_rewards[-AGGREGATE_STATS_EVERY:])
            avgs.append(average_reward)
            max_rewards.append(max_reward)
            old_avg = average_reward
            #print(f'{episode}Ep_{max_reward:_>7.2f}max_{average_reward:_>7.2f}avg_{min_reward:_>7.2f}min')
            
            #Save model
            agent.model.save(f'models/{MODEL_NAME}__{max_reward:_>7.2f}max_{average_reward:_>7.2f}avg_{min_reward:_>7.2f}min__{int(time.time())}.model')
                

        # Decay epsilon
        if epsilon > MIN_EPSILON:
            epsilon *= EPSILON_DECAY
            epsilon = max(MIN_EPSILON, epsilon)
        
        with open(f'Max-rewards.csv', 'w', encoding='UTF8') as f:
                        # create the csv writer
            writer = csv.writer(f)

                        # write a row to the csv file
            writer.writerow(max_rewards)

        with open('Score.csv', 'w', encoding='UTF8') as f:
                        # create the csv writer
            writer = csv.writer(f)

                        # write a row to the csv file
            writer.writerow(ep_score)

        with open(f'Average-rewards.csv', 'w', encoding='UTF8') as f:
                        # create the csv writer
            writer = csv.writer(f)

                        # write a row to the csv file
            writer.writerow(avgs)

    

main()

  0%|          | 0/2000 [00:00<?, ?episodes/s]

INFO:tensorflow:Assets written to: models/snake____-5.09max___-5.09avg___-5.09min__1639601531.model/assets


  1%|1         | 24/2000 [00:07<08:25,  3.91episodes/s]

INFO:tensorflow:Assets written to: models/snake____-4.02max___-4.99avg___-5.09min__1639601538.model/assets


  2%|2         | 49/2000 [00:15<08:30,  3.82episodes/s]

INFO:tensorflow:Assets written to: models/snake____-4.01max___-4.91avg___-5.06min__1639601546.model/assets


  4%|3         | 74/2000 [00:22<08:40,  3.70episodes/s]

INFO:tensorflow:Assets written to: models/snake____-4.01max___-4.83avg___-5.07min__1639601553.model/assets


  5%|4         | 99/2000 [00:29<06:48,  4.66episodes/s]

INFO:tensorflow:Assets written to: models/snake____-4.01max___-4.99avg___-5.06min__1639601560.model/assets


  6%|6         | 124/2000 [00:36<08:50,  3.53episodes/s]

INFO:tensorflow:Assets written to: models/snake____-3.05max___-4.95avg___-5.05min__1639601567.model/assets


  7%|7         | 149/2000 [00:44<07:23,  4.18episodes/s]

INFO:tensorflow:Assets written to: models/snake____-4.02max___-4.95avg___-5.05min__1639601575.model/assets


  9%|8         | 174/2000 [00:52<08:08,  3.74episodes/s]

INFO:tensorflow:Assets written to: models/snake____-5.01max___-5.03avg___-5.07min__1639601583.model/assets


 10%|9         | 199/2000 [00:59<07:56,  3.78episodes/s]

INFO:tensorflow:Assets written to: models/snake____-4.03max___-4.95avg___-5.07min__1639601591.model/assets


 11%|#1        | 224/2000 [01:19<35:32,  1.20s/episodes]

INFO:tensorflow:Assets written to: models/snake____-4.01max___-4.99avg___-5.08min__1639601611.model/assets


 12%|#2        | 249/2000 [02:05<1:14:28,  2.55s/episodes]

INFO:tensorflow:Assets written to: models/snake____-3.06max___-4.86avg___-5.12min__1639601659.model/assets


 14%|#3        | 274/2000 [02:56<51:11,  1.78s/episodes]

INFO:tensorflow:Assets written to: models/snake____-4.04max___-4.86avg___-5.12min__1639601709.model/assets


 15%|#4        | 299/2000 [03:53<1:05:39,  2.32s/episodes]

INFO:tensorflow:Assets written to: models/snake____-4.03max___-5.00avg___-5.12min__1639601765.model/assets


 16%|#6        | 324/2000 [04:52<1:22:43,  2.96s/episodes]

INFO:tensorflow:Assets written to: models/snake____-4.08max___-5.00avg___-5.14min__1639601825.model/assets


 17%|#7        | 349/2000 [06:10<1:23:45,  3.04s/episodes]

INFO:tensorflow:Assets written to: models/snake____-3.13max___-4.82avg___-5.15min__1639601903.model/assets


 19%|#8        | 374/2000 [07:29<1:12:01,  2.66s/episodes]

INFO:tensorflow:Assets written to: models/snake____-4.08max___-4.91avg___-5.17min__1639601983.model/assets


 20%|#9        | 399/2000 [08:55<1:40:30,  3.77s/episodes]

INFO:tensorflow:Assets written to: models/snake____-3.08max___-4.63avg___-5.17min__1639602070.model/assets


 21%|##1       | 424/2000 [10:28<1:21:41,  3.11s/episodes]

INFO:tensorflow:Assets written to: models/snake____-4.01max___-4.92avg___-5.28min__1639602164.model/assets


 22%|##2       | 449/2000 [11:57<1:36:21,  3.73s/episodes]

INFO:tensorflow:Assets written to: models/snake____-3.06max___-4.80avg___-5.24min__1639602253.model/assets


 24%|##3       | 474/2000 [13:23<1:34:06,  3.70s/episodes]

INFO:tensorflow:Assets written to: models/snake____-2.16max___-4.43avg___-5.16min__1639602337.model/assets


 25%|##4       | 499/2000 [15:00<1:32:01,  3.68s/episodes]

INFO:tensorflow:Assets written to: models/snake____-3.13max___-4.65avg___-5.23min__1639602435.model/assets


 26%|##6       | 524/2000 [16:35<1:29:49,  3.65s/episodes]

INFO:tensorflow:Assets written to: models/snake____-0.18max___-4.61avg___-5.21min__1639602532.model/assets


 27%|##7       | 549/2000 [18:02<1:34:55,  3.92s/episodes]

INFO:tensorflow:Assets written to: models/snake____-4.08max___-4.92avg___-5.30min__1639602618.model/assets


 29%|##8       | 574/2000 [20:09<2:56:24,  7.42s/episodes]

INFO:tensorflow:Assets written to: models/snake____-3.42max___-4.69avg___-5.42min__1639602744.model/assets


 30%|##9       | 599/2000 [21:56<2:08:20,  5.50s/episodes]

INFO:tensorflow:Assets written to: models/snake____-3.12max___-4.83avg___-5.35min__1639602853.model/assets


 31%|###1      | 624/2000 [23:53<2:24:10,  6.29s/episodes]

INFO:tensorflow:Assets written to: models/snake____-3.10max___-4.80avg___-5.47min__1639602970.model/assets


 32%|###2      | 649/2000 [26:35<2:34:12,  6.85s/episodes]

INFO:tensorflow:Assets written to: models/snake____-1.41max___-4.58avg___-5.76min__1639603133.model/assets


 34%|###3      | 674/2000 [29:06<1:56:53,  5.29s/episodes]

INFO:tensorflow:Assets written to: models/snake____-3.21max___-4.81avg___-5.75min__1639603284.model/assets


 35%|###4      | 699/2000 [32:16<1:23:33,  3.85s/episodes]

INFO:tensorflow:Assets written to: models/snake____-2.54max___-4.72avg___-5.62min__1639603486.model/assets


 36%|###6      | 724/2000 [34:31<1:13:24,  3.45s/episodes]

INFO:tensorflow:Assets written to: models/snake____-3.15max___-4.60avg___-5.28min__1639603609.model/assets


 37%|###7      | 749/2000 [37:07<2:19:10,  6.68s/episodes]

INFO:tensorflow:Assets written to: models/snake____-2.53max___-4.51avg___-5.75min__1639603775.model/assets


 39%|###8      | 774/2000 [39:12<1:43:47,  5.08s/episodes]

INFO:tensorflow:Assets written to: models/snake____-3.14max___-4.67avg___-5.32min__1639603887.model/assets


 40%|###9      | 799/2000 [41:29<58:15,  2.91s/episodes]  

INFO:tensorflow:Assets written to: models/snake____-1.31max___-4.38avg___-5.35min__1639604024.model/assets


 41%|####1     | 824/2000 [44:28<1:34:25,  4.82s/episodes]

INFO:tensorflow:Assets written to: models/snake____-1.30max___-4.43avg___-5.29min__1639604200.model/assets


 42%|####2     | 849/2000 [47:10<2:18:45,  7.23s/episodes]

INFO:tensorflow:Assets written to: models/snake____-2.28max___-4.42avg___-5.46min__1639604367.model/assets


 44%|####3     | 874/2000 [49:43<3:03:12,  9.76s/episodes]

INFO:tensorflow:Assets written to: models/snake____-2.52max___-4.38avg___-5.62min__1639604531.model/assets


 45%|####4     | 899/2000 [52:14<1:18:32,  4.28s/episodes]

INFO:tensorflow:Assets written to: models/snake____-2.18max___-4.63avg___-5.46min__1639604675.model/assets


 46%|####6     | 924/2000 [53:51<1:01:26,  3.43s/episodes]

INFO:tensorflow:Assets written to: models/snake____-3.08max___-4.36avg___-5.16min__1639604770.model/assets


 47%|####7     | 949/2000 [56:11<57:42,  3.29s/episodes]

INFO:tensorflow:Assets written to: models/snake_____0.51max___-4.30avg___-5.76min__1639604911.model/assets


 49%|####8     | 974/2000 [58:16<1:17:17,  4.52s/episodes]

INFO:tensorflow:Assets written to: models/snake____-3.12max___-4.64avg___-5.43min__1639605032.model/assets


 50%|####9     | 990/2000 [59:45<58:36,  3.48s/episodes]  