Skip to content

Commit 83426ab

Browse files
authored
Merge pull request #1142 from stanfordnlp/sr_custom_serialization
Sr custom serialization
2 parents 5cd0429 + 7e9eaf1 commit 83426ab

File tree

8 files changed

+317
-45
lines changed

8 files changed

+317
-45
lines changed

scripts/srparser/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ englishSR.ser.gz:
144144
englishSR.beam.ser.gz:
145145
@echo Training $@
146146
@echo Will test on $(ENGLISH_TEST)
147-
java -mx50g edu.stanford.nlp.parser.shiftreduce.ShiftReduceParser -trainTreebank $(ENGLISH_TRAIN) -devTreebank $(ENGLISH_DEV) -serializedPath $@ $(DEFAULT_OPTIONS) -preTag -taggerSerializedFile $(ENGLISH_TAGGER) -tlpp $(ENGLISH_TLPP) $(TRAIN_BEAM) $(AUGMENT_LESS) > $@.out 2>&1
147+
java -mx80g edu.stanford.nlp.parser.shiftreduce.ShiftReduceParser -trainTreebank $(ENGLISH_TRAIN) -devTreebank $(ENGLISH_DEV) -serializedPath $@ $(DEFAULT_OPTIONS) -preTag -taggerSerializedFile $(ENGLISH_TAGGER) -tlpp $(ENGLISH_TLPP) $(TRAIN_BEAM) $(AUGMENT_LESS) > $@.out 2>&1
148148
java -mx5g edu.stanford.nlp.parser.shiftreduce.ShiftReduceParser $(TEST_ARGS) -testTreebank $(ENGLISH_TEST) -serializedPath $@ -preTag -taggerSerializedFile $(ENGLISH_TAGGER) >> $@.out 2>&1
149149

150150
frenchSR.ser.gz:
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package edu.stanford.nlp.io;
2+
3+
import java.io.ByteArrayInputStream;
4+
import java.io.ByteArrayOutputStream;
5+
6+
/**
7+
* Static methods for putting shorts, ints, and longs into a ByteArrayOutputStream using bit fiddling
8+
*
9+
* @author John Bauer
10+
*/
11+
public class ByteArrayUtils {
12+
static public short readShort(ByteArrayInputStream bin) {
13+
int high = ((bin.read() & 0x000000FF) << 8);
14+
int low = (bin.read() & 0x000000FF);
15+
return (short) ((high | low) & 0x0000FFFF);
16+
}
17+
18+
static public void writeShort(ByteArrayOutputStream bout, short val) {
19+
bout.write((byte)((val >> 8) & 0xff));
20+
bout.write((byte)(val & 0xff));
21+
}
22+
23+
static public int readInt(ByteArrayInputStream bin) {
24+
int b24 = ((bin.read() & 0x000000FF) << 24);
25+
int b16 = ((bin.read() & 0x000000FF) << 16);
26+
int b8 = ((bin.read() & 0x000000FF) << 8);
27+
int b0 = (bin.read() & 0x000000FF);
28+
return b24 | b16 | b8 | b0;
29+
}
30+
31+
static public void writeInt(ByteArrayOutputStream bout, int val) {
32+
bout.write((byte)((val >> 24) & 0xff));
33+
bout.write((byte)((val >> 16) & 0xff));
34+
bout.write((byte)((val >> 8) & 0xff));
35+
bout.write((byte)(val & 0xff));
36+
}
37+
38+
static public long readLong(ByteArrayInputStream bin) {
39+
long b56 = ((long) (bin.read() & 0x000000FF)) << 56;
40+
long b48 = ((long) (bin.read() & 0x000000FF)) << 48;
41+
long b40 = ((long) (bin.read() & 0x000000FF)) << 40;
42+
long b32 = ((long) (bin.read() & 0x000000FF)) << 32;
43+
long b24 = ((long) (bin.read() & 0x000000FF)) << 24;
44+
long b16 = ((long) (bin.read() & 0x000000FF)) << 16;
45+
long b8 = ((long) (bin.read() & 0x000000FF)) << 8;
46+
long b0 = ((long) (bin.read() & 0x000000FF));
47+
return b56 | b48 | b40 | b32 | b24 | b16 | b8 | b0;
48+
}
49+
50+
static public void writeLong(ByteArrayOutputStream bout, long val) {
51+
bout.write((byte)((val >> 56) & 0xff));
52+
bout.write((byte)((val >> 48) & 0xff));
53+
bout.write((byte)((val >> 40) & 0xff));
54+
bout.write((byte)((val >> 32) & 0xff));
55+
bout.write((byte)((val >> 24) & 0xff));
56+
bout.write((byte)((val >> 16) & 0xff));
57+
bout.write((byte)((val >> 8) & 0xff));
58+
bout.write((byte)(val & 0xff));
59+
}
60+
}

src/edu/stanford/nlp/parser/shiftreduce/BinaryTransition.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,12 +232,11 @@ public boolean equals(Object o) {
232232

233233
@Override
234234
public int hashCode() {
235-
// TODO: fix the hashcode for the side? would require rebuilding all models
236235
switch(side) {
237236
case LEFT:
238237
return 97197711 ^ label.hashCode();
239238
case RIGHT:
240-
return 97197711 ^ label.hashCode();
239+
return 85635467 ^ label.hashCode();
241240
default:
242241
throw new IllegalArgumentException("Unknown side " + side);
243242
}

src/edu/stanford/nlp/parser/shiftreduce/PerceptronModel.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,13 @@ public class PerceptronModel extends BaseModel {
4444

4545
private float learningRate = 1.0f;
4646

47-
Map<String, Weight> featureWeights;
47+
WeightMap featureWeights;
4848
final FeatureFactory featureFactory;
4949

5050
public PerceptronModel(ShiftReduceOptions op, Index<Transition> transitionIndex,
5151
Set<String> knownStates, Set<String> rootStates, Set<String> rootOnlyStates) {
5252
super(op, transitionIndex, knownStates, rootStates, rootOnlyStates);
53-
this.featureWeights = Generics.newHashMap();
53+
this.featureWeights = new WeightMap();
5454

5555
String[] classes = op.featureFactoryClass.split(";");
5656
if (classes.length == 1) {
@@ -74,7 +74,7 @@ public PerceptronModel(PerceptronModel other) {
7474
super(other);
7575
this.featureFactory = other.featureFactory;
7676

77-
this.featureWeights = Generics.newHashMap();
77+
this.featureWeights = new WeightMap();
7878
for (String feature : other.featureWeights.keySet()) {
7979
featureWeights.put(feature, new Weight(other.featureWeights.get(feature)));
8080
}
@@ -110,7 +110,7 @@ public void averageModels(Collection<PerceptronModel> models) {
110110
}
111111
}
112112

113-
featureWeights = Generics.newHashMap();
113+
featureWeights = new WeightMap();
114114
for (String feature : features) {
115115
featureWeights.put(feature, new Weight());
116116
}

src/edu/stanford/nlp/parser/shiftreduce/Weight.java

Lines changed: 71 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
package edu.stanford.nlp.parser.shiftreduce;
22

3+
import java.io.ByteArrayInputStream;
4+
import java.io.ByteArrayOutputStream;
35
import java.io.Serializable;
46

7+
import edu.stanford.nlp.io.ByteArrayUtils;
58
import edu.stanford.nlp.util.ArrayUtils;
69

10+
711
/**
812
* Stores one row of the sparse matrix which makes up the multiclass perceptron.
913
*
@@ -23,63 +27,74 @@
2327
*/
2428

2529
public class Weight implements Serializable {
30+
static final short[] EMPTY = {};
31+
2632
public Weight() {
27-
packed = null;
33+
packed = EMPTY;
2834
}
2935

3036
public Weight(Weight other) {
3137
if (other.size() == 0) {
32-
packed = null;
38+
packed = EMPTY;
3339
return;
3440
}
3541
packed = ArrayUtils.copy(other.packed);
3642
condense();
3743
}
3844

3945
public int size() {
40-
if (packed == null) {
41-
return 0;
42-
}
43-
return packed.length;
46+
// TODO: find a fast way of doing this... we know it's a multiple of 3 after all
47+
return packed.length / 3;
4448
}
4549

46-
private int unpackIndex(int i) {
47-
long pack = packed[i];
48-
return (int) (pack >>> 32);
50+
private short unpackIndex(int i) {
51+
return packed[i * 3];
4952
}
5053

5154
private float unpackScore(int i) {
52-
long pack = packed[i];
53-
return Float.intBitsToFloat((int) (pack & 0xFFFFFFFF));
55+
i = i * 3 + 1;
56+
final int high = ((int) packed[i++]) << 16;
57+
final int low = packed[i] & 0x0000FFFF;
58+
return Float.intBitsToFloat(high | low);
5459
}
5560

56-
private static long packedValue(int index, float score) {
57-
long pack = ((long) (Float.floatToIntBits(score))) & 0x00000000FFFFFFFFL;
58-
pack = pack | (((long) index) << 32);
59-
return pack;
60-
}
61-
62-
private static void pack(long[] packed, int i, int index, float score) {
63-
packed[i] = packedValue(index, score);
61+
private static void pack(short[] packed, int i, int index, float score) {
62+
if (i > Short.MAX_VALUE) {
63+
throw new ArithmeticException("How did you make an index with 30,000 weights??");
64+
}
65+
int pos = i * 3;
66+
packed[pos++] = (short) index;
67+
final int bits = Float.floatToIntBits(score);
68+
packed[pos++] = (short) ((bits & 0xFFFF0000) >> 16);
69+
packed[pos] = (short) (bits & 0x0000FFFF);
6470
}
6571

6672
private void pack(int i, int index, float score) {
67-
packed[i] = packedValue(index, score);
73+
if (i > Short.MAX_VALUE) {
74+
throw new ArithmeticException("How did you make an index with 30,000 weights??");
75+
}
76+
int pos = i * 3;
77+
packed[pos++] = (short) index;
78+
final int bits = Float.floatToIntBits(score);
79+
packed[pos++] = (short) ((bits & 0xFFFF0000) >> 16);
80+
packed[pos] = (short) (bits & 0x0000FFFF);
6881
}
6982

7083
public void score(float[] scores) {
71-
final int length = size();
72-
if (length > scores.length) {
84+
if (packed.length > scores.length * 3) {
7385
throw new AssertionError("Called with an array of scores too small to fit");
7486
}
75-
for (int i = 0; i < length; ++i) {
87+
for (int i = 0; i < packed.length; ) {
7688
// Since this is the critical method, we optimize it even further.
7789
// We could do this:
7890
// int index = unpackIndex; float score = unpackScore;
79-
// That results in an extra array lookup
80-
final long pack = packed[i];
81-
final int index = (int) (pack >>> 32);
82-
final float score = Float.intBitsToFloat((int) (pack & 0xFFFFFFFF));
91+
// That results in extra operations
92+
final short index = packed[i++];
93+
final int high = ((int) packed[i++]) << 16;
94+
final int low = packed[i++] & 0x0000FFFF;
95+
final int bits = high | low;
96+
// final int bits = (((int) packed[i++]) << 16) | (packed[i++] & 0x0000FFFF);
97+
final float score = Float.intBitsToFloat(bits);
8398
scores[index] += score;
8499
}
85100
}
@@ -98,7 +113,7 @@ public void addScaled(Weight other, float scale) {
98113
void condense() {
99114
// threshold is in case floating point math makes a feature we
100115
// don't care about exist
101-
if (packed == null) {
116+
if (packed == null || packed.length == 0) {
102117
return;
103118
}
104119

@@ -111,15 +126,15 @@ void condense() {
111126
}
112127

113128
if (nonzero == 0) {
114-
packed = null;
129+
packed = EMPTY;
115130
return;
116131
}
117132

118133
if (nonzero == length) {
119134
return;
120135
}
121136

122-
long[] newPacked = new long[nonzero];
137+
short[] newPacked = new short[nonzero * 3];
123138
int j = 0;
124139
for (int i = 0; i < length; ++i) {
125140
if (Math.abs(unpackScore(i)) <= THRESHOLD) {
@@ -152,23 +167,23 @@ public void updateWeight(int index, float increment) {
152167
return;
153168
}
154169

155-
if (packed == null) {
156-
packed = new long[1];
170+
if (packed == null || packed.length == 0) {
171+
packed = new short[3];
157172
pack(0, index, increment);
158173
return;
159174
}
160175

161176
final int length = size();
162177
for (int i = 0; i < length; ++i) {
163178
if (unpackIndex(i) == index) {
164-
float score = unpackScore(i);
179+
final float score = unpackScore(i);
165180
pack(i, index, score + increment);
166181
return;
167182
}
168183
}
169184

170-
long[] newPacked = new long[length + 1];
171-
for (int i = 0; i < length; ++i) {
185+
short[] newPacked = new short[packed.length + 3];
186+
for (int i = 0; i < packed.length; ++i) {
172187
newPacked[i] = packed[i];
173188
}
174189
pack(newPacked, length, index, increment);
@@ -231,15 +246,33 @@ void l2Reg(float reg) {
231246
public String toString() {
232247
StringBuilder builder = new StringBuilder();
233248
final int length = size();
249+
builder.append("Weight(");
234250
for (int i = 0; i < length; ++i) {
235-
if (i > 0) builder.append(" ");
251+
if (i > 0) builder.append(" ");
236252
builder.append(unpackIndex(i) + "=" + unpackScore(i));
237253
}
254+
builder.append(")");
238255
return builder.toString();
239256
}
240257

241-
private long[] packed;
258+
private short[] packed;
242259

243-
private static final long serialVersionUID = 1;
260+
void writeBytes(ByteArrayOutputStream bout) {
261+
ByteArrayUtils.writeInt(bout, packed.length);
262+
for (int i = 0; i < packed.length; ++i) {
263+
ByteArrayUtils.writeShort(bout, packed[i]);
264+
}
265+
}
266+
267+
static Weight readBytes(ByteArrayInputStream bin) {
268+
int len = ByteArrayUtils.readInt(bin);
269+
Weight weight = new Weight();
270+
weight.packed = new short[len];
271+
for (int i = 0; i < len; ++i) {
272+
weight.packed[i] = ByteArrayUtils.readShort(bin);
273+
}
274+
return weight;
275+
}
244276

277+
private static final long serialVersionUID = 3;
245278
}

0 commit comments

Comments
 (0)