Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PartialSerialisedMessage test. #27452

Merged
merged 8 commits into from
Jun 20, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,9 @@ public void accept(final String messageString, final Integer sizeInBytes) throws
*/
deserializeAirbyteMessage(messageString)
.ifPresent(message -> {
if (message.getType() == Type.RECORD) {
if (message.getType().equals(Type.RECORD)) {
validateRecord(message);
}

bufferEnqueue.addRecord(message, sizeInBytes + PARTIAL_DESERIALIZE_REF_BYTES);
});
}
Expand All @@ -121,18 +120,28 @@ public void accept(final String messageString, final Integer sizeInBytes) throws
* Deserializes to a {@link PartialAirbyteMessage} which can represent both a Record or a State
* Message
*
* PartialAirbyteMessage holds either:
* <li>entire serialized message string when message is a valid State Message
* <li>serialized AirbyteRecordMessage when message is a valid Record Message</li>
*
* @param messageString the string to deserialize
* @return PartialAirbyteMessage if the message is valid, empty otherwise
*/
private Optional<PartialAirbyteMessage> deserializeAirbyteMessage(final String messageString) {
@VisibleForTesting
public static Optional<PartialAirbyteMessage> deserializeAirbyteMessage(final String messageString) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should add tests for this since this is now a critical piece

// TODO: (ryankfu) plumb in the serialized AirbyteStateMessage to match AirbyteRecordMessage and
// code parity
final Optional<PartialAirbyteMessage> messageOptional = Jsons.tryDeserialize(messageString, PartialAirbyteMessage.class)
.map(partial -> {
if (partial.getRecord() != null) {
if (partial.getType().equals(Type.RECORD) && partial.getRecord().getData() != null) {
return partial.withSerialized(partial.getRecord().getData().toString());
} else {
} else if (partial.getType().equals(Type.STATE)) {
return partial.withSerialized(messageString);
} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@davinchia added another case where I believe we can have bad data which is when messageString is neither a STATE or RECORD message. In this case, this should not be passed to the consumer. I haven't personally seen this happen but this now more closely matches the description that PartialAirbyteMessage should only contain either STATE or RECORD. Lmk if you think this is unnecessary. Not sure if the platform filters out messages

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot guarantee messages arrive at the Destination ungarbled, so we definitely need to consider this case.

Less decided whether we should fail fast or log an error.. I think we should do the simplest for now and throw a RTE to make sure we don't drop data.

I'm wondering if this happens in practice somehow. We should add logging + test this out before committing this code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I'll add an RTE here, add some logging, publish a new pre-release version, and run it overnight with some resets to see if this ever gets triggered but agree we should fail fast here

return null;
}
});

if (messageOptional.isPresent()) {
return messageOptional;
} else {
Expand All @@ -143,6 +152,7 @@ private Optional<PartialAirbyteMessage> deserializeAirbyteMessage(final String m
return Optional.empty();
}
}

}

/**
Expand All @@ -151,7 +161,8 @@ private Optional<PartialAirbyteMessage> deserializeAirbyteMessage(final String m
* @param input a JSON string that represents an {@link AirbyteMessage}.
* @return {@code true} if the message is a state message, {@code false} otherwise.
*/
private static boolean isStateMessage(final String input) {
@VisibleForTesting
public static boolean isStateMessage(final String input) {
final Optional<AirbyteTypeMessage> deserialized = Jsons.tryDeserialize(input, AirbyteTypeMessage.class);
return deserialized.filter(airbyteTypeMessage -> airbyteTypeMessage.getType() == Type.STATE).isPresent();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
import com.fasterxml.jackson.databind.JsonNode;

import java.util.Objects;

public class PartialAirbyteRecordMessage {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.airbyte.integrations.destination_async.state.FlushFailure;
import io.airbyte.protocol.models.Field;
import io.airbyte.protocol.models.JsonSchemaType;
import io.airbyte.protocol.models.v0.AirbyteLogMessage;
import io.airbyte.protocol.models.v0.AirbyteMessage;
import io.airbyte.protocol.models.v0.AirbyteMessage.Type;
import io.airbyte.protocol.models.v0.AirbyteRecordMessage;
Expand All @@ -37,13 +38,16 @@
import java.time.Instant;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.lang.RandomStringUtils;
import org.junit.jupiter.api.BeforeEach;
Expand Down Expand Up @@ -74,6 +78,12 @@ class AsyncStreamConsumerTest {
Field.of("id", JsonSchemaType.NUMBER),
Field.of("name", JsonSchemaType.STRING))));

private static final JsonNode PAYLOAD = Jsons.jsonNode(Map.of(
"created_at", "2022-02-01T17:02:19+00:00",
"id", 1,
"make", "Mazda",
"nested_column", Map.of("array_column", List.of(1, 2, 3))));

private static final AirbyteMessage STATE_MESSAGE1 = new AirbyteMessage()
.withType(Type.STATE)
.withState(new AirbyteStateMessage()
Expand Down Expand Up @@ -114,7 +124,7 @@ void setup() {

@Test
void test1StreamWith1State() throws Exception {
final List<PartialAirbyteMessage> expectedRecords = generateRecords(1_000);
final List<AirbyteMessage> expectedRecords = generateRecords(1_000);

consumer.start();
consumeRecords(consumer, expectedRecords);
Expand All @@ -130,7 +140,7 @@ void test1StreamWith1State() throws Exception {

@Test
void test1StreamWith2State() throws Exception {
final List<PartialAirbyteMessage> expectedRecords = generateRecords(1_000);
final List<AirbyteMessage> expectedRecords = generateRecords(1_000);

consumer.start();
consumeRecords(consumer, expectedRecords);
Expand All @@ -147,21 +157,20 @@ void test1StreamWith2State() throws Exception {

@Test
void test1StreamWith0State() throws Exception {
final List<PartialAirbyteMessage> expectedRecords = generateRecords(1_000);
final List<AirbyteMessage> allRecords = generateRecords(1_000);

consumer.start();
consumeRecords(consumer, expectedRecords);
consumeRecords(consumer, allRecords);
consumer.close();

verifyStartAndClose();

verifyRecords(STREAM_NAME, SCHEMA_NAME, expectedRecords);
verifyRecords(STREAM_NAME, SCHEMA_NAME, allRecords);
}

@Test
void testShouldBlockWhenQueuesAreFull() throws Exception {
consumer.start();

}

/*
Expand Down Expand Up @@ -215,6 +224,62 @@ void testBackPressure() throws Exception {
assertTrue(recordCount.get() < 1000, String.format("Record count was %s", recordCount.get()));
}

@Test
void deserializeAirbyteMessageWithAirbyteRecord() {
final AirbyteMessage airbyteMessage = new AirbyteMessage()
.withType(Type.RECORD)
.withRecord(new AirbyteRecordMessage()
.withStream(STREAM_NAME)
.withNamespace(SCHEMA_NAME)
.withData(PAYLOAD));
final String serializedAirbyteMessage = Jsons.serialize(airbyteMessage);
final String airbyteRecordString = Jsons.serialize(PAYLOAD);
final Optional<PartialAirbyteMessage> partial = AsyncStreamConsumer.deserializeAirbyteMessage(serializedAirbyteMessage);
assertEquals(airbyteRecordString, partial.get().getSerialized());
}

@Test
void deserializeAirbyteMessageWithEmptyAirbyteRecord() {
final Map emptyMap = Map.of();
final AirbyteMessage airbyteMessage = new AirbyteMessage()
.withType(Type.RECORD)
.withRecord(new AirbyteRecordMessage()
.withStream(STREAM_NAME)
.withNamespace(SCHEMA_NAME)
.withData(Jsons.jsonNode(emptyMap)));
final String serializedAirbyteMessage = Jsons.serialize(airbyteMessage);
final Optional<PartialAirbyteMessage> partial = AsyncStreamConsumer.deserializeAirbyteMessage(serializedAirbyteMessage);
assertEquals(emptyMap.toString(), partial.get().getSerialized());
}

@Test
void deserializeAirbyteMessageWithNoStateOrRecord() {
final AirbyteMessage airbyteMessage = new AirbyteMessage()
.withType(Type.LOG)
.withLog(new AirbyteLogMessage());
final String serializedAirbyteMessage = Jsons.serialize(airbyteMessage);
final Optional<PartialAirbyteMessage> partial = AsyncStreamConsumer.deserializeAirbyteMessage(serializedAirbyteMessage);
assertTrue(partial.isEmpty());
}

@Test
void deserializeAirbyteMessageWithAirbyteState() {
final String serializedAirbyteMessage = Jsons.serialize(STATE_MESSAGE1);
final Optional<PartialAirbyteMessage> partial = AsyncStreamConsumer.deserializeAirbyteMessage(serializedAirbyteMessage);
assertEquals(serializedAirbyteMessage, partial.get().getSerialized());
}

@Test
void deserializeAirbyteMessageWithBadAirbyteState() {
final AirbyteMessage badState = new AirbyteMessage()
.withState(new AirbyteStateMessage()
.withType(AirbyteStateType.STREAM)
.withStream(new AirbyteStreamState().withStreamDescriptor(STREAM1_DESC).withStreamState(Jsons.jsonNode(1))));
final String serializedAirbyteMessage = Jsons.serialize(badState);
final boolean isState = AsyncStreamConsumer.isStateMessage(serializedAirbyteMessage);
assertFalse(isState);
}

@Nested
class ErrorHandling {

Expand Down Expand Up @@ -246,10 +311,10 @@ void testErrorOnClose() throws Exception {

}

private static void consumeRecords(final AsyncStreamConsumer consumer, final Collection<PartialAirbyteMessage> records) {
private static void consumeRecords(final AsyncStreamConsumer consumer, final Collection<AirbyteMessage> records) {
records.forEach(m -> {
try {
consumer.accept(m.getSerialized(), RECORD_SIZE_20_BYTES);
consumer.accept(Jsons.serialize(m), RECORD_SIZE_20_BYTES);
} catch (final Exception e) {
throw new RuntimeException(e);
}
Expand All @@ -258,25 +323,20 @@ private static void consumeRecords(final AsyncStreamConsumer consumer, final Col

// NOTE: Generates records at chunks of 160 bytes
@SuppressWarnings("SameParameterValue")
private static List<PartialAirbyteMessage> generateRecords(final long targetSizeInBytes) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was previously doing too much.

simplified to only generate AirbyteMessages and split out the logic to generate expected records into the verifyRecords function

final List<PartialAirbyteMessage> output = Lists.newArrayList();
private static List<AirbyteMessage> generateRecords(final long targetSizeInBytes) {
final List<AirbyteMessage> output = Lists.newArrayList();
long bytesCounter = 0;
for (int i = 0;; i++) {
final JsonNode payload =
Jsons.jsonNode(ImmutableMap.of("id", RandomStringUtils.randomAlphabetic(7), "name", "human " + String.format("%8d", i)));
final long sizeInBytes = RecordSizeEstimator.getStringByteSize(payload);
bytesCounter += sizeInBytes;
final PartialAirbyteMessage airbyteMessage = new PartialAirbyteMessage()
final AirbyteMessage airbyteMessage = new AirbyteMessage()
.withType(Type.RECORD)
.withRecord(new PartialAirbyteRecordMessage()
.withRecord(new AirbyteRecordMessage()
.withStream(STREAM_NAME)
.withNamespace(SCHEMA_NAME))
.withSerialized(Jsons.serialize(new AirbyteMessage()
.withType(Type.RECORD)
.withRecord(new AirbyteRecordMessage()
.withStream(STREAM_NAME)
.withNamespace(SCHEMA_NAME)
.withData(payload))));
.withNamespace(SCHEMA_NAME)
.withData(payload));
if (bytesCounter > targetSizeInBytes) {
break;
} else {
Expand All @@ -292,7 +352,7 @@ private void verifyStartAndClose() throws Exception {
}

@SuppressWarnings({"unchecked", "SameParameterValue"})
private void verifyRecords(final String streamName, final String namespace, final Collection<PartialAirbyteMessage> expectedRecords)
private void verifyRecords(final String streamName, final String namespace, final List<AirbyteMessage> allRecords)
throws Exception {
final ArgumentCaptor<Stream<PartialAirbyteMessage>> argumentCaptor = ArgumentCaptor.forClass(Stream.class);
verify(flushFunction, atLeast(1)).flush(
Expand All @@ -306,7 +366,15 @@ private void verifyRecords(final String streamName, final String namespace, fina
// flatten those results into a single list for the simplicity of comparison
.flatMap(s -> s)
.toList();
assertEquals(expectedRecords.stream().toList(), actualRecords);

final var expRecords = allRecords.stream().map(m -> new PartialAirbyteMessage()
.withType(Type.RECORD)
.withRecord(new PartialAirbyteRecordMessage()
.withStream(m.getRecord().getStream())
.withNamespace(m.getRecord().getNamespace())
.withData(m.getRecord().getData()))
.withSerialized(Jsons.serialize(m.getRecord().getData()))).collect(Collectors.toList());
assertEquals(expRecords, actualRecords);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright (c) 2023 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.integrations.destination_async;

import io.airbyte.commons.json.Jsons;
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteMessage;
import io.airbyte.protocol.models.AirbyteStateMessage;
import io.airbyte.protocol.models.AirbyteStreamState;
import io.airbyte.protocol.models.StreamDescriptor;
import io.airbyte.protocol.models.v0.AirbyteMessage;
import io.airbyte.protocol.models.v0.AirbyteRecordMessage;
import java.time.Instant;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

public class PartialAirbyteMessageTest {

@Test
void testDeserializeRecord() {
final long emittedAt = Instant.now().toEpochMilli();
final var serializedRec = Jsons.serialize(new AirbyteMessage()
.withType(AirbyteMessage.Type.RECORD)
.withRecord(new AirbyteRecordMessage()
.withStream("users")
.withNamespace("public")
.withEmittedAt(emittedAt)
.withData(Jsons.jsonNode("data"))));

final var rec = Jsons.tryDeserialize(serializedRec, PartialAirbyteMessage.class).get();
Assertions.assertEquals(AirbyteMessage.Type.RECORD, rec.getType());
Assertions.assertEquals("users", rec.getRecord().getStream());
Assertions.assertEquals("public", rec.getRecord().getNamespace());
Assertions.assertEquals("\"data\"", rec.getRecord().getData().toString());
Assertions.assertEquals(emittedAt, rec.getRecord().getEmittedAt());
}

@Test
void testDeserializeState() {
final var serializedState = Jsons.serialize(new io.airbyte.protocol.models.AirbyteMessage()
.withType(io.airbyte.protocol.models.AirbyteMessage.Type.STATE)
.withState(new AirbyteStateMessage().withStream(
new AirbyteStreamState().withStreamDescriptor(
new StreamDescriptor().withName("user").withNamespace("public"))
.withStreamState(Jsons.jsonNode("data")))
.withType(AirbyteStateMessage.AirbyteStateType.STREAM)));

final var rec = Jsons.tryDeserialize(serializedState, PartialAirbyteMessage.class).get();
Assertions.assertEquals(AirbyteMessage.Type.STATE, rec.getType());

final var streamDesc = rec.getState().getStream().getStreamDescriptor();
Assertions.assertEquals("user", streamDesc.getName());
Assertions.assertEquals("public", streamDesc.getNamespace());
Assertions.assertEquals(io.airbyte.protocol.models.v0.AirbyteStateMessage.AirbyteStateType.STREAM, rec.getState().getType());
}

@Test
void testGarbage() {
final var badSerialization = "messed up data";

final var rec = Jsons.tryDeserialize(badSerialization, PartialAirbyteMessage.class);
Assertions.assertTrue(rec.isEmpty());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@ public SerializedAirbyteMessageConsumer getSerializedMessageConsumer(final JsonN
log.info("destination class: {}", getClass());
// this is how we toggle async snowflake on.
// final boolean useAsyncSnowflake = false;
final boolean useAsyncSnowflake = config.has("loading_method")
&& config.get("loading_method").has("method")
&& config.get("loading_method").get("method").asText().equals("Internal Staging")
// Standard is when someone doesn't specify a loading method but still DestinationType.INTERNAL_STAGING
final boolean useAsyncSnowflake = config.has("loading_method")
&& config.get("loading_method").has("method")
&& config.get("loading_method").get("method").asText().equals("Internal Staging")
// Standard is when someone doesn't specify a loading method but still
// DestinationType.INTERNAL_STAGING
|| config.get("loading_method").get("method").asText().equals("Standard");

log.info("using async snowflake: {}", useAsyncSnowflake);
Expand Down
Loading