In [3]:
import argparse
import os
import random


class PCFG:
    """
    PCFG to sample sentences from
    """
    def __init__(self, grammar_file):
        self.rules = None
        self.change_rules = None
        self.load_rules(grammar_file)

    def load_rules(self, grammar_file):
        new_rules = {}
        change = {}
        g_file = open(grammar_file, 'r')
        lines = g_file.readlines()
        for l in lines:
            if l.startswith(('#', " ", "\t", "\n")) or len(l) < 1:
                continue
            else:
                if l.find("#") != -1:
                    l = l[:l.find("#")]
                idx = -1
                if len(l.rstrip().split("\t")) == 3:
                    weight, lhs, rhs = l.rstrip().split("\t")
                elif len(l.rstrip().split("\t")) == 4:
                    weight, lhs, rhs, idx = l.rstrip().split("\t")
                if lhs not in new_rules.keys():
                    new_rules[lhs] = []
                poss_rhs = new_rules[lhs]
                poss_rhs.append([rhs, float(weight)])
                if idx != -1:
                    change[lhs + "\t" + rhs] = idx
        for lhs, poss in new_rules.items():
            total = 0
            for rhs in poss:
                total += rhs[1]
            for rhs in poss:
                rhs[1] /= total
        self.rules = new_rules
        self.change_rules = change

    def sample_sentence(self, max_expansions, bracketing):
        self.expansions = 0
        done = False
        sent = ["ROOT"]
        idx = 0
        while not done:
            if sent[idx] not in self.rules.keys():
                idx += 1
                if idx >= len(sent):
                    done = True
                continue
            else:
                replace, change_idx = self.expand(sent[idx])
                if bracketing:
                    if change_idx == -1:
                        sent = (sent[:idx]
                            + ["(", sent[idx]] + replace + [")"]
                            + sent[idx + 1:])
                    else:
                        sent = (sent[:idx]
                            + ["(", change_idx + sent[idx]] + replace + [")"]
                            + sent[idx + 1:])
                else:
                    sent = sent[:idx] + replace  + sent[idx + 1:]
                self.expansions += 1
                if bracketing:
                    idx += 2
                if self.expansions > max_expansions:
                    done = True
                if idx >= len(sent):
                    done = True
        if self.expansions > max_expansions:
            print("Max expansions reached")
            for idx in range(len(sent)):
                if not bracketing:
                    if sent[idx] in self.rules.keys():
                        sent[idx] = "..."
                else:
                    if sent[idx] in self.rules.keys() and sent[idx - 1] != "(":
                        sent[idx] = "..."
        return ' '.join(sent)

    def expand(self, symbol):
        poss = self.rules[symbol]
        sample = random.random()
        val = 0.0
        rhs = ""
        idx = -1
        for p in poss:
            val += p[1]
            if sample <= val:
                if symbol + "\t" + p[0] in self.change_rules.keys():
                    idx = self.change_rules[symbol + "\t" + p[0]]
                rhs = p[0]
                break
        return rhs.split(" "), idx



In [5]:
# monte carlo sampling and estimate probability of a sentence
n_sample = 1_000_000

def monte_carlo_estimation(grammar_file, max_expansions, n_sample):
    sentence_to_count = {}
    grammar = PCFG(grammar_file)
    for i in range(n_sample):
        sent = grammar.sample_sentence(max_expansions=max_expansions, False)
        if sent not in sentence_to_count.keys():
            sentence_to_count[sent] = 0
        sentence_to_count[sent] += 1

    return sentence_to_count




In [6]:
n_sample = 1_000_000
sentence_to_count = monte_carlo_estimation("data_gen/base-grammar_eos.gr", 400, n_sample)

Max expansions reached
Max expansions reached
Max expansions reached


In [7]:
sentence_to_prob = {k: v / n_sample for k, v in sentence_to_count.items()}

{'brelland sub calikated sa embarricate rel doloners sub inthippenated rel amackist ob blamiciateda [eos]': 1,
 'gloralizers sub noachate [eos]': 1,
 'crotifers sub la sub filk sa fancheda rel tibs ob baniticate [eos]': 1,
 'botig denaticians sub naid shagoners sub froachereda sa thespireda [eos]': 1,
 'revalternates rel bactivitor sub zurched [eos]': 1,
 'naid wainsters sub flingerists ob chozed rel senth ob wandify sa embarricateda rel trachiers sub bunessistes rel businer sub garfs sa unseff [eos]': 1,
 'pi sub thespirates [eos]': 31,
 'naid stroaters sub me ob dupresse rel briticists ob croil [eos]': 1,
 'me sub parled [eos]': 34,
 'spean duskists sub dorpals sub parls rel pleddist sub clim milturor sub hurperates sa kurched sa humicianeda sa mangere [eos]': 1,
 'kuestors sub thespirateda [eos]': 2,
 'bilth ampercials hi skeer penursifers sub calikateda [eos]': 1,
 'narters sub protter sub flapperated sa nobeda rel visivets sub threak buestors sub onstigipateda sa crailed rel sheri