In [18]:
import argparse
import os
import random


class PCFG:
    """
    PCFG to sample sentences from
    """
    def __init__(self, grammar_file):
        self.rules = None
        self.change_rules = None
        self.load_rules(grammar_file)

    def load_rules(self, grammar_file):
        new_rules = {}
        change = {}
        g_file = open(grammar_file, 'r')
        lines = g_file.readlines()
        for l in lines:
            if l.startswith(('#', " ", "\t", "\n")) or len(l) < 1:
                continue
            else:
                if l.find("#") != -1:
                    l = l[:l.find("#")]
                idx = -1
                if len(l.rstrip().split("\t")) == 3:
                    weight, lhs, rhs = l.rstrip().split("\t")
                elif len(l.rstrip().split("\t")) == 4:
                    weight, lhs, rhs, idx = l.rstrip().split("\t")
                if lhs not in new_rules.keys():
                    new_rules[lhs] = []
                poss_rhs = new_rules[lhs]
                poss_rhs.append([rhs, float(weight)])
                if idx != -1:
                    change[lhs + "\t" + rhs] = idx
        for lhs, poss in new_rules.items():
            total = 0
            for rhs in poss:
                print(rhs[1])
                total += rhs[1]
            for rhs in poss:
                rhs[1] /= total
        self.rules = new_rules
        self.change_rules = change

    def sample_sentence(self, max_expansions, bracketing):
        self.expansions = 0
        done = False
        sent = ["ROOT"]
        idx = 0
        while not done:
            if sent[idx] not in self.rules.keys():
                idx += 1
                if idx >= len(sent):
                    done = True
                continue
            else:
                replace, change_idx = self.expand(sent[idx])
                if bracketing:
                    if change_idx == -1:
                        sent = (sent[:idx]
                            + ["(", sent[idx]] + replace + [")"]
                            + sent[idx + 1:])
                    else:
                        sent = (sent[:idx]
                            + ["(", change_idx + sent[idx]] + replace + [")"]
                            + sent[idx + 1:])
                else:
                    sent = sent[:idx] + replace  + sent[idx + 1:]
                self.expansions += 1
                if bracketing:
                    idx += 2
                if self.expansions > max_expansions:
                    done = True
                if idx >= len(sent):
                    done = True
        if self.expansions > max_expansions:
            for idx in range(len(sent)):
                if not bracketing:
                    if sent[idx] in self.rules.keys():
                        sent[idx] = "..."
                else:
                    if sent[idx] in self.rules.keys() and sent[idx - 1] != "(":
                        sent[idx] = "..."
        return ' '.join(sent)

    def expand(self, symbol):
        poss = self.rules[symbol]
        sample = random.random()
        val = 0.0
        rhs = ""
        idx = -1
        for p in poss:
            val += p[1]
            if sample <= val:
                if symbol + "\t" + p[0] in self.change_rules.keys():
                    idx = self.change_rules[symbol + "\t" + p[0]]
                rhs = p[0]
                break
        return rhs.split(" "), idx

## Assign Zip weights

In [23]:
def update_grammar_with_zipf_weights(grammar_file, output_file=None):
    def is_terminal(symbol):
        return not symbol[0].isupper() or symbol == '[eos]'

    def is_nonterminal(symbol):
        return symbol[0].isupper()

    # 終端記号への遷移ルールを左辺ごとにグループ化
    terminal_rules_by_lhs = {}

    # ファイルを読んで終端記号ルールを特定
    with open(grammar_file, 'r') as f:
        lines = f.readlines()

    for line in lines:
        if line.startswith(('#', " ", "\t", "\n")) or len(line) < 1:
            continue
        if line.find("#") != -1:
            line = line[:line.find("#")]

        parts = line.strip().split("\t")
        if len(parts) >= 3:
            lhs, rhs = parts[1], parts[2]
            # Nonterminal -> Terminal ルールを特定
            # 右辺が単一の要素であることを確認
            if is_nonterminal(lhs) and ' ' not in rhs and is_terminal(rhs):
                if lhs not in terminal_rules_by_lhs:
                    terminal_rules_by_lhs[lhs] = []
                terminal_rules_by_lhs[lhs].append((parts[0], rhs, parts[3:] if len(parts) > 3 else []))

    # 新しいweightを割り当てて書き戻す
    new_lines = []
    for line in lines:
        if line.startswith(('#', " ", "\t", "\n")) or len(line) < 1:
            new_lines.append(line)
            continue

        if line.find("#") != -1:
            line = line[:line.find("#")]

        parts = line.strip().split("\t")
        if len(parts) >= 3:
            lhs, rhs = parts[1], parts[2]

            if is_nonterminal(lhs) and ' ' not in rhs and is_terminal(rhs) and lhs in terminal_rules_by_lhs:
                # Nonterminal -> Terminal ルールの場合、ランクを重みとして設定
                rules = terminal_rules_by_lhs[lhs]
                rank = rules.index((parts[0], rhs, parts[3:] if len(parts) > 3 else [])) + 1
                new_line = f"{rank}\t{lhs}\t{rhs}"
                if len(parts) > 3:
                    new_line += '\t' + '\t'.join(parts[3:])
                new_lines.append(new_line + '\n')
            else:
                # それ以外のルールの場合はそのまま
                new_lines.append(line)

    # ファイルに書き戻す
    if output_file:
        with open(output_file, 'w') as f:
            f.writelines(new_lines)
    return new_lines


In [24]:
base_grammar_path = "/home/agiats/Projects/impossible_inherent_entropy/data_gen/base-grammar_eos.gr"
zipf_grammar_path = "/home/agiats/Projects/impossible_inherent_entropy/data_gen/base-grammar_eos_zipf.gr"
print(''.join(update_grammar_with_zipf_weights(base_grammar_path)))

1	ROOT	S
1	S	NP_Subj_S VP_S EOS	1
1	S	NP_Subj_P VP_P EOS	1
1	S'	NP_Subj_S VP_S	1
1	S'	NP_Subj_P VP_P	1
1	VP_S	VP_Past_S
1	VP_S	VP_Pres_S
1	VP_S	VP_Comp_S
1	VP_P	VP_Past_P
1	VP_P	VP_Pres_P
1	VP_P	VP_Comp_P
1	VP_Comp_S	VP_Comp_Pres_S
1	VP_Comp_S	VP_Comp_Past_S
1	VP_Comp_P	VP_Comp_Pres_P
1	VP_Comp_P	VP_Comp_Past_P
1	NP_Subj_S	NP_S Subj
1	NP_Subj_P	NP_P Subj
1	VP_Past_S	IVerb_Past_S
1	VP_Pres_S	IVerb_Pres_S
1	VP_Past_P	IVerb_Past_P
1	VP_Pres_P	IVerb_Pres_P
1	VP_Past_S	NP_Obj TVerb_Past_S	2
1	VP_Pres_S	NP_Obj TVerb_Pres_S	2
1	VP_Comp_Pres_S	S_Comp Verb_Comp_Pres_S	2
1	VP_Comp_Past_S	S_Comp Verb_Comp_Past_S	2
1	VP_Past_P	NP_Obj TVerb_Past_P	2
1	VP_Pres_P	NP_Obj TVerb_Pres_P	2
1	VP_Comp_Pres_P	S_Comp Verb_Comp_Pres_P	2
1	VP_Comp_Past_P	S_Comp Verb_Comp_Past_P	2
1	S_Comp	S' Comp	3
1	NP_Obj	NP_S Obj
1	NP_Obj	NP_P Obj
1	NP_S	Noun_S
1	NP_P	Noun_P
0.2	NP_S	PP NP_S	4
0.2	NP_P	PP NP_P	4
1	PP	NP_S Prep	4
1	PP	NP_P Prep	4
1	NP_S	Adj Noun_S	5
1	NP_P	Adj Noun_P	5
1	Adj	Adj CC Adj
1	NP_S	VP_S Rel Noun_S	