# How to write a new scikit-decide solver: depth-first search

In this tutorial, we detail how to implement a new scikit-decide solver.
To keep it simple, we choose to implement a simple depth-first search, that stops whenever a goal is found.
We will apply it to the maze domain.

Defining a solver is a matter of:
- defining the characteristics from the domain needed by the solver
- selecting the necessary solver characteristics
- auto-generating the code skeleton from the combination above (with all abstract methods needed)
- filling the code as needed

The first steps can be accomplished via the [code generator](https://airbus.github.io/scikit-decide/codegen/) available in the scikit-decide online documentation.
The last step is where you really need to type something (namely the solver logics).


<div class="alert alert-block alert-warning">

**Disclaimer:** 
The chosen solver is a simple one, used only to showcase *in a pedagogical way* how to implement a new scikit-decide solver.
It is not adapted to large domain. 
For a more realistic solver, one could use for instance a [greedy best-first search](https://en.wikipedia.org/wiki/Best-first_search)
that uses an heuristic to guide the search. Or use one of the solvers available in scikit-decide hub as [A*](https://airbus.github.io/scikit-decide/reference/_skdecide.hub.solver.astar.astar.html#astar), as presented in the [tutorial dedicated to the maze domain](https://colab.research.google.com/github/airbus/scikit-decide/blob/master/notebooks/11_domain_tuto.ipynb).

</div>


Let's dive in.

## Define the characteristics from the domain needed by the solver


We want to specify the domains that will be compatible with the solver. 
To achieve it, we select the most generic characteristics that the solver can handle.

In our case, we want a domain with
- deterministic transitions
- deterministic initial state
- fully observable
- markovian
- single agent
- enumerable actions

It can be achieved with the template `DeterministicPlanningDomain` except for the condition "enumerable actions".
For the latter, we will see how to enforce it via a specific method during the implementation.

*Translation on [code generators](https://airbus.github.io/scikit-decide/codegen/) page:*

- Click on "Edit" in "Domain specification"

<img src="./solver-tuto/code-generator-solver-1.png" style="height:15em;">

- Choose the template `DeterministicPlanningDomain`

<img src="./solver-tuto/code-generator-solver-2.png" style="height:25em;">






## Select the solver characteristics

Our solver will be deterministic so we can choose the `DeterministicPolicySolver` template. 
As the depth-first search can be launched a priori from any state of the domain, so we can also add the characteristic `FromAnyState`.

Note that we could also make the solver restorable by implementing a save/load of the graph traversal. In that case we should add the `Restorable` characteristic.
To keep things simple we do not here.


*Translation on [code generators](https://airbus.github.io/scikit-decide/codegen/) page:*

- Toggle the button  "Create Solver".

<img src="./solver-tuto/code-generator-solver-3.png" style="height:15em;">

- Click on "Edit" in "Solver specification"
- Choose the template `DeterministicPolicySolver`
- Update it by clicking on `FromAnyState`

<img src="./solver-tuto/code-generator-solver-4.png" style="height:25em;">

## Generate the skeleton

We just have to click on the "Copy code" button and paste it:

```python

from typing import *

from skdecide import *
from skdecide.builders.domain import *
from skdecide.builders.solver import *


class D(DeterministicPlanningDomain):
    pass


class MySolver(DeterministicPolicySolver, FromAnyState):
    T_domain = D

    
    def _solve_from(self, memory: D.T_state) -> None:
        pass
    
    def _get_next_action(self, observation: D.T_observation) -> D.T_event:
        pass
    
    def _is_policy_defined_for(self, observation: D.T_observation) -> bool:
        pass
    

```



As we want to enforce an enumerable action space for the domain, we also add to the skeleton the method `_check_domain_additional()`:
```python
    @classmethod
    def _check_domain_additional(cls, domain: Domain) -> bool:
        pass
```

<div class="alert alert-block alert-info">

**Note:**

In scikit-decide, the methods to implement by a domain or solver developper are prefixed with `_`. 
On the contrary the user of a domain or a solver should call the methods not prefixed by `_`. 

For instance:

- As a solver *developper*, we *implement* `_solve_from()`.
- As a solver *user*, we *call* `solve()`, as we will see later.

</div>

## Implement the solver

Now the real work is starting.

The depth-first search algorithm is quite simple. We see the possible states of the domain as nodes of a graph, 
edges corresponding to the different possible actions possible. 
Starting from the given state, we perform a depth-first search that stops when reaching a goal.

More precisely, we apply successively the first available action to the domain until 
- reaching a goal: in that case we stop the search (no optimization of the cost),
- reaching a dead-end: in that case we roll-back to the previous state and choose the next available action,
- exhausting the actions available at the current state: we roll-back to the previous state and choose the next available action,
- not being able to roll-back (when exhausting available actions from initial state).

Some remarks:
- We specify a max depth to avoid a too long computation for large domains.
- We store the graph traversal history to avoid recomputing it each time the solver is asked to sample an action.
- We also need to implement an `__init__()` method. We make sure to call the constructor of the base class `Solver` which
  takes a domain factory as an argument and cast it to the specified level of characteristics. The cast domain factory is available in   `self.domain_factory`.
- As the domain is fully observable, observation and state are the same thing.
- **Disclaimer:** This is a simple version not using any heuristic and thus not optimizing in any way the cost. We only try to reach a goal.
  Of course, as we blindly traverse the graph, it can be very long to solve and should be used only on small domains.
  But it fills the purpose of showcasing the implementation of a new solver.

In [None]:
from typing import *

from skdecide import *
from skdecide.builders.domain import *
from skdecide.builders.solver import *


class D(DeterministicPlanningDomain):
    pass


class DFSSolver(DeterministicPolicySolver, FromAnyState):
    """Depth-first search solver.

    The considered oriented graph is:
     - nodes: domain states
     - edges: domain actions linking a state to the resulting state when the action is applied

    We perform a DFS of this graph and stop whenever a goal or a max depth is reached.

    The traversal is made from a given point and the resulting plan is stored as a policy mapping a state to the next action.
    Whenever the solver is asked for a new action, either this is from a state in the computed policy, or a new solve is done from that new state.

    Args:
        max_depth: maximal depth for the DFS
        render_during_solve: for pedagogical purposes, it can be nice to see the traversal performed
            during the solving process. This flag enables it, if the domain is renderable.

    """

    T_domain = D

    @classmethod
    def _check_domain_additional(cls, domain: Domain) -> bool:
        """Check that the domain as enumerable space of action."""
        return isinstance(domain.get_action_space(), EnumerableSpace)

    def __init__(
        self,
        domain_factory: Callable[[], Domain],
        max_depth: int = 1000,
        render_during_solve: bool = False,
    ):
        """Constructor."""
        super().__init__(domain_factory)
        self.max_depth = max_depth
        self.render_during_solve = render_during_solve
        # initialize the policy
        self.policy: dict[D.T_observation, D.T_agent] = {}
        # initialize a domain to test actions
        self.domain: DeterministicPlanningDomain = self.domain_factory()

    def _solve_from(self, memory: D.T_state) -> None:
        """Solve from the given state.

        Launch the DFS with the state as root node.

        """
        self.current_state = memory
        goal_reached = self.domain.is_goal(self.current_state)
        queue = self._get_state_n_applicable_actions_as_a_list(self.current_state)
        visited_states = {self.current_state}
        current_plan = []
        # DFS loop
        while (
            not goal_reached and len(queue) > 0 and len(current_plan) < self.max_depth
        ):
            state, action = queue.pop()
            # rollback the plan if needed (if a dead-end was reached)
            while state != self.current_state:
                self.current_state, _ = current_plan.pop()

            # update the plan with the new action to test
            current_plan.append((state, action))
            # apply
            self.current_state = self.domain.get_next_state(state, action)
            # check if we reach an already visited state (in case of loops in the graph)
            if self.current_state in visited_states:
                # drop the move
                current_plan.pop()
                self.current_state = state
            else:
                visited_states.add(self.current_state)
                # check if we reach
                #  - a goal
                #  - a state from which we know already a policy from a previous call to `solve()`
                if self.domain.is_goal(self.current_state):
                    # bingo
                    goal_reached = True
                elif self.current_state in self.policy:
                    # from here the previously computed policy get to the goal: bingo
                    goal_reached = True
                else:
                    # goal not yet reached: we add applicable actions from next state
                    # NB: if we are in a deadend, nothing will be added, so next tested action will be from a previous state
                    queue.extend(
                        self._get_state_n_applicable_actions_as_a_list(
                            self.current_state
                        )
                    )

        # Check stop reason
        if goal_reached:
            # add computed plan to the policy (update only to keep track of previous call to `solve()`)
            self.policy.update(current_plan)
        else:
            # solve fails => raise error
            if len(current_plan) >= self.max_depth:
                # due to max_depth
                raise RuntimeError(
                    "The solver was unable to find a solution within the given max depth."
                )
            else:
                # no valid path exists
                raise RuntimeError(
                    "The solver was trapped in a deadend. The domain has no solution from the given initial state."
                )

    def _get_next_action(self, observation: D.T_observation) -> D.T_event:
        """Choose the next action according to the computed policy.

        If the state has not yet been visited, solve from this state.

        """
        if observation not in self.policy:
            self._solve_from(observation)
        return self.policy[observation]

    def _is_policy_defined_for(self, observation: D.T_observation) -> bool:
        """Tell whether the state (=observation as fully observable) is in the policy."""
        return observation in self.policy

    def _get_state_n_applicable_actions_as_a_list(
        self, state: D.T_state
    ) -> list[tuple[D.T_state, D.T_event]]:
        """Get applicable actions and make a list of (state, action) with it."""
        applicable_actions_space = self.domain.get_applicable_actions(state)
        action_space: EnumerableSpace = self.domain.get_action_space()
        return [
            (state, action)
            for action in action_space.get_elements()
            if action in applicable_actions_space
        ]

    @property
    def current_state(self) -> D.T_state:
        """Current state/node in DFS."""
        return self._current_state

    @current_state.setter
    def current_state(self, state: D.T_state) -> None:
        """Setter for current state in DFS.

        Useful to render current_state each time it is moving during the solving process.

        """
        self._current_state = state
        if self.render_during_solve and isinstance(self.domain, Renderable):
            self.domain.render(state)

## Test the solver on the maze domain

We use the maze domain from the scikit-decide hub.

We first define a domain factory to be feed to the solver and make sure to use the option allowing an inline display in a jupyter notebook.

In [None]:
from skdecide.hub.domain.maze import Maze

domain_factory = lambda: Maze(display_in_jupyter=True)

Now the new solver should be compatible with the maze domain. Let us check it:

In [None]:
domain = domain_factory()
assert DFSSolver.check_domain(domain)

We need a state from which starting the maze, we can take the initial state returned by the `reset()` method.

In [None]:
state = domain.reset()

Now we can solve and rollout from this state. We use the option `render_during_solve` to display the live DFS.

In [None]:
with DFSSolver(
    domain_factory=domain_factory, max_depth=100, render_during_solve=True
) as solver:
    solver.solve(from_memory=state)
    episodes = rollout(
        domain=domain, solver=solver, render=True, verbose=False, return_episodes=True
    )