Skip to content

Commit

Permalink
Support redirecting input from client to daemon, #541
Browse files Browse the repository at this point in the history
The implementation currently switches on the redirection when the daemon actually starts reading the System.in stream using InputStream.read() or InputStream.available().
  • Loading branch information
gnodet committed Dec 13, 2022
1 parent 1249211 commit 8b884ed
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 3 deletions.
76 changes: 76 additions & 0 deletions common/src/main/java/org/mvndaemon/mvnd/common/Message.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ public abstract class Message {
public static final int EXECUTION_FAILURE = 24;
public static final int PRINT_OUT = 25;
public static final int PRINT_ERR = 26;
public static final int REQUEST_INPUT = 27;
public static final int INPUT_DATA = 28;

final int type;

Expand Down Expand Up @@ -115,6 +117,10 @@ public static Message read(DataInputStream input) throws IOException {
case PRINT_OUT:
case PRINT_ERR:
return StringMessage.read(type, input);
case REQUEST_INPUT:
return RequestInput.read(input);
case INPUT_DATA:
return InputData.read(input);
}
throw new IllegalStateException("Unexpected message type: " + type);
}
Expand All @@ -137,6 +143,8 @@ public static int getClassOrder(Message m) {
case DISPLAY:
case PRINT_OUT:
case PRINT_ERR:
case REQUEST_INPUT:
case INPUT_DATA:
return 2;
case PROJECT_STARTED:
return 3;
Expand Down Expand Up @@ -1025,6 +1033,66 @@ public static TransferEvent read(int type, DataInputStream input) throws IOExcep
}
}

public static class RequestInput extends Message {

private String projectId;

public static RequestInput read(DataInputStream input) throws IOException {
String projectId = readUTF(input);
return new RequestInput(projectId);
}

public RequestInput(String projectId) {
super(REQUEST_INPUT);
this.projectId = projectId;
}

public String getProjectId() {
return projectId;
}

@Override
public String toString() {
return "RequestInput{" + "projectId='" + projectId + '\'' + '}';
}

@Override
public void write(DataOutputStream output) throws IOException {
super.write(output);
writeUTF(output, projectId);
}
}

public static class InputData extends Message {

final String data;

public static Message read(DataInputStream input) throws IOException {
String data = readUTF(input);
return new InputData(data);
}

private InputData(String data) {
super(INPUT_DATA);
this.data = data;
}

public String getData() {
return data;
}

@Override
public String toString() {
return "InputResponse{" + "data='" + data + "\'" + '}';
}

@Override
public void write(DataOutputStream output) throws IOException {
super.write(output);
writeUTF(output, data);
}
}

public int getType() {
return type;
}
Expand All @@ -1037,6 +1105,14 @@ public static StringMessage display(String message) {
return new StringMessage(DISPLAY, message);
}

public static RequestInput requestInput(String projectId) {
return new RequestInput(projectId);
}

public static InputData inputResponse(String data) {
return new InputData(data);
}

public static StringMessage out(String message) {
return new StringMessage(PRINT_OUT, message);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import org.mvndaemon.mvnd.common.Message.ExecutionFailureEvent;
import org.mvndaemon.mvnd.common.Message.MojoStartedEvent;
import org.mvndaemon.mvnd.common.Message.ProjectEvent;
import org.mvndaemon.mvnd.common.Message.RequestInput;
import org.mvndaemon.mvnd.common.Message.StringMessage;
import org.mvndaemon.mvnd.common.Message.TransferEvent;
import org.mvndaemon.mvnd.common.OsUtils;
Expand Down Expand Up @@ -112,6 +113,8 @@ public class TerminalOutput implements ClientOutput {
private volatile Consumer<Message> daemonDispatch;
/** A sink for queuing messages to the main queue */
private volatile Consumer<Message> daemonReceive;
/** The project id which is trying to read the input stream */
private volatile String projectReadingInput;

/*
* The following non-final fields are read/written from the main thread only.
Expand Down Expand Up @@ -441,6 +444,15 @@ private boolean doAccept(Message entry) {
failures.add(efe);
break;
}
case Message.REQUEST_INPUT: {
RequestInput ri = (RequestInput) entry;
projectReadingInput = ri.getProjectId();
break;
}
case Message.INPUT_DATA: {
daemonDispatch.accept(entry);
break;
}
default:
throw new IllegalStateException("Unexpected message " + entry);
}
Expand Down Expand Up @@ -480,17 +492,30 @@ void readInputLoop() {
try {
while (!closing) {
if (readInput.readLock().tryLock(10, TimeUnit.MILLISECONDS)) {
try {
if (projectReadingInput != null) {
char[] buf = new char[256];
int idx = 0;
while (idx < buf.length) {
int c = terminal.reader().read(idx > 0 ? 1 : 10);
if (c < 0) {
break;
}
buf[idx++] = (char) c;
}
if (idx > 0) {
String data = String.valueOf(buf, 0, idx);
daemonReceive.accept(Message.inputResponse(data));
}
} else {
int c = terminal.reader().read(10);
if (c == -1) {
break;
}
if (c == KEY_PLUS || c == KEY_MINUS || c == KEY_CTRL_L || c == KEY_CTRL_M || c == KEY_CTRL_B) {
daemonReceive.accept(Message.keyboardInput((char) c));
}
} finally {
readInput.readLock().unlock();
}
readInput.readLock().unlock();
}
}
} catch (InterruptedException e) {
Expand Down
74 changes: 74 additions & 0 deletions daemon/src/main/java/org/mvndaemon/mvnd/daemon/Server.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,23 @@
import static org.mvndaemon.mvnd.common.DaemonState.Stopped;

import java.io.IOException;
import java.io.InputStream;
import java.io.InterruptedIOException;
import java.lang.reflect.Field;
import java.nio.ByteBuffer;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.nio.charset.Charset;
import java.security.MessageDigest;
import java.security.SecureRandom;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingDeque;
Expand All @@ -47,6 +52,7 @@
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.apache.maven.cli.DaemonMavenCli;
Expand Down Expand Up @@ -482,6 +488,8 @@ private void handle(DaemonConnection connection, BuildRequest buildRequest) {
final BlockingQueue<Message> sendQueue = new PriorityBlockingQueue<>(64, Message.getMessageComparator());
final BlockingQueue<Message> recvQueue = new LinkedBlockingDeque<>();
final BuildEventListener buildEventListener = new ClientDispatcher(sendQueue);
final DaemonInputStream daemonInputStream =
new DaemonInputStream(projectId -> sendQueue.add(Message.requestInput(projectId)));
try (ProjectBuildLogAppender logAppender = new ProjectBuildLogAppender(buildEventListener)) {

LOGGER.info("Executing request");
Expand Down Expand Up @@ -529,6 +537,8 @@ private void handle(DaemonConnection connection, BuildRequest buildRequest) {
if (message == Message.BareMessage.CANCEL_BUILD_SINGLETON) {
updateState(Canceled);
return;
} else if (message instanceof Message.InputData) {
daemonInputStream.addInputData(((Message.InputData) message).getData());
} else {
synchronized (recvQueue) {
recvQueue.put(message);
Expand Down Expand Up @@ -581,6 +591,7 @@ public <T extends Message> T request(Message request, Class<T> responseType, Pre
}
}
});
System.setIn(daemonInputStream);
System.setOut(new LoggingOutputStream(s -> sendQueue.add(Message.out(s))).printStream());
System.setErr(new LoggingOutputStream(s -> sendQueue.add(Message.err(s))).printStream());
int exitCode = cli.main(
Expand Down Expand Up @@ -650,4 +661,67 @@ public long getLastBusy() {
public String toString() {
return info.toString();
}

static class DaemonInputStream extends InputStream {
private final Consumer<String> startReadingFromProject;
private final LinkedList<byte[]> datas = new LinkedList<>();
private int pos = -1;
private String projectReading = null;

DaemonInputStream(Consumer<String> startReadingFromProject) {
this.startReadingFromProject = startReadingFromProject;
}

@Override
public int available() throws IOException {
synchronized (datas) {
String projectId = ProjectBuildLogAppender.getProjectId();
if (!Objects.equals(projectId, projectReading)) {
projectReading = projectId;
startReadingFromProject.accept(projectId);
}
return datas.stream().mapToInt(a -> a.length).sum() - Math.max(pos, 0);
}
}

@Override
public int read() throws IOException {
synchronized (datas) {
String projectId = ProjectBuildLogAppender.getProjectId();
if (!Objects.equals(projectId, projectReading)) {
projectReading = projectId;
startReadingFromProject.accept(projectId);
// TODO: start a 10ms timer to turn data off
}
for (; ; ) {
if (datas.isEmpty()) {
try {
datas.wait();
} catch (InterruptedException e) {
throw new InterruptedIOException("Interrupted");
}
pos = -1;
continue;
}
byte[] curData = datas.getFirst();
if (pos >= curData.length) {
datas.removeFirst();
pos = -1;
continue;
}
if (pos < 0) {
pos = 0;
}
return curData[pos++];
}
}
}

public void addInputData(String data) {
synchronized (datas) {
datas.add(data.getBytes(Charset.forName(System.getProperty("file.encoding"))));
datas.notifyAll();
}
}
}
}

0 comments on commit 8b884ed

Please sign in to comment.