In [4]:
import argparse
import os
import random


class PCFG:
    """
    PCFG to sample sentences from
    """
    def __init__(self, grammar_file):
        self.rules = None
        self.change_rules = None
        self.load_rules(grammar_file)

    def load_rules(self, grammar_file):
        new_rules = {}
        change = {}
        g_file = open(grammar_file, 'r')
        lines = g_file.readlines()
        for l in lines:
            if l.startswith(('#', " ", "\t", "\n")) or len(l) < 1:
                continue
            else:
                if l.find("#") != -1:
                    l = l[:l.find("#")]
                idx = -1
                if len(l.rstrip().split("\t")) == 3:
                    weight, lhs, rhs = l.rstrip().split("\t")
                elif len(l.rstrip().split("\t")) == 4:
                    weight, lhs, rhs, idx = l.rstrip().split("\t")
                if lhs not in new_rules.keys():
                    new_rules[lhs] = []
                poss_rhs = new_rules[lhs]
                poss_rhs.append([rhs, float(weight)])
                if idx != -1:
                    change[lhs + "\t" + rhs] = idx
        for lhs, poss in new_rules.items():
            total = 0
            for rhs in poss:
                total += rhs[1]
            for rhs in poss:
                rhs[1] /= total
        self.rules = new_rules
        self.change_rules = change

    def sample_sentence(self, max_expansions, bracketing):
        self.expansions = 0
        done = False
        sent = ["ROOT"]
        idx = 0
        while not done:
            if sent[idx] not in self.rules.keys():
                idx += 1
                if idx >= len(sent):
                    done = True
                continue
            else:
                replace, change_idx = self.expand(sent[idx])
                if bracketing:
                    if change_idx == -1:
                        sent = (sent[:idx]
                            + ["(", sent[idx]] + replace + [")"]
                            + sent[idx + 1:])
                    else:
                        sent = (sent[:idx]
                            + ["(", change_idx + sent[idx]] + replace + [")"]
                            + sent[idx + 1:])
                else:
                    sent = sent[:idx] + replace  + sent[idx + 1:]
                self.expansions += 1
                if bracketing:
                    idx += 2
                if self.expansions > max_expansions:
                    done = True
                if idx >= len(sent):
                    done = True
        if self.expansions > max_expansions:
            print("Max expansions reached")
            for idx in range(len(sent)):
                if not bracketing:
                    if sent[idx] in self.rules.keys():
                        sent[idx] = "..."
                else:
                    if sent[idx] in self.rules.keys() and sent[idx - 1] != "(":
                        sent[idx] = "..."
        return ' '.join(sent)

    def expand(self, symbol):
        poss = self.rules[symbol]
        sample = random.random()
        val = 0.0
        rhs = ""
        idx = -1
        for p in poss:
            val += p[1]
            if sample <= val:
                if symbol + "\t" + p[0] in self.change_rules.keys():
                    idx = self.change_rules[symbol + "\t" + p[0]]
                rhs = p[0]
                break
        return rhs.split(" "), idx



In [5]:
# monte carlo sampling and estimate probability of a sentence
n_sample = 1_000_000

def monte_carlo_estimation(grammar_file, max_expansions, n_sample):
    sentence_to_count = {}
    grammar = PCFG(grammar_file)
    for i in range(n_sample):
        sent = grammar.sample_sentence(max_expansions=max_expansions, False)
        if sent not in sentence_to_count.keys():
            sentence_to_count[sent] = 0
        sentence_to_count[sent] += 1

    return sentence_to_count




In [6]:
n_sample = 1_000_000
sentence_to_count = monte_carlo_estimation("data_gen/base-grammar_eos.gr", 400, n_sample)

Max expansions reached
Max expansions reached
Max expansions reached


## Shuffle

In [22]:
# random.seed(42)
tokens = [1,2,3,4,5]
random.shuffle(tokens)
print(tokens)

[5, 4, 3, 1, 2]


In [73]:
class PCFG:
    def __init__(self, grammar_file):
        pass

    def sample_sentence(self, max_expansions, bracketing):
        return "1 2 3 4 5 [eos]"

class PCFGDeterministicShuffle(PCFG):
    def __init__(self, grammar_file, seed=42):
        super().__init__(grammar_file)
        self.seed = seed

    def sample_sentence(self, max_expansions, bracketing):
        sent = super().sample_sentence(max_expansions, bracketing)

        tokens = sent.split(' ')
        # Remove and store [eos]
        eos = tokens.pop() if tokens[-1] == '[eos]' else None

        random.seed(self.seed)
        random.shuffle(tokens)

        # Add back [eos]
        if eos:
            tokens.append(eos)
        return ' '.join(tokens)

class PCFGNonDeterministicShuffle(PCFG):
    def __init__(self, grammar_file):
        super().__init__(grammar_file)

    def sample_sentence(self, max_expansions, bracketing):
        sent = super().sample_sentence(max_expansions, bracketing)

        tokens = sent.split(' ')
        eos = tokens.pop() if tokens[-1] == '[eos]' else None

        random.shuffle(tokens)

        if eos:
            tokens.append(eos)
        return ' '.join(tokens)

class PCFGLocalShuffle(PCFG):
    def __init__(self, grammar_file, window=5, seed=42):
        super().__init__(grammar_file)
        self.window = window
        self.seed = seed

    def sample_sentence(self, max_expansions, bracketing):
        sent = super().sample_sentence(max_expansions, bracketing)

        tokens = sent.split(' ')
        eos = tokens.pop() if tokens[-1] == '[eos]' else None

        shuffled_tokens = []
        random.seed(self.seed)

        for i in range(0, len(tokens), self.window):
            batch = tokens[i:min(i+self.window, len(tokens))].copy()
            random.shuffle(batch)
            shuffled_tokens.extend(batch)

        if eos:
            shuffled_tokens.append(eos)
        return ' '.join(shuffled_tokens)

class PCFGEvenOddShuffle(PCFG):
    def __init__(self, grammar_file):
        super().__init__(grammar_file)

    def sample_sentence(self, max_expansions, bracketing):
        sent = super().sample_sentence(max_expansions, bracketing)

        tokens = sent.split(' ')
        eos = tokens.pop() if tokens[-1] == '[eos]' else None

        even = [tok for i, tok in enumerate(tokens) if i % 2 == 0]
        odd = [tok for i, tok in enumerate(tokens) if i % 2 != 0]
        shuffled = even + odd

        if eos:
            shuffled.append(eos)
        return ' '.join(shuffled)

In [75]:
pcfg_det = PCFGDeterministicShuffle("data_gen/base-grammar_eos.gr")
pcfg_det.sample_sentence(400, False)

'4 2 3 5 1 [eos]'

In [76]:
pcfg_nondet = PCFGNonDeterministicShuffle("data_gen/base-grammar_eos.gr")
pcfg_nondet.sample_sentence(400, False)

'4 3 1 5 2 [eos]'

In [77]:
pcfg_local = PCFGLocalShuffle("data_gen/base-grammar_eos.gr", window=2, seed=42)
pcfg_local.sample_sentence(400, False)

'2 1 4 3 5 [eos]'

In [78]:
pcfg_even_odd = PCFGEvenOddShuffle("data_gen/base-grammar_eos.gr")
pcfg_even_odd.sample_sentence(400, False)

'1 3 5 2 4 [eos]'

In [84]:
class PCFG:
    """
    PCFG to sample sentences from
    """
    def __init__(self, grammar_file):
        self.rules = None
        self.change_rules = None
        self.load_rules(grammar_file)

    def load_rules(self, grammar_file):
        new_rules = {}
        change = {}
        g_file = open(grammar_file, 'r')
        lines = g_file.readlines()
        for l in lines:
            if l.startswith(('#', " ", "\t", "\n")) or len(l) < 1:
                continue
            else:
                if l.find("#") != -1:
                    l = l[:l.find("#")]
                idx = -1
                if len(l.rstrip().split("\t")) == 3:
                    weight, lhs, rhs = l.rstrip().split("\t")
                elif len(l.rstrip().split("\t")) == 4:
                    weight, lhs, rhs, idx = l.rstrip().split("\t")
                if lhs not in new_rules.keys():
                    new_rules[lhs] = []
                poss_rhs = new_rules[lhs]
                poss_rhs.append([rhs, float(weight)])
                if idx != -1:
                    change[lhs + "\t" + rhs] = idx
        for lhs, poss in new_rules.items():
            total = 0
            for rhs in poss:
                total += rhs[1]
            for rhs in poss:
                rhs[1] /= total
        self.rules = new_rules
        self.change_rules = change

    def sample_sentence(self, max_expansions, bracketing):
        self.expansions = 0
        done = False
        sent = ["ROOT"]
        idx = 0
        while not done:
            if sent[idx] not in self.rules.keys():
                idx += 1
                if idx >= len(sent):
                    done = True
                continue
            else:
                replace, change_idx = self.expand(sent[idx])
                if bracketing:
                    if change_idx == -1:
                        sent = (sent[:idx]
                            + ["(", sent[idx]] + replace + [")"]
                            + sent[idx + 1:])
                    else:
                        sent = (sent[:idx]
                            + ["(", change_idx + sent[idx]] + replace + [")"]
                            + sent[idx + 1:])
                else:
                    sent = sent[:idx] + replace  + sent[idx + 1:]
                self.expansions += 1
                if bracketing:
                    idx += 2
                if self.expansions > max_expansions:
                    done = True
                if idx >= len(sent):
                    done = True
        if self.expansions > max_expansions:
            print("Max expansions reached")
            for idx in range(len(sent)):
                if not bracketing:
                    if sent[idx] in self.rules.keys():
                        sent[idx] = "..."
                else:
                    if sent[idx] in self.rules.keys() and sent[idx - 1] != "(":
                        sent[idx] = "..."
        return ' '.join(sent)

    def expand(self, symbol):
        poss = self.rules[symbol]
        sample = random.random()
        val = 0.0
        rhs = ""
        idx = -1
        for p in poss:
            val += p[1]
            if sample <= val:
                if symbol + "\t" + p[0] in self.change_rules.keys():
                    idx = self.change_rules[symbol + "\t" + p[0]]
                rhs = p[0]
                break
        return rhs.split(" "), idx

class PCFGDeterministicShuffle(PCFG):
    def __init__(self, grammar_file, seed=42):
        super().__init__(grammar_file)
        self.seed = seed

    def sample_sentence(self, max_expansions, bracketing):
        sent = super().sample_sentence(max_expansions, bracketing)

        tokens = sent.split(' ')
        # Remove and store [eos]
        eos = tokens.pop() if tokens[-1] == '[eos]' else None

        random.seed(self.seed)
        random.shuffle(tokens)

        # Add back [eos]
        if eos:
            tokens.append(eos)
        return ' '.join(tokens)

class PCFGNonDeterministicShuffle(PCFG):
    def __init__(self, grammar_file):
        super().__init__(grammar_file)

    def sample_sentence(self, max_expansions, bracketing):
        sent = super().sample_sentence(max_expansions, bracketing)

        tokens = sent.split(' ')
        eos = tokens.pop() if tokens[-1] == '[eos]' else None

        random.shuffle(tokens)

        if eos:
            tokens.append(eos)
        return ' '.join(tokens)

class PCFGLocalShuffle(PCFG):
    def __init__(self, grammar_file, window=5, seed=42):
        super().__init__(grammar_file)
        self.window = window
        self.seed = seed

    def sample_sentence(self, max_expansions, bracketing):
        sent = super().sample_sentence(max_expansions, bracketing)

        tokens = sent.split(' ')
        eos = tokens.pop() if tokens[-1] == '[eos]' else None

        shuffled_tokens = []
        random.seed(self.seed)

        for i in range(0, len(tokens), self.window):
            batch = tokens[i:min(i+self.window, len(tokens))].copy()
            random.shuffle(batch)
            shuffled_tokens.extend(batch)

        if eos:
            shuffled_tokens.append(eos)
        return ' '.join(shuffled_tokens)

class PCFGEvenOddShuffle(PCFG):
    def __init__(self, grammar_file):
        super().__init__(grammar_file)

    def sample_sentence(self, max_expansions, bracketing):
        sent = super().sample_sentence(max_expansions, bracketing)

        tokens = sent.split(' ')
        eos = tokens.pop() if tokens[-1] == '[eos]' else None

        even = [tok for i, tok in enumerate(tokens) if i % 2 == 0]
        odd = [tok for i, tok in enumerate(tokens) if i % 2 != 0]
        shuffled = even + odd

        if eos:
            shuffled.append(eos)
        return ' '.join(shuffled)


In [89]:
pcfg_det = PCFGDeterministicShuffle("data_gen/base-grammar_eos.gr", seed=42)
pcfg_det.sample_sentence(400, False)

'ob rel peachician strubdifies peagerizes dargine sub [eos]'

## Hop

In [None]:
MARKER_HOP_SINGLE = "SINGLE"
MARKER_HOP_PLURAL = "PLURAL"

class PCFGHop(PCFG):


## Reverse

In [5]:
MARKER_REVERSE = "REVERSE"


class PCFG:
    def __init__(self, grammar_file):
        pass

    def sample_sentence(self, max_expansions, bracketing):
        return "1 2 3 4 5 [eos]"

class PCFGNoReverse(PCFG):
    def __init__(self, grammar_file):
        super().__init__(grammar_file)

    def sample_sentence(self, max_expansions, bracketing):
        sent = super().sample_sentence(max_expansions, bracketing)
        if sent is None:
            return None
        tokens = sent.split(' ')

        # Remove and store [eos]
        eos = tokens.pop() if tokens[-1] == '[eos]' else None

        # Insert REVERSE marker at random position
        insert_pos = random.randint(0, len(tokens))
        tokens.insert(insert_pos, MARKER_REVERSE)

        if eos:
            tokens.append(eos)
        return ' '.join(tokens)

class PCFGPartialReverse(PCFG):
    def __init__(self, grammar_file):
        super().__init__(grammar_file)

    def sample_sentence(self, max_expansions, bracketing):
        sent = super().sample_sentence(max_expansions, bracketing)
        if sent is None:
            return None
        tokens = sent.split(' ')

        # Remove and store [eos]
        eos = tokens.pop() if tokens[-1] == '[eos]' else None

        # Insert REVERSE marker and reverse tokens after it
        insert_pos = random.randint(0, len(tokens))
        tokens.insert(insert_pos, MARKER_REVERSE)

        # Reverse tokens after the marker
        tokens[insert_pos+1:] = tokens[insert_pos+1:][::-1]

        if eos:
            tokens.append(eos)
        return ' '.join(tokens)

class PCFGFullReverse(PCFG):
    def __init__(self, grammar_file):
        super().__init__(grammar_file)

    def sample_sentence(self, max_expansions, bracketing):
        sent = super().sample_sentence(max_expansions, bracketing)
        if sent is None:
            return None
        tokens = sent.split(' ')

        # Remove and store [eos]
        eos = tokens.pop() if tokens[-1] == '[eos]' else None

        # Insert REVERSE marker at random position
        insert_pos = random.randint(0, len(tokens))
        tokens.insert(insert_pos, MARKER_REVERSE)

        # Reverse all tokens
        tokens = tokens[::-1]

        if eos:
            tokens.append(eos)
        return ' '.join(tokens)

In [22]:
pcfg_no_reverse = PCFGNoReverse("data_gen/base-grammar_eos.gr")
pcfg_no_reverse.sample_sentence(400, False)

'1 2 3 4 REVERSE 5 [eos]'

In [26]:
pcfg_partial_reverse = PCFGPartialReverse("data_gen/base-grammar_eos.gr")
pcfg_partial_reverse.sample_sentence(400, False)

'1 REVERSE 5 4 3 2 [eos]'

In [34]:
pcfg_full_reverse = PCFGFullReverse("data_gen/base-grammar_eos.gr")
pcfg_full_reverse.sample_sentence(400, False)

'5 4 3 REVERSE 2 1 [eos]'