Skip to content

Commit

Permalink
try unwrapping memory in compress/decompress
Browse files Browse the repository at this point in the history
  • Loading branch information
rjernst committed Mar 22, 2024
1 parent cf60789 commit 4c9e73d
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 41 deletions.
Expand Up @@ -15,7 +15,7 @@
import java.nio.ByteBuffer;

class JnaCloseableByteBuffer implements CloseableByteBuffer {
private final Memory memory;
final Memory memory;
private final ByteBuffer bufferView;

JnaCloseableByteBuffer(int len) {
Expand Down
Expand Up @@ -9,8 +9,12 @@
package org.elasticsearch.nativeaccess.jna;

import com.sun.jna.Library;
import com.sun.jna.Memory;
import com.sun.jna.Native;

import com.sun.jna.Pointer;

import org.elasticsearch.nativeaccess.CloseableByteBuffer;
import org.elasticsearch.nativeaccess.lib.ZstdLibrary;

import java.nio.ByteBuffer;
Expand All @@ -20,13 +24,13 @@ class JnaZstdLibrary implements ZstdLibrary {
private interface NativeFunctions extends Library {
long ZSTD_compressBound(int scrLen);

long ZSTD_compress(ByteBuffer dst, int dstLen, ByteBuffer src, int srcLen, int compressionLevel);
long ZSTD_compress(Pointer dst, int dstLen, Pointer src, int srcLen, int compressionLevel);

boolean ZSTD_isError(long code);

String ZSTD_getErrorName(long code);

long ZSTD_decompress(ByteBuffer dst, int dstLen, ByteBuffer src, int srcLen);
long ZSTD_decompress(Pointer dst, int dstLen, Pointer src, int srcLen);
}

private final NativeFunctions functions;
Expand All @@ -41,8 +45,12 @@ public long compressBound(int scrLen) {
}

@Override
public long compress(ByteBuffer dst, ByteBuffer src, int compressionLevel) {
return functions.ZSTD_compress(dst, dst.remaining(), src, src.remaining(), compressionLevel);
public long compress(CloseableByteBuffer dst, CloseableByteBuffer src, int compressionLevel) {
assert dst instanceof JnaCloseableByteBuffer;
assert src instanceof JnaCloseableByteBuffer;
var nativeDst = (JnaCloseableByteBuffer) dst;
var nativeSrc = (JnaCloseableByteBuffer) src;
return functions.ZSTD_compress(nativeDst.memory.share(dst.buffer().position()), dst.buffer().remaining(), nativeSrc.memory.share(src.buffer().position()), src.buffer().remaining(), compressionLevel);
}

@Override
Expand All @@ -56,7 +64,11 @@ public String getErrorName(long code) {
}

@Override
public long decompress(ByteBuffer dst, ByteBuffer src) {
return functions.ZSTD_decompress(dst, dst.remaining(), src, src.remaining());
public long decompress(CloseableByteBuffer dst, CloseableByteBuffer src) {
assert dst instanceof JnaCloseableByteBuffer;
assert src instanceof JnaCloseableByteBuffer;
var nativeDst = (JnaCloseableByteBuffer) dst;
var nativeSrc = (JnaCloseableByteBuffer) src;
return functions.ZSTD_decompress(nativeDst.memory.share(dst.buffer().position()), dst.buffer().remaining(), nativeSrc.memory.share(src.buffer().position()), src.buffer().remaining());
}
}
Expand Up @@ -25,13 +25,13 @@ public final class Zstd {
* Compress the content of {@code src} into {@code dst} at compression level {@code level}, and return the number of compressed bytes.
* {@link ByteBuffer#position()} and {@link ByteBuffer#limit()} of both {@link ByteBuffer}s are left unmodified.
*/
public int compress(ByteBuffer dst, ByteBuffer src, int level) {
public int compress(CloseableByteBuffer dst, CloseableByteBuffer src, int level) {
Objects.requireNonNull(dst, "Null destination buffer");
Objects.requireNonNull(src, "Null source buffer");
assert dst.isDirect();
/*assert dst.isDirect();
assert dst.isReadOnly() == false;
assert src.isDirect();
assert src.isReadOnly() == false;
assert src.isReadOnly() == false;*/
long ret = zstdLib.compress(dst, src, level);
if (zstdLib.isError(ret)) {
throw new IllegalArgumentException(zstdLib.getErrorName(ret));
Expand All @@ -45,13 +45,13 @@ public int compress(ByteBuffer dst, ByteBuffer src, int level) {
* Compress the content of {@code src} into {@code dst}, and return the number of decompressed bytes. {@link ByteBuffer#position()} and
* {@link ByteBuffer#limit()} of both {@link ByteBuffer}s are left unmodified.
*/
public int decompress(ByteBuffer dst, ByteBuffer src) {
public int decompress(CloseableByteBuffer dst, CloseableByteBuffer src) {
Objects.requireNonNull(dst, "Null destination buffer");
Objects.requireNonNull(src, "Null source buffer");
assert dst.isDirect();
/*assert dst.isDirect();
assert dst.isReadOnly() == false;
assert src.isDirect();
assert src.isReadOnly() == false;
assert src.isReadOnly() == false;*/
long ret = zstdLib.decompress(dst, src);
if (zstdLib.isError(ret)) {
throw new IllegalArgumentException(zstdLib.getErrorName(ret));
Expand Down
Expand Up @@ -8,17 +8,17 @@

package org.elasticsearch.nativeaccess.lib;

import java.nio.ByteBuffer;
import org.elasticsearch.nativeaccess.CloseableByteBuffer;

public non-sealed interface ZstdLibrary extends NativeLibrary {

long compressBound(int scrLen);

long compress(ByteBuffer dst, ByteBuffer src, int compressionLevel);
long compress(CloseableByteBuffer dst, CloseableByteBuffer src, int compressionLevel);

boolean isError(long code);

String getErrorName(long code);

long decompress(ByteBuffer dst, ByteBuffer src);
long decompress(CloseableByteBuffer dst, CloseableByteBuffer src);
}
Expand Up @@ -11,15 +11,18 @@
import org.elasticsearch.nativeaccess.CloseableByteBuffer;

import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.nio.ByteBuffer;

class JdkCloseableByteBuffer implements CloseableByteBuffer {
private final Arena arena;
final MemorySegment segment;
private final ByteBuffer bufferView;

JdkCloseableByteBuffer(int len) {
this.arena = Arena.ofShared();
this.bufferView = this.arena.allocate(len).asByteBuffer();
this.segment = arena.allocate(len);
this.bufferView = segment.asByteBuffer();
}

@Override
Expand Down
Expand Up @@ -8,12 +8,12 @@

package org.elasticsearch.nativeaccess.jdk;

import org.elasticsearch.nativeaccess.CloseableByteBuffer;
import org.elasticsearch.nativeaccess.lib.ZstdLibrary;

import java.lang.foreign.FunctionDescriptor;
import java.lang.foreign.MemorySegment;
import java.lang.invoke.MethodHandle;
import java.nio.ByteBuffer;

import static java.lang.foreign.ValueLayout.ADDRESS;
import static java.lang.foreign.ValueLayout.JAVA_BOOLEAN;
Expand Down Expand Up @@ -49,11 +49,15 @@ public long compressBound(int srcLen) {
}

@Override
public long compress(ByteBuffer dst, ByteBuffer src, int compressionLevel) {
var nativeDst = MemorySegment.ofBuffer(dst);
var nativeSrc = MemorySegment.ofBuffer(src);
public long compress(CloseableByteBuffer dst, CloseableByteBuffer src, int compressionLevel) {
assert dst instanceof JdkCloseableByteBuffer;
assert src instanceof JdkCloseableByteBuffer;
var nativeDst = (JdkCloseableByteBuffer) dst;
var nativeSrc = (JdkCloseableByteBuffer) src;
var segmentDst = nativeDst.segment.asSlice(dst.buffer().position(), dst.buffer().remaining());
var segmentSrc = nativeSrc.segment.asSlice(src.buffer().position(), src.buffer().remaining());
try {
return (long) compress$mh.invokeExact(nativeDst, dst.remaining(), nativeSrc, src.remaining(), compressionLevel);
return (long) compress$mh.invokeExact(segmentDst, segmentDst.byteSize(), segmentSrc, segmentSrc.byteSize(), compressionLevel);
} catch (Throwable t) {
throw new AssertionError(t);
}
Expand All @@ -79,11 +83,15 @@ public String getErrorName(long code) {
}

@Override
public long decompress(ByteBuffer dst, ByteBuffer src) {
var nativeDst = MemorySegment.ofBuffer(dst);
var nativeSrc = MemorySegment.ofBuffer(src);
public long decompress(CloseableByteBuffer dst, CloseableByteBuffer src) {
assert dst instanceof JdkCloseableByteBuffer;
assert src instanceof JdkCloseableByteBuffer;
var nativeDst = (JdkCloseableByteBuffer) dst;
var nativeSrc = (JdkCloseableByteBuffer) src;
var segmentDst = nativeDst.segment.asSlice(dst.buffer().position(), dst.buffer().remaining());
var segmentSrc = nativeSrc.segment.asSlice(src.buffer().position(), src.buffer().remaining());
try {
return (long) decompress$mh.invokeExact(nativeDst, dst.remaining(), nativeSrc, src.remaining());
return (long) decompress$mh.invokeExact(segmentDst, segmentDst.byteSize(), segmentSrc, segmentSrc.byteSize());
} catch (Throwable t) {
throw new AssertionError(t);
}
Expand Down
Expand Up @@ -41,16 +41,16 @@ public void testCompressValidation() {
var srcBuf = src.buffer();
var dstBuf = dst.buffer();

var npe1 = expectThrows(NullPointerException.class, () -> zstd.compress(null, srcBuf, 0));
var npe1 = expectThrows(NullPointerException.class, () -> zstd.compress(null, src, 0));
assertThat(npe1.getMessage(), equalTo("Null destination buffer"));
var npe2 = expectThrows(NullPointerException.class, () -> zstd.compress(dstBuf, null, 0));
var npe2 = expectThrows(NullPointerException.class, () -> zstd.compress(dst, null, 0));
assertThat(npe2.getMessage(), equalTo("Null source buffer"));

// dst capacity too low
for (int i = 0; i < srcBuf.remaining(); ++i) {
srcBuf.put(i, randomByte());
}
var e = expectThrows(IllegalArgumentException.class, () -> zstd.compress(dstBuf, srcBuf, 0));
var e = expectThrows(IllegalArgumentException.class, () -> zstd.compress(dst, src, 0));
assertThat(e.getMessage(), equalTo("Destination buffer is too small"));
}
}
Expand All @@ -64,21 +64,21 @@ public void testDecompressValidation() {
var originalBuf = original.buffer();
var compressedBuf = compressed.buffer();

var npe1 = expectThrows(NullPointerException.class, () -> zstd.decompress(null, originalBuf));
var npe1 = expectThrows(NullPointerException.class, () -> zstd.decompress(null, original));
assertThat(npe1.getMessage(), equalTo("Null destination buffer"));
var npe2 = expectThrows(NullPointerException.class, () -> zstd.decompress(compressedBuf, null));
var npe2 = expectThrows(NullPointerException.class, () -> zstd.decompress(compressed, null));
assertThat(npe2.getMessage(), equalTo("Null source buffer"));

// Invalid compressed format
for (int i = 0; i < originalBuf.remaining(); ++i) {
originalBuf.put(i, (byte) i);
}
var e = expectThrows(IllegalArgumentException.class, () -> zstd.decompress(compressedBuf, originalBuf));
var e = expectThrows(IllegalArgumentException.class, () -> zstd.decompress(compressed, original));
assertThat(e.getMessage(), equalTo("Unknown frame descriptor"));

int compressedLength = zstd.compress(compressedBuf, originalBuf, 0);
int compressedLength = zstd.compress(compressed, original, 0);
compressedBuf.limit(compressedLength);
e = expectThrows(IllegalArgumentException.class, () -> zstd.decompress(restored.buffer(), compressedBuf));
e = expectThrows(IllegalArgumentException.class, () -> zstd.decompress(restored, compressed));
assertThat(e.getMessage(), equalTo("Destination buffer is too small"));

}
Expand Down Expand Up @@ -109,9 +109,9 @@ private void doTestRoundtrip(byte[] data) {
var restored = nativeAccess.newBuffer(data.length)
) {
original.buffer().put(0, data);
int compressedLength = zstd.compress(compressed.buffer(), original.buffer(), randomIntBetween(-3, 9));
int compressedLength = zstd.compress(compressed, original, randomIntBetween(-3, 9));
compressed.buffer().limit(compressedLength);
int decompressedLength = zstd.decompress(restored.buffer(), compressed.buffer());
int decompressedLength = zstd.decompress(restored, compressed);
assertThat(restored.buffer(), equalTo(original.buffer()));
assertThat(decompressedLength, equalTo(data.length));
}
Expand All @@ -127,15 +127,15 @@ private void doTestRoundtrip(byte[] data) {
original.buffer().put(decompressedOffset, data);
original.buffer().position(decompressedOffset);
compressed.buffer().position(compressedOffset);
int compressedLength = zstd.compress(compressed.buffer(), original.buffer(), randomIntBetween(-3, 9));
int compressedLength = zstd.compress(compressed, original, randomIntBetween(-3, 9));
compressed.buffer().limit(compressedOffset + compressedLength);
restored.buffer().position(decompressedOffset);
int decompressedLength = zstd.decompress(restored.buffer(), compressed.buffer());
int decompressedLength = zstd.decompress(restored, compressed);
assertThat(decompressedLength, equalTo(data.length));
assertThat(
restored.buffer().slice(decompressedOffset, data.length),
equalTo(original.buffer().slice(decompressedOffset, data.length))
);
assertThat(decompressedLength, equalTo(data.length));
}
}
}
Expand Up @@ -132,7 +132,7 @@ public void decompress(DataInput in, int originalLength, int offset, int length,
}
src.buffer().flip();

final int decompressedLen = zstd.decompress(dest.buffer(), src.buffer());
final int decompressedLen = zstd.decompress(dest, src);
if (decompressedLen != originalLength) {
throw new CorruptIndexException("Expected " + originalLength + " decompressed bytes, got " + decompressedLen, in);
}
Expand Down Expand Up @@ -183,7 +183,7 @@ public void compress(ByteBuffersDataInput buffersInput, DataOutput out) throws I
}
src.buffer().flip();

final int compressedLen = zstd.compress(dest.buffer(), src.buffer(), level);
final int compressedLen = zstd.compress(dest, src, level);
out.writeVInt(compressedLen);

for (int written = 0; written < compressedLen;) {
Expand Down

0 comments on commit 4c9e73d

Please sign in to comment.