Skip to content

Commit

Permalink
Add PartialSerialisedMessage test. (#27452)
Browse files Browse the repository at this point in the history
* Test deserialise.

* Add tests.

* Simplify and fix tests.

* Format.

* Adds tests for deserializeAirbyteMessage

* Adds tests for deserializeAirbyteMessage with bad data

* Cleans up deserializeAirbyteMessage and throws Exception when invalid message

* More code cleanup

---------

Co-authored-by: ryankfu <ryan.fu@airbyte.io>
  • Loading branch information
davinchia and ryankfu committed Jun 21, 2023
1 parent f787371 commit 0dc8f16
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

package io.airbyte.integrations.destination_async;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import io.airbyte.commons.json.Jsons;
Expand Down Expand Up @@ -109,10 +107,9 @@ public void accept(final String messageString, final Integer sizeInBytes) throws
*/
deserializeAirbyteMessage(messageString)
.ifPresent(message -> {
if (message.getType() == Type.RECORD) {
if (message.getType().equals(Type.RECORD)) {
validateRecord(message);
}

bufferEnqueue.addRecord(message, sizeInBytes + PARTIAL_DESERIALIZE_REF_BYTES);
});
}
Expand All @@ -121,61 +118,32 @@ public void accept(final String messageString, final Integer sizeInBytes) throws
* Deserializes to a {@link PartialAirbyteMessage} which can represent both a Record or a State
* Message
*
* PartialAirbyteMessage holds either:
* <li>entire serialized message string when message is a valid State Message
* <li>serialized AirbyteRecordMessage when message is a valid Record Message</li>
*
* @param messageString the string to deserialize
* @return PartialAirbyteMessage if the message is valid, empty otherwise
*/
private Optional<PartialAirbyteMessage> deserializeAirbyteMessage(final String messageString) {
@VisibleForTesting
public static Optional<PartialAirbyteMessage> deserializeAirbyteMessage(final String messageString) {
// TODO: (ryankfu) plumb in the serialized AirbyteStateMessage to match AirbyteRecordMessage code
// parity
final Optional<PartialAirbyteMessage> messageOptional = Jsons.tryDeserialize(messageString, PartialAirbyteMessage.class)
.map(partial -> {
if (partial.getRecord() != null) {
if (partial.getType().equals(Type.RECORD) && partial.getRecord().getData() != null) {
return partial.withSerialized(partial.getRecord().getData().toString());
} else {
} else if (partial.getType().equals(Type.STATE)) {
return partial.withSerialized(messageString);
} else {
return null;
}
});

if (messageOptional.isPresent()) {
return messageOptional;
} else {
if (isStateMessage(messageString)) {
throw new IllegalStateException("Invalid state message: " + messageString);
} else {
LOGGER.error("Received invalid message: " + messageString);
return Optional.empty();
}
}
}

/**
* Tests whether the provided JSON string represents a state message.
*
* @param input a JSON string that represents an {@link AirbyteMessage}.
* @return {@code true} if the message is a state message, {@code false} otherwise.
*/
private static boolean isStateMessage(final String input) {
final Optional<AirbyteTypeMessage> deserialized = Jsons.tryDeserialize(input, AirbyteTypeMessage.class);
return deserialized.filter(airbyteTypeMessage -> airbyteTypeMessage.getType() == Type.STATE).isPresent();
}

/**
* Custom class that can be used to parse a JSON message to determine the type of the represented
* {@link AirbyteMessage}.
*/
private static class AirbyteTypeMessage {

@JsonProperty("type")
@JsonPropertyDescription("Message type")
private AirbyteMessage.Type type;

@JsonProperty("type")
public AirbyteMessage.Type getType() {
return type;
}

@JsonProperty("type")
public void setType(final AirbyteMessage.Type type) {
this.type = type;
}

throw new RuntimeException(String.format("Invalid state message: %s", messageString));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
import com.fasterxml.jackson.databind.JsonNode;

import java.util.Objects;

public class PartialAirbyteRecordMessage {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.airbyte.integrations.destination_async.state.FlushFailure;
import io.airbyte.protocol.models.Field;
import io.airbyte.protocol.models.JsonSchemaType;
import io.airbyte.protocol.models.v0.AirbyteLogMessage;
import io.airbyte.protocol.models.v0.AirbyteMessage;
import io.airbyte.protocol.models.v0.AirbyteMessage.Type;
import io.airbyte.protocol.models.v0.AirbyteRecordMessage;
Expand All @@ -37,13 +38,16 @@
import java.time.Instant;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.lang.RandomStringUtils;
import org.junit.jupiter.api.BeforeEach;
Expand Down Expand Up @@ -74,6 +78,12 @@ class AsyncStreamConsumerTest {
Field.of("id", JsonSchemaType.NUMBER),
Field.of("name", JsonSchemaType.STRING))));

private static final JsonNode PAYLOAD = Jsons.jsonNode(Map.of(
"created_at", "2022-02-01T17:02:19+00:00",
"id", 1,
"make", "Mazda",
"nested_column", Map.of("array_column", List.of(1, 2, 3))));

private static final AirbyteMessage STATE_MESSAGE1 = new AirbyteMessage()
.withType(Type.STATE)
.withState(new AirbyteStateMessage()
Expand Down Expand Up @@ -114,7 +124,7 @@ void setup() {

@Test
void test1StreamWith1State() throws Exception {
final List<PartialAirbyteMessage> expectedRecords = generateRecords(1_000);
final List<AirbyteMessage> expectedRecords = generateRecords(1_000);

consumer.start();
consumeRecords(consumer, expectedRecords);
Expand All @@ -130,7 +140,7 @@ void test1StreamWith1State() throws Exception {

@Test
void test1StreamWith2State() throws Exception {
final List<PartialAirbyteMessage> expectedRecords = generateRecords(1_000);
final List<AirbyteMessage> expectedRecords = generateRecords(1_000);

consumer.start();
consumeRecords(consumer, expectedRecords);
Expand All @@ -147,21 +157,20 @@ void test1StreamWith2State() throws Exception {

@Test
void test1StreamWith0State() throws Exception {
final List<PartialAirbyteMessage> expectedRecords = generateRecords(1_000);
final List<AirbyteMessage> allRecords = generateRecords(1_000);

consumer.start();
consumeRecords(consumer, expectedRecords);
consumeRecords(consumer, allRecords);
consumer.close();

verifyStartAndClose();

verifyRecords(STREAM_NAME, SCHEMA_NAME, expectedRecords);
verifyRecords(STREAM_NAME, SCHEMA_NAME, allRecords);
}

@Test
void testShouldBlockWhenQueuesAreFull() throws Exception {
consumer.start();

}

/*
Expand Down Expand Up @@ -215,6 +224,60 @@ void testBackPressure() throws Exception {
assertTrue(recordCount.get() < 1000, String.format("Record count was %s", recordCount.get()));
}

@Test
void deserializeAirbyteMessageWithAirbyteRecord() {
final AirbyteMessage airbyteMessage = new AirbyteMessage()
.withType(Type.RECORD)
.withRecord(new AirbyteRecordMessage()
.withStream(STREAM_NAME)
.withNamespace(SCHEMA_NAME)
.withData(PAYLOAD));
final String serializedAirbyteMessage = Jsons.serialize(airbyteMessage);
final String airbyteRecordString = Jsons.serialize(PAYLOAD);
final Optional<PartialAirbyteMessage> partial = AsyncStreamConsumer.deserializeAirbyteMessage(serializedAirbyteMessage);
assertEquals(airbyteRecordString, partial.get().getSerialized());
}

@Test
void deserializeAirbyteMessageWithEmptyAirbyteRecord() {
final Map emptyMap = Map.of();
final AirbyteMessage airbyteMessage = new AirbyteMessage()
.withType(Type.RECORD)
.withRecord(new AirbyteRecordMessage()
.withStream(STREAM_NAME)
.withNamespace(SCHEMA_NAME)
.withData(Jsons.jsonNode(emptyMap)));
final String serializedAirbyteMessage = Jsons.serialize(airbyteMessage);
final Optional<PartialAirbyteMessage> partial = AsyncStreamConsumer.deserializeAirbyteMessage(serializedAirbyteMessage);
assertEquals(emptyMap.toString(), partial.get().getSerialized());
}

@Test
void deserializeAirbyteMessageWithNoStateOrRecord() {
final AirbyteMessage airbyteMessage = new AirbyteMessage()
.withType(Type.LOG)
.withLog(new AirbyteLogMessage());
final String serializedAirbyteMessage = Jsons.serialize(airbyteMessage);
assertThrows(RuntimeException.class, () -> AsyncStreamConsumer.deserializeAirbyteMessage(serializedAirbyteMessage));
}

@Test
void deserializeAirbyteMessageWithAirbyteState() {
final String serializedAirbyteMessage = Jsons.serialize(STATE_MESSAGE1);
final Optional<PartialAirbyteMessage> partial = AsyncStreamConsumer.deserializeAirbyteMessage(serializedAirbyteMessage);
assertEquals(serializedAirbyteMessage, partial.get().getSerialized());
}

@Test
void deserializeAirbyteMessageWithBadAirbyteState() {
final AirbyteMessage badState = new AirbyteMessage()
.withState(new AirbyteStateMessage()
.withType(AirbyteStateType.STREAM)
.withStream(new AirbyteStreamState().withStreamDescriptor(STREAM1_DESC).withStreamState(Jsons.jsonNode(1))));
final String serializedAirbyteMessage = Jsons.serialize(badState);
assertThrows(RuntimeException.class, () -> AsyncStreamConsumer.deserializeAirbyteMessage(serializedAirbyteMessage));
}

@Nested
class ErrorHandling {

Expand Down Expand Up @@ -246,10 +309,10 @@ void testErrorOnClose() throws Exception {

}

private static void consumeRecords(final AsyncStreamConsumer consumer, final Collection<PartialAirbyteMessage> records) {
private static void consumeRecords(final AsyncStreamConsumer consumer, final Collection<AirbyteMessage> records) {
records.forEach(m -> {
try {
consumer.accept(m.getSerialized(), RECORD_SIZE_20_BYTES);
consumer.accept(Jsons.serialize(m), RECORD_SIZE_20_BYTES);
} catch (final Exception e) {
throw new RuntimeException(e);
}
Expand All @@ -258,25 +321,20 @@ private static void consumeRecords(final AsyncStreamConsumer consumer, final Col

// NOTE: Generates records at chunks of 160 bytes
@SuppressWarnings("SameParameterValue")
private static List<PartialAirbyteMessage> generateRecords(final long targetSizeInBytes) {
final List<PartialAirbyteMessage> output = Lists.newArrayList();
private static List<AirbyteMessage> generateRecords(final long targetSizeInBytes) {
final List<AirbyteMessage> output = Lists.newArrayList();
long bytesCounter = 0;
for (int i = 0;; i++) {
final JsonNode payload =
Jsons.jsonNode(ImmutableMap.of("id", RandomStringUtils.randomAlphabetic(7), "name", "human " + String.format("%8d", i)));
final long sizeInBytes = RecordSizeEstimator.getStringByteSize(payload);
bytesCounter += sizeInBytes;
final PartialAirbyteMessage airbyteMessage = new PartialAirbyteMessage()
final AirbyteMessage airbyteMessage = new AirbyteMessage()
.withType(Type.RECORD)
.withRecord(new PartialAirbyteRecordMessage()
.withRecord(new AirbyteRecordMessage()
.withStream(STREAM_NAME)
.withNamespace(SCHEMA_NAME))
.withSerialized(Jsons.serialize(new AirbyteMessage()
.withType(Type.RECORD)
.withRecord(new AirbyteRecordMessage()
.withStream(STREAM_NAME)
.withNamespace(SCHEMA_NAME)
.withData(payload))));
.withNamespace(SCHEMA_NAME)
.withData(payload));
if (bytesCounter > targetSizeInBytes) {
break;
} else {
Expand All @@ -292,7 +350,7 @@ private void verifyStartAndClose() throws Exception {
}

@SuppressWarnings({"unchecked", "SameParameterValue"})
private void verifyRecords(final String streamName, final String namespace, final Collection<PartialAirbyteMessage> expectedRecords)
private void verifyRecords(final String streamName, final String namespace, final List<AirbyteMessage> allRecords)
throws Exception {
final ArgumentCaptor<Stream<PartialAirbyteMessage>> argumentCaptor = ArgumentCaptor.forClass(Stream.class);
verify(flushFunction, atLeast(1)).flush(
Expand All @@ -306,7 +364,15 @@ private void verifyRecords(final String streamName, final String namespace, fina
// flatten those results into a single list for the simplicity of comparison
.flatMap(s -> s)
.toList();
assertEquals(expectedRecords.stream().toList(), actualRecords);

final var expRecords = allRecords.stream().map(m -> new PartialAirbyteMessage()
.withType(Type.RECORD)
.withRecord(new PartialAirbyteRecordMessage()
.withStream(m.getRecord().getStream())
.withNamespace(m.getRecord().getNamespace())
.withData(m.getRecord().getData()))
.withSerialized(Jsons.serialize(m.getRecord().getData()))).collect(Collectors.toList());
assertEquals(expRecords, actualRecords);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright (c) 2023 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.integrations.destination_async;

import io.airbyte.commons.json.Jsons;
import io.airbyte.integrations.destination_async.partial_messages.PartialAirbyteMessage;
import io.airbyte.protocol.models.AirbyteStateMessage;
import io.airbyte.protocol.models.AirbyteStreamState;
import io.airbyte.protocol.models.StreamDescriptor;
import io.airbyte.protocol.models.v0.AirbyteMessage;
import io.airbyte.protocol.models.v0.AirbyteRecordMessage;
import java.time.Instant;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

public class PartialAirbyteMessageTest {

@Test
void testDeserializeRecord() {
final long emittedAt = Instant.now().toEpochMilli();
final var serializedRec = Jsons.serialize(new AirbyteMessage()
.withType(AirbyteMessage.Type.RECORD)
.withRecord(new AirbyteRecordMessage()
.withStream("users")
.withNamespace("public")
.withEmittedAt(emittedAt)
.withData(Jsons.jsonNode("data"))));

final var rec = Jsons.tryDeserialize(serializedRec, PartialAirbyteMessage.class).get();
Assertions.assertEquals(AirbyteMessage.Type.RECORD, rec.getType());
Assertions.assertEquals("users", rec.getRecord().getStream());
Assertions.assertEquals("public", rec.getRecord().getNamespace());
Assertions.assertEquals("\"data\"", rec.getRecord().getData().toString());
Assertions.assertEquals(emittedAt, rec.getRecord().getEmittedAt());
}

@Test
void testDeserializeState() {
final var serializedState = Jsons.serialize(new io.airbyte.protocol.models.AirbyteMessage()
.withType(io.airbyte.protocol.models.AirbyteMessage.Type.STATE)
.withState(new AirbyteStateMessage().withStream(
new AirbyteStreamState().withStreamDescriptor(
new StreamDescriptor().withName("user").withNamespace("public"))
.withStreamState(Jsons.jsonNode("data")))
.withType(AirbyteStateMessage.AirbyteStateType.STREAM)));

final var rec = Jsons.tryDeserialize(serializedState, PartialAirbyteMessage.class).get();
Assertions.assertEquals(AirbyteMessage.Type.STATE, rec.getType());

final var streamDesc = rec.getState().getStream().getStreamDescriptor();
Assertions.assertEquals("user", streamDesc.getName());
Assertions.assertEquals("public", streamDesc.getNamespace());
Assertions.assertEquals(io.airbyte.protocol.models.v0.AirbyteStateMessage.AirbyteStateType.STREAM, rec.getState().getType());
}

@Test
void testGarbage() {
final var badSerialization = "messed up data";

final var rec = Jsons.tryDeserialize(badSerialization, PartialAirbyteMessage.class);
Assertions.assertTrue(rec.isEmpty());
}

}
Loading

0 comments on commit 0dc8f16

Please sign in to comment.