### working implementation (old)


In [468]:
from __future__ import annotations

import copy
from collections import Counter
from collections import defaultdict as dd
from collections import deque
from itertools import product
from typing import List, Tuple


from rayuela.base.semiring import Boolean, ProductSemiring, Real, Semiring
from rayuela.base.state import PairState, State
from rayuela.base.symbol import Sym, ε
from rayuela.fsa.pathsum import Pathsum
from rayuela.fsa.fsa import FSA

from rayuela.base.symbol import EOS, BOS, ε
from rayuela.fsa.sampler import Sampler


class RandomNGramModel:
    """
    Builds a random n-gram model as a Weighted FSA (with Real semiring),
    with internal BOS, EOS tokens not in the user's alphabet.
    """
    def __init__(self,
                 alphabet,  # user-supplied normal symbols
                 n=2,
                 alpha=1.0):
        """
        alphabet: list of normal symbols (no boundary tokens)
        n: n-gram order
        alpha: Dirichlet concentration
        """
        self.alphabet = alphabet  # user symbols only
        self.n = n
        self.alpha = alpha

        self.fsa = FSA(R=Real)
        self.context2state = {}  # to store (n-1)-tuple -> State
        self._build_ngram_fsa()

    def _is_valid_context(self, ctx):
        """
        If we want to prevent EOS in context
        and not have BOS reappear after normal tokens:
        """
        if EOS in ctx:
            return False
        saw_normal = False
        for token in ctx:
            if token == BOS:
                if saw_normal:
                    return False
            else:
                saw_normal = True
        return True

    def _build_ngram_fsa(self):
        """
        1) define final absorbing state
        2) define contexts: (n-1)-tuples from {BOS}+alphabet, skip EOS in context
        3) sample dist over (alphabet + EOS), add arcs
        """
        q_final = State("<<FINAL>>")
        self.fsa.add_state(q_final)
        self.fsa.set_F(q_final, Real.one)

        if self.n <= 1:
            all_ctxs = [()]
        else:
            # possible context symbols
            csyms = [BOS] + self.alphabet  # not EOS
            raw = product(csyms, repeat=self.n-1)
            all_ctxs = [r for r in raw if self._is_valid_context(r)]

        # build states
        for ctx in all_ctxs:
            sname = str(ctx) if ctx else "()"
            st = State(sname)
            self.fsa.add_state(st)
            print(f"Adding state {st} for context {ctx}")
            self.context2state[ctx] = st

        # start ctx
        if self.n>1:
            start_ctx = tuple([BOS]*(self.n-1))
        else:
            start_ctx = ()
        if self._is_valid_context(start_ctx) and start_ctx not in self.context2state:
            st0_ = State(str(start_ctx) if start_ctx else "()")
            self.fsa.add_state(st0_)
            self.context2state[start_ctx] = st0_

        s0 = self.context2state.get(start_ctx, State("<INVALIDSTART>"))
        self.fsa.add_state(s0)
        self.fsa.set_I(s0, Real.one)

        # sample distributions
        out_syms = list(self.alphabet) + [EOS]
        from numpy.random import dirichlet

        for ctx in all_ctxs:
            dist = dirichlet([self.alpha]*len(out_syms))
            s_from = self.context2state[ctx]
            for i, pval in enumerate(dist):
                if pval<1e-15:
                    raise ValueError("Zero probability")
                sym = out_syms[i]
                w = Real(pval)
                if sym==EOS:
                    # arc to final
                    self.fsa.set_arc(s_from, sym, q_final, w)
                else:
                    if self.n>1:
                        new_ctx = tuple(list(ctx[1:]) + [sym]) if len(ctx)==(self.n-1) else (sym,)
                    else:
                        new_ctx = ()

                    if self._is_valid_context(new_ctx):
                        if new_ctx not in self.context2state:
                            stN = State(str(new_ctx) if new_ctx else "()")
                            self.fsa.add_state(stN)
                            self.context2state[new_ctx] = stN
                        s_to = self.context2state[new_ctx]
                        self.fsa.set_arc(s_from, sym, s_to, w)


class KShuffleNgram:
    """
    Naive approach: states=(ctx,buf).
    => We store partially read block in 'buf'.
    => If we must emit that block (full or leftover), we create a small chain
       of ephemeral states, each responsible for a single symbol output.

    That is "one symbol per transition" for each flush.
    """

    def __init__(self, ngram_model: RandomNGramModel, k: int, perturbation_fnc):
        self.ngram_model = ngram_model
        self.k = k
        self.perturbation_fnc = perturbation_fnc
        self.R = Real

        self.fsa = FSA(R=self.R)
        self._build_kshuffle()

    def _build_kshuffle(self):
        base_fsa = self.ngram_model.fsa

        # new final
        q_final = State("<KSHUFFLE_FINAL>")
        self.fsa.add_state(q_final)
        self.fsa.set_F(q_final, Real.one)

        queue = deque()
        visited = set()

        if self.ngram_model.n>1:
            start_ctx = tuple([BOS]*(self.ngram_model.n-1))
        else:
            start_ctx = ()

        def ensure_state(ctx, buf):
            sname = f"{ctx}|{buf}"
            st = State(sname)
            self.fsa.add_state(st)
            return st

        s0 = ensure_state(start_ctx, ())
        self.fsa.set_I(s0, Real.one)
        visited.add((start_ctx, ()))
        queue.append((start_ctx, ()))

        # final states in base_fsa
        base_final = set()
        for (qq,ww) in base_fsa.F:
            if ww!=self.R.zero:
                base_final.add(qq)

        inv_map = {v:k for (k,v) in self.ngram_model.context2state.items()}

        # gather arcs from base FSA for a given context
        def arcs_for_ctx(ctx):
            st_base = self.ngram_model.context2state.get(ctx,None)
            if st_base is None:
                return []
            results=[]
            for (a,j,w) in base_fsa.arcs(st_base, no_eps=True, nozero=True, reverse=False):
                results.append((a,j,w))
            return results

        while queue:
            (ctx, buf) = queue.popleft()
            s_from = ensure_state(ctx, buf)

            st_base = self.ngram_model.context2state.get(ctx, None)
            base_is_final = (st_base in base_final) if st_base else False

            # If base is final => leftover flush => ephemeral chain => final
            if base_is_final:
                leftover_list = list(buf)
                pblock = self.perturbation_fnc(leftover_list)
                self._emit_symbol_chain(s_from, pblock, q_final, EOSweight=Real.one)
                continue

            # If buffer<k => read from base arcs
            if len(buf)<self.k:
                A = arcs_for_ctx(ctx)
                for (a,j,w) in A:
                    if a==EOS:
                        # partial leftover flush
                        leftover_list = list(buf)
                        pblock = self.perturbation_fnc(leftover_list)
                        self._emit_symbol_chain(s_from, pblock, q_final, symbol=a, weight=w)
                    else:
                        symbol = str(a)
                        new_ctx = inv_map.get(j,None)
                        if new_ctx is None:
                            continue
                        new_buf = tuple(list(buf)+[symbol])
                        if (new_ctx,new_buf) not in visited:
                            visited.add((new_ctx,new_buf))
                            queue.append((new_ctx,new_buf))
                        s_to = ensure_state(new_ctx, new_buf)
                        self.fsa.set_arc(s_from, ε, s_to, w)
                continue

            # else => buffer==k => full block flush
            if len(buf)==self.k:
                pblock = self.perturbation_fnc(list(buf))
                # ephemeral chain => next BFS-labeled state = (ctx,())
                next_state = ensure_state(ctx, ())
                if (ctx,()) not in visited:
                    visited.add((ctx,()))
                    queue.append((ctx,()))
                self._emit_symbol_chain(s_from, pblock, next_state, EOSweight=Real.one, eplabel=ε)
                continue

    def _emit_symbol_chain(self, s_from:State, syms:List[str], final_st:State,
                           symbol: Sym = None, weight=None, EOSweight=None, eplabel=ε):
        """
        Creates ephemeral chain from s_from for each symbol in 'syms' (one symbol per transition).
        Then from the last ephemeral state => final_st with either the 'symbol' (like <EOS>) or eplabel.
        'weight' is used for that final transition if provided, else Real.one.
        'EOSweight' is used for ephemeral transitions if not provided, default Real.one
        'eplabel' for the final transition if no symbol is left.

        This ensures single symbol per transition PFSA for the flush.
        """
        if EOSweight is None:
            EOSweight = self.R.one
        curr_st = s_from
        # for each symbol => ephemeral state
        for i, sym in enumerate(syms):
            e_name = f"EMIT_{sym}_@{s_from}_count={i}"
            e_st = State(e_name)
            self.fsa.add_state(e_st)
            # connect curr_st -sym-> e_st
            self.fsa.set_arc(curr_st, Sym(sym), e_st, EOSweight)
            curr_st = e_st

        # now from curr_st => final_st
        # if we have an actual 'symbol' to emit (like <EOS>):
        if symbol is not None:
            self.fsa.set_arc(curr_st, symbol, final_st, weight if weight else self.R.one)
        else:
            # else we do eplabel => final
            self.fsa.set_arc(curr_st, eplabel, final_st, weight if weight else self.R.one)


###############################################################################
# Usage
###############################################################################

if __name__=="__main__":
    user_alphabet = ['a', 'b']
    rng_model = RandomNGramModel(alphabet=user_alphabet, n=3, alpha=1.0)
    print("Base n-gram FSA states:", rng_model.fsa.num_states, "initial:", list(rng_model.fsa.I), "final:", list(rng_model.fsa.F))

    # left rotate
    def left_rotate(lst):
        if len(lst) == 0:
            return lst
        return lst[1:] + [lst[0]]

    kmodel = KShuffleNgram(rng_model, k=2, perturbation_fnc=left_rotate)
    print("K-shuffle FSA states:", kmodel.fsa.num_states, "initial:", list(kmodel.fsa.I), "final:", list(kmodel.fsa.F))

    print("Base n-gram FSA entropy:", rng_model.fsa.entropy())
    print("K-shuffle FSA entropy:", kmodel.fsa.entropy())




Adding state ('BOS', 'BOS') for context ('BOS', 'BOS')
Adding state ('BOS', 'a') for context ('BOS', 'a')
Adding state ('BOS', 'b') for context ('BOS', 'b')
Adding state ('a', 'a') for context ('a', 'a')
Adding state ('a', 'b') for context ('a', 'b')
Adding state ('b', 'a') for context ('b', 'a')
Adding state ('b', 'b') for context ('b', 'b')
Base n-gram FSA states: 8 initial: [(('BOS', 'BOS'), 1.0)] final: [(<<FINAL>>, 1.0)]
K-shuffle FSA states: 30 initial: [(('BOS', 'BOS')|(), 1.0)] final: [(<KSHUFFLE_FINAL>, 1.0)]
Base n-gram FSA entropy: 13.29995
K-shuffle FSA entropy: 13.29995


In [469]:
from rayuela.fsa.transformer import Transformer

rng_model_non_eps = rng_model.fsa.normalize().epsremove().normalize()
kshuffle_model_non_eps = kmodel.fsa.push().epsremove().normalize()
print("Non-EPS states (base):", rng_model_non_eps.num_states)
print("Non-EPS states (kshuffle):", kshuffle_model_non_eps.num_states)
print("Base n-gram FSA entropy:", rng_model_non_eps.entropy())
print("K-shuffle FSA entropy:", kshuffle_model_non_eps.entropy())


Non-EPS states (base): 8
Non-EPS states (kshuffle): 26
Base n-gram FSA entropy: 13.29995
K-shuffle FSA entropy: 33.5551


  print("inverse", self.value)


### Experiment


In [733]:
from __future__ import annotations

import copy
from collections import Counter
from collections import defaultdict as dd
from collections import deque
from itertools import product
from typing import List, Tuple


from rayuela.base.semiring import Boolean, ProductSemiring, Real, Semiring
from rayuela.base.state import PairState, State
from rayuela.base.symbol import Sym, ε
from rayuela.fsa.pathsum import Pathsum
# from rayuela.fsa.fsa import FSA

from rayuela.base.symbol import EOS, BOS, ε
from rayuela.fsa.sampler import Sampler


class RandomNGramModel:
    """
    Builds a random n-gram model as a Weighted FSA (with Real semiring),
    with internal BOS, EOS tokens not in the user's alphabet.
    """
    def __init__(self,
                 alphabet,  # user-supplied normal symbols
                 n=2,
                 alpha=1.0):
        """
        alphabet: list of normal symbols (no boundary tokens)
        n: n-gram order
        alpha: Dirichlet concentration
        """
        self.alphabet = alphabet  # user symbols only
        self.n = n
        self.alpha = alpha

        self.fsa = FSA(R=Real)
        self.context2state = {}  # to store (n-1)-tuple -> State
        self._build_ngram_fsa()

    def _is_valid_context(self, ctx):
        """
        If we want to prevent EOS in context
        and not have BOS reappear after normal tokens:
        """
        if EOS in ctx:
            return False
        saw_normal = False
        for token in ctx:
            if token == BOS:
                if saw_normal:
                    return False
            else:
                saw_normal = True
        return True

    def _build_ngram_fsa(self):
        """
        1) define final absorbing state
        2) define contexts: (n-1)-tuples from {BOS}+alphabet, skip EOS in context
        3) sample dist over (alphabet + EOS), add arcs
        """
        q_final = State("<<FINAL>>")
        self.fsa.add_state(q_final)
        self.fsa.set_F(q_final, Real.one)

        if self.n <= 1:
            all_ctxs = [()]
        else:
            # possible context symbols
            csyms = [BOS] + self.alphabet  # not EOS
            raw = product(csyms, repeat=self.n-1)
            all_ctxs = [r for r in raw if self._is_valid_context(r)]

        # build states
        for ctx in all_ctxs:
            sname = str(ctx) if ctx else "()"
            st = State(sname)
            self.fsa.add_state(st)
            self.context2state[ctx] = st

        # start ctx
        if self.n>1:
            start_ctx = tuple([BOS]*(self.n-1))
        else:
            start_ctx = ()
        if self._is_valid_context(start_ctx) and start_ctx not in self.context2state:
            st0_ = State(str(start_ctx) if start_ctx else "()")
            self.fsa.add_state(st0_)
            self.context2state[start_ctx] = st0_

        s0 = self.context2state.get(start_ctx, State("<INVALIDSTART>"))
        self.fsa.add_state(s0)
        self.fsa.set_I(s0, Real.one)

        # sample distributions
        out_syms = list(self.alphabet) + [EOS]
        from numpy.random import dirichlet

        for ctx in all_ctxs:
            dist = dirichlet([self.alpha]*len(out_syms))
            st_from = self.context2state[ctx]
            for i, pval in enumerate(dist):
                if pval<1e-15:
                    raise ValueError("Zero probability")
                sym = out_syms[i]
                w = Real(pval)
                if sym==EOS:
                    # arc to final
                    self.fsa.set_arc(st_from, sym, q_final, w)
                else:
                    if self.n>1:
                        new_ctx = tuple(list(ctx[1:]) + [sym]) if len(ctx)==(self.n-1) else (sym,)
                    else:
                        new_ctx = ()

                    if self._is_valid_context(new_ctx):
                        if new_ctx not in self.context2state:
                            stN = State(str(new_ctx) if new_ctx else "()")
                            self.fsa.add_state(stN)
                            self.context2state[new_ctx] = stN
                        s_to = self.context2state[new_ctx]
                        self.fsa.set_arc(st_from, sym, s_to, w)


class KShuffleNgram:
    """
    Naive approach: states=(ctx,buf).
    => We store partially read block in 'buf'.
    => If we must emit that block (full or leftover), we create a small chain
       of ephemeral states, each responsible for a single symbol output.

    That is "one symbol per transition" for each flush.
    """

    def __init__(self, ngram_model: RandomNGramModel, k: int, perturbation_fnc):
        self.ngram_model = ngram_model
        self.k = k
        self.perturbation_fnc = perturbation_fnc
        self.R = Real

        self.fsa = FSA(R=self.R)
        self._build_kshuffle()

    def _build_kshuffle(self):
        base_fsa = self.ngram_model.fsa

        # new final
        q_final = State("<KSHUFFLE_FINAL>")
        self.fsa.add_state(q_final)
        self.fsa.set_F(q_final, Real.one)

        queue = deque()
        visited = set()

        if self.ngram_model.n>1:
            start_ctx = tuple([BOS]*(self.ngram_model.n-1))
        else:
            start_ctx = ()

        def ensure_state(ctx, buf):
            sname = f"{ctx}|{buf}"
            st = State(sname)
            self.fsa.add_state(st)
            return st

        s0 = ensure_state(start_ctx, ())
        self.fsa.set_I(s0, Real.one)
        visited.add((start_ctx, ()))
        queue.append((start_ctx, ()))

        # final states in base_fsa
        base_final = set()
        for (qq,ww) in base_fsa.F:
            if ww!=self.R.zero:
                base_final.add(qq)

        inv_map = {v:k for (k,v) in self.ngram_model.context2state.items()}

        # gather arcs from base FSA for a given context
        def arcs_for_ctx(ctx):
            st_base = self.ngram_model.context2state.get(ctx,None)
            if st_base is None:
                return []
            results=[]
            for (a,j,w) in base_fsa.arcs(st_base, no_eps=True, nozero=True, reverse=False):
                results.append((a,j,w))
            return results

        while queue:
            (ctx, buf) = queue.popleft()
            st_from = ensure_state(ctx, buf)

            st_base = self.ngram_model.context2state.get(ctx, None)
            base_is_final = (st_base in base_final) if st_base else False

            # If buffer<k => read from base arcs
            if len(buf)<self.k:
                A = arcs_for_ctx(ctx)
                for (a,j,w) in A:
                    if a==EOS:
                        # flush partial leftover
                        # move to the ephemeral state with prob w and emit the perturbed block one by one
                        leftover_list = list(buf)
                        pblock = self.perturbation_fnc(leftover_list)
                        self._emit_partial_symbol_chain(st_from, pblock, q_final, w)
                    else:
                        symbol = str(a)
                        new_ctx = inv_map.get(j,None)
                        if new_ctx is None:
                            continue
                        new_buf = tuple(list(buf)+[symbol])
                        if (new_ctx,new_buf) not in visited:
                            visited.add((new_ctx,new_buf))
                            queue.append((new_ctx,new_buf))
                        s_to = ensure_state(new_ctx, new_buf)
                        self.fsa.set_arc(st_from, ε, s_to, w)
                continue

            # else => buffer==k => full block flush
            elif len(buf)==self.k:
                pblock = self.perturbation_fnc(list(buf))
                # ephemeral chain => next BFS-labeled state = (ctx,())
                next_state = ensure_state(ctx, ())
                if (ctx,()) not in visited:
                    visited.add((ctx,()))
                    queue.append((ctx,()))
                self._emit_full_symbol_chain(st_from, pblock, next_state)
                continue

            else:
                raise ValueError("Invalid buffer state")

    def _emit_full_symbol_chain(self, st_from: State, syms: List[str], final_st: State):
        """
        Creates ephemeral chain from st_from, emitting all symbols in 'syms' and then to final_st with ε.
        """
        curr_st = st_from
        for i, sym in enumerate(syms):
            e_name = f"EMIT({sym})@{st_from}(count={i})"
            e_st = State(e_name)
            self.fsa.add_state(e_st)
            self.fsa.set_arc(curr_st, Sym(sym), e_st, self.R.one)
            curr_st = e_st
        self.fsa.set_arc(curr_st, ε, final_st, self.R.one)

    def _emit_partial_symbol_chain(self, st_from: State, syms: List[str], final_st: State, eos_weight: Real):
        """
        Creates ephemeral chain from st_from, emitting partial leftover symbols in 'syms' and then to final_st with EOS.
        The arc from st_from to the first ephemeral state has weight eos_weight.
        """
        curr_st = st_from
        if len(syms)==0:
            self.fsa.set_arc(curr_st, EOS, final_st, eos_weight)
            return
        for i, sym in enumerate(syms):
            e_name = f"EMIT({sym})@{st_from}(count={i})"
            e_st = State(e_name)
            self.fsa.add_state(e_st)
            if i==0:
                self.fsa.set_arc(curr_st, Sym(sym), e_st, eos_weight)
            else:
                self.fsa.set_arc(curr_st, Sym(sym), e_st, self.R.one)
            curr_st = e_st
        self.fsa.set_arc(curr_st, EOS, final_st, self.R.one)


# def high_precision_entropy(fsa, precision=50):
#     """Computes the entropy of the FSA with higher numerical precision.

#     Returns:
#         Real: The entropy of the FSA.
#     """
#     from math import log
#     from decimal import Decimal, getcontext
#     from rayuela.base.semiring import Entropy

#     # Set higher precision
#     getcontext().prec = precision

#     assert fsa.R == Real

#     def _high_precision_entropy(w):
#         w_dec = Decimal(str(float(w)))
#         log_w = w_dec.ln()
#         return Entropy(float(w), -float(w_dec * log_w))
#         # return Entropy(float(w), -float(log_w))


#     return Real(
#         fsa.lift(Entropy, _high_precision_entropy)
#         .pathsum()
#         .value[1]
#     )

def fixed_entropy(self) -> Real:
    """Computes the entropy of the FSA.

    Returns:
        Real: The entropy of the FSA.
    """
    from math import log

    from rayuela.base.semiring import Entropy

    assert self.R == Real

    return Real(
        self.lift(Entropy, lambda w: Entropy(float(w), float(w) * -log(float(w)))) # the second term must be p * -log(p), not -log(p)
        .pathsum()
        .value[1]
    )

###############################################################################
# Usage
###############################################################################

if __name__=="__main__":
    user_alphabet = ['a', 'b', 'c']
    rng_model = RandomNGramModel(alphabet=user_alphabet, n=3, alpha=1.0)
    print("Base n-gram FSA states:", rng_model.fsa.num_states)

    # left rotate
    def left_rotate(lst):
        if len(lst) == 0:
            return lst
        return lst[1:] + [lst[0]]

    kmodel = KShuffleNgram(rng_model, k=2, perturbation_fnc=left_rotate)
    print("K-shuffle FSA states:", kmodel.fsa.num_states)

    print("Base n-gram FSA entropy:", fixed_entropy(rng_model.fsa))
    print("K-shuffle FSA entropy:", fixed_entropy(kmodel.fsa))


from rayuela.fsa.transformer import Transformer

rng_model_non_eps = rng_model.fsa.normalize().epsremove().normalize()
kshuffle_model_non_eps = kmodel.fsa.normalize().epsremove().normalize()
print("Base n-gram FSA entropy (non-eps):", fixed_entropy(rng_model_non_eps))
print("K-shuffle FSA entropy (non-eps):", fixed_entropy(kshuffle_model_non_eps))

Base n-gram FSA states: 14
K-shuffle FSA states: 62
Base n-gram FSA entropy: 4.52596
K-shuffle FSA entropy: 4.52596
Base n-gram FSA entropy (non-eps): 4.52596
K-shuffle FSA entropy (non-eps): 4.52596


  print("inverse", self.value)


In [742]:
# sample
def sample_strings(fsa, n=10, remove_eps=False):
    sampler = Sampler(fsa)
    if remove_eps:
        return [sample.replace('ε', '') for sample in sampler.sample(n, sep='')]
    return sampler.sample(n, sep='')

original_sampler = Sampler(rng_model.fsa)
kshuffle_sampler = Sampler(kmodel.fsa)

print("Original sample:", sample_strings(rng_model.fsa, 10))
print("K-shuffle sample:", sample_strings(kmodel.fsa, 10, remove_eps=True))


100%|██████████| 10/10 [00:00<00:00, 2567.36it/s]


Original sample: ['b', '', '', '', 'a', '', 'abbba', 'babbccbb', '', 'a']


100%|██████████| 10/10 [00:00<00:00, 682.51it/s]

K-shuffle sample: ['bacccbcbbccbbb', 'b', 'ccb', '', 'acaaaacc', 'acaacacbbb', '', 'acbaa', 'a', 'cc']





### test if the entropy implementaion is correct


In [726]:
import math

def fixed_entropy(self) -> Real:
    """Computes the entropy of the FSA.

    Returns:
        Real: The entropy of the FSA.
    """
    from math import log

    from rayuela.base.semiring import Entropy

    assert self.R == Real

    return Real(
        self.lift(Entropy, lambda w: Entropy(float(w), float(w) * -log(float(w)))) # the second term must be p * -log(p), not -log(p)
        .pathsum()
        .value[1]
    )

simple_fsa = FSA(R=Real)
s0 = State("s0")
s1 = State("s1")
simple_fsa.add_state(s0)
simple_fsa.add_state(s1)
simple_fsa.set_I(s0)
simple_fsa.set_F(s1)
simple_fsa.set_arc(s0, Sym('a'), s1, Real(0.25))
simple_fsa.set_arc(s0, Sym('b'), s1, Real(0.75))

print("True Entropy:", -0.25 * math.log(0.25) - 0.75 * math.log(0.75))
print("Rayuela Entropy:", simple_fsa.entropy())
print("Corrected Entropy:", fixed_entropy(simple_fsa))


True Entropy: 0.5623351446188083
Rayuela Entropy: 1.67398
Corrected Entropy: 0.56234


### k-local entropy


In [788]:
#!/usr/bin/env python3

import math
import numpy as np
from collections import defaultdict, deque

from rayuela.fsa.fsa import FSA
from rayuela.base.symbol import Sym, BOS
from rayuela.base.state import State
from rayuela.fsa.pathsum import Pathsum, Strategy
from rayuela.base.semiring import Real, Semiring

##################################################
# 1) LocalEntropy
##################################################

class LocalEntropy(Semiring):
    """
    We store (sumProb, sumProbEnt).
      sumProb = total probability mass
      sumProbEnt = sum_{paths} [p(path)* pathEntropy]

    add => (p1+p2, e1+ e2)
    mul => (p1*p2, e1*p2 + e2*p1)
    """
    def __init__(self, sp, spe):
        super().__init__((sp, spe))

    def star(self):
        tmp = 1.0 / (1.0 - self.value[0])
        return LocalEntropy(tmp, tmp * tmp * self.value[1])

    def __add__(self, other):
        if other is self.zero:
            return self
        if self is self.zero:
            return other
        return LocalEntropy(self.value[0] + other.value[0], self.value[1] + other.value[1])

    def __mul__(self, other):
        if other is self.one:
            return self
        if self is self.one:
            return other
        if other is self.zero:
            return self.zero
        if self is self.zero:
            return self.zero
        return LocalEntropy(
            self.value[0] * other.value[0],
            self.value[0] * other.value[1] + self.value[1] * other.value[0],
        )

    def __eq__(self, other):
        return self.value == other.value

    def __repr__(self):
        return f"Entropy({self.value})"

    def __hash__(self):
        return hash(self.value)

    def __repr__(self):
        return f"LE({self.value[0]:.4g}, {self.value[1]:.4g})"


LocalEntropy.zero = LocalEntropy(0.0, 0.0)
LocalEntropy.one = LocalEntropy(1.0, 0.0)


##################################################
# 2) Build k-context FSA in LocalEntropy
##################################################

def build_k_context_fsa(noeps_fsa: FSA, k: int) -> FSA:
    KFSA = FSA(R=LocalEntropy)

    # gather adjacency and outflow
    adjacency = defaultdict(list)
    outSum = {}
    for q in noeps_fsa.Q:
        arcs = list(noeps_fsa.arcs(q))
        total = noeps_fsa.ρ[q].value
        for a, r, w in arcs:
            total += w.value
        outSum[q] = total
        adjacency[q] = arcs

    stateMap = {}
    def ensure_state(oldQ, buf):
        key = (oldQ, tuple(buf))
        if key not in stateMap:
            st = State(f"{oldQ}||{buf}")
            KFSA.add_state(st)
            stateMap[key] = st
        return stateMap[key]

    from collections import deque
    visited = set()
    queue = deque()

    # define initial states
    from rayuela.base.symbol import BOS
    buf0 = [BOS]*(k-1) if k>1 else []
    for (q, wI) in noeps_fsa.I:
        pval = wI.value
        stN = ensure_state(q, buf0)
        if pval>1e-15:
            KFSA.set_I(stN, LocalEntropy(pval, pval * -math.log(pval)))
        visited.add((q, tuple(buf0)))
        queue.append((q, tuple(buf0)))

    # BFS
    while queue:
        (oldQ, buf) = queue.popleft()
        stFrom = stateMap[(oldQ, buf)]

        # final
        fval = noeps_fsa.ρ[oldQ].value
        if fval>1e-15:
            # Manually do dictionary update => call zero():
            currentF = KFSA.ρ.get(stFrom, KFSA.R.zero)
            newF = currentF + LocalEntropy(fval, fval * -math.log(fval))
            KFSA.ρ[stFrom] = newF

        out_v = outSum[oldQ]
        if out_v<1e-15:
            continue

        # arcs
        for (a, r, w) in adjacency[oldQ]:
            arcProb = w.value/out_v
            if arcProb>1e-15:
                newBuf = list(buf[1:]) + [a.value] if k>1 else []
                stTo = ensure_state(r, newBuf)
                KFSA.set_arc(stFrom, a, stTo, LocalEntropy(arcProb, arcProb * -math.log(arcProb)))
                if (r, tuple(newBuf)) not in visited:
                    visited.add((r, tuple(newBuf)))
                    queue.append((r, tuple(newBuf)))

    return KFSA




##################################################
# 3) The k_local_entropy function
##################################################

def k_local_entropy(noeps_fsa: FSA, k: int) -> float:
    """
    1) We assume noeps_fsa is single-symbol arcs, no eps.
    2) build k-fsa in LocalEntropy
    3) pathsum => sum_{paths} p, sum_{paths} pH => ratio
    """
    # build k-fsa
    KFSA = build_k_context_fsa(noeps_fsa, k)

    # run pathsum => allpairs => local ent
    ps = Pathsum(KFSA)
    W = ps.allpairs(zero=True)  # default strategy=LEHMANN

    total_sp = 0.0
    total_spe = 0.0

    for (p, init_w) in KFSA.I:
        for (q, final_w) in KFSA.F:
            if (p,q) in W:
                # multiply them
                mid = init_w * W[(p,q)]
                out = mid * final_w
                # sum up
                total_sp += out.value[0] # sumProb
                total_spe += out.value[1] # sumProbEnt

    if total_sp<1e-15:
        return 0.0
    return total_spe / total_sp


##################################################
# 4) Demo usage
##################################################

if __name__=="__main__":
    from rayuela.fsa.fsa import FSA
    from rayuela.base.semiring import Real
    from rayuela.base.state import State
    from rayuela.base.symbol import Sym

    # 1) an example no-ε PFSA
    fsa = FSA(R=Real)
    s0, s1 = State(0), State(1)
    fsa.set_I(s0, Real(1.0))
    fsa.set_F(s1, Real(0.4))

    # arcs
    # s0 --a/0.6--> s0
    # s0 --b/0.4--> s1
    # out of s1 => final=0.4 => leftover prob=0.6 => can loop or do nothing
    fsa.add_arc(s0, Sym("a"), s0, Real(0.6))
    fsa.add_arc(s0, Sym("b"), s1, Real(0.4))

    # compute k=2 local average entropy
    val = k_local_entropy(fsa, k=2)
    print(f"k=2 local average entropy => {val:.4f}")


k=2 local average entropy => 2.5988


In [796]:
user_alphabet = ['a', 'b', 'c']
rng_model = RandomNGramModel(alphabet=user_alphabet, n=3, alpha=1.0)
print("Base n-gram FSA states:", rng_model.fsa.num_states)

# left rotate
def left_rotate(lst):
    if len(lst) == 0:
        return lst
    return lst[1:] + [lst[0]]

kmodel = KShuffleNgram(rng_model, k=3, perturbation_fnc=left_rotate)
print("K-shuffle FSA states:", kmodel.fsa.num_states)

print("Base n-gram FSA entropy:", fixed_entropy(rng_model.fsa))
print("K-shuffle FSA entropy:", fixed_entropy(kmodel.fsa))


from rayuela.fsa.transformer import Transformer

rng_model_non_eps = rng_model.fsa.normalize().epsremove().normalize()
kshuffle_model_non_eps = kmodel.fsa.normalize().epsremove().normalize()
print("Base n-gram FSA entropy (non-eps):", fixed_entropy(rng_model_non_eps))
print("K-shuffle FSA entropy (non-eps):", fixed_entropy(kshuffle_model_non_eps))

Base n-gram FSA states: 14
K-shuffle FSA states: 170
Base n-gram FSA entropy: 5.6021
K-shuffle FSA entropy: 5.6021


  print("inverse", self.value)


Base n-gram FSA entropy (non-eps): 5.6021
K-shuffle FSA entropy (non-eps): 5.6021


In [797]:
ks = [2, 3]

for k in ks:
    print(f"K={k} Local Entropy (base_model): {k_local_entropy(rng_model_non_eps, k):.5f}")
    print(f"K={k} Local Entropy (kshuffle_model): {k_local_entropy(kshuffle_model_non_eps, k):.5f}")

K=2 Local Entropy (base_model): 5.60210
K=2 Local Entropy (kshuffle_model): 5.60210
K=3 Local Entropy (base_model): 5.60210
K=3 Local Entropy (kshuffle_model): 5.60210
