In [1]:
from game import SnakeGame
from rewards import manhattan_reward, naive_reward, advanced_naive_reward, euclidean_reward
from qlearning import QLearning
from deepQL import DeepQLearning
import random
from tqdm import tqdm
import pygame

pygame 2.5.2 (SDL 2.28.3, Python 3.12.1)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
W = 200
H = 200

def benchmark(model):
    game = SnakeGame(W, H)
    scores = []
    for _ in tqdm(range(1000)):

        step = 0
        game_over = False
        game.reset()
        while not game_over:
            step += 1
            state = game.get_state()
            action = model.get_movement(state)
            _, score, game_over = game.play_step(action)

            if game_over:
                scores.append(score)
                break
    print('Scores:', scores)
    print('Mean Score:', sum(scores)/len(scores))







def play_snake(model):
    """Initialize and run the game loop"""
    pygame.init()

    game = SnakeGame(W, H)

    speed = 20
    clock = pygame.time.Clock()
    stop = False

    # game loop
    for i in range(10):
        game.reset()
        game_over = False
        while True:
            state = game.get_state()
            action = model.get_movement(state)

            _, score, game_over = game.play_step(action)
            game.pygame_draw()
            clock.tick(speed)

            if game_over:
                print('Game Over => Score:', score)
                break
    pygame.quit()

In [3]:
game = SnakeGame(200, 200)
Qmodel = QLearning(game)
DQmodel = DeepQLearning(game)


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
reward = advanced_naive_reward

In [12]:
Qmodel.train(50000, 200, reward)

100%|██████████| 50000/50000 [00:19<00:00, 2511.04it/s]

31





In [52]:
DQmodel.train(10, 200, reward)

100%|██████████| 10/10 [00:00<00:00, 12.50it/s]

10





DQN(
  (fc1): Linear(in_features=6, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=256, bias=True)
  (fc3): Linear(in_features=256, out_features=3, bias=True)
)

In [35]:
filename = 'QL/naive_policy.txt'
Qmodel.save_model(filename)

In [None]:
filename = 'QL/b_manhattan_policy.txt'
Qmodel.load_model(filename)

In [13]:
filename = 'DQL/manhattan_policy.txt'
DQmodel.save_model(filename)

In [4]:
filename = 'DQL/advance_naive_policy.txt'
DQmodel.load_model(filename)

In [60]:
benchmark(Qmodel)
benchmark(DQmodel)


100%|██████████| 1000/1000 [00:01<00:00, 609.21it/s]


Scores: [48, 29, 29, 29, 25, 37, 38, 19, 30, 34, 27, 35, 24, 24, 40, 36, 29, 20, 28, 33, 21, 27, 28, 32, 32, 37, 28, 37, 38, 34, 24, 23, 25, 23, 31, 33, 40, 24, 25, 33, 35, 27, 29, 21, 40, 26, 29, 35, 21, 42, 30, 19, 34, 33, 34, 31, 35, 31, 29, 20, 21, 25, 28, 26, 25, 22, 45, 35, 24, 13, 27, 33, 37, 27, 25, 22, 31, 27, 39, 36, 29, 30, 33, 24, 40, 29, 20, 27, 29, 31, 29, 43, 30, 27, 28, 21, 34, 37, 31, 26, 25, 22, 29, 25, 40, 22, 20, 25, 20, 29, 43, 28, 38, 28, 21, 24, 42, 22, 36, 28, 22, 29, 24, 21, 30, 32, 22, 35, 31, 30, 26, 19, 28, 32, 30, 21, 36, 21, 25, 22, 27, 26, 21, 18, 23, 28, 29, 28, 15, 28, 35, 23, 23, 31, 24, 36, 38, 27, 22, 35, 24, 21, 26, 33, 27, 11, 24, 37, 30, 24, 34, 23, 28, 26, 31, 38, 39, 35, 22, 34, 24, 24, 28, 23, 25, 25, 23, 37, 13, 28, 37, 35, 30, 20, 34, 27, 27, 17, 32, 25, 27, 35, 36, 38, 30, 32, 23, 29, 28, 40, 28, 17, 20, 35, 20, 29, 38, 27, 31, 28, 32, 29, 25, 21, 31, 28, 35, 37, 41, 31, 28, 25, 21, 33, 24, 31, 29, 25, 27, 33, 19, 29, 32, 24, 28, 25, 27, 37,

100%|██████████| 1000/1000 [00:13<00:00, 76.68it/s]

Scores: [32, 26, 30, 21, 19, 33, 32, 31, 12, 19, 21, 30, 21, 39, 8, 35, 31, 20, 24, 23, 34, 30, 28, 36, 19, 40, 24, 32, 20, 33, 21, 27, 23, 29, 25, 23, 34, 47, 29, 27, 7, 26, 35, 18, 24, 22, 26, 29, 36, 21, 26, 30, 22, 33, 26, 17, 36, 33, 28, 34, 25, 18, 37, 29, 27, 30, 42, 17, 36, 22, 20, 37, 34, 28, 24, 23, 23, 28, 34, 24, 28, 31, 27, 24, 20, 21, 43, 11, 18, 47, 23, 36, 13, 26, 28, 21, 41, 27, 23, 30, 34, 14, 28, 19, 23, 32, 50, 19, 26, 16, 27, 24, 33, 25, 17, 13, 13, 17, 26, 20, 39, 25, 24, 20, 25, 33, 17, 20, 27, 41, 34, 32, 15, 37, 20, 13, 27, 40, 37, 24, 28, 31, 30, 40, 16, 27, 33, 32, 22, 20, 26, 23, 30, 34, 27, 40, 26, 14, 26, 32, 19, 30, 34, 17, 30, 34, 24, 33, 17, 31, 20, 28, 9, 21, 30, 21, 20, 27, 36, 25, 37, 23, 19, 33, 32, 22, 26, 25, 19, 37, 18, 46, 16, 20, 31, 26, 17, 24, 29, 31, 20, 23, 30, 20, 20, 36, 40, 34, 22, 19, 24, 34, 37, 11, 28, 32, 33, 36, 34, 33, 37, 29, 19, 10, 13, 20, 16, 18, 23, 29, 30, 23, 27, 36, 42, 26, 25, 20, 38, 28, 20, 11, 26, 21, 26, 22, 33, 15, 13




In [5]:
play_snake(DQmodel)

Game Over => Score: 24
Game Over => Score: 24
Game Over => Score: 40
Game Over => Score: 32
Game Over => Score: 17
Game Over => Score: 34
Game Over => Score: 31
Game Over => Score: 12
Game Over => Score: 17
Game Over => Score: 14


: 