In [12]:
import pygame
import numpy as np
from gymnasium import spaces
from cellitaire.game.game import Game
from cellitaire.game.slot import Slot
from cellitaire.game.card import Card
from cellitaire.environment.rewards.reward import *
import random
import time

In [13]:
class CellitaireEnv:
    def __init__(self, reward, rows = 7, cols = 12, num_reserved = 6, max_moves = 300, max_illegal_moves = 300, render_mode=None):
        self.game = None
        self.prev_foundation_count = 0
        self.prev_legal_moves = 0
        self.prev_stockpile_count = 0
        self.reward = reward
        self.rows = rows
        self.cols = cols
        self.num_reserved = num_reserved

        self.action_space = spaces.Discrete(rows * cols)

        self.num_moves = 0
        self.max_moves = max_moves
        self.num_illegal_moves = 0
        self.max_illegal_moves = max_illegal_moves

        # TODO: this definitely isn't right
        self.observation_space = spaces.Box(low=0.0, high=53.0, shape=(1, rows * cols + 6))

        self.render_mode = render_mode

    def reset(self):
        self.game = Game()
        self.game.new_game(self.rows, self.cols, self.num_reserved)
        self.reward.reset()

        if self.render_mode == 'human':
            self.publish_updates()
        
        # Initialize previous feature values.
        self.prev_foundation_count = self.game.foundation.total_cards()
        self.prev_legal_moves = self.legal_actions_count()
        self.prev_stockpile_count = self.game.stockpile.count()

        self.num_moves = 0
        self.num_illegal_moves = 0

        reward = 0
        done = False
        truncated = False
        state = self.get_state()
        info = {}
        return state, reward, done, truncated, info
    
    def get_legal_actions(self):
        special_coords, placeable_coords = self.game.board.get_special_slots()
        legal_actions = set(special_coords)
        if self.game.stockpile.count() > 0:
            legal_actions.update(placeable_coords)
        return list(legal_actions)
    
    def get_legal_actions_as_int(self):
        legal_actions = self.get_legal_actions()
        return [x * self.cols + y for x, y in legal_actions]

    def legal_actions_count(self):
        return len(self.get_legal_actions())
    
    def get_board_state(self):
        return np.array([[slot.card.card_id if slot.card != None else 0 for slot in row] for row in self.game.board.slots])

    def get_stockpile_state(self):
        top_card = self.game.stockpile.top_card()
        return np.concatenate(
            (np.array([top_card.card_id if top_card != None else 0], dtype=np.float32),
            np.array([self.game.stockpile.count()], dtype=np.float32))
        )

    def get_foundation_state(self):
        return np.array([Card.RANKS.index(card.rank) + 1 if card != None else 0 for _, card in self.game.foundation.foundation.items()], dtype=np.float32)

    def get_state(self):
        board_state = self.get_board_state()
        stockpile_state = self.get_stockpile_state()
        foundation_state = self.get_foundation_state()
        return np.concatenate((
            board_state.reshape(1, -1), 
            stockpile_state.reshape(1, -1), 
            foundation_state.reshape(1, -1),
            np.array([self.legal_actions_count()]).reshape(1, -1)
            ), axis=1).squeeze(0)
    
    def step(self, action):
        action = self.get_action_by_index(action)
        info = {}

        move_executed = self.make_move(action)
        if not move_executed:
            self.num_illegal_moves += 1
            info = {"illegal_move": True}
        else:
            self.num_moves += 1

        new_state = self.get_state()
        done = self.legal_actions_count() < 1
        truncated = self.num_moves > self.max_moves or self.num_illegal_moves > self.max_illegal_moves

        reward = self.reward.calculate_reward(new_state, done, truncated, info)

        if self.render_mode == 'human':
            self.publish_reward_update(reward)
        
        return new_state, reward, done, truncated, info
    
    def get_action_by_index(self, action_index):
        row = action_index // self.cols
        col = action_index % self.cols
        return (row, col)

    def publish_updates(self):
        slot_events = [
            pygame.event.Event(
                GU_SLOT_UPDATE,
                coordinate=(i, j),
                card=slot.card, 
                is_lonenly=slot.is_lonely, 
                is_suffocated=slot.is_suffocated,
                is_placeable=slot.is_placeable
            ) for i, row in enumerate(self.game.board.slots) for j, slot in enumerate(row)
        ]

        for event in slot_events:
            pygame.event.post(event)

        pygame.event.post(pygame.event.Event(GU_STOCKPILE_UPDATE, top_card=self.game.stockpile.top_card(), count=self.game.stockpile.count()))
        
        pygame.event.post(pygame.event.Event(GU_FOUNDATION_UPDATE, foundation_dict=self.game.foundation.foundation, total_saved=self.game.foundation.total_cards()))

    def publish_reward_update(self, reward):
        pygame.event.post(pygame.event.Event(REWARD_UPDATED, reward=reward))
        
    def make_move(self, move):
        move_executed = self.game.make_move(move)
        if move_executed and self.render_mode == "human":
            self.publish_updates()
    
    # TODO: would be cool to have human rendering
    def render(self):
        return self.__str__()

    def __str__(self):
        return f"CellitaireEnv(game={self.game})"


In [14]:
SCREEN_WIDTH = 1920
SCREEN_HEIGHT = 1080
SCREEN_DIMS = (SCREEN_WIDTH, SCREEN_HEIGHT)

RED = (255, 0, 0)
GREEN = (0, 255, 0)
BLUE = (0, 0, 255)
WHITE = (255, 255, 255)
PURPLE = (112, 1, 169)
YELLOW = (255, 255, 0)
BLACK = (0, 0, 0)

SCREEN_MARGIN = 15

BOARD_MARGIN = 15
SLOT_WIDTH = 120
SLOT_HEIGHT = 120
SLOT_PADDING = 5
SLOT_BACKGROUND_COLOR = GREEN
SLOT_LONELY_OR_SUFFOCATED_COLOR = PURPLE
SLOT_PLACEABLE_COLOR = YELLOW

STOCKPILE_WIDTH = SLOT_WIDTH * 2
STOCKPILE_HEIGHT = SLOT_HEIGHT
STOCKPILE_PADDING = 5
STOCKPILE_BACKGROUND_COLOR = BLACK

FOUNDATION_WIDTH = SLOT_WIDTH * 5
FOUNDATION_HEIGHT = STOCKPILE_HEIGHT
FOUNDATION_PADDING = 5
FOUNDATION_PILE_OUTLINE_COLOR = YELLOW
FOUNDATION_BACKGROUND_COLOR = BLACK

BACKGROUND_COLOR = BLUE

GU_SLOT_UPDATE = pygame.USEREVENT + 1
GU_FOUNDATION_UPDATE = pygame.USEREVENT + 2
GU_STOCKPILE_UPDATE = pygame.USEREVENT + 3

SLOT_CLICKED = pygame.USEREVENT + 4

REWARD_UPDATED = pygame.USEREVENT + 5
GAME_DONE = pygame.USEREVENT + 6

def get_card_text_color(card):
    return {
        's': BLACK,
        'h': RED,
        'd': BLUE,
        'c': GREEN
    }[card.suit]

In [15]:
class GameWrapper:
    def __init__(self, rows, cols, num_reserved):
        self.rows = rows
        self.cols = cols
        self.num_reserved = num_reserved
        
        self.game = None

    def reset(self):
        self.game = Game()
        self.game.new_game(self.rows, self.cols, self.num_reserved)

    def publish_updates(self):
        slot_events = [
            pygame.event.Event(
                GU_SLOT_UPDATE,
                coordinate=(i, j),
                card=slot.card, 
                is_lonenly=slot.is_lonely, 
                is_suffocated=slot.is_suffocated,
                is_placeable=slot.is_placeable
            ) for i, row in enumerate(self.game.board.slots) for j, slot in enumerate(row)
        ]

        for event in slot_events:
            pygame.event.post(event)

        pygame.event.post(pygame.event.Event(GU_STOCKPILE_UPDATE, top_card=self.game.stockpile.top_card(), count=self.game.stockpile.count()))
        
        pygame.event.post(pygame.event.Event(GU_FOUNDATION_UPDATE, foundation_dict=self.game.foundation.foundation, total_saved=self.game.foundation.total_cards()))

    def make_move(self, move):
        self.game.make_move(move)
        self.publish_updates()

In [16]:
class RewardSprite(pygame.sprite.Sprite):
    def __init__(self, height, width, x, y):
        super().__init__()

        self.image = pygame.Surface([width, height])
        self.rect = self.image.get_rect()
        self.rect.x = x
        self.rect.y = y

        self.height = height
        self.width = width

        self.total_reward = 0

    def draw_reward(self):
        r = pygame.draw.rect(self.image, BLACK, pygame.Rect(0, 0, self.width, self.height))

        font = pygame.font.Font(None, 24)
        score_text = f'Reward: {str(self.total_reward)}'
        text_surface = font.render(score_text, True, WHITE)
        text_rect = text_surface.get_rect(center=r.center)
        self.image.blit(text_surface, text_rect)

    def hadle_reward_update_event(self, event):
        self.total_reward += event.reward

    def update(self, events):
        for event in events:
            if event.type == REWARD_UPDATED:
                self.hadle_reward_update_event(event)
            elif event.type == GAME_DONE:
                self.total_reward = 0

        self.draw_reward()

In [17]:
class StockpileSprite(pygame.sprite.Sprite):
    def __init__(self, height, width, x, y):
        super().__init__()
        
        self.image = pygame.Surface([width, height])
        self.rect = self.image.get_rect()
        self.rect.x = x
        self.rect.y = y

        self.height = height
        self.width = width

        self.card = None
        self.count = 0

        self.draw_stockpile()

    def draw_stockpile_outline(self):
        pygame.draw.rect(
            self.image, 
            YELLOW, 
            pygame.Rect(
                0, 
                0, 
                self.width // 2, 
                self.height
            ),
            STOCKPILE_PADDING
        )

    def draw_background(self):
        pygame.draw.rect(self.image, STOCKPILE_BACKGROUND_COLOR, pygame.Rect(0, 0, self.width, self.height))

    def draw_card(self):
        if self.card is None:
            return

        card_rect = pygame.draw.rect(
            self.image,
            WHITE,
            pygame.Rect(
                STOCKPILE_PADDING, 
                STOCKPILE_PADDING, 
                (self.width // 2) - 2 * STOCKPILE_PADDING, 
                self.height - 2 * STOCKPILE_PADDING
            )
        )
        
        font = pygame.font.Font(None, 24)
        card_text = str(self.card)
        text_surface = font.render(card_text, True, get_card_text_color(self.card))
        text_rect = text_surface.get_rect(center=card_rect.center)
        self.image.blit(text_surface, text_rect)

    def draw_count(self):
        font = pygame.font.Font(None, 24)
        card_text = str(self.count)
        text_surface = font.render(card_text, True, WHITE)
        text_rect = text_surface.get_rect(center=(3 * self.width // 4, self.height // 2))
        self.image.blit(text_surface, text_rect)
        
    def draw_stockpile(self):
        self.draw_background()
        self.draw_stockpile_outline()
        self.draw_card()
        self.draw_count()

    def handle_stockpile_update_event(self, event):
        self.card = event.top_card
        self.count = event.count

    def update(self, events):
        for event in events:
            if event.type == GU_STOCKPILE_UPDATE:
                self.handle_stockpile_update_event(event)

        self.draw_stockpile()

In [18]:
class FoundationSprite(pygame.sprite.Sprite):
    def __init__(self, height, width, x, y):
        super().__init__()

        self.image = pygame.Surface([width, height])
        self.rect = self.image.get_rect()
        self.rect.x = x
        self.rect.y = y

        self.height = height
        self.width = width

        self.foundation_dict = {
            's': None,
            'h': None,
            'd': None,
            'c': None
        }
        self.total_saved = 0
        self.draw_foundation()

    def draw_background(self):
        pygame.draw.rect(self.image, FOUNDATION_BACKGROUND_COLOR, pygame.Rect(0, 0, self.width, self.height))

    def draw_card(self, x_offset, card):
        if card is None:
            return
        card_rect = pygame.draw.rect(
            self.image,
            WHITE,
            pygame.Rect(
                x_offset + FOUNDATION_PADDING, 
                FOUNDATION_PADDING, 
                self.width // 5 - 2 * FOUNDATION_PADDING, 
                self.height - 2 * FOUNDATION_PADDING
            )
        )

        font = pygame.font.Font(None, 24)
        card_text = str(card)
        text_surface = font.render(card_text, True, get_card_text_color(card))
        text_rect = text_surface.get_rect(center=card_rect.center)
        self.image.blit(text_surface, text_rect)
        
    def draw_pile_outline(self, x_offset):
        pygame.draw.rect(
            self.image, 
            FOUNDATION_PILE_OUTLINE_COLOR, 
            pygame.Rect(
                x_offset, 
                0, 
                self.width // 5, 
                self.height
            ),
            FOUNDATION_PADDING
        )
        
    def draw_cards(self):
        x_offset = self.width // 5
        step = x_offset
        for suit, card in self.foundation_dict.items():
            self.draw_pile_outline(x_offset)
            self.draw_card(x_offset, card)
            x_offset += step

    def draw_total_saved(self):
        font = pygame.font.Font(None, 24)
        card_text = str(self.total_saved)
        text_surface = font.render(card_text, True, WHITE)
        text_rect = text_surface.get_rect(center=(self.width // 10, self.height // 2))
        self.image.blit(text_surface, text_rect)

    def draw_foundation(self):
        self.draw_background()
        self.draw_cards()
        self.draw_total_saved()
    
    def handle_foundation_update_event(self, event):
        self.foundation_dict = event.foundation_dict
        self.total_saved = event.total_saved
    
    def update(self, events):
        for event in events:
            if event.type == GU_FOUNDATION_UPDATE:
                self.handle_foundation_update_event(event)

        self.draw_foundation()

In [19]:
class SlotSprite(pygame.sprite.Sprite):
    def __init__(self, height, width, x, y, coordinate):
        super().__init__()
        
        self.image = pygame.Surface([width, height])
        self.rect = self.image.get_rect()
        self.rect.x = x
        self.rect.y = y

        self.height = height
        self.width = width

        self.coordinate = coordinate

        self.card = None
        self.is_lonely = False
        self.is_suffocated = False
        self.is_placeable = False
        self.is_hovered = False

        self.draw_slot()

    def draw_background(self):
        pygame.draw.rect(self.image, SLOT_BACKGROUND_COLOR, pygame.Rect(0, 0, self.width, self.height))

    def draw_outline(self):
        if not (self.is_lonely or self.is_suffocated or self.is_placeable):
            return

        outline_color = SLOT_LONELY_OR_SUFFOCATED_COLOR
        if self.is_placeable:
            outline_color = SLOT_PLACEABLE_COLOR

        pygame.draw.rect(
            self.image, 
            outline_color, 
            pygame.Rect(
                0, 
                0, 
                self.width, 
                self.height
            ),
            SLOT_PADDING
        )

    def draw_card(self):
        if self.card is None:
            return
        card_rect = pygame.draw.rect(
            self.image,
            WHITE,
            pygame.Rect(
                SLOT_PADDING, 
                SLOT_PADDING, 
                self.width - 2 * SLOT_PADDING, 
                self.height - 2 * SLOT_PADDING
            )
        )

        font = pygame.font.Font(None, 24)
        card_text = str(self.card)
        text_surface = font.render(card_text, True, get_card_text_color(self.card))
        text_rect = text_surface.get_rect(center=card_rect.center)
        self.image.blit(text_surface, text_rect)

    def draw_hover_overlay(self):
        if not self.is_hovered:
            return
        overlay = pygame.Surface((self.width, self.height), pygame.SRCALPHA)
        overlay.fill((0, 0, 0, 50))
        self.image.blit(overlay, (0, 0))
        self.rect = self.image.get_rect(topleft=self.rect.topleft)

    def draw_slot(self):
        self.draw_background()
        self.draw_card()
        self.draw_outline()
        self.draw_hover_overlay()

    def handle_slot_update_event(self, event):
        self.card = event.card
        self.is_lonely = event.is_lonenly
        self.is_suffocated = event.is_suffocated
        self.is_placeable = event.is_placeable

    def handle_clicked(self):
        if not (self.is_lonely or self.is_suffocated or self.is_placeable):
            return
        pygame.event.post(pygame.event.Event(SLOT_CLICKED, coordinate=self.coordinate))
    
    def update(self, events):
        mouse_pos = pygame.mouse.get_pos()
        if self.rect.collidepoint(mouse_pos):
            self.is_hovered = True
        else:
            self.is_hovered = False

        for event in events:
            if event.type == GU_SLOT_UPDATE and  event.coordinate == self.coordinate:
                self.handle_slot_update_event(event)
            if self.is_hovered and event.type == pygame.MOUSEBUTTONUP:
                self.handle_clicked()

        self.draw_slot()

In [20]:
screen = pygame.display.set_mode(SCREEN_DIMS)
pygame.display.set_caption("Cellitaire RL")

pygame.init()

ROWS = 7
COLS = 12
NUM_RESERVED = 6

board_rows = 7
board_cols = 12
num_reserved = 6
test_reward = CombinedReward([
    PlacedCardInFoundationReward(weight=6),
    WinReward(),
    ConstantReward(weight=0.5),
    #PlayedLegalMoveReward(weight=1),
    #PeriodicPlacedCardInFoundationReward(weight=4, reward_period=3),
    CreatedMovesReward(weight=1, num_reserved=num_reserved, foundation_count_dropoff=30)
])

env = CellitaireEnv(reward=test_reward, rows=ROWS, cols=COLS, num_reserved=NUM_RESERVED, render_mode='human')
env.reset()

all_sprites = pygame.sprite.Group()

all_sprites.add(StockpileSprite(
    height=STOCKPILE_HEIGHT,
    width=STOCKPILE_WIDTH,
    x=SCREEN_MARGIN,
    y=SCREEN_MARGIN
))

all_sprites.add(FoundationSprite(
    height=FOUNDATION_HEIGHT,
    width=FOUNDATION_WIDTH,
    x=SCREEN_MARGIN + SLOT_WIDTH * COLS - FOUNDATION_WIDTH,
    y=SCREEN_MARGIN
))

all_sprites.add(RewardSprite(
    height=SLOT_HEIGHT,
    width=SLOT_WIDTH,
    x=SCREEN_MARGIN + SLOT_WIDTH * COLS - FOUNDATION_WIDTH - 2 * SLOT_WIDTH,
    y=SCREEN_MARGIN
))

for i in range(ROWS):
    for j in range(COLS):
        all_sprites.add(
            SlotSprite(
                height=SLOT_HEIGHT,
                width=SLOT_WIDTH, 
                x=SCREEN_MARGIN + j * SLOT_WIDTH, 
                y=SCREEN_MARGIN + STOCKPILE_HEIGHT + BOARD_MARGIN + i * SLOT_HEIGHT,
                coordinate=(i, j)
            )
        )

running = True
clock = pygame.time.Clock()

start_time = time.time()    
while running:
    events = pygame.event.get()
    for event in events:
        if event.type == pygame.QUIT:
            running = False
        if event.type == SLOT_CLICKED:
            move_index = event.coordinate[0] * COLS + event.coordinate[1]
            env.step(move_index)

    if time.time() - start_time > 1:
        legal_actions_as_int = env.get_legal_actions_as_int()
        env.step(random.choice(legal_actions_as_int))
        if len(env.get_legal_actions_as_int()) < 1:
            env.reset()
            pygame.event.post(pygame.event.Event(GAME_DONE))
        start_time = time.time()
    
    all_sprites.update(events)
    screen.fill(BACKGROUND_COLOR)
    all_sprites.draw(screen)
    pygame.display.flip()
    clock.tick(10 * 1000)

pygame.quit()

In [None]:
pygame.quit()