From a0447626dcffd6bcd4decd13e4dac1da4661d6f7 Mon Sep 17 00:00:00 2001 From: Alex Herbert Date: Sun, 24 Nov 2019 23:39:27 +0000 Subject: [PATCH] [CODEC-270] Fix masked check of the final bits to discard. The mask must check all the bits to discard are zero. --- .../apache/commons/codec/binary/Base32.java | 47 +++++++---- .../apache/commons/codec/binary/Base64.java | 26 +++--- .../commons/codec/binary/Base32Test.java | 79 +++++++++++++++++++ .../commons/codec/binary/Base64Test.java | 61 ++++++++++++++ 4 files changed, 188 insertions(+), 25 deletions(-) diff --git a/src/main/java/org/apache/commons/codec/binary/Base32.java b/src/main/java/org/apache/commons/codec/binary/Base32.java index a8ede4eb00..7009fcb0ae 100644 --- a/src/main/java/org/apache/commons/codec/binary/Base32.java +++ b/src/main/java/org/apache/commons/codec/binary/Base32.java @@ -114,8 +114,20 @@ public class Base32 extends BaseNCodec { 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', }; + /** Mask used to extract 7 bits, used when decoding final trailing character. */ + private static final long MASK_7BITS = 0x7fL; + /** Mask used to extract 6 bits, used when decoding final trailing character. */ + private static final long MASK_6BITS = 0x3fL; /** Mask used to extract 5 bits, used when encoding Base32 bytes */ private static final int MASK_5BITS = 0x1f; + /** Mask used to extract 4 bits, used when decoding final trailing character. */ + private static final long MASK_4BITS = 0x0fL; + /** Mask used to extract 3 bits, used when decoding final trailing character. */ + private static final long MASK_3BITS = 0x07L; + /** Mask used to extract 2 bits, used when decoding final trailing character. */ + private static final long MASK_2BITS = 0x03L; + /** Mask used to extract 1 bits, used when decoding final trailing character. */ + private static final long MASK_1BITS = 0x01L; // The static final fields above are used for the original static byte[] methods on Base32. // The private member fields below are used with the new streaming approach, which requires @@ -335,7 +347,8 @@ public Base32(final int lineLength, final byte[] lineSeparator, final boolean us * Amount of bytes available from input for decoding. * @param context the context to be used * - * Output is written to {@link Context#buffer} as 8-bit octets, using {@link Context#pos} as the buffer position + * Output is written to {@link org.apache.commons.codec.binary.BaseNCodec.Context#buffer Context#buffer} + * as 8-bit octets, using {@link org.apache.commons.codec.binary.BaseNCodec.Context#pos Context#pos} as the buffer position */ @Override void decode(final byte[] in, int inPos, final int inAvail, final Context context) { @@ -381,35 +394,35 @@ void decode(final byte[] in, int inPos, final int inAvail, final Context context // we ignore partial bytes, i.e. only multiples of 8 count switch (context.modulus) { case 2 : // 10 bits, drop 2 and output one byte - validateCharacter(2, context); + validateCharacter(MASK_2BITS, context); buffer[context.pos++] = (byte) ((context.lbitWorkArea >> 2) & MASK_8BITS); break; case 3 : // 15 bits, drop 7 and output 1 byte - validateCharacter(7, context); + validateCharacter(MASK_7BITS, context); buffer[context.pos++] = (byte) ((context.lbitWorkArea >> 7) & MASK_8BITS); break; case 4 : // 20 bits = 2*8 + 4 - validateCharacter(4, context); + validateCharacter(MASK_4BITS, context); context.lbitWorkArea = context.lbitWorkArea >> 4; // drop 4 bits buffer[context.pos++] = (byte) ((context.lbitWorkArea >> 8) & MASK_8BITS); buffer[context.pos++] = (byte) ((context.lbitWorkArea) & MASK_8BITS); break; case 5 : // 25bits = 3*8 + 1 - validateCharacter(1, context); + validateCharacter(MASK_1BITS, context); context.lbitWorkArea = context.lbitWorkArea >> 1; buffer[context.pos++] = (byte) ((context.lbitWorkArea >> 16) & MASK_8BITS); buffer[context.pos++] = (byte) ((context.lbitWorkArea >> 8) & MASK_8BITS); buffer[context.pos++] = (byte) ((context.lbitWorkArea) & MASK_8BITS); break; case 6 : // 30bits = 3*8 + 6 - validateCharacter(6, context); + validateCharacter(MASK_6BITS, context); context.lbitWorkArea = context.lbitWorkArea >> 6; buffer[context.pos++] = (byte) ((context.lbitWorkArea >> 16) & MASK_8BITS); buffer[context.pos++] = (byte) ((context.lbitWorkArea >> 8) & MASK_8BITS); buffer[context.pos++] = (byte) ((context.lbitWorkArea) & MASK_8BITS); break; case 7 : // 35 = 4*8 +3 - validateCharacter(3, context); + validateCharacter(MASK_3BITS, context); context.lbitWorkArea = context.lbitWorkArea >> 3; buffer[context.pos++] = (byte) ((context.lbitWorkArea >> 24) & MASK_8BITS); buffer[context.pos++] = (byte) ((context.lbitWorkArea >> 16) & MASK_8BITS); @@ -548,19 +561,23 @@ public boolean isInAlphabet(final byte octet) { } /** - *

- * Validates whether the character is possible in the context of the set of possible base 32 values. - *

+ * Validates whether decoding the final trailing character is possible in the context + * of the set of possible base 32 values. + * + *

The character is valid if the lower bits within the provided mask are zero. This + * is used to test the final trailing base-32 digit is zero in the bits that will be discarded. * - * @param numBits number of least significant bits to check + * @param emptyBitsMask The mask of the lower bits that should be empty * @param context the context to be used * * @throws IllegalArgumentException if the bits being checked contain any non-zero value */ - private void validateCharacter(final int numBits, final Context context) { - if ((context.lbitWorkArea & numBits) != 0) { - throw new IllegalArgumentException( - "Last encoded character (before the paddings if any) is a valid base 32 alphabet but not a possible value"); + private static void validateCharacter(final long emptyBitsMask, final Context context) { + // Use the long bit work area + if ((context.lbitWorkArea & emptyBitsMask) != 0) { + throw new IllegalArgumentException( + "Last encoded character (before the paddings if any) is a valid base 32 alphabet but not a possible value. " + + "Expected the discarded bits to be zero."); } } } diff --git a/src/main/java/org/apache/commons/codec/binary/Base64.java b/src/main/java/org/apache/commons/codec/binary/Base64.java index aed7843c3d..92a3563600 100644 --- a/src/main/java/org/apache/commons/codec/binary/Base64.java +++ b/src/main/java/org/apache/commons/codec/binary/Base64.java @@ -128,6 +128,10 @@ public class Base64 extends BaseNCodec { */ /** Mask used to extract 6 bits, used when encoding */ private static final int MASK_6BITS = 0x3f; + /** Mask used to extract 4 bits, used when decoding final trailing character. */ + private static final int MASK_4BITS = 0xf; + /** Mask used to extract 2 bits, used when decoding final trailing character. */ + private static final int MASK_2BITS = 0x3; // The static final fields above are used for the original static byte[] methods on Base64. // The private member fields below are used with the new streaming approach, which requires @@ -469,12 +473,12 @@ void decode(final byte[] in, int inPos, final int inAvail, final Context context // TODO not currently tested; perhaps it is impossible? break; case 2 : // 12 bits = 8 + 4 - validateCharacter(4, context); + validateCharacter(MASK_4BITS, context); context.ibitWorkArea = context.ibitWorkArea >> 4; // dump the extra 4 bits buffer[context.pos++] = (byte) ((context.ibitWorkArea) & MASK_8BITS); break; case 3 : // 18 bits = 8 + 8 + 2 - validateCharacter(2, context); + validateCharacter(MASK_2BITS, context); context.ibitWorkArea = context.ibitWorkArea >> 2; // dump 2 bits buffer[context.pos++] = (byte) ((context.ibitWorkArea >> 8) & MASK_8BITS); buffer[context.pos++] = (byte) ((context.ibitWorkArea) & MASK_8BITS); @@ -784,20 +788,22 @@ protected boolean isInAlphabet(final byte octet) { } /** - *

- * Validates whether the character is possible in the context of the set of possible base 64 values. - *

+ * Validates whether decoding the final trailing character is possible in the context + * of the set of possible base 64 values. + * + *

The character is valid if the lower bits within the provided mask are zero. This + * is used to test the final trailing base-64 digit is zero in the bits that will be discarded. * - * @param numBitsToDrop number of least significant bits to check + * @param emptyBitsMask The mask of the lower bits that should be empty * @param context the context to be used * * @throws IllegalArgumentException if the bits being checked contain any non-zero value */ - private long validateCharacter(final int numBitsToDrop, final Context context) { - if ((context.ibitWorkArea & numBitsToDrop) != 0) { + private static void validateCharacter(final int emptyBitsMask, final Context context) { + if ((context.ibitWorkArea & emptyBitsMask) != 0) { throw new IllegalArgumentException( - "Last encoded character (before the paddings if any) is a valid base 64 alphabet but not a possible value"); + "Last encoded character (before the paddings if any) is a valid base 64 alphabet but not a possible value. " + + "Expected the discarded bits to be zero."); } - return context.ibitWorkArea >> numBitsToDrop; } } diff --git a/src/test/java/org/apache/commons/codec/binary/Base32Test.java b/src/test/java/org/apache/commons/codec/binary/Base32Test.java index 725b912e05..de62be8532 100644 --- a/src/test/java/org/apache/commons/codec/binary/Base32Test.java +++ b/src/test/java/org/apache/commons/codec/binary/Base32Test.java @@ -71,6 +71,16 @@ public class Base32Test { "CPNMUOJ1E2======" }; + /** + * Copy of the standard base-32 encoding table. Used to test decoding the final + * character of encoded bytes. + */ + private static final byte[] ENCODE_TABLE = { + 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', + 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', + '2', '3', '4', '5', '6', '7', + }; + private static final Object[][] BASE32_BINARY_TEST_CASES; // { null, "O0o0O0o0" } @@ -298,4 +308,73 @@ private void testImpossibleCases(final Base32 codec, final String[] impossible_c } } } + + @Test + public void testBase32DecodingOfTrailing10Bits() { + assertBase32DecodingOfTrailingBits(10); + } + + @Test + public void testBase32DecodingOfTrailing15Bits() { + assertBase32DecodingOfTrailingBits(15); + } + + @Test + public void testBase32DecodingOfTrailing20Bits() { + assertBase32DecodingOfTrailingBits(20); + } + + @Test + public void testBase32DecodingOfTrailing25Bits() { + assertBase32DecodingOfTrailingBits(25); + } + + @Test + public void testBase32DecodingOfTrailing30Bits() { + assertBase32DecodingOfTrailingBits(30); + } + + @Test + public void testBase32DecodingOfTrailing35Bits() { + assertBase32DecodingOfTrailingBits(35); + } + + /** + * Test base 32 decoding of the final trailing bits. Trailing encoded bytes + * cannot fit exactly into 5-bit characters so the last character has a limited + * alphabet where the final bits are zero. This asserts that illegal final + * characters throw an exception when decoding. + * + * @param nbits the number of trailing bits (must be a factor of 5 and {@code <40}) + */ + private static void assertBase32DecodingOfTrailingBits(int nbits) { + final Base32 codec = new Base32(); + // Create the encoded bytes. The first characters must be valid so fill with 'zero'. + final byte[] encoded = new byte[nbits / 5]; + Arrays.fill(encoded, ENCODE_TABLE[0]); + // Compute how many bits would be discarded from 8-bit bytes + final int discard = nbits % 8; + final int emptyBitsMask = (1 << discard) - 1; + // Enumerate all 32 possible final characters in the last position + final int last = encoded.length - 1; + for (int i = 0; i < 32; i++) { + encoded[last] = ENCODE_TABLE[i]; + // If the lower bits are set we expect an exception. This is not a valid + // final character. + if ((i & emptyBitsMask) != 0) { + try { + codec.decode(encoded); + fail("Final base-32 digit should not be allowed"); + } catch (final IllegalArgumentException ex) { + // expected + } + } else { + // Otherwise this should decode + final byte[] decoded = codec.decode(encoded); + // Compute the bits that were encoded. This should match the final decoded byte. + final int bitsEncoded = i >> discard; + assertEquals("Invalid decoding of last character", bitsEncoded, decoded[decoded.length - 1]); + } + } + } } diff --git a/src/test/java/org/apache/commons/codec/binary/Base64Test.java b/src/test/java/org/apache/commons/codec/binary/Base64Test.java index 442fc4c274..cf737098dd 100644 --- a/src/test/java/org/apache/commons/codec/binary/Base64Test.java +++ b/src/test/java/org/apache/commons/codec/binary/Base64Test.java @@ -52,6 +52,18 @@ public class Base64Test { "AB", }; + /** + * Copy of the standard base-64 encoding table. Used to test decoding the final + * character of encoded bytes. + */ + private static final byte[] STANDARD_ENCODE_TABLE = { + 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', + 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', + 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', + 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', + '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '+', '/' + }; + private final Random random = new Random(); /** @@ -1317,4 +1329,53 @@ public void testBase64ImpossibleSamples() { } } } + + @Test + public void testBase64DecodingOfTrailing12Bits() { + assertBase64DecodingOfTrailingBits(12); + } + + @Test + public void testBase64DecodingOfTrailing18Bits() { + assertBase64DecodingOfTrailingBits(18); + } + + /** + * Test base 64 decoding of the final trailing bits. Trailing encoded bytes + * cannot fit exactly into 6-bit characters so the last character has a limited + * alphabet where the final bits are zero. This asserts that illegal final + * characters throw an exception when decoding. + * + * @param nbits the number of trailing bits (must be a factor of 6 and {@code <24}) + */ + private static void assertBase64DecodingOfTrailingBits(int nbits) { + final Base64 codec = new Base64(); + // Create the encoded bytes. The first characters must be valid so fill with 'zero'. + final byte[] encoded = new byte[nbits / 6]; + Arrays.fill(encoded, STANDARD_ENCODE_TABLE[0]); + // Compute how many bits would be discarded from 8-bit bytes + final int discard = nbits % 8; + final int emptyBitsMask = (1 << discard) - 1; + // Enumerate all 64 possible final characters in the last position + final int last = encoded.length - 1; + for (int i = 0; i < 64; i++) { + encoded[last] = STANDARD_ENCODE_TABLE[i]; + // If the lower bits are set we expect an exception. This is not a valid + // final character. + if ((i & emptyBitsMask) != 0) { + try { + codec.decode(encoded); + fail("Final base-64 digit should not be allowed"); + } catch (final IllegalArgumentException ex) { + // expected + } + } else { + // Otherwise this should decode + final byte[] decoded = codec.decode(encoded); + // Compute the bits that were encoded. This should match the final decoded byte. + final int bitsEncoded = i >> discard; + assertEquals("Invalid decoding of last character", bitsEncoded, decoded[decoded.length - 1]); + } + } + } }