In [1]:
pip install pygame

Note: you may need to restart the kernel to use updated packages.


In [8]:
import pygame

pygame 2.1.2 (SDL 2.0.18, Python 3.9.13)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [9]:
import torch 
import math
import random
import numpy as np
import matplotlib.pyplot as plt


In [10]:
from dataclasses import dataclass

In [11]:
from enum import Enum

class Direction(Enum):
    RIGHT = 1
    LEFT = 2
    UP = 3
    DOWN = 4

In [12]:
@dataclass
class Point:
    
    x: int
    y: int

In [7]:
#from collections import namedtuple

#Point = namedtuple('Point', 'x, y')

In [19]:



class SnakeGameAI:
    
    
    
    pygame.init()

    #Point = namedtuple('Point', 'x, y')
    BLOCK = 20
    SPEED = 4
  
    WHITE = (255, 255, 255)
    RED = (200,0,0)
    BLUE1 = (0, 0, 255)
    BLUE2 = (0, 100, 255)
    BLACK = (0,0,0)
    
    def __init__(self,w=640,h=480):
        self.w = w
        self.h = h
        
        self.display = pygame.display.set_mode((self.w,self.h))
        
        self.clock = pygame.time.Clock()
        self.reset()
    
    def reset(self):
        
        self.direction = Direction.RIGHT
        
        self.head = Point(self.w/2,self.h/2)
        
        self.snake = [self.head,Point(self.head.x - BLOCK,self.head.y),
                     Point(self.head.x - 2*BLOCK,self.head.y)]

        self.score = 0
        
        self.food = None
        
        self._place_food()
        
        self.frame_iteration = 0
        
    def _place_food(self):
        x = random.randint(0,self.w//BLOCK)*BLOCK
        y = random.randint(0,self.h//BLOCK)*BLOCK
        
        self.food = Point(x,y)
        
        if self.food in self.snake:
            self._place_food()
        
        
        
        
    def play_step(self, action):
        
        self.frame_iteration += 1
        
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
            
        #    if event.type == pygame.KEYDOWN:
         #       if event.key == pygame.K_LEFT:
          #          self.direction = Direction.LEFT
           #     if event.key == pygame.K_RIGHT:
            #        self.direction = Direction.RIGHT
             #   if event.key == pygame.K_UP:
              #      self.direction = Direction.UP
               # if event.key == pygame.K_DOWN:
                #    self.direction = Direction.DOWN
        
        
        self._move(action) ####only updates head
        self.snake.insert(0,self.head)
        
        ###after move is game over
                
        reward = 0

        game_over = False       
        
        if self._is_collision() or self.frame_iteration > 100*len(self.snake):
            game_over =  True
            reward-=10
            return reward,game_over, self.score
        
        if self.head == self.food:
            self.score += 1
            reward = 10
            self._place_food()
        
        else:
            self.snake.pop()
            
        
        self._update_ui()
        self.clock.tick(SPEED)
        
        
        return reward, game_over, self.score
        
        
        
    
    def _move(self,action):
        
        x = self.head.x
        y = self.head.y
        
        new_dir = None
        
        #######action format = [straight, right, left]  now if i will bw going right a right will be down
        clock_wise = [Direction.RIGHT,Direction.DOWN,Direction.LEFT,Direction.UP]
        idx = clock_wise.index(self.direction)
        
        if np.array_equal([action],[1,0,0]):
            new_dir = clock_wise[idx]
            
        if np.array_equal([action],[0,1,0]):
            new_dir = clock_wise[(idx+1)%4]
        
        if np.array_equal([action],[0,0,1]):
            new_dir = clock_wise[(idx-1)%4]
        
        self.direction = new_dir
        
        if self.direction == Direction.RIGHT:
            x += BLOCK
        if self.direction == Direction.LEFT:
            x -= BLOCK
        if self.direction == Direction.UP:
            y+= BLOCK
        if self.direction == Direction.DOWN:
            y-=BLOCK
        
        self.head  = Point(x,y)
        
    def _is_collision(self,pt=None):
        if pt is None:
            pt = self.head
        
        if pt.x > self.w - BLOCK or pt.x<0 or pt.y>self.h-BLOCK or pt.y<0:
            return True
        
        if self.head in self.snake[1:]:
            return True
        
        return False
            

    def _update_ui(self):
        self.display.fill(BLACK)
        
        for pt in self.snake:
            x = pt.x
            y = pt.y
            
            pygame.draw.rect(self.display, BLUE1, pygame.Rect(x,y, BLOCK,BLOCK))
            pygame.draw.rect(self.display, BLUE1, pygame.Rect(x+4,y+4, 12,12))
            
        pygame.draw.rect(self.display, RED, pygame.Rect(self.food.x,self.food.y, BLOCK,BLOCK))
        
        #text = font.render("score:" + str(self.score),True, WHITE)
        
        #self.display.blit(text, (0,0))
        
        pygame.display.flip()


In [2]:
from collections import deque

In [24]:
MAX_MEMORY = 100000
BATCH_SIZE = 1000
LR = 1e-6


class Agent:
    
    def __init__(self):
        self.n_games = 0
        self.epsilon = 0
        self.gamma = 0.9
        self.memory = deque(maxlen=MAX_MEMORY)
        self.model = Linear_QNet(11, 256, 3)
        self.trainer = QTrainer(self.model, lr=LR, gamma=self.gamma)
       
        
        
    
    def get_state(self,game):
        head = game.snake[0]
        
        point_l = Point(head.x-20, head.y)
        point_r = Point(head.x+20, head.y)
        point_u = Point(head.x, head.y-20)
        point_d = Point(head.x, head.y+20)
        
        dir_l = game.direction == Direction.LEFT
        dir_r = game.direction == Direction.RIGHT
        dir_u = game.direction == Direction.UP
        dir_d = game.direction == Direction.DOWN
        
        state = [
            
            (dir_r and game._is_collision(point_r)) or 
            (dir_l and game._is_collision(point_l)) or
            (dir_u and game._is_collision(point_u)) or
            (dir_d and game._is_collision(point_d)),
            
            (dir_u and game._is_collision(point_r)) or 
            (dir_d and game._is_collision(point_l)) or
            (dir_l and game._is_collision(point_u)) or
            (dir_r and game._is_collision(point_d)),
            
            (dir_d and game._is_collision(point_r)) or 
            (dir_u and game._is_collision(point_l)) or
            (dir_r and game._is_collision(point_u)) or
            (dir_l and game._is_collision(point_d)),
            
            dir_l,
            dir_r,
            dir_u,
            dir_d,
            
            
            game.food.x < game.head.x,
            game.food.x> game.head.x,
            game.food.y < game.head.y,
            game.food.y < game.head.y
            
        ]
        return state
    
    def remember(self,state,action,reward,next_state,done):
        self.memory.append((state,action,reward,next_state,done))
    
    def train_long_memory(self):
        if len(self.memory) > BATCH_SIZE:
            mini_sample  = random.sample(self.memory,BATCH_SIZE)
        else:
            mini_sample = self.memory
            
        
        states,actions,rewards,next_states,dones = zip(*mini_sample)    
        self.trainer.train_step(states,actions,rewards,next_states,dones)
    
    def train_short_memory(self,state,action,reward,next_state,done):
        self.trainer.train_step(state,action,reward,next_state,done)
    
    def get_action(self,state):
        self.epsilon = 80 - self.n_games
        final_move = [0,0,0]
        
        if random.randint(0,200) < self.epsilon:
            move = random.randint(0,2)
            
            final_move[move] = 1
        else:
            state0 = torch.tensor(state,dtype=torch.float)
            prediction = self.model(state0)
            move = torch.argmax(prediction).item()
            final_move[move] = 1
        return final_move
    
def train():
    plot_scores = []
    plot_mean_scores = []
    total_score = 0
    record = 0
    agent = Agent()
    game = SnakeGameAI()
    i = 0
    while True:
        if i == 100:
            break
        i+=1
        
        state_old = agent.get_state(game)
        
        final_move = agent.get_action(state_old)
        
        reward, done, score = game.play_step(final_move)
        
        state_new = agent.get_state(game)
        
        #train short
        
        agent.train_short_memory(state_old, final_move, reward, state_new, done)
        
        agent.remember(state_old, final_move, reward, state_new, done)
        
        if done:
            game.reset()
            agent.n_games += 1
            agent.train_long_memory()
            
            if score > record:
                record = score
                
                agent.model.save()
            print("Game",agent.n_games,'Score',score, 'Record',record)
        

if __name__ == '__main__':
    train()

Game 1 Score 0 Record 0
Game 2 Score 0 Record 0
Game 3 Score 0 Record 0
Game 4 Score 0 Record 0
Game 5 Score 0 Record 0
Game 6 Score 0 Record 0
Game 7 Score 0 Record 0
Game 8 Score 0 Record 0
Game 9 Score 0 Record 0
Game 10 Score 0 Record 0
Game 11 Score 0 Record 0
Game 12 Score 0 Record 0
Game 13 Score 0 Record 0
Game 14 Score 0 Record 0
Game 15 Score 0 Record 0
Game 16 Score 0 Record 0
Game 17 Score 0 Record 0
Game 18 Score 0 Record 0
Game 19 Score 0 Record 0
Game 20 Score 0 Record 0
Game 21 Score 0 Record 0
Game 22 Score 0 Record 0
Game 23 Score 0 Record 0
Game 24 Score 0 Record 0
Game 25 Score 0 Record 0
Game 26 Score 0 Record 0
Game 27 Score 0 Record 0
Game 28 Score 0 Record 0
Game 29 Score 0 Record 0
Game 30 Score 0 Record 0
Game 31 Score 0 Record 0
Game 32 Score 0 Record 0
Game 33 Score 0 Record 0
Game 34 Score 0 Record 0
Game 35 Score 0 Record 0
Game 36 Score 0 Record 0
Game 37 Score 0 Record 0
Game 38 Score 0 Record 0
Game 39 Score 0 Record 0
Game 40 Score 0 Record 0
Game 41 S

In [4]:
import torch.nn as nn
import torch.nn.functional as F
import os

In [22]:
class Linear_QNet(nn.Module):
    def __init__(self,input_size,hidden_size,output_size):
        super().__init__()
        self.linear1 = nn.Linear(input_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, output_size)
        
    def forward(self,x):
        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        return x
    
    def save(self, file_name = 'qnet.pth'):
        model_folder_path = './qnet'
        if not os.path.exists(model_folder_path):
            os.makedirs(model_folder_path)
        file_name = os.path.join(model_folder_path, file_name)
        torch.save(self.state_dict(), file_name)
        


In [21]:
class QTrainer:
    def __init__(self, model, lr, gamma):
        self.lr = lr
        self.model = model
        self.gamma = gamma
        self.optimizer = torch.optim.Adam(model.parameters(), lr=self.lr)
        self.criterion = nn.MSELoss()
        
    def train_step(self, state, action, reward,next_state,done):
        state = torch.tensor(state, dtype=torch.float)
        next_state = torch.tensor(next_state, dtype=torch.float)
        action = torch.tensor(action, dtype=torch.float)
        reward = torch.tensor(reward, dtype=torch.float32)
        
        if len(state.shape) == 1:
            state = torch.unsqueeze(state,0)
            next_state = torch.unsqueeze(next_state,0)
            action = torch.unsqueeze(action,0)
            reward = torch.unsqueeze(reward,0)
            done = (done,)
            
        pred = self.model(state)
        target = pred.clone()
        for idx in range(len(done)):
            Q_new = reward[idx]
            if not done[idx]:
                Q_new = reward[idx] + self.gamma * torch.max(self.model(next_state[idx]))

            target[idx][torch.argmax(action[idx]).item()] = Q_new
    
        # 2: Q_new = r + y * max(next_predicted Q value) -> only do this if not done
        # pred.clone()
        # preds[argmax(action)] = Q_new
        self.optimizer.zero_grad()
        loss = self.criterion(target, pred)
        loss.backward()

        self.optimizer.step()
        
        
        
        