In [1]:
import pygame
import sys
import math
import random
import copy
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from gymnasium.envs.registration import register
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import DummyVecEnv
import math
import random

pygame 2.5.2 (SDL 2.28.3, Python 3.11.4)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
SCREEN_WIDTH = 640
SCREEN_HEIGHT = 480
TILESIZE = 32
FPS = 60
DEPTH = 5

PLAYER_SPEED = 2
ENEMY_SPEED = 2
COOLDOWN = 1

PLAYER_LAYER = 3
ENEMY_LAYER = 2
WALL_LAYER = 1
GROUND_LAYER = 0

RED = (255, 0, 0)
BLACK = (0 ,0, 0)
BLUE = (0, 0, 255)
WHITE = (255,255,255)

tilemap = [
    'WWWWWWWWWWWWWWWWWWWW',
    'W.........E........W',
    'W..................W',
    'W..................W',
    'W..................W',
    'W..................W',
    'W..................W',
    'W..................W',
    'W..................W',
    'W..................W',
    'W..................W',
    'W..................W',
    'W..................W',
    'W.........P........W',
    'WWWWWWWWWWWWWWWWWWWW',
]

In [3]:
def get_heuristic(state):
    v = (state[2] - state[5]) - (abs(math.dist([state[0], state[1]], [state[3], state[4]])))/64
    return v

def minimax(state, depth, alpha, beta, maximizingPlayer, game):
        valid_moves = game.actions()
        if depth == 0 or state[2] <= 0 or state[5] <= 0:
            return get_heuristic(state)
        if maximizingPlayer:
            value = -np.Inf
            for mv in valid_moves:
                child = game.next_state(mv, state, True)
                value = max(value, minimax(child, depth-1, alpha, beta, False, game))
                alpha = max(alpha, value)
                if value >= beta:
                    break
            return value
        else:
            value = np.Inf
            for mv in valid_moves:
                child = game.next_state(mv, state, False)
                value = min(value, minimax(child, depth-1, alpha, beta, True, game))
                beta = min(beta, value)
                if value <= alpha:
                    break
            return value

In [4]:
class Spritesheet():
    def __init__(self, file):
        self.sheet = pygame.image.load(file).convert()

    def get_sprite(self, x, y, width, height):
        sprite = pygame.Surface([width, height])
        sprite.blit(self.sheet, (0, 0), (x, y, width, height))
        sprite.set_colorkey(BLACK)
        return sprite

In [5]:
class Player(pygame.sprite.Sprite):
    def __init__(self, game, x, y, algo=None):
        self.health = 800

        self.last_att = 0
        self.attack_cool = COOLDOWN
        
        self.game = game
        self._layer = PLAYER_LAYER
        self.groups = self.game.all_sprites, self.game.players, self.game.chars
        pygame.sprite.Sprite.__init__(self, self.groups)
        
        self.x = x * TILESIZE
        self.y = y * TILESIZE
        self.x_change = 0
        self.y_change = 0
        
        self.facing = 'up'
        self.ani_loop = 1
        self.width = TILESIZE
        self.height = TILESIZE

        self.image = self.game.character_spritesheet.get_sprite(64, 32, self.width, self.height)
        
        self.rect = self.image.get_rect()
        self.rect.x = self.x
        self.rect.y = self.y

        self.algo = algo

    def update(self):
        if self.algo == None:
            self.movement()
        elif self.algo == "minimax":
            self.movement_ai()
        self.animate()
        
        self.rect.x += self.x_change
        self.collide_blocks("x")
        self.collide_enemy("x")
        self.rect.y += self.y_change
        self.collide_blocks("y")
        self.collide_enemy("y")

        self.x_change = 0
        self.y_change = 0

    def movement(self):
        keys = pygame.key.get_pressed()
        if keys[pygame.K_a]:
            self.x_change -= PLAYER_SPEED
            self.facing = 'left'
        if keys[pygame.K_d]:
            self.x_change += PLAYER_SPEED
            self.facing = 'right'
        if keys[pygame.K_s]:
            self.y_change += PLAYER_SPEED
            self.facing = 'down'
        if keys[pygame.K_w]:
            self.y_change -= PLAYER_SPEED
            self.facing = 'up'

    def movement_ai(self):
        scores = dict(zip(self.game.actions(), [minimax(self.game.next_state(mv, self.game.get_state(), True), DEPTH-1, -np.Inf, np.Inf, False, self.game) for mv in self.game.actions()]))
        max_cols = [key for key in scores.keys() if scores[key] == max(scores.values())]
        move = random.choice(max_cols)

        if move == "left":
            self.x_change -= PLAYER_SPEED
            self.facing = 'left'
        elif move == "right":
            self.x_change += PLAYER_SPEED
            self.facing = 'right'
        elif move == "down":
            self.y_change += PLAYER_SPEED
            self.facing = 'down'
        elif move == "up":
            self.y_change -= PLAYER_SPEED
            self.facing = 'up'
        elif move == "a_left":
            curr = pygame.time.get_ticks() / 1000
            if curr - self.last_att >= self.attack_cool:
                Attack(self.game, self.rect.x-TILESIZE, self.rect.y, "left", self)
                self.last_att = curr
        elif move == "a_right":
            curr = pygame.time.get_ticks() / 1000
            if curr - self.last_att >= self.attack_cool:
                Attack(self.game, self.rect.x+TILESIZE, self.rect.y, "right", self)
                self.last_att = curr
        elif move == "a_down":
            curr = pygame.time.get_ticks() / 1000
            if curr - self.last_att >= self.attack_cool:
                Attack(self.game, self.rect.x, self.rect.y+TILESIZE, "down", self)
                self.last_att = curr
        elif move == "a_up":
            curr = pygame.time.get_ticks() / 1000
            if curr - self.last_att >= self.attack_cool:
                Attack(self.game, self.rect.x, self.rect.y-TILESIZE, "up", self)
                self.last_att = curr

    def collide_enemy(self, direction):
        if direction == "x":
            hit = pygame.sprite.spritecollide(self, self.game.enemies, False)
            if hit:
                if self.x_change > 0:
                    self.rect.x = hit[0].rect.left - self.rect.width
                if self.x_change < 0:
                    self.rect.x = hit[0].rect.right

        if direction == "y":
            hit = pygame.sprite.spritecollide(self, self.game.enemies, False)
            if hit:
                if self.y_change > 0:
                    self.rect.y = hit[0].rect.top - self.rect.height
                if self.y_change < 0:
                    self.rect.y = hit[0].rect.bottom

    def collide_blocks(self, direction):
        if direction == "x":
            hit = pygame.sprite.spritecollide(self, self.game.blocks, False)
            if hit:
                if self.x_change > 0:
                    self.rect.x = hit[0].rect.left - self.rect.width
                if self.x_change < 0:
                    self.rect.x = hit[0].rect.right

        if direction == "y":
            hit = pygame.sprite.spritecollide(self, self.game.blocks, False)
            if hit:
                if self.y_change > 0:
                    self.rect.y = hit[0].rect.top - self.rect.height
                if self.y_change < 0:
                    self.rect.y = hit[0].rect.bottom

    def animate(self):
        down_ani = [
            self.game.character_spritesheet.get_sprite(32, 64, self.width, self.height),
            self.game.character_spritesheet.get_sprite(64, 64, self.width, self.height),
            self.game.character_spritesheet.get_sprite(32, 64, self.width, self.height),
            self.game.character_spritesheet.get_sprite(96, 64, self.width, self.height)
        ]

        up_ani = [
            self.game.character_spritesheet.get_sprite(64, 32, self.width, self.height),
            self.game.character_spritesheet.get_sprite(96, 32, self.width, self.height),
            self.game.character_spritesheet.get_sprite(64, 32, self.width, self.height),
            self.game.character_spritesheet.get_sprite(0, 64, self.width, self.height)
        ]

        right_ani = [
            self.game.character_spritesheet.get_sprite(32, 0, self.width, self.height),
            self.game.character_spritesheet.get_sprite(0, 0, self.width, self.height),
            self.game.character_spritesheet.get_sprite(32, 0, self.width, self.height),
            self.game.character_spritesheet.get_sprite(64, 0, self.width, self.height)
        ]

        left_ani = [
            self.game.character_spritesheet.get_sprite(0, 32, self.width, self.height),
            self.game.character_spritesheet.get_sprite(96, 0, self.width, self.height),
            self.game.character_spritesheet.get_sprite(0, 32, self.width, self.height),
            self.game.character_spritesheet.get_sprite(32, 32, self.width, self.height)
        ]

        if self.facing == "down":
            if self.y_change == 0:
                self.image = self.game.character_spritesheet.get_sprite(32, 64, self.width, self.height)
            else:
                self.image = down_ani[math.floor(self.ani_loop)]
                self.ani_loop += 0.1
                if self.ani_loop >= 4:
                    self.ani_loop = 1
        if self.facing == "up":
            if self.y_change == 0:
                self.image = self.game.character_spritesheet.get_sprite(64, 32, self.width, self.height)
            else:
                self.image = up_ani[math.floor(self.ani_loop)]
                self.ani_loop += 0.1
                if self.ani_loop >= 4:
                    self.ani_loop = 1
        if self.facing == "left":
            if self.x_change == 0:
                self.image = self.game.character_spritesheet.get_sprite(0, 32, self.width, self.height)
            else:
                self.image = left_ani[math.floor(self.ani_loop)]
                self.ani_loop += 0.1
                if self.ani_loop >= 4:
                    self.ani_loop = 1
        if self.facing == "right":
            if self.x_change == 0:
                self.image = self.game.character_spritesheet.get_sprite(32, 0, self.width, self.height)
            else:
                self.image = right_ani[math.floor(self.ani_loop)]
                self.ani_loop += 0.1
                if self.ani_loop >= 4:
                    self.ani_loop = 1

In [6]:
class Enemy(pygame.sprite.Sprite):
    def __init__(self, game, x, y):
        self.health = 800

        self.last_att = 0
        self.attack_cool = COOLDOWN
        
        self.game = game
        self._layer = ENEMY_LAYER
        self.groups = self.game.all_sprites, self.game.enemies, self.game.chars
        pygame.sprite.Sprite.__init__(self, self.groups)

        self.x = x * TILESIZE
        self.y = y * TILESIZE
        self.x_change = 0
        self.y_change = 0
        
        self.facing = 'down'
        self.ani_loop = 1
        self.width = TILESIZE
        self.height = TILESIZE

        self.image = self.game.enemy_spritesheet.get_sprite(32, 64, self.width, self.height)
        
        self.rect = self.image.get_rect()
        self.rect.x = self.x
        self.rect.y = self.y

    def update(self):
        self.animate()
        
        self.rect.x += self.x_change
        self.collide_blocks("x")
        self.collide_enemy("x")
        self.rect.y += self.y_change
        self.collide_blocks("y")
        self.collide_enemy("y")

        self.x_change = 0
        self.y_change = 0

    def movement(self, move):
        player = None
        for elem in self.game.all_sprites:
            if isinstance(elem, Player):
                player = elem
        
        if math.dist([self.rect.x,self.rect.y],[player.rect.x,player.rect.y]) <= 2*TILESIZE:
            PLAYER_SPEED = 2
            ENEMY_SPEED = 1
        else:
            PLAYER_SPEED = 3
            ENEMY_SPEED = 3

        if move == 4:
            self.y_change -= ENEMY_SPEED
            self.facing = "up"
        elif move == 5:
            self.y_change += ENEMY_SPEED
            self.facing = "down"
        elif move == 6:
            self.x_change -= ENEMY_SPEED
            self.facing = "left"
        elif move == 7:
            self.x_change += ENEMY_SPEED
            self.facing = "right"
        curr = pygame.time.get_ticks() / 1000
        if curr - self.last_att >= self.attack_cool:
            if move == 0:
                Attack(self.game, self.rect.x, self.rect.y-TILESIZE, "up", self)
            elif move == 1:
                Attack(self.game, self.rect.x, self.rect.y+TILESIZE, "down", self)
            elif move == 2:
                Attack(self.game, self.rect.x-TILESIZE, self.rect.y, "left", self)
            elif move == 3:
                Attack(self.game, self.rect.x+TILESIZE, self.rect.y, "right", self)
            self.last_att = curr

    def collide_enemy(self, direction):
        if direction == "x":
            hit = pygame.sprite.spritecollide(self, self.game.players, False)
            if hit:
                if self.x_change > 0:
                    self.rect.x = hit[0].rect.left - self.rect.width
                if self.x_change < 0:
                    self.rect.x = hit[0].rect.right

        if direction == "y":
            hit = pygame.sprite.spritecollide(self, self.game.players, False)
            if hit:
                if self.y_change > 0:
                    self.rect.y = hit[0].rect.top - self.rect.height
                if self.y_change < 0:
                    self.rect.y = hit[0].rect.bottom
    
    def collide_blocks(self, direction):
        if direction == "x":
            hit = pygame.sprite.spritecollide(self, self.game.blocks, False)
            if hit:
                if self.x_change > 0:
                    self.rect.x = hit[0].rect.left - self.rect.width
                if self.x_change < 0:
                    self.rect.x = hit[0].rect.right

        if direction == "y":
            hit = pygame.sprite.spritecollide(self, self.game.blocks, False)
            if hit:
                if self.y_change > 0:
                    self.rect.y = hit[0].rect.top - self.rect.height
                if self.y_change < 0:
                    self.rect.y = hit[0].rect.bottom

    def animate(self):
        down_ani = [
            self.game.enemy_spritesheet.get_sprite(32, 64, self.width, self.height),
            self.game.enemy_spritesheet.get_sprite(64, 64, self.width, self.height),
            self.game.enemy_spritesheet.get_sprite(32, 64, self.width, self.height),
            self.game.enemy_spritesheet.get_sprite(96, 64, self.width, self.height)
        ]

        up_ani = [
            self.game.enemy_spritesheet.get_sprite(64, 32, self.width, self.height),
            self.game.enemy_spritesheet.get_sprite(96, 32, self.width, self.height),
            self.game.enemy_spritesheet.get_sprite(64, 32, self.width, self.height),
            self.game.enemy_spritesheet.get_sprite(0, 64, self.width, self.height)
        ]

        right_ani = [
            self.game.enemy_spritesheet.get_sprite(32, 0, self.width, self.height),
            self.game.enemy_spritesheet.get_sprite(0, 0, self.width, self.height),
            self.game.enemy_spritesheet.get_sprite(32, 0, self.width, self.height),
            self.game.enemy_spritesheet.get_sprite(64, 0, self.width, self.height)
        ]

        left_ani = [
            self.game.enemy_spritesheet.get_sprite(0, 32, self.width, self.height),
            self.game.enemy_spritesheet.get_sprite(96, 0, self.width, self.height),
            self.game.enemy_spritesheet.get_sprite(0, 32, self.width, self.height),
            self.game.enemy_spritesheet.get_sprite(32, 32, self.width, self.height)
        ]

        if self.facing == "down":
            if self.y_change == 0:
                self.image = self.game.enemy_spritesheet.get_sprite(32, 64, self.width, self.height)
            else:
                self.image = down_ani[math.floor(self.ani_loop)]
                self.ani_loop += 0.1
                if self.ani_loop >= 4:
                    self.ani_loop = 1
        if self.facing == "up":
            if self.y_change == 0:
                self.image = self.game.enemy_spritesheet.get_sprite(64, 32, self.width, self.height)
            else:
                self.image = up_ani[math.floor(self.ani_loop)]
                self.ani_loop += 0.1
                if self.ani_loop >= 4:
                    self.ani_loop = 1
        if self.facing == "left":
            if self.x_change == 0:
                self.image = self.game.enemy_spritesheet.get_sprite(0, 32, self.width, self.height)
            else:
                self.image = left_ani[math.floor(self.ani_loop)]
                self.ani_loop += 0.1
                if self.ani_loop >= 4:
                    self.ani_loop = 1
        if self.facing == "right":
            if self.x_change == 0:
                self.image = self.game.enemy_spritesheet.get_sprite(32, 0, self.width, self.height)
            else:
                self.image = right_ani[math.floor(self.ani_loop)]
                self.ani_loop += 0.1
                if self.ani_loop >= 4:
                    self.ani_loop = 1

In [7]:
class Ground(pygame.sprite.Sprite):
    def __init__(self, game, x, y):
        self.game = game
        self._layer = GROUND_LAYER
        self.groups = self.game.all_sprites
        pygame.sprite.Sprite.__init__(self, self.groups)
        self.x = x * TILESIZE
        self.y = y * TILESIZE
        self.width = TILESIZE
        self.height = TILESIZE
        self.image = self.game.tile_spritesheet.get_sprite(0, 0, self.width, self.height)
        self.rect = self.image.get_rect()
        self.rect.x = self.x
        self.rect.y = self.y

In [8]:
class Wall(pygame.sprite.Sprite):
    def __init__(self, game, x, y):
        self.game = game
        self._layer = WALL_LAYER
        self.groups = self.game.all_sprites, self.game.blocks
        pygame.sprite.Sprite.__init__(self, self.groups)
        self.x = x * TILESIZE
        self.y = y * TILESIZE
        self.width = TILESIZE
        self.height = TILESIZE
        self.image = self.game.tile_spritesheet.get_sprite(33, 0, self.width, self.height)
        self.rect = self.image.get_rect()
        self.rect.x = self.x
        self.rect.y = self.y
        

In [9]:
class Button:
    def __init__(self, x, y, width, height, fg, bg, content, fontsize):
        self.font = pygame.font.SysFont('Courier New', fontsize)
        self.content = content
        self.x = x
        self.y = y
        self.width = width
        self.height = height
        self.fg = fg
        self.bg = bg
        self.image = pygame.Surface((self.width, self.height))
        self.image.fill(self.bg)
        self.rect = self.image.get_rect()
        self.rect.x = x
        self.rect.y = y
        self.text = self.font.render(self.content, True, self.fg)
        self.text_rect = self.text.get_rect(center=(self.width/2,self.height/2))
        self.image.blit(self.text, self.text_rect)

    def is_pressed(self, pos, pressed):
        if self.rect.collidepoint(pos):
            if pressed[0]:
                return True
            return False
        return False

In [10]:
class Attack(pygame.sprite.Sprite):
    def __init__(self, game, x, y, d, p):
        self.game = game
        self._layer = PLAYER_LAYER
        self.groups = self.game.all_sprites, self.game.attacks
        pygame.sprite.Sprite.__init__(self, self.groups)
        self.x = x
        self.y = y
        self.width = TILESIZE
        self.height = TILESIZE
        self.ani_loop = 0
        self.image = self.game.attack_spritesheet.get_sprite(0,0,self.width,self.height)
        self.rect = self.image.get_rect()
        self.rect.x = self.x
        self.rect.y = self.y
        self.direction = d
        self.player = p

    def update(self):
        self.animate()
        self.collide()

    def collide(self):
        hit = pygame.sprite.spritecollide(self, self.game.chars, False)
        if hit:
            for char in hit:
                if self.player != char:
                    char.health -= 10
                    if char.health <= 0:
                        char.kill()
                        self.game.playing = False

    def animate(self):
        down_ani = [
            self.game.attack_spritesheet.get_sprite(64, 32, self.width, self.height),
            self.game.attack_spritesheet.get_sprite(96, 32, self.width, self.height),
            self.game.attack_spritesheet.get_sprite(128, 32, self.width, self.height),
            self.game.attack_spritesheet.get_sprite(160, 32, self.width, self.height),
            self.game.attack_spritesheet.get_sprite(0, 64, self.width, self.height),
            self.game.attack_spritesheet.get_sprite(32, 64, self.width, self.height),
            self.game.attack_spritesheet.get_sprite(64, 64, self.width, self.height),
            self.game.attack_spritesheet.get_sprite(96, 64, self.width, self.height)
        ]

        up_ani = [
            self.game.attack_spritesheet.get_sprite(0, 0, self.width, self.height),
            self.game.attack_spritesheet.get_sprite(32, 0, self.width, self.height),
            self.game.attack_spritesheet.get_sprite(64, 0, self.width, self.height),
            self.game.attack_spritesheet.get_sprite(96, 0, self.width, self.height),
            self.game.attack_spritesheet.get_sprite(128, 0, self.width, self.height),
            self.game.attack_spritesheet.get_sprite(160, 0, self.width, self.height),
            self.game.attack_spritesheet.get_sprite(0, 32, self.width, self.height),
            self.game.attack_spritesheet.get_sprite(32, 32, self.width, self.height)
        ]

        right_ani = [
            self.game.attack_spritesheet.get_sprite(128, 64, self.width, self.height),
            self.game.attack_spritesheet.get_sprite(160, 64, self.width, self.height),
            self.game.attack_spritesheet.get_sprite(0, 96, self.width, self.height),
            self.game.attack_spritesheet.get_sprite(32, 96, self.width, self.height),
            self.game.attack_spritesheet.get_sprite(64, 96, self.width, self.height),
            self.game.attack_spritesheet.get_sprite(96, 96, self.width, self.height),
            self.game.attack_spritesheet.get_sprite(128, 96, self.width, self.height),
            self.game.attack_spritesheet.get_sprite(160, 96, self.width, self.height)
        ]

        left_ani = [
            self.game.attack_spritesheet.get_sprite(0, 128, self.width, self.height),
            self.game.attack_spritesheet.get_sprite(32, 128, self.width, self.height),
            self.game.attack_spritesheet.get_sprite(64, 128, self.width, self.height),
            self.game.attack_spritesheet.get_sprite(96, 128, self.width, self.height),
            self.game.attack_spritesheet.get_sprite(128, 128, self.width, self.height),
            self.game.attack_spritesheet.get_sprite(160, 128, self.width, self.height),
            self.game.attack_spritesheet.get_sprite(0, 160, self.width, self.height),
            self.game.attack_spritesheet.get_sprite(32, 160, self.width, self.height)
        ]

        if self.direction == "down":
                self.image = down_ani[math.floor(self.ani_loop)]
                self.ani_loop += 1
                if self.ani_loop >= 8:
                    self.ani_loop = 0
                    self.kill()
        if self.direction == "up":
                self.image = up_ani[math.floor(self.ani_loop)]
                self.ani_loop += 1
                if self.ani_loop >= 8:
                    self.ani_loop = 0
                    self.kill()
        if self.direction == "left":
                self.image = left_ani[math.floor(self.ani_loop)]
                self.ani_loop += 1
                if self.ani_loop >= 8:
                    self.ani_loop = 0
                    self.kill()
        if self.direction == "right":
                self.image = right_ani[math.floor(self.ani_loop)]
                self.ani_loop += 1
                if self.ani_loop >= 8:
                    self.ani_loop = 0
                    self.kill()

In [11]:
class Game:
    def __init__(self):
        pygame.init()
        pygame.font.init()
        self.screen = pygame.display.set_mode((SCREEN_WIDTH, SCREEN_HEIGHT))
        self.clock = pygame.time.Clock()
        self.font = pygame.font.SysFont('Courier New', 35)
        self.running = True
        self.character_spritesheet = Spritesheet("images/sprite-sheet.png")
        self.tile_spritesheet = Spritesheet("images/tile.png")
        self.enemy_spritesheet = Spritesheet("images/enemy.png")
        self.attack_spritesheet = Spritesheet("images/attack.png")
        self.intro_background = pygame.image.load("images/introbg.png")

    def create_tilemap(self):
        for i, row in enumerate(tilemap):
            for j, column in enumerate(row):
                Ground(self, j, i)
                if column == 'W':
                    Wall(self, j, i)
                if column == 'P':
                    self.player = Player(self, j, i, "minimax")
                if column == 'E':
                    self.enemy = Enemy(self, j, i)

    def new(self):
        self.playing = True
        self.all_sprites = pygame.sprite.LayeredUpdates()
        self.blocks = pygame.sprite.LayeredUpdates()
        self.enemies = pygame.sprite.LayeredUpdates()
        self.attacks = pygame.sprite.LayeredUpdates()
        self.players = pygame.sprite.LayeredUpdates()
        self.chars = pygame.sprite.LayeredUpdates()
        self.create_tilemap()

    def events(self):
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                self.playing = False
                self.running = False

            if event.type == pygame.KEYDOWN:
                curr = pygame.time.get_ticks() / 1000
                if curr - self.player.last_att >= self.player.attack_cool:
                    if event.key == pygame.K_UP:
                        Attack(self, self.player.rect.x, self.player.rect.y-TILESIZE, "up", self.player)
                    elif event.key == pygame.K_DOWN:
                        Attack(self, self.player.rect.x, self.player.rect.y+TILESIZE, "down", self.player)
                    elif event.key == pygame.K_LEFT:
                        Attack(self, self.player.rect.x-TILESIZE, self.player.rect.y, "left", self.player)
                    elif event.key == pygame.K_RIGHT:
                        Attack(self, self.player.rect.x+TILESIZE, self.player.rect.y, "right", self.player)
                    self.player.last_att = curr

    def update(self):
        self.all_sprites.update()

    def draw(self):
        self.screen.fill(BLACK)
        self.all_sprites.draw(self.screen)
        self.clock.tick(FPS)
        pygame.display.update()

    def main(self):
        self.events()
        self.update()
        self.draw()

    def actions(self):
        return ["a_up", "a_down", "a_left", "a_right", "up", "down", "left", "right"] 

    def game_over(self):
        text = self.font.render("GAME OVER", True, BLACK)
        text_rect = text.get_rect(center=(SCREEN_WIDTH/2, SCREEN_HEIGHT/3))

        restart_button = Button(SCREEN_WIDTH/2-50, SCREEN_HEIGHT/2, 100, 50, BLACK, RED, "Again", 35)

        for sprite in self.all_sprites:
            sprite.kill()

        while self.running:
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    self.running = False
            mouse_pos = pygame.mouse.get_pos()
            mouse_pressed = pygame.mouse.get_pressed()
            if restart_button.is_pressed(mouse_pos, mouse_pressed):
                self.new()
                self.main()
            self.screen.blit(text, text_rect)
            self.screen.blit(restart_button.image, restart_button.rect)
            self.clock.tick(FPS)
            pygame.display.update()

    def intro_screen(self):
        intro = True

        title = self.font.render("The Duel", True, RED)
        title_rect = title.get_rect(x=SCREEN_WIDTH/2-75,y=10)
        play_button = Button(SCREEN_WIDTH/2-50, SCREEN_HEIGHT/2, 100, 50, BLACK, RED, "Start", 35)

        while intro:
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    intro = False
                    self.running = False
            mouse_pos = pygame.mouse.get_pos()
            mouse_pressed = pygame.mouse.get_pressed()
            if play_button.is_pressed(mouse_pos, mouse_pressed):
                intro = False

            self.screen.blit(self.intro_background, (0,0))
            self.screen.blit(title, title_rect)
            self.screen.blit(play_button.image, play_button.rect)
            self.clock.tick(FPS)
            pygame.display.update()

    def get_state(self):
        return [self.player.rect.x, self.player.rect.y, self.player.health, self.enemy.rect.x, self.enemy.rect.y, self.enemy.health]

    def next_state(self, mv, state, ismax):
        if ismax:
            if (mv == "a_up" and state[1] - state[4] < 0 and abs(state[1] - state[4]) <= TILESIZE and abs(state[0] - state[3]) <= TILESIZE):
                state[5] -= 80
            elif (mv == "a_down" and state[1] - state[4] > 0 and abs(state[1] - state[4]) <= TILESIZE and abs(state[0] - state[3]) <= TILESIZE):
                state[5] -= 80
            elif (mv == "a_left" and state[0] - state[3] > 0 and abs(state[0] - state[3]) <= TILESIZE and abs(state[1] - state[4]) <= TILESIZE):
                state[5] -= 80
            elif (mv == "a_right" and state[0] - state[3] < 0 and abs(state[0] - state[3]) <= TILESIZE and abs(state[1] - state[4]) <= TILESIZE):
                state[5] -= 80
            elif mv == "up":
                state[1] -= PLAYER_SPEED
            elif mv == "down":
                state[1] += PLAYER_SPEED
            elif mv == "left":
                state[0] -= PLAYER_SPEED
            elif mv == "right":
                state[0] += PLAYER_SPEED
        else:
            if (mv == "a_up" and state[4] - state[1] < 0 and abs(state[4] - state[1]) <= TILESIZE and abs(state[0] - state[3]) <= TILESIZE):
                state[2] -= 80
            elif (mv == "a_down" and state[4] - state[1] > 0 and abs(state[4] - state[1]) <= TILESIZE and abs(state[0] - state[3]) <= TILESIZE):
                state[2] -= 80
            elif (mv == "a_left" and state[3] - state[0] > 0 and abs(state[3] - state[0]) <= TILESIZE and abs(state[1] - state[4]) <= TILESIZE):
                state[2] -= 80
            elif (mv == "a_right" and state[3] - state[0] < 0 and abs(state[3] - state[0]) <= TILESIZE and abs(state[1] - state[4]) <= TILESIZE):
                state[2] -= 80
            elif mv == "up":
                state[4] -= ENEMY_SPEED
            elif mv == "down":
                state[4] += ENEMY_SPEED
            elif mv == "left":
                state[3] -= ENEMY_SPEED
            elif mv == "right":
                state[3] += ENEMY_SPEED
        return state

In [12]:
# Define the custom environment
class CustomGameEnv(gym.Env):
    def __init__(self):
        super(CustomGameEnv, self).__init__()

        self.num_actions = 8
        self.reward = 0
        
        # Define observation space (example: using a Box space for state)
        self.observation_space = gym.spaces.Box(low=0, high=800, shape=(6,), dtype=np.int64)
        
        # Define action space (example: using a Discrete space for actions)
        self.action_space = gym.spaces.Discrete(self.num_actions)
        
        # Initialize game
        self.game = Game()
        self.game.intro_screen()
        self.game.new()
        self.done = False
        
        # Initialize observation (state)
        self.state = self.game.get_state()
        self.state = np.array(self.state)
        
    def reset(self, seed=None):
        # Reset the game and return the initial observation
        self.reward = 0
        self.done = False
        self.game.new()
        self.state = self.game.get_state()
        self.state = np.array(self.state)
        info = {}
        return self.state, info
    
    def step(self, action):
        # Apply the action to the game and update the state
        self.game.enemy.movement(action)
        self.game.main()
        self.state = self.game.get_state()
        self.state = np.array(self.state)
        if self.game.playing == False:
            self.done = True

        if self.done == True and self.state[5] <= 0:
            self.reward = -1000
        elif self.done == True and self.state[2] <= 0:
            self.reward = 1000
        else:
            self.reward = (self.state[5] - self.state[2]) - (abs(math.dist([self.state[0], self.state[1]], [self.state[3], self.state[4]])))/320

        # Return the state, reward, done flag, and additional info
        return self.state, self.reward, self.done, False, {}
    
    def close(self):
        # Clean up resources
        self.game.game_over()
        pygame.quit()
        sys.exit()

# Create the custom environment
env = CustomGameEnv()

# Wrap the environment in a vectorized environment
env = DummyVecEnv([lambda: env])

# Create a DQN model using the custom environment and train it
model = DQN.set_parameters("game_model")
model.learn(total_timesteps=200000)

# Save the trained model
model.save("game_model")

# Use the trained model for inference
model = DQN.set_parameters("game_model")
obs = env.reset()
done = False
while not done:
    action, _ = model.predict(obs)
    obs, reward, done, info = env.step(action)

# Close the environment
env.close()




Using cpu device
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.869    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 59       |
|    time_elapsed     | 232      |
|    total_timesteps  | 13767    |
| train/              |          |
|    learning_rate    | 0.01     |
|    loss             | 7.09     |
|    n_updates        | 3416     |
----------------------------------
----------------------------------
| rollout/            |          |
|    exploration_rate | 0.747    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 58       |
|    time_elapsed     | 453      |
|    total_timesteps  | 26677    |
| train/              |          |
|    learning_rate    | 0.01     |
|    loss             | 10.8     |
|    n_updates        | 6644     |
----------------------------------
----------------------------------
| rollout/            |          |
|  

ValueError: not enough values to unpack (expected 2, got 1)