In [2]:
import re
import random
from collections import ChainMap

#### Create grammar with terminal, non-terminal hash

In [None]:
non_terminal_pattern = r"^[A-Z]+$"
pre_terminal_pattern = r"^[A-Za-z]{2,}$"
terminal_pattern = r"(^[a-z\s]+)$"

def _spit_probable_choice(items, weights):
    return random.choices(items, weights, k=1)[0]

grammar_text = open("grammar.gr", "rb").read().decode("utf-8")


def _is_non_terminal(text):
    return bool(re.search(non_terminal_pattern, text))

def _is_pre_terminal(text):
    return bool(re.search(pre_terminal_pattern, text))

def _is_terminal(text):
    return bool(re.search(terminal_pattern, text))

def _create_grammar_hash(grammar_text):
    def _is_comment(line):
        if re.search(r"^[#\s()]", line):
            return 1

    valid_symbols = list(filter(lambda text: len(text)>1 and not _is_comment(text), grammar_text.split("\n")))
    cleaned_valid_symbols = list(map(lambda line: line.split("#")[0].strip().split("\t"), valid_symbols))

    non_terminal_hash = {}
    terminal_hash = {}
    for symbol in cleaned_valid_symbols:
        if _is_terminal(symbol[2]):
            if symbol[1] not in terminal_hash:
                terminal_hash[symbol[1]] = {str(symbol[2]): float(symbol[0])}
                continue
            terminal_hash[symbol[1]].update({str(symbol[2]): float(symbol[0])})
            continue
        else:
            print(symbol)
            if symbol[1] not in non_terminal_hash:
                non_terminal_hash[symbol[1]] = {tuple(map(str, symbol[2].split())): float(symbol[0])}
                continue
            non_terminal_hash[symbol[1]].update({tuple(map(str, symbol[2].split())): float(symbol[0])})
    merged = non_terminal_hash.copy()
    merged.update(terminal_hash)
    return terminal_hash, non_terminal_hash, merged

#### Create a combined grammar hash

In [None]:
non_terminal_pattern = r"^[A-Z]+$"
pre_terminal_pattern = r"^[A-Za-z]{2,}$"
terminal_pattern = r"(^[a-z\s]+)$"

def _spit_probable_choice(items, weights):
    return random.choices(items, weights, k=1)[0]

grammar_text = open("grammar.gr", "rb").read().decode("utf-8")


def _is_non_terminal(text):
    return bool(re.search(non_terminal_pattern, text))

def _is_pre_terminal(text):
    return bool(re.search(pre_terminal_pattern, text))

def _is_terminal(text):
    return bool(re.search(terminal_pattern, text))

def _create_grammar_hash(grammar_text):
    def _is_comment(line):
        if re.search(r"^[#\s()]", line):
            return 1

    valid_symbols = list(filter(lambda text: len(text)>1 and not _is_comment(text), grammar_text.split("\n")))
    cleaned_valid_symbols = list(map(lambda line: line.split("#")[0].strip().split("\t"), valid_symbols))

    grammar_hash = {}
    for symbol in cleaned_valid_symbols:
        if _is_terminal(symbol[2]):        
            if symbol[1] not in grammar_hash:
                grammar_hash[symbol[1]] = {str(symbol[2]): float(symbol[0])}
                continue
            grammar_hash[symbol[1]].update({str(symbol[2]): float(symbol[0])})
            continue
        else:
            if symbol[1] not in grammar_hash:
                grammar_hash[symbol[1]] = {tuple(map(str, symbol[2].split())): float(symbol[0])}
                continue
            grammar_hash[symbol[1]].update({tuple(map(str, symbol[2].split())): float(symbol[0])})
            continue
    return grammar_hash

In [5]:
grammar_hash = _create_grammar_hash(grammar_text)

#### Grammar Code

In [16]:
class Grammar:
    def __init__(self, grammar_file):
        """
        Context-Free Grammar (CFG) Sentence Generator

        Args:
            grammar_file (str): Path to a .gr grammar file

        Returns:
            self
        """
        # Parse the input grammar file
        self.rules = None
        self._load_rules_from_file(grammar_file)

    @staticmethod
    def _select_probable_choice(items, weights):
        return random.choices(items, weights, k=1)[0]

    @staticmethod
    def _convert_tokens_to_sentence(tokens):
        sentence = " ".join(tokens)
        sentence = re.sub(r"(?<!\.)\s+([.!?])", r"\1", sentence)
        sentence = re.sub(r"\s+'", "'", sentence)
        return sentence

    @staticmethod
    def _is_terminal(text):
        return bool(re.search(terminal_pattern, text))

    def _load_rules_from_file(self, grammar_file):
        """
        Read grammar file and store its rules in self.rules

        Args:
            grammar_file (str): Path to the raw grammar file
        """

        def _is_comment(line):
            if re.search(r"^[#\s()]", line):
                return 1

        grammar_text = open(grammar_file, "rb").read().decode("utf-8")
        
        valid_symbols = list(
            filter(
                lambda text: len(text) > 1 and not _is_comment(text),
                grammar_text.split("\n"),
            )
        )
        cleaned_valid_symbols = list(
            map(lambda line: line.split("#")[0].strip().split("\t"), valid_symbols)
        )

        grammar_hash = {}
        for symbol in cleaned_valid_symbols:
            if self._is_terminal(symbol[2]):
                if symbol[1] not in grammar_hash:
                    grammar_hash[symbol[1]] = {str(symbol[2]): float(symbol[0])}
                    continue
                grammar_hash[symbol[1]].update({str(symbol[2]): float(symbol[0])})
                continue
            else:
                if symbol[1] not in grammar_hash:
                    grammar_hash[symbol[1]] = {
                        tuple(map(str, symbol[2].split())): float(symbol[0])
                    }
                    continue
                grammar_hash[symbol[1]].update(
                    {tuple(map(str, symbol[2].split())): float(symbol[0])}
                )
                continue
        self.rules = grammar_hash

    def sample(self, derivation_tree, max_expansions, start_symbol):
        """
        Sample a random sentence from this grammar

        Args:
            derivation_tree (bool): if true, the returned string will represent
                the tree (using bracket notation) that records how the sentence
                was derived
            max_expansions (int): max number of nonterminal expansions we allow

            start_symbol (str): start symbol to generate from

        Returns:
            str: the random sentence or its derivation tree
        """
        if start_symbol not in self.rules:
            return start_symbol.strip()

        def _is_nonterminal(sym):
            return sym.isupper()
        
        n_expansions = 0
        def _tree_expand(symbol):
            nonlocal n_expansions

            if symbol not in self.rules:
                tokens = symbol.split()
                return tokens, symbol

            if _is_nonterminal(symbol):
                n_expansions += 1
                if n_expansions > max_expansions:
                    return ["..."], f"({symbol} ...)"

            items = list(self.rules[symbol].keys())
            weights = list(self.rules[symbol].values())
            right_split = self._select_probable_choice(items=items, weights=weights)

            if isinstance(right_split, (tuple, list)):
                tokens_list, subtrees = [], []                
                for daughter in right_split:
                    if (daughter in self.rules) and \
                        (_is_nonterminal(daughter)) and \
                            (n_expansions >= max_expansions):
                        tokens_list.append("...")
                        subtrees.append("...")
                        continue
                    tokens, subtree = _tree_expand(daughter)
                    tokens_list.extend(tokens)
                    subtrees.append(subtree)
                return tokens_list, f"({symbol} {' '.join(subtrees)})"
            else:
                return right_split.split(), f"({symbol} {right_split})"
        
        tokens, tree_str = _tree_expand(start_symbol)
        return tree_str if derivation_tree else self._convert_tokens_to_sentence(tokens)


In [17]:
grammar = Grammar(grammar_file="grammar.gr")
grammar.sample(derivation_tree=True, max_expansions=3, start_symbol="ROOT")

'(ROOT (S (NP ... ...) ...) .)'

In [24]:
for i in range(3):
    print(f"{i+1}. {grammar.sample(derivation_tree=False, max_expansions=450, start_symbol='ROOT')}")

1. every pickle in every delicious pickle in the pickled chief of staff ate every sandwich on every sandwich under every floor on a pickle in a floor under the chief of staff on the floor on a president on a pickle!
2. is it true that a pickle on the pickle with every pickle kissed every floor?
3. every chief of staff with a pickle in a pickle in a floor in every president under the sandwich on a floor under every president in the pickle under the delicious sandwich in the sandwich under the president on the chief of staff in every president on every fine perplexed pickle under a pickle with the chief of staff understood a floor on the delicious sandwich with the pickle with the pickled floor!
