In [237]:
#!/usr/bin/env python3

#############################
# 1) Real Semiring
#############################

import numpy as np
from math import isclose
from __future__ import annotations

import copy
from collections import Counter
from collections import defaultdict as dd
from collections import deque
from itertools import product
from typing import Callable, Dict, Generator, List, Optional, Sequence, Set, Tuple, Type, Union

import numpy as np
from frozendict import frozendict

import rayuela
from rayuela.base.semiring import Boolean, ProductSemiring, Real, Semiring
from rayuela.base.state import PairState, State
from rayuela.base.symbol import Expr, Sym, ε, ε_1, ε_2, φ
from rayuela.cfg.nonterminal import NT, S
from rayuela.fsa.pathsum import Pathsum, Strategy
from rayuela.fsa.fsa import FSA

def build_ngram_wfsa(ngram_model, n, alphabet, bos='[BOS]', eos='[EOS]'):
    """
    Build a Weighted FSA (Real semiring) capturing the distribution
    of an n-gram model, one-symbol-per-transition style.

    ngram_model: dict => ngram_model[context][symbol] = probability
    n: the 'n' in n-gram
    alphabet: list of symbols
    bos, eos: special tokens
    """
    fsa = FSA(R=Real)

    # define final "absorbing" state for after EOS
    q_final = State("<<FINAL>>")
    fsa.add_state(q_final)
    fsa.set_F(q_final, Real.one)

    # gather contexts
    contexts = list(ngram_model.keys())  # each is (n-1)-tuple or () if n=1
    context2state = {}
    for ctx in contexts:
        ctxName = str(ctx) if ctx else "()"
        q = State(ctxName)
        context2state[ctx] = q
        fsa.add_state(q)

    # start context
    if n>1:
        start_ctx = tuple([bos]*(n-1))
    else:
        start_ctx = ()

    if start_ctx not in context2state:
        qstart = State(str(start_ctx))
        context2state[start_ctx] = qstart
        fsa.add_state(qstart)
    else:
        qstart = context2state[start_ctx]
    fsa.set_I(qstart, Real.one)

    # define transitions
    for ctx in contexts:
        s_from = context2state[ctx]
        dist_dict = ngram_model[ctx]
        for symbol, prob in dist_dict.items():
            if prob <= 1e-15:
                continue
            w = Real(prob)
            if symbol == eos:
                # go to final
                fsa.add_arc(s_from, Sym(symbol), q_final, w)
            else:
                # next context
                if n>1:
                    new_ctx = tuple(list(ctx[1:]) + [symbol]) if len(ctx)==(n-1) else (symbol,)
                else:
                    new_ctx = ()
                if new_ctx not in context2state:
                    qq = State(str(new_ctx))
                    context2state[new_ctx] = qq
                    fsa.add_state(qq)
                s_to = context2state[new_ctx]
                fsa.add_arc(s_from, Sym(symbol), s_to, w)

    return fsa

#############################
# 4) build_kshuffle_ngram_wfsa
#############################

def build_kshuffle_ngram_wfsa(ngram_model, n, k, alphabet, permute_block, bos='[BOS]', eos='[EOS]'):
    """
    Merge an n-gram model with a k-local deterministic shuffle in a single WFSA,
    using a buffer of size k and a 'phase' to read tokens or emit them.

    One-symbol-per-transition approach.

    ngram_model: dict => ngram_model[context][symbol] = probability
    n: order of n-gram
    k: block size
    alphabet: list of symbols
    permute_block: function that permutes a list of length k => new list
    bos, eos: special tokens
    """
    fsa = FSA(R=Real)

    # final absorbing state
    q_final = State("<FINAL>")
    fsa.add_state(q_final)
    fsa.set_F(q_final, Real(1.0))

    # gather contexts
    contexts = list(ngram_model.keys())
    # create a map for states: (ctx, phase, buf) => State
    state_map = {}

    def ensure_state(ctx, phase, buf):
        key = (ctx, phase, buf)
        if key not in state_map:
            st_name = f"{ctx}|ph={phase}|{buf}"
            st = State(st_name)
            state_map[key] = st
            fsa.add_state(st)
        return state_map[key]

    # start context
    if n>1:
        start_ctx = tuple([bos]*(n-1))
    else:
        start_ctx = ()

    # ensure start state
    s0 = ensure_state(start_ctx, 0, ())
    fsa.set_I(s0, Real(1.0))

    # We'll do BFS-like expansion of states
    queue = deque()
    visited = set()

    init_key = (start_ctx, 0, ())
    queue.append(init_key)
    visited.add(init_key)

    max_phase = 2*k

    def get_dist(ctx):
        return ngram_model.get(ctx, {})

    while queue:
        (ctx, phase, buf) = queue.popleft()
        s_from = ensure_state(ctx, phase, buf)

        if phase < k:
            # reading
            dist_dict = get_dist(ctx)
            for symbol, prob in dist_dict.items():
                if prob <= 1e-15:
                    continue
                w = Real(prob)
                if symbol == eos:
                    # transition to final
                    fsa.add_arc(s_from, Sym(symbol), q_final, w)
                else:
                    # next context
                    if n>1:
                        if len(ctx)==(n-1):
                            new_ctx = tuple(list(ctx[1:]) + [symbol])
                        else:
                            new_ctx = (symbol,)
                    else:
                        new_ctx = ()
                    new_buf = tuple(list(buf) + [symbol])
                    new_phase = phase+1
                    to_key = (new_ctx, new_phase, new_buf)
                    if to_key not in visited:
                        visited.add(to_key)
                        queue.append(to_key)
                    s_to = ensure_state(new_ctx, new_phase, new_buf)
                    # emit no symbol => we can treat as epsilon
                    fsa.add_arc(s_from, Sym(''), s_to, w)

        elif phase < 2*k:
            # emission
            idx = phase - k
            if len(buf) == k:
                pblock = permute_block(list(buf))
                if idx < k:
                    emit_symbol = pblock[idx]
                    w = Real(1.0)
                    new_phase = phase+1
                    to_key = (ctx, new_phase, buf)
                    if new_phase == 2*k:
                        # We'll handle resetting after we actually do a transition
                        pass
                    if to_key not in visited:
                        visited.add(to_key)
                        queue.append(to_key)
                    s_to = ensure_state(ctx, new_phase, buf)
                    fsa.add_arc(s_from, Sym(emit_symbol), s_to, w)

        else:
            # phase == 2k => done emitting => reset
            to_key = (ctx, 0, ())
            if to_key not in visited:
                visited.add(to_key)
                queue.append(to_key)
            s_to = ensure_state(ctx, 0, ())
            fsa.add_arc(s_from, Sym(''), s_to, Real(1.0))

    return fsa



In [238]:
"""
Example usage:
    We'll define a small n-gram model, build an n-gram WFSA, then
    build a (k=2)-shuffle version, and do some minimal checks.
"""

# Let's define a toy 2-gram model: (context=1 token) => { next_token: probability }
# Suppose we have an alphabet: [BOS], [EOS], 'a', 'b'
n = 2
bos, eos = '[BOS]', '[EOS]'
alphabet = [bos, eos, 'a', 'b']

# ngram_model:
#   context is (bos,) => { 'a':0.6, 'b':0.4 }
#   context is ('a',) => { 'a':0.3, 'b':0.2, eos:0.5 }
#   context is ('b',) => { 'a':0.1, 'b':0.1, eos:0.8 }
my_ngram_model = {
    (bos,): {'a':0.6, 'b':0.4},
    ('a',): {'a':0.3, 'b':0.2, eos:0.5},
    ('b',): {'a':0.1, 'b':0.1, eos:0.8},
}

# 1) Build the plain n-gram WFSA
fsa_ng = build_ngram_wfsa(my_ngram_model, n, alphabet, bos, eos)
print("N-gram FSA states:", len(fsa_ng.Q))

# 2) Define a block permutation for k=2 => swap
def swap2(block):
    if len(block)<2: return block
    return [block[1], block[0]]

# 3) Build k=2 shuffle WFSA
k=2
fsa_kshuffle = build_kshuffle_ngram_wfsa(my_ngram_model, n, k, alphabet, swap2, bos, eos)
print("k-shuffle FSA states:", len(fsa_kshuffle.Q))

# 4) We can do a naive 'entropy' check
# (Recall the placeholder pathsum => 0.0 in this minimal version.)
print("N-gram FSA approximate entropy:", fsa_ng.entropy())
print("k-shuffle FSA approximate entropy:", fsa_kshuffle.entropy())

# The real rayuela library has advanced pathsum & entropies if the FSA is acyclic
# or if you do a correct lifting to the entropy semiring, etc.
# But structurally, these two FSAs demonstrate the n-gram and n-gram+k-shuffle logic.

N-gram FSA states: 4
k-shuffle FSA states: 18
N-gram FSA approximate entropy: 7.92733
k-shuffle FSA approximate entropy: 7.92733


In [5]:
from rayuela.fsa.sampler import Sampler

fsa_ng_sampler = Sampler(fsa_ng)
fsa_kshuffle_sampler = Sampler(fsa_kshuffle)


In [11]:
fsa_ng_sampler.sample(10)

100%|██████████| 10/10 [00:00<00:00, 1048.44it/s]


['a [EOS]',
 'b [EOS]',
 'b [EOS]',
 'a [EOS]',
 'a a b [EOS]',
 'a [EOS]',
 'b [EOS]',
 'a [EOS]',
 'a b [EOS]',
 'b [EOS]']

In [24]:
fsa_kshuffle_sampler.sample(10)

100%|██████████| 10/10 [00:00<00:00, 855.39it/s]


[' [EOS]',
 ' [EOS]',
 ' [EOS]',
 ' [EOS]',
 ' [EOS]',
 '  a a   [EOS]',
 '  a a    b a    b b  [EOS]',
 ' [EOS]',
 ' [EOS]',
 ' [EOS]']

In [239]:
import numpy as np
from math import isclose
from itertools import product
from collections import defaultdict, deque
from rayuela.fsa.fsa import FSA



class RandomNGramModel:
    """
    Builds a random n-gram model as a Weighted FSA over the Real semiring.
    Example usage:
       model = RandomNGramModel(n=3, alpha=1.0, bos='[BOS]', eos='[EOS]')
       # Then model.fsa is your random n-gram FSA
       # You can also call model.build_kshuffle_fsa(...) to get a k-local shuffle version.
    """

    def __init__(self, alphabet: list[str], n=2, alpha=1.0, bos='[BOS]', eos='[EOS]'):
        """
        alphabet: list of symbols
        n: n-gram order
        alpha: Dirichlet concentration parameter
        bos, eos: special boundary tokens
        """
        self.alphabet = alphabet
        self.n = n
        self.alpha = alpha
        self.bos = bos
        self.eos = eos


        # We'll store the random n-gram distribution as an FSA in self.fsa
        self.fsa = FSA(R=Real)

        self.__build_model()

    def __build_model(self):
        """
        Internally builds the random n-gram WFSA.
        We do the following:
          1) define contexts (n-1)-tuples from (alphabet + [BOS]) but not [EOS] in context
          2) For each context, sample a random distribution over (alphabet + [EOS]) with Dirichlet
          3) Add arcs to next context or final state
        """
        # Create a final absorbing state
        q_final = State("<FINAL>")
        self.fsa.add_state(q_final)
        self.fsa.set_F(q_final, Real.one)

        # If n=1, context = ()
        # Otherwise context = (bos, bos, ..., bos) for start
        if self.n <= 1:
            all_contexts = [()]
        else:
            # We'll allow up to (n-1)-length combos from [BOS] + alphabet
            # but exclude [EOS] from context
            possible_symbols = [self.bos] + self.alphabet  # no eos
            # enumerates all (n-1)-tuples
            all_contexts = list(product(possible_symbols, repeat=self.n-1))

        # We'll store them in a dict for convenience
        self.context2state = {}
        for ctx in all_contexts:
            sname = str(ctx) if ctx else "()"
            st = State(sname)
            self.fsa.add_state(st)
            self.context2state[ctx] = st

        # Start context = (bos, bos, ..., bos) if n>1, else ()
        if self.n>1:
            start_ctx = tuple([self.bos]*(self.n-1))
        else:
            start_ctx = ()
        s0 = self.context2state.get(start_ctx, None)
        if not s0:
            # create one if missing
            s0 = State(str(start_ctx))
            self.context2state[start_ctx] = s0
            self.fsa.add_state(s0)
        self.fsa.set_I(s0, Real.one)

        # For each context, we'll sample a distribution over (alphabet + eos)
        # using Dirichlet
        import numpy as np
        from numpy.random import dirichlet

        out_alphabet = self.alphabet + [self.eos]

        for ctx in all_contexts:
            alpha_vec = [self.alpha]*len(out_alphabet)
            probs = dirichlet(alpha_vec)
            # For each symbol, define a transition
            s_from = self.context2state[ctx]
            for sym_idx, pval in enumerate(probs):
                sym = out_alphabet[sym_idx]
                if pval < 1e-15:
                    continue
                w = Real(pval)

                if sym == self.eos:
                    # transition to final
                    self.fsa.add_arc(s_from, Sym(sym), q_final, w)
                else:
                    # next context: shift left + sym
                    if self.n>1:
                        new_ctx = tuple(list(ctx[1:]) + [sym]) if len(ctx)==(self.n-1) else (sym,)
                    else:
                        new_ctx = ()
                    if new_ctx not in self.context2state:
                        # create it on the fly if missing
                        stx = State(str(new_ctx))
                        self.fsa.add_state(stx)
                        self.context2state[new_ctx] = stx
                    s_to = self.context2state[new_ctx]
                    self.fsa.add_arc(s_from, Sym(sym), s_to, w)

    def build_kshuffle_fsa(self, k, permute_block) -> FSA:
        """
        Builds a single FSA that merges:
          - this random n-gram FSA
          - a k-local shuffle (block-based) in a single symbol-per-transition approach.

        The code is similar to the "buffer" approach: states = (context, phase, buffer).
        For demonstration, partial blocks not carefully handled.
        """
        fsa_k = FSA(R=Real)

        # We'll reuse self.fsa's transitions and states as "the base model",
        # but we embed them in a bigger state machine that does the shuffle.

        # final absorbing
        q_final = State("<FINAL-KSHUFFLE>")
        fsa_k.add_state(q_final)
        fsa_k.set_F(q_final, Real.one)

        # We'll do BFS expansions. We store states as (ctx, phase, buf_tuple).
        from collections import deque
        visited = set()
        queue = deque()

        # define a helper for creating states
        def ensure_state(ctx, phase, buf):
            st_name = f"{ctx}|ph={phase}|{buf}"
            st = State(st_name)
            fsa_k.add_state(st)
            return st

        # Start with (start_ctx, 0, ())
        if self.n>1:
            start_ctx = tuple([self.bos]*(self.n-1))
        else:
            start_ctx = ()
        s0 = ensure_state(start_ctx, 0, ())
        fsa_k.set_I(s0, Real.one)
        queue.append( (start_ctx, 0, ()) )
        visited.add( (start_ctx, 0, ()) )

        max_phase = 2*k

        # We'll need to look up transitions from the old self.fsa (the n-gram model).
        # Let's store them in a dict for easy reference:
        # base_arcs[ctx] = list of (symbol, next_ctx, prob).
        # We can discover it by scanning self.fsa, but we also have context2state for states.
        base_arcs = dict()

        # We'll also identify the 'q_final' in the old FSA if it has a final state
        old_final_states = set()
        for q, w in self.fsa.F:
            if w != Real.zero:
                old_final_states.add(q)

        # gather transitions from the old FSA
        inv_map = {v:k for k,v in self.context2state.items()}  # state->context
        for ctx, st in self.context2state.items():
            arcs_list = []
            for a, j, w in self.fsa.arcs(st):
                arcs_list.append((a, j, w))
            base_arcs[ctx] = arcs_list

        while queue:
            (ctx, phase, buf) = queue.popleft()
            s_from = ensure_state(ctx, phase, buf)

            # Check if the old FSA state is final => the old context leads to final?
            st_old = self.context2state.get(ctx)
            if st_old in old_final_states:
                # we interpret that if the base model can end, we also have a transition to the new final
                # with weight 1 if we want. Or you can keep partial block logic. We'll do a direct approach:
                # If phase < k, we can skip, or we add an arc to new final:
                fsa_k.add_arc(s_from, Sym(self.eos), q_final, Real(1.0))

            if phase < k:
                # reading from base model
                arcs_list = base_arcs.get(ctx, [])
                for (a, j, w) in arcs_list:
                    if w == Real.zero:
                        continue
                    sym = str(a)
                    if sym == self.eos:
                        # go to final in the new FSA
                        # weigh by w
                        fsa_k.add_arc(s_from, a, q_final, w)
                    else:
                        # next context
                        # find j's context
                        new_c = inv_map.get(j, None)
                        new_buf = tuple(list(buf) + [sym])
                        new_phase = phase + 1
                        to_key = (new_c, new_phase, new_buf)
                        if to_key not in visited:
                            visited.add(to_key)
                            queue.append(to_key)
                        s_to = ensure_state(new_c, new_phase, new_buf)
                        # emit no symbol => treat as ε
                        fsa_k.add_arc(s_from, Sym(''), s_to, w)

            elif phase < 2*k:
                # emission
                idx = phase - k
                if len(buf) == k:
                    # permute
                    pblock = permute_block(list(buf))
                    if idx < k:
                        out_sym = pblock[idx]
                        new_phase = phase+1
                        w = Real(1.0)
                        to_key = (ctx, new_phase, buf)
                        if to_key not in visited:
                            visited.add(to_key)
                            queue.append(to_key)
                        s_to = ensure_state(ctx, new_phase, buf)
                        fsa_k.add_arc(s_from, Sym(out_sym), s_to, w)

            else:
                # phase == 2k => reset
                to_key = (ctx, 0, ())
                if to_key not in visited:
                    visited.add(to_key)
                    queue.append(to_key)
                s_to = ensure_state(ctx, 0, ())
                fsa_k.add_arc(s_from, Sym(''), s_to, Real.one)

        return fsa_k

In [240]:
rng_model = RandomNGramModel(alphabet=['a','b'], n=3, alpha=0.3, bos='[BOS]', eos='[EOS]')

print("Random Ngram FSA states:", rng_model.fsa.num_states)

# define a k=2 shuffle function
def swap_block(block):
    if len(block)<2:
        return block
    return [block[1], block[0]]

# build a k-shuffle version
fsa_k = rng_model.build_kshuffle_fsa(k=2, permute_block=swap_block)
print("k-shuffle FSA states:", fsa_k.num_states)

# Check stubbed entropy:
print("Random n-gram FSA 'entropy':", rng_model.fsa.entropy())
print("k-shuffle FSA 'entropy':", fsa_k.entropy())


Random Ngram FSA states: 10
k-shuffle FSA states: 24
Random n-gram FSA 'entropy': 8.04627
k-shuffle FSA 'entropy': 8.04627


In [121]:
rng_sampler = Sampler(rng_model.fsa)
kshuffle_sampler = Sampler(fsa_k)

In [124]:
rng_sampler.sample(100)

100%|██████████| 100/100 [00:00<00:00, 1997.67it/s]


['b a a [EOS]',
 'a [EOS]',
 'a [EOS]',
 'a [EOS]',
 'a [EOS]',
 'a [EOS]',
 'b a a [EOS]',
 'a [EOS]',
 'a [EOS]',
 'b a a [EOS]',
 'a [EOS]',
 'b a a a a [EOS]',
 'b a a a [EOS]',
 'a [EOS]',
 'a [EOS]',
 'a [EOS]',
 'b a a [EOS]',
 'a [EOS]',
 'b a a [EOS]',
 'b [EOS]',
 'a [EOS]',
 'b a a [EOS]',
 'b a a [EOS]',
 'a [EOS]',
 'b a a [EOS]',
 'b a a [EOS]',
 'a [EOS]',
 'b a a [EOS]',
 'b a a [EOS]',
 'a [EOS]',
 'a [EOS]',
 'a [EOS]',
 'b b [EOS]',
 'b a a [EOS]',
 'b a a [EOS]',
 'a [EOS]',
 'b a a [EOS]',
 'a [EOS]',
 'b a a [EOS]',
 'a [EOS]',
 'a [EOS]',
 'a [EOS]',
 'a [EOS]',
 'a [EOS]',
 'b a a [EOS]',
 'b a a [EOS]',
 'b a [EOS]',
 'a [EOS]',
 'a [EOS]',
 'b a a [EOS]',
 'b a a [EOS]',
 'a [EOS]',
 'a [EOS]',
 'b a a [EOS]',
 'a [EOS]',
 'a [EOS]',
 'a [EOS]',
 'b a a [EOS]',
 'b a a [EOS]',
 'a [EOS]',
 'b a a [EOS]',
 'b a a [EOS]',
 'a [EOS]',
 'a [EOS]',
 'a [EOS]',
 'a [EOS]',
 'a [EOS]',
 'a [EOS]',
 'b a a [EOS]',
 'a [EOS]',
 'a [EOS]',
 'a [EOS]',
 'a [EOS]',
 'b a 

In [159]:
class RandomNGramModel:
    """
    Builds a random n-gram model as a Weighted FSA (with Real semiring).
    Uses _is_valid_context to skip contexts that contain [EOS]
    or that have a [BOS] reappearing after normal tokens.
    """

    def __init__(self, alphabet, n=2, alpha=1.0, bos='[BOS]', eos='[EOS]'):
        """
        alphabet: list[str]
        n: n-gram order
        alpha: Dirichlet concentration
        bos,eos: boundary tokens
        """
        self.alphabet = alphabet
        self.n = n
        self.alpha = alpha
        self.bos = bos
        self.eos = eos

        self.fsa = FSA(R=Real)
        self._build_ngram_fsa()

    def _is_valid_context(self, ctx):
        """
        Rules:
        - EOS cannot be in context
        - BOS must appear continuously from the left edge, followed by normal characters only
        """
        if self.eos in ctx:
            return False
        saw_normal = False
        for token in ctx:
            if token == self.bos:
                if saw_normal:
                    return False
            else:
                saw_normal = True
        return True

    def _build_ngram_fsa(self):
        """
        1) create final absorbing state
        2) define contexts (n-1)-tuples from [BOS]+alphabet, excluding [EOS] in context,
           only keep them if _is_valid_context(ctx) => True
        3) for each context, sample distribution over (alphabet + [EOS]) => arcs
        """
        # final absorbing
        q_final = State("<FINAL>")
        self.fsa.add_state(q_final)
        self.fsa.set_F(q_final, Real.one)

        # gather contexts
        if self.n <= 1:
            all_contexts = [()]
        else:
            # exclude [EOS] from contexts
            context_alphabet = [a for a in self.alphabet if a != self.eos]
            # enumerates all (n-1)-tuples
            raw_contexts = product(context_alphabet, repeat=self.n-1)
            # filter them with _is_valid_context
            all_contexts = [ctx for ctx in raw_contexts if self._is_valid_context(ctx)]

        self.context2state = {}
        for ctx in all_contexts:
            sname = str(ctx) if ctx else "()"
            st = State(sname)
            self.fsa.add_state(st)
            self.context2state[ctx] = st

        # define start context
        if self.n>1:
            start_ctx = tuple([self.bos]*(self.n-1))
        else:
            start_ctx = ()
        if start_ctx not in self.context2state and self._is_valid_context(start_ctx):
            st0 = State(str(start_ctx))
            self.fsa.add_state(st0)
            self.context2state[start_ctx] = st0

        s0 = self.context2state.get(start_ctx, None)
        if s0 is None:
            # fallback if start_ctx isn't valid => no arcs from start
            s0 = State("<NoValidStart>")
            self.fsa.add_state(s0)
        self.fsa.set_I(s0, Real.one)

        # We'll sample distributions for each context
        out_alphabet = self.alphabet + [self.eos]
        from numpy.random import dirichlet

        for ctx in all_contexts:
            dist = dirichlet([self.alpha]*len(out_alphabet))
            s_from = self.context2state[ctx]
            for i, pval in enumerate(dist):
                if pval < 1e-15:
                    continue
                sym = out_alphabet[i]
                w = Real(pval)
                if sym == self.eos:
                    # arc to final
                    self.fsa.add_arc(s_from, Sym(sym), q_final, w)
                else:
                    # shift context
                    if self.n>1:
                        new_ctx = tuple(list(ctx[1:]) + [sym]) if len(ctx)==(self.n-1) else (sym,)
                    else:
                        new_ctx = ()
                    # must also check if new_ctx is valid => if not, skip
                    if self._is_valid_context(new_ctx):
                        # ensure state
                        if new_ctx not in self.context2state:
                            stN = State(str(new_ctx))
                            self.fsa.add_state(stN)
                            self.context2state[new_ctx] = stN
                        s_to = self.context2state[new_ctx]
                        self.fsa.add_arc(s_from, Sym(sym), s_to, w)

class KShuffleNgram:
    """
    A separate class that merges:
      - A RandomNGramModel's FSA
      - A k-local block shuffle, via a single symbol-per-transition approach.
    """

    def __init__(self, ngram_model: RandomNGramModel, k: int, perturbation_fnc):
        """
        ngram_model: a RandomNGramModel
        k: block size
        perturbation_fnc: function that rearranges a full block
        """
        self.ngram_model = ngram_model
        self.k = k
        self.perturbation_fnc = perturbation_fnc
        self.R = Real

        self.fsa = FSA(R=self.R)
        self._build_kshuffle()

    def _build_kshuffle(self):
        # final
        q_final = State("<KSHUFFLE_FINAL>")
        self.fsa.add_state(q_final)
        self.fsa.set_F(q_final, Real.one)

        # base_fsa
        base_fsa = self.ngram_model.fsa

        # We'll track states as (context, buffer_tuple)
        # BFS approach
        visited = set()
        queue = deque()

        # define start context
        if self.ngram_model.n>1:
            start_ctx = tuple([self.ngram_model.bos]*(self.ngram_model.n-1))
        else:
            start_ctx = ()
        def ensure_state(ctx, buf):
            st_name = f"{ctx}|{buf}"
            st = State(st_name)
            self.fsa.add_state(st)
            return st

        s0 = ensure_state(start_ctx, ())
        self.fsa.set_I(s0, Real.one)
        queue.append((start_ctx, ()))
        visited.add((start_ctx, ()))

        # find base_fsa final states:
        base_final_states = set()
        for q, w in base_fsa.F:
            if w != self.R.zero:
                base_final_states.add(q)

        # We'll need a reverse map for base_fsa context -> state
        # But we have ngram_model.context2state
        inv_map = {v: k for k, v in self.ngram_model.context2state.items()}

        def get_arcs_for_context(ctx):
            """
            Return arcs from the base_fsa for the state that represents ctx.
            Using base_fsa.arcs(...).
            """
            st_base = self.ngram_model.context2state.get(ctx, None)
            if not st_base:
                return []
            results = []
            for (a, j, w) in base_fsa.arcs(st_base, nozero=True, no_eps=True, reverse=False):
                results.append((a, j, w))
            return results

        while queue:
            (ctx, buf) = queue.popleft()
            s_from = ensure_state(ctx, buf)

            # if base context is final => flush leftover, go final
            st_base = self.ngram_model.context2state.get(ctx, None)
            base_is_final = (st_base in base_final_states) if st_base else False
            if base_is_final:
                # flush leftover
                current_st = s_from
                for symL in buf:
                    st_next = ensure_state(ctx, ())
                    self.fsa.add_arc(current_st, Sym(symL), st_next, Real.one)
                    current_st = st_next
                # then arc to <KSHUFFLE_FINAL>
                self.fsa.add_arc(current_st, Sym(self.ngram_model.eos), q_final, Real.one)

            # if buffer < k => we can read more from base_fsa
            if len(buf) < self.k and not base_is_final:
                # arcs from base_fsa
                arcs_list = get_arcs_for_context(ctx)
                for (a, j, w) in arcs_list:
                    if w == self.R.zero:
                        continue
                    symbol = str(a)
                    if symbol == self.ngram_model.eos:
                        # end => flush leftover
                        # do same approach
                        current_st = s_from
                        for symL in buf:
                            st_next = ensure_state(ctx, ())
                            self.fsa.add_arc(current_st, Sym(symL), st_next, Real.one)
                            current_st = st_next
                        # then go final
                        self.fsa.add_arc(current_st, a, q_final, w)
                    else:
                        # read
                        new_ctx = inv_map.get(j, None)
                        if new_ctx is None:
                            # skip if no known context
                            continue
                        new_buf = tuple(list(buf) + [symbol])
                        to_key = (new_ctx, new_buf)
                        if to_key not in visited:
                            visited.add(to_key)
                            queue.append(to_key)
                        s_to = ensure_state(new_ctx, new_buf)
                        # emit epsilon
                        self.fsa.add_arc(s_from, Sym(''), s_to, w)

            # if buffer == k => we output the block in permuted order
            if len(buf) == self.k:
                # do single-step chain
                pblock = self.perturbation_fnc(list(buf))
                current_st = s_from
                for tok in pblock:
                    mid_st = ensure_state(ctx, buf)
                    self.fsa.add_arc(current_st, Sym(tok), mid_st, Real.one)
                    current_st = mid_st
                # after that => empty buffer
                final_st = ensure_state(ctx, ())
                self.fsa.add_arc(current_st, Sym(''), final_st, Real.one)
                # queue next
                if (ctx, ()) not in visited:
                    visited.add((ctx, ()))
                    queue.append((ctx, ()))


In [181]:
# Suppose we have an alphabet
alphabet = ["[BOS]", "[EOS]", "a", "b"]

# Build a random 2-gram model
rng_model = RandomNGramModel(alphabet=alphabet, n=2, alpha=1.0, bos="[BOS]", eos="[EOS]")
print("Random N-gram FSA states:", rng_model.fsa.num_states)

# define a simple block-permutation function
def swap2(block):
    if len(block)==2:
        return [block[1], block[0]]
    return block

# Build KShuffleNgram
kshuf = KShuffleNgram(ngram_model=rng_model, k=2, perturbation_fnc=swap2)
print("k-shuffle FSA states:", kshuf.fsa.num_states)

# Show stub entropies
print("Base FSA 'entropy':", rng_model.fsa.entropy())
print("K-shuffle FSA 'entropy':", kshuf.fsa.entropy())

Random N-gram FSA states: 4
k-shuffle FSA states: 16
Base FSA 'entropy': 24.71265
K-shuffle FSA 'entropy': 0.19565


In [223]:
class RandomNGramModel:
    """
    Builds a random n-gram model as a Weighted FSA (with Real semiring),
    WITHOUT requiring user to pass [BOS] or [EOS] in the 'alphabet'.
    We'll define them as internal tokens.
    """

    _BOS = Sym("<BOS>")
    _EOS = Sym("<EOS>")

    def __init__(self,
                 alphabet,         # user-supplied normal symbols (no BOS/EOS)
                 n=2,
                 alpha=1.0):
        """
        alphabet: list of normal symbols (no boundary tokens)
        n: n-gram order
        alpha: Dirichlet concentration
        """
        self.alphabet = alphabet  # user symbols only
        self.n = n
        self.alpha = alpha

        self.fsa = FSA(R=Real)
        self._build_ngram_fsa()

    def _is_valid_context(self, ctx):
        """
        Let's define:
        - <EOS> cannot be in context
        - <BOS> must appear from left to right with no normal tokens before it
        (Similar logic to the older snippet)
        But if you want simpler logic, you can skip or adjust.
        """
        if self._EOS in ctx:
            return False
        saw_normal = False
        for token in ctx:
            if token == self._BOS:
                if saw_normal:
                    return False
            else:
                saw_normal = True
        return True

    def _build_ngram_fsa(self):
        """
        1) define final absorbing state
        2) define contexts: (n-1)-tuples from {<BOS>}+alphabet (no <EOS> in context),
           filter them with _is_valid_context
        3) for each context, sample distribution over alphabet + <EOS>
        4) arcs to next context or final
        """
        q_final = State("<<FINAL>>")
        self.fsa.add_state(q_final)
        self.fsa.set_F(q_final, Real.one)

        # define possible for contexts
        if self.n <= 1:
            all_contexts = [()]
        else:
            context_syms = [self._BOS] + self.alphabet  # no <EOS> here
            raw = product(context_syms, repeat=self.n-1)
            all_contexts = [ctx for ctx in raw if self._is_valid_context(ctx)]

        # create states for each context
        self.context2state = {}
        for ctx in all_contexts:
            st = State(str(ctx) if ctx else "()")
            self.fsa.add_state(st)
            self.context2state[ctx] = st

        # define start context = (<BOS>, ... <BOS>) if n>1 else ()
        if self.n>1:
            start_ctx = tuple([self._BOS]*(self.n-1))
        else:
            start_ctx = ()

        # if not valid, no arcs from start
        if self._is_valid_context(start_ctx):
            if start_ctx not in self.context2state:
                s0_ = State(str(start_ctx))
                self.fsa.add_state(s0_)
                self.context2state[start_ctx] = s0_
            s0 = self.context2state[start_ctx]
        else:
            s0 = State("<NoValidStart>")
            self.fsa.add_state(s0)

        self.fsa.set_I(s0, Real.one)

        # for each context, sample distribution over (alphabet + <EOS>)
        out_syms = list(self.alphabet) + [self._EOS]
        from numpy.random import dirichlet

        for ctx in all_contexts:
            dist = dirichlet([self.alpha]*len(out_syms))
            s_from = self.context2state[ctx]
            for i, pval in enumerate(dist):
                if pval < 1e-15:
                    continue
                sym = out_syms[i]
                w = Real(pval)
                if sym == self._EOS:
                    # arc to q_final
                    self.fsa.add_arc(s_from, sym, q_final, w)
                else:
                    # shift context
                    if self.n>1:
                        new_ctx = tuple(list(ctx[1:]) + [sym]) if len(ctx)==(self.n-1) else (sym,)
                    else:
                        new_ctx = ()

                    if self._is_valid_context(new_ctx):
                        if new_ctx not in self.context2state:
                            st_new = State(str(new_ctx) if new_ctx else "()")
                            self.fsa.add_state(st_new)
                            self.context2state[new_ctx] = st_new
                        s_to = self.context2state[new_ctx]
                        self.fsa.add_arc(s_from, sym, s_to, w)

###############################################################################
# (D) KShuffleNgram for optional k-local block shuffle
###############################################################################

class KShuffleNgram:
    """
    Takes the ngram_model's FSA and merges in a k-local shuffle approach.
    """

    def __init__(self, ngram_model: RandomNGramModel, k: int, perturbation_fnc):
        self.ngram_model = ngram_model
        self.k = k
        self.perturbation_fnc = perturbation_fnc
        self.R = Real

        self.fsa = FSA(R=self.R)
        self._build_kshuffle()

    def _build_kshuffle(self):
        q_final = State("<KSHUFFLE_FINAL>")
        self.fsa.add_state(q_final)
        self.fsa.set_F(q_final, Real.one)

        base_fsa = self.ngram_model.fsa

        visited = set()
        queue = deque()

        # start context: same logic
        if self.ngram_model.n>1:
            start_ctx = tuple([self.ngram_model._BOS]*(self.ngram_model.n-1))
        else:
            start_ctx = ()

        def ensure_state(ctx, buf):
            name = f"{ctx}|{buf}"
            st = State(name)
            self.fsa.add_state(st)
            return st

        s0 = ensure_state(start_ctx, ())
        self.fsa.set_I(s0, Real.one)
        visited.add((start_ctx, ()))
        queue.append((start_ctx, ()))

        base_final_states = set()
        for (q, w) in base_fsa.F:
            if w != self.R.zero:
                base_final_states.add(q)

        # We invert the ngram_model's context2state dict for arcs
        inv_map = {v: k for (k,v) in self.ngram_model.context2state.items()}

        def arcs_for_ctx(ctx):
            # Return arcs from base_fsa for the state representing 'ctx'
            st_base = self.ngram_model.context2state.get(ctx, None)
            if not st_base:
                return []
            results = []
            # gather arcs from st_base
            for (a, j, w) in base_fsa.arcs(st_base, nozero=True, no_eps=True):
                results.append((a,j,w))
            return results

        while queue:
            (ctx, buf) = queue.popleft()
            s_from = ensure_state(ctx, buf)

            st_base = self.ngram_model.context2state.get(ctx, None)
            base_is_final = (st_base in base_final_states) if st_base else False

            if base_is_final:
                # flush leftover => final
                current_st = s_from
                for leftover_tok in buf:
                    mid_st = ensure_state(ctx, ())
                    self.fsa.add_arc(current_st, Sym(leftover_tok), mid_st, Real.one)
                    current_st = mid_st
                self.fsa.add_arc(current_st, Sym(self.ngram_model._EOS), q_final, Real.one)

            # read if buffer<k
            if len(buf)<self.k and not base_is_final:
                A = arcs_for_ctx(ctx)
                for (a,j,w) in A:
                    if w == self.R.zero:
                        continue
                    symbol = str(a)
                    if a == self.ngram_model._EOS:
                        # flush leftover => final
                        cur_st = s_from
                        for tok_ in buf:
                            mid_st = ensure_state(ctx, ())
                            self.fsa.add_arc(cur_st, Sym(tok_), mid_st, Real.one)
                            cur_st = mid_st
                        self.fsa.add_arc(cur_st, a, q_final, w)
                    else:
                        new_ctx = inv_map.get(j, None)
                        if new_ctx is not None:
                            new_buf = tuple(list(buf) + [symbol])
                            if (new_ctx,new_buf) not in visited:
                                visited.add((new_ctx,new_buf))
                                queue.append((new_ctx,new_buf))
                            s_to = ensure_state(new_ctx, new_buf)
                            # epsilon
                            self.fsa.add_arc(s_from, Sym(''), s_to, w)

            # if buffer==k => output block
            if len(buf)==self.k:
                pblock = self.perturbation_fnc(list(buf))
                current_st = s_from
                for out_sym in pblock:
                    mid_st = ensure_state(ctx, buf)
                    self.fsa.add_arc(current_st, Sym(out_sym), mid_st, Real.one)
                    current_st = mid_st
                # reset buffer
                final_st = ensure_state(ctx, ())
                self.fsa.add_arc(current_st, Sym(''), final_st, Real.one)
                if (ctx,()) not in visited:
                    visited.add((ctx,()))
                    queue.append((ctx,()))

###############################################################################
# Example usage
###############################################################################
if __name__=="__main__":
    # Suppose the user has an alphabet with no boundary tokens
    user_alphabet = ['a','b']

    # Build a random 2-gram model
    rng_model = RandomNGramModel(alphabet=user_alphabet, n=2, alpha=1.0)
    print("Random NGram FSA states:", rng_model.fsa.num_states)

    # define a simple block-permutation
    def swap2(block):
        if len(block)==2:
            return [block[1], block[0]]
        return block

    # build the k-shuffle version
    kmodel = KShuffleNgram(rng_model, k=2, perturbation_fnc=swap2)
    print("K-shuffle FSA states:", kmodel.fsa.num_states)

    # show stubbed entropies
    print("Base ngram FSA entropy:", rng_model.fsa.entropy())
    print("K-shuffle FSA entropy:", kmodel.fsa.entropy())

Random NGram FSA states: 4
K-shuffle FSA states: 10
Base ngram FSA entropy: 18.85219
K-shuffle FSA entropy: 0.27923


In [200]:
sampler = Sampler(rng_model.fsa)
sampler.sample(10)


100%|██████████| 10/10 [00:00<00:00, 656.84it/s]


['a a b a a b <EOS>',
 'b <EOS>',
 '<EOS>',
 'a b <EOS>',
 'a b <EOS>',
 'a b a b <EOS>',
 '<EOS>',
 'a b <EOS>',
 'c a b <EOS>',
 'a a b b c a b b b a a c <EOS>']

In [406]:
from rayuela.base.symbol import EOS, BOS, ε

class RandomNGramModel:
    """
    Builds a random n-gram model as a Weighted FSA (with Real semiring),
    with internal BOS, EOS tokens not in the user's alphabet.
    """
    def __init__(self,
                 alphabet,  # user-supplied normal symbols
                 n=2,
                 alpha=1.0):
        """
        alphabet: list of normal symbols (no boundary tokens)
        n: n-gram order
        alpha: Dirichlet concentration
        """
        self.alphabet = alphabet  # user symbols only
        self.n = n
        self.alpha = alpha

        self.fsa = FSA(R=Real)
        self.context2state = {}  # to store (n-1)-tuple -> State
        self._build_ngram_fsa()

    def _is_valid_context(self, ctx):
        """
        If we want to prevent EOS in context
        and not have BOS reappear after normal tokens:
        """
        if EOS in ctx:
            return False
        saw_normal = False
        for token in ctx:
            if token == BOS:
                if saw_normal:
                    return False
            else:
                saw_normal = True
        return True

    def _build_ngram_fsa(self):
        """
        1) define final absorbing state
        2) define contexts: (n-1)-tuples from {BOS}+alphabet, skip EOS in context
        3) sample dist over (alphabet + EOS), add arcs
        """
        q_final = State("<<FINAL>>")
        self.fsa.add_state(q_final)
        self.fsa.set_F(q_final, Real.one)

        if self.n <= 1:
            all_ctxs = [()]
        else:
            # possible context symbols
            csyms = [BOS] + self.alphabet  # not EOS
            raw = product(csyms, repeat=self.n-1)
            all_ctxs = [r for r in raw if self._is_valid_context(r)]

        # build states
        for ctx in all_ctxs:
            sname = str(ctx) if ctx else "()"
            st = State(sname)
            self.fsa.add_state(st)
            self.context2state[ctx] = st

        # start ctx
        if self.n>1:
            start_ctx = tuple([BOS]*(self.n-1))
        else:
            start_ctx = ()
        if self._is_valid_context(start_ctx) and start_ctx not in self.context2state:
            st0_ = State(str(start_ctx) if start_ctx else "()")
            self.fsa.add_state(st0_)
            self.context2state[start_ctx] = st0_

        s0 = self.context2state.get(start_ctx, State("<INVALIDSTART>"))
        self.fsa.add_state(s0)
        self.fsa.set_I(s0, Real.one)

        # sample distributions
        out_syms = list(self.alphabet) + [EOS]
        from numpy.random import dirichlet

        for ctx in all_ctxs:
            dist = dirichlet([self.alpha]*len(out_syms))
            s_from = self.context2state[ctx]
            for i, pval in enumerate(dist):
                if pval<1e-15:
                    continue
                sym = out_syms[i]
                w = Real(pval)
                if sym==EOS:
                    # arc to final
                    self.fsa.set_arc(s_from, sym, q_final, w)
                else:
                    if self.n>1:
                        new_ctx = tuple(list(ctx[1:]) + [sym]) if len(ctx)==(self.n-1) else (sym,)
                    else:
                        new_ctx = ()

                    if self._is_valid_context(new_ctx):
                        if new_ctx not in self.context2state:
                            stN = State(str(new_ctx) if new_ctx else "()")
                            self.fsa.add_state(stN)
                            self.context2state[new_ctx] = stN
                        s_to = self.context2state[new_ctx]
                        self.fsa.set_arc(s_from, sym, s_to, w)

class KShuffleNgram:
    """
    BFS-labeled approach: states=(ctx,buf).
    When we reach a full block or leftover, we output it symbol-by-symbol to
    EXACTLY ONE ephemeral pivot state, then from pivot -> next BFS-labeled.
    This ensures we don't do arcs from (ctx,buf) to itself for each symbol.
    Also we use set_arc to avoid doubling if BFS hits the same arc multiple times.
    """

    def __init__(self, ngram_model: RandomNGramModel, k: int, perturbation_fnc):
        self.ngram_model = ngram_model
        self.k = k
        self.perturbation_fnc = perturbation_fnc
        self.R = Real

        self.fsa = FSA(R=self.R)
        self._build_kshuffle()

    def _build_kshuffle(self):
        base_fsa = self.ngram_model.fsa

        q_final = State("<KSHUFFLE_FINAL>")
        self.fsa.add_state(q_final)
        self.fsa.set_F(q_final)

        queue = deque()
        visited = set()

        if self.ngram_model.n>1:
            start_ctx = tuple([BOS]*(self.ngram_model.n-1))
        else:
            start_ctx = ()

        def ensure_state(ctx, buf):
            sname = f"{ctx}|{buf}"
            st = State(sname)
            self.fsa.add_state(st)
            return st

        s0 = ensure_state(start_ctx, ())
        self.fsa.set_I(s0)
        visited.add((start_ctx, ()))
        queue.append((start_ctx, ()))

        # final states in base_fsa
        base_final = set()
        for (qq,ww) in base_fsa.F:
            if ww!=self.R.zero:
                base_final.add(qq)
        inv_map = {v:k for (k,v) in self.ngram_model.context2state.items()}

        def arcs_for_ctx(ctx):
            st_base = self.ngram_model.context2state.get(ctx,None)
            if st_base is None:
                return []
            results=[]
            for (a,j,w) in base_fsa.arcs(st_base, nozero=True, no_eps=True):
                results.append((a,j,w))
            return results

        while queue:
            (ctx, buf) = queue.popleft()
            s_from = ensure_state(ctx, buf)

            st_base = self.ngram_model.context2state.get(ctx,None)
            if st_base in base_final or st_base is None:
                raise ValueError("Base final state reached")

            # 2) If buffer < k => read arcs from base
            if len(buf)<self.k:
                A = arcs_for_ctx(ctx)
                for (a,j,w) in A:
                    if w==self.R.zero:
                        continue
                    if a==EOS:
                        # partial leftover
                        leftover_list = list(buf)
                        pblock = self.perturbation_fnc(leftover_list)
                        if pblock:
                            comb = "".join(pblock)
                            pivot = State(f"PARTIAL_EMIT{ctx}_{buf}")
                            self.fsa.add_state(pivot)
                            self.fsa.set_arc(s_from, Sym(comb), pivot, self.R.one)
                            self.fsa.set_arc(pivot, Sym(str(a)), q_final, w)
                        else:
                            self.fsa.set_arc(s_from, Sym(str(a)), q_final, w)
                    else:
                        symbol = str(a)
                        new_ctx = inv_map.get(j,None)
                        if new_ctx is None:
                            continue
                        new_buf = tuple(list(buf)+[symbol])
                        if (new_ctx,new_buf) not in visited:
                            visited.add((new_ctx,new_buf))
                            queue.append((new_ctx,new_buf))
                        s_to = ensure_state(new_ctx,new_buf)
                        # epsilon
                        self.fsa.set_arc(s_from, ε, s_to, w)
                continue

            # 3) else => buffer==k => full block
            if len(buf)==self.k:
                pblock = self.perturbation_fnc(list(buf))
                comb = "".join(pblock)
                pivot = State(f"FULL_EMIT{ctx}_{buf}")
                self.fsa.add_state(pivot)
                self.fsa.set_arc(s_from, Sym(comb), pivot, self.R.one)

                # now pivot => (ctx,())
                final_st = ensure_state(ctx,())
                self.fsa.set_arc(pivot, ε, final_st, self.R.one)

                if (ctx,()) not in visited:
                    visited.add((ctx,()))
                    queue.append((ctx,()))
                continue




Base n-gram FSA states: 14
K-shuffle FSA states: 53
Base n-gram FSA entropy: 23.57117
K-shuffle FSA entropy: 23.57117


## working implementation

In [49]:
from __future__ import annotations

import copy
from collections import Counter
from collections import defaultdict as dd
from collections import deque
from itertools import product
from typing import List, Tuple


from rayuela.base.semiring import Boolean, ProductSemiring, Real, Semiring
from rayuela.base.state import PairState, State
from rayuela.base.symbol import Sym, ε
from rayuela.fsa.pathsum import Pathsum
from rayuela.fsa.fsa import FSA

from rayuela.base.symbol import EOS, BOS, ε
from rayuela.fsa.sampler import Sampler


class RandomNGramModel:
    """
    Builds a random n-gram model as a Weighted FSA (with Real semiring),
    with internal BOS, EOS tokens not in the user's alphabet.
    """
    def __init__(self,
                 alphabet,  # user-supplied normal symbols
                 n=2,
                 alpha=1.0):
        """
        alphabet: list of normal symbols (no boundary tokens)
        n: n-gram order
        alpha: Dirichlet concentration
        """
        self.alphabet = alphabet  # user symbols only
        self.n = n
        self.alpha = alpha

        self.fsa = FSA(R=Real)
        self.context2state = {}  # to store (n-1)-tuple -> State
        self._build_ngram_fsa()

    def _is_valid_context(self, ctx):
        """
        If we want to prevent EOS in context
        and not have BOS reappear after normal tokens:
        """
        if EOS in ctx:
            return False
        saw_normal = False
        for token in ctx:
            if token == BOS:
                if saw_normal:
                    return False
            else:
                saw_normal = True
        return True

    def _build_ngram_fsa(self):
        """
        1) define final absorbing state
        2) define contexts: (n-1)-tuples from {BOS}+alphabet, skip EOS in context
        3) sample dist over (alphabet + EOS), add arcs
        """
        q_final = State("<<FINAL>>")
        self.fsa.add_state(q_final)
        self.fsa.set_F(q_final, Real.one)

        if self.n <= 1:
            all_ctxs = [()]
        else:
            # possible context symbols
            csyms = [BOS] + self.alphabet  # not EOS
            raw = product(csyms, repeat=self.n-1)
            all_ctxs = [r for r in raw if self._is_valid_context(r)]

        # build states
        for ctx in all_ctxs:
            sname = str(ctx) if ctx else "()"
            st = State(sname)
            self.fsa.add_state(st)
            self.context2state[ctx] = st

        # start ctx
        if self.n>1:
            start_ctx = tuple([BOS]*(self.n-1))
        else:
            start_ctx = ()
        if self._is_valid_context(start_ctx) and start_ctx not in self.context2state:
            st0_ = State(str(start_ctx) if start_ctx else "()")
            self.fsa.add_state(st0_)
            self.context2state[start_ctx] = st0_

        s0 = self.context2state.get(start_ctx, State("<INVALIDSTART>"))
        self.fsa.add_state(s0)
        self.fsa.set_I(s0, Real.one)

        # sample distributions
        out_syms = list(self.alphabet) + [EOS]
        from numpy.random import dirichlet

        for ctx in all_ctxs:
            dist = dirichlet([self.alpha]*len(out_syms))
            s_from = self.context2state[ctx]
            for i, pval in enumerate(dist):
                if pval<1e-15:
                    continue
                sym = out_syms[i]
                w = Real(pval)
                if sym==EOS:
                    # arc to final
                    self.fsa.set_arc(s_from, sym, q_final, w)
                else:
                    if self.n>1:
                        new_ctx = tuple(list(ctx[1:]) + [sym]) if len(ctx)==(self.n-1) else (sym,)
                    else:
                        new_ctx = ()

                    if self._is_valid_context(new_ctx):
                        if new_ctx not in self.context2state:
                            stN = State(str(new_ctx) if new_ctx else "()")
                            self.fsa.add_state(stN)
                            self.context2state[new_ctx] = stN
                        s_to = self.context2state[new_ctx]
                        self.fsa.set_arc(s_from, sym, s_to, w)


class KShuffleNgram:
    """
    Naive approach: states=(ctx,buf).
    => We store partially read block in 'buf'.
    => If we must emit that block (full or leftover), we create a small chain
       of ephemeral states, each responsible for a single symbol output.

    That is "one symbol per transition" for each flush.
    """

    def __init__(self, ngram_model: RandomNGramModel, k: int, perturbation_fnc):
        self.ngram_model = ngram_model
        self.k = k
        self.perturbation_fnc = perturbation_fnc
        self.R = Real

        self.fsa = FSA(R=self.R)
        self._build_kshuffle()

    def _build_kshuffle(self):
        base_fsa = self.ngram_model.fsa

        # new final
        q_final = State("<KSHUFFLE_FINAL>")
        self.fsa.add_state(q_final)
        self.fsa.set_F(q_final, Real.one)

        queue = deque()
        visited = set()

        if self.ngram_model.n>1:
            start_ctx = tuple([BOS]*(self.ngram_model.n-1))
        else:
            start_ctx = ()

        def ensure_state(ctx, buf):
            sname = f"{ctx}|{buf}"
            st = State(sname)
            self.fsa.add_state(st)
            return st

        s0 = ensure_state(start_ctx, ())
        self.fsa.set_I(s0, Real.one)
        visited.add((start_ctx, ()))
        queue.append((start_ctx, ()))

        # final states in base_fsa
        base_final = set()
        for (qq,ww) in base_fsa.F:
            if ww!=self.R.zero:
                base_final.add(qq)

        inv_map = {v:k for (k,v) in self.ngram_model.context2state.items()}

        # gather arcs from base FSA for a given context
        def arcs_for_ctx(ctx):
            st_base = self.ngram_model.context2state.get(ctx,None)
            if st_base is None:
                return []
            results=[]
            for (a,j,w) in base_fsa.arcs(st_base, no_eps=True, nozero=True, reverse=False):
                results.append((a,j,w))
            return results

        while queue:
            (ctx, buf) = queue.popleft()
            s_from = ensure_state(ctx, buf)

            st_base = self.ngram_model.context2state.get(ctx, None)
            base_is_final = (st_base in base_final) if st_base else False

            # If base is final => leftover flush => ephemeral chain => final
            if base_is_final:
                leftover_list = list(buf)
                pblock = self.perturbation_fnc(leftover_list)
                self._emit_symbol_chain(s_from, pblock, q_final, EOSweight=Real.one)
                continue

            # If buffer<k => read from base arcs
            if len(buf)<self.k:
                A = arcs_for_ctx(ctx)
                for (a,j,w) in A:
                    if a==EOS:
                        # partial leftover flush
                        leftover_list = list(buf)
                        pblock = self.perturbation_fnc(leftover_list)
                        self._emit_symbol_chain(s_from, pblock, q_final, symbol=a, weight=w)
                    else:
                        symbol = str(a)
                        new_ctx = inv_map.get(j,None)
                        if new_ctx is None:
                            continue
                        new_buf = tuple(list(buf)+[symbol])
                        if (new_ctx,new_buf) not in visited:
                            visited.add((new_ctx,new_buf))
                            queue.append((new_ctx,new_buf))
                        s_to = ensure_state(new_ctx, new_buf)
                        self.fsa.set_arc(s_from, ε, s_to, w)
                continue

            # else => buffer==k => full block flush
            if len(buf)==self.k:
                pblock = self.perturbation_fnc(list(buf))
                # ephemeral chain => next BFS-labeled state = (ctx,())
                next_state = ensure_state(ctx, ())
                if (ctx,()) not in visited:
                    visited.add((ctx,()))
                    queue.append((ctx,()))
                self._emit_symbol_chain(s_from, pblock, next_state, EOSweight=Real.one, eplabel=ε)
                continue

    def _emit_symbol_chain(self, s_from:State, syms:List[str], final_st:State,
                           symbol: Sym = None, weight=None, EOSweight=None, eplabel=ε):
        """
        Creates ephemeral chain from s_from for each symbol in 'syms' (one symbol per transition).
        Then from the last ephemeral state => final_st with either the 'symbol' (like <EOS>) or eplabel.
        'weight' is used for that final transition if provided, else Real.one.
        'EOSweight' is used for ephemeral transitions if not provided, default Real.one
        'eplabel' for the final transition if no symbol is left.

        This ensures single symbol per transition PFSA for the flush.
        """
        if EOSweight is None:
            EOSweight = self.R.one
        curr_st = s_from
        # for each symbol => ephemeral state
        for i, sym in enumerate(syms):
            e_name = f"EMIT_{s_from}_{i}_{sym}"
            e_st = State(e_name)
            self.fsa.add_state(e_st)
            # connect curr_st -sym-> e_st
            self.fsa.set_arc(curr_st, Sym(sym), e_st, EOSweight)
            curr_st = e_st

        # now from curr_st => final_st
        # if we have an actual 'symbol' to emit (like <EOS>):
        if symbol is not None:
            self.fsa.set_arc(curr_st, symbol, final_st, weight if weight else self.R.one)
        else:
            # else we do eplabel => final
            self.fsa.set_arc(curr_st, eplabel, final_st, weight if weight else self.R.one)


###############################################################################
# Usage
###############################################################################

if __name__=="__main__":
    user_alphabet = ['a','b']
    rng_model = RandomNGramModel(alphabet=user_alphabet, n=3, alpha=0.4)
    print("Base n-gram FSA states:", rng_model.fsa.num_states)

    # left rotate
    def left_rotate(lst):
        if len(lst) == 0:
            return lst
        return lst[1:] + [lst[0]]

    kmodel = KShuffleNgram(rng_model, k=3, perturbation_fnc=left_rotate)
    print("K-shuffle FSA states:", kmodel.fsa.num_states)

    print("Base n-gram FSA entropy:", rng_model.fsa.entropy())
    print("K-shuffle FSA entropy:", kmodel.fsa.entropy())



Base n-gram FSA states: 8
K-shuffle FSA states: 62
Base n-gram FSA entropy: 34.45959


K-shuffle FSA entropy: 34.45959


In [66]:
kmodel.fsa

In [53]:
from rayuela.fsa.transformer import Transformer

rng_model_non_eps = rng_model.fsa.push().epsremove().normalize()
kshuffle_model_non_eps = kmodel.fsa.push().epsremove()
print("Non-EPS states (base):", rng_model_non_eps.num_states)
print("Non-EPS states (kshuffle):", kshuffle_model_non_eps.num_states)
print("Base n-gram FSA entropy:", rng_model_non_eps.entropy())
print("K-shuffle FSA entropy:", kshuffle_model_non_eps.entropy())


Non-EPS states (base): 8
Non-EPS states (kshuffle): 62
Base n-gram FSA entropy: 34.45959
K-shuffle FSA entropy: 157.56298


In [55]:
kmodel.fsa.push().entropy()

34.45959

In [56]:
kmodel.fsa.push().epsremove().entropy()

157.56298

In [57]:
def partition(fsa, partition_symbol: Sym = ε) -> Tuple[FSA, FSA]:
    """Partition FSA into two
    (one with arcs of the partition symbol and one with all others)

    Args:
        fsa (FSA): The input FSA
        partition_symbol (Sym, optional): The symbol based on which to
        partition the input FSA

    Returns:
        Tuple[FSA, FSA]: The FSA with non-partition symbol arcs
                            and the FSA with only the partition symbol arcs
    """

    E = fsa.spawn()
    N = fsa.spawn(keep_init=True, keep_final=True)

    for q in fsa.Q:
        E.add_state(q)
        N.add_state(q)

    for i in fsa.Q:
        for a, j, w in fsa.arcs(i):
            if a == partition_symbol:
                E.add_arc(i, a, j, w)
            else:
                N.add_arc(i, a, j, w)

    return N, E

@staticmethod
def epsremoval(fsa):
    # note that N keeps same initial and final weights
    N, E = Transformer.partition(fsa)
    W = Pathsum(E).lehmann(zero=False)

    for i in fsa.Q:
        for a, j, w in fsa.arcs(i, no_eps=True):
            print(a)
            for k in fsa.Q:
                N.add_arc(i, a, k, w * W[j, k])

    # additional initial states
    for i, j in product(fsa.Q, repeat=2):
        N.add_I(j, fsa.λ[i] * W[i, j])

    return N

In [63]:
fsa = kmodel.fsa
E = fsa.spawn()
N = fsa.spawn(keep_init=True, keep_final=True)

for q in fsa.Q:
    E.add_state(q)
    N.add_state(q)

for i in fsa.Q:
    for a, j, w in fsa.arcs(i):
        if a == ε:
            E.add_arc(i, a, j, w)
        else:
            N.add_arc(i, a, j, w)

In [64]:
kshuffle_model_non_eps = kshuffle_model_non_eps.epsremove()
kshuffle_model_non_eps.entropy()

157.56298

In [28]:
kshuffle_model_non_eps = kshuffle_model_non_eps.normalize()
kshuffle_model_non_eps.entropy()

  return Real(1.0 / self.value)


64.38882

In [25]:
kshuffle_model_non_eps.entropy()

64.38882

In [496]:
kshuffle_model_non_eps_sampler = Sampler(kshuffle_model_non_eps)
print(["".join([i for i in item if i != "ε"]) for item in kshuffle_model_non_eps_sampler.sample(10, sep='')])

100%|██████████| 10/10 [00:00<00:00, 566.98it/s]

['', 'aa', 'ccba', 'cacbba', 'a', 'abbacccbb', 'cccba', '', 'bbba', 'cabaccaa']





In [49]:

import numpy as np
from math import isclose
from __future__ import annotations

import copy
from collections import Counter
from collections import defaultdict as dd
from collections import deque
from itertools import product
from typing import Callable, Dict, Generator, List, Optional, Sequence, Set, Tuple, Type, Union

import numpy as np
from frozendict import frozendict

import rayuela
from rayuela.base.semiring import Boolean, ProductSemiring, Real, Semiring
from rayuela.base.state import PairState, State
from rayuela.base.symbol import Expr, Sym, ε, ε_1, ε_2, φ
from rayuela.cfg.nonterminal import NT, S
from rayuela.fsa.pathsum import Pathsum, Strategy
from rayuela.fsa.fsa import FSA

from rayuela.base.symbol import EOS, BOS, ε
from rayuela.fsa.sampler import Sampler


class RandomNGramModel:
    """
    Builds a random n-gram model as a Weighted FSA (with Real semiring),
    with internal BOS, EOS tokens not in the user's alphabet.
    """
    def __init__(self,
                 alphabet,  # user-supplied normal symbols
                 n=2,
                 alpha=1.0):
        """
        alphabet: list of normal symbols (no boundary tokens)
        n: n-gram order
        alpha: Dirichlet concentration
        """
        self.alphabet = alphabet  # user symbols only
        self.n = n
        self.alpha = alpha

        self.fsa = FSA(R=Real)
        self.context2state = {}  # to store (n-1)-tuple -> State
        self._build_ngram_fsa()

    def _is_valid_context(self, ctx):
        """
        If we want to prevent EOS in context
        and not have BOS reappear after normal tokens:
        """
        if EOS in ctx:
            return False
        saw_normal = False
        for token in ctx:
            if token == BOS:
                if saw_normal:
                    return False
            else:
                saw_normal = True
        return True

    def _build_ngram_fsa(self):
        """
        1) define final absorbing state
        2) define contexts: (n-1)-tuples from {BOS}+alphabet, skip EOS in context
        3) sample dist over (alphabet + EOS), add arcs
        """
        q_final = State("<<FINAL>>")
        self.fsa.add_state(q_final)
        self.fsa.set_F(q_final, Real.one)

        if self.n <= 1:
            all_ctxs = [()]
        else:
            # possible context symbols
            csyms = [BOS] + self.alphabet  # not EOS
            raw = product(csyms, repeat=self.n-1)
            all_ctxs = [r for r in raw if self._is_valid_context(r)]

        # build states
        for ctx in all_ctxs:
            sname = str(ctx) if ctx else "()"
            st = State(sname)
            self.fsa.add_state(st)
            self.context2state[ctx] = st

        # start ctx
        if self.n>1:
            start_ctx = tuple([BOS]*(self.n-1))
        else:
            start_ctx = ()
        if self._is_valid_context(start_ctx) and start_ctx not in self.context2state:
            st0_ = State(str(start_ctx) if start_ctx else "()")
            self.fsa.add_state(st0_)
            self.context2state[start_ctx] = st0_

        s0 = self.context2state.get(start_ctx, State("<INVALIDSTART>"))
        self.fsa.add_state(s0)
        self.fsa.set_I(s0, Real.one)

        # sample distributions
        out_syms = list(self.alphabet) + [EOS]
        from numpy.random import dirichlet

        for ctx in all_ctxs:
            dist = dirichlet([self.alpha]*len(out_syms))
            s_from = self.context2state[ctx]
            for i, pval in enumerate(dist):
                if pval<1e-15:
                    continue
                sym = out_syms[i]
                w = Real(pval)
                if sym==EOS:
                    # arc to final
                    self.fsa.set_arc(s_from, sym, q_final, w)
                else:
                    if self.n>1:
                        new_ctx = tuple(list(ctx[1:]) + [sym]) if len(ctx)==(self.n-1) else (sym,)
                    else:
                        new_ctx = ()

                    if self._is_valid_context(new_ctx):
                        if new_ctx not in self.context2state:
                            stN = State(str(new_ctx) if new_ctx else "()")
                            self.fsa.add_state(stN)
                            self.context2state[new_ctx] = stN
                        s_to = self.context2state[new_ctx]
                        self.fsa.set_arc(s_from, sym, s_to, w)


class KShuffleNgram:
    """
    BFS-labeled approach: states=(ctx,buf).
    => We store partially read block in 'buf'.
    => If we must emit that block (full or leftover), we create a small chain
       of ephemeral states, each responsible for a single symbol output.

    That is "one symbol per transition" for each flush.
    """

    def __init__(self, ngram_model: RandomNGramModel, k: int, perturbation_fnc):
        self.ngram_model = ngram_model
        self.k = k
        self.perturbation_fnc = perturbation_fnc
        self.R = Real

        self.fsa = FSA(R=self.R)
        self._build_kshuffle()

    def _build_kshuffle(self):
        base_fsa = self.ngram_model.fsa

        # new final
        q_final = State("<KSHUFFLE_FINAL>")
        self.fsa.add_state(q_final)
        self.fsa.set_F(q_final, Real.one)

        queue = deque()
        visited = set()

        if self.ngram_model.n>1:
            start_ctx = tuple([BOS]*(self.ngram_model.n-1))
        else:
            start_ctx = ()

        def ensure_state(ctx, buf):
            sname = f"{ctx}|{buf}"
            st = State(sname)
            self.fsa.add_state(st)
            return st

        s0 = ensure_state(start_ctx, ())
        self.fsa.set_I(s0, Real.one)
        visited.add((start_ctx, ()))
        queue.append((start_ctx, ()))

        # final states in base_fsa
        base_final = set()
        for (qq,ww) in base_fsa.F:
            if ww!=self.R.zero:
                base_final.add(qq)

        inv_map = {v:k for (k,v) in self.ngram_model.context2state.items()}

        # gather arcs from base FSA for a given context
        def arcs_for_ctx(ctx):
            st_base = self.ngram_model.context2state.get(ctx,None)
            if st_base is None:
                return []
            results=[]
            for (a,j,w) in base_fsa.arcs(st_base, no_eps=True, nozero=True, reverse=False):
                results.append((a,j,w))
            return results

        while queue:
            (ctx, buf) = queue.popleft()
            s_from = ensure_state(ctx, buf)

            st_base = self.ngram_model.context2state.get(ctx, None)
            base_is_final = (st_base in base_final) if st_base else False

            # If base is final => leftover flush => ephemeral chain => final
            if base_is_final:
                leftover_list = list(buf)
                pblock = self.perturbation_fnc(leftover_list)
                self._emit_symbol_chain(s_from, pblock, q_final, EOSweight=Real.one)
                continue

            # If buffer<k => read from base arcs
            if len(buf)<self.k:
                A = arcs_for_ctx(ctx)
                for (a,j,w) in A:
                    if a==EOS:
                        # partial leftover flush
                        leftover_list = list(buf)
                        pblock = self.perturbation_fnc(leftover_list)
                        # print("Partial block:", pblock)
                        # print(f"Context: {ctx}, Buffer: {buf}, Arc: {a}, Weight: {w}")                        # ephemeral chain => final
                        self._emit_symbol_chain(s_from, pblock, q_final, symbol=a, weight=w)
                    else:
                        symbol = str(a)
                        new_ctx = inv_map.get(j,None)
                        if new_ctx is None:
                            continue
                        new_buf = tuple(list(buf)+[symbol])
                        if (new_ctx,new_buf) not in visited:
                            visited.add((new_ctx,new_buf))
                            queue.append((new_ctx,new_buf))
                        s_to = ensure_state(new_ctx, new_buf)
                        self.fsa.set_arc(s_from, ε, s_to, w)
                continue

            # else => buffer==k => full block flush
            if len(buf)==self.k:
                pblock = self.perturbation_fnc(list(buf))
                # ephemeral chain => next BFS-labeled state = (ctx,())
                next_state = ensure_state(ctx, ())
                if (ctx,()) not in visited:
                    visited.add((ctx,()))
                    queue.append((ctx,()))
                self._emit_symbol_chain(s_from, pblock, next_state, EOSweight=Real.one, eplabel=ε)
                continue

    def _emit_symbol_chain(self, s_from:State, syms:List[str], final_st:State,
                           symbol: Sym = None, weight=None, EOSweight=None, eplabel=ε):
        """
        Creates ephemeral chain from s_from for each symbol in 'syms' (one symbol per transition).
        Then from the last ephemeral state => final_st with either the 'symbol' (like <EOS>) or eplabel.
        'weight' is used for that final transition if provided, else Real.one.
        'EOSweight' is used for ephemeral transitions if not provided, default Real.one
        'eplabel' for the final transition if no symbol is left.

        This ensures single symbol per transition PFSA for the flush.
        """
        if EOSweight is None:
            EOSweight = self.R.one
        curr_st = s_from
        # for each symbol => ephemeral state
        for i, sym in enumerate(syms):
            e_name = f"EMIT_{s_from}_{i}_{sym}"
            e_st = State(e_name)
            self.fsa.add_state(e_st)
            # connect curr_st -sym-> e_st
            self.fsa.set_arc(curr_st, Sym(sym), e_st, EOSweight)
            curr_st = e_st

        # now from curr_st => final_st
        # if we have an actual 'symbol' to emit (like <EOS>):
        if symbol is not None:
            self.fsa.set_arc(curr_st, symbol, final_st, weight if weight else self.R.one)
        else:
            # else we do eplabel => final
            self.fsa.set_arc(curr_st, eplabel, final_st, weight if weight else self.R.one)


###################################################
# 4) a naive epsremoval
###################################################
class Transformer:
    @staticmethod
    def partition(fsa, partition_symbol: Sym = ε) -> Tuple[FSA, FSA]:
        """Partition FSA into two
        (one with arcs of the partition symbol and one with all others)

        Args:
            fsa (FSA): The input FSA
            partition_symbol (Sym, optional): The symbol based on which to
            partition the input FSA

        Returns:
            Tuple[FSA, FSA]: The FSA with non-partition symbol arcs
                             and the FSA with only the partition symbol arcs
        """

        E = fsa.spawn()
        N = fsa.spawn(keep_init=True, keep_final=True)

        for q in fsa.Q:
            E.add_state(q)
            N.add_state(q)

        for i in fsa.Q:
            for a, j, w in fsa.arcs(i):
                if a == partition_symbol:
                    E.add_arc(i, a, j, w)
                else:
                    N.add_arc(i, a, j, w)

        return N, E

    @staticmethod
    def epsremoval_bebugged(fsa):
        """
        naive approach => triple loop => overcounts arcs from each state
        leading to bigger local arc sums => bigger local partial expansions => bigger entropy
        """
        from itertools import product

        N,E= Transformer.partition(fsa)
        W= Pathsum(E).lehmann(zero=False)
        print(W)

        # triple loop => for each i-a->j in original no_eps arcs, for each k => i-a->k
        for i in fsa.Q:
            for a,j,w in fsa.arcs(i, no_eps=True):
                for k in fsa.Q:
                    neww= w * W[(j,k)]
                    if neww.value>1e-15:
                        N.add_arc(i,a,k,neww)
        # fix initial => double loop
        for i in fsa.Q:
            for j in fsa.Q:
                N.add_I(j, fsa.λ[i]* W[(i,j)])
        return N


###################################################
# 5) Demo
###################################################

if __name__=="__main__":
    user_alphabet= ['a','b']
    rng_model= RandomNGramModel(alphabet=user_alphabet, n=3, alpha=0.4)
    print("Base n-gram FSA states:", rng_model.fsa.num_states)
    base_ent= rng_model.fsa.entropy()
    print("Base n-gram FSA entropy =>", base_ent)

    # define a simple block-perturbation function
    def left_rotate(block):
        if len(block)==0:
            return block
        return block[1:] + [block[0]]

    # K-shuffle
    kmodel= KShuffleNgram(rng_model, k=2, perturbation_fnc=left_rotate)
    shuffle_ent= kmodel.fsa.entropy()
    print("K-shuffle FSA states:", kmodel.fsa.num_states)
    print("K-shuffle FSA entropy =>", shuffle_ent)

    # Now the naive eps removal
    bebugged_noeps= Transformer.epsremoval_bebugged(kmodel.fsa)
    big_ent= bebugged_noeps.entropy()
    print("After naive eps removal => entropy =>", big_ent)


Base n-gram FSA states: 8
Base n-gram FSA entropy => 8.03324
K-shuffle FSA states: 30
K-shuffle FSA entropy => 8.03324
frozendict.frozendict({(EMIT_('b', 'a')|('b', 'a')_1_b, EMIT_('b', 'a')|('b', 'a')_1_b): 0.0, (EMIT_('b', 'a')|('b', 'a')_1_b, ('b', 'a')|('a',)): 0.0, (EMIT_('b', 'a')|('b', 'a')_1_b, ('b', 'b')|('b',)): 0.0, (EMIT_('b', 'a')|('b', 'a')_1_b, ('a', 'b')|('a', 'b')): 0.01029, (EMIT_('b', 'a')|('b', 'a')_1_b, EMIT_('b', 'b')|('b', 'b')_0_b): 0.0, (EMIT_('b', 'a')|('b', 'a')_1_b, EMIT_('b', 'a')|('a',)_0_a): 0.0, (EMIT_('b', 'a')|('b', 'a')_1_b, EMIT_('BOS', 'b')|('b',)_0_b): 0.0, (EMIT_('b', 'a')|('b', 'a')_1_b, EMIT_('a', 'a')|('a', 'a')_1_a): 0.0, (EMIT_('b', 'a')|('b', 'a')_1_b, ('b', 'a')|()): 1.0, (EMIT_('b', 'a')|('b', 'a')_1_b, ('a', 'b')|('b',)): 0.16415, (EMIT_('b', 'a')|('b', 'a')_1_b, ('BOS', 'BOS')|()): 0.0, (EMIT_('b', 'a')|('b', 'a')_1_b, <KSHUFFLE_FINAL>): 0.0, (EMIT_('b', 'a')|('b', 'a')_1_b, EMIT_('a', 'a')|('a',)_0_a): 0.0, (EMIT_('b', 'a')|('b', 'a')_1

### k-local entropy

In [7]:
#!/usr/bin/env python3

from collections import defaultdict as dd
import math
import numpy as np

####################################################
# A) The base Semiring class with the required methods
####################################################
class Semiring:
    zero: "Semiring"
    one: "Semiring"
    idempotent = False

    def __init__(self, value):
        # We'll store the "underlying numeric" or "structured" value in .value
        self.value = value

    @classmethod
    def zeros(cls, N, M):
        # returns an NxM matrix with "semiring.zero" in each cell
        mat = np.empty((N, M), dtype=object)
        for i in range(N):
            for j in range(M):
                mat[i, j] = cls.zero
        return mat

    @classmethod
    def chart(cls, default=None):
        if default is None:
            default = cls.zero
        return dd(lambda: default)

    @classmethod
    def diag(cls, N):
        # NxN with diagonal=cls.one, off-diag=cls.zero
        mat = cls.zeros(N, N)
        for i in range(N):
            mat[i, i] = cls.one
        return mat

    @classmethod
    @property
    def is_field(self):
        # e.g. Real semiring might set True
        return False

    def __add__(self, other):
        raise NotImplementedError

    def __mul__(self, other):
        raise NotImplementedError

    def __eq__(self, other):
        # default eq for .value
        return self.value == other.value

    def __hash__(self):
        # must be hashable
        return hash(self.value)

    def star(self):
        # if needed for star-closed
        raise NotImplementedError


####################################################
# B) Our LocalEntropySemiring that accumulates
#    ( sumProb, sumProbEnt ) i.e. (p, p * H)
####################################################
class LocalEntropySemiring(Semiring):
    """
    We store a pair: (sp, spe) = ( sumProb, sumProbTimesEntropy ).
      sp  = total probability mass
      spe = sum of (prob(path) * pathEntropy).

    The operations:

    1) add (+):
       (p1, e1) + (p2, e2) = (p1+p2, e1+e2)

    2) mul (×):
       (p1, e1) x (p2, e2) = (p1*p2, e1*p2 + e2*p1)
       => the usual "expectation semiring" formula.

    The final ratio: average entropy = ( sumProbEnt / sumProb ).

    We'll store them in .value as a tuple (sp, spe).
    """

    idempotent = False

    def __init__(self, pair):
        # pair = (sp, spe)
        super().__init__(pair)  # self.value = pair
        # we can define convenient attributes:
        self.sp = pair[0]
        self.spe = pair[1]

    def __add__(self, other):
        # (p1, e1) + (p2, e2) => (p1+p2, e1+ e2)
        return LocalEntropySemiring( (self.sp + other.sp, self.spe + other.spe) )

    def __mul__(self, other):
        # (p1, e1) x (p2, e2) => (p1*p2, e1*p2 + e2*p1)
        spNew = self.sp * other.sp
        speNew = (self.spe * other.sp) + (other.spe * self.sp)
        return LocalEntropySemiring( (spNew, speNew) )

    def star(self):
        raise NotImplementedError("No star in LocalEntropySemiring")

    def __repr__(self):
        return f"LE({self.sp:.4g}, {self.spe:.4g})"

    @staticmethod
    def make(prob: float):
        """
        Creates a local-entropy semiring element for a single arc of probability p:
         => ( p, p * (-log p) )
        """
        if prob < 1e-15:
            return LocalEntropySemiring((0.0, 0.0))
        ent = prob * (-math.log(prob))
        return LocalEntropySemiring((prob, ent))

    #####################################################
    # "static" semiring-wide methods: zero, one, etc.
    #####################################################
    @classmethod
    def zero(cls):
        # additive identity => (0.0, 0.0)
        return cls((0.0, 0.0))

    @classmethod
    def one(cls):
        # multiplicative identity => (1.0, 0.0)
        return cls((1.0, 0.0))

    @classmethod
    def zeros(cls, N, M):
        # NxM matrix of zero elements
        mat = np.empty((N, M), dtype=object)
        z = cls.zero()
        for i in range(N):
            for j in range(M):
                mat[i, j] = z
        return mat

    @classmethod
    def chart(cls, default=None):
        if default is None:
            default = cls.zero()
        return dd(lambda: default)

    @classmethod
    def diag(cls, N):
        mat = cls.zeros(N, N)
        one = cls.one()
        for i in range(N):
            mat[i, i] = one
        return mat

    def value(self):
        # returns the (sumProb, sumProbEnt) pair
        return (self.sp, self.spe)


####################################################
# C) Building a "k-context FSA" in localEnt semiring
####################################################
def build_k_context_fsa(noeps_fsa, k: int):
    """
    from a no-eps FSA (with arcs that presumably sum to 1 from each state),
    produce a new FSA in "LocalEntropySemiring". states = ( oldQ, (k-1)-context ).
    arcs => define prob = w.value / sumOut => semiring => (p, p*-log p).
    """
    from rayuela.fsa.fsa import FSA
    from rayuela.base.symbol import Sym
    from rayuela.base.state import State

    KFSA = FSA(R=LocalEntropySemiring)

    # gather adjacency
    from collections import defaultdict, deque
    adjacency = defaultdict(list)
    outSum = {}

    for q in noeps_fsa.Q:
        arcsq = list(noeps_fsa.arcs(q))
        s = 0.0
        for aSym, rSt, w in arcsq:
            s += w.value
        s += noeps_fsa.ρ[q].value
        outSum[q] = s
        adjacency[q] = arcsq

    newStateDict = {}
    def ensure_state(q, buf):
        key = (q, tuple(buf))
        if key not in newStateDict:
            st = State(str(key))
            KFSA.add_state(st)
            newStateDict[key] = st
        return newStateDict[key]

    queue = deque()
    visited = set()

    # define initial states
    for q, wInit in noeps_fsa.I:
        prob = wInit.value
        ctx0 = ["<BOS>"]*(k-1) if k>1 else []
        stNew = ensure_state(q, ctx0)
        if prob>1e-15:
            # localEnt = (prob, prob*-log(prob))
            wEl = LocalEntropySemiring.make(prob)
            KFSA.set_I(stNew, wEl)
        queue.append((q, tuple(ctx0)))
        visited.add((q, tuple(ctx0)))

    while queue:
        (oldQ, buf) = queue.popleft()
        stFrom = newStateDict[(oldQ, buf)]
        # final?
        fval = noeps_fsa.ρ[oldQ].value
        if fval>1e-15:
            fEl = LocalEntropySemiring.make(fval)
            KFSA.add_F(stFrom, fEl)

        # gather arcs
        s = outSum[oldQ]
        if s<1e-15:
            continue
        for (aSym, rSt, w) in adjacency[oldQ]:
            arcProb = w.value/s
            if arcProb<1e-15:
                continue
            # shift buffer
            newBuf = list(buf)[1:] + [aSym.value] if k>1 else []
            stTo = ensure_state(rSt, newBuf)

            arcWeight = LocalEntropySemiring.make(arcProb)
            KFSA.set_arc(stFrom, aSym, stTo, arcWeight)

            if (rSt, tuple(newBuf)) not in visited:
                visited.add((rSt, tuple(newBuf)))
                queue.append((rSt, tuple(newBuf)))

    return KFSA


########################################################
# D) The actual k_local_entropy function
########################################################

def k_local_entropy(original_fsa, k: int):
    """
    1) eps-removal => single symbol arcs
    2) build k-context fsa in localEntropySemiring
    3) run pathsum(Strategy.LEHMANN) => get allpairs => sum init->final => ratio => average local ent
    """
    from rayuela.fsa.transformer import Transformer
    from rayuela.fsa.pathsum import Pathsum, Strategy

    # remove eps arcs
    noeps = Transformer.epsremoval(original_fsa)
    # optionally push/normalize => omitted here

    # build k-fsa
    kfsa = build_k_context_fsa(noeps, k)

    # run pathsum => allpairs => localEntropySemiring
    ps = Pathsum(kfsa)
    W = ps.allpairs(strategy=Strategy.LEHMANN, zero=True)  # returns {(p,q): LocalEntropySemiring}

    total_sp = 0.0
    total_spe = 0.0

    for (p, wInit) in kfsa.I:     # wInit is LocalEntropySemiring
        for (q, wFinal) in kfsa.F:
            if (p,q) in W:
                outPQ = (wInit * W[(p,q)]) * wFinal
                total_sp += outPQ.sp
                total_spe += outPQ.spe

    if total_sp<1e-15:
        return 0.0
    return total_spe / total_sp


########################################################
# E) Minimal Demo
########################################################
if __name__=="__main__":
    from rayuela.fsa.fsa import FSA
    from rayuela.base.state import State
    from rayuela.base.symbol import Sym, ε
    from rayuela.base.semiring import Real

    # make small PFSA with an eps arc
    fsa = FSA(R=Real)
    s0, s1, s2 = State(0), State(1), State(2)
    fsa.set_I(s0, Real(1.0))
    fsa.set_F(s2, Real(1.0))

    # arcs:
    fsa.add_arc(s0, ε, s1, Real(0.3))      # eps
    fsa.add_arc(s0, Sym("a"), s0, Real(0.5))
    fsa.add_arc(s1, Sym("b"), s2, Real(1.0))

    # compute k=2 local entropy
    val = k_local_entropy(fsa, k=2)
    print(f"2-local average entropy => {val:.4f}")


AttributeError: 'function' object has no attribute 'value'