Skip to content

Commit

Permalink
Reduce Mem Copy in client
Browse files Browse the repository at this point in the history
pr-link: #13168
change-id: cid-56b6b88ec1b44e389e46248e5a6ac12a9c956539
  • Loading branch information
apc999 committed Apr 14, 2021
1 parent 4c4ec92 commit 32f64b6
Show file tree
Hide file tree
Showing 11 changed files with 89 additions and 54 deletions.
Expand Up @@ -40,6 +40,7 @@

import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;

import javax.annotation.concurrent.NotThreadSafe;

Expand Down Expand Up @@ -106,18 +107,17 @@ public static BlockInStream create(FileSystemContext context, BlockInfo info,
long blockId = info.getBlockId();
long blockSize = info.getLength();

AlluxioConfiguration alluxioConf = context.getClusterConf();
boolean shortCircuit = alluxioConf.getBoolean(PropertyKey.USER_SHORT_CIRCUIT_ENABLED);
boolean shortCircuitPreferred =
alluxioConf.getBoolean(PropertyKey.USER_SHORT_CIRCUIT_PREFERRED);
boolean sourceSupportsDomainSocket = NettyUtils.isDomainSocketSupported(dataSource);

if (dataSourceType == BlockInStreamSource.PROCESS_LOCAL) {
// Interaction between the current client and the worker it embedded to should
// go through worker internal communication directly without RPC involves
return createProcessLocalBlockInStream(context, dataSource, blockId, blockSize, options);
}

AlluxioConfiguration alluxioConf = context.getClusterConf();
boolean shortCircuit = alluxioConf.getBoolean(PropertyKey.USER_SHORT_CIRCUIT_ENABLED);
boolean shortCircuitPreferred =
alluxioConf.getBoolean(PropertyKey.USER_SHORT_CIRCUIT_PREFERRED);
boolean sourceSupportsDomainSocket = NettyUtils.isDomainSocketSupported(dataSource);
boolean sourceIsLocal = dataSourceType == BlockInStreamSource.NODE_LOCAL;

// Short circuit is enabled when
Expand Down Expand Up @@ -291,10 +291,23 @@ public int read(byte[] b) throws IOException {

@Override
public int read(byte[] b, int off, int len) throws IOException {
checkIfClosed();
Preconditions.checkArgument(b != null, PreconditionMessage.ERR_READ_BUFFER_NULL);
Preconditions.checkArgument(off >= 0 && len >= 0 && len + off <= b.length,
PreconditionMessage.ERR_BUFFER_STATE.toString(), b.length, off, len);
return read(ByteBuffer.wrap(b), off, len);
}

/**
* Reads up to len bytes of data from the input stream into the byte buffer.
*
* @param byteBuffer the buffer into which the data is read
* @param off the start offset in the buffer at which the data is written
* @param len the maximum number of bytes to read
* @return the total number of bytes read into the buffer, or -1 if there is no more data because
* the end of the stream has been reached
*/
public int read(ByteBuffer byteBuffer, int off, int len) throws IOException {
Preconditions.checkArgument(off >= 0 && len >= 0 && len + off <= byteBuffer.capacity(),
PreconditionMessage.ERR_BUFFER_STATE.toString(), byteBuffer.capacity(), off, len);
checkIfClosed();
if (len == 0) {
return 0;
}
Expand All @@ -310,13 +323,14 @@ public int read(byte[] b, int off, int len) throws IOException {
return -1;
}
int toRead = Math.min(len, mCurrentChunk.readableBytes());
byteBuffer.position(off).limit(off + toRead);
if (mDetailedMetricsEnabled) {
try (Timer.Context ctx = MetricsSystem
.timer(MetricKey.CLIENT_BLOCK_READ_FROM_CHUNK.getName()).time()) {
mCurrentChunk.readBytes(b, off, toRead);
mCurrentChunk.readBytes(byteBuffer);
}
} else {
mCurrentChunk.readBytes(b, off, toRead);
mCurrentChunk.readBytes(byteBuffer);
}
mPos += toRead;
return toRead;
Expand Down
Expand Up @@ -11,9 +11,9 @@

package alluxio.client.block.stream;

import alluxio.client.ReadType;
import alluxio.client.file.FileSystemContext;
import alluxio.client.file.options.InStreamOptions;
import alluxio.grpc.ReadPType;
import alluxio.metrics.MetricKey;
import alluxio.metrics.MetricsSystem;
import alluxio.network.protocol.databuffer.DataBuffer;
Expand Down Expand Up @@ -118,7 +118,7 @@ public Factory(FileSystemContext context, long blockId,
mBlockId = blockId;
mChunkSize = chunkSize;
mClosed = false;
mIsPromote = ReadType.fromProto(options.getOptions().getReadType()).isPromote();
mIsPromote = options.getOptions().getReadType() == ReadPType.CACHE_PROMOTE;
mIsPositionShort = options.getPositionShort();
mOpenUfsBlockOptions = options.getOpenUfsBlockOptions(blockId);
mBlockWorker = context.getProcessLocalWorker();
Expand Down Expand Up @@ -146,7 +146,9 @@ public void close() throws IOException {
if (mClosed) {
return;
}
mReader.close();
if (mReader != null) {
mReader.close();
}
mClosed = true;
}
}
Expand Down
Expand Up @@ -37,6 +37,7 @@
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.time.Duration;
import java.util.HashMap;
import java.util.Map;
Expand Down Expand Up @@ -170,8 +171,21 @@ public int read(byte[] b) throws IOException {
@Override
public int read(byte[] b, int off, int len) throws IOException {
Preconditions.checkArgument(b != null, PreconditionMessage.ERR_READ_BUFFER_NULL);
Preconditions.checkArgument(off >= 0 && len >= 0 && len + off <= b.length,
PreconditionMessage.ERR_BUFFER_STATE.toString(), b.length, off, len);
return read(ByteBuffer.wrap(b), off, len);
}

/**
* Reads up to len bytes of data from the input stream into the byte buffer.
*
* @param byteBuffer the buffer into which the data is read
* @param off the start offset in the buffer at which the data is written
* @param len the maximum number of bytes to read
* @return the total number of bytes read into the buffer, or -1 if there is no more data because
* the end of the stream has been reached
*/
public int read(ByteBuffer byteBuffer, int off, int len) throws IOException {
Preconditions.checkArgument(off >= 0 && len >= 0 && len + off <= byteBuffer.capacity(),
PreconditionMessage.ERR_BUFFER_STATE.toString(), byteBuffer.capacity(), off, len);
if (len == 0) {
return 0;
}
Expand All @@ -186,7 +200,7 @@ public int read(byte[] b, int off, int len) throws IOException {
while (bytesLeft > 0 && mPosition != mLength && retry.attempt()) {
try {
updateStream();
int bytesRead = mBlockInStream.read(b, currentOffset, bytesLeft);
int bytesRead = mBlockInStream.read(byteBuffer, currentOffset, bytesLeft);
if (bytesRead > 0) {
bytesLeft -= bytesRead;
currentOffset += bytesRead;
Expand Down Expand Up @@ -353,10 +367,14 @@ private void updateStream() throws IOException {
boolean isBlockInfoOutdated = true;
// blockInfo is "outdated" when all the locations in that blockInfo are failed workers,
// if there is at least one location that is not a failed worker, then it's not outdated.
for (BlockLocation location : blockInfo.getLocations()) {
if (!mFailedWorkers.containsKey(location.getWorkerAddress())) {
isBlockInfoOutdated = false;
break;
if (mFailedWorkers.isEmpty() || mFailedWorkers.size() < blockInfo.getLocations().size()) {
isBlockInfoOutdated = false;
} else {
for (BlockLocation location : blockInfo.getLocations()) {
if (!mFailedWorkers.containsKey(location.getWorkerAddress())) {
isBlockInfoOutdated = false;
break;
}
}
}
if (isBlockInfoOutdated) {
Expand Down
Expand Up @@ -16,6 +16,7 @@
import alluxio.wire.WorkerNetAddress;

import java.io.IOException;
import java.nio.ByteBuffer;

/**
* A {@link BlockInStream} which reads from the given byte array. The stream is able to track how
Expand All @@ -35,8 +36,8 @@ public TestBlockInStream(byte[] data, long id, long length, boolean shortCircuit
}

@Override
public int read(byte[] b, int off, int len) throws IOException {
int bytesRead = super.read(b, off, len);
public int read(ByteBuffer byteBuffer, int off, int len) throws IOException {
int bytesRead = super.read(byteBuffer, off, len);
if (bytesRead <= 0) {
return bytesRead;
}
Expand Down
Expand Up @@ -59,6 +59,7 @@
import org.powermock.modules.junit4.PowerMockRunnerDelegate;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
Expand Down Expand Up @@ -590,7 +591,7 @@ public void testSkip() throws IOException {
*/
@Test
public void failGetInStream() throws IOException {
when(mBlockStore.getInStream(anyLong(), any(InStreamOptions.class), any()))
when(mBlockStore.getInStream(any(BlockInfo.class), any(InStreamOptions.class), any()))
.thenThrow(new UnavailableException("test exception"));
try {
mTestStream.read();
Expand Down Expand Up @@ -695,7 +696,7 @@ public void readOneRetry() throws Exception {
TestBlockInStream workingStream = mInStreams.get(0);
TestBlockInStream brokenStream = mock(TestBlockInStream.class);
when(mBlockStore
.getInStream(eq(0L), any(InStreamOptions.class), any()))
.getInStream(any(BlockInfo.class), any(InStreamOptions.class), any()))
.thenReturn(brokenStream).thenReturn(workingStream);
when(brokenStream.read()).thenThrow(new UnavailableException("test exception"));
when(brokenStream.getPos()).thenReturn(offset);
Expand All @@ -704,8 +705,7 @@ public void readOneRetry() throws Exception {
int b = mTestStream.read();

doReturn(0).when(brokenStream).read();
verify(brokenStream, times(1))
.read();
verify(brokenStream, times(1)).read();
assertEquals(offset, b);
}

Expand All @@ -714,19 +714,19 @@ public void readBufferRetry() throws Exception {
TestBlockInStream workingStream = mInStreams.get(0);
TestBlockInStream brokenStream = mock(TestBlockInStream.class);
when(mBlockStore
.getInStream(eq(0L), any(InStreamOptions.class), any()))
.getInStream(any(BlockInfo.class), any(InStreamOptions.class), any()))
.thenReturn(brokenStream).thenReturn(workingStream);
when(brokenStream.read(any(byte[].class), anyInt(), anyInt()))
when(brokenStream.read(any(ByteBuffer.class), anyInt(), anyInt()))
.thenThrow(new UnavailableException("test exception"));
when(brokenStream.getPos()).thenReturn(BLOCK_LENGTH / 2);

mTestStream.seek(BLOCK_LENGTH / 2);
byte[] b = new byte[(int) BLOCK_LENGTH * 2];
mTestStream.read(b, 0, b.length);

doReturn(0).when(brokenStream).read(any(byte[].class), anyInt(), anyInt());
doReturn(0).when(brokenStream).read(any(ByteBuffer.class), anyInt(), anyInt());
verify(brokenStream, times(1))
.read(any(byte[].class), anyInt(), anyInt());
.read(any(ByteBuffer.class), anyInt(), anyInt());
assertArrayEquals(BufferUtils.getIncreasingByteArray((int) BLOCK_LENGTH / 2, (int)
BLOCK_LENGTH * 2), b);
}
Expand Down Expand Up @@ -759,7 +759,7 @@ public void positionedReadRetry() throws Exception {
*/
@Test
public void blockInStreamOutOfSync() throws Exception {
when(mBlockStore.getInStream(anyLong(), any(InStreamOptions.class), any()))
when(mBlockStore.getInStream(any(BlockInfo.class), any(InStreamOptions.class), any()))
.thenAnswer(new Answer<BlockInStream>() {
@Override
public BlockInStream answer(InvocationOnMock invocation) throws Throwable {
Expand Down
Expand Up @@ -68,7 +68,14 @@ public void readBytes(OutputStream outputStream, int length) throws IOException

@Override
public void readBytes(ByteBuffer outputBuf) {
outputBuf.put(mBuffer);
if (mBuffer.remaining() <= outputBuf.remaining()) {
outputBuf.put(mBuffer);
} else {
int oldLimit = mBuffer.limit();
mBuffer.limit(mBuffer.position() + outputBuf.remaining());
outputBuf.put(mBuffer);
mBuffer.limit(oldLimit);
}
}

@Override
Expand Down
Expand Up @@ -35,7 +35,7 @@ public BlockReader() {
* Reads data from the block.
*
* @param offset the offset from starting of the block file in bytes
* @param length the length of data to read in bytes, -1 for the rest of the block
* @param length the length of data to read in bytes
* @return {@link ByteBuffer} the data that was read
*/
public abstract ByteBuffer read(long offset, long length) throws IOException;
Expand Down
Expand Up @@ -91,10 +91,6 @@ public String getFilePath() {
public ByteBuffer read(long offset, long length) throws IOException {
Preconditions.checkArgument(offset + length <= mFileSize,
"offset=%s, length=%s, exceeding fileSize=%s", offset, length, mFileSize);
// TODO(calvin): May need to make sure length is an int.
if (length == -1L) {
length = mFileSize - offset;
}
return mLocalFileChannel.map(FileChannel.MapMode.READ_ONLY, offset, length);
}

Expand Down
Expand Up @@ -101,11 +101,6 @@ public void read() throws Exception {
// Read entire block by setting the length to be block size.
buffer = mReader.read(0, TEST_BLOCK_SIZE);
Assert.assertTrue(BufferUtils.equalIncreasingByteBuffer(0, (int) TEST_BLOCK_SIZE, buffer));

// Read entire block by setting the length to be -1
int length = -1;
buffer = mReader.read(0, length);
Assert.assertTrue(BufferUtils.equalIncreasingByteBuffer(0, (int) TEST_BLOCK_SIZE, buffer));
}

/**
Expand Down
Expand Up @@ -12,6 +12,7 @@
package alluxio.fuse;

import alluxio.AlluxioURI;
import alluxio.client.file.AlluxioFileInStream;
import alluxio.client.file.FileInStream;
import alluxio.client.file.FileOutStream;
import alluxio.client.file.FileSystem;
Expand Down Expand Up @@ -355,9 +356,10 @@ public int read(String path, ByteBuffer buf, long size, long offset, FuseFileInf

private int readInternal(String path, ByteBuffer buf, long size, long offset, FuseFileInfo fi) {
MetricsSystem.counter(MetricKey.FUSE_BYTES_TO_READ.getName()).inc(size);
final int sz = (int) size;
int nread = 0;
int rd = 0;
long fd = fi.fh.get();
Long fd = fi.fh.get();
try {
FileInStream is = mOpenFileEntries.get(fd);
if (is == null) {
Expand All @@ -372,19 +374,12 @@ private int readInternal(String path, ByteBuffer buf, long size, long offset, Fu
}
if (offset - is.getPos() < is.remaining()) {
is.seek(offset);
final int sz = (int) size;
final byte[] dest = new byte[sz];
while (rd >= 0 && nread < sz) {
rd = is.read(dest, nread, sz - nread);
rd = ((AlluxioFileInStream) is).read(buf, nread, sz - nread);
if (rd >= 0) {
nread += rd;
}
}
if (nread == -1) { // EOF
nread = 0;
} else if (nread > 0) {
buf.put(dest, 0, nread);
}
}
}
} catch (Throwable e) {
Expand Down
11 changes: 9 additions & 2 deletions integration/fuse/src/main/java/alluxio/fuse/StackFS.java
Expand Up @@ -168,11 +168,18 @@ public int read(String path, ByteBuffer buf, long size, long offset, FuseFileInf
private int readInternal(String path, ByteBuffer buf, long size, long offset, FuseFileInfo fi) {
MetricsSystem.counter("Stackfs.BytesToRead").inc(size);
path = transformPath(path);
final int sz = (int) size;
int nread = 0;
byte[] tmpbuf = new byte[sz];
try (FileInputStream fis = new FileInputStream(path)) {
byte[] tmpbuf = new byte[(int) size];
long nskipped = fis.skip(offset);
nread = fis.read(tmpbuf, 0, (int) size);
int rd = 0;
while (rd >= 0 && nread < sz) {
rd = fis.read(tmpbuf, nread, sz - nread);
if (rd >= 0) {
nread += rd;
}
}
buf.put(tmpbuf, 0, nread);
} catch (IndexOutOfBoundsException e) {
return 0;
Expand Down

0 comments on commit 32f64b6

Please sign in to comment.