Skip to content
This repository has been archived by the owner on Nov 11, 2022. It is now read-only.

Fix InProcessPipelineRunner to handle a null subscription #547

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -30,6 +30,7 @@
import com.google.cloud.dataflow.sdk.options.DataflowPipelineOptions;
import com.google.cloud.dataflow.sdk.options.PipelineOptions;
import com.google.cloud.dataflow.sdk.options.ValueProvider;
import com.google.cloud.dataflow.sdk.options.ValueProvider.StaticValueProvider;
import com.google.cloud.dataflow.sdk.transforms.Aggregator;
import com.google.cloud.dataflow.sdk.transforms.Combine;
import com.google.cloud.dataflow.sdk.transforms.DoFn;
Expand Down Expand Up @@ -1290,6 +1291,7 @@ public String getIdLabel() {

@Override
public PCollection<T> apply(PBegin input) {
ValueProvider<SubscriptionPath> subscriptionPath = subscription;
if (subscription == null) {
try {
try (PubsubClient pubsubClient =
Expand All @@ -1299,9 +1301,8 @@ public PCollection<T> apply(PBegin input) {
.as(DataflowPipelineOptions.class))) {
checkState(project.isAccessible(), "createRandomSubscription must be called at runtime.");
checkState(topic.isAccessible(), "createRandomSubscription must be called at runtime.");
SubscriptionPath subscriptionPath =
pubsubClient.createRandomSubscription(
project.get(), topic.get(), DEAULT_ACK_TIMEOUT_SEC);
subscriptionPath = StaticValueProvider.of(pubsubClient.createRandomSubscription(
project.get(), topic.get(), DEAULT_ACK_TIMEOUT_SEC));
LOG.warn("Created subscription {} to topic {}."
+ " Note this subscription WILL NOT be deleted when the pipeline terminates",
subscription, topic);
Expand All @@ -1314,7 +1315,7 @@ public PCollection<T> apply(PBegin input) {
return input.getPipeline().begin()
.apply(Read.from(new PubsubSource<T>(this)))
.apply(ParDo.named("PubsubUnboundedSource.Stats")
.of(new StatsFn<T>(pubsubFactory, subscription,
timestampLabel, idLabel)));
.of(new StatsFn<T>(pubsubFactory, checkNotNull(subscriptionPath),
timestampLabel, idLabel)));
}
}
Expand Up @@ -107,6 +107,11 @@ private static class State {
*/
@Nullable
Map<String, Long> ackDeadline;

/**
* Whether a subscription has been created.
*/
boolean createdSubscription;
}

private static final State STATE = new State();
Expand All @@ -124,12 +129,40 @@ public static PubsubTestClientFactory createFactoryForPublish(
final TopicPath expectedTopic,
final Iterable<OutgoingMessage> expectedOutgoingMessages,
final Iterable<OutgoingMessage> failingOutgoingMessages) {
return createFactoryForPublishInternal(
expectedTopic, expectedOutgoingMessages, failingOutgoingMessages, false);
}

/**
* Return a factory for testing publishers. Only one factory may be in-flight at a time.
* The factory must be closed when the test is complete, at which point final validation will
* occur. Additionally, verify that createSubscription was called.
*/
public static PubsubTestClientFactory createFactoryForPublishVerifySubscription(
final TopicPath expectedTopic,
final Iterable<OutgoingMessage> expectedOutgoingMessages,
final Iterable<OutgoingMessage> failingOutgoingMessages) {
return createFactoryForPublishInternal(
expectedTopic, expectedOutgoingMessages, failingOutgoingMessages, true);
}

/**
* Return a factory for testing publishers. Only one factory may be in-flight at a time.
* The factory must be closed when the test is complete, at which point final validation will
* occur.
*/
public static PubsubTestClientFactory createFactoryForPublishInternal(
final TopicPath expectedTopic,
final Iterable<OutgoingMessage> expectedOutgoingMessages,
final Iterable<OutgoingMessage> failingOutgoingMessages,
final boolean verifySubscriptionCreated) {
synchronized (STATE) {
checkState(!STATE.isActive, "Test still in flight");
STATE.expectedTopic = expectedTopic;
STATE.remainingExpectedOutgoingMessages = Sets.newHashSet(expectedOutgoingMessages);
STATE.remainingFailingOutgoingMessages = Sets.newHashSet(failingOutgoingMessages);
STATE.isActive = true;
STATE.createdSubscription = false;
}
return new PubsubTestClientFactory() {
@Override
Expand All @@ -148,6 +181,9 @@ public String getKind() {
@Override
public void close() {
synchronized (STATE) {
if (verifySubscriptionCreated) {
checkState(STATE.createdSubscription, "Did not call create subscription");
}
checkState(STATE.isActive, "No test still in flight");
checkState(STATE.remainingExpectedOutgoingMessages.isEmpty(),
"Still waiting for %s messages to be published",
Expand Down Expand Up @@ -372,7 +408,10 @@ public List<TopicPath> listTopics(ProjectPath project) throws IOException {
@Override
public void createSubscription(
TopicPath topic, SubscriptionPath subscription, int ackDeadlineSeconds) throws IOException {
throw new UnsupportedOperationException();
synchronized (STATE) {
STATE.createdSubscription = true;
}
return;
}

@Override
Expand Down
Expand Up @@ -36,7 +36,10 @@
import com.google.cloud.dataflow.sdk.util.CoderUtils;
import com.google.cloud.dataflow.sdk.util.PubsubClient;
import com.google.cloud.dataflow.sdk.util.PubsubClient.IncomingMessage;
import com.google.cloud.dataflow.sdk.util.PubsubClient.OutgoingMessage;
import com.google.cloud.dataflow.sdk.util.PubsubClient.ProjectPath;
import com.google.cloud.dataflow.sdk.util.PubsubClient.SubscriptionPath;
import com.google.cloud.dataflow.sdk.util.PubsubClient.TopicPath;
import com.google.cloud.dataflow.sdk.util.PubsubTestClient;
import com.google.cloud.dataflow.sdk.util.PubsubTestClient.PubsubTestClientFactory;

Expand All @@ -60,8 +63,12 @@
*/
@RunWith(JUnit4.class)
public class PubsubUnboundedSourceTest {
private static final ProjectPath PROJECT =
PubsubClient.projectPathFromId("testProject");
private static final SubscriptionPath SUBSCRIPTION =
PubsubClient.subscriptionPathFromName("testProject", "testSubscription");
private static final TopicPath TOPIC =
PubsubClient.topicPathFromName("testProject", "testTopic");
private static final String DATA = "testData";
private static final long TIMESTAMP = 1234L;
private static final long REQ_TIME = 6373L;
Expand Down Expand Up @@ -320,4 +327,14 @@ public void readManyMessages() throws IOException {
assertTrue(dataToMessageNum.isEmpty());
reader.close();
}

@Test
public void testNullSubscription() throws Exception {
factory = PubsubTestClient.createFactoryForPublishVerifySubscription(
TOPIC, ImmutableList.<OutgoingMessage>of(), ImmutableList.<OutgoingMessage>of());
TestPipeline p = TestPipeline.create();
p.apply(new PubsubUnboundedSource<>(
clock, factory, StaticValueProvider.of(PROJECT), StaticValueProvider.of(TOPIC),
null, StringUtf8Coder.of(), TIMESTAMP_LABEL, ID_LABEL));
}
}