In [1]:
import pygame
import traceback
import numpy as np
import glob
import cv2
import math
import random
import os
import re
import sys
from collections import Counter
from pygame.locals import *
# pygame.init()

pygame 2.0.0.dev6 (SDL 2.0.10, python 3.8.3)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, BatchNormalization, Input, Flatten, Dense, Lambda, ReLU
from tensorflow.keras.optimizers import Adam
from tensorflow.python.keras.models import Model
from tensorflow.keras.losses import MeanAbsoluteError


In [3]:
from easydict import EasyDict as edict
config = edict()

config.CHECKPOINT_DIR = '/Users/vijay/Downloads/Code_Data/snake_game/checkpoints/'
config.NUM_EPOCHS     = 500
config.BATCH_SIZE     = 64
config.EPSILON        = 0.9
config.BUFFER_SAMPLE_RATE = 0.8
config.BUFFER_SAMPLE_RATE_DECAY_RATE = 0.1
config.DISCOUNT_FACTOR = 0.99
config.IMG_DIR = '/Users/vijay/Downloads/Code_Data/snake_game/screenshots/'
config.BUFFERS_NPY_DIR = '/Users/vijay/Downloads/Code_Data/snake_game/npy_files/'
config.FINAL_WEIGHTS_DIR = '/Users/vijay/Downloads/Code_Data/snake_game/final_weights/'

In [4]:
class settings:
    def __init__(self):
        
        self.screen_color = (100, 100, 100)
        self.screen_width = 200
        self.screen_height = 200
        self.run_game = True
        self.snake_nodes_group = []
        
        
class Snake_Node:
    def __init__(self, node_num, x, y, color, radius):
        
        self.snake_node_num = node_num
        self.x = x
        self.y = y
        self.color = color
        self.radius = radius
        self.direction = 'up'
        
    def draw_snake_node(self, screen):
#         snake_node = pygame.draw.circle(screen, self.color, (self.x, self.y), self.radius)
        snake_node = pygame.draw.rect(screen, self.color, (self.x, self.y, 10, 10), 1)
        return snake_node
    
    
    
class Snake_Head:
    def __init__(self, x, y, color, radius):
        self.x = x
        self.y = y
        self.color = color
        self.radius = radius
#         self.moving_direction = 'up'
    
    def draw_snake_head(self, screen):
        snake_head = pygame.draw.rect(screen, self.color, (self.x, self.y, 10, 10))
#         snake_head = pygame.draw.circle(screen, self.color, (self.x, self.y), self.radius)
        return snake_head
        
        
        
        

In [5]:


class Game_functions:
    
    ####################################
    
    def __init__(self):
        self.num_of_steps_in_curr_episode = 0
        self.start_training_gap_period = False
        self.is_apple_there = False
        self.game_settings = settings()
        self.divide_screen_into_segments()
        self.num_of_snake_nodes = 0
        self.snake_head_moving_direction = ' '
        self.xy_list = []
        self.radius = 10
        self.rect_width = 5
        self.rect_height = 5
        self.full_reward = 0
#         self.full_reward = False
        self.img_count = 0
        self.prev_distance = 0
        self.new_game = True
        self.score    = 10
        self.img_dir  = config.IMG_DIR
        
        
        self.create_dirs()
        self.set_up_the_screen(True)
        
    
    
    
    ###################################
    
    def create_dirs(self):
        if not os.path.exists(self.img_dir):
            os.makedirs(self.img_dir)
        
        if not os.path.exists(config.CHECKPOINT_DIR):
            os.makedirs(config.CHECKPOINT_DIR)
        
        if not os.path.exists(config.BUFFERS_NPY_DIR):
            os.makedirs(config.BUFFERS_NPY_DIR)
        
        if not os.path.exists(config.FINAL_WEIGHTS_DIR):
            os.makedirs(config.FINAL_WEIGHTS_DIR)
            
            
    ####################################
    
    def divide_screen_into_segments(self):
        self.segment_indices = []
        for i in range(0, 200):
            if i % 10 == 0:
                self.segment_indices.append(i)
                
                
    
    ####################################
    
    def draw_snake_nodes_xy(self, from_list):
        
        if from_list:
            for index, xy_tuple in enumerate(self.xy_list):
                if index != 0:
                    self.create_snake_head_or_apple(xy_tuple[0], xy_tuple[1], 5, False)
                    pygame.display.update()
        
        
    
    
    ###################################
    
    def get_new_xy_for_apple(self):
        x, y = random.choice(self.segment_indices),random.choice(self.segment_indices)
        if (x, y) not in self.xy_list:
            return x, y
        else:
            return self.get_new_xy_for_apple()
        
        
    ####################################
    
    def create_snake_head_or_apple(self, x, y, is_apple):
        
        if not is_apple:
            color = (255,255,255)
            
            if self.num_of_snake_nodes != 0:
                
                snake_node_obj = Snake_Node(self.num_of_snake_nodes, x, y, color, self.radius)
                snake_node = snake_node_obj.draw_snake_node(self.screen)
                
                
            elif self.num_of_snake_nodes == 0:
                snake_head_obj = Snake_Head(x, y, (0, 255, 0), self.radius)
                self.snake_head = snake_head_obj.draw_snake_head(self.screen)

                
            self.num_of_snake_nodes = self.num_of_snake_nodes + 1
            
            if self.num_of_snake_nodes < 5:
                pygame.time.wait(80)
            else:
                pygame.time.wait(10)
            
        if is_apple:
            color = (255, 0, 0)
            self.apple = pygame.draw.rect(self.screen, color, (x, y, 10, 10))
#             self.apple = pygame.draw.circle(self.screen, color, (x, y), self.radius)
            self.is_apple_there = True
            self.apple_x = x
            self.apple_y = y


    
    
    
        
    ####################################
    
    def start_a_new_game(self):
#         print('start a new game')
        if self.new_game:
#             self.set_up_the_screen()
            self.is_apple_there     = False
            self.num_of_snake_nodes = 0
            self.xy_list.clear()
            self.screen.fill((100, 100, 100))
            self.prev_distance      =  0
            self.score              = 10
            
            x, y = random.choice(self.segment_indices),random.choice(self.segment_indices)
            self.xy_list.append((x, y))
            self.draw_rect_at_new_location()
            self.full_reward = 0
            self.new_game = False
            self.end_game = False
            self.curr_reward = 0
#             self.set_up_the_screen()

    ####################################
    
    def check_if_snake_is_at_boundaries(self, new_x, new_y, direction):

        if direction == 'up':
            if new_y  < 0:
                return True
    

        if direction == 'down':
            if new_y  >= self.game_settings.screen_width:
                return True
    

        if direction == 'left':
            if new_x  < 0:
                return True
  

        if direction == 'right':
            if new_x  >= self.game_settings.screen_width:
                return True
        
        
        return False






    ####################################

    def draw_rect_at_new_location(self):
        
        self.screen.fill((100, 100, 100))
        self.num_of_snake_nodes = 0
        
        for xy_tuple in self.xy_list:        
            self.create_snake_head_or_apple(xy_tuple[0], xy_tuple[1], False)
    
        
        self.create_apple()
        pygame.display.update()


    
    ####################################
    
    def get_new_x_y_for_moving(self, x, y, direction):
        
        if direction == 'up':
            y = y - self.radius
            

        elif direction == 'down':
            y = y + self.radius
            

        elif direction == 'left':
            x = x - self.radius
            

        elif direction == 'right':
            x = x + self.radius
        
        
        return x, y
    
    
#     ###################################
    
    def move_nodes(self, head_direction):
        prev_tuple = (-1, -1)
        curr_tuple = (-1, -1)
        for index, xy_tuple in enumerate(self.xy_list):
            if index == 0:
                prev_tuple = xy_tuple
            elif index != 0:
                curr_tuple = self.xy_list[index]
                self.xy_list[index] = prev_tuple
                prev_tuple = curr_tuple
        
    
    
    ####################################
    
    def check_if_snake_ran_into_its_own_body(self):
        
        tuple_counter = Counter(self.xy_list)
        if tuple_counter[self.xy_list[0]] != 1:
            return True
        
        return False
        
 
    
    ####################################
    
    def move_snake_head_new(self, direction):
        
        '''
        First, get the new xy for the snake head to move
        check whether the snake head is at the boundaries wrt new xy
        if not, get new xy for all the snake nodes 
        update the xy_list with all the new xy values of both nodes and head
        after updating, check, after moving, if the head bumped into its own body
        '''
#         end_game_in_curr_step = False
        
        x, y = self.xy_list[0][0], self.xy_list[0][1]
        new_x, new_y = self.get_new_x_y_for_moving(x, y, direction)
#         if new_x == -1 and new_y == -1:
#             is_snake_at_boundary = True
#         else:
#             is_snake_at_boundary = False
#         print(new_x, new_y)
        is_snake_at_boundary = self.check_if_snake_is_at_boundaries(new_x, new_y, direction)

        if not is_snake_at_boundary:
#             print('not bound')
            self.move_nodes(direction)
            self.xy_list[0] = (new_x, new_y)
            snake_ran_into_its_own_body = self.check_if_snake_ran_into_its_own_body()
            if not snake_ran_into_its_own_body:
                self.new_game = False
#                 self.move_nodes(direction)
                self.draw_rect_at_new_location()
#                 pygame.display.update()
#                 print(new_x_, new_y)
            else:
                '''
                if the snake bumped into its own body, then a reward of -1 has to be given to the snake
                '''
                
                self.full_reward = -1
                self.new_game    = True
#                 end_game_in_curr_step = True
#                 self.end_game()
#                 self.start_a_new_game()
            

        else:
            '''
            if the snake bumped into any boundary, then a reward of -1 has to be given to the snake
            '''
#             print('bundar')
#             print(new_x, new_y)
            self.full_reward = -1
            self.new_game  = True
#             end_game_in_curr_step = True
#             self.end_game()
#             self.start_a_new_game()

        
#         return end_game_in_curr_step
        
    

    ####################################
    
    def create_apple(self):
        
        if not self.is_apple_there:
            x, y = self.get_new_xy_for_apple()
#             x, y = random.choice(self.segment_indices),random.choice(self.segment_indices)
            
            self.create_snake_head_or_apple(x, y, True)
        
        if self.is_apple_there:
            self.create_snake_head_or_apple(self.apple_x, self.apple_y, True)
            
        pygame.display.update()
    
    
    
    
    #####################################
    '''
    if an action is supplied (direction) and in that direction, if there is a node in that direction, then stop moving it
    
    '''
    def check_whether_to_move_nodes_or_not(self, direction):
        if direction == 'up':
            if self.xy_list[0][0] == self.xy_list[1][0] and self.xy_list[0][1] == self.xy_list[1][1] + self.radius:
                return False
        
        elif direction == 'down':
            if self.xy_list[0][0] == self.xy_list[1][0] and self.xy_list[0][1] == self.xy_list[1][1] - self.radius:
                return False
        
        elif direction == 'left':
            if self.xy_list[0][0] == self.xy_list[1][0] + self.radius and self.xy_list[0][1] == self.xy_list[1][1]:
                return False
            
        elif direction == 'right':
            if self.xy_list[0][0] == self.xy_list[1][0] - self.radius and self.xy_list[0][1] == self.xy_list[1][1]:
                return False
        
        return True
    
    
    
    #####################################
    
    def get_new_xy_for_new_node_after_eating_apple(self):
        
        last_node_x, last_node_y = self.xy_list[-1][0], self.xy_list[-1][1]
        new_x, new_y = -1, -1
        
        if last_node_x + self.radius < self.game_settings.screen_width:
            if (last_node_x + self.radius, last_node_y) not in self.xy_list: 
                new_x = last_node_x + self.radius
                new_y = last_node_y
        
        elif last_node_x - self.radius > 0:
            if (last_node_x - self.radius, last_node_y) not in self.xy_list:
                new_x = last_node_x - self.radius
                new_y = last_node_y
        
        
        elif last_node_y - self.radius > 0 :
            if (last_node_x, last_node_y - self.radius) not in self.xy_list: 
                new_x = last_node_x
                new_y = last_node_y -self.radius5
            
        elif last_node_y + self.radius < self.game_settings.screen_width:
            if (last_node_x, last_node_y + self.radius) not in self.xy_list:
                new_x = last_node_x
                new_y = last_node_y + self.radius
            
        
        return new_x, new_y
    
    
    
    #####################################
    
    def add_just_eaten_apple_to_the_snake_end(self):
        
        new_node_x, new_node_y = self.get_new_xy_for_new_node_after_eating_apple()
        self.xy_list.append((new_node_x, new_node_y))
        self.is_apple_there = False
        self.draw_rect_at_new_location()
        self.start_training_gap_period = True
        
        
    
    
    #####################################
    
    '''
    if apple's xy and snake_head's xy (first value in xy_list), then we consider that snake has eaten apple
    '''
    def is_apple_eaten(self):
        
        if self.is_apple_there:
            if self.apple_x == self.xy_list[0][0] and self.apple_y == self.xy_list[0][1]:
                '''
                if apple is eaten, a full reward of 1 has to be given to the snake
                '''
                self.full_reward = 1
                self.score = self.score + 10
                self.add_just_eaten_apple_to_the_snake_end()
                
    
    
    
    ###################################
    
    def check_if_snake_moved_closer_to_apple(self):
        
        snake_head_xy = self.xy_list[0]
        apple_xy      = (self.apple_x, self.apple_y)
        
        curr_distance = math.sqrt( ((snake_head_xy[0] - apple_xy[0]) ** 2) + ((snake_head_xy[1] - apple_xy[1]) ** 2))
        
        curr_length = self.score
        ratio = (curr_length + self.prev_distance) / (curr_length + curr_distance)
        reward = math.log(ratio, curr_length) # calculate logarithm base curr_length of ratio
        if reward <= -1:
            reward = reward + 1
        elif reward >= 1:
            reward = reward - 1
        self.prev_distance = curr_distance
        return reward
#             if curr_distance < self.prev_distance:
#                 return True
#             else: 
#                 False
#             self.prev_distance = curr_distance
        

    
    ###################################
    
    def get_reward(self):
        '''
        -- if snake_head moved towards apple, reward = 0.2
        --                            if not, reward = -0.2
        --- if snake bumped into its own body or bumped into boundaries, reward = -1
        --- if snake eats apple, reward = 1
        --- if snake cannot move, then reward = -0.1
        '''
        if self.full_reward == 0:
            distance_reward = self.check_if_snake_moved_closer_to_apple()
            return distance_reward
        
        else:
            return self.full_reward
        

        
        
    
    
    ##################################
    
    def get_curr_state_of_the_game(self):
        
        '''
        -- Stacking last 4 screenshots of the image as one
        -- if the num of screenshots is less than 4, then the last screenshot has to be appended 
            required num of times
        --- if the num of screenshots is greater than 4, then the recent last 4 screenshots have to be 
            appended
        '''
#         print('  previous')
#         for path in self.curr_screenshots_paths:
#             print(path)
        
#         print(' ')
        self.curr_screenshots_paths.clear()
        img_ids = []
        if self.img_count < 4 and self.img_count > 0:
            for id_ in range(self.img_count):
                img_ids.append(id_)
            remaining = 4 - self.img_count
            for _ in range(remaining):
                img_ids.append(img_ids[-1])

        elif self.img_count >= 4:
            for id_ in range(1, 5):
                img_ids.append(self.img_count - id_)
        
            img_ids.reverse()
        
        for id_ in img_ids:
            self.curr_screenshots_paths.append(self.img_dir + str(id_) + '.jpg')
        reward = self.get_reward()        
        end_game = self.new_game
        
#         self.curr_screenshot_path = curr_screenshot_path
        self.curr_reward = reward
        self.end_game    = end_game
        
    

        
    ###################################    
    '''
    -- if we started the game newly then is_game_just_started is set to True.
    -- if game ended because of the snake bumping into itself or into boundaries, then we have to end the current game and have to 
        start a new game. In this case, is_game_just_started is set to False
    '''
    def set_up_the_screen(self, is_game_just_started):
        pygame.init()
        pygame.display.init()
        self.screen = pygame.display.set_mode((self.game_settings.screen_width, self.game_settings.screen_height))
        pygame.display.flip()
        self.start_a_new_game()
        
        if is_game_just_started:
#             pygame.image.save(self.screen, self.img_dir + str(self.img_count) + '.jpg')
            self.img_count = self.img_count + 1
            self.curr_screenshots_paths = []
            self.curr_reward = 0
            self.end_game = False

        
    
    
    
    ####################################
    
    def end_the_game_and_start_a_new_one(self):
    
        pygame.display.quit()
        pygame.quit()
        exit()
        self.new_game = True
        self.set_up_the_screen(False)
        
        
    
    
    def end_the_game(self):
        pygame.display.quit()
        pygame.quit()
        exit()
        
    ####################################
    
    def run(self, action):
#         while self.game_settings.run_game:
        try:
            
            apple_eaten = False
            events = pygame.event.get()
            
            for event in events:
                if event.type == pygame.KEYDOWN:
                    if event.key == pygame.K_x:
                        self.game_settings.run_game = False
                        pygame.display.quit()
                        pygame.quit()
                        exit()
                    
                
            if self.num_of_snake_nodes == 1:
                move_nodes = True

            keys = pygame.key.get_pressed()
            
            if keys[pygame.K_x]:
                self.game_settings.run_game = False
                pygame.display.quit()
                pygame.quit()
                exit()

            ############

            if action == 'up':
                if self.num_of_snake_nodes > 1:
                    move_nodes = self.check_whether_to_move_nodes_or_not(action)
                    if not move_nodes:
                        self.full_reward = -0.1


                if move_nodes:

                    self.move_snake_head_new(action)
#                     self.img_count = self.img_count + 1


            ##############

            elif action == 'down':
                if self.num_of_snake_nodes > 1:
                    move_nodes = self.check_whether_to_move_nodes_or_not(action)
                    if not move_nodes:
                        self.full_reward = -0.1

                if move_nodes:

                    self.move_snake_head_new(action)
#                     self.img_count = self.img_count + 1


            ##############

            elif action == 'left':
                if self.num_of_snake_nodes > 1:
                    move_nodes = self.check_whether_to_move_nodes_or_not(action)
                    if not move_nodes:
                        self.full_reward = -0.1

                if move_nodes:

                    self.move_snake_head_new(action)
#                     self.img_count = self.img_count + 1




            ###########

            elif action == 'right':
                if self.num_of_snake_nodes > 1:
                    move_nodes = self.check_whether_to_move_nodes_or_not(action)
                    if not move_nodes:
                        self.full_reward = -0.1

                if move_nodes:

                    self.move_snake_head_new(action)
#                     self.img_count = self.img_count + 1
    
            if not self.new_game:
                self.is_apple_eaten()
                pygame.display.flip()
#             if not self.end_game:
                pygame.image.save(self.screen, self.img_dir + str(self.img_count) + '.jpg')
                self.img_count = self.img_count + 1
            self.get_curr_state_of_the_game()
#             else:
#                 self.end_the_game()
            
        except:
            pass

In [6]:
    
# game_functions = Game_functions()
# # game_functions.run('test')
# for i in range(200):
#     actions = ['right', 'right', 'right', 'left']
#     action_index = random.randint(0, 3)
#     action = actions[action_index]
#     game_functions.run(action)

In [7]:

class Q_values_approximator_model:
    def __init__(self):
        pass
        
    def normalize_to_range_01(self, img):
        return tf.cast(img, tf.float32) / 255.0
    
    def get_q_values(self):
        input_data = Input(shape = (64, 64, 12))
        data = Lambda(self.normalize_to_range_01)(input_data)
        data = Conv2D(filters = 32, kernel_size = 7, strides = (4, 4), padding = 'SAME')(data)
        data = ReLU()(data)
        
        data = Conv2D(filters = 64, kernel_size = 5, strides = (2, 2), padding = 'SAME')(data)
        data = ReLU()(data)
        
        data = Conv2D(filters = 128, kernel_size = 3, strides = (2, 2), padding = 'SAME')(data)
        data = ReLU()(data)
        
        data = Flatten()(data)
        data = Dense(512)(data)
        data = ReLU()(data)
        
        data = Dense(4)(data)
        
        self.model = tf.keras.Model(input_data, data)
        

    

In [8]:
class Training:
    def __init__(self):
        
        
        self.game_controls = Game_functions()
        self.training_gap_steps = 0
        self.experiences_buffer_1   = []
        self.experiences_buffer_2 = []
        self.observation_period_threshold = 50000
        self.num_of_steps_completed_in_obv_period = 0
        self.actions = ['up', 'down', 'right', 'left']
        
        q_values_predictor = Q_values_approximator_model()
        q_values_predictor.get_q_values()
        self.q_values_pred_model = q_values_predictor.model
        self.model_optimizer = Adam(learning_rate = 0.001)
        self.checkpoint = tf.train.Checkpoint(curr_epoch = tf.Variable(0),
                                              optimizer = self.model_optimizer,
                                              model = self.q_values_pred_model 
                                             )
        self.checkpoint_manager = tf.train.CheckpointManager(self.checkpoint, 
                                                            directory = config.CHECKPOINT_DIR,
                                                            max_to_keep = 3)
        self.mse_loss = MeanAbsoluteError()
        self.buffers_max_len = self.observation_period_threshold
        self.from_index_to_add_exp_buffer_1 = -1
        self.from_index_to_add_exp_buffer_2 = -1
        
    ############################################
        
    def preprocess_and_stack_images(self, curr_game_images_paths):
#         print(' ')
#         print(curr_game_images_paths)
#         print(' ')
        images = []
        for image_path in curr_game_images_paths:
#             print(image_path)
            image = cv2.resize(cv2.imread(image_path), (64, 64))
            images.append(image)
        
        stacked_images = np.concatenate((images[0], images[1], images[2], images[3]), axis = 2)
        return stacked_images
    
    
    
    
    ############################################
    
    '''
    After deciding upon action to be taken, run the game for one step
    Then, collect the next state images paths, reward received and whether the next state is terminal
    '''
    def perform_a_step_in_an_episode(self, action):
        
        self.game_controls.run(action)
        game_imgs_paths_after_action = self.game_controls.curr_screenshots_paths 
        reward_after_action = self.game_controls.curr_reward
        end_game_after_action = self.game_controls.end_game
        start_training_gap_period = self.game_controls.start_training_gap_period

        return game_imgs_paths_after_action, reward_after_action, end_game_after_action, start_training_gap_period
    
    
    
    
    ############################################
    
    def determing_num_of_training_gap_steps_after_snake_eating_apple(self):
        length_of_snake = self.game_controls.score 
        k = 10
        p = 0.4
        q = 0.2
        if length_of_snake <= k:
            self.training_gap_steps = 6
        elif length_of_snake > k:
            self.training_gap_steps = math.ceil((p * length_of_snake) + q)
    
    
    
    
    
    ############################################
    
    def check_if_buffers_reached_their_maximum_capacity(self, buffer_num):
        
        if buffer_num == 1: #and len(self.experiences_buffer_1) == self.buffers_max_len:
            if self.from_index_to_add_exp_buffer_1 == -1 or self.from_index_to_add_exp_buffer_1 > self.buffers_max_len - 1:
                self.from_index_to_add_exp_buffer_1 = 0

        
        elif buffer_num == 2: #and len(self.experiences_buffer_2) == self.buffers_max_len:
            if self.from_index_to_add_exp_buffer_2 == -1 or self.from_index_to_add_exp_buffer_2 > self.buffers_max_len - 1:
                self.from_index_to_add_exp_buffer_2 = 0
        
    
    
    
    
    ############################################
    
    def save_buffer_experiences_as_npy_files_periodically(self):
        npy_files = glob.glob(config.BUFFERS_NPY_DIR + '*')
        for file in npy_files:
            os.remove(file)
#         print(len(self.experiences_buffer_1), len(self.experiences_buffer_2))
        np.save(config.BUFFERS_NPY_DIR + 'buffer_1_npy.npy', self.experiences_buffer_1)
        np.save(config.BUFFERS_NPY_DIR + 'buffer_2_npy.npy', self.experiences_buffer_2)
    
    
    
    
    ############################################
    
    '''
    --- Get an action and run the game as per the action
    --- Now, record the updated state of the game by getting the curr screenshot of the screen, 
         current reward and also whether the current state resulted in quitting game which happens
         when snake bumps into boundaries or ran into itself.

    '''
    def add_experiences_to_buffers_after_running_game_for_one_step(self, action, curr_state_imgs_paths, 
                                                                               next_state_imgs_paths, reward_after_action, end_game):
        
#         print('  before ' + str(self.from_index_to_add_exp_buffer_1) +' ' + str(self.from_index_to_add_exp_buffer_2))
        self.check_if_buffers_reached_their_maximum_capacity(1)
        self.check_if_buffers_reached_their_maximum_capacity(2)

        buffer_1_len = len(self.experiences_buffer_1)
        buffer_2_len = len(self.experiences_buffer_2)
#         print('  after ' + str(self.from_index_to_add_exp_buffer_1) +' ' + str(self.from_index_to_add_exp_buffer_2))
#         for path in curr_state_imgs_paths:
#             print('   ' + str(path))
        
#         print(' ')
#         print(' ')
        if reward_after_action <= 0:
            
            if buffer_1_len == self.buffers_max_len:
                self.experiences_buffer_1[self.from_index_to_add_exp_buffer_1] = [curr_state_imgs_paths, action, reward_after_action, next_state_imgs_paths, end_game]
            elif buffer_1_len < self.buffers_max_len - 1:
                self.experiences_buffer_1.append([curr_state_imgs_paths, action, reward_after_action, next_state_imgs_paths, end_game])
            
            self.from_index_to_add_exp_buffer_1 = self.from_index_to_add_exp_buffer_1 + 1
            
#             if self.from_index_to_add_exp_buffer_1 == -1:
#                 self.experiences_buffer_1.append([curr_state_imgs_paths, action, reward_after_action, next_state_imgs_paths, end_game])
#             else:
#                 self.experiences_buffer_1[self.from_index_to_add_exp_buffer_1] = [curr_state_imgs_paths, action, reward_after_action, next_state_imgs_paths, end_game]
#                 self.from_index_to_add_exp_buffer_1 = self.from_index_to_add_exp_buffer_1 + 1
        
        else:
            if buffer_2_len == self.buffers_max_len:
                self.experiences_buffer_2[self.from_index_to_add_exp_buffer_2] = [curr_state_imgs_paths, action, reward_after_action, next_state_imgs_paths, end_game]
            elif buffer_2_len < self.buffers_max_len:
                self.experiences_buffer_2.append([curr_state_imgs_paths, action, reward_after_action, next_state_imgs_paths, end_game])
            
            self.from_index_to_add_exp_buffer_2 = self.from_index_to_add_exp_buffer_2 + 1
            
            
            
#             if self.from_index_to_add_exp_buffer_2 == -1:
#                 self.experiences_buffer_2.append([curr_state_imgs_paths, action, reward_after_action, next_state_imgs_paths, end_game])
#             else:
#                 self.experiences_buffer_2[self.from_index_to_add_exp_buffer_2] = [curr_state_imgs_paths, action, reward_after_action, next_state_imgs_paths, end_game]
#                 self.from_index_to_add_exp_buffer_2 = self.from_index_to_add_exp_buffer_2 + 1
        
    
    
    
    
    
    ############################################
    
    def start_exploration_in_observation_period(self):
        
        '''If not resuming (if we are staring now), that means, we have just started and there is no current state
        Therefore, for two steps, choose a random action and then carry on forward'''
        if self.game_controls.img_count == 1:
            steps_taken = 0
            while steps_taken < 5:
                
                action = self.actions[random.randint(0, 3)]
                self.perform_a_step_in_an_episode(action)
                
                steps_taken = steps_taken + 1
            

              
        num_of_training_gap_steps_completed = 0
        enter_training_gap_period = False
        
        self.game_controls.get_curr_state_of_the_game()
        curr_state_imgs_paths = self.game_controls.curr_screenshots_paths.copy()
        curr_reward           = self.game_controls.curr_reward
        end_game              = self.game_controls.end_game
        start_training_gap    = self.game_controls.start_training_gap_period
#         curr_state           = self.preprocess_and_stack_images(curr_game_imgs_paths)

        while self.num_of_steps_completed_in_obv_period < self.observation_period_threshold and not end_game:
        
            
            '''
            -- so, when the snake eats an apple, then we should immediately enter into a training_gap_period during which no experiences are
                stored into experience buffers.
            
            -- if start_training_gap and enter_training_gap_period == False, then we haven't entered the training_gap and therefore set
                enter_training_gap_period = True
            
            -- if not start_training_gap, then add experiences to the buffer. if in training_gap_period, then just perform single step in the game
                without storing experiences in the experience buffers. Also, increase the number of steps completed in trainig_gap_period
            
            -- if the num_of_steps in training_gap is equal to the pre-determined num_of_steps in training_gap_period, then exit training_gap_period
                by setting enter_training_gap_period = False and self.game_controls.start_training_gap_period = False
            
            '''
            
            if (self.num_of_steps_completed_in_obv_period % 5000 == 0 or self.num_of_steps_completed_in_obv_period == self.observation_period_threshold - 1) and self.num_of_steps_completed_in_obv_period > 0:
                self.save_buffer_experiences_as_npy_files_periodically()
                print('npy files saved and ' + str(self.num_of_steps_completed_in_obv_period) + ' steps completed')

                
                
            if start_training_gap and not enter_training_gap_period:
                enter_training_gap_period = True
                self.determing_num_of_training_gap_steps_after_snake_eating_apple()
            
            
            if enter_training_gap_period and (num_of_training_gap_steps_completed == self.training_gap_steps - 1):
                enter_training_gap_period = False
                self.game_controls.start_training_gap_period = False
            
            action = self.actions[random.randint(0, 3)]
            
            next_state_imgs_paths, reward_after_action, end_game, start_training_gap = self.perform_a_step_in_an_episode(action)
            
            
            
            
            if not enter_training_gap_period:

    
                self.add_experiences_to_buffers_after_running_game_for_one_step(action, curr_state_imgs_paths.copy(), 
                                                                               self.game_controls.curr_screenshots_paths.copy(), 
                                                                                reward_after_action, end_game)  
            elif enter_training_gap_period:
                num_of_training_gap_steps_completed = num_of_training_gap_steps_completed + 1

            curr_state_imgs_paths = next_state_imgs_paths.copy()
            self.num_of_steps_completed_in_obv_period = self.num_of_steps_completed_in_obv_period + 1
        

    
    
    ############################################
        
    def restore_checkpoint(self):

        if self.checkpoint_manager.latest_checkpoint:
            self.checkpoint.restore(self.checkpoint_manager.latest_checkpoint)
            print('restored checkpoint successfully at epoch ' + str(self.checkpoint.curr_epoch.numpy()))
        else:
            print('No checkpoint restoration')

    
    
    
    
    ############################################

    def epsilon_greedy_policy_for_action(self, curr_state_imgs_paths):
        if random.random() < config.EPSILON:
            return self.actions[random.randint(0, 3)]
        else:
            curr_state_imgs = self.preprocess_and_stack_images(curr_state_imgs_paths)
            q_values = self.checkpoint.model(np.expand_dims(curr_state_imgs, axis = 0), training = True)
            return self.actions[np.argmax(q_values.numpy()[0])]

    
    
    ############################################

    def adjust_buffer_sample_rate(self, epoch):
        if epoch % 1000 == 0:
            decay_value = config.BUFFER_SAMPLE_RATE * config.BUFFER_SAMPLE_RATE_DECAY_RATE
            new_buffer_sample_rate = config.BUFFER_SAMPLE_RATE - decay_value
            if not new_buffer_sample_rate > 0.5:
                config.BUFFER_SAMPLE_RATE = new_buffer_sample_rate




    
    ############################################
    
    def sample_mini_batches_from_experience_buffers(self):
        
        experience_1_buffer_length = len(self.experiences_buffer_1)
        experience_2_buffer_length = len(self.experiences_buffer_2)
        required_num_of_samples_from_buffer_1 = int(64 * config.BUFFER_SAMPLE_RATE)
        
        if experience_1_buffer_length < required_num_of_samples_from_buffer_1:
            num_samples_from_buffer_1 = experience_1_buffer_length
        else:
            num_samples_from_buffer_1 = int(64 * config.BUFFER_SAMPLE_RATE)
            
        num_samples_from_buffer_2 = 64 - num_samples_from_buffer_1

        random_samples_from_buffer_1 = random.sample(self.experiences_buffer_1, num_samples_from_buffer_1)
        random_samples_from_buffer_2 = random.sample(self.experiences_buffer_2, num_samples_from_buffer_2)

        return random_samples_from_buffer_1, random_samples_from_buffer_2



    
    
    ############################################

    def get_target_labels_for_samples(self, buffer_samples):
        '''
        experience[0] = curr_State_imgs_paths
        experience[1] = action [This is a text 'up', 'down', 'left', 'right']
        experience[2] = reward_after_action
        experience[3] = next_state_imgs_paths
        experience[4] = end_game
        '''

        
        temp_labels = []
        for experience in buffer_samples:
            if experience[4]:
                temp_labels.append([True, experience[0], experience[2]])
            
            else:
                temp_labels.append([False, experience[3], experience[2]])
        
        return temp_labels



    
    
    ############################################
    '''
    -- In total we would get a total of 64 experiences sampled from buffer experiences.
    -- Now, for actual labels, we have run through all the screenshots of the game through the model.
    -- In other words, either we should run the model for each screenshots in all the 64 experiences separately, or
    -- we can group together all the images in all the experiences as one and the we can run the model on this single set of all images.
    -- In the later case, the shape of the input to the model would be, [64, 64, 64, 12]
    -- 64 - total num of experiences (each experience will have one of the screenshots)
        64 , 64 - would be the size of the each image
        12 - here, each screenshot will have 3 channels and we consider recent 4 screenshots. If we concatenate all the 4 screenshots, 
        then the number of channels will be 12
    '''
    def get_actual_labels_for_samples(self, buffer_samples):
        actual_labels = []
        images = []
        action_values = []
        for experience in buffer_samples:
            curr_state_imgs_paths = experience[0]
            action = experience[1]
            action_index = self.actions.index(action)
            action_values.append(action_index)
            curr_state_imgs = self.preprocess_and_stack_images(curr_state_imgs_paths)
            images.append(curr_state_imgs)
          
        return images, action_values
    

    
    
    
    
    ############################################
    
    def loss_function(self, target_labels, actual_labels):
        return np.square(np.asarray(target_labels) - np.asarray(actual_labels))
    
    
    
    
    ############################################
    
    def train_step(self, buffer_samples):
        
            
        with tf.GradientTape(persistent = True) as params_tape:
            target_label_values_img_paths = self.get_target_labels_for_samples(buffer_samples)
            images, action_values = self.get_actual_labels_for_samples(buffer_samples)
            
            target_label_values_for_samples = []
            actual_label_values_for_samples = []
            '''
            entry is list of three values:
            either
                1. bool value telling second value is a label value (in this case, True)
                2. a label value (reward_after_action)
                3. -1
            
            or
                1. bool value telling we have to calculate the label value using next_state_img_paths (in this case, False)
                2. next_state_img_paths
                3. reward_after_action
            '''
            # calculation target_values
            for entry in target_label_values_img_paths:
                if entry[0]: # if game_ended at this step then label = reward
                    target_label_values_for_samples.append(entry[2])
                
                else: # else label = reward + gamma * max(model(images, action))
                    next_state_imgs = self.preprocess_and_stack_images(entry[1])
                    q_values_of_actions = self.checkpoint.model(np.expand_dims(next_state_imgs, axis = 0), training = True)
                    
#                     target_value = entry[2] + (config.DISCOUNT_FACTOR * np.max(q_values_of_actions.numpy()[0]))
                    target_value = entry[2] + (config.DISCOUNT_FACTOR * tf.math.reduce_max(q_values_of_actions))
                    target_label_values_for_samples.append(target_value)
            
            
            # calculating actual labels
            q_values_of_actions_a = self.checkpoint.model(np.asarray(images), training = True)
#             q_values_of_actions_numpy = q_values_of_actions_a
        
            for index, q_values in enumerate(q_values_of_actions_a):
#             for index, q_values in enumerate(q_values_of_actions_numpy):
                actual_label_values_for_samples.append(q_values[action_values[index]])

#             if self.count == 0:
#                 print(type(target_label_values_for_samples))
#                 print(target_label_values_for_samples)
#                 print('---------------')
#                 print(type(actual_label_values_for_samples))
#                 print(actual_label_values_for_samples)
                                                       
            loss = self.mse_loss(tf.convert_to_tensor(actual_label_values_for_samples), 
                                 tf.convert_to_tensor(target_label_values_for_samples))
            
        model_gradients = params_tape.gradient(tf.cast(loss, tf.float32), self.checkpoint.model.trainable_variables)
        
        self.checkpoint.optimizer.apply_gradients(zip(model_gradients, self.checkpoint.model.trainable_variables))

        return loss
                
        
    

    ############################################
    
    '''
    Now, when we resume game after a pause or a stop, then we have to know how many number of steps completed in the 
    observation period so that we can resume recording experiences from that step onwards. To achieve this, before starting
    recording experiences in observation period, we should how many number of experiences are there in each buffer which
    we can know by reading the npy files of the buffers stored in config.BUFFERS_NPY_DIR directory. If we add the number
    of experiences in both the npy files, then we can know how many total number of steps are completed in observation 
    period, as each experience is recorded after completing one individual game step.
    '''
    
    def check_the_num_of_steps_completed_in_observation_period(self):
        num_of_steps_completed_in_obsv_period = 0
        files = glob.glob(config.BUFFERS_NPY_DIR + '*')
        if len(files) != 0:
            for file in files:
                
                buffer_experiences = np.load(file, allow_pickle = True)
                if 'buffer_1_npy' in file:
                    self.experiences_buffer_1 = list(buffer_experiences)
                    self.from_index_to_add_exp_buffer_1 = len(self.experiences_buffer_1)
                
                else:
                    self.experiences_buffer_2 = list(buffer_experiences)
                    self.from_index_to_add_exp_buffer_2 = len(self.experiences_buffer_2)
                num_of_steps_completed_in_obsv_period = num_of_steps_completed_in_obsv_period + len(buffer_experiences)
                
            
            
            return num_of_steps_completed_in_obsv_period 
        else:
            return 0
    
    
    
    ############################################
    
    def check_for_in_screenshots_and_get_images_count(self):
        files = os.listdir(self.game_controls.img_dir)
        files = [file for file in files if 'DS' not in file]
        if len(files) != 0:
            '''Sort the image names using the integer part'''
            files.sort(key=lambda f: int(re.sub('\D', '', f)))
            recent_img_name = files[-1]
            match_object = re.search('\D', recent_img_name)
            num = int(recent_img_name[0:match_object.start()])
            return num
        else:
            return 1
    
    
    
    ##########################################
    
    def observation_period(self):
        '''
        For the first 50000 seteps, just let the agent choose random actions and explore very
        extensively so tha it could have a vast experience buffer
        '''
#         try:
        self.game_controls.img_count = self.check_for_in_screenshots_and_get_images_count()
        obsv_period_steps_completed = self.check_the_num_of_steps_completed_in_observation_period()
        
        if self.game_controls.img_count > 45000:
            self.num_of_steps_completed_in_obv_period = self.observation_period_threshold - 1
            
        
        else:
        
            if self.game_controls.img_count > obsv_period_steps_completed:
                self.num_of_steps_completed_in_obv_period = self.game_controls.img_count
            else:
                self.num_of_steps_completed_in_obv_period = obsv_period_steps_completed

            print('num of steps completed are ' + str(self.num_of_steps_completed_in_obv_period))
            if not self.num_of_steps_completed_in_obv_period == self.observation_period_threshold - 1:
                self.start_exploration_in_observation_period()
                
                while self.num_of_steps_completed_in_obv_period < self.observation_period_threshold:
                    self.game_controls.end_the_game_and_start_a_new_one()
                    self.start_exploration_in_observation_period()

#                 self.game_controls.end_the_game_and_start_a_new_one()
        print(' ')
        print(' final count is ' + str(self.game_controls.img_count))
#         self.game_controls.self.game_controls.end_the_game_and_start_a_new_one()()
        print('exploration period ended')
    
    
    
    
    
    '''
    If snake failed to eat any apple in the past P steps, it receives a negative as a punishment
    '''
#     def check_and_apply_negative_reward_if_snake_not_eat_apple_for_p_steps(self):
        
    
    
    ############################################
    
    def train(self):
        
        start = True
        imgs_count = self.check_for_in_screenshots_and_get_images_count()
#         obsv_period_steps_completed = self.check_the_num_of_steps_completed_in_observation_period()
        
#         print(self.from_index_to_add_exp_buffer_1, self.from_index_to_add_exp_buffer_2)
        
        self.restore_checkpoint()
        epochs_completed = self.checkpoint.curr_epoch.numpy()
        epochs_remaining = config.NUM_EPOCHS - epochs_completed
        
#         self.game_controls.end_the_game()
#         self.game_controls    = Game_functions()
        self.game_controls.img_count = imgs_count
        loss_log = tf.keras.metrics.Mean('perc_loss', dtype = tf.float32)
        for epoch in range(epochs_remaining):
            
            
            print('epoch ' + str(epoch))
            
            self.game_controls.end_the_game_and_start_a_new_one()
            curr_epoch = self.checkpoint.curr_epoch.numpy()
            
            '''
            When observation period ended and when we start training, at the start of each epoch, we are doing end_the_game_and_start_a_new_game() leaving
            curr_state_of_the_game as it is. In other words, when observation period ended and if we are starting with 0th epoch, 
            end_game = Fase,
            reward = -1 (for bumping into boundaries or into itself)
            and so on.
            
            Therefore, 
            '''
            
            curr_state_imgs_paths = self.game_controls.curr_screenshots_paths.copy()
            end_game = self.game_controls.end_game
#             action_to_be_taken = self.epsilon_greedy_policy_for_action(curr_state_imgs_paths)
#             is_training_start = False
#             next_state_imgs_paths, curr_reward, end_game, start_training_gap_period = self.perform_a_step_in_an_episode(action_to_be_taken)
                

#             curr_state_imgs_paths = next_state_imgs_paths.copy()
            '''If the current state is not terminal i.e. not end_game'''
            while not end_game: 
                action_to_be_taken = self.epsilon_greedy_policy_for_action(curr_state_imgs_paths)
                next_state_imgs_paths, curr_reward, end_game, start_training_gap_period = self.perform_a_step_in_an_episode(action_to_be_taken)
                
                self.add_experiences_to_buffers_after_running_game_for_one_step(action_to_be_taken, curr_state_imgs_paths.copy(),
                                                                                    self.game_controls.curr_screenshots_paths.copy(),
                                                                                    curr_reward,
                                                                                    end_game)
                curr_state_imgs_paths = self.game_controls.curr_screenshots_paths.copy()



                '''Train the model on random samples from both the experience buffers'''
                self.adjust_buffer_sample_rate(epoch)
                buffer_1_samples, buffer_2_samples = self.sample_mini_batches_from_experience_buffers()
                buffer_samples = buffer_1_samples + buffer_2_samples
                loss = self.train_step(buffer_samples)
                loss_log.update_state(loss)
#                 print('this episode finished')
    
                    
            
            if curr_epoch % 10 == 0 and curr_epoch > 0:
                self.save_buffer_experiences_as_npy_files_periodically()
                self.checkpoint_manager.save()
                print('In epoch ' + str(curr_epoch) + ' the loss is ' + str(loss_log.result()))
                loss_log.reset_states()


            if curr_epoch == config.NUM_EPOCHS - 1:
                self.checkpoint.model.save_weights(config.FINAL_WEIGHTS_DIR + 'snake_game_weights.h5')
                self.save_buffer_experiences_as_npy_files_periodically()
                self.game_controls.end_the_game()


            if curr_epoch != config.NUM_EPOCHS - 1:
                self.checkpoint.curr_epoch.assign_add(1)
            
#             if end_game:
#                 self.game_controls.end_the_game_and_start_a_new_one()
#                 self.game_controls.new_game = True
#                 self.game_controls.set_up_the_screen(False)
                    
#         except Exception:
#             self.game_controls.end_the_game()
#             traceback.print_exc()
                
                
                
    

In [None]:
'''If we are starting game (and so observation period) from start then resuming should be False
else, True'''

train_snake_game = Training()
train_snake_game.observation_period()

num of steps completed are 1
npy files saved and 5000 steps completed
npy files saved and 10000 steps completed


In [None]:
train_snake_game.train()

In [None]:
import numpy as np
file = np.load('/Users/vijay/Downloads/Code_Data/snake_game/npy_files/buffer_1_npy.npy', allow_pickle = True)
filee = np.load('/Users/vijay/Downloads/Code_Data/snake_game/npy_files/buffer_2_npy.npy', allow_pickle = True)
print(len(file), len(filee))

In [None]:
for p in file:

    print(p)
    print(' ')

In [None]:
files = os.listdir('/Users/vijay/Downloads/Code_Data/snake_game/screenshots/')
files = [file for file in files if 'DS' not in file]
if len(files) != 0:
    '''Sort the image names using the integer part'''
    files.sort(key=lambda f: int(re.sub('\D', '', f)))
    recent_img_name = files[-1]
    match_object = re.search('\D', recent_img_name)
    num = int(recent_img_name[0:match_object.start()])
    print(num)



In [None]:
import random
random.randint(0,3)