Skip to content

Commit

Permalink
feat(batch-processing): Support for moving non retryable msg to DLQ
Browse files Browse the repository at this point in the history
  • Loading branch information
Pankaj Agrawal committed Aug 12, 2021
1 parent 7ec6fd3 commit 242736b
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,8 @@
Class<? extends SqsMessageHandler<Object>> value();

boolean suppressException() default false;

Class<? extends Exception>[] nonRetryableExceptions() default {};

boolean deleteNonRetryableMessageFromQueue() default false;
}
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ public static <R> List<R> batchProcessor(final SQSEvent event,
return batchProcessor(event, false, handler);
}

@SafeVarargs
public static <R> List<R> batchProcessor(final SQSEvent event,
final Class<? extends SqsMessageHandler<R>> handler,
final Class<? extends Exception>... nonRetryableExceptions) {
return batchProcessor(event, false, handler, nonRetryableExceptions);
}

/**
* This utility method is used to processes each {@link SQSMessage} inside received {@link SQSEvent}
*
Expand Down Expand Up @@ -166,6 +173,16 @@ public static <R> List<R> batchProcessor(final SQSEvent event,
return batchProcessor(event, suppressException, handlerInstance);
}

@SafeVarargs
public static <R> List<R> batchProcessor(final SQSEvent event,
final boolean suppressException,
final Class<? extends SqsMessageHandler<R>> handler,
final Class<? extends Exception>... nonRetryableExceptions) {

SqsMessageHandler<R> handlerInstance = instantiatedHandler(handler);
return batchProcessor(event, suppressException, handlerInstance, false, nonRetryableExceptions);
}

/**
* This utility method is used to processes each {@link SQSMessage} inside received {@link SQSEvent}
*
Expand Down Expand Up @@ -199,6 +216,14 @@ public static <R> List<R> batchProcessor(final SQSEvent event,
return batchProcessor(event, false, handler);
}

@SafeVarargs
public static <R> List<R> batchProcessor(final SQSEvent event,
final SqsMessageHandler<R> handler,
final Class<? extends Exception>... nonRetryableExceptions) {
return batchProcessor(event, false, handler, false, nonRetryableExceptions);
}


/**
* This utility method is used to processes each {@link SQSMessage} inside received {@link SQSEvent}
*
Expand Down Expand Up @@ -229,6 +254,16 @@ public static <R> List<R> batchProcessor(final SQSEvent event,
public static <R> List<R> batchProcessor(final SQSEvent event,
final boolean suppressException,
final SqsMessageHandler<R> handler) {
return batchProcessor(event, suppressException, handler, false);

}

@SafeVarargs
public static <R> List<R> batchProcessor(final SQSEvent event,
final boolean suppressException,
final SqsMessageHandler<R> handler,
final boolean deleteNonRetryableMessageFromQueue,
final Class<? extends Exception>... nonRetryableExceptions) {
final List<R> handlerReturn = new ArrayList<>();

if(client == null) {
Expand All @@ -246,7 +281,7 @@ public static <R> List<R> batchProcessor(final SQSEvent event,
}
}

batchContext.processSuccessAndHandleFailed(handlerReturn, suppressException);
batchContext.processSuccessAndHandleFailed(handlerReturn, suppressException, deleteNonRetryableMessageFromQueue, nonRetryableExceptions);

return handlerReturn;
}
Expand All @@ -255,12 +290,12 @@ private static <R> SqsMessageHandler<R> instantiatedHandler(final Class<? extend

try {
if (null == handler.getDeclaringClass()) {
return handler.newInstance();
return handler.getDeclaredConstructor().newInstance();
}

final Constructor<? extends SqsMessageHandler<R>> constructor = handler.getDeclaredConstructor(handler.getDeclaringClass());
constructor.setAccessible(true);
return constructor.newInstance(handler.getDeclaringClass().newInstance());
return constructor.newInstance(handler.getDeclaringClass().getDeclaredConstructor().newInstance());
} catch (Exception e) {
LOG.error("Failed creating handler instance", e);
throw new RuntimeException("Unexpected error occurred. Please raise issue at " +
Expand All @@ -276,4 +311,8 @@ private static SQSMessage clonedMessage(final SQSMessage sqsMessage) {
throw new RuntimeException(e);
}
}

public static ObjectMapper objectMapper() {
return objectMapper;
}
}
Original file line number Diff line number Diff line change
@@ -1,28 +1,43 @@
package software.amazon.lambda.powertools.sqs.internal;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.services.sqs.SqsClient;
import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequest;
import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequestEntry;
import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchResponse;
import software.amazon.awssdk.services.sqs.model.GetQueueAttributesRequest;
import software.amazon.awssdk.services.sqs.model.GetQueueAttributesResponse;
import software.amazon.awssdk.services.sqs.model.GetQueueUrlRequest;
import software.amazon.awssdk.services.sqs.model.MessageAttributeValue;
import software.amazon.awssdk.services.sqs.model.QueueAttributeName;
import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequestEntry;
import software.amazon.awssdk.services.sqs.model.SendMessageBatchResponse;
import software.amazon.lambda.powertools.sqs.SQSBatchProcessingException;
import software.amazon.lambda.powertools.sqs.SqsUtils;

import static com.amazonaws.services.lambda.runtime.events.SQSEvent.SQSMessage;
import static java.lang.String.format;
import static java.util.stream.Collectors.toList;

public final class BatchContext {
private static final Logger LOG = LoggerFactory.getLogger(BatchContext.class);
private static final Map<String, String> queueArnToQueueUrlMapping = new HashMap<>();
private static final Map<String, String> queueArnToDlqUrlMapping = new HashMap<>();

private final Map<SQSMessage, Exception> messageToException = new HashMap<>();
private final List<SQSMessage> success = new ArrayList<>();
private final List<SQSMessage> failures = new ArrayList<>();
private final List<Exception> exceptions = new ArrayList<>();

private final SqsClient client;

public BatchContext(SqsClient client) {
Expand All @@ -34,37 +49,132 @@ public void addSuccess(SQSMessage event) {
}

public void addFailure(SQSMessage event, Exception e) {
failures.add(event);
exceptions.add(e);
messageToException.put(event, e);
}

public <T> void processSuccessAndHandleFailed(final List<T> successReturns,
final boolean suppressException) {
@SafeVarargs
public final <T> void processSuccessAndHandleFailed(final List<T> successReturns,
final boolean suppressException,
final boolean deleteNonRetryableMessageFromQueue,
final Class<? extends Exception>... nonRetryableExceptions) {
if (hasFailures()) {
deleteSuccessMessage();

List<Exception> exceptions = new ArrayList<>();
List<SQSMessage> failedMessages = new ArrayList<>();
Map<SQSMessage, Exception> nonRetryableMessageToException = new HashMap<>();

messageToException.forEach((sqsMessage, exception) -> {
boolean nonRetryableMessage = Arrays.stream(nonRetryableExceptions)
.anyMatch(aClass -> aClass.isInstance(exception));

if (nonRetryableMessage) {
nonRetryableMessageToException.put(sqsMessage, exception);
} else {
exceptions.add(exception);
failedMessages.add(sqsMessage);
}
});

List<SQSMessage> messagesToBeDeleted = new ArrayList<>(success);

if (!nonRetryableMessageToException.isEmpty() && deleteNonRetryableMessageFromQueue) {
messagesToBeDeleted.addAll(nonRetryableMessageToException.keySet());
} else if (!nonRetryableMessageToException.isEmpty()) {

boolean isMovedToDlq = moveNonRetryableMessagesToDlqIfConfigured(nonRetryableMessageToException);

if (!isMovedToDlq) {
exceptions.addAll(nonRetryableMessageToException.values());
failedMessages.addAll(nonRetryableMessageToException.keySet());
}
}

deleteMessagesFromQueue(messagesToBeDeleted);

if (suppressException) {
List<String> messageIds = failures.stream().
List<String> messageIds = failedMessages.stream().
map(SQSMessage::getMessageId)
.collect(toList());

LOG.debug(format("[%s] records failed processing, but exceptions are suppressed. " +
"Failed messages %s", failures.size(), messageIds));
"Failed messages %s", failedMessages.size(), messageIds));
} else {
throw new SQSBatchProcessingException(exceptions, failures, successReturns);
throw new SQSBatchProcessingException(exceptions, failedMessages, successReturns);
}
}
}

private boolean moveNonRetryableMessagesToDlqIfConfigured(Map<SQSMessage, Exception> nonRetryableMessageToException) {
Optional<String> dlqUrl = fetchDlqUrl(nonRetryableMessageToException);

if (!dlqUrl.isPresent()) {
return false;
}

List<SendMessageBatchRequestEntry> dlqMessages = nonRetryableMessageToException.keySet().stream()
.map(sqsMessage -> {
Map<String, MessageAttributeValue> messageAttributesMap = new HashMap<>();

sqsMessage.getMessageAttributes().forEach((s, messageAttribute) -> {
MessageAttributeValue.Builder builder = MessageAttributeValue.builder();

builder
.dataType(messageAttribute.getDataType())
.stringValue(messageAttribute.getStringValue());

if (null != messageAttribute.getBinaryValue()) {
builder.binaryValue(SdkBytes.fromByteBuffer(messageAttribute.getBinaryValue()));
}

messageAttributesMap.put(s, builder.build());
});

return SendMessageBatchRequestEntry.builder()
.messageBody(sqsMessage.getBody())
.id(sqsMessage.getMessageId())
.messageAttributes(messageAttributesMap)
.build();
})
.collect(toList());

SendMessageBatchResponse sendMessageBatchResponse = client.sendMessageBatch(builder -> builder.queueUrl(dlqUrl.get())
.entries(dlqMessages));

LOG.debug(format("Response from send batch message to DLQ request %s", sendMessageBatchResponse));

return true;
}

private Optional<String> fetchDlqUrl(Map<SQSMessage, Exception> nonRetryableMessageToException) {
return nonRetryableMessageToException.keySet().stream()
.findFirst()
.map(sqsMessage -> queueArnToDlqUrlMapping.computeIfAbsent(sqsMessage.getEventSourceArn(), sourceArn -> {
String queueUrl = url(sourceArn);

GetQueueAttributesResponse queueAttributes = client.getQueueAttributes(GetQueueAttributesRequest.builder()
.attributeNames(QueueAttributeName.REDRIVE_POLICY)
.queueUrl(queueUrl)
.build());

try {
JsonNode jsonNode = SqsUtils.objectMapper().readTree(queueAttributes.attributes().get(QueueAttributeName.REDRIVE_POLICY));
return url(jsonNode.get("deadLetterTargetArn").asText());
} catch (JsonProcessingException e) {
LOG.debug("Unable to parse Re drive policy for queue {}. Even if DLQ exists, failed messages will be send back to main queue.", queueUrl, e);
return null;
}
}));
}

private boolean hasFailures() {
return !failures.isEmpty();
return !messageToException.isEmpty();
}

private void deleteSuccessMessage() {
if (!success.isEmpty()) {
private void deleteMessagesFromQueue(final List<SQSMessage> messages) {
if (!messages.isEmpty()) {
DeleteMessageBatchRequest request = DeleteMessageBatchRequest.builder()
.queueUrl(url())
.entries(success.stream().map(m -> DeleteMessageBatchRequestEntry.builder()
.queueUrl(url(messages.get(0).getEventSourceArn()))
.entries(messages.stream().map(m -> DeleteMessageBatchRequestEntry.builder()
.id(m.getMessageId())
.receiptHandle(m.getReceiptHandle())
.build()).collect(toList()))
Expand All @@ -75,12 +185,15 @@ private void deleteSuccessMessage() {
}
}

private String url() {
String[] arnArray = success.get(0).getEventSourceArn().split(":");
return client.getQueueUrl(GetQueueUrlRequest.builder()
.queueOwnerAWSAccountId(arnArray[4])
.queueName(arnArray[5])
.build())
.queueUrl();
private String url(String queueArn) {
return queueArnToQueueUrlMapping.computeIfAbsent(queueArn, s -> {
String[] arnArray = queueArn.split(":");

return client.getQueueUrl(GetQueueUrlRequest.builder()
.queueOwnerAWSAccountId(arnArray[4])
.queueName(arnArray[5])
.build())
.queueUrl();
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ && placedOnSqsEventRequestHandler(pjp)) {

SQSEvent sqsEvent = (SQSEvent) proceedArgs[0];

batchProcessor(sqsEvent, sqsBatch.suppressException(), sqsBatch.value());
batchProcessor(sqsEvent, sqsBatch.suppressException(), sqsBatch.value(), sqsBatch.nonRetryableExceptions());
}

return pjp.proceed(proceedArgs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,20 @@ public String process(SQSMessage message) {
}
}

@Test
void shouldBatchProcessAndMoveNonRetryableExceptionToDlq() {
String failedId = "2e1424d4-f796-459a-8184-9c92662be6da";

List<String> batchProcessor = batchProcessor(event, (message) -> {
if (failedId.equals(message.getMessageId())) {
throw new IllegalStateException("Failed processing");
}

interactionClient.listQueues();
return "Success";
}, IllegalStateException.class, IllegalArgumentException.class);
}

public class FailureSampleInnerSqsHandler implements SqsMessageHandler<String> {
@Override
public String process(SQSEvent.SQSMessage message) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package software.amazon.lambda.powertools.sqs.handlers;

import com.amazonaws.services.lambda.runtime.Context;
import com.amazonaws.services.lambda.runtime.RequestHandler;
import com.amazonaws.services.lambda.runtime.events.SQSEvent;
import software.amazon.lambda.powertools.sqs.SqsBatch;
import software.amazon.lambda.powertools.sqs.SqsMessageHandler;

import static com.amazonaws.services.lambda.runtime.events.SQSEvent.SQSMessage;
import static software.amazon.lambda.powertools.sqs.internal.SqsMessageBatchProcessorAspectTest.mockedRandom;

public class SqsMessageHandlerWithNonRetryableHandler implements RequestHandler<SQSEvent, String> {

@Override
@SqsBatch(value = InnerMessageHandler.class, nonRetryableExceptions = {IllegalStateException.class, IllegalArgumentException.class})
public String handleRequest(final SQSEvent sqsEvent,
final Context context) {
return "Success";
}

private class InnerMessageHandler implements SqsMessageHandler<Object> {

@Override
public String process(SQSMessage message) {
if(message.getMessageId().isEmpty()) {
throw new IllegalArgumentException("Invalid message and was moved to DLQ");
}

mockedRandom.nextInt();
return "Success";
}
}
}

0 comments on commit 242736b

Please sign in to comment.