In [None]:
%pip install gym rich gymnasium
%pip install stable_baselines3

In [2]:
from stable_baselines3.common.env_checker import check_env
from stable_baselines3 import PPO, A2C, DQN
import gymnasium as gym
from gymnasium import spaces
import random
from rich.console import Console
from rich.text import Text
import numpy as np
import cv2
import os

In [3]:
class Card:
    def __init__(self, value, suit, visible=False, bonus=False):
        """
        Initialize a card with given value, suit, visibility, and bonus status.

        Parameters:
        - value (int): The rank of the card (1 to 13, where 1 = Ace and 13 = King).
        - suit (int): The suit of the card (0 to 3).
        - visible (bool): Whether the card is face-up.
        - bonus (bool): Whether this card has been given a bonus for moving to the foundation.
        """
        self.value = value  # 1 to 13
        self.suit = suit  # 0 to 3
        # 0 - ♥ (hearts) (red)
        # 1 - ♦ (diamonds) (red)
        # 2 - ♣ (clubs) (black)
        # 3 - ♠ (spades) (black)
        # for observation:
        # 4 - unexisting suit for flatten()
        # 5 - invisible suit

        self.visible = visible  # Face-up or face-down
        self.bonus = bonus  # Bonus flag to prevent duplicate rewards

    def __repr__(self):
        visibility = "Visible" if self.visible else "Hidden"
        return f"Card(value={self.value}, suit={self.suit}, {visibility}, bonus={self.bonus})"



In [None]:
def obtain_visible_cards_from_screenshot():

    baseline_x, baseline_y = 727, 291
    offset_x, offset_y, rso = 67, 24, 40


    def match_templates(image, templates, threshold=0.9):
        """
        Matches templates against the input image and returns a list of match locations.

        Parameters:
        - image (ndarray): The input grayscale image.
        - templates (dict): A dictionary of templates, where keys are labels (e.g., rank or suit names) and values are the template images.
        - threshold (float): The threshold for considering a match.

        Returns:
        - list: A list of matches as (label, x, y), where (x, y) is the top-left corner of the match.
        """
        matches = []
        for label, template in templates.items():
            result = cv2.matchTemplate(image, template, cv2.TM_CCOEFF_NORMED)
            loc = np.where(result >= threshold)
            for pt in zip(*loc[::-1]):
                matches.append((label, pt[0], pt[1], result[pt[1], pt[0]]))
        return matches


    # Paths
    image_path = 'screenshot_1.png'  # Input screenshot
    ranks_path = './icons/ranks/'  # Path to rank templates (e.g., 1.png, 2.png, ..., 13.png)
    suits_path = './icons/suits/'  # Path to suit templates (e.g., 0.png for ♥, 1.png for ♦, etc.)

    # Load input image and convert to grayscale
    image = cv2.imread(image_path)
    gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)


    # Load rank and suit templates
    rank_templates = {f.split('.')[0]: cv2.imread(os.path.join(ranks_path, f), cv2.IMREAD_GRAYSCALE)
                        for f in os.listdir(ranks_path)}
    suit_templates = {f.split('.')[0]: cv2.imread(os.path.join(suits_path, f))
                        for f in os.listdir(suits_path)}


    # Perform template matching
    rank_matches = match_templates(gray_image, rank_templates)
    suit_matches = match_templates(image, suit_templates)


    def determine_card_locations(rank_matches, suit_matches):
        """
        Determines card locations based on rank and suit matches and their corresponding regions.

        Parameters:
        - rank_matches (list): List of rank matches as (rank_label, x, y).
        - suit_matches (list): List of suit matches as (suit_label, x, y).

        Returns:
        - dictionary {"tableau": np_field, "foundation": np_foundation, "top_card": ret_draw}
        """
        field_cards = [[[0, 0, 0.0, 0.0] for j in range(19)] for i in range(7)]
        draw_card = [0, 0, 0.0, 0.0]
        foundation_cards = [[0, 0, 0.0, 0.0] for i in range(4)]

        ret_field = []
        ret_foundation = []
        ret_draw = Card(value=0, suit=0, visible=False)

        for rank_label, rank_x, rank_y, sim in rank_matches:
            rank_col = (rank_x - baseline_x) // offset_x
            rank_row = (rank_y - baseline_y) // offset_y

            if (rank_y < baseline_y):
                if (rank_col < 4):
                    foundation_cards[rank_col][0] = rank_label
                    foundation_cards[rank_col][2] = sim
                elif (rank_col == 5):
                    draw_card[0] = rank_label
                    draw_card[2] = sim
                continue

            if (sim > field_cards[rank_col][rank_row][2]):
                field_cards[rank_col][rank_row][0] = rank_label
                field_cards[rank_col][rank_row][2] = sim

        for suit_label, suit_x, suit_y, sim in suit_matches:
            suit_col = (suit_x - baseline_x) // offset_x
            suit_row = (suit_y - baseline_y) // offset_y

            if (suit_y < baseline_y):
                if (suit_col < 4):
                    foundation_cards[suit_col][1] = suit_label
                    foundation_cards[suit_col][3] = sim
                elif (suit_col == 6):
                    draw_card[1] = suit_label
                    draw_card[3] = sim
                continue

            if (sim > field_cards[suit_col][suit_row][3]):
                field_cards[suit_col][suit_row][1] = suit_label
                field_cards[suit_col][suit_row][3] = sim

        for i, row in enumerate(field_cards):
            for j, card in enumerate(row):
                if (card[0] != 0 and card[0] != 14):
                    ret_field.append((i, j, Card(
                        value=card[0],
                        suit=card[1],
                        visible=True)))

        for i, card in enumerate(foundation_cards):
            if (card[0] != 0 and card[0] != 14):
                    ret_foundation.append((i, Card(
                        value=card[0],
                        suit=card[1],
                        visible=True)))

        if (draw_card[0] != 14):
            if (draw_card[0] == 0):
                ret_draw = Card(
                value=0,
                suit=0,
                visible=False)
            else:
                ret_draw = Card(
                    value=draw_card[0],
                    suit=draw_card[1],
                    visible=True)
        np_field = np.array(ret_field)
        np_foundation = np.array(ret_foundation)
        np_draw = np.array(ret_draw)
        all_current_cards_ret = {
            "tableau": np_field,
            "foundation": np_foundation,
            "top_card": ret_draw
        }
        return all_current_cards_ret


    # Determine card locations
    all_current_cards = determine_card_locations(rank_matches, suit_matches)


    # Output results
    print("Detected Cards:")
    print(all_current_cards, sep='\n')

    return all_current_cards



In [15]:

  # Initialize a Console object from the rich library for printing with styles
console = Console()

class SolitaireEnv(gym.Env):
    def __init__(self):
        super(SolitaireEnv, self).__init__()
        # Action type: 0 (Move within tableau), 1 (Draw card), 2 (Move to foundation); Source column (0-11) and card index (0-18); Destination column (0-6)
        self.action_space = spaces.MultiDiscrete([3, 12, 18, 7])

        # Define observation space with structured tableau, foundation, and draw pile
        self.observation_space = spaces.Dict({
            'tableau': spaces.MultiDiscrete([2] * 54 * 7 * 18),  # Each card is one-hot (54), 7 columns, 18 cards max
            'foundation': spaces.MultiDiscrete([14, 14, 14, 14]),  # Foundation unchanged
            'top_card': spaces.MultiDiscrete([2] * 54),  # One-hot encoded top card
        })

        # Track revealed cards from the draw pile
        self.revealed_cards = []
        self.draw_index = 0  # Tracks current position in draw pile

        self.tableau = None
        self.foundation = None
        self.draw_pile = None
        self.draw_pile_cycles = 3
        self.done = False
        self.reward = 0
        self.colors = {
            0: "red",
            1: "red",
            2: "black",
            3: "black"
        }

        self.tries = 400
        self.king_value = 13
        self.ace_value = 1
        self.invisible_value = 14
        self.unexisting_value = 0
        self.invisible_suit = 5
        self.unexisting_suit = 4
        self._reset_game_state()

    def _reset_game_state(self):
        # Initialize the deck as a list of Card objects
        deck = [Card(value, suit) for suit in range(4) for value in range(self.ace_value, self.king_value+1)]
        random.shuffle(deck)

        # Initialize tableau with some cards face-down
        self.tableau = [[deck.pop() for _ in range(i + 1)] for i in range(7)]
        for col in self.tableau:
            for card in col[:-1]:
                card.visible = False  # Face-down
            col[-1].visible = True  # Top card face-up

        # Foundation starts empty
        self.foundation = [[] for _ in range(4)]
        self.draw_pile = deck  # Remaining cards in the draw pile
        self.done = False
        self.reward = 0
        self.revealed_cards = []
        self.draw_index = 0
        self.tries = 400
        
    def reset(self, seed=None):
        """Resets the environment to the initial state."""
        self._reset_game_state()
        return self._get_observation(), {}

    def is_valid_tableau_move(self, source_col, source_idx):
        """Custom logic to validate moves within tableau based on game rules."""
        # Example logic: check if there's a card at the specified column and index
        if source_col >= len(self.tableau):
            return False
        if source_idx >= len(self.tableau[source_col]):
            return False
        if not self.tableau[source_col][source_idx].visible:
            return False
        return True

    def is_valid_move_to_foundation(self, source_col):
        """Custom logic to validate moves to the foundation."""
        # Example logic: check if a card from the source column can move to the foundation
        if source_col >= len(self.tableau):
            return False
        if len(self.tableau[source_col]) == 0:
            return False
        return True

    def compute_action_mask(self):
      # Initialize mask with zeros (invalid actions by default)
      mask = np.zeros((3, 12, 18, 7), dtype=int)

      # Iterate over all possible actions
      for action_type in range(3):  # Action type: 0, 1, 2
          for source_col in range(12):  # Source column: 0-11
              for source_idx in range(18):  # Source index: 0-17
                  for dest_col in range(7):  # Destination column: 0-6
                      if action_type == 0:  # Move within tableau
                          if 0 <= source_col <= 6:
                              # Check if this source_col and source_idx are valid
                              if self.is_valid_tableau_move(source_col, source_idx):
                                  mask[0, source_col, source_idx, dest_col] = 1
                          elif 7 <= source_col <= 11:
                              # Ignore source_idx for these columns, but only allow source_idx = 0
                              if self.is_valid_tableau_move(source_col, 0):
                                  mask[0, source_col, 0, dest_col] = 1

                      elif action_type == 2:  # Move to foundation
                          # Only source_col is relevant; ignore source_idx and dest_col
                          if self.is_valid_move_to_foundation(source_col):
                              mask[2, source_col, 0, 0] = 1

      mask[1, 0, 0, 0] = 1   # only valid move for drawing card

      return mask.flatten()

    def card_to_one_hot(self, card):
          """
          Convert a card to a one-hot encoded array of length 54.
          """
          one_hot = [0] * 54
          if card.visible == True:  # Invisible card
              one_hot[53] = 1
          else:
              index = (card.value - 1) + (card.suit * 13)  # Calculate index for 52 cards
              one_hot[index] = 1
          return one_hot


    def _get_observation(self):
      # Convert tableau to one-hot encoding
      tableau_observation = [
          [self.card_to_one_hot(card) for card in column]
          for column in self.tableau
      ]

      # Pad tableau columns to a fixed length (18 cards)
      max_length = 18
      tableau_observation_padded = [
          column + [[0] * 54] * (max_length - len(column))
          for column in tableau_observation
      ]

      # Flatten the tableau for observation
      tableau_array = np.array(tableau_observation_padded).reshape(-1)

      # Foundation observation (unchanged)
      foundation_observation = [len(pile) for pile in self.foundation]

      # Top card observation
      if self.revealed_cards:
          top_card = self.revealed_cards[-1]
          top_card_observation = self.card_to_one_hot(top_card)
      else:
          top_card_observation = [0] * 54
          top_card_observation[52] = 1  # Non-existent card

      # Combine all parts of the observation
      result = {
          "tableau": tableau_array,
          "foundation": np.array(foundation_observation),
          "top_card": np.array(top_card_observation),
      }
      return result


    def step(self, action: list):
        # action: [int, list[int,int], int]
        # action_type: int
        # source: list[int,int] - [column, card in column]. Column numbers:
        # 0-6: tableau
        # 7-10: foundation
        # 11: draw pile
        # destination: int - column, no need  since all the cards will be moved on top of it

        action_type, source1, source2, destination = action
        current_reward = -1   # Base penalty for each action

        if action_type == 0:  # Move Card within Tableau
            valid_move_made, reward = self._move_within_tableau([source1, source2], destination)   # here reward should be negative
            current_reward += reward
            if not valid_move_made:
                #print("The move isn't valid")
                current_reward -= 10  # Extra penalty for invalid move

        elif action_type == 1:  # Draw Card from Draw Pile
            empty = self._draw_card()
            current_reward -= 70
            if empty:
              current_reward -= 100

        elif action_type == 2:  # Move Card to Foundation
            valid_move_made, reward = self._move_to_foundation(source1)
            current_reward += reward

        flipped_count = self._flip_visible_cards()
        current_reward += flipped_count * 1600

        # Check if game is won (all foundations complete)
        if all(len(foundation) == 13 for foundation in self.foundation):
            self.done = True
              
        self.reward += current_reward
        self.tries -= 1

        return self._get_observation(), current_reward, self.done, self.tries <= 0, {}

    def _draw_card(self):
        # Reveal 1 card at a time from the draw pile
        if self.draw_pile:
            card = self.draw_pile.pop()
            card.visible = True
            self.revealed_cards.append(card)
            return False
        else:
            self.draw_pile = self.revealed_cards[::-1]  # Restart the draw pile if we reach the end
            self.revealed_cards = []
            self.draw_pile_cycles-=1
            if self.draw_pile_cycles < 0:
                return True


    def _is_alternating_color(self, object1, object2):
        # Check if the objects have alternating colors
        return self.colors[object1] != self.colors[object2]


    # all return numbers after false are negative, and positive after true
    def _move_within_tableau(self, source: list[int], destination: int):
        if destination > 6 or destination < 0:
          #print("Wrong destination column")
          return False, -10
        # If the source is from the draw pile
        if source[0] == 11:
            if not self.revealed_cards:
                #print("Invalid move: No cards revealed in the draw pile")
                return False, -50 # No cards revealed in the draw pile
            # Use the last revealed card from the draw pile
            card_to_move = self.revealed_cards[-1]

            # Check if destination column is empty (only Kings can move to empty columns)
            if not self.tableau[destination]:
                if card_to_move.value == self.king_value: # King card value
                    self.tableau[destination].append(card_to_move)
                    self.revealed_cards.pop()  # Remove from revealed list
                    return True, 500
                else:
                    #print("Invalid move: Only Kings can move to an empty column")
                    return False, -60  # Only Kings can move to an empty column

            # Check if the move is valid based on the destination column's top card
            dest_card = self.tableau[destination][-1]
            if (card_to_move.value == dest_card.value - 1 and
                self._is_alternating_color(card_to_move.suit, dest_card.suit)):
                self.tableau[destination].append(card_to_move)
                self.revealed_cards.pop()  # Remove from revealed list
                return True, 500

            #print("Invalid move: Invalid move for draw pile card")
            return False, -50  # Invalid move for draw pile card

        if 7 <= source[0] <= 10:
          suit = source[0] - 7
          if self.foundation[suit] and self.tableau[destination]:
              card_to_move = self.foundation[suit][-1]
              dest_card = self.tableau[destination][-1]
              if (card_to_move.value == dest_card.value - 1 and
                  self._is_alternating_color(card_to_move.suit, dest_card.suit)):
                  self.tableau[destination].append(card_to_move)
                  self.foundation[suit].pop()  # Remove from revealed list
                  return True, 500
              return False, -50
          else:
            return False, -50

        if source[0] < 0:
            #print("Invalid move: Wrong column number")
            return False, -100
        if source[1] >= len(self.tableau[source[0]]):
            #print("Invalid move: Wrong card index number")
            return False, -100

        card_column = source[0]
        card_index = source[1]
        cards_to_move = self.tableau[card_column][card_index:]

        # Check if destination column is empty (only Kings can be moved to an empty column)
        if not self.tableau[destination]:
            if cards_to_move[0].value == self.king_value: # King card value
                # Move the sequence
                self.tableau[destination].extend(cards_to_move)
                del self.tableau[card_column][card_index:]
                return True, 500
            else:
                #print("Invalid move: Only Kings can be moved to an empty column")
                return False, -60  # Only Kings can be moved to an empty column

        # Check if the move is valid based on the destination column’s top card
        dest_card = self.tableau[destination][-1]
        if (cards_to_move[0].value == dest_card.value - 1 and
                self._is_alternating_color(cards_to_move[0].suit, dest_card.suit)):
            # Move the sequence
            self.tableau[destination].extend(cards_to_move)
            del self.tableau[card_column][card_index:]
            return True, 500

        #print("Invalid move: Invalid move within tableau")
        return False, -40  # Move was invalid


    def _move_to_foundation(self, source): # source is int, since we move the top card of source to top of foundation
        valid_reward = 130
        # If source is the draw pile (denoted by 11), take the last revealed card
        if source == 11:
            if not self.revealed_cards:
                #print("Invalid move: No revealed cards in draw pile")
                return False, -50  # No revealed cards in draw pile
            card = self.revealed_cards[-1]
            foundation_index = card.suit # Determine foundation based on suit

            if len(self.foundation[foundation_index]) == card.value - self.ace_value:  # -ace_value because card values are from 2 to 14
                # Move the card to the foundation and remove from revealed list
                self.foundation[foundation_index].append(card)
                self.revealed_cards.pop()
                if not card.bonus:
                    card.bonus = True
                return True, valid_reward

            #print("Invalid move: Invalid move to foundation from draw pile")
            return False, -40  # Invalid move

        # Validate source column
        if source < 0 or source > 6 or not self.tableau[source]:
            #print("Invalid move: No card to move")
            return False, -60  # Invalid move, no card to move

        # Get the top card from the source column
        card = self.tableau[source][-1]
        foundation_index = card.suit  # Determine foundation pile based on suit

        # Check if the card can move to the foundation (must be in ascending order)
        if len(self.foundation[foundation_index]) == card.value - self.ace_value:     # -ace_value because card values are from 2 to 14
            # Move the card to the foundation and remove from tableau
            self.foundation[foundation_index].append(self.tableau[source].pop())
            if not card.bonus:
                card.bonus = True
            return True, valid_reward

        #print("Invalid move: Invalid move to foundation from tableau")
        return False, -40  # Move was invalid


    def _flip_visible_cards(self):
        flipped_count = 0
        for column in self.tableau:
            if column and not column[-1].visible:  # If the top card is face-down
                column[-1].visible = True  # Flip it face-up
                flipped_count += 1
        return flipped_count


    def render(self, mode='human'):
        # Foundations
        foundation_str = []
        for pile in self.foundation:
            if pile:
                card = pile[-1]
                suit = ['[bold red] ♥[/bold red]', '[bold red] ♦[/bold red]', ' ♣', ' ♠'][card.suit]
                # Apply color red for red suits (Diamonds and Hearts)
                foundation_str.append(f"| {card.value} {suit} |" if card.value != self.invisible_value else f"| A {suit} |")
            else:
                foundation_str.append("|     |")  # Empty foundation pile

        # Print foundation row
        console.print("Foundations:", "  ".join(foundation_str))

        # Tableau - display all 7 columns in a single row
        tableau_str = []
        for col in self.tableau:
            tableau_str.append(" ".join([f"┌─────┐" if card.visible else "┌─────┐" for card in col]))  # Card tops
            tableau_str.append(" ".join([f"| {card.value if card.visible else ' ?'}{' ' if card.visible and len(str(card.value)) != 2 else ''}{['[bold red] ♥[/bold red]', '[bold red] ♦[/bold red]', ' ♣', ' ♠'][card.suit] if card.visible else '  '}|" for card in col]).replace(" 0 ", " A "))  # Card values
            tableau_str.append(" ".join([f"|     |" for _ in col]))  # Empty space for spacing between cards
            tableau_str.append(" ".join([f"└─────┘" for _ in col]))  # Card bottoms

        # Print tableau columns in one row
        tableau_str = '\n'.join(tableau_str)
        console.print("Tableau:\n" + tableau_str)

        # Draw pile (remaining count, last 3 revealed, discarded count)
        draw_pile_display = f"Draw Pile: {len(self.draw_pile)} cards remaining"

        # Last 3 revealed cards (if any)
        last_three = [f"|{card.value if card.visible else ' ?'}{' ' if card.visible and len(str(card.value)) != 2 else ''}{['[bold red] ♥[/bold red]', '[bold red] ♦[/bold red]', ' ♣', ' ♠'][card.suit] if card.visible else '  '}|" for card in self.revealed_cards[-3:]]
        last_three_display = f"\nLast 3 Drawn: {' '.join(last_three)}"

        # Discarded cards (number of discarded cards)

        # Print the full draw pile row
        console.print(draw_pile_display, last_three_display)

In [16]:
env = SolitaireEnv()
#check_env(env, warn=True)

print(env.render())

None


In [14]:
print(env.observation_space.sample())

{'foundation': array([9, 2, 7, 7], dtype=int64), 'tableau': array([0, 1, 1, ..., 1, 0, 1], dtype=int64), 'top_card': array([1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1,
       0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0,
       1, 1, 1, 1, 0, 1, 1, 1, 0, 0], dtype=int64)}


In [None]:
from IPython.display import clear_output
import sys
import json

def main():
    # Initialize the environment
    env = SolitaireEnv()
    done = False

    # Initialize dataset to collect state-action pairs
    dataset = []

    # Game loop
    while not done:
        # Render the current state
        clear_output(wait=True)
        print(f"\nReward: {env.reward}")
        env.render()

        # Get user input for action
        try:
            print("\nEnter your action:")
            input_str = input()

            # Split the input string into individual parts and convert them to integers
            action = list(map(int, input_str.split()))

            if action[0] == 5: # remove last element in dataset in case of user's(my) mistake
                dataset.pop()
                continue
            if action[0] == 9:
                break
        except ValueError:
            print("Invalid input! Please enter valid numbers.")
            continue

        # Get the current state (observation)
        current_state = env._get_observation()  # Modify this based on how the state is represented in your environment
        current_state_serializable = {
            key: value.tolist() if isinstance(value, np.ndarray) else value
            for key, value in current_state.items()
        }

        # Add the state and action to the dataset
        dataset.append({'state': current_state_serializable, 'action': action})

        # Take the action
        try:
            obs, reward, done, terminal, info = env.step(action)
            if done:
                print("Congratulations! You have completed the game!\n Your score: ", env.reward)
        except Exception as e:
            print(f"Error: {e}")
            sys.exit(1)

    # Save the dataset to a file after the game ends
    with open("gameplay_data1.json", "w") as f:
        json.dump(dataset, f, indent=4)
    print("Dataset saved!")

if __name__ == "__main__":
    main()


# Combined Pre-Training and Training

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim

class ComplexModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(ComplexModel, self).__init__()
        
        # Increased depth and width
        self.fc1 = nn.Linear(input_dim, 512)   # First layer with 512 units
        self.fc2 = nn.Linear(512, 256)         # Second layer with 256 units
        self.fc3 = nn.Linear(256, 128)         # Third layer with 128 units
        self.fc4 = nn.Linear(128, 64)          # Fourth layer with 64 units
        self.fc5 = nn.Linear(64, 32)           # Fifth layer with 32 units
        self.fc6 = nn.Linear(32, output_dim)   # Output layer
        
        # Dropout layer to prevent overfitting
        self.dropout = nn.Dropout(p=0.5)  # 50% dropout rate

    def forward(self, x):
        x = torch.relu(self.fc1(x))  # Apply ReLU activation after first layer
        x = self.dropout(x)  # Apply dropout

        x = torch.relu(self.fc2(x))  # Second layer
        x = self.dropout(x)

        x = torch.relu(self.fc3(x))  # Third layer
        x = self.dropout(x)

        x = torch.relu(self.fc4(x))  # Fourth layer
        x = self.dropout(x)

        x = torch.relu(self.fc5(x))  # Fifth layer
        x = self.dropout(x)

        return self.fc6(x)  # Output layer


In [19]:
import torch
import torch.optim as optim
import torch.nn as nn
import json

def decode_action(action_index):
    action_type = action_index // (12 * 18 * 7)
    remaining = action_index % (12 * 18 * 7)
    
    source_col = remaining // (18 * 7)
    remaining = remaining % (18 * 7)
    
    source_idx = remaining // 7
    dest_col = remaining % 7
    
    return [action_type, source_col, source_idx, dest_col]

# Hyperparameters
input_dim = 6862  # Adjust based on your state space
output_dim = 4536  # Adjust based on your action space

# Model, loss function, optimizer
model = ComplexModel(input_dim, output_dim)
criterion = nn.MSELoss()  # Use MSE loss for Q-value prediction
optimizer = optim.Adam(model.parameters(), lr=0.001)

with open('gameplay_data.json', 'r') as f:
    loaded_dataset = json.load(f)

with open('gameplay_data1.json', 'r') as f:
    loaded_dataset1 = json.load(f)


loaded_dataset.extend(loaded_dataset1)


# Loop over the dataset for pretraining
epochs = 15  # You can adjust the number of epochs
for epoch in range(epochs):
    total_loss = 0
    for data in loaded_dataset:
        state = data['state']
        action = data['action']
        
        # Convert state and action to tensors
        tableau_tensor = torch.tensor(state['tableau'], dtype=torch.float32)
        foundation_tensor = torch.tensor(state['foundation'], dtype=torch.float32)
        top_card_tensor = torch.tensor(state['top_card'], dtype=torch.float32)
        
        state_tensor = torch.cat((tableau_tensor, foundation_tensor, top_card_tensor), dim=0)  # Flatten the state
        
        # Convert action to a scalar index
        action_type, source_col, source_idx, dest_col = action
        action_index = action_type * (12 * 18 * 7) + source_col * (18 * 7) + source_idx * 7 + dest_col
        
        # Forward pass
        optimizer.zero_grad()  # Clear previous gradients
        output = model(state_tensor)  # Get model predictions

        # Create the target tensor with the same shape as the output (4536,)
        target_tensor = torch.zeros_like(output)  # Start with a tensor of zeros
        target_tensor[action_index] = 1.0  # Set the value of the taken action to 1.0 (or the reward)

        # Compute loss (use MSE)
        loss = criterion(output, target_tensor)  # MSE loss

        total_loss += loss.item()
        
        # Backpropagation
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(loaded_dataset)}")


# After pretraining, use the pretrained model in the RL loop:

epsilon = 0.1  # Exploration rate
gamma = 0.99   # Discount factor for rewards
num_episodes = 10000  # Number of episodes for RL training

# Example environment (replace with your own)
env = SolitaireEnv()

# Training loop for RL with pretraining
for episode in range(num_episodes):
    state_dict = env.reset()  # Reset environment to get initial state
    state = state_dict[0]
    done = False
    total_reward = 0
    while not done:
        # Convert state to tensor (flattened)
        state_tensor = torch.cat((
            torch.tensor(state['tableau'], dtype=torch.float32),
            torch.tensor(state['foundation'], dtype=torch.float32),
            torch.tensor(state['top_card'], dtype=torch.float32)
        ), dim=0)

        # Epsilon-greedy strategy for exploration vs exploitation
        if torch.rand(1).item() < epsilon:
            action = env.action_space.sample()  # Random action (exploration)
        else:
            with torch.no_grad():
                output = model(state_tensor)  # Get model predictions (Q-values)
                action = torch.argmax(output).item()  # Max action (exploitation)
                action = decode_action(action)  # Decode to the original action

        # Take action and observe the result
        next_state, reward, done, timer, _ = env.step(action)
        total_reward += reward

        if timer:
            break
        
        # Compute target Q-value (using Bellman equation)
        next_state_tensor = torch.cat((
            torch.tensor(next_state['tableau'], dtype=torch.float32),
            torch.tensor(next_state['foundation'], dtype=torch.float32),
            torch.tensor(next_state['top_card'], dtype=torch.float32)
        ), dim=0)

        with torch.no_grad():
            next_output = model(next_state_tensor)  # Predicted Q-values for next state
            target = reward + gamma * torch.max(next_output)  # Bellman equation

        # Compute the predicted Q-value for the taken action
        output = model(state_tensor)
        predicted_q_value = output[action]  # Extract the Q-value for the taken action

        # Compute loss (use MSE loss)
        loss = criterion(predicted_q_value.unsqueeze(0), torch.tensor([target], dtype=torch.float32))

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        state = next_state  # Move to next state

    print(f"Episode {episode+1}/{num_episodes}, Total Reward: {total_reward}")    


Epoch [1/15], Loss: 0.002432723241353239
Epoch [2/15], Loss: 0.0003096109760734502
Epoch [3/15], Loss: 0.00021695073375645104
Epoch [4/15], Loss: 0.0001952525792612767
Epoch [5/15], Loss: 0.00019031383753864386
Epoch [6/15], Loss: 0.0001887987701245296
Epoch [7/15], Loss: 0.00018815530009290833
Epoch [8/15], Loss: 0.00018788353840503344
Epoch [9/15], Loss: 0.0001877892846919485
Epoch [10/15], Loss: 0.00018773792646183351
Epoch [11/15], Loss: 0.00018772579950650334
Epoch [12/15], Loss: 0.0001877010264877065
Epoch [13/15], Loss: 0.00018769874908303107
Epoch [14/15], Loss: 0.00018769035631083643
Epoch [15/15], Loss: 0.00018769044467030523


  return F.mse_loss(input, target, reduction=self.reduction)


Episode 1/10000, Total Reward: -22720
Episode 2/10000, Total Reward: -20630
Episode 3/10000, Total Reward: -22710
Episode 4/10000, Total Reward: -26620
Episode 5/10000, Total Reward: -24760
Episode 6/10000, Total Reward: -27510
Episode 7/10000, Total Reward: -26240
Episode 8/10000, Total Reward: -27150
Episode 9/10000, Total Reward: -24650
Episode 10/10000, Total Reward: -24370
Episode 11/10000, Total Reward: -24940
Episode 12/10000, Total Reward: -27890
Episode 13/10000, Total Reward: -25740
Episode 14/10000, Total Reward: -24560
Episode 15/10000, Total Reward: -27570
Episode 16/10000, Total Reward: -19610
Episode 17/10000, Total Reward: -25970
Episode 18/10000, Total Reward: -25980
Episode 19/10000, Total Reward: -23690
Episode 20/10000, Total Reward: -27290
Episode 21/10000, Total Reward: -26660
Episode 22/10000, Total Reward: -26660
Episode 23/10000, Total Reward: -26620
Episode 24/10000, Total Reward: -23430
Episode 25/10000, Total Reward: -23300
Episode 26/10000, Total Reward: -2

KeyboardInterrupt: 

Code to check the model in action

In [None]:
state_dict = env.reset()
state = state_dict[0]
n_steps = 20
for step in range(n_steps):
    state_tensor = torch.cat((
            torch.tensor(state['tableau'], dtype=torch.float32),
            torch.tensor(state['foundation'], dtype=torch.float32),
            torch.tensor(state['top_card'], dtype=torch.float32)
        ), dim=0)
    
    with torch.no_grad():
                output = model(state_tensor)  # Get model predictions (Q-values)
                action = torch.argmax(output).item()  # Max action (exploitation)
                action = decode_action(action)  # Decode to the original action

    print(f"Step {step + 1}")
    print("Action: ", action)
    obs, reward, done, truncated, info = env.step(action)
    #print("obs=", obs, "reward=", env.reward, "done=", done)
    print("reward= ", env.reward)
    env.render()