## Environment 

We need to create our own environment

Some basic methods:

- `env.reset()`
- `env.step(action)` -> observation, reward, terminated
- `env.compute_reward()`
- `env.compute_termination()`
- `env.compute_next_state()`



In [2]:
import torch 

import numpy as np 

In [1]:
from src.mechanism.utils import load_puzzle, generate_state_from_moves
from src.mechanism.reduce import iterate_reduce_sequence

In [3]:
class Sampler:
    def __init__(self) -> None:
        pass

    def sample(self) -> int:
        pass


class Uniform_sampler(Sampler):
    def __init__(self, low, high) -> None:
        super().__init__()
        self.low = low
        self.high = high

    def sample(self, size) -> int:
        return torch.randint(self.low, self.high, size)

In [4]:
env_config = {
    "puzzle_name": "cube_2x2x2",
    "num_envs": 100,
    "max_steps": 100,
    "reset_config": {
        "sampler": "uniform",
        "shuffle_range": [1, 10],
    },
    "reward_config": {
        "success": 50,
        "time": -1,
    },
}

In [10]:
class PuzzleEnv:
    def __init__(self, env_config):
        self._load_puzzle(env_config["puzzle_name"])
        self._load_config(env_config)

        # for exporting purposes
        # self.config = env_config
        self.puzzle_name = env_config["puzzle_name"]

        self.reset()

    def _load_puzzle(self, puzzle_name):
        self.move_dict, self.final_state = load_puzzle(
            puzzle_name, puzzle_dir="./puzzles"
        )
        self.state_size = len(self.final_state)
        # we just want to identify the move by an index:
        action_names, swaps = [], []
        for name, swap in self.move_dict.items():
            action_names.append(name)
            swaps.append(swap)

        # The available actions
        self.swaps = swaps
        self.action_names = np.array(action_names)
        self.num_actions = len(self.swaps)

    def _load_config(self, config):
        self.num_envs = config["num_envs"]
        self.max_steps = config["max_steps"]

        self.reset_config = config["reset_config"]
        self.sampler = Uniform_sampler(*self.reset_config["shuffle_range"])

    def step(self, actions):
        self.compute_next_state(actions)
        self.compute_reward()
        self.compute_termination()

    def reset(self):
        """Reinitialize & shuffle every state"""
        self.states = torch.empty((self.num_envs, self.state_size), dtype=torch.float32)
        self.gt_moves = [[]] * self.num_envs # make sure we have access to the ground truth moves
        ns = self.sampler.sample((self.num_envs,))

        for i in range(self.num_envs):
            n = ns[i].item()
            # sample n moves
            non_reduced_moves = np.random.choice(self.action_names, n)
            # reduce moves
            reduced_moves = iterate_reduce_sequence(non_reduced_moves, self.puzzle_name)
            # generate state from move
            state = generate_state_from_moves(
                reduced_moves, self.move_dict, self.final_state
            )
            self.states[i, :] = torch.tensor(state, dtype=torch.float32)
            self.gt_moves[i] = reduced_moves

    def compute_next_state(self, actions):
        pass

    def compute_reward(self):
        pass

    def compute_termination(self):
        pass


env = PuzzleEnv(env_config)

In [24]:
np.array(['a', 'b', 'c'])[[0, 0, 0]]

array(['a', 'a', 'a'], dtype='<U1')

In [14]:
type((1,))

tuple