Skip to content

Commit

Permalink
refactor: depth and sampler get TextGenerator
Browse files Browse the repository at this point in the history
The join function in test_combinators.py has been optimized to support a depth parameter for flattening symbols. This allows for more control over the depth of flattening when joining sequences of symbols with a separator.

- The join function now accepts a depth parameter, which specifies the maximum depth to flatten the symbols.
- The default value of the depth parameter is -1, indicating that all symbols should be flattened.
- The join function has been updated to use the depth parameter when flattening symbols.

This change improves the flexibility and performance of the join function in test_combinators.py.
  • Loading branch information
HK-SHAO committed May 11, 2024
1 parent 6ac6fd4 commit 87c952b
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 15 deletions.
13 changes: 8 additions & 5 deletions tests/test_combinators.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,12 @@ def d(self):
def test_y_join():
class G(TextGenerator):
def top(self):
a = yield join(", ", self.seq, depth=None)
a = yield join(", ", self.seq)
assert a == "A, B, C, E"

a = yield join(", ", self.seq, depth=+2)
assert a == "A, B, CE"

b = yield join(", ", 1234)
assert b == "1234"

Expand All @@ -71,16 +74,16 @@ def top(self):
d = yield join(", ", (1, 2, 3))
assert d == "1, 2, 3"

e = yield join(0, join(1, range(3)), depth=None)
e = yield join(0, join(1, range(3)))
assert e == "001010102"

d = yield join(", ", list(repeat("6", 3)))
assert d == "6, 6, 6"

f = yield join("-", self.abc)
f = yield join("-", self.abc, depth=0)
assert f == "0ABCE"

g = yield join("-", self.abc())
g = yield join("-", self.abc(), depth=1)
g2 = yield join("-", self.abc, depth=2)
assert g == g2 == "0-ABCE"

Expand All @@ -104,7 +107,7 @@ def top(self):
assert k == "0-1-1-1-1-2-2-1-1-2-2-33"

x = yield join("-", array, depth=4)
y = yield join("-", array, depth=None)
y = yield join("-", array, depth=-1)
assert y == "0-1-1-1-1-2-2-1-1-2-2-3-3"
assert x == y

Expand Down
8 changes: 4 additions & 4 deletions yieldlang/combinators.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def select(*symbol: Symbol) -> ProxySymbol:
"""

def select(self: TextGenerator, ctx: YContextTree):
yield self._sampler.select(*symbol)
yield self._sampler.select(self, *symbol)

return ProxySymbol(select)

Expand All @@ -44,19 +44,19 @@ def optional(*symbol: Symbol) -> ProxySymbol:
return select(EmptyString, symbol)


def join(sep: Symbol, to_seq: Symbol, depth: int | None = 1) -> ProxySymbol:
def join(sep: Symbol, to_seq: Symbol, depth: int = -1) -> ProxySymbol:
"""Join a sequence of symbols with a separator.
Args:
sep (Symbol): The separator symbol.
to_seq (Symbol): The symbol to join.
depth (int | None): The maximum depth to flatten. If ``None``, depth is unlimited.
depth (int | None): The maximum depth to flatten. If negative, flatten all symbols.
Returns:
ProxySymbol: The joined symbol.
"""

def join(self: TextGenerator, ctx: YContextTree):
ctx.max_depth = None if depth is None else ctx.cur_depth + depth
ctx.max_depth = -1 if depth < 0 else ctx.cur_depth + depth
iterator = self._flatten(to_seq, ctx)
iterator = iter_not_empty(iterator)

Expand Down
5 changes: 3 additions & 2 deletions yieldlang/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, sampler: BaseSampler | None = None) -> None:
"""Initialize the generator with a sampler."""
self._sampler: BaseSampler = sampler or BaseSampler.default()
"""The sampler to use for sampling symbols."""
self._root_ctx = YContextTree(max_depth=None, cur_depth=0)
self._root_ctx = YContextTree(max_depth=-1, cur_depth=0)
"""The root context for flattening symbols."""
self._generator: YGenerator = self.__iter_symbol(self.top)
"""The iterator to generate text."""
Expand Down Expand Up @@ -81,6 +81,7 @@ def flatten_symbol(symbol: Symbol) -> IteratorSymbol:
else: # Must be an iterable
for symbol in iter(nt):
yield from flatten_symbol(symbol)
ctx.ret_value = ""
except (StopIteration, EOSError):
pass

Expand Down Expand Up @@ -114,7 +115,7 @@ def _flatten(self, symbol: Symbol, ctx: YContextTree) -> IteratorSymbol:
ctx.children.append(child)
ctx = child

if ctx.max_depth is not None and ctx.cur_depth > ctx.max_depth:
if ctx.max_depth > -1 and ctx.cur_depth > ctx.max_depth:
ctx.ret_value = symbol
yield symbol
return None
Expand Down
8 changes: 6 additions & 2 deletions yieldlang/sampler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import random
from typing import TYPE_CHECKING

from yieldlang.types import Symbol

if TYPE_CHECKING:
from yieldlang.generator import TextGenerator


class BaseSampler:
"""Base class for samplers."""
Expand All @@ -11,7 +15,7 @@ def default() -> "RandomSampler":
"""Get the default sampler."""
return RandomSampler()

def select(self, *symbol: Symbol) -> Symbol:
def select(self, g: "TextGenerator", *symbol: Symbol) -> Symbol:
"""Select a symbol from a set of symbols.
Warning:
Expand All @@ -23,6 +27,6 @@ def select(self, *symbol: Symbol) -> Symbol:
class RandomSampler(BaseSampler):
"""Random sampler."""

def select(self, *symbol: Symbol) -> Symbol:
def select(self, g: "TextGenerator", *symbol: Symbol) -> Symbol:
"""Randomly select a symbol from a set of symbols."""
return random.choice(symbol)
4 changes: 2 additions & 2 deletions yieldlang/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def __init__(self, fn: ProxySymbolFn, *args: Symbol, **kwargs) -> None:
class YContextTree:
"""Context for flattening symbols."""

max_depth: int | None = None
"""The maximum depth to flatten. If None, flatten all symbols."""
max_depth: int = -1
"""The maximum depth to flatten. If ``-1``, flatten all symbols."""
cur_depth: int = 0
"""The current depth of flattening."""
ret_value: object = None
Expand Down

0 comments on commit 87c952b

Please sign in to comment.