This repository has been archived by the owner on Dec 13, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 64
/
Rule.java
628 lines (552 loc) · 20.9 KB
/
Rule.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.joshua.decoder.ff.tm;
import static org.apache.joshua.decoder.ff.tm.OwnerMap.UNKNOWN_OWNER_ID;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;
import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;
import org.apache.joshua.corpus.Vocabulary;
import org.apache.joshua.decoder.ff.FeatureFunction;
import org.apache.joshua.decoder.ff.FeatureVector;
import org.apache.joshua.decoder.segment_file.Sentence;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* This class define the interface for Rule.
*
* All feature scores are interpreted as negative log probabilities, and are therefore negated.
* Note that not all features need to be negative log probs, but you should be aware that they
* will be negated, so if you want a positive count, it should come in as negative.
*
* Normally, the feature score in the rule should be *cost* (i.e., -LogP), so that the feature
* weight should be positive
*
* @author Zhifei Li, zhifei.work@gmail.com
* @author Matt Post post@cs.jhu.edu
*/
public class Rule implements Comparator<Rule>, Comparable<Rule> {
private static final Logger LOG = LoggerFactory.getLogger(Rule.class);
private int lhs; // tag of this rule
private int[] source; // pointer to the RuleCollection, as all the rules under it share the same
// Source side
protected int arity;
// And a string containing the sparse ones
//protected final String sparseFeatureString;
protected final Supplier<String> sparseFeatureStringSupplier;
private final Supplier<FeatureVector> featuresSupplier;
/*
* a feature function will be fired for this rule only if the owner of the rule matches the owner
* of the feature function
*/
private OwnerId owner = UNKNOWN_OWNER_ID;
/**
* This is the cost computed only from the features present with the grammar rule. This cost is
* needed to sort the rules in the grammar for cube pruning, but isn't the full cost of applying
* the rule (which will include contextual features that can't be computed until the rule is
* applied).
*/
private float estimatedCost = Float.NEGATIVE_INFINITY;
private float precomputableCost = Float.NEGATIVE_INFINITY;
private int[] target;
// The alignment string, e.g., 0-0 0-1 1-1 2-1
private String alignmentString;
private final Supplier<byte[]> alignmentSupplier;
/**
* Constructs a new rule using the provided parameters. Rule id for this rule is
* undefined. Note that some of the sparse features may be unlabeled, but they cannot be mapped to
* their default names ("tm_OWNER_INDEX") until later, when we know the owner of the rule. This is
* not known until the rule is actually added to a grammar in Grammar::addRule().
*
* Constructor used by other constructors below;
*
* @param lhs Left-hand side of the rule.
* @param source Source language right-hand side of the rule.
* @param target Target language right-hand side of the rule.
* @param sparseFeatures Feature value scores for the rule.
* @param arity Number of nonterminals in the source language right-hand side.
* @param owner todo
*/
public Rule(int lhs, int[] source, int[] target, String sparseFeatures, int arity, OwnerId owner) {
this.lhs = lhs;
this.source = source;
this.arity = arity;
this.owner = owner;
this.target = target;
this.sparseFeatureStringSupplier = Suppliers.memoize(() -> sparseFeatures);
this.featuresSupplier = initializeFeatureSupplierFromString();
this.alignmentSupplier = initializeAlignmentSupplier();
}
/**
* Constructor used by PackedGrammar's sortRules()
* @param lhs todo
* @param sourceRhs todo
* @param targetRhs todo
* @param features todo
* @param arity todo
* @param owner todo
*/
public Rule(int lhs, int[] sourceRhs, int[] targetRhs, FeatureVector features, int arity, OwnerId owner) {
this.lhs = lhs;
this.source = sourceRhs;
this.arity = arity;
this.owner = owner;
this.target = targetRhs;
this.featuresSupplier = Suppliers.memoize(() -> features);
this.sparseFeatureStringSupplier = initializeSparseFeaturesStringSupplier();
this.alignmentSupplier = initializeAlignmentSupplier();
}
/**
* Constructor used for SamtFormatReader and GrammarBuilderWalkerFunction's getRuleWithSpans()
* Rule is unowned.
* @param lhs todo
* @param sourceRhs todo
* @param targetRhs todo
* @param sparseFeatures todo
* @param arity todo
*/
public Rule(int lhs, int[] sourceRhs, int[] targetRhs, String sparseFeatures, int arity) {
this(lhs, sourceRhs, targetRhs, sparseFeatures, arity, OwnerMap.UNKNOWN_OWNER_ID);
}
/**
* Constructor used for addOOVRules(), HieroFormatReader and PhraseRule.
* @param lhs todo
* @param sourceRhs todo
* @param targetRhs todo
* @param sparseFeatures todo
* @param arity todo
* @param alignment todo
*/
public Rule(int lhs, int[] sourceRhs, int[] targetRhs, String sparseFeatures, int arity, String alignment) {
this(lhs, sourceRhs, targetRhs, sparseFeatures, arity);
this.alignmentString = alignment;
}
/**
* Constructor (implicitly) used by PackedRule
*/
public Rule() {
this.lhs = -1;
this.sparseFeatureStringSupplier = initializeSparseFeaturesStringSupplier();
this.featuresSupplier = initializeFeatureSupplierFromString();
this.alignmentSupplier = initializeAlignmentSupplier();
}
// ==========================================================================
// Lazy loading Suppliers for alignments, feature vector, and feature strings
// ==========================================================================
private Supplier<byte[]> initializeAlignmentSupplier(){
return Suppliers.memoize(() ->{
byte[] alignment = null;
String alignmentString = getAlignmentString();
if (alignmentString != null) {
String[] tokens = alignmentString.split("[-\\s]+");
alignment = new byte[tokens.length];
for (int i = 0; i < tokens.length; i++)
alignment[i] = (byte) Short.parseShort(tokens[i]);
}
return alignment;
});
}
/**
* If Rule was constructed with sparseFeatures String, we lazily populate the
* FeatureSupplier.
*/
private Supplier<FeatureVector> initializeFeatureSupplierFromString(){
return Suppliers.memoize(() ->{
if (!owner.equals(UNKNOWN_OWNER_ID)) {
return new FeatureVector(getFeatureString(), "tm_" + OwnerMap.getOwner(owner) + "_");
} else {
return new FeatureVector();
}
});
}
/**
* If Rule was constructed with a FeatureVector, we lazily populate the sparseFeaturesStringSupplier.
*/
private Supplier<String> initializeSparseFeaturesStringSupplier() {
return Suppliers.memoize(() -> getFeatureVector().toString());
}
// ===============================================================
// Attributes
// ===============================================================
public void setEnglish(int[] eng) {
this.target = eng;
}
public int[] getEnglish() {
return this.target;
}
/**
* Two Rules are equal of they have the same LHS, the same source RHS and the same target
* RHS.
*
* @param o the object to check for equality
* @return true if o is the same Rule as this rule, false otherwise
*/
public boolean equals(Object o) {
if (!(o instanceof Rule)) {
return false;
}
Rule other = (Rule) o;
if (getLHS() != other.getLHS()) {
return false;
}
if (!Arrays.equals(getFrench(), other.getFrench())) {
return false;
}
return Arrays.equals(target, other.getEnglish());
}
public int hashCode() {
// I just made this up. If two rules are equal they'll have the
// same hashcode. Maybe someone else can do a better job though?
int frHash = Arrays.hashCode(getFrench());
int enHash = Arrays.hashCode(target);
return frHash ^ enHash ^ getLHS();
}
// ===============================================================
// Attributes
// ===============================================================
public void setArity(int arity) {
this.arity = arity;
}
public int getArity() {
return this.arity;
}
public void setOwner(final OwnerId owner) {
this.owner = owner;
}
public OwnerId getOwner() {
return this.owner;
}
public void setLHS(int lhs) {
this.lhs = lhs;
}
public int getLHS() {
return this.lhs;
}
public void setFrench(int[] french) {
this.source = french;
}
public int[] getFrench() {
return this.source;
}
/**
* This function does the work of turning the string version of the sparse features (passed in
* when the rule was created) into an actual set of features. This is a bit complicated because we
* support intermingled labeled and unlabeled features, where the unlabeled features are mapped to
* a default name template of the form "tm_OWNER_INDEX".
*
* This function returns the dense (phrasal) features discovered when the rule was loaded. Dense
* features are the list of unlabeled features that preceded labeled ones. They can also be
* specified as labeled features of the form "tm_OWNER_INDEX", but the former format is preferred.
*
* @return the {@link org.apache.joshua.decoder.ff.FeatureVector} for this rule
*/
public FeatureVector getFeatureVector() {
return featuresSupplier.get();
}
/**
* This function returns the estimated cost of a rule, which should have been computed when the
* grammar was first sorted via a call to Rule::estimateRuleCost(). This function is a getter
* only; it will not compute the value if it has not already been set. It is necessary in addition
* to estimateRuleCost(models) because sometimes the value needs to be retrieved from contexts
* that do not have access to the feature functions.
*
* This function is called by the rule comparator when sorting the grammar. As such it may be
* called many times and any implementation of it should be a cached implementation.
*
* @return the estimated cost of the rule (a lower bound on the true cost)
*/
public float getEstimatedCost() {
return estimatedCost;
}
/**
* Precomputable costs is the inner product of the weights found on each grammar rule and the
* weight vector. This is slightly different from the estimated rule cost, which can include other
* features (such as a language model estimate). This getter and setter should also be cached, and
* is basically provided to allow the PhraseModel feature to cache its (expensive) computation for
* each rule.
*
* The weights are passed in as dense weights and sparse weights. This allows the dense weight
* precomputation to be even faster (since we don't have to query a hash map.
*
* @param dense_weights the dense weights from the model
* @param weights the sparse weights from the model
*/
public void setPrecomputableCost(float[] dense_weights, FeatureVector weights) {
float cost = 0.0f;
FeatureVector features = getFeatureVector();
for (int i = 0; i < features.getDenseFeatures().size() && i < dense_weights.length; i++) {
cost += dense_weights[i] * features.getDense(i);
}
for (String key: features.getSparseFeatures().keySet()) {
cost += weights.getSparse(key) * features.getSparse(key);
}
this.precomputableCost = cost;
}
/**
* @return the precomputed model cost of each rule
*/
public float getPrecomputableCost() {
return precomputableCost;
}
public float getDenseFeature(int k) {
return getFeatureVector().getDense(k);
}
/**
* This function estimates the cost of a rule, which is used for sorting the rules for cube
* pruning. The estimated cost is basically the set of precomputable features (features listed
* along with the rule in the grammar file) along with any other estimates that other features
* would like to contribute (e.g., a language model estimate). This cost will be a lower bound on
* the rule's actual cost.
*
* The value of this function is used only for sorting the rules. When the rule is later applied
* in context to particular hypernodes, the rule's actual cost is computed.
*
* @param models the list of models available to the decoder
* @return estimated cost of the rule
*/
public float estimateRuleCost(List<FeatureFunction> models) {
if (null == models)
return 0.0f;
if (this.estimatedCost <= Float.NEGATIVE_INFINITY) {
this.estimatedCost = 0.0f; // weights.innerProduct(computeFeatures());
// LOG.debug("estimateCost({} ;; {})", getFrenchWords(), getEnglishWords());
for (FeatureFunction ff : models) {
float val = ff.estimateCost(this);
// LOG.debug(" FEATURE {} -> {}", ff.getName(), val);
this.estimatedCost += val;
}
}
return estimatedCost;
}
// ===============================================================
// Methods
// ===============================================================
public String toString() {
String sb = Vocabulary.word(this.getLHS()) +
" ||| " +
getFrenchWords() +
" ||| " +
getEnglishWords() +
" |||" +
" " + getFeatureVector() +
String.format(" ||| est=%.3f", getEstimatedCost()) +
String.format(" pre=%.3f", getPrecomputableCost());
return sb;
}
/**
* Returns a version of the rule suitable for reading in from a text file.
*
* @return string version of the rule
*/
public String textFormat() {
StringBuffer sb = new StringBuffer();
sb.append(Vocabulary.word(this.getLHS()));
sb.append(" |||");
int nt = 1;
for (int i = 0; i < getFrench().length; i++) {
if (getFrench()[i] < 0)
sb.append(" ").append(
Vocabulary.word(getFrench()[i]).replaceFirst("\\]", String.format(",%d]", nt++)));
else
sb.append(" ").append(Vocabulary.word(getFrench()[i]));
}
sb.append(" |||");
nt = 1;
for (int i = 0; i < getEnglish().length; i++) {
if (getEnglish()[i] < 0)
sb.append(" ").append(
Vocabulary.word(getEnglish()[i]).replaceFirst("\\]", String.format(",%d]", nt++)));
else
sb.append(" ").append(Vocabulary.word(getEnglish()[i]));
}
sb.append(" |||");
sb.append(" ").append(getFeatureString());
if (getAlignmentString() != null)
sb.append(" ||| ").append(getAlignmentString());
return sb.toString();
}
public String getFeatureString() {
return sparseFeatureStringSupplier.get();
}
/**
* Returns an alignment as a sequence of integers. The integers at positions i and i+1 are paired,
* with position i indexing the source and i+1 the target.
*
* @return a byte[] from the {@link com.google.common.base.Supplier}
*/
public byte[] getAlignment() {
return this.alignmentSupplier.get();
}
public String getAlignmentString() {
return this.alignmentString;
}
/**
* The nonterminals on the English side are pointers to the source side nonterminals (-1 and -2),
* rather than being directly encoded. These number indicate the correspondence between the
* nonterminals on each side, introducing a level of indirection however when we want to resolve
* them. So to get the ID, we need to look up the corresponding source side ID.
*
* @return The string of English words
*/
public String getEnglishWords() {
int[] foreignNTs = getForeignNonTerminals();
StringBuilder sb = new StringBuilder();
for (Integer index : getEnglish()) {
if (index >= 0)
sb.append(Vocabulary.word(index)).append(" ");
else
sb.append(Vocabulary.word(foreignNTs[-index - 1]).replace("]",
String.format(",%d] ", Math.abs(index))));
}
return sb.toString().trim();
}
public boolean isTerminal() {
for (int i = 0; i < getEnglish().length; i++)
if (getEnglish()[i] < 0)
return false;
return true;
}
/**
* Return the French (source) nonterminals as list of Strings
*
* @return a list of strings
*/
public int[] getForeignNonTerminals() {
int[] nts = new int[getArity()];
int index = 0;
for (int id : getFrench())
if (id < 0)
nts[index++] = -id;
return nts;
}
/**
* Returns an array of size getArity() containing the source indeces of non terminals.
*
* @return an array of size getArity() containing the source indeces of non terminals
*/
public int[] getNonTerminalSourcePositions() {
int[] nonTerminalPositions = new int[getArity()];
int ntPos = 0;
for (int sourceIdx = 0; sourceIdx < getFrench().length; sourceIdx++) {
if (getFrench()[sourceIdx] < 0)
nonTerminalPositions[ntPos++] = sourceIdx;
}
return nonTerminalPositions;
}
/**
* Parses the Alignment byte[] into a Map from target to (possibly a list of) source positions.
* Used by the WordAlignmentExtractor.
*
* @return a {@link java.util.Map} of alignments
*/
public Map<Integer, List<Integer>> getAlignmentMap() {
byte[] alignmentArray = getAlignment();
Map<Integer, List<Integer>> alignmentMap = new HashMap<>();
if (alignmentArray != null) {
for (int alignmentIdx = 0; alignmentIdx < alignmentArray.length; alignmentIdx += 2 ) {
int s = alignmentArray[alignmentIdx];
int t = alignmentArray[alignmentIdx + 1];
List<Integer> values = alignmentMap.get(t);
if (values == null)
alignmentMap.put(t, values = new ArrayList<>());
values.add(s);
}
}
return alignmentMap;
}
/**
* Return the English (target) nonterminals as list of Strings
*
* @return list of strings
*/
public int[] getEnglishNonTerminals() {
int[] nts = new int[getArity()];
int[] foreignNTs = getForeignNonTerminals();
int index = 0;
for (int i : getEnglish()) {
if (i < 0)
nts[index++] = foreignNTs[Math.abs(getEnglish()[i]) - 1];
}
return nts;
}
private int[] getNormalizedEnglishNonterminalIndices() {
int[] result = new int[getArity()];
int ntIndex = 0;
for (Integer index : getEnglish()) {
if (index < 0)
result[ntIndex++] = -index - 1;
}
return result;
}
public boolean isInverting() {
int[] normalizedEnglishNonTerminalIndices = getNormalizedEnglishNonterminalIndices();
if (normalizedEnglishNonTerminalIndices.length == 2) {
if (normalizedEnglishNonTerminalIndices[0] == 1) {
return true;
}
}
return false;
}
public String getFrenchWords() {
return Vocabulary.getWords(getFrench());
}
public static final String NT_REGEX = "\\[[^\\]]+?\\]";
private Pattern getPattern() {
String source = getFrenchWords();
String pattern = Pattern.quote(source);
pattern = pattern.replaceAll(NT_REGEX, "\\\\E.+\\\\Q");
pattern = pattern.replaceAll("\\\\Q\\\\E", "");
pattern = "(?:^|\\s)" + pattern + "(?:$|\\s)";
return Pattern.compile(pattern);
}
/**
* Matches the string representation of the rule's source side against a sentence
*
* @param sentence {@link org.apache.joshua.lattice.Lattice} input
* @return true if there is a match
*/
public boolean matches(Sentence sentence) {
// System.err.println(String.format("match(%s,%s) = %s", Pattern.quote(getFrenchWords()),
// sentence.annotatedSource(), match));
return getPattern().matcher(sentence.fullSource()).find();
}
/**
* This comparator is used for sorting the rules during cube pruning. An estimate of the cost
* of each rule is computed and used to sort.
*/
public static final Comparator<Rule> EstimatedCostComparator = (rule1, rule2) -> {
float cost1 = rule1.getEstimatedCost();
float cost2 = rule2.getEstimatedCost();
return Float.compare(cost2, cost1);
};
public int compare(Rule rule1, Rule rule2) {
return EstimatedCostComparator.compare(rule1, rule2);
}
public int compareTo(Rule other) {
return EstimatedCostComparator.compare(this, other);
}
public String getRuleString() {
return String.format("%s -> %s ||| %s", Vocabulary.word(getLHS()), getFrenchWords(), getEnglishWords());
}
}