-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
29106a7
commit 49944e2
Showing
3 changed files
with
321 additions
and
0 deletions.
There are no files selected for viewing
57 changes: 57 additions & 0 deletions
57
dora/core/server/worker/src/main/java/alluxio/worker/ucx/AsyncFuture.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
package alluxio.worker.ucx; | ||
|
||
import java.util.concurrent.BlockingQueue; | ||
import java.util.concurrent.CompletableFuture; | ||
import java.util.concurrent.ExecutionException; | ||
import java.util.concurrent.LinkedBlockingQueue; | ||
import java.util.concurrent.TimeUnit; | ||
import java.util.concurrent.TimeoutException; | ||
import java.util.concurrent.atomic.AtomicInteger; | ||
|
||
/** | ||
* Collectively wait for a collection of T results | ||
* @param <T> | ||
*/ | ||
public class AsyncFuture<T> { | ||
private AtomicInteger mCompletedCount = new AtomicInteger(0); | ||
private AtomicInteger mTotalExpected = new AtomicInteger(Integer.MAX_VALUE); | ||
private final CompletableFuture<Boolean> mFuture; | ||
|
||
public AsyncFuture() { | ||
mFuture = new CompletableFuture<>(); | ||
} | ||
|
||
public AsyncFuture(int totalExpected) { | ||
mTotalExpected.compareAndSet(Integer.MAX_VALUE, totalExpected); | ||
mFuture = new CompletableFuture<>(); | ||
} | ||
|
||
public void setTotalExpected(int totalExpected) { | ||
mTotalExpected.compareAndSet(Integer.MAX_VALUE, totalExpected); | ||
} | ||
|
||
public void complete(T result) { | ||
if (mCompletedCount.incrementAndGet() >= mTotalExpected.get()) { | ||
mFuture.complete(true); | ||
} | ||
} | ||
|
||
public void fail(Throwable ex) { | ||
mFuture.completeExceptionally(ex); | ||
} | ||
|
||
public boolean get() throws ExecutionException, InterruptedException { | ||
if (mCompletedCount.incrementAndGet() >= mTotalExpected.get()) { | ||
mFuture.complete(true); | ||
} | ||
return mFuture.get(); | ||
} | ||
|
||
public boolean get(long timeout, TimeUnit unit) | ||
throws ExecutionException, InterruptedException, TimeoutException { | ||
if (mCompletedCount.incrementAndGet() >= mTotalExpected.get()) { | ||
mFuture.complete(true); | ||
} | ||
return mFuture.get(timeout, unit); | ||
} | ||
} |
188 changes: 188 additions & 0 deletions
188
dora/core/server/worker/src/main/java/alluxio/worker/ucx/UcxDataReader.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,188 @@ | ||
package alluxio.worker.ucx; | ||
|
||
import alluxio.PositionReader; | ||
import alluxio.file.ByteBufferTargetBuffer; | ||
import alluxio.file.ReadTargetBuffer; | ||
import alluxio.proto.dataserver.Protocol; | ||
import alluxio.ucx.AlluxioUcxUtils; | ||
import alluxio.worker.ucx.UcxConnection; | ||
import alluxio.wire.WorkerNetAddress; | ||
import alluxio.worker.ucx.UcpProxy; | ||
|
||
import com.google.common.base.Preconditions; | ||
import org.openucx.jucx.UcxCallback; | ||
import org.openucx.jucx.UcxException; | ||
import org.openucx.jucx.UcxUtils; | ||
import org.openucx.jucx.ucp.UcpConstants; | ||
import org.openucx.jucx.ucp.UcpEndpoint; | ||
import org.openucx.jucx.ucp.UcpEndpointParams; | ||
import org.openucx.jucx.ucp.UcpMemory; | ||
import org.openucx.jucx.ucp.UcpRequest; | ||
import org.openucx.jucx.ucp.UcpWorker; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
import java.io.IOException; | ||
import java.net.InetAddress; | ||
import java.net.InetSocketAddress; | ||
import java.net.UnknownHostException; | ||
import java.nio.ByteBuffer; | ||
import java.util.LinkedList; | ||
import java.util.Map; | ||
import java.util.TreeMap; | ||
import java.util.function.Supplier; | ||
|
||
public class UcxDataReader implements PositionReader { | ||
private static final Logger LOG = LoggerFactory.getLogger(UcxDataReader.class); | ||
public static final int PAGE_SIZE = 4096; | ||
|
||
InetSocketAddress mAddr; | ||
private static InetSocketAddress sLocalAddr = null; | ||
|
||
UcxConnection mConnection; | ||
// make this a global, one per process only instance | ||
UcpWorker mWorker; | ||
UcpEndpoint mWorkerEndpoint; | ||
|
||
Supplier<Protocol.ReadRequest.Builder> mRequestBuilder; | ||
public UcxDataReader(InetSocketAddress addr, UcpWorker worker, | ||
Protocol.ReadRequest.Builder requestBuilder) { | ||
try { | ||
sLocalAddr = new InetSocketAddress(InetAddress.getLocalHost(),0); | ||
} catch (UnknownHostException e) { | ||
throw new RuntimeException(e); | ||
} | ||
mAddr = addr; | ||
mWorker = worker; | ||
mRequestBuilder = requestBuilder::clone; | ||
} | ||
|
||
|
||
public void acquireServerConn() throws IOException { | ||
try { | ||
mConnection = UcxConnection.initNewConnection(mAddr, mWorker); | ||
} catch (Exception e) { | ||
throw new IOException( | ||
String.format("Error initializing conn with remote:%s", mAddr), e); | ||
} | ||
} | ||
|
||
synchronized public int progressWorker() throws Exception { | ||
return mWorker.progress(); | ||
} | ||
|
||
public void waitForRequest(UcpRequest ucpRequest) { | ||
while(!ucpRequest.isCompleted()) { | ||
try { | ||
progressWorker(); | ||
} catch (Exception e) { | ||
LOG.error("Error progressing req:", e); | ||
} | ||
} | ||
} | ||
|
||
@Override | ||
public int readInternal(long position, ReadTargetBuffer buffer, int length) throws IOException { | ||
// use Stream API | ||
// return readInternalStream(position, buffer, length); | ||
// use RMA API | ||
return readInternalRMA(position, buffer, length); | ||
} | ||
|
||
public int readInternalRMA(long position, ReadTargetBuffer buffer, int length) | ||
throws IOException { | ||
return 0; | ||
} | ||
|
||
public int readInternalStream(long position, ReadTargetBuffer buffer, int length) | ||
throws IOException { | ||
Protocol.ReadRequest.Builder builder = mRequestBuilder.get() | ||
.setLength(length) | ||
.setOffset(position) | ||
.clearCancel(); | ||
Protocol.ReadRequest readRequest = builder.build(); | ||
byte[] serializedBytes = readRequest.toByteArray(); | ||
ByteBuffer buf = ByteBuffer.allocateDirect(PAGE_SIZE); | ||
buf.putInt(serializedBytes.length); | ||
buf.put(serializedBytes); | ||
buf.clear(); | ||
UcpRequest sendRequest = mWorkerEndpoint.sendTaggedNonBlocking( | ||
buf, mConnection.getTagToSend(), new UcxCallback() { | ||
public void onSuccess(UcpRequest request) { | ||
LOG.info("ReadReq:{} sent.", readRequest); | ||
} | ||
|
||
public void onError(int ucsStatus, String errorMsg) { | ||
throw new UcxException(errorMsg); | ||
} | ||
}); | ||
LOG.info("Waiting for read request to send..."); | ||
waitForRequest(sendRequest); | ||
// now wait to recv data | ||
Preconditions.checkArgument(buffer.byteBuffer().isDirect(), "ByteBuffer must be direct buffer"); | ||
int bytesRead = 0; | ||
ByteBuffer preamble = ByteBuffer.allocateDirect(16); | ||
TreeMap<Long, ByteBuffer> buffers = new TreeMap<>(); | ||
preamble.clear(); | ||
LinkedList<UcpRequest> dataUcpRecvReqs = new LinkedList<>(); | ||
while (bytesRead < length) { | ||
UcpRequest recvReq = mWorkerEndpoint.recvStreamNonBlocking(UcxUtils.getAddress(preamble), 16, | ||
UcpConstants.UCP_STREAM_RECV_FLAG_WAITALL, new UcxCallback() { | ||
public void onSuccess(UcpRequest request) {} | ||
|
||
public void onError(int ucsStatus, String errorMsg) { | ||
throw new UcxException(errorMsg); | ||
} | ||
}); | ||
LOG.info("Waiting for preamble..."); | ||
waitForRequest(recvReq); | ||
preamble.clear(); | ||
long seq = preamble.getLong(); | ||
long size = preamble.getLong(); | ||
preamble.clear(); | ||
ByteBuffer seqBuffer = ByteBuffer.allocateDirect(8); | ||
ByteBuffer dataBuffer = ByteBuffer.allocateDirect((int)size); | ||
long[] addrs = new long[2]; | ||
long[] sizes = new long[2]; | ||
addrs[0] = UcxUtils.getAddress(seqBuffer); | ||
addrs[1] = UcxUtils.getAddress(dataBuffer); | ||
sizes[0] = 8; | ||
sizes[1] = size; | ||
LOG.info("preamble info:seq:{}:len:{}", seq, size); | ||
UcpRequest dataRecvReq = mWorkerEndpoint.recvStreamNonBlocking(addrs, sizes, | ||
UcpConstants.UCP_STREAM_RECV_FLAG_WAITALL, new UcxCallback() { | ||
public void onSuccess(UcpRequest request) { | ||
ByteBuffer seqBufView = UcxUtils.getByteBufferView(addrs[0], sizes[0]); | ||
seqBufView.clear(); | ||
long sequence = seqBufView.getLong(); | ||
ByteBuffer dataBufView = UcxUtils.getByteBufferView(addrs[1], sizes[1]); | ||
dataBufView.clear(); | ||
LOG.info("Received buffers, seq:{}, data buf size:{}", sequence, sizes[1]); | ||
buffers.put(sequence, dataBufView); | ||
} | ||
|
||
public void onError(int ucsStatus, String errorMsg) { | ||
LOG.error("Error receiving buffers, seq:{}, data buf size:{}, errorMsg:{}", | ||
seq, size, errorMsg); | ||
throw new UcxException(errorMsg); | ||
} | ||
}); | ||
LOG.info("Offering actual data recReq to q..."); | ||
dataUcpRecvReqs.offer(dataRecvReq); | ||
// waitForRequest(recvReq); | ||
bytesRead += size; | ||
} | ||
while(!dataUcpRecvReqs.isEmpty()) { | ||
UcpRequest nextReq = dataUcpRecvReqs.poll(); | ||
waitForRequest(nextReq); | ||
} | ||
buffer.byteBuffer().clear(); | ||
while (!buffers.isEmpty()) { | ||
Map.Entry<Long, ByteBuffer> entry = buffers.pollFirstEntry(); | ||
LOG.info("Copying seq:{},bufsize:{}", entry.getKey(), entry.getValue()); | ||
entry.getValue().clear(); | ||
buffer.byteBuffer().put(entry.getValue()); | ||
} | ||
return 0; | ||
} | ||
} |
76 changes: 76 additions & 0 deletions
76
dora/core/server/worker/src/test/java/alluxio/worker/ucx/TestUcxConnection.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
package alluxio.worker.ucx; | ||
|
||
import alluxio.concurrent.jsr.CompletableFuture; | ||
|
||
import org.junit.BeforeClass; | ||
import org.junit.Test; | ||
import org.openucx.jucx.ucp.UcpConnectionRequest; | ||
import org.openucx.jucx.ucp.UcpContext; | ||
import org.openucx.jucx.ucp.UcpEndpoint; | ||
import org.openucx.jucx.ucp.UcpEndpointParams; | ||
import org.openucx.jucx.ucp.UcpListener; | ||
import org.openucx.jucx.ucp.UcpListenerConnectionHandler; | ||
import org.openucx.jucx.ucp.UcpListenerParams; | ||
import org.openucx.jucx.ucp.UcpParams; | ||
import org.openucx.jucx.ucp.UcpWorker; | ||
import org.openucx.jucx.ucp.UcpWorkerParams; | ||
|
||
import java.net.InetAddress; | ||
import java.net.InetSocketAddress; | ||
import java.net.UnknownHostException; | ||
import java.util.concurrent.ExecutionException; | ||
|
||
public class TestUcxConnection { | ||
public static UcpContext sGlobalContext; | ||
|
||
@BeforeClass | ||
public void initContext() { | ||
sGlobalContext = new UcpContext(new UcpParams() | ||
.requestStreamFeature() | ||
.requestTagFeature() | ||
.requestWakeupFeature()); | ||
} | ||
|
||
|
||
@Test | ||
public void testEstablishConnection() throws Exception { | ||
InetAddress localAddr = InetAddress.getLocalHost(); | ||
int serverPort = 1234; | ||
UcpWorker serverWorker = sGlobalContext.newWorker(new UcpWorkerParams().requestThreadSafety()); | ||
CompletableFuture<UcpConnectionRequest> incomingConn = new CompletableFuture<>(); | ||
UcpListenerParams listenerParams = new UcpListenerParams() | ||
.setConnectionHandler(new UcpListenerConnectionHandler() { | ||
@Override | ||
public void onConnectionRequest(UcpConnectionRequest connectionRequest) { | ||
incomingConn.complete(connectionRequest); | ||
// mConnectionRequests.offer(connectionRequest); | ||
} | ||
}); | ||
InetSocketAddress remoteAddr = new InetSocketAddress(localAddr, serverPort); | ||
UcpListener ucpListener = serverWorker.newListener( | ||
listenerParams.setSockAddr(remoteAddr)); | ||
Thread serverThread = new Thread(() -> { | ||
try { | ||
UcpConnectionRequest incomeConnReq = incomingConn.get(); | ||
if (incomeConnReq != null) { | ||
UcpEndpoint bootstrapEp = serverWorker.newEndpoint(new UcpEndpointParams() | ||
.setPeerErrorHandlingMode() | ||
.setConnectionRequest(incomeConnReq)); | ||
UcxConnection ucxConnection = UcxConnection.acceptIncomingConnection( | ||
bootstrapEp, serverWorker, incomeConnReq.getClientAddress()); | ||
System.out.println("Conn established from server:" + ucxConnection.toString()); | ||
} | ||
} catch (Exception e) { | ||
e.printStackTrace(); | ||
} | ||
}); | ||
serverThread.start(); | ||
|
||
UcpWorker clientWorker = sGlobalContext.newWorker(new UcpWorkerParams().requestThreadSafety()); | ||
// client init conn | ||
UcxConnection connToServer = UcxConnection.initNewConnection(remoteAddr, clientWorker); | ||
System.out.println("Conn established to server:" + connToServer.toString()); | ||
serverThread.join(); | ||
} | ||
|
||
} |