In [None]:
import gym
import numpy as np
import itertools
from gym import spaces
import itertools

class StickGameEnv(gym.Env):
    def __init__(self):
        super(StickGameEnv, self).__init__()
        self.all_combinations = generate_all_combinations()
        self.action_space = spaces.Discrete(len(self.all_combinations))
        self.observation_space = spaces.MultiBinary(12)
        self.state = None
        self.reset()

    def reset(self):
        self.state = np.ones(12, dtype=int)
        return self.state

    def step(self, action):
        action = int(action)
        combination = self.all_combinations[action]
        if not self._is_valid_action(combination):
            return self.state, 0, True, {}  # Invalid action, end the game

        for stick in combination:
            self.state[stick] = 0

        done = np.all(self.state == 0)
        reward = 1 if done else 0

        return self.state, reward, done, {}

    def _is_valid_action(self, combination):
        return all(self.state[stick] == 1 for stick in combination)

    def valid_actions(self):
        dice_roll = np.random.randint(2, 13)
        sticks = [i + 1 for i, present in enumerate(self.state) if present == 1]

        valid = []
        for L in range(1, len(sticks) + 1):
            for subset in itertools.combinations(sticks, L):
                if sum(subset) == dice_roll:
                    subset_action = np.zeros(12, dtype=int)
                    for stick in subset:
                        subset_action[stick - 1] = 1
                    if np.all(subset_action <= self.state):
                        valid.append(subset_action)
        return valid

def generate_all_combinations():
    all_combinations = {}
    action_id = 0

    for dice_roll in range(2, 13):
        for num_sticks in range(1, dice_roll + 1):
            for combination in itertools.combinations(range(12), num_sticks):
                if sum(combination) + len(combination) == dice_roll:
                    all_combinations[action_id] = combination
                    action_id += 1

    return all_combinations

all_combinations = generate_all_combinations()


In [None]:
from stable_baselines3 import DQN
from stable_baselines3.common.env_util import make_vec_env

# Create the environment
env = make_vec_env(lambda: StickGameEnv(), n_envs=1)

# Initialize the DQN model
model = DQN("MlpPolicy", env, verbose=1)

# Train the model
model.learn(total_timesteps=1e6)


In [None]:
def evaluate_model(model, num_episodes=100):
    env = StickGameEnv()
    win_count = 0

    for episode in range(num_episodes):
        obs = env.reset()
        done = False
        while not done:
            action, _states = model.predict(obs, deterministic=True)
            obs, rewards, done, info = env.step(action)
            if done and rewards == 1:
                win_count += 1

    win_rate = win_count / num_episodes
    return win_rate

# Evaluate the trained model
win_rate = evaluate_model(model, num_episodes=10000)
print("Win Rate:", win_rate)


In [11]:
import gymnasium as gym
import numpy as np
import itertools
from gymnasium import spaces

class StickGameEnv(gym.Env):
    def __init__(self):
        super(StickGameEnv, self).__init__()
        self.all_combinations = generate_all_combinations()
        self.action_space = spaces.Discrete(len(self.all_combinations))

        # Define observation space with 12 elements for sticks and 1 element for the dice roll
        # Sticks can be 0 or 1, dice roll can be between 2 and 12
        self.observation_space = spaces.Box(low=0, high=12, shape=(13,), dtype=int)

        self.state = None
        self.dice_roll = None
        self.reset()

    def reset(self, seed=None):
        self.state = np.ones(12, dtype=int)  # Reset stick state
        self.dice_roll = self._roll_dice()   # Initial dice roll
        initial_observation = np.append(self.state, self.dice_roll)
        return initial_observation, {}  # Return observation and empty info dictionary

    def step(self, action):
        action = int(action)
        combination = self.all_combinations[action]

        if not self._is_valid_action(combination):
            return np.append(self.state, self.dice_roll), -1, False, False, {}  # Penalize for invalid action

        for stick in combination:
            self.state[stick - 1] = 0

        terminated = bool(np.all(self.state == 0))
        reward = 1 if terminated else 0

        # Roll the dice for the next state
        self.dice_roll = self._roll_dice()
        truncated = False  # This can be modified as per your game logic if there's a truncation condition
        return np.append(self.state, self.dice_roll), reward, terminated, truncated, {}

    def _roll_dice(self):
        return np.random.randint(1, 7) + np.random.randint(1, 7)  # Two dice roll

    def _is_valid_action(self, combination):
        return all(self.state[stick - 1] == 1 for stick in combination)

    def valid_actions(self):
        dice_roll = np.random.randint(2, 13)
        sticks = [i + 1 for i, present in enumerate(self.state) if present == 1]

        valid = []
        for L in range(1, len(sticks) + 1):
            for subset in itertools.combinations(sticks, L):
                if sum(subset) == dice_roll:
                    subset_action = np.zeros(12, dtype=int)
                    for stick in subset:
                        subset_action[stick - 1] = 1
                    if np.all(subset_action <= self.state):
                        valid.append(subset_action)
        return valid

def generate_all_combinations():
    all_combinations = {}
    action_id = 0

    for dice_roll in range(2, 13):
        for num_sticks in range(1, min(dice_roll + 1, 13)):
            for combination in itertools.combinations(range(1, 13), num_sticks):
                if sum(combination) == dice_roll:
                    all_combinations[action_id] = combination
                    action_id += 1

    return all_combinations

In [12]:
from stable_baselines3.common.env_checker import check_env
check_env(StickGameEnv())