Skip to content

Commit 32a427c

Browse files
committed
InMemoryStateInternals.copy clones the values using the coder
1 parent 0371848 commit 32a427c

File tree

2 files changed

+64
-20
lines changed

2 files changed

+64
-20
lines changed

runners/core-java/src/main/java/org/apache/beam/runners/core/InMemoryStateInternals.java

Lines changed: 59 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.apache.beam.sdk.annotations.Experimental;
3535
import org.apache.beam.sdk.annotations.Experimental.Kind;
3636
import org.apache.beam.sdk.coders.Coder;
37+
import org.apache.beam.sdk.coders.CoderException;
3738
import org.apache.beam.sdk.state.BagState;
3839
import org.apache.beam.sdk.state.CombiningState;
3940
import org.apache.beam.sdk.state.MapState;
@@ -49,6 +50,7 @@
4950
import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext;
5051
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
5152
import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
53+
import org.apache.beam.sdk.util.CoderUtils;
5254
import org.apache.beam.sdk.util.CombineFnUtil;
5355
import org.joda.time.Instant;
5456

@@ -126,25 +128,25 @@ public InMemoryStateBinder(StateContext<?> c) {
126128
@Override
127129
public <T> ValueState<T> bindValue(
128130
StateTag<ValueState<T>> address, Coder<T> coder) {
129-
return new InMemoryValue<>();
131+
return new InMemoryValue<>(coder);
130132
}
131133

132134
@Override
133135
public <T> BagState<T> bindBag(
134136
final StateTag<BagState<T>> address, Coder<T> elemCoder) {
135-
return new InMemoryBag<>();
137+
return new InMemoryBag<>(elemCoder);
136138
}
137139

138140
@Override
139141
public <T> SetState<T> bindSet(StateTag<SetState<T>> spec, Coder<T> elemCoder) {
140-
return new InMemorySet<>();
142+
return new InMemorySet<>(elemCoder);
141143
}
142144

143145
@Override
144146
public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
145147
StateTag<MapState<KeyT, ValueT>> spec,
146148
Coder<KeyT> mapKeyCoder, Coder<ValueT> mapValueCoder) {
147-
return new InMemoryMap<>();
149+
return new InMemoryMap<>(mapKeyCoder, mapValueCoder);
148150
}
149151

150152
@Override
@@ -153,7 +155,7 @@ public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
153155
StateTag<CombiningState<InputT, AccumT, OutputT>> address,
154156
Coder<AccumT> accumCoder,
155157
final CombineFn<InputT, AccumT, OutputT> combineFn) {
156-
return new InMemoryCombiningState<>(combineFn);
158+
return new InMemoryCombiningState<>(combineFn, accumCoder);
157159
}
158160

159161
@Override
@@ -178,9 +180,15 @@ public WatermarkHoldState bindWatermark(
178180
*/
179181
public static final class InMemoryValue<T>
180182
implements ValueState<T>, InMemoryState<InMemoryValue<T>> {
183+
private final Coder<T> coder;
184+
181185
private boolean isCleared = true;
182186
private @Nullable T value = null;
183187

188+
public InMemoryValue(Coder<T> coder) {
189+
this.coder = coder;
190+
}
191+
184192
@Override
185193
public void clear() {
186194
// Even though we're clearing we can't remove this from the in-memory state map, since
@@ -207,10 +215,10 @@ public void write(T input) {
207215

208216
@Override
209217
public InMemoryValue<T> copy() {
210-
InMemoryValue<T> that = new InMemoryValue<>();
218+
InMemoryValue<T> that = new InMemoryValue<>(coder);
211219
if (!this.isCleared) {
212220
that.isCleared = this.isCleared;
213-
that.value = this.value;
221+
that.value = unsafeClone(coder, this.value);
214222
}
215223
return that;
216224
}
@@ -305,14 +313,16 @@ public InMemoryWatermarkHold<W> copy() {
305313
public static final class InMemoryCombiningState<InputT, AccumT, OutputT>
306314
implements CombiningState<InputT, AccumT, OutputT>,
307315
InMemoryState<InMemoryCombiningState<InputT, AccumT, OutputT>> {
308-
private boolean isCleared = true;
309316
private final CombineFn<InputT, AccumT, OutputT> combineFn;
317+
private final Coder<AccumT> accumCoder;
318+
private boolean isCleared = true;
310319
private AccumT accum;
311320

312321
public InMemoryCombiningState(
313-
CombineFn<InputT, AccumT, OutputT> combineFn) {
322+
CombineFn<InputT, AccumT, OutputT> combineFn, Coder<AccumT> accumCoder) {
314323
this.combineFn = combineFn;
315324
accum = combineFn.createAccumulator();
325+
this.accumCoder = accumCoder;
316326
}
317327

318328
@Override
@@ -378,7 +388,7 @@ public boolean isCleared() {
378388
@Override
379389
public InMemoryCombiningState<InputT, AccumT, OutputT> copy() {
380390
InMemoryCombiningState<InputT, AccumT, OutputT> that =
381-
new InMemoryCombiningState<>(combineFn);
391+
new InMemoryCombiningState<>(combineFn, accumCoder);
382392
if (!this.isCleared) {
383393
that.isCleared = this.isCleared;
384394
that.addAccum(accum);
@@ -391,8 +401,13 @@ public InMemoryCombiningState<InputT, AccumT, OutputT> copy() {
391401
* An {@link InMemoryState} implementation of {@link BagState}.
392402
*/
393403
public static final class InMemoryBag<T> implements BagState<T>, InMemoryState<InMemoryBag<T>> {
404+
private final Coder<T> elemCoder;
394405
private List<T> contents = new ArrayList<>();
395406

407+
public InMemoryBag(Coder<T> elemCoder) {
408+
this.elemCoder = elemCoder;
409+
}
410+
396411
@Override
397412
public void clear() {
398413
// Even though we're clearing we can't remove this from the in-memory state map, since
@@ -442,8 +457,10 @@ public Boolean read() {
442457

443458
@Override
444459
public InMemoryBag<T> copy() {
445-
InMemoryBag<T> that = new InMemoryBag<>();
446-
that.contents.addAll(this.contents);
460+
InMemoryBag<T> that = new InMemoryBag<>(elemCoder);
461+
for (T elem : this.contents) {
462+
that.contents.add(unsafeClone(elemCoder, elem));
463+
}
447464
return that;
448465
}
449466
}
@@ -452,8 +469,13 @@ public InMemoryBag<T> copy() {
452469
* An {@link InMemoryState} implementation of {@link SetState}.
453470
*/
454471
public static final class InMemorySet<T> implements SetState<T>, InMemoryState<InMemorySet<T>> {
472+
private final Coder<T> elemCoder;
455473
private Set<T> contents = new HashSet<>();
456474

475+
public InMemorySet(Coder<T> elemCoder) {
476+
this.elemCoder = elemCoder;
477+
}
478+
457479
@Override
458480
public void clear() {
459481
contents = new HashSet<>();
@@ -513,8 +535,10 @@ public Boolean read() {
513535

514536
@Override
515537
public InMemorySet<T> copy() {
516-
InMemorySet<T> that = new InMemorySet<>();
517-
that.contents.addAll(this.contents);
538+
InMemorySet<T> that = new InMemorySet<>(elemCoder);
539+
for (T elem : this.contents) {
540+
that.contents.add(unsafeClone(elemCoder, elem));
541+
}
518542
return that;
519543
}
520544
}
@@ -524,8 +548,16 @@ public InMemorySet<T> copy() {
524548
*/
525549
public static final class InMemoryMap<K, V> implements
526550
MapState<K, V>, InMemoryState<InMemoryMap<K, V>> {
551+
private final Coder<K> keyCoder;
552+
private final Coder<V> valueCoder;
553+
527554
private Map<K, V> contents = new HashMap<>();
528555

556+
public InMemoryMap(Coder<K> keyCoder, Coder<V> valueCoder) {
557+
this.keyCoder = keyCoder;
558+
this.valueCoder = valueCoder;
559+
}
560+
529561
@Override
530562
public void clear() {
531563
contents = new HashMap<>();
@@ -600,9 +632,21 @@ public boolean isCleared() {
600632

601633
@Override
602634
public InMemoryMap<K, V> copy() {
603-
InMemoryMap<K, V> that = new InMemoryMap<>();
635+
InMemoryMap<K, V> that = new InMemoryMap<>(keyCoder, valueCoder);
636+
for (Map.Entry<K, V> entry : this.contents.entrySet()) {
637+
that.contents.put(
638+
unsafeClone(keyCoder, entry.getKey()), unsafeClone(valueCoder, entry.getValue()));
639+
}
604640
that.contents.putAll(this.contents);
605641
return that;
606642
}
607643
}
644+
645+
private static <T> T unsafeClone(Coder<T> coder, T value) {
646+
try {
647+
return CoderUtils.clone(coder, value);
648+
} catch (CoderException e) {
649+
throw new RuntimeException(e);
650+
}
651+
}
608652
}

runners/direct-java/src/main/java/org/apache/beam/runners/direct/CopyOnAccessInMemoryStateInternals.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ public <T> ValueState<T> bindValue(
300300
underlying.get().get(namespace, address, c);
301301
return existingState.copy();
302302
} else {
303-
return new InMemoryValue<>();
303+
return new InMemoryValue<>(coder);
304304
}
305305
}
306306

@@ -317,7 +317,7 @@ CombiningState<InputT, AccumT, OutputT> bindCombiningValue(
317317
underlying.get().get(namespace, address, c);
318318
return existingState.copy();
319319
} else {
320-
return new InMemoryCombiningState<>(combineFn);
320+
return new InMemoryCombiningState<>(combineFn, accumCoder);
321321
}
322322
}
323323

@@ -331,7 +331,7 @@ public <T> BagState<T> bindBag(
331331
underlying.get().get(namespace, address, c);
332332
return existingState.copy();
333333
} else {
334-
return new InMemoryBag<>();
334+
return new InMemoryBag<>(elemCoder);
335335
}
336336
}
337337

@@ -345,7 +345,7 @@ public <T> SetState<T> bindSet(
345345
underlying.get().get(namespace, address, c);
346346
return existingState.copy();
347347
} else {
348-
return new InMemorySet<>();
348+
return new InMemorySet<>(elemCoder);
349349
}
350350
}
351351

@@ -361,7 +361,7 @@ public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
361361
underlying.get().get(namespace, address, c);
362362
return existingState.copy();
363363
} else {
364-
return new InMemoryMap<>();
364+
return new InMemoryMap<>(mapKeyCoder, mapValueCoder);
365365
}
366366
}
367367

0 commit comments

Comments
 (0)