In [None]:
import dataclasses
import numpy as np
import numpy.typing as npt
import matplotlib.pyplot as plt
import scipy.special
import scipy.stats

In [None]:
def compute_entropy(n: npt.ArrayLike, v: npt.ArrayLike) -> np.ndarray:
    n = np.array(n)
    return n * np.log(v) - scipy.special.gammaln(n + 1)

def compute_split_entropy_diff(
        n_a: npt.ArrayLike,
        n: npt.ArrayLike, 
        p_a: npt.ArrayLike, 
    ) -> np.ndarray:
    n_a = np.array(n_a)
    n = np.array(n)
    p_a = np.array(p_a)
    return (
        compute_entropy(n_a, p_a) +
        compute_entropy(n - n_a, 1 - p_a) -
        compute_entropy(p_a * n, p_a) -
        compute_entropy((1 - p_a) * n, 1 - p_a)
    )

@dataclasses.dataclass
class State:
    v: tuple[int, int]
    position: np.ndarray

    @staticmethod
    def init_random(
        n_particles: int,
        v: tuple[int, int],
        rng: np.random.Generator
    ) -> "State": 
        return State(
            v=v,
            position = np.stack(
                [rng.integers(0, v_i, n_particles, dtype=np.int32) for v_i in v],
                axis=1
            )
        )
    def n_particles(self) -> int:
        return self.position.shape[0]

    def step(self, rng: np.random.Generator):
        displacements = np.array([[1, 0], [0, 1], [-1, 0], [0, -1]])
        index = rng.integers(0, 4, size=self.n_particles())
        displacements = displacements[index]
        self.position += displacements
        self.position = np.clip(self.position, [0, 0], [v - 1 for v in self.v])

    def compute_half_split_count(self):
        return np.sum(self.position[:, 0] < self.v[0] // 2)

In [None]:
rng = np.random.default_rng(0)
state = State.init_random(10_000, (100, 50), rng)
half_split_count = [state.compute_half_split_count()]
for i in range(100_000):
    state.step(rng)
    half_split_count.append(state.compute_half_split_count())
    if (i + 1) % 10_000 == 0:
        print(f"Done step {i + 1}")
half_split_count = np.array(half_split_count)

In [None]:
entropy = compute_split_entropy_diff(
    half_split_count,
    state.n_particles(),
    0.5
)
figure, axes = plt.subplots(1, 1)
axes.plot(entropy, linewidth=0.1)
plt.show(figure)
plt.close(figure)

In [None]:
count, bin_edges = np.histogram(entropy)
figure, axes = plt.subplots(1, 1)
axes.stairs(np.log(count), bin_edges)
plt.show(figure)
plt.close(figure)
