Skip to content

Commit

Permalink
Improved efficiency in DigestManager.verify() (#3810)
Browse files Browse the repository at this point in the history
  • Loading branch information
merlimat committed Feb 27, 2023
1 parent f65b72d commit 1f8de8f
Show file tree
Hide file tree
Showing 15 changed files with 204 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ void populateValueAndReset(ByteBuf buf) {
}

@Override
void update(ByteBuf data) {
void update(ByteBuf data, int offset, int len) {
MutableInt current = currentCrc.get();
final int lastCrc = current.intValue();
current.setValue(Crc32cIntChecksum.resumeChecksum(lastCrc, data));
current.setValue(Crc32cIntChecksum.resumeChecksum(lastCrc, data, offset, len));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class CRC32DigestManager extends DigestManager {
interface CRC32Digest {
long getValueAndReset();

void update(ByteBuf buf);
void update(ByteBuf buf, int offset, int len);
}

private static final FastThreadLocal<CRC32Digest> crc = new FastThreadLocal<CRC32Digest>() {
Expand Down Expand Up @@ -62,7 +62,7 @@ void populateValueAndReset(ByteBuf buf) {
}

@Override
void update(ByteBuf data) {
crc.get().update(data);
void update(ByteBuf data, int offset, int len) {
crc.get().update(data, offset, len);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@

import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.ByteBufUtil;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.buffer.Unpooled;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.FastThreadLocal;
import java.security.GeneralSecurityException;
import java.security.NoSuchAlgorithmException;
import org.apache.bookkeeper.client.BKException.BKDigestMatchException;
Expand Down Expand Up @@ -51,10 +54,10 @@ public abstract class DigestManager {
abstract int getMacCodeLength();

void update(byte[] data) {
update(Unpooled.wrappedBuffer(data, 0, data.length));
update(Unpooled.wrappedBuffer(data), 0, data.length);
}

abstract void update(ByteBuf buffer);
abstract void update(ByteBuf buffer, int offset, int len);

abstract void populateValueAndReset(ByteBuf buffer);

Expand Down Expand Up @@ -109,7 +112,7 @@ public ByteBufList computeDigestAndPackageForSending(long entryId, long lastAddC
headersBuffer.writeLong(lastAddConfirmed);
headersBuffer.writeLong(length);

update(headersBuffer);
update(headersBuffer, 0, METADATA_LENGTH);

// don't unwrap slices
final ByteBuf unwrapped = data.unwrap() != null && data.unwrap() instanceof CompositeByteBuf
Expand All @@ -118,9 +121,9 @@ public ByteBufList computeDigestAndPackageForSending(long entryId, long lastAddC
ReferenceCountUtil.safeRelease(data);

if (unwrapped instanceof CompositeByteBuf) {
((CompositeByteBuf) unwrapped).forEach(this::update);
((CompositeByteBuf) unwrapped).forEach(b -> update(b, b.readerIndex(), b.readableBytes()));
} else {
update(unwrapped);
update(unwrapped, unwrapped.readerIndex(), unwrapped.readableBytes());
}
populateValueAndReset(headersBuffer);

Expand All @@ -144,7 +147,7 @@ public ByteBufList computeDigestAndPackageForSendingLac(long lac) {
headersBuffer.writeLong(ledgerId);
headersBuffer.writeLong(lac);

update(headersBuffer);
update(headersBuffer, 0, LAC_METADATA_LENGTH);
populateValueAndReset(headersBuffer);

return ByteBufList.get(headersBuffer);
Expand All @@ -158,6 +161,18 @@ private void verifyDigest(long entryId, ByteBuf dataReceived) throws BKDigestMat
verifyDigest(entryId, dataReceived, false);
}

private static final FastThreadLocal<ByteBuf> DIGEST_BUFFER = new FastThreadLocal<ByteBuf>() {
@Override
protected ByteBuf initialValue() throws Exception {
return PooledByteBufAllocator.DEFAULT.directBuffer(1024);
}

@Override
protected void onRemoval(ByteBuf value) throws Exception {
value.release();
}
};

private void verifyDigest(long entryId, ByteBuf dataReceived, boolean skipEntryIdCheck)
throws BKDigestMatchException {

Expand All @@ -168,21 +183,18 @@ private void verifyDigest(long entryId, ByteBuf dataReceived, boolean skipEntryI
this.getClass().getName(), dataReceived.readableBytes());
throw new BKDigestMatchException();
}
update(dataReceived.slice(0, METADATA_LENGTH));
update(dataReceived, 0, METADATA_LENGTH);

int offset = METADATA_LENGTH + macCodeLength;
update(dataReceived.slice(offset, dataReceived.readableBytes() - offset));
update(dataReceived, offset, dataReceived.readableBytes() - offset);

ByteBuf digest = allocator.buffer(macCodeLength);
ByteBuf digest = DIGEST_BUFFER.get();
digest.clear();
populateValueAndReset(digest);

try {
if (digest.compareTo(dataReceived.slice(METADATA_LENGTH, macCodeLength)) != 0) {
logger.error("Mac mismatch for ledger-id: " + ledgerId + ", entry-id: " + entryId);
throw new BKDigestMatchException();
}
} finally {
ReferenceCountUtil.safeRelease(digest);
if (!ByteBufUtil.equals(digest, 0, dataReceived, METADATA_LENGTH, macCodeLength)) {
logger.error("Mac mismatch for ledger-id: " + ledgerId + ", entry-id: " + entryId);
throw new BKDigestMatchException();
}

long actualLedgerId = dataReceived.readLong();
Expand Down Expand Up @@ -211,20 +223,17 @@ public long verifyDigestAndReturnLac(ByteBuf dataReceived) throws BKDigestMatchE
throw new BKDigestMatchException();
}

update(dataReceived.slice(0, LAC_METADATA_LENGTH));
update(dataReceived, 0, LAC_METADATA_LENGTH);

ByteBuf digest = allocator.buffer(macCodeLength);
try {
populateValueAndReset(digest);
ByteBuf digest = DIGEST_BUFFER.get();
digest.clear();

if (digest.compareTo(dataReceived.slice(LAC_METADATA_LENGTH, macCodeLength)) != 0) {
logger.error("Mac mismatch for ledger-id LAC: " + ledgerId);
throw new BKDigestMatchException();
}
} finally {
ReferenceCountUtil.safeRelease(digest);
}
populateValueAndReset(digest);

if (!ByteBufUtil.equals(digest, 0, dataReceived, LAC_METADATA_LENGTH, macCodeLength)) {
logger.error("Mac mismatch for ledger-id LAC: " + ledgerId);
throw new BKDigestMatchException();
}
long actualLedgerId = dataReceived.readLong();
long lac = dataReceived.readLong();
if (actualLedgerId != ledgerId) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,7 @@ public long getValueAndReset() {
}

@Override
public void update(ByteBuf buf) {
int index = buf.readerIndex();
int length = buf.readableBytes();

public void update(ByteBuf buf, int index, int length) {
try {
if (buf.hasMemoryAddress()) {
// Calculate CRC directly from the direct memory pointer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ int getMacCodeLength() {
}

@Override
void update(ByteBuf buffer) {}
void update(ByteBuf buffer, int offset, int len) {}

@Override
void populateValueAndReset(ByteBuf buffer) {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ void populateValueAndReset(ByteBuf buffer) {
}

@Override
void update(ByteBuf data) {
mac.get().update(data.nioBuffer());
void update(ByteBuf data, int offset, int len) {
mac.get().update(data.slice(offset, len).nioBuffer());
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public long getValueAndReset() {
}

@Override
public void update(ByteBuf buf) {
crc.update(buf.nioBuffer());
public void update(ByteBuf buf, int offset, int len) {
crc.update(buf.slice(offset, len).nioBuffer());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ public static int computeChecksum(ByteBuf payload) {
* @param payload
* @return
*/
public static int resumeChecksum(int previousChecksum, ByteBuf payload) {
return CRC32C_HASH.resume(previousChecksum, payload);
public static int resumeChecksum(int previousChecksum, ByteBuf payload, int offset, int len) {
return CRC32C_HASH.resume(previousChecksum, payload, offset, len);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,10 @@

public interface IntHash {
int calculate(ByteBuf buffer);

int calculate(ByteBuf buffer, int offset, int len);

int resume(int current, ByteBuf buffer);

int resume(int current, ByteBuf buffer, int offset, int len);
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,21 @@ public int calculate(ByteBuf buffer) {

@Override
public int resume(int current, ByteBuf buffer) {
return resume(current, buffer, buffer.readerIndex(), buffer.readableBytes());
}

@Override
public int calculate(ByteBuf buffer, int offset, int len) {
return resume(0, buffer, offset, len);
}

@Override
public int resume(int current, ByteBuf buffer, int offset, int len) {
if (buffer.hasArray()) {
return hash.resume(current, buffer.array(), buffer.arrayOffset() + buffer.readerIndex(),
buffer.readableBytes());
return hash.resume(current, buffer.array(), buffer.arrayOffset() + offset,
len);
} else {
return hash.resume(current, buffer.nioBuffer());
return hash.resume(current, buffer.slice(offset, len).nioBuffer());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ public int calculate(ByteBuf buffer) {
return resume(0, buffer);
}

@Override
public int calculate(ByteBuf buffer, int offset, int len) {
return resume(0, buffer, offset, len);
}

private int resume(int current, long address, int offset, int length) {
try {
return (int) UPDATE_DIRECT_BYTEBUFFER.invoke(null, current, address, offset, offset + length);
Expand All @@ -89,19 +94,24 @@ private int resume(int current, byte[] array, int offset, int length) {

@Override
public int resume(int current, ByteBuf buffer) {
return resume(current, buffer, buffer.readerIndex(), buffer.readableBytes());
}

@Override
public int resume(int current, ByteBuf buffer, int offset, int len) {
int negCrc = ~current;

if (buffer.hasMemoryAddress()) {
negCrc = resume(negCrc, buffer.memoryAddress(), buffer.readerIndex(), buffer.readableBytes());
negCrc = resume(negCrc, buffer.memoryAddress(), offset, len);
} else if (buffer.hasArray()) {
int offset = buffer.arrayOffset() + buffer.readerIndex();
negCrc = resume(negCrc, buffer.array(), offset, buffer.readableBytes());
int arrayOffset = buffer.arrayOffset() + offset;
negCrc = resume(negCrc, buffer.array(), arrayOffset, len);
} else {
byte[] b = TL_BUFFER.get();
int toRead = buffer.readableBytes();
int toRead = len;
while (toRead > 0) {
int length = Math.min(toRead, b.length);
buffer.readBytes(b, 0, length);
buffer.slice(offset, len).readBytes(b, 0, length);
negCrc = resume(negCrc, b, 0, length);
toRead -= length;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,27 @@ public class JniIntHash implements IntHash {

@Override
public int calculate(ByteBuf buffer) {
return resume(0, buffer);
return calculate(buffer, buffer.readerIndex(), buffer.readableBytes());
}

@Override
public int resume(int current, ByteBuf buffer) {
return resume(current, buffer, buffer.readerIndex(), buffer.readableBytes());
}

@Override
public int calculate(ByteBuf buffer, int offset, int len) {
return resume(0, buffer, offset, len);
}

@Override
public int resume(int current, ByteBuf buffer, int offset, int len) {
if (buffer.hasMemoryAddress()) {
return hash.resume(current, buffer.memoryAddress() + buffer.readerIndex(),
buffer.readableBytes());
return hash.resume(current, buffer.memoryAddress() + offset, len);
} else if (buffer.hasArray()) {
return hash.resume(current, buffer.array(), buffer.arrayOffset() + buffer.readerIndex(),
buffer.readableBytes());
return hash.resume(current, buffer.array(), buffer.arrayOffset() + offset, len);
} else {
return hash.resume(current, buffer.nioBuffer());
return hash.resume(current, buffer.slice(offset, len).nioBuffer());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public void testCrc32cValue() {
@Test
public void testCrc32cValueResume() {
final byte[] bytes = "Some String".getBytes();
int checksum = Crc32cIntChecksum.resumeChecksum(0, Unpooled.wrappedBuffer(bytes));
int checksum = Crc32cIntChecksum.resumeChecksum(0, Unpooled.wrappedBuffer(bytes), 0, bytes.length);

assertEquals(608512271, checksum);
}
Expand All @@ -58,19 +58,19 @@ public void testCrc32cValueIncremental() {

checksum = Crc32cIntChecksum.computeChecksum(Unpooled.wrappedBuffer(bytes, 0, 1));
for (int i = 1; i < bytes.length; i++) {
checksum = Crc32cIntChecksum.resumeChecksum(checksum, Unpooled.wrappedBuffer(bytes, i, 1));
checksum = Crc32cIntChecksum.resumeChecksum(checksum, Unpooled.wrappedBuffer(bytes), i, 1);
}
assertEquals(608512271, checksum);

checksum = Crc32cIntChecksum.computeChecksum(Unpooled.wrappedBuffer(bytes, 0, 4));
checksum = Crc32cIntChecksum.resumeChecksum(checksum, Unpooled.wrappedBuffer(bytes, 4, 7));
checksum = Crc32cIntChecksum.resumeChecksum(checksum, Unpooled.wrappedBuffer(bytes), 4, 7);
assertEquals(608512271, checksum);


ByteBuf buffer = Unpooled.wrappedBuffer(bytes, 0, 4);
checksum = Crc32cIntChecksum.computeChecksum(buffer);
checksum = Crc32cIntChecksum.resumeChecksum(
checksum, Unpooled.wrappedBuffer(bytes, 4, bytes.length - 4));
checksum, Unpooled.wrappedBuffer(bytes), 4, bytes.length - 4);

assertEquals(608512271, checksum);
}
Expand All @@ -86,7 +86,7 @@ public void testCrc32cLongValue() {
@Test
public void testCrc32cLongValueResume() {
final byte[] bytes = "Some String".getBytes();
long checksum = Crc32cIntChecksum.resumeChecksum(0, Unpooled.wrappedBuffer(bytes));
long checksum = Crc32cIntChecksum.resumeChecksum(0, Unpooled.wrappedBuffer(bytes), 0, bytes.length);

assertEquals(608512271L, checksum);
}
Expand Down
Loading

0 comments on commit 1f8de8f

Please sign in to comment.