Skip to content

Commit

Permalink
Fix ArrayIndexOutOfBoundsException when skipping allowed string value…
Browse files Browse the repository at this point in the history
… with escaped escape character (#92)

* Add a test case with ArrayIndexOutOfBoundsException

* Refactored the escape characters handling to fix ArrayOutOfBounds

* Move counting of non-visible characters inside the utility function, refactor logic to reuse value skipping inside masking as well

* Simplified the counting, added more tests to show how jackson would parse the object

* Change skip to stepOver, added some docs and renamed a variable

---------

Co-authored-by: breus <b.blaauwendraad@gmail.com>
  • Loading branch information
gavlyukovskiy and Breus committed Mar 13, 2024
1 parent da82ad7 commit 5b13781
Show file tree
Hide file tree
Showing 12 changed files with 332 additions and 263 deletions.
218 changes: 93 additions & 125 deletions src/main/java/dev/blaauwendraad/masker/json/KeyContainsMasker.java

Large diffs are not rendered by default.

94 changes: 74 additions & 20 deletions src/main/java/dev/blaauwendraad/masker/json/MaskingState.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
* Represents the state of the {@link JsonMasker} at a given point in time during the {@link JsonMasker#mask(byte[])}
* operation.
*/
public final class MaskingState {
final class MaskingState {
private final byte[] message;
private int currentIndex = 0;
private final List<ReplacementOperation> replacementOperations = new ArrayList<>();
Expand Down Expand Up @@ -43,10 +43,6 @@ public byte byteAtCurrentIndex() {
return message[currentIndex];
}

public byte byteAtCurrentIndexMinusOne() {
return message[currentIndex - 1];
}

public int currentIndex() {
return currentIndex;
}
Expand All @@ -56,26 +52,82 @@ public byte[] getMessage() {
}

/**
* Adds new delayed replacement operation to the list of operations to be applied to the message.
* Replaces a target value (byte slice) with a mask byte. If lengths of both target value and mask are equal, the
* replacement is done in-place, otherwise a replacement operation is recorded to be performed as a batch using
* {@link #flushReplacementOperations}.
*
* @see ReplacementOperation
*/
public void addReplacementOperation(int startIndex, int endIndex, byte[] mask, int maskRepeat) {
ReplacementOperation replacementOperation = new ReplacementOperation(startIndex, endIndex, mask, maskRepeat);
public void replaceTargetValueWith(int startIndex, int length, byte[] mask, int maskRepeat) {
ReplacementOperation replacementOperation = new ReplacementOperation(startIndex, length, mask, maskRepeat);
replacementOperations.add(replacementOperation);
replacementOperationsTotalDifference += replacementOperation.difference();
}

/**
* Returns the list of replacement operations that need to be applied to the message.
* Performs all replacement operations to the message array, must be called at the end of the replacements.
* <p>
* For every operation that required resizing of the original array, to avoid copying the array multiple times,
* those operations were stored in a list and can be performed in one go, thus resizing the array only once.
* <p>
* Replacement operation is only recorded if the length of the target value is different from the length of the mask,
* otherwise the replacement must have been done in-place.
*
* @return the message array with all replacement operations performed.
*/
public List<ReplacementOperation> getReplacementOperations() {
return replacementOperations;
}
public byte[] flushReplacementOperations() {
if (replacementOperations.isEmpty()) {
return message;
}

/**
* Returns the total difference between the masks and target values lengths of all replacement operations.
*/
public int getReplacementOperationsTotalDifference() {
return replacementOperationsTotalDifference;
// Create new empty array with a length computed by the difference of all mismatches of lengths between the target values and the masks
// in some edge cases the length difference might be equal to 0, but since some indices mismatch (otherwise there would be no replacement operations)
// we still have to copy the array to keep track of data according to original indices
byte[] newMessage = new byte[message.length + replacementOperationsTotalDifference];

// Index of the original message array
int index = 0;
// Offset is the difference between the original and new array indices, we need it to calculate indices
// in the new message array using startIndex and endIndex, which are indices in the original array
int offset = 0;
for (ReplacementOperation replacementOperation : replacementOperations) {
// Copy everything from message up until replacement operation start index
System.arraycopy(
message,
index,
newMessage,
index + offset,
replacementOperation.startIndex - index
);
// Insert the mask bytes
int length = replacementOperation.mask.length;
for (int i = 0; i < replacementOperation.maskRepeat; i++) {
System.arraycopy(
replacementOperation.mask,
0,
newMessage,
replacementOperation.startIndex + offset + i * length,
length
);
}
// Adjust index and offset to continue copying from the end of the replacement operation
index = replacementOperation.startIndex + replacementOperation.length;
offset += replacementOperation.difference();
}

// Copy the remainder of the original array
System.arraycopy(
message,
index,
newMessage,
index + offset,
message.length - index
);

// make sure no operations are performed after this
this.currentIndex = Integer.MAX_VALUE;

return newMessage;
}

/**
Expand Down Expand Up @@ -140,19 +192,21 @@ public String toString() {
* a single resize operation.
*
* @param startIndex index from which to start replacing
* @param endIndex index at which to stop replacing
* @param length the length of the target value slice
* @param mask byte array mask to use as replacement for the value
* @param maskRepeat number of times to repeat the mask (for cases when every character or digit is masked)
*
* @see #flushReplacementOperations()
*/
@SuppressWarnings("java:S6218") // never used for comparison
public record ReplacementOperation(int startIndex, int endIndex, byte[] mask, int maskRepeat) {
private record ReplacementOperation(int startIndex, int length, byte[] mask, int maskRepeat) {

/**
* The difference between the mask length and the length of the target value to replace.
* Used to calculate keep track of the offset during replacements.
*/
public int difference() {
return mask.length * maskRepeat - (endIndex - startIndex);
return mask.length * maskRepeat - length;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,13 @@ public final class KeyMaskingConfig {

KeyMaskingConfig(KeyMaskingConfig.Builder builder) {
if (builder.maskStringsWith != null) {
this.maskStringsWith = builder.maskStringsWith.getBytes(StandardCharsets.UTF_8);
this.maskStringsWith = ("\"" + builder.maskStringsWith + "\"").getBytes(StandardCharsets.UTF_8);
this.maskStringCharactersWith = null;
} else if (builder.maskStringCharactersWith != null) {
this.maskStringsWith = null;
this.maskStringCharactersWith = builder.maskStringCharactersWith.getBytes(StandardCharsets.UTF_8);
} else {
// no quotes for strings as opposed to numbers and booleans because we never change the type of the string
// and only mask value inside the quotes
this.maskStringsWith = "***".getBytes(StandardCharsets.UTF_8);
this.maskStringsWith = "\"***\"".getBytes(StandardCharsets.UTF_8);
this.maskStringCharactersWith = null;
}
if (builder.maskNumbersWithString != null) {
Expand Down
48 changes: 48 additions & 0 deletions src/main/java/dev/blaauwendraad/masker/json/util/Utf8Util.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,52 @@ public static int getCodePointByteLength(byte input) {
}
throw new IllegalArgumentException("Input byte is not using UTF-8 encoding");
}

/**
* Counts amount of non-visible characters inside the string. The intervals supplied must be within a single string
* (already inside quotes) as this method will not do boundary checks or look for end of string value.
* an escaped quotes.
* @param message the byte array containing the string
* @param fromIndex the starting index of the string value (after the quote)
* @param length the length of the string value (excluding the quotes)
* @return the amount of non-visible characters in the string - escape characters, UTF-8 character data ('\u0000'),
* characters that use more than a single byte
*/
public static int countNonVisibleCharacters(byte[] message, int fromIndex, int length) {
int index = fromIndex;
int toIndex = fromIndex + length;
boolean isEscapeCharacter = false;
int nonVisibleCharacterCount = 0;
while (index < toIndex) {
byte currentByte = message[index];
if (isEscapeCharacter) {
/*
* Non-escaped backslashes are escape characters and are not actually part of the string but
* only used for character encoding, so must not be included in the mask.
*/
nonVisibleCharacterCount++;
if (AsciiCharacter.isLowercaseU(currentByte)) {
/*
* The next 4 characters are hexadecimal digits which form a single character and are only
* there for encoding, so must not be included in the mask.
*/
nonVisibleCharacterCount += 4;
index += 4;
}
} else {
int codePointByteLength = Utf8Util.getCodePointByteLength(currentByte);
if (codePointByteLength > 1) {
/*
* We only support UTF-8, so whenever code points are encoded using multiple bytes this should
* be represented by a single asterisk and the additional bytes used for encoding need to be
* removed.
*/
nonVisibleCharacterCount += codePointByteLength - 1;
}
}
isEscapeCharacter = !isEscapeCharacter && AsciiCharacter.isEscapeCharacter(currentByte);
index++;
}
return nonVisibleCharacterCount;
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ private static Stream<JsonMaskerTestInstance> targetKeyAllowModeNotPretty() {
.build());
return Stream.of(
new JsonMaskerTestInstance("""
[
[ \s
{
"allowedKey": "yes",
"notAllowedKey": "hello"
}
]
""", """
[
[ \s
{
"allowedKey": "yes",
"notAllowedKey": "***"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,13 @@ void shouldReturnSpecificConfigWhenMatched() {
.isNotNull()
.extracting(KeyMaskingConfig::getMaskStringsWith)
.extracting(bytes -> new String(bytes, StandardCharsets.UTF_8))
.isEqualTo("***");
.isEqualTo("\"***\"");

assertThatConfig(keyMatcher, "maskMeLikeCIA")
.isNotNull()
.extracting(KeyMaskingConfig::getMaskStringsWith)
.extracting(bytes -> new String(bytes, StandardCharsets.UTF_8))
.isEqualTo("[redacted]");
.isEqualTo("\"[redacted]\"");
}

@Test
Expand All @@ -95,13 +95,13 @@ void shouldReturnMaskingConfigInAllowMode() {
.isNotNull()
.extracting(KeyMaskingConfig::getMaskStringsWith)
.extracting(bytes -> new String(bytes, StandardCharsets.UTF_8))
.isEqualTo("***");
.isEqualTo("\"***\"");

assertThatConfig(keyMatcher, "maskMeLikeCIA")
.isNotNull()
.extracting(KeyMaskingConfig::getMaskStringsWith)
.extracting(bytes -> new String(bytes, StandardCharsets.UTF_8))
.isEqualTo("[redacted]");
.isEqualTo("\"[redacted]\"");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ private static JsonNode maskBooleanNode(BooleanNode booleanNode, KeyMaskingConfi
private static TextNode maskTextNode(TextNode textNode, KeyMaskingConfig config) {
String text = textNode.textValue();
if (config.getMaskStringsWith() != null) {
text = new String(config.getMaskStringsWith(), StandardCharsets.UTF_8);
// strip the quotes
text = new String(config.getMaskStringsWith(), 1, config.getMaskStringsWith().length - 2, StandardCharsets.UTF_8);
} else if (config.getMaskStringCharactersWith() != null) {
text = new String(config.getMaskStringCharactersWith(), StandardCharsets.UTF_8).repeat(text.length());
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ private JsonStringCharacters() {
private static final Set<Character> asciiUppercaseLetters =
IntStream.rangeClosed(65, 90).mapToObj(i -> (char) i).collect(Collectors.toSet());
private static final Set<Character> asciiSpecialChars2 = IntStream.rangeClosed(91, 96)
.filter(i -> i != 92 /* escape character */)
.mapToObj(i -> (char) i)
.collect(Collectors.toSet());
private static final Set<Character> asciiLowercaseLetters =
Expand Down
Loading

0 comments on commit 5b13781

Please sign in to comment.