Skip to content

Commit

Permalink
Add AbstractFileRegion
Browse files Browse the repository at this point in the history
  • Loading branch information
zsxwing committed Dec 9, 2017
1 parent 93632a0 commit 96df5f2
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 135 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.*;
import io.netty.util.AbstractReferenceCounted;
import org.apache.commons.crypto.stream.CryptoInputStream;
import org.apache.commons.crypto.stream.CryptoOutputStream;

import org.apache.spark.network.util.AbstractFileRegion;
import org.apache.spark.network.util.ByteArrayReadableChannel;
import org.apache.spark.network.util.ByteArrayWritableChannel;

Expand Down Expand Up @@ -161,7 +161,7 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception {
}
}

private static class EncryptedMessage extends AbstractReferenceCounted implements FileRegion {
private static class EncryptedMessage extends AbstractFileRegion {
private final boolean isByteBuf;
private final ByteBuf buf;
private final FileRegion region;
Expand Down Expand Up @@ -198,27 +198,14 @@ public long position() {
return 0;
}

@Override
public long transfered() {
return transferred;
}

@Override
public long transferred() {
return transferred;
}

/**
* Override this due to different return types of ReferenceCounted.touch and FileRegion.touch.
*/
@Override
public EncryptedMessage touch() {
super.touch();
return this;
}

@Override
public EncryptedMessage touch(Object o) {
super.touch(o);
if (region != null) {
region.touch(o);
}
Expand All @@ -228,15 +215,6 @@ public EncryptedMessage touch(Object o) {
return this;
}

/**
* Override this due to different return types of ReferenceCounted.touch and FileRegion.touch.
*/
@Override
public EncryptedMessage retain() {
super.retain();
return this;
}

@Override
public EncryptedMessage retain(int increment) {
super.retain(increment);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,17 @@
import com.google.common.base.Preconditions;
import io.netty.buffer.ByteBuf;
import io.netty.channel.FileRegion;
import io.netty.util.AbstractReferenceCounted;
import io.netty.util.ReferenceCountUtil;

import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.util.AbstractFileRegion;

/**
* A wrapper message that holds two separate pieces (a header and a body).
*
* The header must be a ByteBuf, while the body can be a ByteBuf or a FileRegion.
*/
class MessageWithHeader extends AbstractReferenceCounted implements FileRegion {
class MessageWithHeader extends AbstractFileRegion {

@Nullable private final ManagedBuffer managedBuffer;
private final ByteBuf header;
Expand Down Expand Up @@ -90,11 +90,6 @@ public long position() {
return 0;
}

@Override
public long transfered() {
return totalBytesTransferred;
}

@Override
public long transferred() {
return totalBytesTransferred;
Expand Down Expand Up @@ -166,27 +161,14 @@ private int writeNioBuffer(
return ret;
}

/** Override this due to different return types of ReferenceCounted.touch and FileRegion.touch. */
@Override
public MessageWithHeader touch() {
super.touch();
return this;
}

@Override
public MessageWithHeader touch(Object o) {
super.touch(o);
header.touch(o);
ReferenceCountUtil.touch(body, o);
return this;
}

/** Override this due to different return types of ReferenceCounted.touch and FileRegion.touch. */
@Override
public MessageWithHeader retain() {
super.retain();
return this;
}

@Override
public MessageWithHeader retain(int increment) {
super.retain(increment);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
import io.netty.channel.ChannelPromise;
import io.netty.channel.FileRegion;
import io.netty.handler.codec.MessageToMessageDecoder;
import io.netty.util.AbstractReferenceCounted;

import org.apache.spark.network.util.AbstractFileRegion;
import org.apache.spark.network.util.ByteArrayWritableChannel;
import org.apache.spark.network.util.NettyUtils;

Expand Down Expand Up @@ -129,7 +129,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf msg, List<Object> out)
}

@VisibleForTesting
static class EncryptedMessage extends AbstractReferenceCounted implements FileRegion {
static class EncryptedMessage extends AbstractFileRegion {

private final SaslEncryptionBackend backend;
private final boolean isByteBuf;
Expand Down Expand Up @@ -182,27 +182,14 @@ public long position() {
/**
* Returns an approximation of the amount of data transferred. See {@link #count()}.
*/
@Override
public long transfered() {
return transferred;
}

@Override
public long transferred() {
return transferred;
}

/**
* Override this due to different return types of ReferenceCounted.touch and FileRegion.touch.
*/
@Override
public EncryptedMessage touch() {
super.touch();
return this;
}

@Override
public EncryptedMessage touch(Object o) {
super.touch(o);
if (buf != null) {
buf.touch(o);
}
Expand All @@ -212,15 +199,6 @@ public EncryptedMessage touch(Object o) {
return this;
}

/**
* Override this due to different return types of ReferenceCounted.retain and FileRegion.retain.
*/
@Override
public EncryptedMessage retain() {
super.retain();
return this;
}

@Override
public EncryptedMessage retain(int increment) {
super.retain(increment);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.network.util;

import io.netty.channel.FileRegion;
import io.netty.util.AbstractReferenceCounted;

public abstract class AbstractFileRegion extends AbstractReferenceCounted implements FileRegion {

@Override
@SuppressWarnings("deprecation")
public final long transfered() {
return transferred();
}

@Override
public AbstractFileRegion retain() {
super.retain();
return this;
}

@Override
public AbstractFileRegion retain(int increment) {
super.retain(increment);
return this;
}

@Override
public AbstractFileRegion touch() {
super.touch();
return this;
}

@Override
public AbstractFileRegion touch(Object o) {
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.FileRegion;
import io.netty.util.AbstractReferenceCounted;
import org.apache.spark.network.util.AbstractFileRegion;
import org.junit.Test;
import org.mockito.Mockito;

Expand Down Expand Up @@ -108,7 +107,7 @@ private ByteBuf doWrite(MessageWithHeader msg, int minExpectedWrites) throws Exc
return Unpooled.wrappedBuffer(channel.getData());
}

private static class TestFileRegion extends AbstractReferenceCounted implements FileRegion {
private static class TestFileRegion extends AbstractFileRegion {

private final int writeCount;
private final int writesPerCall;
Expand All @@ -129,39 +128,11 @@ public long position() {
return 0;
}

@Override
public long transfered() {
return 8 * written;
}

@Override
public long transferred() {
return 8 * written;
}

@Override
public TestFileRegion touch() {
super.touch();
return this;
}

@Override
public TestFileRegion touch(Object o) {
return this;
}

@Override
public TestFileRegion retain() {
super.retain();
return this;
}

@Override
public TestFileRegion retain(int increment) {
super.retain(increment);
return this;
}

@Override
public long transferTo(WritableByteChannel target, long position) throws IOException {
for (int i = 0; i < writesPerCall; i++) {
Expand Down
36 changes: 3 additions & 33 deletions core/src/main/scala/org/apache/spark/storage/DiskStore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,11 @@ import java.util.concurrent.ConcurrentHashMap
import scala.collection.mutable.ListBuffer

import com.google.common.io.Closeables
import io.netty.channel.{DefaultFileRegion, FileRegion}
import io.netty.util.AbstractReferenceCounted
import io.netty.channel.DefaultFileRegion

import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.network.util.{AbstractFileRegion, JavaUtils}
import org.apache.spark.security.CryptoStreamUtils
import org.apache.spark.util.Utils
import org.apache.spark.util.io.ChunkedByteBuffer
Expand Down Expand Up @@ -266,7 +265,7 @@ private class EncryptedBlockData(
}

private class ReadableChannelFileRegion(source: ReadableByteChannel, blockSize: Long)
extends AbstractReferenceCounted with FileRegion {
extends AbstractFileRegion {

private var _transferred = 0L

Expand All @@ -277,37 +276,8 @@ private class ReadableChannelFileRegion(source: ReadableByteChannel, blockSize:

override def position(): Long = 0

override def transfered(): Long = _transferred

override def transferred(): Long = _transferred

/**
* Override this due to different return types of ReferenceCounted.touch and FileRegion.touch.
*/
override def touch(): this.type = {
super.touch()
this
}

override def touch(o: Object): this.type = {
this
}

/**
* Override this due to different return types of ReferenceCounted.retain and FileRegion.retain.
*/
override def retain(): this.type = {
super.retain()
this
}

override def retain(increment: Int): this.type = {
super.retain(increment)
this
}

override def release(decrement: Int): Boolean = super.release(decrement)

override def transferTo(target: WritableByteChannel, pos: Long): Long = {
assert(pos == transfered(), "Invalid position.")

Expand Down

0 comments on commit 96df5f2

Please sign in to comment.