1
1
package edu .stanford .nlp .parser .shiftreduce ;
2
2
3
+ import java .io .ByteArrayInputStream ;
4
+ import java .io .ByteArrayOutputStream ;
3
5
import java .io .Serializable ;
4
6
7
+ import edu .stanford .nlp .io .ByteArrayUtils ;
5
8
import edu .stanford .nlp .util .ArrayUtils ;
6
9
10
+
7
11
/**
8
12
* Stores one row of the sparse matrix which makes up the multiclass perceptron.
9
13
*
23
27
*/
24
28
25
29
public class Weight implements Serializable {
30
+ static final short [] EMPTY = {};
31
+
26
32
public Weight () {
27
- packed = null ;
33
+ packed = EMPTY ;
28
34
}
29
35
30
36
public Weight (Weight other ) {
31
37
if (other .size () == 0 ) {
32
- packed = null ;
38
+ packed = EMPTY ;
33
39
return ;
34
40
}
35
41
packed = ArrayUtils .copy (other .packed );
36
42
condense ();
37
43
}
38
44
39
45
public int size () {
40
- if (packed == null ) {
41
- return 0 ;
42
- }
43
- return packed .length ;
46
+ // TODO: find a fast way of doing this... we know it's a multiple of 3 after all
47
+ return packed .length / 3 ;
44
48
}
45
49
46
- private int unpackIndex (int i ) {
47
- long pack = packed [i ];
48
- return (int ) (pack >>> 32 );
50
+ private short unpackIndex (int i ) {
51
+ return packed [i * 3 ];
49
52
}
50
53
51
54
private float unpackScore (int i ) {
52
- long pack = packed [i ];
53
- return Float .intBitsToFloat ((int ) (pack & 0xFFFFFFFF ));
55
+ i = i * 3 + 1 ;
56
+ final int high = ((int ) packed [i ++]) << 16 ;
57
+ final int low = packed [i ] & 0x0000FFFF ;
58
+ return Float .intBitsToFloat (high | low );
54
59
}
55
60
56
- private static long packedValue (int index , float score ) {
57
- long pack = ((long ) (Float .floatToIntBits (score ))) & 0x00000000FFFFFFFFL ;
58
- pack = pack | (((long ) index ) << 32 );
59
- return pack ;
60
- }
61
-
62
- private static void pack (long [] packed , int i , int index , float score ) {
63
- packed [i ] = packedValue (index , score );
61
+ private static void pack (short [] packed , int i , int index , float score ) {
62
+ if (i > Short .MAX_VALUE ) {
63
+ throw new ArithmeticException ("How did you make an index with 30,000 weights??" );
64
+ }
65
+ int pos = i * 3 ;
66
+ packed [pos ++] = (short ) index ;
67
+ final int bits = Float .floatToIntBits (score );
68
+ packed [pos ++] = (short ) ((bits & 0xFFFF0000 ) >> 16 );
69
+ packed [pos ] = (short ) (bits & 0x0000FFFF );
64
70
}
65
71
66
72
private void pack (int i , int index , float score ) {
67
- packed [i ] = packedValue (index , score );
73
+ if (i > Short .MAX_VALUE ) {
74
+ throw new ArithmeticException ("How did you make an index with 30,000 weights??" );
75
+ }
76
+ int pos = i * 3 ;
77
+ packed [pos ++] = (short ) index ;
78
+ final int bits = Float .floatToIntBits (score );
79
+ packed [pos ++] = (short ) ((bits & 0xFFFF0000 ) >> 16 );
80
+ packed [pos ] = (short ) (bits & 0x0000FFFF );
68
81
}
69
82
70
83
public void score (float [] scores ) {
71
- final int length = size ();
72
- if (length > scores .length ) {
84
+ if (packed .length > scores .length * 3 ) {
73
85
throw new AssertionError ("Called with an array of scores too small to fit" );
74
86
}
75
- for (int i = 0 ; i < length ; ++ i ) {
87
+ for (int i = 0 ; i < packed . length ; ) {
76
88
// Since this is the critical method, we optimize it even further.
77
89
// We could do this:
78
90
// int index = unpackIndex; float score = unpackScore;
79
- // That results in an extra array lookup
80
- final long pack = packed [i ];
81
- final int index = (int ) (pack >>> 32 );
82
- final float score = Float .intBitsToFloat ((int ) (pack & 0xFFFFFFFF ));
91
+ // That results in extra operations
92
+ final short index = packed [i ++];
93
+ final int high = ((int ) packed [i ++]) << 16 ;
94
+ final int low = packed [i ++] & 0x0000FFFF ;
95
+ final int bits = high | low ;
96
+ // final int bits = (((int) packed[i++]) << 16) | (packed[i++] & 0x0000FFFF);
97
+ final float score = Float .intBitsToFloat (bits );
83
98
scores [index ] += score ;
84
99
}
85
100
}
@@ -98,7 +113,7 @@ public void addScaled(Weight other, float scale) {
98
113
void condense () {
99
114
// threshold is in case floating point math makes a feature we
100
115
// don't care about exist
101
- if (packed == null ) {
116
+ if (packed == null || packed . length == 0 ) {
102
117
return ;
103
118
}
104
119
@@ -111,15 +126,15 @@ void condense() {
111
126
}
112
127
113
128
if (nonzero == 0 ) {
114
- packed = null ;
129
+ packed = EMPTY ;
115
130
return ;
116
131
}
117
132
118
133
if (nonzero == length ) {
119
134
return ;
120
135
}
121
136
122
- long [] newPacked = new long [nonzero ];
137
+ short [] newPacked = new short [nonzero * 3 ];
123
138
int j = 0 ;
124
139
for (int i = 0 ; i < length ; ++i ) {
125
140
if (Math .abs (unpackScore (i )) <= THRESHOLD ) {
@@ -152,23 +167,23 @@ public void updateWeight(int index, float increment) {
152
167
return ;
153
168
}
154
169
155
- if (packed == null ) {
156
- packed = new long [ 1 ];
170
+ if (packed == null || packed . length == 0 ) {
171
+ packed = new short [ 3 ];
157
172
pack (0 , index , increment );
158
173
return ;
159
174
}
160
175
161
176
final int length = size ();
162
177
for (int i = 0 ; i < length ; ++i ) {
163
178
if (unpackIndex (i ) == index ) {
164
- float score = unpackScore (i );
179
+ final float score = unpackScore (i );
165
180
pack (i , index , score + increment );
166
181
return ;
167
182
}
168
183
}
169
184
170
- long [] newPacked = new long [ length + 1 ];
171
- for (int i = 0 ; i < length ; ++i ) {
185
+ short [] newPacked = new short [ packed . length + 3 ];
186
+ for (int i = 0 ; i < packed . length ; ++i ) {
172
187
newPacked [i ] = packed [i ];
173
188
}
174
189
pack (newPacked , length , index , increment );
@@ -231,15 +246,33 @@ void l2Reg(float reg) {
231
246
public String toString () {
232
247
StringBuilder builder = new StringBuilder ();
233
248
final int length = size ();
249
+ builder .append ("Weight(" );
234
250
for (int i = 0 ; i < length ; ++i ) {
235
- if (i > 0 ) builder .append (" " );
251
+ if (i > 0 ) builder .append (" " );
236
252
builder .append (unpackIndex (i ) + "=" + unpackScore (i ));
237
253
}
254
+ builder .append (")" );
238
255
return builder .toString ();
239
256
}
240
257
241
- private long [] packed ;
258
+ private short [] packed ;
242
259
243
- private static final long serialVersionUID = 1 ;
260
+ void writeBytes (ByteArrayOutputStream bout ) {
261
+ ByteArrayUtils .writeInt (bout , packed .length );
262
+ for (int i = 0 ; i < packed .length ; ++i ) {
263
+ ByteArrayUtils .writeShort (bout , packed [i ]);
264
+ }
265
+ }
266
+
267
+ static Weight readBytes (ByteArrayInputStream bin ) {
268
+ int len = ByteArrayUtils .readInt (bin );
269
+ Weight weight = new Weight ();
270
+ weight .packed = new short [len ];
271
+ for (int i = 0 ; i < len ; ++i ) {
272
+ weight .packed [i ] = ByteArrayUtils .readShort (bin );
273
+ }
274
+ return weight ;
275
+ }
244
276
277
+ private static final long serialVersionUID = 3 ;
245
278
}
0 commit comments