/
wcfg.py
120 lines (94 loc) · 3.5 KB
/
wcfg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
@author wilkeraziz
"""
from collections import defaultdict, deque
from symbol import is_terminal
from rule import Rule
from math import log
class WCFG(object):
def __init__(self, rules=[]):
self._rules = []
self._rules_by_lhs = defaultdict(list)
self._terminals = set()
self._nonterminals = set()
for rule in rules:
self.add(rule)
def add(self, rule):
self._rules.append(rule)
self._rules_by_lhs[rule.lhs].append(rule)
self._nonterminals.add(rule.lhs)
for s in rule.rhs:
if is_terminal(s):
self._terminals.add(s)
else:
self._nonterminals.add(s)
def update(self, rules):
for rule in rules:
self.add(rule)
@property
def nonterminals(self):
return self._nonterminals
@property
def terminals(self):
return self._terminals
def __len__(self):
return len(self._rules)
def __getitem__(self, lhs):
return self._rules_by_lhs.get(lhs, frozenset())
def get(self, lhs, default=frozenset()):
return self._rules_by_lhs.get(lhs, frozenset())
def can_rewrite(self, lhs):
"""Whether a given nonterminal can be rewritten.
This may differ from ``self.is_nonterminal(symbol)`` which returns whether a symbol belongs
to the set of nonterminals of the grammar.
"""
return lhs in self._rules_by_lhs
def __iter__(self):
return iter(self._rules)
def iteritems(self):
return self._rules_by_lhs.iteritems()
def __str__(self):
lines = []
for lhs, rules in self.iteritems():
for rule in rules:
lines.append(str(rule))
return '\n'.join(lines)
def count_derivations(wcfg, root):
def recursion(derivation, projection, Q, wcfg, counts):
#print 'd:', '|'.join(str(r) for r in derivation)
#print 'p:', projection
#print 'Q:', Q
if Q:
sym = Q.popleft()
#print ' pop:', sym
if is_terminal(sym):
recursion(derivation, [sym] + projection, Q, wcfg, counts)
else:
for rule in wcfg[sym]:
#print ' rule:', rule
QQ = deque(Q)
QQ.extendleft(rule.rhs)
recursion(derivation + [rule], projection, QQ, wcfg, counts)
else:
counts['d'][tuple(derivation)] += 1
counts['p'][tuple(projection)] += 1
counts = {'d': defaultdict(int), 'p': defaultdict(int)}
recursion([], [], deque([root]), wcfg, counts)
return counts
def read_grammar_rules(istream, transform=log, strip_quotes=False):
"""
Reads grammar rules in cdec format.
>>> import math
>>> istream = ['[S] ||| [X] ||| 1.0', '[X] ||| [X] [X] ||| 0.5'] + ['[X] ||| %d ||| 0.1' % i for i in range(1,6)]
>>> rules = list(read_grammar_rules(istream, transform=log))
>>> rules
[[S] -> [X] (0.0), [X] -> [X] [X] (-0.69314718056), [X] -> 1 (-2.30258509299), [X] -> 2 (-2.30258509299), [X] -> 3 (-2.30258509299), [X] -> 4 (-2.30258509299), [X] -> 5 (-2.30258509299)]
"""
for line in istream:
lhs, rhs, log_prob = line.strip().split(' ||| ')
if not strip_quotes:
rhs = rhs.split()
else:
rhs = [s[1:-1] if s.startswith("'") and s.endswith("'") else s for s in rhs.split()]
log_prob = transform(float(log_prob))
yield Rule(lhs, rhs, log_prob)