In [1]:
import pygame
import time
import random
import torch as T
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import gym
import keyboard as k

pygame 2.0.0 (SDL 2.0.12, python 3.6.10)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
class Network(nn.Module):
    def __init__(self,lr,input_dims,actions,l1_dim,l2_dim):
        super(Network,self).__init__()
        self.lr=lr
        self.input_dims=input_dims
        self.actions=actions
        self.l1_dim=l1_dim
        self.l2_dim=l2_dim
        self.l1=nn.Linear(self.input_dims,self.l1_dim)
        self.l2=nn.Linear(self.l1_dim,self.l2_dim)
        self.l3=nn.Linear(self.l2_dim,self.actions)
        self.optimizer=optim.Adam(self.parameters(),lr=self.lr)
        self.device=T.device('cuda:0' if T.cuda.is_available() else 'cpu:0')
        self.to(self.device)
        
    def forward(self,observation):
        state=T.Tensor(observation).to(self.device)
        x=F.relu(self.l1(state))
        x=F.relu(self.l2(x))
        x=self.l3(x)
        
        return x

In [3]:
class Agent():
    def __init__(self,alpha,beta,input_dims,actions,l1_dim,l2_dim,gamma):
        self.gamma=gamma
        self.log_probs=None
        self.actor=Network(alpha,input_dims,actions,l1_dim,l2_dim)
        self.critic=Network(beta,input_dims,1,l1_dim,l2_dim)
    
    def act(self,observation):
        probs=F.softmax(self.actor.forward(observation))
        action_probs=T.distributions.Categorical(probs)
        action=action_probs.sample()
        self.log_probs=action_probs.log_prob(action)
        return action.item()
    
    def learn(self,state,reward,new_state,done):
        self.actor.optimizer.zero_grad()
        self.critic.optimizer.zero_grad()
        
        q_val_state=self.critic.forward(state)
        q_val_new_state=self.critic.forward(new_state)
        
        delta=((reward+self.gamma*q_val_new_state*(1-int(done)))-q_val_state)
        
        actor_loss=-self.log_probs*delta
        critic_loss=delta**2
        
        (actor_loss+critic_loss).backward()
        
        self.actor.optimizer.step()
        self.critic.optimizer.step()   

In [4]:
# pygame.init()
 
# white = (255, 255, 255)
# yellow = (255, 255, 102)
# black = (0, 0, 0)
# red = (213, 50, 80)
# green = (0, 255, 0)
# blue = (50, 153, 213)
 
# dis_width = 600
# dis_height = 400
 
# dis = pygame.display.set_mode((dis_width, dis_height))
# pygame.display.set_caption('Snake Game by Edureka')
 
# clock = pygame.time.Clock()
 
# snake_block = 10
# snake_speed = 15
 
# font_style = pygame.font.SysFont("bahnschrift", 25)
# score_font = pygame.font.SysFont("comicsansms", 35)
 
 

def Your_score(score):
    value = score_font.render("Your Score: " + str(score), True, yellow)
    dis.blit(value, [0, 0])
 
def dist(x1,y1,x2,y2):
    return abs(x1-x2)+abs(y1-y2)
 
def our_snake(snake_block, snake_list):
    for x in snake_list:
        pygame.draw.rect(dis, black, [x[0], x[1], snake_block, snake_block])
 
 
def message(msg, color):
    mesg = font_style.render(msg, True, color)
    dis.blit(mesg, [dis_width / 6, dis_height / 3])
 
 
def gameLoop():
    game_over = False
    game_close = False
    canPress=0
    x1 = dis_width / 2
    y1 = dis_height / 2
    x1_change = 0
    y1_change = 0
    prev_x1_change = 0
    prev_y1_change = 0
    snake_List = []
    Length_of_snake = 1
    done=False
    foodx = round(random.randrange(0, dis_width - snake_block) / 10.0) * 10.0
    foody = round(random.randrange(0, dis_height - snake_block) / 10.0) * 10.0

    state=[x1/600,y1/400,foodx/600,foody/400]

    while True:
 
#         while game_close == True:
#             dis.fill(blue)
# #             message("You Lost! Press C-Play Again or Q-Quit", red)
#             Your_score(Length_of_snake - 1)
#             pygame.display.update()
#             pygame.quit()
#             for event in pygame.event.get():
#                 if event.type == pygame.KEYDOWN:
#                     if event.key == pygame.K_q:
#                         game_over = True
#                         game_close = False
#                     if event.key == pygame.K_c:
#                         gameLoop()
        if canPress>=4:
            action=agent.act(state)
            if action==0:
                action=np.random.choice([action,1,2,3,4], p=[0.995,1/800,1/800,1/800,1/800])
            elif action==1:
                action=np.random.choice([action,2,3,0,4], p=[0.995,1/800,1/800,1/800,1/800])
            elif action==2:
                action=np.random.choice([action,1,3,0,4], p=[0.995,1/800,1/800,1/800,1/800])
            elif action==3:
                action=np.random.choice([action,1,2,0,4], p=[0.995,1/800,1/800,1/800,1/800])
            elif action==4:
                action=np.random.choice([action,1,2,0,3], p=[0.995,1/800,1/800,1/800,1/800])
                
            
            if action==0:
                prev_x1_change=x1_change
                prev_y1_change=y1_change
                y1_change = -snake_block
                x1_change = 0
                prev_x1_change=x1_change
            elif action==1:
                prev_x1_change=x1_change
                prev_y1_change=y1_change
                y1_change = snake_block
                x1_change = 0 
            elif action==2:
                prev_x1_change=x1_change
                prev_y1_change=y1_change
                x1_change = -snake_block
                y1_change = 0
            elif action==3:
                prev_x1_change=x1_change
                prev_y1_change=y1_change
                x1_change = snake_block
                y1_change = 0
            elif action==4:
                pass
        
#         for event in pygame.event.get():
#             if event.type == pygame.QUIT:
#                 game_over = True
#             if event.type == pygame.KEYDOWN:
#                 if event.key == pygame.K_LEFT:
#                     x1_change = -snake_block
#                     y1_change = 0
#                 elif event.key == pygame.K_RIGHT:
#                     x1_change = snake_block
#                     y1_change = 0
#                 elif event.key == pygame.K_UP:
#                     y1_change = -snake_block
#                     x1_change = 0
#                 elif event.key == pygame.K_DOWN:
#                     y1_change = snake_block
#                     x1_change = 0            
            
 
        if x1 >= dis_width or x1 < 0 or y1 >= dis_height or y1 < 0:
            done = True
            reward=-100
        x1 += x1_change
        y1 += y1_change
        dis.fill(blue)
        pygame.draw.rect(dis, green, [foodx, foody, snake_block, snake_block])
        snake_Head = []
        snake_Head.append(x1)
        snake_Head.append(y1)
        snake_List.append(snake_Head)
        if len(snake_List) > Length_of_snake:
            del snake_List[0]
        for x in snake_List[:-1]:
            if x == snake_Head:
                game_close = True
                reward=-100
                done=True
 
        our_snake(snake_block, snake_List)
        Your_score(Length_of_snake - 1)
 
        pygame.display.update()
 
        if x1 == foodx and y1 == foody and done==False:
            foodx = round(random.randrange(0, dis_width - snake_block) / 10.0) * 10.0
            foody = round(random.randrange(0, dis_height - snake_block) / 10.0) * 10.0
            Length_of_snake += 1
            reward=100

        tail_x1=snake_List[0][0]
        tail_y1=snake_List[0][1]
        
        new_state=[x1/600,y1/400,foodx/600,foody/400]
        clock.tick(snake_speed)
        if done!=True:
            prevDist=dist(state[0],state[1],state[2],state[3])
            newDist=dist(new_state[0],new_state[1],new_state[2],new_state[3])
            if newDist>=prevDist:
                reward=-100
            elif newDist<prevDist:
                reward=10
                
        if canPress>=4:
            agent.learn(state,reward,new_state,done)
            canPress=0
        state=new_state
        if done==True:
            break
        canPress+=1
 

In [None]:
agent=Agent(0.00025,0.00025,4,5,32,32,0.995)
pygame.init()
white = (255, 255, 255)
yellow = (255, 255, 102)
black = (0, 0, 0)
red = (213, 50, 80)
green = (0, 255, 0)
blue = (50, 153, 213)

dis_width = 600
dis_height = 400

dis = pygame.display.set_mode((dis_width, dis_height))
pygame.display.set_caption('Snake Game')

clock = pygame.time.Clock()

snake_block = 10
snake_speed = 15

font_style = pygame.font.SysFont("bahnschrift", 25)
score_font = pygame.font.SysFont("comicsansms", 35)
for i in range(100):
    gameLoop()
pygame.quit()

  if __name__ == '__main__':
