In [40]:
import pygame
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os
import numpy as np

In [41]:
class Linear_QNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.linear1 = nn.Linear(input_size,hidden_size)
        self.linear2 = nn.Linear(hidden_size,output_size)

    def forward(self,x):
        x = F.relu(self.linear1(x))
        x = self.linear2(x)
        return x

In [42]:
state_dict = torch.load(r'./models/DQL_model_100.pth')
model = Linear_QNet(5,128,3)
model.load_state_dict(state_dict)
model.eval()

Linear_QNet(
  (linear1): Linear(in_features=5, out_features=128, bias=True)
  (linear2): Linear(in_features=128, out_features=3, bias=True)
)

In [60]:
pygame.init()

width, height = 500, 600
window = pygame.display.set_mode((width,height))
pygame.display.set_caption("Catch The Egg!")

#colors
yellow = (225,225,0)
white = (255,255,255)
black = (0,0,0)
green = (0,255,0)
red = (255,0,0)
blue = (0,0,255)

#Basket
basket_width = 30
basket_height = 50
basket_x = width//2-basket_width//2
basket_y = height - basket_height - 10
basket_speed = 10

basket = pygame.Rect(basket_x,basket_y,basket_width,basket_height)

#Egg
egg_width = 23
egg_height = 35
egg_speed = 5

def create_egg():
    egg_x = random.randint(10,width-egg_width-10)
    egg_y = 0
    return pygame.Rect(egg_x,egg_y,egg_width,egg_height)

eggs = [create_egg()]

score = 0
action = [0,0,0]
font = pygame.font.Font(None, 36)
over_font = pygame.font.Font(None, 72)

def get_state(action):
    global eggs, basket
    egg = eggs[0]
    state = [
        action[0], action[1], action[2],
        egg.left <= basket.left,
        egg.left > basket.left
    ]
    return np.array(state, dtype=int)

def perform_action(action):
    global basket, eggs, score, basket_speed, egg_speed
    if action[0] == 1 and basket.left>0:
        basket.x -= basket_speed
    elif action[2] == 1 and basket.right<width:
        basket.x += basket_speed
    game_over = False

    for egg in eggs:
        egg.y += egg_speed
        if egg.y > height:
            game_over = True
            return score, game_over
        if basket.colliderect(egg):
            eggs[0] = create_egg()
            score+=1

    return score, game_over

clock = pygame.time.Clock()
running = True
while running:
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            running = False
    state = get_state(action)
    
    action = [0,0,0]

    state = torch.tensor(state, dtype=torch.float)
    pred = model(state)
    move_idx = torch.argmax(pred).item()
    action[move_idx] = 1

    score, done = perform_action(action)

    if done:
        running = False

    window.fill(black)
    pygame.draw.rect(window, yellow, basket)
    for egg in eggs:
        pygame.draw.ellipse(window, white, egg)

    score_text = font.render(f'Score: {score}', True, green)
    window.blit(score_text, (width - 150, 10))

    pygame.display.flip()
    clock.tick(30)

window.fill(black)
game_over_text = over_font.render(f'Game Over', True, red) 
final_score_text = over_font.render(f'Final Score: {score}', True, blue)
window.blit(game_over_text, (width//2-game_over_text.get_width()//2, height//2-game_over_text.get_height()//2-50))
window.blit(final_score_text, (width//2-final_score_text.get_width()//2, height//2-final_score_text.get_height()//2))
pygame.display.flip()
pygame.time.wait(2000)

pygame.quit()