Skip to content

Commit

Permalink
THRIFT-5762 Expose service result objects in Java
Browse files Browse the repository at this point in the history
Some libraries want to bypass the TServer class and handle the full
service startup manually. For example when building a service that hosts
multiple thrift services where the IFace type is unknown when handling a
request.

For example when you host multiple services on top of netty and through
an HTTP path you want to route to the correct thrift service. In this
situation you treat can treat an IFace as an Object and use the
`getProcessMapView()` method to parse a byte array into a thrift message
and pass let the `AsyncProcessFunction` handle the invocation.

To return a correct thrift response it's necessary to write the
`{service_name}_result` that contains the response args.
While it is possible to get an incoming args object from the
(Async)ProcessFunction its unfortunately not possible to get
a result object without using reflection.

This PR extends the (Async)ProcessFunction by adding a
`getEmptyResultInstance` method that returns a new generic `A` (answer)
that matches the `{service_name}_result` object.

This allows thrift users to write the following processing code:
```java
<I> void handleRequest(
        TProtocol in,
        TProtocol out,
        TBaseAsyncProcessor<I> processor,
        I asyncIface
) throws TException {
    final Map<String, AsyncProcessFunction<Object, TBase<?, ?>, TBase<?, ?>, TBase<?, ?>>> processMap = (Map) processor.getProcessMapView();
    final var msg = in.readMessageBegin();
    final var fn = processMap.get(msg.name);

    final var args = fn.getEmptyArgsInstance();
    args.read(in);
    in.readMessageEnd();

    if (fn.isOneway()) {
        return;
    }

    fn.start(asyncIface, args, new AsyncMethodCallback<>() {
        @OverRide
        public void onComplete(TBase<?, ?> o) {
            try {
                out.writeMessageBegin(new TMessage(fn.getMethodName(), TMessageType.REPLY, msg.getSeqid()));
                final var response_result = fn.getEmptyResultInstance();
                final var success_field = response_result.fieldForId(SUCCESS_ID);
                ((TBase) response_result).setFieldValue(success_field, o);
                response_result.write(out);
                out.writeMessageEnd();
                out.getTransport().flush();
            } catch (TException e) {
                throw new RuntimeException(e);
            }
        }

        @OverRide
        public void onError(Exception e) {
            try {
                out.writeMessageBegin(new TMessage(fn.getMethodName(), TMessageType.EXCEPTION, msg.getSeqid()));
                ((TApplicationException) e).write(out);
                out.writeMessageEnd();
                out.getTransport().flush();
            } catch (TException ex) {
                throw new RuntimeException(ex);
            }
        }
    });
}
```
The above example code doesn't need any reference to the original types
and can dynamically create the correct objects to return a correct
response.
  • Loading branch information
thomasbruggink authored and Jens-G committed Mar 5, 2024
1 parent da2ef34 commit b6cf049
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 92 deletions.
51 changes: 36 additions & 15 deletions compiler/cpp/src/thrift/generate/t_java_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3635,22 +3635,23 @@ void t_java_generator::generate_service_server(t_service* tservice) {
indent(f_service_) << "public Processor(I iface) {" << endl;
indent(f_service_) << " super(iface, getProcessMap(new java.util.HashMap<java.lang.String, "
"org.apache.thrift.ProcessFunction<I, ? extends "
"org.apache.thrift.TBase>>()));"
"org.apache.thrift.TBase, ? extends org.apache.thrift.TBase>>()));"
<< endl;
indent(f_service_) << "}" << endl << endl;

indent(f_service_) << "protected Processor(I iface, java.util.Map<java.lang.String, "
"org.apache.thrift.ProcessFunction<I, ? extends org.apache.thrift.TBase>> "
"processMap) {"
"org.apache.thrift.ProcessFunction<I, ? extends org.apache.thrift.TBase, ? "
"extends org.apache.thrift.TBase>> processMap) {"
<< endl;
indent(f_service_) << " super(iface, getProcessMap(processMap));" << endl;
indent(f_service_) << "}" << endl << endl;

indent(f_service_) << "private static <I extends Iface> java.util.Map<java.lang.String, "
"org.apache.thrift.ProcessFunction<I, ? extends org.apache.thrift.TBase>> "
indent(f_service_) << "private static <I extends Iface> java.util.Map<java.lang.String, "
"org.apache.thrift.ProcessFunction<I, ? extends org.apache.thrift.TBase, "
"? extends org.apache.thrift.TBase>> "
"getProcessMap(java.util.Map<java.lang.String, "
"org.apache.thrift.ProcessFunction<I, ? extends "
" org.apache.thrift.TBase>> processMap) {"
" org.apache.thrift.TBase, ? extends org.apache.thrift.TBase>> processMap) {"
<< endl;
indent_up();
for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
Expand Down Expand Up @@ -3702,23 +3703,23 @@ void t_java_generator::generate_service_async_server(t_service* tservice) {
indent(f_service_) << "public AsyncProcessor(I iface) {" << endl;
indent(f_service_) << " super(iface, getProcessMap(new java.util.HashMap<java.lang.String, "
"org.apache.thrift.AsyncProcessFunction<I, ? extends "
"org.apache.thrift.TBase, ?>>()));"
"org.apache.thrift.TBase, ?, ? extends org.apache.thrift.TBase>>()));"
<< endl;
indent(f_service_) << "}" << endl << endl;

indent(f_service_) << "protected AsyncProcessor(I iface, java.util.Map<java.lang.String, "
"org.apache.thrift.AsyncProcessFunction<I, ? extends "
"org.apache.thrift.TBase, ?>> processMap) {"
"org.apache.thrift.TBase, ?, ? extends org.apache.thrift.TBase>> processMap) {"
<< endl;
indent(f_service_) << " super(iface, getProcessMap(processMap));" << endl;
indent(f_service_) << "}" << endl << endl;

indent(f_service_)
<< "private static <I extends AsyncIface> java.util.Map<java.lang.String, "
"org.apache.thrift.AsyncProcessFunction<I, ? extends "
"org.apache.thrift.TBase,?>> getProcessMap(java.util.Map<java.lang.String, "
"org.apache.thrift.TBase, ?, ? extends org.apache.thrift.TBase>> getProcessMap(java.util.Map<java.lang.String, "
"org.apache.thrift.AsyncProcessFunction<I, ? extends "
"org.apache.thrift.TBase, ?>> processMap) {"
"org.apache.thrift.TBase, ?, ? extends org.apache.thrift.TBase>> processMap) {"
<< endl;
indent_up();
for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
Expand Down Expand Up @@ -3783,13 +3784,23 @@ void t_java_generator::generate_process_async_function(t_service* tservice, t_fu
// Open class
indent(f_service_) << "public static class " << make_valid_java_identifier(tfunction->get_name())
<< "<I extends AsyncIface> extends org.apache.thrift.AsyncProcessFunction<I, "
<< argsname << ", " << resulttype << "> {" << endl;
<< argsname << ", " << resulttype << ", " << resultname << "> {" << endl;
indent_up();

indent(f_service_) << "public " << make_valid_java_identifier(tfunction->get_name()) << "() {" << endl;
indent(f_service_) << " super(\"" << tfunction->get_name() << "\");" << endl;
indent(f_service_) << "}" << endl << endl;

indent(f_service_) << java_override_annotation() << endl;
indent(f_service_) << "public " << resultname << " getEmptyResultInstance() {" << endl;
if (tfunction->is_oneway()) {
indent(f_service_) << " return null;" << endl;
}
else {
indent(f_service_) << " return new " << resultname << "();" << endl;
}
indent(f_service_) << "}" << endl << endl;

indent(f_service_) << java_override_annotation() << endl;
indent(f_service_) << "public " << argsname << " getEmptyArgsInstance() {" << endl;
indent(f_service_) << " return new " << argsname << "();" << endl;
Expand Down Expand Up @@ -3931,7 +3942,7 @@ void t_java_generator::generate_process_async_function(t_service* tservice, t_fu
indent(f_service_) << "}" << endl << endl;

indent(f_service_) << java_override_annotation() << endl;
indent(f_service_) << "protected boolean isOneway() {" << endl;
indent(f_service_) << "public boolean isOneway() {" << endl;
indent(f_service_) << " return " << ((tfunction->is_oneway()) ? "true" : "false") << ";" << endl;
indent(f_service_) << "}" << endl << endl;

Expand Down Expand Up @@ -3989,7 +4000,7 @@ void t_java_generator::generate_process_function(t_service* tservice, t_function
// Open class
indent(f_service_) << "public static class " << make_valid_java_identifier(tfunction->get_name())
<< "<I extends Iface> extends org.apache.thrift.ProcessFunction<I, "
<< argsname << "> {" << endl;
<< argsname << ", " << resultname << "> {" << endl;
indent_up();

indent(f_service_) << "public " << make_valid_java_identifier(tfunction->get_name()) << "() {" << endl;
Expand All @@ -4002,7 +4013,7 @@ void t_java_generator::generate_process_function(t_service* tservice, t_function
indent(f_service_) << "}" << endl << endl;

indent(f_service_) << java_override_annotation() << endl;
indent(f_service_) << "protected boolean isOneway() {" << endl;
indent(f_service_) << "public boolean isOneway() {" << endl;
indent(f_service_) << " return " << ((tfunction->is_oneway()) ? "true" : "false") << ";" << endl;
indent(f_service_) << "}" << endl << endl;

Expand All @@ -4012,12 +4023,22 @@ void t_java_generator::generate_process_function(t_service* tservice, t_function
<< endl;
indent(f_service_) << "}" << endl << endl;

indent(f_service_) << java_override_annotation() << endl;
indent(f_service_) << "public " << resultname << " getEmptyResultInstance() {" << endl;
if (tfunction->is_oneway()) {
indent(f_service_) << " return null;" << endl;
}
else {
indent(f_service_) << " return new " << resultname << "();" << endl;
}
indent(f_service_) << "}" << endl << endl;

indent(f_service_) << java_override_annotation() << endl;
indent(f_service_) << "public " << resultname << " getResult(I iface, " << argsname
<< " args) throws org.apache.thrift.TException {" << endl;
indent_up();
if (!tfunction->is_oneway()) {
indent(f_service_) << resultname << " result = new " << resultname << "();" << endl;
indent(f_service_) << resultname << " result = getEmptyResultInstance();" << endl;
}

t_struct* xs = tfunction->get_xceptions();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,22 @@
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.server.AbstractNonblockingServer;

public abstract class AsyncProcessFunction<I, T extends TBase, R> {
public abstract class AsyncProcessFunction<I, T extends TBase, R, A extends TBase> {
final String methodName;

public AsyncProcessFunction(String methodName) {
this.methodName = methodName;
}

protected abstract boolean isOneway();
public abstract boolean isOneway();

public abstract void start(I iface, T args, AsyncMethodCallback<R> resultHandler)
throws TException;

public abstract T getEmptyArgsInstance();

public abstract A getEmptyResultInstance();

public abstract AsyncMethodCallback<R> getResultHandler(
final AbstractNonblockingServer.AsyncFrameBuffer fb, int seqid);

Expand Down
141 changes: 73 additions & 68 deletions lib/java/src/main/java/org/apache/thrift/ProcessFunction.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,86 +8,91 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class ProcessFunction<I, T extends TBase> {
private final String methodName;
public abstract class ProcessFunction<I, T extends TBase, A extends TBase> {
private final String methodName;

private static final Logger LOGGER = LoggerFactory.getLogger(ProcessFunction.class.getName());
private static final Logger LOGGER = LoggerFactory.getLogger(ProcessFunction.class.getName());

public ProcessFunction(String methodName) {
this.methodName = methodName;
}

public final void process(int seqid, TProtocol iprot, TProtocol oprot, I iface)
throws TException {
T args = getEmptyArgsInstance();
try {
args.read(iprot);
} catch (TProtocolException e) {
iprot.readMessageEnd();
TApplicationException x =
new TApplicationException(TApplicationException.PROTOCOL_ERROR, e.getMessage());
oprot.writeMessageBegin(new TMessage(getMethodName(), TMessageType.EXCEPTION, seqid));
x.write(oprot);
oprot.writeMessageEnd();
oprot.getTransport().flush();
return;
public ProcessFunction(String methodName) {
this.methodName = methodName;
}
iprot.readMessageEnd();
TSerializable result = null;
byte msgType = TMessageType.REPLY;

try {
result = getResult(iface, args);
} catch (TTransportException ex) {
LOGGER.error("Transport error while processing " + getMethodName(), ex);
throw ex;
} catch (TApplicationException ex) {
LOGGER.error("Internal application error processing " + getMethodName(), ex);
result = ex;
msgType = TMessageType.EXCEPTION;
} catch (Exception ex) {
LOGGER.error("Internal error processing " + getMethodName(), ex);
if (rethrowUnhandledExceptions()) throw new RuntimeException(ex.getMessage(), ex);
if (!isOneway()) {
result =
new TApplicationException(
TApplicationException.INTERNAL_ERROR,
"Internal error processing " + getMethodName());
msgType = TMessageType.EXCEPTION;
}
public final void process(int seqid, TProtocol iprot, TProtocol oprot, I iface)
throws TException {
T args = getEmptyArgsInstance();
try {
args.read(iprot);
} catch (TProtocolException e) {
iprot.readMessageEnd();
TApplicationException x =
new TApplicationException(TApplicationException.PROTOCOL_ERROR, e.getMessage());
oprot.writeMessageBegin(new TMessage(getMethodName(), TMessageType.EXCEPTION, seqid));
x.write(oprot);
oprot.writeMessageEnd();
oprot.getTransport().flush();
return;
}
iprot.readMessageEnd();
TSerializable result = null;
byte msgType = TMessageType.REPLY;

try {
result = getResult(iface, args);
} catch (TTransportException ex) {
LOGGER.error("Transport error while processing " + getMethodName(), ex);
throw ex;
} catch (TApplicationException ex) {
LOGGER.error("Internal application error processing " + getMethodName(), ex);
result = ex;
msgType = TMessageType.EXCEPTION;
} catch (Exception ex) {
LOGGER.error("Internal error processing " + getMethodName(), ex);
if (rethrowUnhandledExceptions()) throw new RuntimeException(ex.getMessage(), ex);
if (!isOneway()) {
result =
new TApplicationException(
TApplicationException.INTERNAL_ERROR,
"Internal error processing " + getMethodName());
msgType = TMessageType.EXCEPTION;
}
}

if (!isOneway()) {
oprot.writeMessageBegin(new TMessage(getMethodName(), msgType, seqid));
result.write(oprot);
oprot.writeMessageEnd();
oprot.getTransport().flush();
}
}

if (!isOneway()) {
oprot.writeMessageBegin(new TMessage(getMethodName(), msgType, seqid));
result.write(oprot);
oprot.writeMessageEnd();
oprot.getTransport().flush();
private void handleException(int seqid, TProtocol oprot) throws TException {
if (!isOneway()) {
TApplicationException x =
new TApplicationException(
TApplicationException.INTERNAL_ERROR, "Internal error processing " + getMethodName());
oprot.writeMessageBegin(new TMessage(getMethodName(), TMessageType.EXCEPTION, seqid));
x.write(oprot);
oprot.writeMessageEnd();
oprot.getTransport().flush();
}
}
}

private void handleException(int seqid, TProtocol oprot) throws TException {
if (!isOneway()) {
TApplicationException x =
new TApplicationException(
TApplicationException.INTERNAL_ERROR, "Internal error processing " + getMethodName());
oprot.writeMessageBegin(new TMessage(getMethodName(), TMessageType.EXCEPTION, seqid));
x.write(oprot);
oprot.writeMessageEnd();
oprot.getTransport().flush();
protected boolean rethrowUnhandledExceptions() {
return false;
}
}

protected boolean rethrowUnhandledExceptions() {
return false;
}
public abstract boolean isOneway();

protected abstract boolean isOneway();
public abstract TBase<?, ?> getResult(I iface, T args) throws TException;

public abstract TBase getResult(I iface, T args) throws TException;
public abstract T getEmptyArgsInstance();

public abstract T getEmptyArgsInstance();
/**
* Returns null when this is a oneWay function.
*/
public abstract A getEmptyResultInstance();

public String getMethodName() {
return methodName;
}
public String getMethodName() {
return methodName;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@ public class TBaseAsyncProcessor<I> implements TAsyncProcessor, TProcessor {
protected final Logger LOGGER = LoggerFactory.getLogger(getClass().getName());

final I iface;
final Map<String, AsyncProcessFunction<I, ? extends TBase, ?>> processMap;
final Map<String, AsyncProcessFunction<I, ? extends TBase, ?, ? extends TBase>> processMap;

public TBaseAsyncProcessor(
I iface, Map<String, AsyncProcessFunction<I, ? extends TBase, ?>> processMap) {
I iface, Map<String, AsyncProcessFunction<I, ? extends TBase, ?, ? extends TBase>> processMap) {
this.iface = iface;
this.processMap = processMap;
}

public Map<String, AsyncProcessFunction<I, ? extends TBase, ?>> getProcessMapView() {
public Map<String, AsyncProcessFunction<I, ? extends TBase, ?, ? extends TBase>> getProcessMapView() {
return Collections.unmodifiableMap(processMap);
}

Expand Down
6 changes: 3 additions & 3 deletions lib/java/src/main/java/org/apache/thrift/TBaseProcessor.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@

public abstract class TBaseProcessor<I> implements TProcessor {
private final I iface;
private final Map<String, ProcessFunction<I, ? extends TBase>> processMap;
private final Map<String, ProcessFunction<I, ? extends TBase, ? extends TBase>> processMap;

protected TBaseProcessor(
I iface, Map<String, ProcessFunction<I, ? extends TBase>> processFunctionMap) {
I iface, Map<String, ProcessFunction<I, ? extends TBase, ? extends TBase>> processFunctionMap) {
this.iface = iface;
this.processMap = processFunctionMap;
}

public Map<String, ProcessFunction<I, ? extends TBase>> getProcessMapView() {
public Map<String, ProcessFunction<I, ? extends TBase, ? extends TBase>> getProcessMapView() {
return Collections.unmodifiableMap(processMap);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ private void handleIO() {
} else if (selected.isWritable()) {
saslHandler.handleWrite();
} else {
LOGGER.error("Invalid intrest op " + selected.interestOps());
LOGGER.error("Invalid interest op " + selected.interestOps());
closeChannel(selected);
continue;
}
Expand Down

0 comments on commit b6cf049

Please sign in to comment.