In [241]:
import re
import random
from collections import ChainMap

In [242]:
## Create grammar with terminal, non-terminal hash
non_terminal_pattern = r"^[A-Z]+$"
pre_terminal_pattern = r"^[A-Za-z]{2,}$"
terminal_pattern = r"(^[a-z\s]+)$"

def _spit_probable_choice(items, weights):
    return random.choices(items, weights, k=1)[0]

grammar_text = open("grammar.gr", "rb").read().decode("utf-8")


def _is_non_terminal(text):
    return bool(re.search(non_terminal_pattern, text))

def _is_pre_terminal(text):
    return bool(re.search(pre_terminal_pattern, text))

def _is_terminal(text):
    return bool(re.search(terminal_pattern, text))

def _create_grammar_hash(grammar_text):
    def _is_comment(line):
        if re.search(r"^[#\s()]", line):
            return 1

    valid_symbols = list(filter(lambda text: len(text)>1 and not _is_comment(text), grammar_text.split("\n")))
    cleaned_valid_symbols = list(map(lambda line: line.split("#")[0].strip().split("\t"), valid_symbols))

    non_terminal_hash = {}
    terminal_hash = {}
    for symbol in cleaned_valid_symbols:
        if _is_terminal(symbol[2]):
            if symbol[1] not in terminal_hash:
                terminal_hash[symbol[1]] = {str(symbol[2]): float(symbol[0])}
                continue
            terminal_hash[symbol[1]].update({str(symbol[2]): float(symbol[0])})
            continue
        else:
            print(symbol)
            if symbol[1] not in non_terminal_hash:
                non_terminal_hash[symbol[1]] = {tuple(map(str, symbol[2].split())): float(symbol[0])}
                continue
            non_terminal_hash[symbol[1]].update({tuple(map(str, symbol[2].split())): float(symbol[0])})
    merged = non_terminal_hash.copy()
    merged.update(terminal_hash)
    return terminal_hash, non_terminal_hash, merged

In [243]:
## Create a combined grammar hash

non_terminal_pattern = r"^[A-Z]+$"
pre_terminal_pattern = r"^[A-Za-z]{2,}$"
terminal_pattern = r"(^[a-z\s]+)$"

def _spit_probable_choice(items, weights):
    return random.choices(items, weights, k=1)[0]

grammar_text = open("grammar.gr", "rb").read().decode("utf-8")


def _is_non_terminal(text):
    return bool(re.search(non_terminal_pattern, text))

def _is_pre_terminal(text):
    return bool(re.search(pre_terminal_pattern, text))

def _is_terminal(text):
    return bool(re.search(terminal_pattern, text))

def _create_grammar_hash(grammar_text):
    def _is_comment(line):
        if re.search(r"^[#\s()]", line):
            return 1

    valid_symbols = list(filter(lambda text: len(text)>1 and not _is_comment(text), grammar_text.split("\n")))
    cleaned_valid_symbols = list(map(lambda line: line.split("#")[0].strip().split("\t"), valid_symbols))

    grammar_hash = {}
    for symbol in cleaned_valid_symbols:
        if _is_terminal(symbol[2]):        
            if symbol[1] not in grammar_hash:
                grammar_hash[symbol[1]] = {str(symbol[2]): float(symbol[0])}
                continue
            grammar_hash[symbol[1]].update({str(symbol[2]): float(symbol[0])})
            continue
        else:
            if symbol[1] not in grammar_hash:
                grammar_hash[symbol[1]] = {tuple(map(str, symbol[2].split())): float(symbol[0])}
                continue
            grammar_hash[symbol[1]].update({tuple(map(str, symbol[2].split())): float(symbol[0])})
            continue
    return grammar_hash

In [244]:
grammar_hash = _create_grammar_hash(grammar_text)

In [234]:
class Grammar:
    def __init__(self, grammar_file):
        """
        Context-Free Grammar (CFG) Sentence Generator

        Args:
            grammar_file (str): Path to a .gr grammar file
        
        Returns:
            self
        """
        # Parse the input grammar file
        self.rules = None
        self._load_rules_from_file(grammar_file)

    @staticmethod
    def _spit_probable_choice(items, weights):
        return random.choices(items, weights, k=1)[0]

    @staticmethod
    def _convert_tokens_to_sentence(tokens):
        sentence = " ".join(tokens)
        sentence = re.sub(r"\s+([.!?])", r"\1", sentence)
        sentence = re.sub(r"\s+'", "'", sentence)
        return sentence

    def _load_rules_from_file(self, grammar_file):
        """
        Read grammar file and store its rules in self.rules

        Args:
            grammar_file (str): Path to the raw grammar file 
        """
        def _is_comment(line):
            if re.search(r"^[#\s()]", line):
                return 1

        grammar_text = open(grammar_file, "rb").read().decode("utf-8")

        def _is_terminal(text):
            return bool(re.search(terminal_pattern, text))


        valid_symbols = list(filter(lambda text: len(text)>1 and not _is_comment(text), grammar_text.split("\n")))
        cleaned_valid_symbols = list(map(lambda line: line.split("#")[0].strip().split("\t"), valid_symbols))

        grammar_hash = {}
        for symbol in cleaned_valid_symbols:
            if _is_terminal(symbol[2]):        
                if symbol[1] not in grammar_hash:
                    grammar_hash[symbol[1]] = {str(symbol[2]): float(symbol[0])}
                    continue
                grammar_hash[symbol[1]].update({str(symbol[2]): float(symbol[0])})
                continue
            else:
                if symbol[1] not in grammar_hash:
                    grammar_hash[symbol[1]] = {tuple(map(str, symbol[2].split())): float(symbol[0])}
                    continue
                grammar_hash[symbol[1]].update({tuple(map(str, symbol[2].split())): float(symbol[0])})
                continue
        self.rules = grammar_hash

    def sample(self, derivation_tree, max_expansions, start_symbol):
        """
        Sample a random sentence from this grammar

        Args:
            derivation_tree (bool): if true, the returned string will represent 
                the tree (using bracket notation) that records how the sentence 
                was derived
            max_expansions (int): max number of nonterminal expansions we allow

            start_symbol (str): start symbol to generate from

        Returns:
            str: the random sentence or its derivation tree
        """
        
        if start_symbol not in self.rules:
            return start_symbol.strip()
        
        def _tree_expand(symbol, n_expansions=0):
            if (symbol not in self.rules) or (n_expansions >= max_expansions):
                tokens = symbol.split()
                tree = symbol
                return tokens, tree
            
            items = list(self.rules[symbol].keys())
            weights = list(self.rules[symbol].values())
            right_split = self._spit_probable_choice(items=items, weights=weights)
            
            if isinstance(right_split, tuple):
                tokens_list, subtrees = [], []
                for child in right_split:
                    tokens, tree = _tree_expand(child, n_expansions+1)
                    tokens_list.extend(tokens)
                    subtrees.append(tree)
                return tokens_list, f"({symbol} {' '.join(subtrees)})"
            else:
                return right_split.split(), f"({symbol} {right_split})"
        
        tokens, tree_str = _tree_expand(start_symbol)
        if derivation_tree:
            return tree_str
        return self._convert_tokens_to_sentence(tokens)

In [None]:
grammar = Grammar(grammar_file="grammar.gr")
grammar.sample(derivation_tree=True, max_expansions=10, start_symbol="ROOT")