In [9]:
import jax
import jax.numpy as jnp
from typing import NamedTuple

In [12]:
class State(NamedTuple):
    """The state of the WFC algorithm."""

    wave: jnp.ndarray
    """Shape (H, W, C). wave[r, c, k] is True if pattern k is possible at
    position (r, c), and False otherwise."""

    compatible: jnp.ndarray
    """Shape (H, W, C, C). At each position, compatible[i, j, k, l] is 1 if
    pattern k is compatible with pattern l at position (i, j), and 0 otherwise."""

    weights: jnp.ndarray
    """The unnormalized positive probability assigned to each pattern."""

    entropies: jnp.ndarray
    """Shape (H, W). The entropy of each position. Used for picking the next state."""


In [None]:
def init_state(self, nr: int, nc: int, weights: jnp.ndarray):
    wave = jnp.zeros((nr, nc, self.nweights), dtype=jnp.bool_)
    compatible = jnp.zeros((nr, nc))
    entropies = jnp.full((nr, nc), init_entropy(weights))
    return State(wave, compatible, weights, entropies)

def init_entropy(weights):
    s = jnp.sum(weights)
    return jnp.log(s) - jnp.sum(weights * jnp.log(weights)) / s

def run(key, max_steps=100):
    for _ in range(max_steps):
        key, subkey = jax.random.split(key)

def eliminate(state: State, r: int, c: int, t: int):
    wave = state.wave.at[r, c, t].set(False)
    compatible = state.compatible.at[r, c].set(state.compatible[r, c] - 1)
    return State(wave, compatible, state.weights, state.entropies)

def next_node(key):
    min, argmin = float('inf'), None

    @property
    def nweights(self):
        return self.weights.shape[0]

    @staticmethod
    @abstractmethod
    def weights(self):
        ...