In [1]:
from gflownet.envs.scrabble import Scrabble
env = Scrabble()

print(env.action_space)
print(env.state)

[(1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,), (9,), (10,), (11,), (12,), (13,), (14,), (15,), (16,), (17,), (18,), (19,), (20,), (21,), (22,), (23,), (24,), (25,), (26,), (-1,)]
[0, 0, 0, 0, 0, 0, 0]


## Testing GFlowNet on Frozen-Lake

### My Code:-

In [27]:
import torch
import random
import numpy as np
import torch.nn as nn
import gymnasium as gym
from gymnasium.wrappers import RecordVideo
from gymnasium import ActionWrapper, RewardWrapper

In [None]:
env = gym.make('FrozenLake-v1', is_slippery=False)

<TimeLimit<OrderEnforcing<PassiveEnvChecker<FrozenLakeEnv<FrozenLake-v1>>>>>

In [28]:
class ModifiedStep(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        self.grid_size = int(np.sqrt(env.observation_space.n))
        self.goal_state = self._find_goal_state()
        self.goal_pos = self._state_to_pos(self.goal_state)

    def _state_to_pos(self, state):
        return (state // self.grid_size, state % self.grid_size)

    def _find_goal_state(self):
        desc = self.env.unwrapped.desc  # shape (4, 4), values like b'S', b'F', b'G'
        for r in range(desc.shape[0]):
            for c in range(desc.shape[1]):
                if desc[r][c] == b'G':
                    return r * self.grid_size + c
        raise ValueError("Goal state not found")

    def step(self, action):
        prev_state = self.env.unwrapped.s
        prev_pos = self._state_to_pos(prev_state)
        prev_dist = self._manhattan(prev_pos, self.goal_pos)

        next_state, reward, terminated, truncated, info = self.env.step(action)
        next_pos = self._state_to_pos(next_state)
        next_dist = self._manhattan(next_pos, self.goal_pos)

        if reward == 1.0:
            return next_state, 1.0, terminated, truncated, info
        elif next_dist < prev_dist:
            return next_state, 0.1, terminated, truncated, info
        else:
            return next_state, 0.0, terminated, truncated, info

    def _manhattan(self, a, b):
        return abs(a[0] - b[0]) + abs(a[1] - b[1])

In [29]:
env = ModifiedStep(env)

In [34]:
state, _ = env.reset()

for _ in range(10):
    action = env.action_space.sample()
    act_dict = {0:'RIGHT', 1:'DOWN', 2:'LEFT', 3:'UP'}
    print(act_dict[action], env.step(action))

RIGHT (0, 0.0, False, False, {'prob': 1.0})
LEFT (1, 0.1, False, False, {'prob': 1.0})
LEFT (2, 0.1, False, False, {'prob': 1.0})
LEFT (3, 0.1, False, False, {'prob': 1.0})
DOWN (7, 0.1, True, False, {'prob': 1.0})
RIGHT (7, 0.0, True, False, {'prob': 1.0})
DOWN (7, 0.0, True, False, {'prob': 1.0})
UP (7, 0.0, True, False, {'prob': 1.0})
DOWN (7, 0.0, True, False, {'prob': 1.0})
DOWN (7, 0.0, True, False, {'prob': 1.0})


### Simplified Version:-

In [35]:
import gymnasium as gym

class SimpleFrozenLake:
    def __init__(self, map_name="4x4", is_slippery=False):
        # Create the FrozenLake environment
        self.env = gym.make("FrozenLake-v1", map_name=map_name, is_slippery=is_slippery)
        self.eos = 4  # Special "End Of Sequence" action
        self.action_space = [0, 1, 2, 3, self.eos]  # 4 directions + EOS
        self.state = None
        self.done = False
        self.n_actions = 0

    def reset(self):
        self.state, _ = self.env.reset()
        self.state = int(self.state)
        self.done = False
        self.n_actions = 0
        return self.state

    def step(self, action):
        if self.done:
            return self.state, action, True

        if action == self.eos:
            self.done = True
            self.n_actions += 1
            return self.state, action, True

        next_state, reward, terminated, truncated, _ = self.env.step(action)
        self.state = int(next_state)
        self.n_actions += 1
        self.done = terminated or truncated
        return self.state, action, True

    def get_action_space(self):
        return self.action_space

    def state_to_coordinates(self, state=None):
        if state is None:
            state = self.state
        row, col = divmod(state, 4)
        return f"({row}, {col})"


In [None]:
env = SimpleFrozenLake()

### GPT Code:-

In [36]:
from typing import List, Tuple, Optional, Union

import gymnasium as gym
import torch
from torchtyping import TensorType

from gflownet.envs.base import GFlowNetEnv
from gflownet.utils.common import tlong, tfloat

class FrozenLake(GFlowNetEnv):
    def __init__(self,
                 map_name: str = "4x4",
                 is_slippery: bool = False,
                 **kwargs):
        # underlying Gym env
        self.env = gym.make("FrozenLake-v1",
                            map_name=map_name,
                            is_slippery=is_slippery)
        # EOS pseudo-action
        self.eos = 4
        # call parent ctor (sets device, logger helpers, etc.)
        super().__init__(**kwargs)

    # ───────────── required API ───────────── #
    def get_action_space(self):
        # 4 navigation moves + EOS
        return [0, 1, 2, 3, self.eos]

    def reset(self):
        self.state, _ = self.env.reset()
        self.state = int(self.state)          # make it a plain int
        self.done = False
        self.n_actions = 0
        return self.state

    def step(self, action: Tuple[int], skip_mask_check: bool = False) -> Tuple[int, int, bool]:
        do_step, self.state, action = self._pre_step(
            action, skip_mask_check or self.skip_mask_check)
        if not do_step:                              # mask said “invalid”
            return self.state, action, False

        # EOS action: mark done but leave state unchanged
        if action == self.eos:
            self.done = True
            self.n_actions += 1
            return self.state, action, True

        # Forward step in Gym env
        next_state, reward, terminated, truncated, _ = self.env.step(action)
        self.state = int(next_state)
        self.n_actions += 1
        self.done = terminated or truncated
        # Auto-inject EOS once the lake episode finishes
        if self.done:
            action = self.eos
        return self.state, action, True

    # ───────────── masks ───────────── #
    def get_mask_invalid_actions_forward(self,
                                         state: Optional[int] = None,
                                         done: Optional[bool] = None):
        """Allow EOS only after termination."""
        done = self._get_done(done)
        if done:                   # only EOS valid
            return [a != self.eos for a in self.action_space]
        else:                      # navigation moves valid, EOS invalid
            return [a == self.eos for a in self.action_space]

    def get_parents(self, state=None, done=None, action=None):
        # For Trajectory-Balance objectives we can get by with a dummy
        # implementation that says “parent is self + EOS”.
        return [state], [self.eos]

    # ───────────── encodings ───────────── #
    def states2policy(self,
                      states: Union[List[int],
                                    TensorType["batch"]]):     # noqa: F821
        """One-hot encode 16 discrete tiles."""
        states = tlong(states, device=self.device)
        n = states.shape[0]
        out = torch.zeros(n, 16, dtype=self.float, device=self.device)
        out[torch.arange(n, device=self.device), states] = 1.
        return out

    def states2proxy(self, states):
        # No extra features – just reuse one-hot
        return self.states2policy(states)

    def state2readable(self, state=None, alphabet=None):
        s = self._get_state(state)
        r, c = divmod(s, 4)
        return f"({r},{c})"