# Solving Yashi with SAT

First, we need some dependencies and imports.

In [None]:
%pip install python-sat
from pysat.solvers import Minisat22
from pysat.examples.rc2 import RC2
from pysat.formula import WCNFPlus

%pip install matplotlib
from matplotlib import pyplot

from itertools import combinations, product
import requests

## Edges

The first step is to detect the possible edges out of a set of nodes positions.
This is done by iterating through every pair of nodes and checking if there can be a fully horizontal or vertical segment that connects them. Moreover there needs to be no other node inbetween them.

In [None]:
def find_edges(nodes):
    def edge_exist(ns, ne):
        nsx, nsy = nodes[ns]
        nex, ney = nodes[ne]

        # The edge is vertical.
        if nsx == nex and nsy != ney:
            # Check that no other node is strictly inbetween.
            for nmx, nmy in nodes:
                if nmx == nsx and nsy < nmy < ney:
                    return False
            return True
        
        # The edge is horizontal.
        if nsy == ney and nsx != nex:
            # Check that no other node is strictly inbetween.
            for nmx, nmy in nodes:
                if nmy == nsy and nsx < nmx < nex:
                    return False
            return True
        
        return False
    
    for ns, ne in combinations(range(len(nodes)), 2):
        if edge_exist(ns, ne):
            yield (ns, ne)

## Crossing edges

Once the set of possible edges has been found, the second step is finding out which of them cross each other.
This is also done by iterating through every possible edge pair and checking whether they cross, that is they have different directions and both are inbetween the two extremes of the other.

In [None]:
def find_crossings(nodes, edges):
    def edges_cross(e1, e2):
        ns1, ne1 = edges[e1]
        ns2, ne2 = edges[e2]
        
        ns1x, ns1y = nodes[ns1]
        ne1x, ne1y = nodes[ne1]
        ns2x, ns2y = nodes[ns2]
        ne2x, ne2y = nodes[ne2]

        # Note that nodes are sorted by x then y coordinate, both in nodes and in edges.

        # s1x < e1x is true only if edge1 is horizontal.
        # s2y < e2y is true only if edge2 is vertical.
        # The inclusion of s2x and s1y in the middle is true only if they cross.
        if ns1x < ns2x < ne1x and ns2y < ns1y < ne2y:
            return True
        
        # Like above, but now edge1 is vertical and edge2 is horizontal.
        if ns2x < ns1x < ne2x and ns1y < ns2y < ne1y:
            return True
        
        return False
    
    for e1, e2 in combinations(range(len(edges)), 2):
        if edges_cross(e1, e2):
            yield (e1, e2)

## Variables

At this point we need to generate the SAT clauses, but the library used to solve them requires the variables to be integers. In this section we define a couple of methods for converting the 3 kind of variables used into integer ids. The variables used are:

- $e_i$, which represent whether the edge $e_i$ is selected or not and are generated by `var_edge`;
- $c_{ij}$, which represent the fact that the node $n_i$ is a direct child of node $n_j$ in the final tree and are generated by `var_child`. Note that these variables exist only for nodes $n_i$ and $n_j$ that are directly connected by edges;
- $u_{ij}$, which represent whether the node $n_i$ stands under (not necessarily directly) the node $n_j$ in the final tree and are generated by `var_under`.

In [None]:
# e_i variables are directly encoded with the edge index (plus 1 because variables need to start from 1).
def var_edge(e):
    return 1 + e

# After the first len(edges) variables, the next 2*len(edges) variables are used for c_ij variables.
# Note that an edge from ni to nj is expected to exist, or this function will error.
# These are encoded sorted by edge index, which gets multiplied by 2 because there are two variables per edge.
def var_child(edges, ni, nj):
    # Edge nodes are always stored in order (ns, ne) with ns < ne.
    # Using tuple(sorted((ni, nj))) gives thus the correct encoding of the edge between ni and nj.
    return 1 + len(edges) + 2 * edges.index(tuple(sorted((ni, nj)))) + (ni > nj)

# After the first len(edges)+2*len(edges), the next len(nodes)**2 variables
# are used for u_ij variables. These are encoded like the index of a 2D matrix. 
def var_under(nodes, edges, ni, nj):
    return 1 + 3 * len(edges) + ni * len(nodes) + nj

## Clauses

### Clauses for crossings

We need the solution to not contain crossing edges. This can be stated as $\neg e_i \vee \neg e_j$ for all edges $e_i$ which crosses $e_j$.

In [None]:
def clauses_crossings(crossings):
    for ei, ej in crossings:
        yield [-var_edge(ei), -var_edge(ej)]

### Clauses for children

One of two necessary and sufficient conditions for a graph to be a spanning tree is to have exactly $n - 1$ (`len(nodes) - 1` in python) edges. Luckily this can be enforced via a local property: each edge can be seen as connecting a child to its parent, and every such connection is done by an edge. Moreover each node, excluding the root, has exactly one such edge.

Let $N_i = \{ n_j \mid \exists e_k.\ e_k = (n_i, n_j) \vee e_k = (n_j, n_i) \}$ be the set of neighbours of $n_i$.

This produces the following clauses:

- an edge is selected if and only if it connects a child to its parent. This can be stated as $e_i \leftrightarrow c_{jk} \vee c_{kj}$ for all edges $e_i$ which connects $n_j$ to $n_k$. It can then be converted to CNF as $(\neg e_i \vee c_{jk} \vee c_{kj}) \land (e_i \vee \neg c_{jk}) \land (e_i \vee \neg c_{kj})$;
- the root node $n_0$ has no parent. This can be stated as $\neg c_{0i}$ for all nodes $n_i \in N_0$
- every node except the root has exactly one parent, which can be split into:
  - every node except the root has at least one parent. This can be stated as $\vee_{n_j \in N_i} c_{ij}$ for all $i \neq 0$;
  - every node except the root has at most one parent. This can be stated as $\neg c_{ij} \vee \neg c_{ik}$ for all pairs $n_j$, $n_k$ in $N_i$ and $i \neq 0$.

In [None]:
def clauses_children(nodes, edges):
    # An edge is selected iff it connects a child to its parent.
    for e, (ni, nj) in enumerate(edges):
        ve = var_edge(e)
        vij = var_child(edges, ni, nj)
        vji = var_child(edges, nj, ni)
        yield [-ve, vij, vji]
        yield [ve, -vij]
        yield [ve, -vji]
    
    # The root node 0 has no parent.
    for ni, nj in edges:
        if ni == 0:
            yield [-var_child(edges, 0, nj)]

    # Every node except the root has exactly one parent.
    for ni in range(1, len(nodes)):
        # Let Ni be the set of neighbours of ni.
        Ni = [nj for nj in range(len(nodes)) if (ni, nj) in edges or (nj, ni) in edges]
        # ni has at least one parent.
        yield [var_child(edges, ni, nj) for nj in Ni]
        # ni has at most one parent.
        for nj, nk in combinations(Ni, 2):
            yield [-var_child(edges, ni, nj), -var_child(edges, ni, nk)]

### Clauses for the "under" relation

The second necessary and sufficient condition for a graph to be a spanning tree is to be acyclic. We can enforce this by forcing the existance of an "under" relation, that is a relation implied by the "child" relation just defined and closed under transitivity and irreflexivity.

To see why this prevent cycles, assume that a cycle exist, then pick one of its edges. This edge must connect a child $n_j$ to its parent $n_i$, so $n_j$ is under $n_i$. Consider then the other edge of the cycle connected to $n_j$. It can't connect it to its parent, because that edge is unique and we've already considered it, so it must connect it to one of its children $n_k$. Thus $n_k$ is under $n_j$, and by transitivity it's under $n_i$. Continuing this reasoning to all the other edges of the cycle, we get that all the other nodes we encounter are under the node $n_i$. However since we have a cycle we must reach $n_i$ again, thus obtaining that $n_i$ is under $n_i$. This is impossible because the under relation is irreflexive, so we got a contradiction and a cycle can't exist.

This also doesn't prevent valid solutions because every tree forms a hierarchy which clearly satisfies this relation.

This produces the following clauses:

- a child is always under its parent. This can be stated as $c_{ij} \to u_{ij}$ for all $n_i$ and $n_j$ connected by an edge, which in CNF becomes $\neg c_{ij} \vee u_{ij}$;
- irreflexivity: a node is never under itself. This can be stated as $\neg u_{ii}$ for all $n_i$;
- transitivity: if $n_i$ is under $n_j$ and $n_j$ is under $n_k$ then $n_i$ is under $n_k$ as well. This can be stated as $u_{ij} \land u_{jk} \to u_{ik}$ for all $n_i$, $n_j$ and $n_k$, which in CNF becomes $\neg u_{ij} \vee \neg u_{jk} \vee u_{ik}$.
  - transitivity (optimized): we can reduce the cases to when $n_i$ is a child of $n_j$.

In [None]:
def clauses_under(nodes, edges):
    # A child is always under its parent.
    for ni, nj in edges:
        # Note that for each edge there are two cij variables due to the two possibilities
        # for which node is the child and which is the parent.
        # Thus we need to generate two implications here, one for each variable.
        yield [-var_child(edges, ni, nj), var_under(nodes, edges, ni, nj)]
        yield [-var_child(edges, nj, ni), var_under(nodes, edges, nj, ni)]
        
    # Irreflexivity: a node is never under itself.
    for n in range(len(nodes)):
        yield [-var_under(nodes, edges, n, n)]

    # Transitivity (optimized): if ni is a children of nj and nj is under nk then ni is under nk as well.
    for ni, nk in product(range(len(nodes)), n=2):
        # Let Ni be the set of neighbours of ni.
        Ni = [nj for nj in range(len(nodes)) if (ni, nj) in edges or (nj, ni) in edges]
        for nj in Ni:
            yield [-var_child(edges, ni, nj), -var_under(nodes, edges, nj, nk), var_under(nodes, edges, ni, nk)]

Finally we merge all the clauses in a single function.

In [None]:
def clauses(nodes, edges, crossings):
    yield from clauses_crossings(crossings)
    yield from clauses_children(nodes, edges)
    yield from clauses_under(nodes, edges)

Here's a small helper to plot the solution of a yashi game.

In [None]:
def plot_solution(nodes, edges, model):
    xs = [x for x, _ in nodes]
    ys = [y for _, y in nodes]
    minx, maxx = min(xs), max(xs)
    miny, maxy = min(ys), max(ys)

    pyplot.figure()
    pyplot.gca().set_aspect("equal", "box")
    pyplot.xticks(range(minx, maxx + 1))
    pyplot.yticks(range(miny, maxy + 1))
    pyplot.xlim(minx - 0.5, maxx + 0.5)
    pyplot.ylim(miny - 0.5, maxy + 0.5)
    pyplot.plot(xs, ys, "ok", linestyle='', markersize=7)

    for e in range(len(edges)):
        if var_edge(e) in model:
            ns, ne = edges[e]
            nsx, nsy = nodes[ns]
            nex, ney = nodes[ne]
            pyplot.plot((nsx, nex), (nsy, ney), "k")

This is the main entrypoint for solving a yashi game, given a list of node, each consisting of its position in the grid.

In [None]:
def yashi(nodes):
    nodes = sorted(nodes)
    edges = list(find_edges(nodes))
    crossings = find_crossings(nodes, edges)

    solver = Minisat22(bootstrap_with = clauses(nodes, edges, crossings))

    if solver.solve():
        plot_solution(nodes, edges, solver.get_model())
    else:
        print("There's no solution")

This is the main entrypoint for solving a yashi game with the minimum length solution.

## Minimum-length solution

To find a minimum-length solution we can use Max-SAT. Since we have to maximize something, but the goal is to minimize the total length, we can maximize its negation. Given a solution we can get its length by summing the length of its edges, so it natually follows that we should assign to each edge the negation of its length.

In [None]:
def min_yashi(nodes):
    nodes = sorted(nodes)
    edges = list(find_edges(nodes))
    crossings = find_crossings(nodes, edges)

    # RC2 is a pythonsat solver for Max-SAT
    solver = RC2(WCNFPlus(), solver="m22")

    # Manually add hard clauses because there's no `extend` or `bootstrap_with`
    for clause in clauses(nodes, edges, crossings):
        solver.add_clause(clause)

    # Add the soft clauses for the edge lengths
    for e, (ns, ne) in enumerate(edges):
        nsx, nsy = nodes[ns]
        nex, ney = nodes[ne]
        weight = -(abs(nsx - nex) + abs(nsy - ney))
        solver.add_clause([var_edge(e)], weight=weight)

    model = solver.compute()
    if model:
        plot_solution(nodes, edges, model)
    else:
        print("There's no solution")

Helper function to extract the list of nodes from a `sumsumpuzzle.com` URL.

In [None]:
def sumsumpuzzle_nodes(url):
    page = requests.get(url).text
    grid = page.split("grid = [,\n")[1].split("];\n</script>")[0]
    for y, line in enumerate(reversed(grid.split("\n"))):
        for x, cell in enumerate(line.lstrip("[,").rstrip(",").rstrip("]").split(",")):
            if cell == "dot":
                yield (x+1, y+1)

## Examples

Examples of a solving some instances from `sumsumpuzzle.com`

In [None]:
urls = [
    "http://www.sumsumpuzzle.com/Y201108_0707P.htm",
    "http://www.sumsumpuzzle.com/Y201108_0808P.htm",
    "http://www.sumsumpuzzle.com/Y201108_0909P.htm",
    "http://www.sumsumpuzzle.com/Y201108_1010P.htm",
    "http://www.sumsumpuzzle.com/Y201108_1212P.htm",
    "http://www.sumsumpuzzle.com/Y201108_1414P.htm",
    "http://www.sumsumpuzzle.com/Y201108_1616P.htm",
    "http://www.sumsumpuzzle.com/Y201108_1818P.htm",
    "http://www.sumsumpuzzle.com/Y201108_2020P.htm",
]
for url in urls:
    yashi(sumsumpuzzle_nodes(url))

Example of `min_yashi` finding the minimum-length solution where `yashi` doesn't.

In [None]:
nodes = [(0, 2), (0, 0), (2, 0), (4, 0), (4, 2)]
yashi(nodes)
min_yashi(nodes)

Example of `min_yashi` with some instances from `sumsumpuzzle.com`.

In [None]:
urls = [
    "http://www.sumsumpuzzle.com/Y201108_0707P.htm",
    "http://www.sumsumpuzzle.com/Y201108_0808P.htm",
    "http://www.sumsumpuzzle.com/Y201108_0909P.htm",
    "http://www.sumsumpuzzle.com/Y201108_1010P.htm",
    "http://www.sumsumpuzzle.com/Y201108_1212P.htm",
    "http://www.sumsumpuzzle.com/Y201108_1414P.htm",
    "http://www.sumsumpuzzle.com/Y201108_1616P.htm",
    "http://www.sumsumpuzzle.com/Y201108_1818P.htm",
    "http://www.sumsumpuzzle.com/Y201108_2020P.htm",
]
for url in urls:
    min_yashi(sumsumpuzzle_nodes(url))