diff --git a/examples/java/src/main/java/org/apache/beam/examples/complete/game/StatefulTeamScore2.java b/examples/java/src/main/java/org/apache/beam/examples/complete/game/StatefulTeamScore2.java new file mode 100644 index 0000000000000..faf8996dedc77 --- /dev/null +++ b/examples/java/src/main/java/org/apache/beam/examples/complete/game/StatefulTeamScore2.java @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.examples.complete.game; + +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects.firstNonNull; + +import java.util.HashMap; +import java.util.Map; +import org.apache.beam.examples.common.ExampleUtils; +import org.apache.beam.examples.complete.game.utils.GameConstants; +import org.apache.beam.examples.complete.game.utils.WriteToBigQuery.FieldInfo; +import org.apache.beam.examples.complete.game.utils.WriteWindowedToBigQuery; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.PipelineResult; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.extensions.gcp.options.GcpOptions; +import org.apache.beam.sdk.io.gcp.pubsub.PubsubIO; +import org.apache.beam.sdk.options.Default; +import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.ReadModifyWriteState; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.sdk.values.TypeDescriptors; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting; +import org.joda.time.Instant; + +/** + * This class is part of a series of pipelines that tell a story in a gaming domain. Concepts + * include: stateful processing. + * + *

This pipeline processes an unbounded stream of 'game events'. It uses stateful processing to + * aggregate team scores per team and outputs team name and it's total score every time the team + * passes a new multiple of a threshold score. For example, multiples of the threshold could be the + * corresponding scores required to pass each level of the game. By default, this threshold is set + * to 5000. + * + *

Stateful processing allows us to write pipelines that output based on a runtime state (when a + * team reaches a certain score, in every 100 game events etc) without time triggers. See + * https://beam.apache.org/blog/2017/02/13/stateful-processing.html for more information on using + * stateful processing. + * + *

Run {@code injector.Injector} to generate pubsub data for this pipeline. The Injector + * documentation provides more detail on how to do this. + * + *

To execute this pipeline, specify the pipeline configuration like this: + * + *

{@code
+ * --project=YOUR_PROJECT_ID
+ * --tempLocation=gs://YOUR_TEMP_DIRECTORY
+ * --runner=YOUR_RUNNER
+ * --dataset=YOUR-DATASET
+ * --topic=projects/YOUR-PROJECT/topics/YOUR-TOPIC
+ * }
+ * + *

The BigQuery dataset you specify must already exist. The PubSub topic you specify should be + * the same topic to which the Injector is publishing. + */ +public class StatefulTeamScore2 extends LeaderBoard { + + /** Options supported by {@link StatefulTeamScore}. */ + public interface Options extends LeaderBoard.Options { + + @Description("Numeric value, multiple of which is used as threshold for outputting team score.") + @Default.Integer(5000) + Integer getThresholdScore(); + + void setThresholdScore(Integer value); + } + + /** + * Create a map of information that describes how to write pipeline output to BigQuery. This map + * is used to write team score sums. + */ + private static Map>> configureCompleteWindowedTableWrite() { + + Map>> tableConfigure = + new HashMap<>(); + tableConfigure.put( + "team", new WriteWindowedToBigQuery.FieldInfo<>("STRING", (c, w) -> c.element().getKey())); + tableConfigure.put( + "total_score", + new WriteWindowedToBigQuery.FieldInfo<>("INTEGER", (c, w) -> c.element().getValue())); + tableConfigure.put( + "processing_time", + new WriteWindowedToBigQuery.FieldInfo<>( + "STRING", (c, w) -> GameConstants.DATE_TIME_FORMATTER.print(Instant.now()))); + return tableConfigure; + } + + public static void main(String[] args) throws Exception { + + Options options = PipelineOptionsFactory.fromArgs(args).withValidation().as(Options.class); + // Enforce that this pipeline is always run in streaming mode. + options.setStreaming(true); + ExampleUtils exampleUtils = new ExampleUtils(options); + Pipeline pipeline = Pipeline.create(options); + + pipeline + // Read game events from Pub/Sub using custom timestamps, which are extracted from the + // pubsub data elements, and parse the data. + .apply( + PubsubIO.readStrings() + .withTimestampAttribute(GameConstants.TIMESTAMP_ATTRIBUTE) + .fromTopic(options.getTopic())) + .apply("ParseGameEvent", ParDo.of(new ParseEventFn())) + // Create mapping. UpdateTeamScore uses team name as key. + .apply( + "MapTeamAsKey", + MapElements.into( + TypeDescriptors.kvs( + TypeDescriptors.strings(), TypeDescriptor.of(GameActionInfo.class))) + .via((GameActionInfo gInfo) -> KV.of(gInfo.team, gInfo))) + // Outputs a team's score every time it passes a new multiple of the threshold. + .apply("UpdateTeamScore", ParDo.of(new UpdateTeamScoreFn(options.getThresholdScore()))) + // Write the results to BigQuery. + .apply( + "WriteTeamLeaders", + new WriteWindowedToBigQuery<>( + options.as(GcpOptions.class).getProject(), + options.getDataset(), + options.getLeaderBoardTableName() + "_team_leader", + configureCompleteWindowedTableWrite())); + + // Run the pipeline and wait for the pipeline to finish; capture cancellation requests from the + // command line. + PipelineResult result = pipeline.run(); + exampleUtils.waitToFinish(result); + } + + /** + * Tracks each team's score separately in a single state cell and outputs the score every time it + * passes a new multiple of a threshold. + * + *

We use stateful {@link DoFn} because: + * + *

+ */ + @VisibleForTesting + public static class UpdateTeamScoreFn + extends DoFn, KV> { + + private static final String TOTAL_SCORE = "totalScore"; + private final int thresholdScore; + + public UpdateTeamScoreFn(int thresholdScore) { + this.thresholdScore = thresholdScore; + } + + /** + * Describes the state for storing team score. Let's break down this statement. + * + *

{@link StateSpec} configures the state cell, which is provided by a runner during pipeline + * execution. + * + *

{@link org.apache.beam.sdk.transforms.DoFn.StateId} annotation assigns an identifier to + * the state, which is used to refer the state in {@link + * org.apache.beam.sdk.transforms.DoFn.ProcessElement}. + * + *

A {@link ReadModifyWriteState} stores single value per key and per window. Because our pipeline is + * globally windowed in this example, this {@link ReadModifyWriteState} is just key partitioned, with one + * score per team. Any other class that extends {@link org.apache.beam.sdk.state.State} can be + * used. + * + *

In order to store the value, the state must be encoded. Therefore, we provide a coder, in + * this case the {@link VarIntCoder}. If the coder is not provided as in {@code + * StateSpecs.value()}, Beam's coder inference will try to provide a coder automatically. + */ + @StateId(TOTAL_SCORE) + private final StateSpec> totalScoreSpec = + StateSpecs.readModifyWrite(VarIntCoder.of()); + + /** + * To use a state cell, annotate a parameter with {@link + * org.apache.beam.sdk.transforms.DoFn.StateId} that matches the state declaration. The type of + * the parameter should match the {@link StateSpec} type. + */ + @ProcessElement + public void processElement( + ProcessContext c, @StateId(TOTAL_SCORE) ReadModifyWriteState totalScore) { + String teamName = c.element().getKey(); + GameActionInfo gInfo = c.element().getValue(); + + // ReadModifyWriteState cells do not contain a default value. If the state is possibly not written, make + // sure to check for null on read. + int oldTotalScore = firstNonNull(totalScore.read(), 0); + totalScore.write(oldTotalScore + gInfo.score); + + // Since there are no negative scores, the easiest way to check whether a team just passed a + // new multiple of the threshold score is to compare the quotients of dividing total scores by + // threshold before and after this aggregation. For example, if the total score was 1999, + // the new total is 2002, and the threshold is 1000, 1999 / 1000 = 1, 2002 / 1000 = 2. + // Therefore, this team passed the threshold. + if (oldTotalScore / this.thresholdScore < totalScore.read() / this.thresholdScore) { + c.output(KV.of(teamName, totalScore.read())); + } + } + } +} diff --git a/examples/java/src/test/java/org/apache/beam/examples/complete/game/StatefulTeamScoreTest2.java b/examples/java/src/test/java/org/apache/beam/examples/complete/game/StatefulTeamScoreTest2.java new file mode 100644 index 0000000000000..078a4bddd8add --- /dev/null +++ b/examples/java/src/test/java/org/apache/beam/examples/complete/game/StatefulTeamScoreTest2.java @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.examples.complete.game; + +import org.apache.beam.examples.complete.game.StatefulTeamScore2.UpdateTeamScoreFn; +import org.apache.beam.examples.complete.game.UserScore.GameActionInfo; +import org.apache.beam.sdk.coders.AvroCoder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.testing.TestStream; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.transforms.windowing.IntervalWindow; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TimestampedValue; +import org.joda.time.Duration; +import org.joda.time.Instant; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link StatefulTeamScore2}. */ +@RunWith(JUnit4.class) +public class StatefulTeamScoreTest2 { + + private Instant baseTime = new Instant(0); + + @Rule public TestPipeline p = TestPipeline.create(); + + /** Some example users, on two separate teams. */ + private enum TestUser { + RED_ONE("scarlet", "red"), + RED_TWO("burgundy", "red"), + BLUE_ONE("navy", "blue"), + BLUE_TWO("sky", "blue"); + + private final String userName; + private final String teamName; + + TestUser(String userName, String teamName) { + this.userName = userName; + this.teamName = teamName; + } + + public String getUser() { + return userName; + } + + public String getTeam() { + return teamName; + } + } + + /** + * Tests that {@link UpdateTeamScoreFn} {@link org.apache.beam.sdk.transforms.DoFn} outputs + * correctly for one team. + */ + @Test + public void testScoreUpdatesOneTeam() { + + TestStream> createEvents = + TestStream.create(KvCoder.of(StringUtf8Coder.of(), AvroCoder.of(GameActionInfo.class))) + .advanceWatermarkTo(baseTime) + .addElements( + event(TestUser.RED_TWO, 99, Duration.standardSeconds(10)), + event(TestUser.RED_ONE, 1, Duration.standardSeconds(20)), + event(TestUser.RED_ONE, 0, Duration.standardSeconds(30)), + event(TestUser.RED_TWO, 100, Duration.standardSeconds(40)), + event(TestUser.RED_TWO, 201, Duration.standardSeconds(50))) + .advanceWatermarkToInfinity(); + + PCollection> teamScores = + p.apply(createEvents).apply(ParDo.of(new UpdateTeamScoreFn(100))); + + String redTeam = TestUser.RED_ONE.getTeam(); + + PAssert.that(teamScores) + .inWindow(GlobalWindow.INSTANCE) + .containsInAnyOrder(KV.of(redTeam, 100), KV.of(redTeam, 200), KV.of(redTeam, 401)); + + p.run().waitUntilFinish(); + } + + /** + * Tests that {@link UpdateTeamScoreFn} {@link org.apache.beam.sdk.transforms.DoFn} outputs + * correctly for multiple teams. + */ + @Test + public void testScoreUpdatesPerTeam() { + + TestStream> createEvents = + TestStream.create(KvCoder.of(StringUtf8Coder.of(), AvroCoder.of(GameActionInfo.class))) + .advanceWatermarkTo(baseTime) + .addElements( + event(TestUser.RED_ONE, 50, Duration.standardSeconds(10)), + event(TestUser.RED_TWO, 50, Duration.standardSeconds(20)), + event(TestUser.BLUE_ONE, 70, Duration.standardSeconds(30)), + event(TestUser.BLUE_TWO, 80, Duration.standardSeconds(40)), + event(TestUser.BLUE_TWO, 50, Duration.standardSeconds(50))) + .advanceWatermarkToInfinity(); + + PCollection> teamScores = + p.apply(createEvents).apply(ParDo.of(new UpdateTeamScoreFn(100))); + + String redTeam = TestUser.RED_ONE.getTeam(); + String blueTeam = TestUser.BLUE_ONE.getTeam(); + + PAssert.that(teamScores) + .inWindow(GlobalWindow.INSTANCE) + .containsInAnyOrder(KV.of(redTeam, 100), KV.of(blueTeam, 150), KV.of(blueTeam, 200)); + + p.run().waitUntilFinish(); + } + + /** + * Tests that {@link UpdateTeamScoreFn} {@link org.apache.beam.sdk.transforms.DoFn} outputs + * correctly per window and per key. + */ + @Test + public void testScoreUpdatesPerWindow() { + + TestStream> createEvents = + TestStream.create(KvCoder.of(StringUtf8Coder.of(), AvroCoder.of(GameActionInfo.class))) + .advanceWatermarkTo(baseTime) + .addElements( + event(TestUser.RED_ONE, 50, Duration.standardMinutes(1)), + event(TestUser.RED_TWO, 50, Duration.standardMinutes(2)), + event(TestUser.RED_ONE, 50, Duration.standardMinutes(3)), + event(TestUser.RED_ONE, 60, Duration.standardMinutes(6)), + event(TestUser.RED_TWO, 60, Duration.standardMinutes(7))) + .advanceWatermarkToInfinity(); + + Duration teamWindowDuration = Duration.standardMinutes(5); + + PCollection> teamScores = + p.apply(createEvents) + .apply(Window.>into(FixedWindows.of(teamWindowDuration))) + .apply(ParDo.of(new UpdateTeamScoreFn(100))); + + String redTeam = TestUser.RED_ONE.getTeam(); + String blueTeam = TestUser.BLUE_ONE.getTeam(); + + IntervalWindow window1 = new IntervalWindow(baseTime, teamWindowDuration); + IntervalWindow window2 = new IntervalWindow(window1.end(), teamWindowDuration); + + PAssert.that(teamScores).inWindow(window1).containsInAnyOrder(KV.of(redTeam, 100)); + + PAssert.that(teamScores).inWindow(window2).containsInAnyOrder(KV.of(redTeam, 120)); + + p.run().waitUntilFinish(); + } + + private TimestampedValue> event( + TestUser user, int score, Duration baseTimeOffset) { + return TimestampedValue.of( + KV.of( + user.getTeam(), + new GameActionInfo( + user.getUser(), user.getTeam(), score, baseTime.plus(baseTimeOffset).getMillis())), + baseTime.plus(baseTimeOffset)); + } +} diff --git a/model/pipeline/src/main/proto/beam_runner_api.proto b/model/pipeline/src/main/proto/beam_runner_api.proto index 204c408ba1a4c..fa5adec3dc419 100644 --- a/model/pipeline/src/main/proto/beam_runner_api.proto +++ b/model/pipeline/src/main/proto/beam_runner_api.proto @@ -411,7 +411,7 @@ message Parameter { message StateSpec { oneof spec { - ValueStateSpec value_spec = 1; + ReadModifyWriteStateSpec read_modify_write_spec = 1; BagStateSpec bag_spec = 2; CombiningStateSpec combining_spec = 3; MapStateSpec map_spec = 4; @@ -419,7 +419,7 @@ message StateSpec { } } -message ValueStateSpec { +message ReadModifyWriteStateSpec { string coder_id = 1; } diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java index 067a651e0b509..6b0648f71e1dc 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java @@ -422,14 +422,26 @@ public static RunnerApi.StateSpec translateStateSpec( return stateSpec.match( new StateSpec.Cases() { @Override - public RunnerApi.StateSpec dispatchValue(Coder valueCoder) { + public RunnerApi.StateSpec dispatchReadModifyWrite(Coder valueCoder) { return builder - .setValueSpec( - RunnerApi.ValueStateSpec.newBuilder() + .setReadModifyWriteSpec( + RunnerApi.ReadModifyWriteStateSpec.newBuilder() .setCoderId(registerCoderOrThrow(components, valueCoder))) .build(); } + @Override + public RunnerApi.StateSpec dispatchValue(Coder valueCoder) { + /* We are keeping this method for backward compatibility but we are using + ReadModifyWriteState under the hood. + */ + return builder + .setReadModifyWriteSpec( + RunnerApi.ReadModifyWriteStateSpec.newBuilder() + .setCoderId(registerCoderOrThrow(components, valueCoder))) + .build(); + } + @Override public RunnerApi.StateSpec dispatchBag(Coder elementCoder) { return builder @@ -475,8 +487,8 @@ public RunnerApi.StateSpec dispatchSet(Coder elementCoder) { static StateSpec fromProto(RunnerApi.StateSpec stateSpec, RehydratedComponents components) throws IOException { switch (stateSpec.getSpecCase()) { - case VALUE_SPEC: - return StateSpecs.value(components.getCoder(stateSpec.getValueSpec().getCoderId())); + case READ_MODIFY_WRITE_SPEC: + return StateSpecs.readModifyWrite(components.getCoder(stateSpec.getReadModifyWriteSpec().getCoderId())); case BAG_SPEC: return StateSpecs.bag(components.getCoder(stateSpec.getBagSpec().getElementCoderId())); case COMBINING_SPEC: diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java index 6ea3bde27b2ca..5614b94c1d282 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java @@ -210,7 +210,7 @@ public static class TestStateAndTimerTranslation { @Parameters(name = "{index}: {0}") public static Iterable> stateSpecs() { return ImmutableList.of( - StateSpecs.value(VarIntCoder.of()), + StateSpecs.readModifyWrite(VarIntCoder.of()), StateSpecs.bag(VarIntCoder.of()), StateSpecs.set(VarIntCoder.of()), StateSpecs.map(StringUtf8Coder.of(), VarIntCoder.of())); diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/InMemoryStateInternals.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/InMemoryStateInternals.java index 9628cff4b63dc..a0c499695bda4 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/InMemoryStateInternals.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/InMemoryStateInternals.java @@ -41,6 +41,7 @@ import org.apache.beam.sdk.state.State; import org.apache.beam.sdk.state.StateContext; import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.state.ReadModifyWriteState; import org.apache.beam.sdk.state.WatermarkHoldState; import org.apache.beam.sdk.transforms.Combine.CombineFn; import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext; @@ -124,6 +125,11 @@ public ValueState bindValue(StateTag> address, Coder cod return new InMemoryValue<>(coder); } + @Override + public ReadModifyWriteState bindReadModifyWrite(StateTag> address, Coder coder) { + return new InMemoryReadModifyWrite<>(coder); + } + @Override public BagState bindBag(final StateTag> address, Coder elemCoder) { return new InMemoryBag<>(elemCoder); @@ -166,9 +172,62 @@ CombiningState bindCombiningValueWithContext( } } - /** An {@link InMemoryState} implementation of {@link ValueState}. */ + + /** An {@link InMemoryState} implementation of {@link ReadModifyWriteState}. */ + public static final class InMemoryReadModifyWrite + implements ReadModifyWriteState, InMemoryState> { + private final Coder coder; + + private boolean isCleared = true; + private @Nullable T value = null; + + public InMemoryReadModifyWrite(Coder coder) { + this.coder = coder; + } + + @Override + public void clear() { + // Even though we're clearing we can't remove this from the in-memory state map, since + // other users may already have a handle on this Value. + value = null; + isCleared = true; + } + + @Override + public InMemoryReadModifyWrite readLater() { + return this; + } + + @Override + public T read() { + return value; + } + + @Override + public void write(T input) { + isCleared = false; + this.value = input; + } + + @Override + public InMemoryReadModifyWrite copy() { + InMemoryReadModifyWrite that = new InMemoryReadModifyWrite<>(coder); + if (!this.isCleared) { + that.isCleared = this.isCleared; + that.value = uncheckedClone(coder, this.value); + } + return that; + } + + @Override + public boolean isCleared() { + return isCleared; + } + } + + /** An {@link InMemoryState} implementation of {@link ReadModifyWriteState}. */ public static final class InMemoryValue - implements ValueState, InMemoryState> { + implements ValueState, InMemoryState> { private final Coder coder; private boolean isCleared = true; diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/MergingActiveWindowSet.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/MergingActiveWindowSet.java index 9d2d8478b8a81..e338f6f7408cd 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/MergingActiveWindowSet.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/MergingActiveWindowSet.java @@ -32,7 +32,7 @@ import javax.annotation.Nullable; import org.apache.beam.sdk.coders.MapCoder; import org.apache.beam.sdk.coders.SetCoder; -import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.state.ReadModifyWriteState; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.WindowFn; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting; @@ -60,19 +60,19 @@ public class MergingActiveWindowSet implements ActiveWi private final Map> originalActiveWindowToStateAddressWindows; /** Handle representing our state in the backend. */ - private final ValueState>> valueState; + private final ReadModifyWriteState>> readModifyWriteState; public MergingActiveWindowSet(WindowFn windowFn, StateInternals state) { this.windowFn = windowFn; - StateTag>>> tag = + StateTag>>> tag = StateTags.makeSystemTagInternal( - StateTags.value( + StateTags.readModifyWrite( "tree", MapCoder.of(windowFn.windowCoder(), SetCoder.of(windowFn.windowCoder())))); - valueState = state.state(StateNamespaces.global(), tag); + readModifyWriteState = state.state(StateNamespaces.global(), tag); // Little use trying to prefetch this state since the ReduceFnRunner // is stymied until it is available. - activeWindowToStateAddressWindows = emptyIfNull(valueState.read()); + activeWindowToStateAddressWindows = emptyIfNull(readModifyWriteState.read()); originalActiveWindowToStateAddressWindows = deepCopy(activeWindowToStateAddressWindows); } @@ -88,14 +88,14 @@ public void persist() { checkInvariants(); if (activeWindowToStateAddressWindows.isEmpty()) { // Force all persistent state to disappear. - valueState.clear(); + readModifyWriteState.clear(); return; } if (activeWindowToStateAddressWindows.equals(originalActiveWindowToStateAddressWindows)) { // No change. return; } - valueState.write(activeWindowToStateAddressWindows); + readModifyWriteState.write(activeWindowToStateAddressWindows); // No need to update originalActiveWindowToStateAddressWindows since this object is about to // become garbage. } diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/PaneInfoTracker.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/PaneInfoTracker.java index 1537ad54ebd75..9c6c940a90cb5 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/PaneInfoTracker.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/PaneInfoTracker.java @@ -20,7 +20,7 @@ import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; import org.apache.beam.sdk.state.ReadableState; -import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.state.ReadModifyWriteState; import org.apache.beam.sdk.transforms.windowing.AfterWatermark; import org.apache.beam.sdk.transforms.windowing.PaneInfo; import org.apache.beam.sdk.transforms.windowing.PaneInfo.PaneInfoCoder; @@ -43,8 +43,8 @@ public PaneInfoTracker(TimerInternals timerInternals) { } @VisibleForTesting - static final StateTag> PANE_INFO_TAG = - StateTags.makeSystemTagInternal(StateTags.value("pane", PaneInfoCoder.INSTANCE)); + static final StateTag> PANE_INFO_TAG = + StateTags.makeSystemTagInternal(StateTags.readModifyWrite("pane", PaneInfoCoder.INSTANCE)); public void clear(StateAccessor state) { state.access(PANE_INFO_TAG).clear(); diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/SideInputHandler.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/SideInputHandler.java index 41f72a79039ff..bbf6b7d9d2f54 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/SideInputHandler.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/SideInputHandler.java @@ -31,7 +31,7 @@ import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.SetCoder; import org.apache.beam.sdk.state.CombiningState; -import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.state.ReadModifyWriteState; import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.Materializations; import org.apache.beam.sdk.transforms.Materializations.MultimapView; @@ -78,7 +78,7 @@ public class SideInputHandler implements ReadyCheckingSideInputReader { availableWindowsTags; /** State tag for the actual contents of each side input per window. */ - private final Map, StateTag>>> sideInputContentsTags; + private final Map, StateTag>>> sideInputContentsTags; /** * Creates a new {@code SideInputHandler} for the given side inputs that uses the given {@code @@ -114,7 +114,7 @@ public SideInputHandler( availableWindowsTags.put(sideInput, availableTag); - StateTag>> stateTag = + StateTag>> stateTag = StateTags.value( "side-input-data-" + sideInput.getTagInternal().getId(), (Coder) IterableCoder.of(sideInput.getCoderInternal())); @@ -131,7 +131,7 @@ public void addSideInputValue(PCollectionView sideInput, WindowedValue windowCoder = (Coder) sideInput.getWindowingStrategyInternal().getWindowFn().windowCoder(); - StateTag>> stateTag = sideInputContentsTags.get(sideInput); + StateTag>> stateTag = sideInputContentsTags.get(sideInput); for (BoundedWindow window : value.getWindows()) { stateInternals @@ -172,9 +172,9 @@ public Iterable getIterable(PCollectionView view, BoundedWindow window Coder windowCoder = (Coder) view.getWindowingStrategyInternal().getWindowFn().windowCoder(); - StateTag>> stateTag = sideInputContentsTags.get(view); + StateTag>> stateTag = sideInputContentsTags.get(view); - ValueState> state = + ReadModifyWriteState> state = stateInternals.state(StateNamespaces.window(windowCoder, window), stateTag); Iterable elements = state.read(); diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java index ec73e7b766200..be9093c9f354f 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/SplittableParDoViaKeyedWorkItems.java @@ -33,7 +33,7 @@ import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.runners.PTransformOverrideFactory; import org.apache.beam.sdk.state.TimeDomain; -import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.state.ReadModifyWriteState; import org.apache.beam.sdk.state.WatermarkHoldState; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.GroupByKey; @@ -236,13 +236,13 @@ public static class ProcessFn * DoFn.ProcessElement} call and read during subsequent calls in response to timer firings, when * the original element is no longer available. */ - private final StateTag>> elementTag; + private final StateTag>> elementTag; /** * The state cell containing a restriction representing the unprocessed part of work for this * element. */ - private StateTag> restrictionTag; + private StateTag> restrictionTag; private final DoFn fn; private final Coder elementCoder; @@ -267,11 +267,11 @@ public ProcessFn( this.restrictionCoder = restrictionCoder; this.inputWindowingStrategy = inputWindowingStrategy; this.elementTag = - StateTags.value( + StateTags.readModifyWrite( "element", WindowedValue.getFullCoder( elementCoder, inputWindowingStrategy.getWindowFn().windowCoder())); - this.restrictionTag = StateTags.value("restriction", restrictionCoder); + this.restrictionTag = StateTags.readModifyWrite("restriction", restrictionCoder); } public void setStateInternalsFactory(StateInternalsFactory stateInternalsFactory) { @@ -354,9 +354,9 @@ public void processElement(final ProcessContext c) { stateNamespace = timer.getNamespace(); } - ValueState> elementState = + ReadModifyWriteState> elementState = stateInternals.state(stateNamespace, elementTag); - ValueState restrictionState = + ReadModifyWriteState restrictionState = stateInternals.state(stateNamespace, restrictionTag); WatermarkHoldState holdState = stateInternals.state(stateNamespace, watermarkHoldTag); diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTag.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTag.java index 021f8598a395f..7573ed07fa9ea 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTag.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTag.java @@ -28,6 +28,7 @@ import org.apache.beam.sdk.state.SetState; import org.apache.beam.sdk.state.State; import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.ReadModifyWriteState; import org.apache.beam.sdk.state.ValueState; import org.apache.beam.sdk.state.WatermarkHoldState; import org.apache.beam.sdk.transforms.Combine.CombineFn; @@ -77,6 +78,8 @@ public interface StateTag extends Serializable { public interface StateBinder { ValueState bindValue(StateTag> spec, Coder coder); + ReadModifyWriteState bindReadModifyWrite(StateTag> spec, Coder coder); + BagState bindBag(StateTag> spec, Coder elemCoder); SetState bindSet(StateTag> spec, Coder elemCoder); diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java index 8dd84c17f9584..0d27866fe1027 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java @@ -33,6 +33,7 @@ import org.apache.beam.sdk.state.StateSpec; import org.apache.beam.sdk.state.StateSpecs; import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.state.ReadModifyWriteState; import org.apache.beam.sdk.state.WatermarkHoldState; import org.apache.beam.sdk.transforms.Combine.CombineFn; import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext; @@ -68,6 +69,11 @@ private static StateBinder adaptTagBinder(final StateTag.StateBinder binder) { public ValueState bindValue(String id, StateSpec> spec, Coder coder) { return binder.bindValue(tagForSpec(id, spec), coder); } + + @Override + public ReadModifyWriteState bindReadModifyWrite(String id, StateSpec> spec, Coder coder) { + return binder.bindReadModifyWrite(tagForSpec(id, spec), coder); + } @Override public BagState bindBag(String id, StateSpec> spec, Coder elemCoder) { @@ -138,6 +144,12 @@ public static StateTag tagForSpec( return new SimpleStateTag<>(new StructuredId(id), spec); } + /** Create a simple state tag for values of type {@code T}. */ + public static StateTag> readModifyWrite(String id, Coder valueCoder) { + return new SimpleStateTag<>(new StructuredId(id), StateSpecs.readModifyWrite(valueCoder)); + } + + /** Create a simple state tag for values of type {@code T}. */ public static StateTag> value(String id, Coder valueCoder) { return new SimpleStateTag<>(new StructuredId(id), StateSpecs.value(valueCoder)); diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/TriggerStateMachineRunner.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/TriggerStateMachineRunner.java index ac642ceae7c0f..3776dc99667da 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/TriggerStateMachineRunner.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/triggers/TriggerStateMachineRunner.java @@ -29,7 +29,7 @@ import org.apache.beam.runners.core.StateTags; import org.apache.beam.sdk.coders.BitSetCoder; import org.apache.beam.sdk.state.Timers; -import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.state.ReadModifyWriteState; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; @@ -58,8 +58,8 @@ */ public class TriggerStateMachineRunner { @VisibleForTesting - public static final StateTag> FINISHED_BITS_TAG = - StateTags.makeSystemTagInternal(StateTags.value("closed", BitSetCoder.of())); + public static final StateTag> FINISHED_BITS_TAG = + StateTags.makeSystemTagInternal(StateTags.readModifyWrite("closed", BitSetCoder.of())); private final ExecutableTriggerStateMachine rootTrigger; private final TriggerStateMachineContextFactory contextFactory; @@ -72,7 +72,7 @@ public TriggerStateMachineRunner( this.contextFactory = contextFactory; } - private FinishedTriggersBitSet readFinishedBits(ValueState state) { + private FinishedTriggersBitSet readFinishedBits(ReadModifyWriteState state) { if (!isFinishedSetNeeded()) { // If no trigger in the tree will ever have finished bits, then we don't need to read them. // So that the code can be agnostic to that fact, we create a BitSet that is all 0 (not @@ -86,7 +86,7 @@ private FinishedTriggersBitSet readFinishedBits(ValueState state) { : FinishedTriggersBitSet.fromBitSet(bitSet); } - private void clearFinishedBits(ValueState state) { + private void clearFinishedBits(ReadModifyWriteState state) { if (!isFinishedSetNeeded()) { // Nothing to clear. return; @@ -138,7 +138,7 @@ public void processValue(W window, Instant timestamp, Timers timers, StateAccess public void prefetchForMerge( W window, Collection mergingWindows, MergingStateAccessor state) { if (isFinishedSetNeeded()) { - for (ValueState value : state.accessInEachMergingWindow(FINISHED_BITS_TAG).values()) { + for (ReadModifyWriteState value : state.accessInEachMergingWindow(FINISHED_BITS_TAG).values()) { value.readLater(); } } @@ -155,7 +155,7 @@ public void onMerge(W window, Timers timers, MergingStateAccessor state) t // And read the finished bits in each merging window. ImmutableMap.Builder builder = ImmutableMap.builder(); - for (Map.Entry> entry : + for (Map.Entry> entry : state.accessInEachMergingWindow(FINISHED_BITS_TAG).entrySet()) { // Don't need to clone these, since the trigger context doesn't allow modification builder.put(entry.getKey(), readFinishedBits(entry.getValue())); @@ -197,7 +197,7 @@ private void persistFinishedSet( return; } - ValueState finishedSetState = state.access(FINISHED_BITS_TAG); + ReadModifyWriteState finishedSetState = state.access(FINISHED_BITS_TAG); if (!readFinishedBits(finishedSetState).equals(modifiedFinishedSet)) { if (modifiedFinishedSet.getBitSet().isEmpty()) { finishedSetState.clear(); diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CopyOnAccessInMemoryStateInternals.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CopyOnAccessInMemoryStateInternals.java index 0a64a4b8ed6cc..513e8ba7658ba 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CopyOnAccessInMemoryStateInternals.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CopyOnAccessInMemoryStateInternals.java @@ -45,6 +45,7 @@ import org.apache.beam.sdk.state.StateContext; import org.apache.beam.sdk.state.StateContexts; import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.state.ReadModifyWriteState; import org.apache.beam.sdk.state.WatermarkHoldState; import org.apache.beam.sdk.transforms.Combine.CombineFn; import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext; @@ -278,6 +279,19 @@ public WatermarkHoldState bindWatermark( } } + @Override + public ReadModifyWriteState bindReadModifyWrite(StateTag> address, Coder coder) { + if (containedInUnderlying(namespace, address)) { + @SuppressWarnings("unchecked") + InMemoryState> existingState = + (InMemoryState>) + underlying.get().get(namespace, address, c); + return existingState.copy(); + } else { + return new InMemoryValue<>(coder); + } + } + @Override public ValueState bindValue(StateTag> address, Coder coder) { if (containedInUnderlying(namespace, address)) { @@ -406,6 +420,11 @@ public WatermarkHoldState bindWatermark( return underlying.get(namespace, address, c); } + @Override + public ReadModifyWriteState bindReadModifyWrite(StateTag> address, Coder coder) { + return underlying.get(namespace, address, c); + } + @Override public ValueState bindValue(StateTag> address, Coder coder) { return underlying.get(namespace, address, c); diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkBroadcastStateInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkBroadcastStateInternals.java index 69ff465a333b4..e9e32ddc1bdd6 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkBroadcastStateInternals.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkBroadcastStateInternals.java @@ -41,6 +41,7 @@ import org.apache.beam.sdk.state.State; import org.apache.beam.sdk.state.StateContext; import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.state.ReadModifyWriteState; import org.apache.beam.sdk.state.WatermarkHoldState; import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.CombineWithContext; @@ -93,6 +94,12 @@ public ValueState bindValue(StateTag> address, Coder return new FlinkBroadcastValueState<>(stateBackend, address, namespace, coder); } + @Override + public ReadModifyWriteState bindReadModifyWrite(StateTag> address, Coder coder) { + + return new FlinkBroadcastReadModifyWriteState<>(stateBackend, address, namespace, coder); + } + @Override public BagState bindBag(StateTag> address, Coder elemCoder) { @@ -273,14 +280,73 @@ void clearInternal() { } private class FlinkBroadcastValueState extends AbstractBroadcastState - implements ValueState { + implements ValueState { private final StateNamespace namespace; private final StateTag> address; FlinkBroadcastValueState( + OperatorStateBackend flinkStateBackend, + StateTag> address, + StateNamespace namespace, + Coder coder) { + super(flinkStateBackend, address.getId(), namespace, coder); + + this.namespace = namespace; + this.address = address; + } + + @Override + public void write(T input) { + writeInternal(input); + } + + @Override + public ValueState readLater() { + return this; + } + + @Override + public T read() { + return readInternal(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + FlinkBroadcastValueState that = (FlinkBroadcastValueState) o; + + return namespace.equals(that.namespace) && address.equals(that.address); + } + + @Override + public int hashCode() { + int result = namespace.hashCode(); + result = 31 * result + address.hashCode(); + return result; + } + + @Override + public void clear() { + clearInternal(); + } + } + + private class FlinkBroadcastReadModifyWriteState extends AbstractBroadcastState + implements ReadModifyWriteState { + + private final StateNamespace namespace; + private final StateTag> address; + + FlinkBroadcastReadModifyWriteState( OperatorStateBackend flinkStateBackend, - StateTag> address, + StateTag> address, StateNamespace namespace, Coder coder) { super(flinkStateBackend, address.getId(), namespace, coder); @@ -295,7 +361,7 @@ public void write(T input) { } @Override - public ValueState readLater() { + public ReadModifyWriteState readLater() { return this; } @@ -313,7 +379,7 @@ public boolean equals(Object o) { return false; } - FlinkBroadcastValueState that = (FlinkBroadcastValueState) o; + FlinkBroadcastReadModifyWriteState that = (FlinkBroadcastReadModifyWriteState) o; return namespace.equals(that.namespace) && address.equals(that.address); } diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java index 2eb450852925f..874f291e801bb 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java @@ -43,6 +43,7 @@ import org.apache.beam.sdk.state.StateContext; import org.apache.beam.sdk.state.StateSpec; import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.state.ReadModifyWriteState; import org.apache.beam.sdk.state.WatermarkHoldState; import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.CombineWithContext; @@ -153,10 +154,16 @@ private FlinkStateBinder( @Override public ValueState bindValue( - String id, StateSpec> spec, Coder coder) { + String id, StateSpec> spec, Coder coder) { return new FlinkValueState<>(flinkStateBackend, id, namespace, coder); } + @Override + public ReadModifyWriteState bindReadModifyWrite( + String id, StateSpec> spec, Coder coder) { + return new FlinkReadModifyWriteState<>(flinkStateBackend, id, namespace, coder); + } + @Override public BagState bindBag(String id, StateSpec> spec, Coder elemCoder) { return new FlinkBagState<>(flinkStateBackend, id, namespace, elemCoder); @@ -222,6 +229,89 @@ private static class FlinkValueState implements ValueState { private final KeyedStateBackend flinkStateBackend; FlinkValueState( + KeyedStateBackend flinkStateBackend, + String stateId, + StateNamespace namespace, + Coder coder) { + + this.namespace = namespace; + this.stateId = stateId; + this.flinkStateBackend = flinkStateBackend; + + flinkStateDescriptor = new ValueStateDescriptor<>(stateId, new CoderTypeSerializer<>(coder)); + } + + @Override + public void write(T input) { + try { + flinkStateBackend + .getPartitionedState( + namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .update(input); + } catch (Exception e) { + throw new RuntimeException("Error updating state.", e); + } + } + + @Override + public ValueState readLater() { + return this; + } + + @Override + public T read() { + try { + return flinkStateBackend + .getPartitionedState( + namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .value(); + } catch (Exception e) { + throw new RuntimeException("Error reading state.", e); + } + } + + @Override + public void clear() { + try { + flinkStateBackend + .getPartitionedState( + namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor) + .clear(); + } catch (Exception e) { + throw new RuntimeException("Error clearing state.", e); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + FlinkValueState that = (FlinkValueState) o; + + return namespace.equals(that.namespace) && stateId.equals(that.stateId); + } + + @Override + public int hashCode() { + int result = namespace.hashCode(); + result = 31 * result + stateId.hashCode(); + return result; + } + } + + private static class FlinkReadModifyWriteState implements ReadModifyWriteState { + + private final StateNamespace namespace; + private final String stateId; + private final ValueStateDescriptor flinkStateDescriptor; + private final KeyedStateBackend flinkStateBackend; + + FlinkReadModifyWriteState( KeyedStateBackend flinkStateBackend, String stateId, StateNamespace namespace, @@ -247,7 +337,7 @@ public void write(T input) { } @Override - public ValueState readLater() { + public ReadModifyWriteState readLater() { return this; } @@ -284,7 +374,7 @@ public boolean equals(Object o) { return false; } - FlinkValueState that = (FlinkValueState) o; + FlinkReadModifyWriteState that = (FlinkReadModifyWriteState) o; return namespace.equals(that.namespace) && stateId.equals(that.stateId); } @@ -1246,6 +1336,19 @@ public EarlyBinder(KeyedStateBackend keyedStateBackend) { @Override public ValueState bindValue(String id, StateSpec> spec, Coder coder) { + try { + keyedStateBackend.getOrCreateKeyedState( + StringSerializer.INSTANCE, + new ValueStateDescriptor<>(id, new CoderTypeSerializer<>(coder))); + } catch (Exception e) { + throw new RuntimeException(e); + } + + return null; + } + + @Override + public ReadModifyWriteState bindReadModifyWrite(String id, StateSpec> spec, Coder coder) { try { keyedStateBackend.getOrCreateKeyedState( StringSerializer.INSTANCE, diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputFetcher.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputFetcher.java index 7c5babf9ad866..4237307dc0594 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputFetcher.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputFetcher.java @@ -41,6 +41,7 @@ import org.apache.beam.sdk.coders.SetCoder; import org.apache.beam.sdk.state.BagState; import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.state.ReadModifyWriteState; import org.apache.beam.sdk.state.WatermarkHoldState; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.WindowFn; @@ -61,7 +62,7 @@ public class StreamingSideInputFetcher { private final StateTag>> elementsAddr; private final StateTag> timersAddr; private final StateTag watermarkHoldingAddr; - private final StateTag>>> blockedMapAddr; + private final StateTag>>> blockedMapAddr; private Map> blockedMap = null; // lazily initialized @@ -95,9 +96,9 @@ public StreamingSideInputFetcher( @VisibleForTesting static - StateTag>>> blockedMapAddr( + StateTag>>> blockedMapAddr( Coder mainWindowCoder) { - return StateTags.value( + return StateTags.readModifyWrite( "blockedMap", MapCoder.of(mainWindowCoder, SetCoder.of(GlobalDataRequestCoder.of()))); } @@ -226,7 +227,7 @@ public void persist() { return; } - ValueState>> mapState = + ReadModifyWriteState>> mapState = stepContext.stateInternals().state(StateNamespaces.global(), blockedMapAddr); if (blockedMap.isEmpty()) { // Avoid storing the empty map so we don't leave unnecessary state behind from processing diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java index 9c9779e8176c5..dbe77841a932c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java @@ -48,6 +48,7 @@ import org.apache.beam.sdk.state.State; import org.apache.beam.sdk.state.StateContext; import org.apache.beam.sdk.state.StateContexts; +import org.apache.beam.sdk.state.ReadModifyWriteState; import org.apache.beam.sdk.state.ValueState; import org.apache.beam.sdk.state.WatermarkHoldState; import org.apache.beam.sdk.transforms.Combine.CombineFn; @@ -176,6 +177,13 @@ public ValueState bindValue(StateTag> address, Coder cod result.initializeForWorkItem(reader, scopedReadStateSupplier); return result; } + + @Override + public ReadModifyWriteState bindReadModifyWrite(StateTag> address, Coder coder) { + // TODO: need to implement this in generic way which can be used. + throw new UnsupportedOperationException( + String.format("%s is not supported", ReadModifyWriteState.class.getSimpleName())); + } }; } } diff --git a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/splittabledofn/SDFFeederViaStateAndTimers.java b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/splittabledofn/SDFFeederViaStateAndTimers.java index 38a885f2a171b..817f6cb44595e 100644 --- a/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/splittabledofn/SDFFeederViaStateAndTimers.java +++ b/runners/java-fn-execution/src/main/java/org/apache/beam/runners/fnexecution/splittabledofn/SDFFeederViaStateAndTimers.java @@ -35,7 +35,7 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.state.TimeDomain; -import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.state.ReadModifyWriteState; import org.apache.beam.sdk.state.WatermarkHoldState; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; @@ -64,11 +64,11 @@ public class SDFFeederViaStateAndTimers { private StateNamespace stateNamespace; - private final StateTag>>> seedTag; - private ValueState>> seedState; + private final StateTag>>> seedTag; + private ReadModifyWriteState>> seedState; - private final StateTag> restrictionTag; - private ValueState restrictionState; + private final StateTag> restrictionTag; + private ReadModifyWriteState restrictionState; private StateTag watermarkHoldTag = StateTags.makeSystemTagInternal( @@ -91,8 +91,8 @@ public SDFFeederViaStateAndTimers( this.windowCoder = windowCoder; this.elementRestrictionWireCoder = FullWindowedValueCoder.of(KvCoder.of(elementWireCoder, restrictionWireCoder), windowCoder); - this.seedTag = StateTags.value("seed", elementRestrictionWireCoder); - this.restrictionTag = StateTags.value("restriction", restrictionWireCoder); + this.seedTag = StateTags.readModifyWrite("seed", elementRestrictionWireCoder); + this.restrictionTag = StateTags.readModifyWrite("restriction", restrictionWireCoder); } /** Passes the initial element/restriction pair. */ diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/ReadModifyWriteState.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/ReadModifyWriteState.java new file mode 100644 index 0000000000000..f749ace082505 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/ReadModifyWriteState.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.state; + +import org.apache.beam.sdk.annotations.Experimental; +import org.apache.beam.sdk.annotations.Experimental.Kind; + +/** + * A {@link ReadableState} cell containing a single value. + * + * @param The type of value being stored. + */ +@Experimental(Kind.STATE) +public interface ReadModifyWriteState extends ReadableState, State { + /** Set the value. */ + void write(T input); + + @Override + ReadModifyWriteState readLater(); +} diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/State.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/State.java index a52319d14e2d9..00c3b3530251b 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/State.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/State.java @@ -24,7 +24,7 @@ * A state cell, supporting a {@link #clear()} operation. * *

Specific types of state add appropriate accessors for reading and writing values, see {@link - * ValueState}, {@link BagState}, and {@link GroupingState}. + * ReadModifyWritState}, {@link BagState}, and {@link GroupingState}. */ @Experimental(Kind.STATE) public interface State { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateBinder.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateBinder.java index e9f37fe66ff2b..3d98273517283 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateBinder.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateBinder.java @@ -11,7 +11,7 @@ * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, eithzer express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ @@ -32,6 +32,8 @@ public interface StateBinder { ValueState bindValue(String id, StateSpec> spec, Coder coder); + ReadModifyWriteState bindReadModifyWrite(String id, StateSpec> spec, Coder coder); + BagState bindBag(String id, StateSpec> spec, Coder elemCoder); SetState bindSet(String id, StateSpec> spec, Coder elemCoder); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpec.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpec.java index 4cbb834898fcd..994eb920b4aa8 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpec.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpec.java @@ -72,8 +72,11 @@ public interface StateSpec extends Serializable { /** Cases for doing a "switch" on the type of {@link StateSpec}. */ interface Cases { + ResultT dispatchValue(Coder valueCoder); + ResultT dispatchReadModifyWrite(Coder valueCoder); + ResultT dispatchBag(Coder elementCoder); ResultT dispatchCombining(Combine.CombineFn combineFn, Coder accumCoder); @@ -92,6 +95,11 @@ public ResultT dispatchValue(Coder valueCoder) { return dispatchDefault(); } + @Override + public ResultT dispatchReadModifyWrite(Coder valueCoder) { + return dispatchDefault(); + } + @Override public ResultT dispatchBag(Coder elementCoder) { return dispatchDefault(); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java index 72277273774a9..d8800d186e767 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java @@ -44,7 +44,28 @@ private StateSpecs() {} * *

This method attempts to infer the accumulator coder automatically. * - * @see #value(Coder) + * @see #readModifyWrite(Coder) + */ + public static StateSpec> readModifyWrite() { + return new ReadModifyWriteStateSpec<>(null); + } + + /** + * Identical to {@link #value()}, but with a coder explicitly supplied. + * + *

If automatic coder inference fails, use this method. + */ + public static StateSpec> readModifyWrite(Coder valueCoder) { + checkArgument(valueCoder != null, "valueCoder should not be null. Consider readModifyWrite() instead"); + return new ReadModifyWriteStateSpec<>(valueCoder); + } + + /** + * Create a {@link StateSpec} for a single value of type {@code T}. + * + *

This method attempts to infer the accumulator coder automatically. + * + * @see #readModifyWrite(Coder) */ public static StateSpec> value() { return new ValueStateSpec<>(null); @@ -56,7 +77,7 @@ public static StateSpec> value() { *

If automatic coder inference fails, use this method. */ public static StateSpec> value(Coder valueCoder) { - checkArgument(valueCoder != null, "valueCoder should not be null. Consider value() instead"); + checkArgument(valueCoder != null, "valueCoder should not be null. Consider readModifyWrite() instead"); return new ValueStateSpec<>(valueCoder); } @@ -258,46 +279,96 @@ public static StateSpec> convertToBag } } + private static abstract class AbstractReadModifyWriteStateSpec { + + @Nullable protected Coder coder; + + protected AbstractReadModifyWriteStateSpec(@Nullable Coder coder) { + this.coder = coder; + } + + @SuppressWarnings("unchecked") + public void offerCoders(Coder[] coders) { + if (this.coder == null && coders[0] != null) { + this.coder = (Coder) coders[0]; + } + } + + public void finishSpecifying() { + if (coder == null) { + throw new IllegalStateException( + "Unable to infer a coder for ReadModifyWriteState and no Coder" + + " was specified. Please set a coder by either invoking" + + " StateSpecs.readModifyWrite(Coder valueCoder) or by registering the coder in the" + + " Pipeline's CoderRegistry."); + } + } + + @Override + public int hashCode() { + return Objects.hash(getClass(), coder); + } + } + + /** * A specification for a state cell holding a settable value of type {@code T}. * *

Includes the coder for {@code T}. */ - private static class ValueStateSpec implements StateSpec> { - - @Nullable private Coder coder; + private static class ReadModifyWriteStateSpec extends AbstractReadModifyWriteStateSpec + implements StateSpec> { - private ValueStateSpec(@Nullable Coder coder) { - this.coder = coder; + private ReadModifyWriteStateSpec(@Nullable Coder coder) { + super(coder); } @Override - public ValueState bind(String id, StateBinder visitor) { - return visitor.bindValue(id, this, coder); + public ReadModifyWriteState bind(String id, StateBinder visitor) { + return visitor.bindReadModifyWrite(id, this, coder); } @Override public ResultT match(Cases cases) { - return cases.dispatchValue(coder); + return cases.dispatchReadModifyWrite(coder); } - @SuppressWarnings("unchecked") @Override - public void offerCoders(Coder[] coders) { - if (this.coder == null && coders[0] != null) { - this.coder = (Coder) coders[0]; + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + + if (!(obj instanceof ReadModifyWriteStateSpec)) { + return false; } + + ReadModifyWriteStateSpec that = (ReadModifyWriteStateSpec) obj; + return Objects.equals(this.coder, that.coder); + } + + } + + /** + * A specification for a state cell holding a settable value of type {@code T}. + * + *

Includes the coder for {@code T}. + */ + private static class ValueStateSpec extends AbstractReadModifyWriteStateSpec + implements StateSpec> { + + private ValueStateSpec(@Nullable Coder coder) { + super(coder); } @Override - public void finishSpecifying() { - if (coder == null) { - throw new IllegalStateException( - "Unable to infer a coder for ValueState and no Coder" - + " was specified. Please set a coder by either invoking" - + " StateSpecs.value(Coder valueCoder) or by registering the coder in the" - + " Pipeline's CoderRegistry."); - } + public ValueState bind(String id, StateBinder visitor) { + return visitor.bindValue(id, this, coder); + } + + @Override + public ResultT match(Cases cases) { + return cases.dispatchValue(coder); } @Override @@ -313,11 +384,6 @@ public boolean equals(Object obj) { ValueStateSpec that = (ValueStateSpec) obj; return Objects.equals(this.coder, that.coder); } - - @Override - public int hashCode() { - return Objects.hash(getClass(), coder); - } } /** diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/ValueState.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/ValueState.java index 0562c89dde448..9905a96c09970 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/ValueState.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/ValueState.java @@ -26,10 +26,6 @@ * @param The type of value being stored. */ @Experimental(Kind.STATE) -public interface ValueState extends ReadableState, State { - /** Set the value. */ - void write(T input); - - @Override - ValueState readLater(); +public interface ValueState extends ReadModifyWriteState { + // This is deprecated please use ReadModifyWriteState } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java index 2352e3abb5171..ccbc5c6f21240 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java @@ -362,13 +362,13 @@ public interface MultiOutputReceiver { *

{@literal new DoFn, Baz>()} {
    *
    *  {@literal @StateId("my-state-id")}
-   *  {@literal private final StateSpec>} myStateSpec =
+   *  {@literal private final StateSpec>} myStateSpec =
    *       StateSpecs.value(new MyStateCoder());
    *
    *  {@literal @ProcessElement}
    *   public void processElement(
    *       {@literal @Element InputT element},
-   *      {@literal @StateId("my-state-id") ValueState myState}) {
+   *      {@literal @StateId("my-state-id") ReadModifyWriteState myState}) {
    *     myState.read();
    *     myState.write(...);
    *   }
diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupIntoBatches.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupIntoBatches.java
index 3f724a9d1f3fd..03e56798f1294 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupIntoBatches.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/GroupIntoBatches.java
@@ -29,7 +29,7 @@
 import org.apache.beam.sdk.state.Timer;
 import org.apache.beam.sdk.state.TimerSpec;
 import org.apache.beam.sdk.state.TimerSpecs;
-import org.apache.beam.sdk.state.ValueState;
+import org.apache.beam.sdk.state.ReadModifyWriteState;
 import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.values.KV;
 import org.apache.beam.sdk.values.PCollection;
@@ -122,7 +122,7 @@ static class GroupIntoBatchesDoFn
     private final StateSpec> numElementsInBatchSpec;
 
     @StateId(KEY_ID)
-    private final StateSpec> keySpec;
+    private final StateSpec> keySpec;
 
     private final long prefetchFrequency;
 
@@ -149,7 +149,7 @@ public long apply(long left, long right) {
                 }
               });
 
-      this.keySpec = StateSpecs.value(inputKeyCoder);
+      this.keySpec = StateSpecs.readModifyWrite(inputKeyCoder);
       // prefetch every 20% of batchSize elements. Do not prefetch if batchSize is too little
       this.prefetchFrequency = ((batchSize / 5) <= 1) ? Long.MAX_VALUE : (batchSize / 5);
     }
@@ -159,7 +159,7 @@ public void processElement(
         @TimerId(END_OF_WINDOW_ID) Timer timer,
         @StateId(BATCH_ID) BagState batch,
         @StateId(NUM_ELEMENTS_IN_BATCH_ID) CombiningState numElementsInBatch,
-        @StateId(KEY_ID) ValueState key,
+        @StateId(KEY_ID) ReadModifyWriteState key,
         @Element KV element,
         BoundedWindow window,
         OutputReceiver>> receiver) {
@@ -190,7 +190,7 @@ public void processElement(
     public void onTimerCallback(
         OutputReceiver>> receiver,
         @Timestamp Instant timestamp,
-        @StateId(KEY_ID) ValueState key,
+        @StateId(KEY_ID) ReadModifyWriteState key,
         @StateId(BATCH_ID) BagState batch,
         @StateId(NUM_ELEMENTS_IN_BATCH_ID) CombiningState numElementsInBatch,
         BoundedWindow window) {
@@ -203,7 +203,7 @@ public void onTimerCallback(
 
     private void flushBatch(
         OutputReceiver>> receiver,
-        ValueState key,
+        ReadModifyWriteState key,
         BagState batch,
         CombiningState numElementsInBatch) {
       Iterable values = batch.read();
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoLifecycleTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoLifecycleTest.java
index 06856442b7d60..d225cfeccb2f9 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoLifecycleTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoLifecycleTest.java
@@ -38,7 +38,7 @@
 import java.util.concurrent.atomic.AtomicReference;
 import org.apache.beam.sdk.state.StateSpec;
 import org.apache.beam.sdk.state.StateSpecs;
-import org.apache.beam.sdk.state.ValueState;
+import org.apache.beam.sdk.state.ReadModifyWriteState;
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.testing.UsesParDoLifecycle;
 import org.apache.beam.sdk.testing.UsesStatefulParDo;
@@ -160,7 +160,7 @@ private static class CallSequenceEnforcingStatefulFn
     private static final String STATE_ID = "foo";
 
     @StateId(STATE_ID)
-    private final StateSpec> valueSpec = StateSpecs.value();
+    private final StateSpec> valueSpec = StateSpecs.value();
   }
 
   @Test
@@ -429,7 +429,7 @@ private static class ExceptionThrowingStatefulFn extends ExceptionThrowing
     private static final String STATE_ID = "foo";
 
     @StateId(STATE_ID)
-    private final StateSpec> valueSpec = StateSpecs.value();
+    private final StateSpec> valueSpec = StateSpecs.value();
 
     private ExceptionThrowingStatefulFn(MethodForException toThrow) {
       super(toThrow);
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
index 724bc3f01491f..7d661e60c7e65 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
@@ -74,7 +74,7 @@
 import org.apache.beam.sdk.state.Timer;
 import org.apache.beam.sdk.state.TimerSpec;
 import org.apache.beam.sdk.state.TimerSpecs;
-import org.apache.beam.sdk.state.ValueState;
+import org.apache.beam.sdk.state.ReadModifyWriteState;
 import org.apache.beam.sdk.testing.DataflowPortabilityApiUnsupported;
 import org.apache.beam.sdk.testing.NeedsRunner;
 import org.apache.beam.sdk.testing.PAssert;
@@ -1399,19 +1399,19 @@ public Duration getAllowedTimestampSkew() {
   public static class StateTests extends SharedTestBase implements Serializable {
     @Test
     @Category({ValidatesRunner.class, UsesStatefulParDo.class})
-    public void testValueStateSimple() {
+    public void testReadModifyWriteStateSimple() {
       final String stateId = "foo";
 
       DoFn, Integer> fn =
           new DoFn, Integer>() {
 
             @StateId(stateId)
-            private final StateSpec> intState =
+            private final StateSpec> intState =
                 StateSpecs.value(VarIntCoder.of());
 
             @ProcessElement
             public void processElement(
-                @StateId(stateId) ValueState state, OutputReceiver r) {
+                @StateId(stateId) ReadModifyWriteState state, OutputReceiver r) {
               Integer currentValue = MoreObjects.firstNonNull(state.read(), 0);
               r.output(currentValue);
               state.write(currentValue + 1);
@@ -1429,20 +1429,20 @@ public void processElement(
 
     @Test
     @Category({ValidatesRunner.class, UsesStatefulParDo.class})
-    public void testValueStateDedup() {
+    public void testReadModifyWriteStateDedup() {
       final String stateId = "foo";
 
       DoFn, Integer> onePerKey =
           new DoFn, Integer>() {
 
             @StateId(stateId)
-            private final StateSpec> seenSpec =
+            private final StateSpec> seenSpec =
                 StateSpecs.value(VarIntCoder.of());
 
             @ProcessElement
             public void processElement(
                 @Element KV element,
-                @StateId(stateId) ValueState seenState,
+                @StateId(stateId) ReadModifyWriteState seenState,
                 OutputReceiver r) {
               Integer seen = MoreObjects.firstNonNull(seenState.read(), 0);
 
@@ -1485,11 +1485,11 @@ public void testStateNotKeyed() {
           new DoFn() {
 
             @StateId(stateId)
-            private final StateSpec> intState = StateSpecs.value();
+            private final StateSpec> intState = StateSpecs.value();
 
             @ProcessElement
             public void processElement(
-                ProcessContext c, @StateId(stateId) ValueState state) {}
+                ProcessContext c, @StateId(stateId) ReadModifyWriteState state) {}
           };
 
       thrown.expect(IllegalArgumentException.class);
@@ -1508,11 +1508,11 @@ public void testStateNotDeterministic() {
           new DoFn, Integer>() {
 
             @StateId(stateId)
-            private final StateSpec> intState = StateSpecs.value();
+            private final StateSpec> intState = StateSpecs.value();
 
             @ProcessElement
             public void processElement(
-                ProcessContext c, @StateId(stateId) ValueState state) {}
+                ProcessContext c, @StateId(stateId) ReadModifyWriteState state) {}
           };
 
       thrown.expect(IllegalArgumentException.class);
@@ -1539,12 +1539,12 @@ public void testCoderInferenceOfList() {
           new DoFn, List>() {
 
             @StateId(stateId)
-            private final StateSpec>> intState = StateSpecs.value();
+            private final StateSpec>> intState = StateSpecs.value();
 
             @ProcessElement
             public void processElement(
                 @Element KV element,
-                @StateId(stateId) ValueState> state,
+                @StateId(stateId) ReadModifyWriteState> state,
                 OutputReceiver> r) {
               MyInteger myInteger = new MyInteger(element.getValue());
               List currentValue = state.read();
@@ -1574,19 +1574,19 @@ public void processElement(
       UsesStatefulParDo.class,
       DataflowPortabilityApiUnsupported.class
     })
-    public void testValueStateFixedWindows() {
+    public void testReadModifyWriteStateFixedWindows() {
       final String stateId = "foo";
 
       DoFn, Integer> fn =
           new DoFn, Integer>() {
 
             @StateId(stateId)
-            private final StateSpec> intState =
+            private final StateSpec> intState =
                 StateSpecs.value(VarIntCoder.of());
 
             @ProcessElement
             public void processElement(
-                @StateId(stateId) ValueState state, OutputReceiver r) {
+                @StateId(stateId) ReadModifyWriteState state, OutputReceiver r) {
               Integer currentValue = MoreObjects.firstNonNull(state.read(), 0);
               r.output(currentValue);
               state.write(currentValue + 1);
@@ -1622,19 +1622,19 @@ public void processElement(
      */
     @Test
     @Category({ValidatesRunner.class, UsesStatefulParDo.class})
-    public void testValueStateSameId() {
+    public void testReadModifyWriteStateSameId() {
       final String stateId = "foo";
 
       DoFn, KV> fn =
           new DoFn, KV>() {
 
             @StateId(stateId)
-            private final StateSpec> intState =
+            private final StateSpec> intState =
                 StateSpecs.value(VarIntCoder.of());
 
             @ProcessElement
             public void processElement(
-                @StateId(stateId) ValueState state,
+                @StateId(stateId) ReadModifyWriteState state,
                 OutputReceiver> r) {
               Integer currentValue = MoreObjects.firstNonNull(state.read(), 0);
               r.output(KV.of("sizzle", currentValue));
@@ -1646,12 +1646,12 @@ public void processElement(
           new DoFn, Integer>() {
 
             @StateId(stateId)
-            private final StateSpec> intState =
+            private final StateSpec> intState =
                 StateSpecs.value(VarIntCoder.of());
 
             @ProcessElement
             public void processElement(
-                @StateId(stateId) ValueState state, OutputReceiver r) {
+                @StateId(stateId) ReadModifyWriteState state, OutputReceiver r) {
               Integer currentValue = MoreObjects.firstNonNull(state.read(), 13);
               r.output(currentValue);
               state.write(currentValue + 13);
@@ -1677,7 +1677,7 @@ public void processElement(
       UsesStatefulParDo.class,
       DataflowPortabilityApiUnsupported.class
     })
-    public void testValueStateTaggedOutput() {
+    public void testReadModifyWriteStateTaggedOutput() {
       final String stateId = "foo";
 
       final TupleTag evenTag = new TupleTag() {};
@@ -1687,12 +1687,12 @@ public void testValueStateTaggedOutput() {
           new DoFn, Integer>() {
 
             @StateId(stateId)
-            private final StateSpec> intState =
+            private final StateSpec> intState =
                 StateSpecs.value(VarIntCoder.of());
 
             @ProcessElement
             public void processElement(
-                @StateId(stateId) ValueState state, MultiOutputReceiver r) {
+                @StateId(stateId) ReadModifyWriteState state, MultiOutputReceiver r) {
               Integer currentValue = MoreObjects.firstNonNull(state.read(), 0);
               if (currentValue % 2 == 0) {
                 r.get(evenTag).output(currentValue);
@@ -2725,18 +2725,18 @@ public void testEventTimeTimerLoop() {
             private final TimerSpec loopSpec = TimerSpecs.timer(TimeDomain.EVENT_TIME);
 
             @StateId(stateId)
-            private final StateSpec> countSpec = StateSpecs.value();
+            private final StateSpec> countSpec = StateSpecs.value();
 
             @ProcessElement
             public void processElement(
-                @StateId(stateId) ValueState countState,
+                @StateId(stateId) ReadModifyWriteState countState,
                 @TimerId(timerId) Timer loopTimer) {
               loopTimer.offset(Duration.millis(1)).setRelative();
             }
 
             @OnTimer(timerId)
             public void onLoopTimer(
-                @StateId(stateId) ValueState countState,
+                @StateId(stateId) ReadModifyWriteState countState,
                 @TimerId(timerId) Timer loopTimer,
                 OutputReceiver r) {
               int count = MoreObjects.firstNonNull(countState.read(), 0);
@@ -2779,14 +2779,14 @@ public void testEventTimeTimerMultipleKeys() throws Exception {
             private final TimerSpec spec = TimerSpecs.timer(TimeDomain.EVENT_TIME);
 
             @StateId(stateId)
-            private final StateSpec> stateSpec =
+            private final StateSpec> stateSpec =
                 StateSpecs.value(StringUtf8Coder.of());
 
             @ProcessElement
             public void processElement(
                 ProcessContext context,
                 @TimerId(timerId) Timer timer,
-                @StateId(stateId) ValueState state,
+                @StateId(stateId) ReadModifyWriteState state,
                 BoundedWindow window) {
               timer.set(window.maxTimestamp());
               state.write(context.element().getKey());
@@ -2796,7 +2796,7 @@ public void processElement(
 
             @OnTimer(timerId)
             public void onTimer(
-                @StateId(stateId) ValueState state, OutputReceiver> r) {
+                @StateId(stateId) ReadModifyWriteState state, OutputReceiver> r) {
               r.output(KV.of(state.read(), timerOutput));
             }
           };
@@ -3208,7 +3208,7 @@ public static class TimerCoderInferenceTests extends SharedTestBase implements S
       UsesStatefulParDo.class,
       DataflowPortabilityApiUnsupported.class
     })
-    public void testValueStateCoderInference() {
+    public void testReadModifyWriteStateCoderInference() {
       final String stateId = "foo";
       MyIntegerCoder myIntegerCoder = MyIntegerCoder.of();
       pipeline.getCoderRegistry().registerCoderForClass(MyInteger.class, myIntegerCoder);
@@ -3217,12 +3217,12 @@ public void testValueStateCoderInference() {
           new DoFn, MyInteger>() {
 
             @StateId(stateId)
-            private final StateSpec> intState = StateSpecs.value();
+            private final StateSpec> intState = StateSpecs.value();
 
             @ProcessElement
             public void processElement(
                 ProcessContext c,
-                @StateId(stateId) ValueState state,
+                @StateId(stateId) ReadModifyWriteState state,
                 OutputReceiver r) {
               MyInteger currentValue = MoreObjects.firstNonNull(state.read(), new MyInteger(0));
               r.output(currentValue);
@@ -3242,7 +3242,7 @@ public void processElement(
 
     @Test
     @Category({ValidatesRunner.class, UsesStatefulParDo.class})
-    public void testValueStateCoderInferenceFailure() throws Exception {
+    public void testReadModifyWriteStateCoderInferenceFailure() throws Exception {
       final String stateId = "foo";
       MyIntegerCoder myIntegerCoder = MyIntegerCoder.of();
 
@@ -3250,11 +3250,11 @@ public void testValueStateCoderInferenceFailure() throws Exception {
           new DoFn, MyInteger>() {
 
             @StateId(stateId)
-            private final StateSpec> intState = StateSpecs.value();
+            private final StateSpec> intState = StateSpecs.value();
 
             @ProcessElement
             public void processElement(
-                @StateId(stateId) ValueState state, OutputReceiver r) {
+                @StateId(stateId) ReadModifyWriteState state, OutputReceiver r) {
               MyInteger currentValue = MoreObjects.firstNonNull(state.read(), new MyInteger(0));
               r.output(currentValue);
               state.write(new MyInteger(currentValue.getValue() + 1));
@@ -3262,7 +3262,7 @@ public void processElement(
           };
 
       thrown.expect(RuntimeException.class);
-      thrown.expectMessage("Unable to infer a coder for ValueState and no Coder was specified.");
+      thrown.expectMessage("Unable to infer a coder for ReadModifyWriteState and no Coder was specified.");
 
       pipeline
           .apply(Create.of(KV.of("hello", 42), KV.of("hello", 97), KV.of("hello", 84)))
@@ -3278,7 +3278,7 @@ public void processElement(
       UsesStatefulParDo.class,
       DataflowPortabilityApiUnsupported.class
     })
-    public void testValueStateCoderInferenceFromInputCoder() {
+    public void testReadModifyWriteStateCoderInferenceFromInputCoder() {
       final String stateId = "foo";
       MyIntegerCoder myIntegerCoder = MyIntegerCoder.of();
 
@@ -3286,11 +3286,11 @@ public void testValueStateCoderInferenceFromInputCoder() {
           new DoFn, MyInteger>() {
 
             @StateId(stateId)
-            private final StateSpec> intState = StateSpecs.value();
+            private final StateSpec> intState = StateSpecs.value();
 
             @ProcessElement
             public void processElement(
-                @StateId(stateId) ValueState state, OutputReceiver r) {
+                @StateId(stateId) ReadModifyWriteState state, OutputReceiver r) {
               MyInteger currentValue = MoreObjects.firstNonNull(state.read(), new MyInteger(0));
               r.output(currentValue);
               state.write(new MyInteger(currentValue.getValue() + 1));
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java
index 0d2be5a2d536c..30c11ff4949d4 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnInvokersTest.java
@@ -47,7 +47,7 @@
 import org.apache.beam.sdk.state.Timer;
 import org.apache.beam.sdk.state.TimerSpec;
 import org.apache.beam.sdk.state.TimerSpecs;
-import org.apache.beam.sdk.state.ValueState;
+import org.apache.beam.sdk.state.ReadModifyWriteState;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver;
 import org.apache.beam.sdk.transforms.DoFn.OutputReceiver;
@@ -235,16 +235,16 @@ public void processElement(
   /** Tests that the generated {@link DoFnInvoker} passes the state parameter that it should. */
   @Test
   public void testDoFnWithState() throws Exception {
-    ValueState mockState = mock(ValueState.class);
+    ReadModifyWriteState mockState = mock(ReadModifyWriteState.class);
     final String stateId = "my-state-id-here";
     when(mockArgumentProvider.state(stateId)).thenReturn(mockState);
 
     class MockFn extends DoFn {
       @StateId(stateId)
-      private final StateSpec> spec = StateSpecs.value(VarIntCoder.of());
+      private final StateSpec> spec = StateSpecs.value(VarIntCoder.of());
 
       @ProcessElement
-      public void processElement(ProcessContext c, @StateId(stateId) ValueState valueState)
+      public void processElement(ProcessContext c, @StateId(stateId) ReadModifyWriteState valueState)
           throws Exception {}
     }
 
diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java
index 8c80cfc42497f..1c69259c60d1e 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java
@@ -45,7 +45,7 @@
 import org.apache.beam.sdk.state.Timer;
 import org.apache.beam.sdk.state.TimerSpec;
 import org.apache.beam.sdk.state.TimerSpecs;
-import org.apache.beam.sdk.state.ValueState;
+import org.apache.beam.sdk.state.ReadModifyWriteState;
 import org.apache.beam.sdk.state.WatermarkHoldState;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.Sum;
@@ -765,11 +765,11 @@ public void testStateIdDuplicate() throws Exception {
         DoFnSignatures.getSignature(
             new DoFn, Long>() {
               @StateId("my-id")
-              private final StateSpec> myfield1 =
+              private final StateSpec> myfield1 =
                   StateSpecs.value(VarIntCoder.of());
 
               @StateId("my-id")
-              private final StateSpec> myfield2 =
+              private final StateSpec> myfield2 =
                   StateSpecs.value(VarLongCoder.of());
 
               @ProcessElement
@@ -787,7 +787,7 @@ public void testStateIdNonFinal() throws Exception {
     DoFnSignatures.getSignature(
         new DoFn, Long>() {
           @StateId("my-id")
-          private StateSpec> myfield = StateSpecs.value(VarIntCoder.of());
+          private StateSpec> myfield = StateSpecs.value(VarIntCoder.of());
 
           @ProcessElement
           public void foo(ProcessContext context) {}
@@ -804,7 +804,7 @@ public void testStateParameterNoAnnotation() throws Exception {
     DoFnSignatures.getSignature(
         new DoFn, Long>() {
           @ProcessElement
-          public void myProcessElement(ProcessContext context, ValueState noAnnotation) {}
+          public void myProcessElement(ProcessContext context, ReadModifyWriteState noAnnotation) {}
         }.getClass());
   }
 
@@ -820,7 +820,7 @@ public void testStateParameterUndeclared() throws Exception {
         new DoFn, Long>() {
           @ProcessElement
           public void myProcessElement(
-              ProcessContext context, @StateId("my-id") ValueState undeclared) {}
+              ProcessContext context, @StateId("my-id") ReadModifyWriteState undeclared) {}
         }.getClass());
   }
 
@@ -835,13 +835,13 @@ public void testStateParameterDuplicate() throws Exception {
     DoFnSignatures.getSignature(
         new DoFn, Long>() {
           @StateId("my-id")
-          private final StateSpec> myfield = StateSpecs.value(VarIntCoder.of());
+          private final StateSpec> myfield = StateSpecs.value(VarIntCoder.of());
 
           @ProcessElement
           public void myProcessElement(
               ProcessContext context,
-              @StateId("my-id") ValueState one,
-              @StateId("my-id") ValueState two) {}
+              @StateId("my-id") ReadModifyWriteState one,
+              @StateId("my-id") ReadModifyWriteState two) {}
         }.getClass());
   }
 
@@ -851,7 +851,7 @@ public void testStateParameterWrongStateType() throws Exception {
     thrown.expectMessage("WatermarkHoldState");
     thrown.expectMessage("reference to");
     thrown.expectMessage("supertype");
-    thrown.expectMessage("ValueState");
+    thrown.expectMessage("ReadModifyWriteState");
     thrown.expectMessage("my-id");
     thrown.expectMessage("myProcessElement");
     thrown.expectMessage("index 1");
@@ -859,7 +859,7 @@ public void testStateParameterWrongStateType() throws Exception {
     DoFnSignatures.getSignature(
         new DoFn, Long>() {
           @StateId("my-id")
-          private final StateSpec> myfield = StateSpecs.value(VarIntCoder.of());
+          private final StateSpec> myfield = StateSpecs.value(VarIntCoder.of());
 
           @ProcessElement
           public void myProcessElement(
@@ -870,10 +870,10 @@ public void myProcessElement(
   @Test
   public void testStateParameterWrongGenericType() throws Exception {
     thrown.expect(IllegalArgumentException.class);
-    thrown.expectMessage("ValueState");
+    thrown.expectMessage("ReadModifyWriteState");
     thrown.expectMessage("reference to");
     thrown.expectMessage("supertype");
-    thrown.expectMessage("ValueState");
+    thrown.expectMessage("ReadModifyWriteState");
     thrown.expectMessage("my-id");
     thrown.expectMessage("myProcessElement");
     thrown.expectMessage("index 1");
@@ -881,11 +881,11 @@ public void testStateParameterWrongGenericType() throws Exception {
     DoFnSignatures.getSignature(
         new DoFn, Long>() {
           @StateId("my-id")
-          private final StateSpec> myfield = StateSpecs.value(VarIntCoder.of());
+          private final StateSpec> myfield = StateSpecs.value(VarIntCoder.of());
 
           @ProcessElement
           public void myProcessElement(
-              ProcessContext context, @StateId("my-id") ValueState stringState) {}
+              ProcessContext context, @StateId("my-id") ReadModifyWriteState stringState) {}
         }.getClass());
   }
 
@@ -910,7 +910,7 @@ public void testSimpleStateIdAnonymousDoFn() throws Exception {
         DoFnSignatures.getSignature(
             new DoFn, Long>() {
               @StateId("foo")
-              private final StateSpec> bizzle =
+              private final StateSpec> bizzle =
                   StateSpecs.value(VarIntCoder.of());
 
               @ProcessElement
@@ -924,7 +924,7 @@ public void foo(ProcessContext context) {}
     assertThat(decl.field().getName(), equalTo("bizzle"));
     assertThat(
         decl.stateType(),
-        Matchers.>equalTo(new TypeDescriptor>() {}));
+        Matchers.>equalTo(new TypeDescriptor>() {}));
   }
 
   @Test
@@ -934,7 +934,7 @@ public void testUsageOfStateDeclaredInSuperclass() throws Exception {
           @ProcessElement
           public void process(
               ProcessContext context,
-              @StateId(DoFnDeclaringState.STATE_ID) ValueState state) {}
+              @StateId(DoFnDeclaringState.STATE_ID) ReadModifyWriteState state) {}
         };
 
     thrown.expect(IllegalArgumentException.class);
@@ -954,7 +954,7 @@ public void testDeclOfStateUsedInSuperclass() throws Exception {
     DoFnSignatures.getSignature(
         new DoFnUsingState() {
           @StateId(DoFnUsingState.STATE_ID)
-          private final StateSpec> spec = StateSpecs.value(VarIntCoder.of());
+          private final StateSpec> spec = StateSpecs.value(VarIntCoder.of());
         }.getClass());
   }
 
@@ -963,7 +963,7 @@ public void testDeclAndUsageOfStateInSuperclass() throws Exception {
     class DoFnOverridingAbstractStateUse extends DoFnDeclaringStateAndAbstractUse {
 
       @Override
-      public void processWithState(ProcessContext c, ValueState state) {}
+      public void processWithState(ProcessContext c, ReadModifyWriteState state) {}
     }
 
     DoFnSignature sig =
@@ -995,11 +995,11 @@ public void testSimpleStateIdRefAnonymousDoFn() throws Exception {
         DoFnSignatures.getSignature(
             new DoFn, Long>() {
               @StateId("foo")
-              private final StateSpec> bizzleDecl =
+              private final StateSpec> bizzleDecl =
                   StateSpecs.value(VarIntCoder.of());
 
               @ProcessElement
-              public void foo(ProcessContext context, @StateId("foo") ValueState bizzle) {}
+              public void foo(ProcessContext context, @StateId("foo") ReadModifyWriteState bizzle) {}
             }.getClass());
 
     assertThat(sig.processElement().extraParameters().size(), equalTo(2));
@@ -1028,7 +1028,7 @@ public Void dispatch(StateParameter stateParam) {
   public void testSimpleStateIdNamedDoFn() throws Exception {
     class DoFnForTestSimpleStateIdNamedDoFn extends DoFn, Long> {
       @StateId("foo")
-      private final StateSpec> bizzle = StateSpecs.value(VarIntCoder.of());
+      private final StateSpec> bizzle = StateSpecs.value(VarIntCoder.of());
 
       @ProcessElement
       public void foo(ProcessContext context) {}
@@ -1045,7 +1045,7 @@ public void foo(ProcessContext context) {}
         decl.field(), equalTo(DoFnForTestSimpleStateIdNamedDoFn.class.getDeclaredField("bizzle")));
     assertThat(
         decl.stateType(),
-        Matchers.>equalTo(new TypeDescriptor>() {}));
+        Matchers.>equalTo(new TypeDescriptor>() {}));
   }
 
   @Test
@@ -1054,7 +1054,7 @@ class DoFnForTestGenericStatefulDoFn extends DoFn, Long> {
       // Note that in order to have a coder for T it will require initialization in the constructor,
       // but that isn't important for this test
       @StateId("foo")
-      private final StateSpec> bizzle = null;
+      private final StateSpec> bizzle = null;
 
       @ProcessElement
       public void foo(ProcessContext context) {}
@@ -1073,7 +1073,7 @@ public void foo(ProcessContext context) {}
         decl.field(), equalTo(DoFnForTestGenericStatefulDoFn.class.getDeclaredField("bizzle")));
     assertThat(
         decl.stateType(),
-        Matchers.>equalTo(new TypeDescriptor>() {}));
+        Matchers.>equalTo(new TypeDescriptor>() {}));
   }
 
   @Test
@@ -1158,7 +1158,7 @@ public void testOnWindowExpirationWithAllowedParams() {
         DoFnSignatures.getSignature(
             new DoFn() {
               @StateId("foo")
-              private final StateSpec> bizzle =
+              private final StateSpec> bizzle =
                   StateSpecs.value(VarIntCoder.of());
 
               @ProcessElement
@@ -1167,7 +1167,7 @@ public void process(ProcessContext c) {}
               @OnWindowExpiration
               public void bar(
                   BoundedWindow b,
-                  @StateId("foo") ValueState s,
+                  @StateId("foo") ReadModifyWriteState s,
                   PipelineOptions p,
                   OutputReceiver o,
                   MultiOutputReceiver m) {}
@@ -1195,14 +1195,14 @@ private abstract static class DoFnDeclaringState extends DoFn> bizzle = StateSpecs.value(VarIntCoder.of());
+    private final StateSpec> bizzle = StateSpecs.value(VarIntCoder.of());
   }
 
   private abstract static class DoFnUsingState extends DoFn, Long> {
     public static final String STATE_ID = "my-state-id";
 
     @ProcessElement
-    public void process(ProcessContext context, @StateId(STATE_ID) ValueState state) {}
+    public void process(ProcessContext context, @StateId(STATE_ID) ReadModifyWriteState state) {}
   }
 
   private abstract static class DoFnDeclaringStateAndAbstractUse
@@ -1210,12 +1210,12 @@ private abstract static class DoFnDeclaringStateAndAbstractUse
     public static final String STATE_ID = "my-state-id";
 
     @StateId(STATE_ID)
-    private final StateSpec> myStateSpec =
+    private final StateSpec> myStateSpec =
         StateSpecs.value(StringUtf8Coder.of());
 
     @ProcessElement
     public abstract void processWithState(
-        ProcessContext context, @StateId(STATE_ID) ValueState state);
+        ProcessContext context, @StateId(STATE_ID) ReadModifyWriteState state);
   }
 
   private abstract static class DoFnDeclaringMyTimerId extends DoFn, Long> {
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSortRel.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSortRel.java
index c660328268c2d..52fbcd3440322 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSortRel.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/rel/BeamSortRel.java
@@ -36,7 +36,7 @@
 import org.apache.beam.sdk.schemas.Schema.FieldType;
 import org.apache.beam.sdk.state.StateSpec;
 import org.apache.beam.sdk.state.StateSpecs;
-import org.apache.beam.sdk.state.ValueState;
+import org.apache.beam.sdk.state.ReadModifyWriteState;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.Flatten;
 import org.apache.beam.sdk.transforms.PTransform;
@@ -259,17 +259,17 @@ public LimitFn(int c, int s) {
     }
 
     @StateId("counter")
-    private final StateSpec> counterState = StateSpecs.value(VarIntCoder.of());
+    private final StateSpec> counterState = StateSpecs.value(VarIntCoder.of());
 
     @StateId("skipped_rows")
-    private final StateSpec> skippedRowsState =
+    private final StateSpec> skippedRowsState =
         StateSpecs.value(VarIntCoder.of());
 
     @ProcessElement
     public void processElement(
         ProcessContext context,
-        @StateId("counter") ValueState counterState,
-        @StateId("skipped_rows") ValueState skippedRowsState) {
+        @StateId("counter") ReadModifyWriteState counterState,
+        @StateId("skipped_rows") ReadModifyWriteState skippedRowsState) {
       Integer toSkipRows = firstNonNull(skippedRowsState.read(), startIndex);
       if (toSkipRows == 0) {
         int current = firstNonNull(counterState.read(), 0);
diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaCSVTableIT.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaCSVTableIT.java
index 201a1df4ed282..01aa2d6211030 100644
--- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaCSVTableIT.java
+++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaCSVTableIT.java
@@ -44,7 +44,7 @@
 import org.apache.beam.sdk.state.BagState;
 import org.apache.beam.sdk.state.StateSpec;
 import org.apache.beam.sdk.state.StateSpecs;
-import org.apache.beam.sdk.state.ValueState;
+import org.apache.beam.sdk.state.ReadModifyWriteState;
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.MapElements;
@@ -199,13 +199,13 @@ public static class StreamAssertEqual extends DoFn, Void> {
     private final StateSpec> seenRows = StateSpecs.bag();
 
     @StateId("count")
-    private final StateSpec> countState = StateSpecs.value();
+    private final StateSpec> countState = StateSpecs.value();
 
     @ProcessElement
     public void process(
         ProcessContext context,
         @StateId("seenValues") BagState seenValues,
-        @StateId("count") ValueState countState) {
+        @StateId("count") ReadModifyWriteState countState) {
       // I don't think doing this will be safe in parallel
       int count = MoreObjects.firstNonNull(countState.read(), 0);
       count = count + 1;
diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
index b1224bd19792a..c7c569e6d0227 100644
--- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
+++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
@@ -44,6 +44,7 @@
 import org.apache.beam.sdk.state.StateContext;
 import org.apache.beam.sdk.state.StateSpec;
 import org.apache.beam.sdk.state.ValueState;
+import org.apache.beam.sdk.state.ReadModifyWriteState;
 import org.apache.beam.sdk.state.WatermarkHoldState;
 import org.apache.beam.sdk.transforms.Combine.CombineFn;
 import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext;
@@ -197,12 +198,54 @@ public boolean isEmpty() {
   @Override
   public  ValueState bindValue(String id, StateSpec> spec, Coder coder) {
     return (ValueState)
+            stateKeyObjectCache.computeIfAbsent(
+                    createBagUserStateKey(id),
+                    new Function() {
+                      @Override
+                      public Object apply(StateKey key) {
+                        return new ValueState() {
+                          private final BagUserState impl = createBagUserState(id, coder);
+
+                          @Override
+                          public void clear() {
+                            impl.clear();
+                          }
+
+                          @Override
+                          public void write(T input) {
+                            impl.clear();
+                            impl.append(input);
+                          }
+
+                          @Override
+                          public T read() {
+                            Iterator value = impl.get().iterator();
+                            if (value.hasNext()) {
+                              return value.next();
+                            } else {
+                              return null;
+                            }
+                          }
+
+                          @Override
+                          public ValueState readLater() {
+                            // TODO: Support prefetching.
+                            return this;
+                          }
+                        };
+                      }
+                    });
+  }
+
+  @Override
+  public  ReadModifyWriteState bindReadModifyWrite(String id, StateSpec> spec, Coder coder) {
+    return (ReadModifyWriteState)
         stateKeyObjectCache.computeIfAbsent(
             createBagUserStateKey(id),
             new Function() {
               @Override
               public Object apply(StateKey key) {
-                return new ValueState() {
+                return new ReadModifyWriteState() {
                   private final BagUserState impl = createBagUserState(id, coder);
 
                   @Override
@@ -227,7 +270,7 @@ public T read() {
                   }
 
                   @Override
-                  public ValueState readLater() {
+                  public ReadModifyWriteState readLater() {
                     // TODO: Support prefetching.
                     return this;
                   }
diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
index 088148d70db21..42e225c660555 100644
--- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
+++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/FnApiDoFnRunnerTest.java
@@ -68,7 +68,7 @@
 import org.apache.beam.sdk.state.Timer;
 import org.apache.beam.sdk.state.TimerSpec;
 import org.apache.beam.sdk.state.TimerSpecs;
-import org.apache.beam.sdk.state.ValueState;
+import org.apache.beam.sdk.state.ReadModifyWriteState;
 import org.apache.beam.sdk.testing.ResetDateTimeProvider;
 import org.apache.beam.sdk.transforms.Combine.CombineFn;
 import org.apache.beam.sdk.transforms.Create;
@@ -144,7 +144,7 @@ private static class TestStatefulDoFn extends DoFn, String> {
     private static final TupleTag additionalOutput = new TupleTag<>("output");
 
     @StateId("value")
-    private final StateSpec> valueStateSpec =
+    private final StateSpec> valueStateSpec =
         StateSpecs.value(StringUtf8Coder.of());
 
     @StateId("bag")
@@ -157,7 +157,7 @@ private static class TestStatefulDoFn extends DoFn, String> {
     @ProcessElement
     public void processElement(
         ProcessContext context,
-        @StateId("value") ValueState valueState,
+        @StateId("value") ReadModifyWriteState valueState,
         @StateId("bag") BagState bagState,
         @StateId("combine") CombiningState combiningState) {
       context.output("value:" + valueState.read());
diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaExactlyOnceSink.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaExactlyOnceSink.java
index 6ea6bafe0e2fc..9226931eadeeb 100644
--- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaExactlyOnceSink.java
+++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaExactlyOnceSink.java
@@ -44,7 +44,7 @@
 import org.apache.beam.sdk.state.BagState;
 import org.apache.beam.sdk.state.StateSpec;
 import org.apache.beam.sdk.state.StateSpecs;
-import org.apache.beam.sdk.state.ValueState;
+import org.apache.beam.sdk.state.ReadModifyWriteState;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.GroupByKey;
 import org.apache.beam.sdk.transforms.PTransform;
@@ -210,10 +210,10 @@ private static class Sequencer
     private static final String NEXT_ID = "nextId";
 
     @StateId(NEXT_ID)
-    private final StateSpec> nextIdSpec = StateSpecs.value();
+    private final StateSpec> nextIdSpec = StateSpecs.readModifyWrite();
 
     @ProcessElement
-    public void processElement(@StateId(NEXT_ID) ValueState nextIdState, ProcessContext ctx) {
+    public void processElement(@StateId(NEXT_ID) ReadModifyWriteState nextIdState, ProcessContext ctx) {
       long nextId = MoreObjects.firstNonNull(nextIdState.read(), 0L);
       int shard = ctx.element().getKey();
       for (TimestampedValue> value : ctx.element().getValue()) {
@@ -237,10 +237,10 @@ private static class ExactlyOnceWriter
     private static final ObjectMapper JSON_MAPPER = new ObjectMapper();
 
     @StateId(NEXT_ID)
-    private final StateSpec> sequenceIdSpec = StateSpecs.value();
+    private final StateSpec> sequenceIdSpec = StateSpecs.readModifyWrite();
 
     @StateId(MIN_BUFFERED_ID)
-    private final StateSpec> minBufferedIdSpec = StateSpecs.value();
+    private final StateSpec> minBufferedIdSpec = StateSpecs.readModifyWrite();
 
     @StateId(OUT_OF_ORDER_BUFFER)
     private final StateSpec>>>>
@@ -250,7 +250,7 @@ private static class ExactlyOnceWriter
     // a job is restarted with same groupId, but the metadata from previous run was not cleared.
     // Better to be safe and error out with a clear message.
     @StateId(WRITER_ID)
-    private final StateSpec> writerIdSpec = StateSpecs.value();
+    private final StateSpec> writerIdSpec = StateSpecs.readModifyWrite();
 
     private final WriteRecords spec;
 
@@ -277,11 +277,11 @@ public void setup() {
     @RequiresStableInput
     @ProcessElement
     public void processElement(
-        @StateId(NEXT_ID) ValueState nextIdState,
-        @StateId(MIN_BUFFERED_ID) ValueState minBufferedIdState,
+        @StateId(NEXT_ID) ReadModifyWriteState nextIdState,
+        @StateId(MIN_BUFFERED_ID) ReadModifyWriteState minBufferedIdState,
         @StateId(OUT_OF_ORDER_BUFFER)
             BagState>>> oooBufferState,
-        @StateId(WRITER_ID) ValueState writerIdState,
+        @StateId(WRITER_ID) ReadModifyWriteState writerIdState,
         ProcessContext ctx)
         throws IOException {
 
@@ -521,7 +521,7 @@ void commitTxn(long lastRecordId, Counter numTransactions) throws IOException {
     }
 
     private ShardWriter initShardWriter(
-        int shard, ValueState writerIdState, long nextId) throws IOException {
+        int shard, ReadModifyWriteState writerIdState, long nextId) throws IOException {
 
       String producerName = String.format("producer_%d_for_%s", shard, spec.getSinkGroupId());
       Producer producer = initializeExactlyOnceProducer(spec, producerName);
diff --git a/sdks/java/testing/nexmark/src/main/java/org/apache/beam/sdk/nexmark/NexmarkUtils.java b/sdks/java/testing/nexmark/src/main/java/org/apache/beam/sdk/nexmark/NexmarkUtils.java
index ca6f693e7f4b4..43cb9476ec355 100644
--- a/sdks/java/testing/nexmark/src/main/java/org/apache/beam/sdk/nexmark/NexmarkUtils.java
+++ b/sdks/java/testing/nexmark/src/main/java/org/apache/beam/sdk/nexmark/NexmarkUtils.java
@@ -61,7 +61,7 @@
 import org.apache.beam.sdk.nexmark.sources.generator.GeneratorConfig;
 import org.apache.beam.sdk.state.StateSpec;
 import org.apache.beam.sdk.state.StateSpecs;
-import org.apache.beam.sdk.state.ValueState;
+import org.apache.beam.sdk.state.ReadModifyWriteState;
 import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.Create;
 import org.apache.beam.sdk.transforms.DoFn;
@@ -587,12 +587,12 @@ public void processElement(ProcessContext context) {
                 private static final String DISK_BUSY = "diskBusy";
 
                 @StateId(DISK_BUSY)
-                private final StateSpec> spec =
+                private final StateSpec> spec =
                     StateSpecs.value(ByteArrayCoder.of());
 
                 @ProcessElement
                 public void processElement(
-                    ProcessContext c, @StateId(DISK_BUSY) ValueState state) {
+                    ProcessContext c, @StateId(DISK_BUSY) ReadModifyWriteState state) {
                   long remain = bytes;
                   long now = System.currentTimeMillis();
                   while (remain > 0) {
diff --git a/sdks/java/testing/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query3.java b/sdks/java/testing/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query3.java
index 05d7bf3990c74..327d7880be813 100644
--- a/sdks/java/testing/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query3.java
+++ b/sdks/java/testing/nexmark/src/main/java/org/apache/beam/sdk/nexmark/queries/Query3.java
@@ -33,7 +33,7 @@
 import org.apache.beam.sdk.state.Timer;
 import org.apache.beam.sdk.state.TimerSpec;
 import org.apache.beam.sdk.state.TimerSpecs;
-import org.apache.beam.sdk.state.ValueState;
+import org.apache.beam.sdk.state.ReadModifyWriteState;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.Filter;
 import org.apache.beam.sdk.transforms.ParDo;
@@ -159,12 +159,12 @@ private static class JoinDoFn extends DoFn, KV> personSpec = StateSpecs.value(Person.CODER);
+    private static final StateSpec> personSpec = StateSpecs.value(Person.CODER);
 
     private static final String PERSON_STATE_EXPIRING = "personStateExpiring";
 
     @StateId(AUCTIONS)
-    private final StateSpec>> auctionsSpec =
+    private final StateSpec>> auctionsSpec =
         StateSpecs.value(ListCoder.of(Auction.CODER));
 
     @TimerId(PERSON_STATE_EXPIRING)
@@ -195,8 +195,8 @@ private JoinDoFn(String name, int maxAuctionsWaitingTime) {
     public void processElement(
         ProcessContext c,
         @TimerId(PERSON_STATE_EXPIRING) Timer timer,
-        @StateId(PERSON) ValueState personState,
-        @StateId(AUCTIONS) ValueState> auctionsState) {
+        @StateId(PERSON) ReadModifyWriteState personState,
+        @StateId(AUCTIONS) ReadModifyWriteState> auctionsState) {
       // We would *almost* implement this by  rewindowing into the global window and
       // running a combiner over the result. The combiner's accumulator would be the
       // state we use below. However, combiners cannot emit intermediate results, thus
@@ -271,7 +271,7 @@ public void processElement(
 
     @OnTimer(PERSON_STATE_EXPIRING)
     public void onTimerCallback(
-        OnTimerContext context, @StateId(PERSON) ValueState personState) {
+        OnTimerContext context, @StateId(PERSON) ReadModifyWriteState personState) {
       personState.clear();
     }
   }
diff --git a/sdks/python/apache_beam/runners/direct/direct_userstate.py b/sdks/python/apache_beam/runners/direct/direct_userstate.py
index 42afaa3f63378..c305e0ffa79b8 100644
--- a/sdks/python/apache_beam/runners/direct/direct_userstate.py
+++ b/sdks/python/apache_beam/runners/direct/direct_userstate.py
@@ -22,6 +22,7 @@
 
 from apache_beam.transforms import userstate
 from apache_beam.transforms.trigger import _ListStateTag
+from apache_beam.transforms.trigger import _ReadModifyWriteStateTag
 from apache_beam.transforms.trigger import _SetStateTag
 
 
@@ -40,6 +41,11 @@ def for_spec(state_spec, state_tag, current_value_accessor):
                                         current_value_accessor)
     elif isinstance(state_spec, userstate.SetStateSpec):
       return SetRuntimeState(state_spec, state_tag, current_value_accessor)
+
+    elif isinstance(state_spec, userstate.ReadModifyWriteStateSpec):
+      return ReadModifyWriteRuntimeState(state_spec,
+                                         state_tag,
+                                         current_value_accessor)
     else:
       raise ValueError('Invalid state spec: %s' % state_spec)
 
@@ -110,6 +116,36 @@ def is_modified(self):
     return self._modified
 
 
+class ReadModifyWriteRuntimeState(DirectRuntimeState,
+                                  userstate.ReadModifyWriteRuntimeState):
+  """Read modify write state interface object passed to user code."""
+
+  def __init__(self, state_spec, state_tag, current_value_accessor):
+    super(ReadModifyWriteRuntimeState, self).__init__(
+        state_spec, state_tag, current_value_accessor)
+    self._value = UNREAD_VALUE
+    self._modified = False
+    self._cleared = False
+
+  def read(self):
+    if self._value is UNREAD_VALUE:
+      self._value = self._decode(self._current_value_accessor())
+
+    return self._value
+
+  def add(self, value):
+    self._modified = True
+    self._value = value
+
+  def clear(self):
+    self._cleared = True
+    self._value = UNREAD_VALUE
+    self._modified = False
+
+  def is_modified(self):
+    return self._modified and self._value is not UNREAD_VALUE
+
+
 class CombiningValueRuntimeState(
     DirectRuntimeState, userstate.CombiningValueRuntimeState):
   """Combining value state interface object passed to user code."""
@@ -169,6 +205,8 @@ def __init__(self, step_context, dofn, key_coder):
         state_tag = _ListStateTag(state_key)
       elif isinstance(state_spec, userstate.SetStateSpec):
         state_tag = _SetStateTag(state_key)
+      elif isinstance(state_spec, userstate.ReadModifyWriteStateSpec):
+        state_tag = _ReadModifyWriteStateTag(state_key)
       else:
         raise ValueError('Invalid state spec: %s' % state_spec)
       self.state_tags[state_spec] = state_tag
@@ -225,6 +263,14 @@ def commit(self):
           for new_value in runtime_state._current_accumulator:
             state.add_state(
                 window, state_tag, state_spec.coder.encode(new_value))
+      elif isinstance(state_spec, userstate.ReadModifyWriteStateSpec):
+        if runtime_state._cleared:
+          state.clear_state(window, state_tag)
+
+        if runtime_state.is_modified():
+          state.add_state(window,
+                          state_tag,
+                          state_spec.coder.encode(runtime_state._value))
       else:
         raise ValueError('Invalid state spec: %s' % state_spec)
 
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner.py b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
index b0399214c7ee7..1633e95d5e31f 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner.py
@@ -938,7 +938,9 @@ def blocking_append(self, state_key, data):
 
     def blocking_clear(self, state_key):
       with self._lock:
-        del self._state[self._to_key(state_key)]
+        key = self._to_key(state_key)
+        if key in self._state:
+          del self._state[self._to_key(state_key)]
 
     @staticmethod
     def _to_key(state_key):
diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py
index a818ec5997a06..ab712e3f07540 100644
--- a/sdks/python/apache_beam/runners/worker/bundle_processor.py
+++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py
@@ -431,6 +431,7 @@ def clear(self):
   def _commit(self):
     if self._cleared:
       self._state_handler.blocking_clear(self._state_key)
+
     if self._added_elements:
       value_coder_impl = self._value_coder.get_impl()
       out = coder_impl.create_OutputStream()
@@ -439,6 +440,47 @@ def _commit(self):
       self._state_handler.blocking_append(self._state_key, out.get())
 
 
+class SynchronousReadModifyWriteRuntimeState(
+    userstate.ReadModifyWriteRuntimeState):
+
+  def __init__(self, state_handler, state_key, value_coder):
+    self._state_handler = state_handler
+    self._state_key = state_key
+    self._value_coder = value_coder
+    self._cleared = False
+    self._added_element = None
+
+  def read(self):
+    if self._cleared:
+      return None
+    elif self._added_element:
+      return self._added_element
+    else:
+      elements = [element for element in _StateBackedIterable(
+          self._state_handler, self._state_key, self._value_coder)]
+      return elements[0] if elements else None
+
+  def add(self, value):
+    if self._cleared:
+      self._state_handler.blocking_clear(self._state_key)
+      self._cleared = False
+    self._added_element = value
+
+  def clear(self):
+    self._cleared = True
+    self._added_element = None
+
+  def _commit(self):
+    if self._cleared:
+      self._state_handler.blocking_clear(self._state_key)
+
+    if self._added_element:
+      value_coder_impl = self._value_coder.get_impl()
+      out = coder_impl.create_OutputStream()
+      value_coder_impl.encode_to_stream(self._added_element, out, True)
+      self._state_handler.blocking_append(self._state_key, out.get())
+
+
 class OutputTimer(object):
   def __init__(self, key, window, receiver):
     self._key = key
@@ -502,7 +544,7 @@ def _create_state(self, state_spec, key, window):
     if isinstance(state_spec,
                   (userstate.BagStateSpec, userstate.CombiningValueStateSpec)):
       bag_state = SynchronousBagRuntimeState(
-          self._state_handler,
+          state_handler=self._state_handler,
           state_key=beam_fn_api_pb2.StateKey(
               bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
                   ptransform_id=self._transform_id,
@@ -516,7 +558,17 @@ def _create_state(self, state_spec, key, window):
         return CombiningValueRuntimeState(bag_state, state_spec.combine_fn)
     elif isinstance(state_spec, userstate.SetStateSpec):
       return SynchronousSetRuntimeState(
-          self._state_handler,
+          state_handler=self._state_handler,
+          state_key=beam_fn_api_pb2.StateKey(
+              bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
+                  ptransform_id=self._transform_id,
+                  user_state_id=state_spec.name,
+                  window=self._window_coder.encode(window),
+                  key=self._key_coder.encode(key))),
+          value_coder=state_spec.coder)
+    elif isinstance(state_spec, userstate.ReadModifyWriteStateSpec):
+      return SynchronousReadModifyWriteRuntimeState(
+          state_handler=self._state_handler,
           state_key=beam_fn_api_pb2.StateKey(
               bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
                   ptransform_id=self._transform_id,
diff --git a/sdks/python/apache_beam/transforms/trigger.py b/sdks/python/apache_beam/transforms/trigger.py
index dbda3cca0f860..a9923cae48dc4 100644
--- a/sdks/python/apache_beam/transforms/trigger.py
+++ b/sdks/python/apache_beam/transforms/trigger.py
@@ -103,6 +103,16 @@ def with_prefix(self, prefix):
     return _SetStateTag(prefix + self.tag)
 
 
+class _ReadModifyWriteStateTag(_StateTag):
+  """StateTag pointing to an element."""
+
+  def __repr__(self):
+    return 'ReadModifyWriteState(%s)' % (self.tag)
+
+  def with_prefix(self, prefix):
+    return _ReadModifyWriteStateTag(prefix + self.tag)
+
+
 class _CombiningValueStateTag(_StateTag):
   """StateTag pointing to an element, accumulated with a combiner.
 
@@ -865,11 +875,13 @@ def add_state(self, window, tag, value):
   def get_state(self, window, tag):
     if isinstance(tag, _CombiningValueStateTag):
       original_tag, tag = tag, tag.without_extraction()
+
     values = [self.raw_state.get_state(window_id, tag)
               for window_id in self._get_ids(window)]
-    if isinstance(tag, _ValueStateTag):
-      raise ValueError(
-          'Merging requested for non-mergeable state tag: %r.' % tag)
+
+    if isinstance(tag, _ReadModifyWriteStateTag):
+      raise ValueError("ReadModifyWriteStateTag is not allowed for"
+                       " merging windows")
     elif isinstance(tag, _CombiningValueStateTag):
       return original_tag.combine_fn.extract_output(
           original_tag.combine_fn.merge_accumulators(values))
@@ -1231,7 +1243,7 @@ def get_window(self, window_id):
   def add_state(self, window, tag, value):
     if self.defensive_copy:
       value = copy.deepcopy(value)
-    if isinstance(tag, _ValueStateTag):
+    if isinstance(tag, _ReadModifyWriteStateTag):
       self.state[window][tag.tag] = value
     elif isinstance(tag, _CombiningValueStateTag):
       # TODO(robertwb): Store merged accumulators.
@@ -1247,7 +1259,9 @@ def add_state(self, window, tag, value):
 
   def get_state(self, window, tag):
     values = self.state[window][tag.tag]
-    if isinstance(tag, _ValueStateTag):
+    if isinstance(tag, _ReadModifyWriteStateTag):
+      # since we have stored only one item, values will
+      # have only one item.
       return values
     elif isinstance(tag, _CombiningValueStateTag):
       return tag.combine_fn.apply(values)
diff --git a/sdks/python/apache_beam/transforms/userstate.py b/sdks/python/apache_beam/transforms/userstate.py
index 4d7126e9597f1..d3d0a4a91a6da 100644
--- a/sdks/python/apache_beam/transforms/userstate.py
+++ b/sdks/python/apache_beam/transforms/userstate.py
@@ -76,6 +76,31 @@ def to_runner_api(self, context):
             element_coder_id=context.coders.get_id(self.coder)))
 
 
+class ReadModifyWriteStateSpec(StateSpec):
+  """
+  Specification of a user DoFn read modify write State Cell.
+  """
+  def __init__(self, name, coder):
+    """
+    Initialize the specification for Read modify write state.
+
+    Args:
+    name (str): The name by which the state is identified.
+    coder (Coder): Coder specifying how to encode the value.
+    """
+    if not isinstance(name, str):
+      raise TypeError("ReadModifyWriteState name is not a string")
+    if not isinstance(coder, Coder):
+      raise TypeError("ReadModifyWriteState coder is not of type Coder")
+    self.name = name
+    self.coder = coder
+
+  def to_runner_api(self, context):
+    return beam_runner_api_pb2.StateSpec(
+        read_modify_write_spec=beam_runner_api_pb2.ReadModifyWriteStateSpec(
+            coder_id=context.coders.get_id(self.coder)))
+
+
 class CombiningValueStateSpec(StateSpec):
   """Specification for a user DoFn combining value state cell."""
 
@@ -267,6 +292,7 @@ def set(self, timestamp):
 
 class RuntimeState(object):
   """State interface object passed to user code."""
+
   def prefetch(self):
     # The default implementation here does nothing.
     pass
@@ -291,6 +317,10 @@ class SetRuntimeState(AccumulatingRuntimeState):
   """Set state interface object passed to user code."""
 
 
+class ReadModifyWriteRuntimeState(AccumulatingRuntimeState):
+  """ReadModifyWrite state information object passed to user code."""
+
+
 class CombiningValueRuntimeState(AccumulatingRuntimeState):
   """Combining value state interface object passed to user code."""
 
diff --git a/sdks/python/apache_beam/transforms/userstate_test.py b/sdks/python/apache_beam/transforms/userstate_test.py
index 8e55ceead8532..7c95c668d52f5 100644
--- a/sdks/python/apache_beam/transforms/userstate_test.py
+++ b/sdks/python/apache_beam/transforms/userstate_test.py
@@ -42,6 +42,7 @@
 from apache_beam.transforms.timeutil import TimeDomain
 from apache_beam.transforms.userstate import BagStateSpec
 from apache_beam.transforms.userstate import CombiningValueStateSpec
+from apache_beam.transforms.userstate import ReadModifyWriteStateSpec
 from apache_beam.transforms.userstate import SetStateSpec
 from apache_beam.transforms.userstate import TimerSpec
 from apache_beam.transforms.userstate import get_dofn_specs
@@ -127,6 +128,15 @@ def test_spec_construction(self):
     with self.assertRaises(ValueError):
       DoFn.TimerParam(BagStateSpec('elements', BytesCoder()))
 
+    # ReadModifyWriteRuntimeState('ReadModify', Coder())
+    with self.assertRaises(TypeError):
+      # invalid state name
+      ReadModifyWriteStateSpec(123, BytesCoder())
+
+    with self.assertRaises(TypeError):
+      # invalid coder
+      ReadModifyWriteStateSpec('value', object())
+
     TimerSpec('timer', TimeDomain.WATERMARK)
     TimerSpec('timer', TimeDomain.REAL_TIME)
     with self.assertRaises(ValueError):
@@ -345,8 +355,8 @@ def process(self, element):
 
     return RecordDoFn()
 
-  def test_simple_stateful_dofn(self):
-    class SimpleTestStatefulDoFn(DoFn):
+  def test_simple_bagstate(self):
+    class SimpleTestBagStateDoFn(DoFn):
       BUFFER_STATE = BagStateSpec('buffer', BytesCoder())
       EXPIRY_TIMER = TimerSpec('expiry', TimeDomain.WATERMARK)
 
@@ -371,7 +381,7 @@ def expiry_callback(self, buffer=DoFn.StateParam(BUFFER_STATE),
       (p
        | test_stream
        | beam.Map(lambda x: ('mykey', x))
-       | beam.ParDo(SimpleTestStatefulDoFn())
+       | beam.ParDo(SimpleTestBagStateDoFn())
        | beam.ParDo(self.record_dofn()))
 
     # Two firings should occur: once after element 3 since the timer should
@@ -569,6 +579,151 @@ def emit_values(self, set_state=beam.DoFn.StateParam(SET_STATE)):
     result = p.run()
     result.wait_until_finish()
 
+  def test_simple_read_modify_write_state(self):
+
+    class SimpleReadModifyWriteStatefulDoFn(beam.DoFn):
+      READMODIFFYWRITE = ReadModifyWriteStateSpec('buffer', VarIntCoder())
+      EXPIRY_TIMER = TimerSpec('expiry', TimeDomain.WATERMARK)
+
+      def process(self,
+                  element,
+                  buffer=DoFn.StateParam(READMODIFFYWRITE),
+                  timer=DoFn.TimerParam(EXPIRY_TIMER)):
+        _, value = element
+        buffer.add(value)
+        timer.set(20)
+
+      @on_timer(EXPIRY_TIMER)
+      def expiry_callback(self,
+                          buffer=DoFn.StateParam(READMODIFFYWRITE)):
+        yield buffer.read()
+
+    with TestPipeline() as p:
+      test_stream = (TestStream()
+                     .advance_watermark_to(10)
+                     .add_elements([('key', 1)])
+                     .add_elements([('key', 2)])
+                     .advance_watermark_to(15)
+                     .add_elements([('key', 3)])
+                     .advance_watermark_to(20))
+
+      (p
+       | test_stream
+       | beam.ParDo(SimpleReadModifyWriteStatefulDoFn())
+       | beam.ParDo(self.record_dofn()))
+
+    self.assertEqual([3], StatefulDoFnOnDirectRunnerTest.all_records)
+
+  def test_simple_read_modify_write_state_clear(self):
+
+    class SimpleReadModifyWriteStatefulCleanDoFn(beam.DoFn):
+      READMODIFFYWRITE = ReadModifyWriteStateSpec('buffer', StrUtf8Coder())
+      CLEAR_TIMER = TimerSpec('clear_timer', TimeDomain.WATERMARK)
+      EXPIRY_TIMER = TimerSpec('expiry_timer', TimeDomain.WATERMARK)
+
+      def process(self,
+                  element,
+                  buffer=DoFn.StateParam(READMODIFFYWRITE),
+                  expire_timer=DoFn.TimerParam(EXPIRY_TIMER),
+                  clear_timer=DoFn.TimerParam(CLEAR_TIMER)):
+        _, value = element
+        buffer.add(value)
+        clear_timer.set(20)
+        expire_timer.set(40)
+
+      @on_timer(EXPIRY_TIMER)
+      def expiry_callback(self,
+                          buffer=DoFn.StateParam(READMODIFFYWRITE)):
+        yield buffer.clear()
+
+      @on_timer(CLEAR_TIMER)
+      def clear_callback(self,
+                         buffer=DoFn.StateParam(READMODIFFYWRITE)):
+        buffer.clear()
+
+    with TestPipeline() as p:
+      test_stream = (TestStream()
+                     .advance_watermark_to(10)
+                     .add_elements([('testkey', '1')])
+                     .advance_watermark_to(15)
+                     .add_elements([('testkey', '3')])
+                     .advance_watermark_to(20))
+
+      (p
+       | test_stream
+       | beam.ParDo(SimpleReadModifyWriteStatefulCleanDoFn())
+       | beam.ParDo(self.record_dofn()))
+
+    self.assertTrue(not StatefulDoFnOnDirectRunnerTest.all_records)
+
+  def test_stateful_read_modify_write_state_portably(self):
+
+    class ReadModifyWriteStatefulDoFn(beam.DoFn):
+
+      BUFFER_STATE = ReadModifyWriteStateSpec('buffer', VarIntCoder())
+      EMIT_TIMER = TimerSpec('emit_timer', TimeDomain.WATERMARK)
+
+      def process(self,
+                  element,
+                  buffer_state=beam.DoFn.StateParam(BUFFER_STATE),
+                  emit_timer=beam.DoFn.TimerParam(EMIT_TIMER)):
+        _, value = element
+        buffer_state.add(value)
+        emit_timer.set(3)
+
+      @on_timer(EMIT_TIMER)
+      def emit_values(self, buffer_state=beam.DoFn.StateParam(BUFFER_STATE)):
+        yield buffer_state.read()
+
+    p = TestPipeline()
+    values = p | beam.Create([('key', 1),
+                              ('key', 2),
+                              ('key', 3)])
+    actual_values = (values
+                     | beam.Map(lambda t: window.TimestampedValue(t, t[1]))
+                     | beam.ParDo(ReadModifyWriteStatefulDoFn()))
+
+    assert_that(actual_values, equal_to([3]))
+
+    result = p.run()
+    result.wait_until_finish()
+
+  def test_stateful_read_modify_write_state_clean_portably(self):
+
+    class ReadModifyWriteStatefulDoFn(beam.DoFn):
+
+      BUFFER_STATE = ReadModifyWriteStateSpec('buffer', VarIntCoder())
+      EMIT_TIMER = TimerSpec('emit_timer', TimeDomain.WATERMARK)
+
+      def process(self,
+                  element,
+                  buffer_state=beam.DoFn.StateParam(BUFFER_STATE),
+                  emit_timer=beam.DoFn.TimerParam(EMIT_TIMER)):
+        _, value = element
+        buffer_state.add(value)
+
+        if value == 3:
+          buffer_state.clear()
+
+        emit_timer.set(3)
+
+      @on_timer(EMIT_TIMER)
+      def emit_value(self, buffer_state=beam.DoFn.StateParam(BUFFER_STATE)):
+        yield buffer_state.read()
+
+    p = TestPipeline()
+    values = p | beam.Create([('key', 1),
+                              ('key', 2),
+                              ('key', 3)])
+    actual_values = (values
+                     | beam.Map(lambda t: window.TimestampedValue(t, t[1]))
+                     | beam.ParDo(ReadModifyWriteStatefulDoFn()))
+
+    assert_that(actual_values, equal_to([None]))
+
+    result = p.run()
+    result.wait_until_finish()
+
   def test_stateful_dofn_nonkeyed_input(self):
     p = TestPipeline()
     values = p | beam.Create([1, 2, 3])