Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.io;

import org.apache.spark.storage.StorageUtils;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.nio.file.StandardOpenOption;

/**
* {@link InputStream} implementation which uses direct buffer
* to read a file to avoid extra copy of data between Java and
* native memory which happens when using {@link java.io.BufferedInputStream}.
* Unfortunately, this is not something already available in JDK,
* {@link sun.nio.ch.ChannelInputStream} supports reading a file using nio,
* but does not support buffering.
*/
public final class NioBufferedFileInputStream extends InputStream {

private static final int DEFAULT_BUFFER_SIZE_BYTES = 8192;

private final ByteBuffer byteBuffer;

private final FileChannel fileChannel;

public NioBufferedFileInputStream(File file, int bufferSizeInBytes) throws IOException {
byteBuffer = ByteBuffer.allocateDirect(bufferSizeInBytes);
fileChannel = FileChannel.open(file.toPath(), StandardOpenOption.READ);
byteBuffer.flip();
}

public NioBufferedFileInputStream(File file) throws IOException {
this(file, DEFAULT_BUFFER_SIZE_BYTES);
}

/**
* Checks weather data is left to be read from the input stream.
* @return true if data is left, false otherwise
* @throws IOException
*/
private boolean refill() throws IOException {
if (!byteBuffer.hasRemaining()) {
byteBuffer.clear();
int nRead = 0;
while (nRead == 0) {
nRead = fileChannel.read(byteBuffer);
}
if (nRead < 0) {
return false;
}
byteBuffer.flip();
}
return true;
}

@Override
public synchronized int read() throws IOException {
if (!refill()) {
return -1;
}
return byteBuffer.get() & 0xFF;
}

@Override
public synchronized int read(byte[] b, int offset, int len) throws IOException {
if (offset < 0 || len < 0 || offset + len < 0 || offset + len > b.length) {
throw new IndexOutOfBoundsException();
}
if (!refill()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: please add the defense codes like BufferedInputStream

        if ((off | len | (off + len) | (b.length - (off + len))) < 0) {
            throw new IndexOutOfBoundsException();
        } else if (len == 0) {
            return 0;
        }

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

return -1;
}
len = Math.min(len, byteBuffer.remaining());
byteBuffer.get(b, offset, len);
return len;
}

@Override
public synchronized int available() throws IOException {
return byteBuffer.remaining();
}

@Override
public synchronized long skip(long n) throws IOException {
if (n <= 0L) {
return 0L;
}
if (byteBuffer.remaining() >= n) {
// The buffered content is enough to skip
byteBuffer.position(byteBuffer.position() + (int) n);
return n;
}
long skippedFromBuffer = byteBuffer.remaining();
long toSkipFromFileChannel = n - skippedFromBuffer;
// Discard everything we have read in the buffer.
byteBuffer.position(0);
byteBuffer.flip();
return skippedFromBuffer + skipFromFileChannel(toSkipFromFileChannel);
}

private long skipFromFileChannel(long n) throws IOException {
long currentFilePosition = fileChannel.position();
long size = fileChannel.size();
if (n > size - currentFilePosition) {
fileChannel.position(size);
return size - currentFilePosition;
} else {
fileChannel.position(currentFilePosition + n);
return n;
}
}

@Override
public synchronized void close() throws IOException {
fileChannel.close();
StorageUtils.dispose(byteBuffer);
}

@Override
protected void finalize() throws IOException {
close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.google.common.io.Closeables;

import org.apache.spark.SparkEnv;
import org.apache.spark.io.NioBufferedFileInputStream;
import org.apache.spark.serializer.SerializerManager;
import org.apache.spark.storage.BlockId;
import org.apache.spark.unsafe.Platform;
Expand Down Expand Up @@ -69,8 +70,8 @@ public UnsafeSorterSpillReader(
bufferSizeBytes = DEFAULT_BUFFER_SIZE_BYTES;
}

final BufferedInputStream bs =
new BufferedInputStream(new FileInputStream(file), (int) bufferSizeBytes);
final InputStream bs =
new NioBufferedFileInputStream(file, (int) bufferSizeBytes);
try {
this.in = serializerManager.wrapStream(blockId, bs);
this.din = new DataInputStream(this.in);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import com.google.common.io.ByteStreams

import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.io.NioBufferedFileInputStream
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID
Expand Down Expand Up @@ -89,7 +90,7 @@ private[spark] class IndexShuffleBlockResolver(
val lengths = new Array[Long](blocks)
// Read the lengths of blocks
val in = try {
new DataInputStream(new BufferedInputStream(new FileInputStream(index)))
new DataInputStream(new NioBufferedFileInputStream(index))
} catch {
case e: IOException =>
return null
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.io;

import org.apache.commons.io.FileUtils;
import org.apache.commons.lang3.RandomUtils;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;

import static org.junit.Assert.assertEquals;

/**
* Tests functionality of {@link NioBufferedFileInputStream}
*/
public class NioBufferedFileInputStreamSuite {

private byte[] randomBytes;

private File inputFile;

@Before
public void setUp() throws IOException {
// Create a byte array of size 2 MB with random bytes
randomBytes = RandomUtils.nextBytes(2 * 1024 * 1024);
inputFile = File.createTempFile("temp-file", ".tmp");
FileUtils.writeByteArrayToFile(inputFile, randomBytes);
}

@After
public void tearDown() {
inputFile.delete();
}

@Test
public void testReadOneByte() throws IOException {
InputStream inputStream = new NioBufferedFileInputStream(inputFile);
for (int i = 0; i < randomBytes.length; i++) {
assertEquals(randomBytes[i], (byte) inputStream.read());
}
}

@Test
public void testReadMultipleBytes() throws IOException {
InputStream inputStream = new NioBufferedFileInputStream(inputFile);
byte[] readBytes = new byte[8 * 1024];
int i = 0;
while (i < randomBytes.length) {
int read = inputStream.read(readBytes, 0, 8 * 1024);
for (int j = 0; j < read; j++) {
assertEquals(randomBytes[i], readBytes[j]);
i++;
}
}
}

@Test
public void testBytesSkipped() throws IOException {
InputStream inputStream = new NioBufferedFileInputStream(inputFile);
assertEquals(1024, inputStream.skip(1024));
for (int i = 1024; i < randomBytes.length; i++) {
assertEquals(randomBytes[i], (byte) inputStream.read());
}
}

@Test
public void testBytesSkippedAfterRead() throws IOException {
InputStream inputStream = new NioBufferedFileInputStream(inputFile);
for (int i = 0; i < 1024; i++) {
assertEquals(randomBytes[i], (byte) inputStream.read());
}
assertEquals(1024, inputStream.skip(1024));
for (int i = 2048; i < randomBytes.length; i++) {
assertEquals(randomBytes[i], (byte) inputStream.read());
}
}

@Test
public void testNegativeBytesSkippedAfterRead() throws IOException {
InputStream inputStream = new NioBufferedFileInputStream(inputFile);
for (int i = 0; i < 1024; i++) {
assertEquals(randomBytes[i], (byte) inputStream.read());
}
// Skipping negative bytes should essential be a no-op
assertEquals(0, inputStream.skip(-1));
assertEquals(0, inputStream.skip(-1024));
assertEquals(0, inputStream.skip(Long.MIN_VALUE));
assertEquals(1024, inputStream.skip(1024));
for (int i = 2048; i < randomBytes.length; i++) {
assertEquals(randomBytes[i], (byte) inputStream.read());
}
}

@Test
public void testSkipFromFileChannel() throws IOException {
InputStream inputStream = new NioBufferedFileInputStream(inputFile, 10);
// Since the buffer is smaller than the skipped bytes, this will guarantee
// we skip from underlying file channel.
assertEquals(1024, inputStream.skip(1024));
for (int i = 1024; i < 2048; i++) {
assertEquals(randomBytes[i], (byte) inputStream.read());
}
assertEquals(256, inputStream.skip(256));
assertEquals(256, inputStream.skip(256));
assertEquals(512, inputStream.skip(512));
for (int i = 3072; i < randomBytes.length; i++) {
assertEquals(randomBytes[i], (byte) inputStream.read());
}
}

@Test
public void testBytesSkippedAfterEOF() throws IOException {
InputStream inputStream = new NioBufferedFileInputStream(inputFile);
assertEquals(randomBytes.length, inputStream.skip(randomBytes.length + 1));
assertEquals(-1, inputStream.read());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.io._
import com.google.common.io.Closeables

import org.apache.spark.SparkException
import org.apache.spark.io.NioBufferedFileInputStream
import org.apache.spark.memory.{MemoryConsumer, TaskMemoryManager}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.unsafe.Platform
Expand Down Expand Up @@ -130,7 +131,7 @@ private[python] case class DiskRowQueue(file: File, fields: Int) extends RowQueu
if (out != null) {
out.close()
out = null
in = new DataInputStream(new BufferedInputStream(new FileInputStream(file.toString)))
in = new DataInputStream(new NioBufferedFileInputStream(file))
}

if (unreadBytes > 0) {
Expand Down