In [2]:
import pyparsing

from typing import TypeVar, Optional
from collections.abc import Hashable, Callable

DataType = Hashable
TreeList = list[str, Optional[list['TreeList']]]
TreeTuple = tuple[DataType, Optional[tuple['TreeTuple', ...]]]

class Tree:

    LPAR = pyparsing.Suppress('(')
    RPAR = pyparsing.Suppress(')')
    DATA = pyparsing.Regex(r'[^\(\)\s]+')

    PARSER = pyparsing.Forward()
    SUBTREE = pyparsing.ZeroOrMore(PARSER)
    PARSERLIST = pyparsing.Group(LPAR + DATA + SUBTREE + RPAR)
    PARSER <<= DATA | PARSERLIST
    
    def __init__(self, data: DataType, children: list['Tree'] = []):
        self._data = data
        self._children = children
        
        self._validate()
  
    def to_tuple(self) -> TreeTuple:
        return self._data, tuple(c.to_tuple() for c in self._children)

    def __hash__(self) -> int:
        return hash(self.to_tuple())
    
    def __eq__(self, other: 'Tree') -> bool:
        return self.to_tuple() == other.to_tuple()

    def __str__(self) -> str:
        return ' '.join(self.terminals)
        
    def __repr__(self) -> str:
        return self.to_string()
     
    def to_string(self, depth=0) -> str:
        s = (depth - 1) * '  ' +\
            int(depth > 0) * '--' +\
            self._data + '\n'
        s += ''.join(c.to_string(depth+1)
                     for c in self._children)
        
        return s
    
    def __contains__(self, data: DataType) -> bool:
        # pre-order depth-first search
        if self._data == data:
            return True
        else:
            for child in self._children:
                if data in child:
                    return True
                
            return False
        
    def __getitem__(self, idx: int | tuple[int, ...]) -> 'Tree':
        if isinstance(idx, int):
            return self._children[idx]
        elif len(idx) == 1:
            return self._children[idx[0]]
        elif idx:
            return self._children[idx[0]].__getitem__(idx[1:])
        else:
            return self
        
    @property
    def data(self) -> DataType:
        return self._data 
    
    @property
    def children(self) -> list['Tree']:
        return self._children
     
    @property
    def terminals(self) -> list[str]:
        if self._children:
            return [w for c in self._children 
                    for w in c.terminals]
        else:
            return [str(self._data)]
        
    def _validate(self) -> None:
        try:
            assert all(isinstance(c, Tree)
                       for c in self._children)
        except AssertionError:
            msg = 'all children must be trees'
            raise TypeError(msg)
            
    def index(self, data: DataType, index_path: tuple[int, ...] = tuple()) -> list[tuple[int, ...]]:
        indices = [index_path] if self._data==data else []
        root_path = [] if index_path == -1 else index_path
        
        indices += [j 
                    for i, c in enumerate(self._children) 
                    for j in c.index(data, root_path+(i,))]

        return indices
    
    def relabel(self, label_map: Callable[[DataType], DataType], 
                nonterminals_only: bool = False, terminals_only: bool = False) -> 'Tree':
        if not nonterminals_only and not terminals_only:
            data = label_map(self._data)
        elif nonterminals_only and self._children:
            data = label_map(self._data)
        elif terminals_only and not self._children:
            data = label_map(self._data)
        else:
            data = self._data
        
        children = [c.relabel(label_map, nonterminals_only, terminals_only) 
                    for c in self._children]
        
        return self.__class__(data, children)
    
    @classmethod
    def from_string(cls, treestr: str) -> 'Tree':
        treelist = cls.PARSER.parseString(treestr[2:-2])[0]
        
        return cls.from_list(treelist)
    
    @classmethod
    def from_list(cls, treelist: TreeList) -> 'Tree':
        if isinstance(treelist, str):
            return cls(treelist[0])
        elif isinstance(treelist[1], str):
            return cls(treelist[0], [cls(treelist[1])])
        else:
            return cls(treelist[0], [cls.from_list(l) for l in treelist[1:]])

In [3]:
import re

StringVariables = tuple[int, ...]

class MCFGRuleElement:

    """A multiple context free grammar rule element

    Parameters
    ----------
    variable
    string_variables

    Attributes
    ----------
    symbol
    string_variables
    """

    def __init__(self, variable: str, *string_variables: StringVariables):
        self._variable = variable
        self._string_variables = string_variables

    def __str__(self) -> str:
        strvars = ', '.join(
            ''.join(str(v) for v in vtup)
            for vtup in self._string_variables
        )
        
        return f"{self._variable}({strvars})"

    def __eq__(self, other) -> bool:
        vareq = self._variable == other._variable
        strvareq = self._string_variables == other._string_variables
        
        return vareq and strvareq
        
    def to_tuple(self) -> tuple[str, tuple[StringVariables, ...]]:
        return (self._variable, self._string_variables)

    def __hash__(self) -> int:
        return hash(self.to_tuple())
        
    @property
    def variable(self) -> str:
        return self._variable

    @property
    def string_variables(self) -> tuple[StringVariables, ...]:
        return self._string_variables

    @property    
    def unique_string_variables(self) -> set[int]:
        return {
            i
            for tup in self.string_variables
            for i in tup
        }
        

In [4]:
print(
    MCFGRuleElement('VPwhemb', (0,), (1,)),
    "->", 
    MCFGRuleElement('NPwh', (0,)), 
    MCFGRuleElement('Vpres', (1,))
)

VPwhemb(0, 1) -> NPwh(0) Vpres(1)


In [5]:
print(
    MCFGRuleElement('VPwhemb', (0,), (2, 1)),
    "->", 
    MCFGRuleElement('NPwhdisloc', (0,), (1,)),
    MCFGRuleElement('Vpres', (2,))
)


VPwhemb(0, 21) -> NPwhdisloc(0, 1) Vpres(2)


In [6]:
SpanIndices = tuple[int, ...]

class MCFGRuleElementInstance:
    """An instantiated multiple context free grammar rule element

    Parameters
    ----------
    symbol
    string_spans

    Attributes
    ----------
    symbol
    string_spans
    """
    def __init__(self, variable: str, *string_spans: SpanIndices):
        self._variable = variable
        self._string_spans = string_spans

    def __eq__(self, other: 'MCFGRuleElementInstance') -> bool:
        vareq = self._variable == other._variable
        strspaneq = self._string_spans == other._string_spans
        
        return vareq and strspaneq
        
    def to_tuple(self) -> tuple[str, tuple[SpanIndices, ...]]:
        return (self._variable, self._string_spans)

    def __hash__(self) -> int:
        return hash(self.to_tuple())

    def __str__(self):
        strspans = ', '.join(
            str(list(stup))
            for stup in self._string_spans
        )
        
        return f"{self._variable}({strspans})"

    def __repr__(self) -> str:
        return self.__str__()
    
    @property
    def variable(self) -> str:
        return self._variable

    @property
    def string_spans(self) -> tuple[SpanIndices, ...]:
        return self._string_spans

In [7]:
SpanMap = dict[int, SpanIndices]

class MCFGRule:
    """A linear multiple context free grammar rule

    Parameters
    ----------
    left_side 
    right_side

    Attributes
    ----------
    left_side
    right_side
    """

    def __init__(self, left_side: MCFGRuleElement, *right_side: MCFGRuleElement):
        self._left_side = left_side
        self._right_side = right_side

        self._validate()

    def to_tuple(self) -> tuple[MCFGRuleElement, tuple[MCFGRuleElement, ...]]:
        return (self._left_side, self._right_side)

    def __hash__(self) -> int:
        return hash(self.to_tuple())
    
    def __repr__(self) -> str:
        return '<Rule: '+str(self)+'>'
        
    def __str__(self) -> str:
        if self.is_epsilon:
            return str(self._left_side)                

        else:
            return str(self._left_side) +\
                ' -> ' +\
                ' '.join(str(el) for el in self._right_side)

    def __eq__(self, other: 'MCFGRule') -> bool:
        left_side_equal = self._left_side == other._left_side
        right_side_equal = self._right_side == other._right_side

        return left_side_equal and right_side_equal

    def _validate(self):
        vs = [
            el.unique_string_variables
            for el in self.right_side
        ]
        sharing = any(
            vs1.intersection(vs2)
            for i, vs1 in enumerate(vs)
            for j, vs2 in enumerate(vs)
            if i < j
        )

        if sharing:
            raise ValueError(
                'right side variables cannot share '
                'string variables'
            )

        if not self.is_epsilon:
            left_vars = self.left_side.unique_string_variables
            right_vars = {
                var for el in self.right_side
                for var in el.unique_string_variables
            }
            if left_vars != right_vars:
                raise ValueError(
                    'number of arguments to instantiate must '
                    'be equal to number of unique string_variables'
                )
        
    @property
    def left_side(self) -> MCFGRuleElement:
        return self._left_side

    @property
    def right_side(self) -> tuple[MCFGRuleElement, ...]:
        return self._right_side

    @property
    def is_epsilon(self) -> bool:
        return len(self._right_side) == 0

    @property
    def unique_variables(self) -> set[str]:
        return {
            el.variable
            for el in [self._left_side]+list(self._right_side)
        }

    def instantiate_left_side(self, *right_side: MCFGRuleElementInstance) -> MCFGRuleElementInstance:
        """Instantiate the left side of the rule given an instantiated right side

        Parameters
        ----------
        right_side
            The instantiated right side of the rule.
        """
        
        if self.is_epsilon:
            strvars = tuple(v[0] for v in self._left_side.string_variables)
            strconst = tuple(el.variable for el in right_side)
            
            if strconst == strvars:
                return MCFGRuleElementInstance(
                    self._left_side.variable,
                    *[s for el in right_side for s in el.string_spans]
                )

        new_spans = []
        span_map = self._build_span_map(right_side)
        
        for vs in self._left_side.string_variables:
            for i in range(1,len(vs)):
                end_prev = span_map[vs[i-1]][1]
                begin_curr = span_map[vs[i]][0]

                if end_prev != begin_curr:
                    raise ValueError(
                        f"Spans {span_map[vs[i-1]]} and {span_map[vs[i]]} "
                        f"must be adjacent according to {self} but they "
                        "are not."
                    )
                
            begin_span = span_map[vs[0]][0]
            end_span = span_map[vs[-1]][1]

            new_spans.append((begin_span, end_span))

        return MCFGRuleElementInstance(
            self._left_side.variable, *new_spans
        )

    
    def _build_span_map(self, right_side: tuple[MCFGRuleElementInstance, ...]) -> SpanMap:
        """Construct a mapping from string variables to string spans"""
        
        if self._right_side_aligns(right_side):
            return {
                strvar[0]: strspan
                for elem, eleminst in zip(
                    self._right_side,
                    right_side
                )
                for strvar, strspan in zip(
                    elem.string_variables,
                    eleminst.string_spans
                )
            }
        else:
            raise ValueError(
                f"Instantiated right side {right_side} do not "
                f"align with rule's right side {self._right_side}"
            )

    def _right_side_aligns(self, right_side: tuple[MCFGRuleElementInstance, ...]) -> bool:
        """Check whether the right side aligns"""

        if len(right_side) == len(self._right_side):
            vars_match = all(
                elem.variable == eleminst.variable
                for elem, eleminst in zip(self._right_side, right_side)
            )
            strvars_match = all(
                len(elem.string_variables) == len(eleminst.string_spans)
                for elem, eleminst in zip(self._right_side, right_side)
            )

            return vars_match and strvars_match
        else:
            return False 

    @classmethod
    def from_string(cls, rule_string) -> 'MCFGRule':
        elem_strs = re.findall('(\w+)\(((?:\w+,? ?)+?)\)', rule_string)

        elem_tuples = [(var, [v.strip()
                              for v in svs.split(',')])
                       for var, svs in elem_strs]

        if len(elem_tuples) == 1:
            return cls(MCFGRuleElement(elem_tuples[0][0],
                                   tuple(w for w in elem_tuples[0][1])))

        else:
            strvars = [v for _, sv in elem_tuples[1:] for v in sv]

            # no duplicate string variables
            try:
                assert len(strvars) == len(set(strvars))
            except AssertionError:
                msg = 'variables duplicated on right side of '+rule_string
                raise ValueError(msg)

            
            elem_left = MCFGRuleElement(elem_tuples[0][0],
                                    *[tuple([strvars.index(v)
                                             for v in re.findall('('+'|'.join(strvars)+')', vs)])
                                      for vs in elem_tuples[0][1]])

            elems_right = [MCFGRuleElement(var, *[(strvars.index(sv),)
                                              for sv in svs])
                           for var, svs in elem_tuples[1:]]

            return cls(elem_left, *elems_right)
        
    def string_yield(self):
        if self.is_epsilon:
            return self._left_side.variable
        else:
            raise ValueError(
                'string_yield is only implemented for epsilon rules'
            )
            

In [8]:
rule = MCFGRule.from_string('A(w1u, x1v) -> B(w1, x1) C(u, v)')

rule

<Rule: A(02, 13) -> B(0, 1) C(2, 3)>

In [10]:
# tests/test_grammar.py

#import pytest
#from src.mcfg_parser.grammar import MCFGRule, MCFGRuleElement, MCFGRuleElementInstance

@pytest.fixture
def test_grammar():
    """
    Provides a sample set of grammar rules to test the parser against.
    """
    return [
        'S(uv) -> NP(u) VP(v)',
        'VP(uv) -> Vpres(u) NP(v)',
        'NP(uv) -> D(u) N(v)',
        'D(the)',
        'N(greyhound)'
    ]

def test_rule_element():
    """
    Test the initialization and functionality of MCFGRuleElement.
    """
    elem = MCFGRuleElement('VP', (0,), (1,))
    assert str(elem) == 'VP(0, 1)'

def test_rule_element_instance():
    """
    Test the initialization and functionality of MCFGRuleElementInstance.
    """
    instance = MCFGRuleElementInstance('NP', (0, 2), (1,))
    assert str(instance) == 'NP([0, 2], [1])'

def test_mcfg_rule(test_grammar):
    """
    Test creating and manipulating MCFG rules.
    """
    rule = MCFGRule.from_string(test_grammar[0])
    assert str(rule) == 'S(uv) -> NP(u) VP(v)'
    assert rule.left_side.variable == 'S'
    assert rule.right_side[0].variable == 'NP'
    assert rule.right_side[1].variable == 'VP'


In [None]:
# Additional tests in tests/test_grammar.py

import pytest
from src.mcfg_parser.grammar import MCFGRule, MCFGRuleElement, MCFGRuleElementInstance

def test_empty_rule():
    """
    Test that creating an empty rule raises an appropriate error.
    """
    with pytest.raises(ValueError, match='right side'):
        MCFGRule.from_string('S()')

def test_invalid_format():
    """
    Test that a rule with invalid format raises an appropriate error.
    """
    invalid_rule = 'InvalidRule -> NP(u)'
    with pytest.raises(ValueError, match='variables duplicated'):
        MCFGRule.from_string(invalid_rule)

def test_multiple_non_terminals():
    """
    Test a rule with multiple non-terminals in the right side.
    """
    rule_str = 'S(uv, w) -> NP(u) VP(v) Aux(w)'
    rule = MCFGRule.from_string(rule_str)
    assert str(rule) == 'S(uv, w) -> NP(u) VP(v) Aux(w)'
    assert rule.left_side.variable == 'S'
    assert len(rule.right_side) == 3

def test_epsilon_rule():
    """
    Test an epsilon rule with an empty right-hand side.
    """
    epsilon_rule = 'EmptyRule()'
    rule = MCFGRule.from_string(epsilon_rule)
    assert rule.is_epsilon
    assert str(rule) == 'EmptyRule()'

def test_rule_element_equality():
    """
    Test equality checking for rule elements.
    """
    elem1 = MCFGRuleElement('VP', (0,), (1,))
    elem2 = MCFGRuleElement('VP', (0,), (1,))
    assert elem1 == elem2

    elem3 = MCFGRuleElement('NP', (0,), (1,))
    assert elem1 != elem3

def test_rule_element_instance_span_checking():
    """
    Test span checks for rule element instances.
    """
    instance1 = MCFGRuleElementInstance('NP', (0, 2), (1,))
    instance2 = MCFGRuleElementInstance('NP', (0, 2), (1,))
    instance3 = MCFGRuleElementInstance('VP', (0, 1), (2,))
    
    assert instance1 == instance2
    assert instance1 != instance3
