Skip to content

Commit

Permalink
feat: basic parser sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
HK-SHAO committed May 13, 2024
1 parent 6025702 commit 70c396f
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ ignore_missing_imports = true
no_site_packages = true
check_untyped_defs = true

exclude = ["^build/", "^dist/", "^tests/"]
exclude = ["^build/", "^dist/", "^tests/", "^tmp/"]

[[tool.mypy.overrides]]
module = "tests.*"
Expand Down
47 changes: 47 additions & 0 deletions tests/test_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from yieldlang.combinators import select
from yieldlang.generator import TextGenerator
from yieldlang.sampler import ParserSampler


def test_parser_sampler():
class G(TextGenerator):
def top(self):
yield select("A", "B", "C")
yield select(1, 2, 3)
yield select("XYZ", "456")

sampler = ParserSampler()
g = G(sampler)
stream = ["A", "3", "X", "YZ"]
for s in stream:
try:
print(g.send(s), end="")
for _ in range(len(s) - 1):
print(g.send(None), end="")
except StopIteration as e:
print(e.value)
break


def test_next_pointer():
s = ParserSampler()
s.inputs = [None, "", "123", "", "4", None, "5"]

def g():
s.pointer = yield from s._next_pointer(*s.pointer)
assert s._cur_char() == "1"
s.pointer = yield from s._next_pointer(*s.pointer)
assert s._cur_char() == "2"
s.pointer = yield from s._next_pointer(*s.pointer)
assert s._cur_char() == "3"
s.pointer = yield from s._next_pointer(*s.pointer)
assert s._cur_char() == "4"
s.pointer = yield from s._next_pointer(*s.pointer)
assert s._cur_char() == "5"

list(g())


if __name__ == "__main__":
test_parser_sampler()
test_next_pointer()
6 changes: 3 additions & 3 deletions yieldlang/combinators.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ def select(*symbol: Symbol) -> ProxySymbol:
ProxySymbol: The proxy symbol that selects a symbol from the set.
"""

def select(g: TextGenerator, ctx: YContextTree):
yield g._sampler.select(g, ctx, *symbol)
def select(self: TextGenerator, ctx: YContextTree):
yield self._sampler.select(self, ctx, *symbol)

return ProxySymbol(select)

Expand All @@ -50,7 +50,7 @@ def join(sep: Symbol, to_seq: Symbol, depth: int = -1) -> ProxySymbol:
Args:
sep (Symbol): The separator symbol.
to_seq (Symbol): The symbol to join.
depth (int): The maximum depth to flatten. If negative, flatten all symbols.
depth (int): The maximum depth to flatten. If negative, flatten all symbols. Defaults to ``-1``.
Returns:
ProxySymbol: The joined symbol.
"""
Expand Down
6 changes: 6 additions & 0 deletions yieldlang/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ def top(self) -> Symbol:
"""
raise NotImplementedError

def send(self, value: str | None) -> str:
self._sampler.inputs.append(value)
if self._generator.gi_running:
return self._generator.send(value)
return next(self._generator)

def __init__(self, sampler: BaseSampler | None = None) -> None:
"""Initialize the generator with a sampler."""
self._sampler: BaseSampler = sampler or BaseSampler.default()
Expand Down
58 changes: 57 additions & 1 deletion yieldlang/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import TYPE_CHECKING

from yieldlang.tree import YContextTree
from yieldlang.types import Symbol
from yieldlang.types import EmptyString, Symbol

if TYPE_CHECKING:
from yieldlang.generator import TextGenerator
Expand All @@ -11,6 +11,9 @@
class BaseSampler:
"""Base class for samplers."""

def __init__(self) -> None:
self.inputs: list[str | None] = []

@staticmethod
def default() -> "RandomSampler":
"""Get the default sampler."""
Expand All @@ -35,3 +38,56 @@ def select(
) -> Symbol:
"""Randomly select a symbol from a set of symbols."""
return random.choice(symbol)


class ParserSampler(BaseSampler):
def __init__(self) -> None:
super().__init__()
self.pointer = (0, -1)

def select(
self, g: "TextGenerator", ctx: YContextTree, *symbol: Symbol
) -> Symbol:
strs = map(str, symbol) # TODO: Implement first set
# print(symbol, self.pointer, self.inputs)
for s in strs:
p = self.pointer
flag = True
for c in s:
p = yield from self._next_pointer(*p)
char = self._char(*p)
if char != c:
flag = False
break
else:
yield char
if flag:
self.pointer = p
return s
raise EOFError

def _char(self, i: int, j: int) -> str:
s = self.inputs[i]
assert s
return s[j]

def _cur_char(self) -> str:
return self._char(*self.pointer)

def _next_pointer(self, i: int, j: int):
while True:
try:
while not self.inputs[i]:
i += 1
j = -1
s = self.inputs[i]
assert s
j += 1
if j >= len(s):
i += 1
j = 0
while not self.inputs[i]:
i += 1
return i, j
except IndexError:
yield EmptyString

0 comments on commit 70c396f

Please sign in to comment.