Skip to content

Commit

Permalink
ARROW-5978: [FlightRPC] [Java] Properly release buffers in Flight int…
Browse files Browse the repository at this point in the history
…egration client

Fixes a bug where dictionaries weren't properly released in cleaning up a flight stream.

Travis build: https://travis-ci.com/lihalite/arrow/builds/119807464

Recreated from #4905

Closes #4913 from lihalite/flight-leak and squashes the following commits:

ac8ba8d <David Li> Improve documentation/tests for FlightStream dictionary provider
bca02a7 <David Li> Add test case for freeing dictionaries in Flight
a096a80 <David Li> Properly release buffers in Flight integration client

Authored-by: David Li <li.davidm96@gmail.com>
Signed-off-by: Micah Kornfield <emkornfield@gmail.com>
  • Loading branch information
lidavidm authored and emkornfield committed Aug 20, 2019
1 parent 8f690e3 commit a40d6b6
Show file tree
Hide file tree
Showing 8 changed files with 271 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ public ArrowDictionaryBatch asDictionaryBatch() throws IOException {
Preconditions.checkArgument(bufs.size() == 1, "A batch can only be consumed if it contains a single ArrowBuf.");
Preconditions.checkArgument(getMessageType() == HeaderType.DICTIONARY_BATCH);
ArrowBuf underlying = bufs.get(0);
// Retain a reference to keep the batch alive when the message is closed
underlying.getReferenceManager().retain();
return MessageSerializer.deserializeDictionaryBatch(message, underlying);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
import java.util.List;
import java.util.Set;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.VectorUnloader;
Expand Down Expand Up @@ -74,4 +76,14 @@ static Schema generateSchemaMessages(final Schema originalSchema, final FlightDe
}
return schema;
}

static void closeDictionaries(final Schema schema, final DictionaryProvider provider) throws Exception {
// Close dictionaries
final Set<Long> dictionaryIds = new HashSet<>();
schema.getFields().forEach(field -> DictionaryUtility.toMessageFormat(field, provider, dictionaryIds));

final List<AutoCloseable> dictionaryVectors = dictionaryIds.stream()
.map(id -> (AutoCloseable) provider.lookup(id).getVector()).collect(Collectors.toList());
AutoCloseables.close(dictionaryVectors);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.stream.Collectors;

import org.apache.arrow.flight.ArrowMessage.HeaderType;
import org.apache.arrow.flight.grpc.StatusUtils;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.FieldVector;
Expand All @@ -40,7 +41,6 @@
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.DictionaryUtility;

import com.google.common.base.Throwables;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
import com.google.common.util.concurrent.SettableFuture;
Expand Down Expand Up @@ -95,10 +95,42 @@ public Schema getSchema() {
return schema;
}

/**
* Get the provider for dictionaries in this stream.
*
* <p>Does NOT retain a reference to the underlying dictionaries. Dictionaries may be updated as the stream is read.
* This method is intended for stream processing, where the application code will not retain references to values
* after the stream is closed.
*
* @throws IllegalStateException if {@link #takeDictionaryOwnership()} was called
* @see #takeDictionaryOwnership()
*/
public DictionaryProvider getDictionaryProvider() {
if (dictionaries == null) {
throw new IllegalStateException("Dictionary ownership was claimed by the application.");
}
return dictionaries;
}

/**
* Get an owned reference to the dictionaries in this stream. Should be called after finishing reading the stream,
* but before closing.
*
* <p>If called, the client is responsible for closing the dictionaries in this provider. Can only be called once.
*
* @return The dictionary provider for the stream.
* @throws IllegalStateException if called more than once.
*/
public DictionaryProvider takeDictionaryOwnership() {
if (dictionaries == null) {
throw new IllegalStateException("Dictionary ownership was claimed by the application.");
}
// Swap out the provider so it is not closed
final DictionaryProvider provider = dictionaries;
dictionaries = null;
return provider;
}

public FlightDescriptor getDescriptor() {
return descriptor;
}
Expand All @@ -117,8 +149,13 @@ public void close() throws Exception {
.map(t -> ((AutoCloseable) t))
.collect(Collectors.toList());

final List<FieldVector> dictionaryVectors =
dictionaries == null ? Collections.emptyList() : dictionaries.getDictionaryIds().stream()
.map(id -> dictionaries.lookup(id).getVector()).collect(Collectors.toList());

// Must check for null since ImmutableList doesn't accept nulls
AutoCloseables.close(Iterables.concat(closeables,
dictionaryVectors,
applicationMetadata != null ? ImmutableList.of(root.get(), applicationMetadata)
: ImmutableList.of(root.get())));
}
Expand Down Expand Up @@ -168,6 +205,9 @@ public boolean next() {
} else if (msg.getMessageType() == HeaderType.DICTIONARY_BATCH) {
try (ArrowDictionaryBatch arb = msg.asDictionaryBatch()) {
final long id = arb.getDictionaryId();
if (dictionaries == null) {
throw new IllegalStateException("Dictionary ownership was claimed by the application.");
}
final Dictionary dictionary = dictionaries.lookup(id);
if (dictionary == null) {
throw new IllegalArgumentException("Dictionary not defined in schema: ID " + id);
Expand Down Expand Up @@ -195,8 +235,10 @@ public boolean next() {
public VectorSchemaRoot getRoot() {
try {
return root.get();
} catch (InterruptedException | ExecutionException e) {
throw Throwables.propagate(e);
} catch (InterruptedException e) {
throw CallStatus.INTERNAL.withCause(e).toRuntimeException();
} catch (ExecutionException e) {
throw StatusUtils.fromThrowable(e.getCause());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import org.apache.arrow.util.AutoCloseables;

/**
* An Example Flight Server that provides access to the InMemoryStore.
* An Example Flight Server that provides access to the InMemoryStore. Used for integration testing.
*/
public class ExampleFlightServer implements AutoCloseable {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
package org.apache.arrow.flight.example;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.stream.Collectors;

Expand All @@ -31,6 +33,7 @@
import org.apache.arrow.util.Preconditions;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.DictionaryUtility;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Iterables;
Expand Down Expand Up @@ -106,6 +109,13 @@ public FlightInfo getFlightInfo(final Location l) {

@Override
public void close() throws Exception {
AutoCloseables.close(Iterables.concat(streams, ImmutableList.of(allocator)));
// Close dictionaries
final Set<Long> dictionaryIds = new HashSet<>();
schema.getFields().forEach(field -> DictionaryUtility.toMessageFormat(field, dictionaryProvider, dictionaryIds));

final Iterable<AutoCloseable> dictionaries = dictionaryIds.stream()
.map(id -> (AutoCloseable) dictionaryProvider.lookup(id).getVector())::iterator;

AutoCloseables.close(Iterables.concat(streams, ImmutableList.of(allocator), dictionaries));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
import org.apache.arrow.vector.VectorUnloader;

/**
* A FlightProducer that hosts an in memory store of Arrow buffers.
* A FlightProducer that hosts an in memory store of Arrow buffers. Used for integration testing.
*/
public class InMemoryStore implements FlightProducer, AutoCloseable {

Expand Down Expand Up @@ -80,8 +80,7 @@ public Stream getStream(Ticket t) {
}

@Override
public void listFlights(CallContext context, Criteria criteria,
StreamListener<FlightInfo> listener) {
public void listFlights(CallContext context, Criteria criteria, StreamListener<FlightInfo> listener) {
try {
for (FlightHolder h : holders.values()) {
listener.onNext(h.getFlightInfo(location));
Expand All @@ -93,8 +92,7 @@ public void listFlights(CallContext context, Criteria criteria,
}

@Override
public FlightInfo getFlightInfo(CallContext context,
FlightDescriptor descriptor) {
public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) {
FlightHolder h = holders.get(descriptor);
if (h == null) {
throw new IllegalStateException("Unknown descriptor.");
Expand All @@ -121,6 +119,8 @@ public Runnable acceptPut(CallContext context,
ackStream.onNext(PutResult.metadata(flightStream.getLatestMetadata()));
creator.add(unloader.getRecordBatch());
}
// Closing the stream will release the dictionaries
flightStream.takeDictionaryOwnership();
creator.complete();
success = true;
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.VectorUnloader;
import org.apache.arrow.vector.ipc.JsonFileReader;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.util.Validator;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.CommandLineParser;
Expand Down Expand Up @@ -84,12 +85,19 @@ private void run(String[] args) throws ParseException, IOException {
final String host = cmd.getOptionValue("host", "localhost");
final int port = Integer.parseInt(cmd.getOptionValue("port", "31337"));

final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
final Location defaultLocation = Location.forGrpcInsecure(host, port);
final FlightClient client = FlightClient.builder(allocator, defaultLocation).build();
try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
final FlightClient client = FlightClient.builder(allocator, defaultLocation).build()) {

final String inputPath = cmd.getOptionValue("j");
final String inputPath = cmd.getOptionValue("j");
testStream(allocator, defaultLocation, client, inputPath);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}

private static void testStream(BufferAllocator allocator, Location server, FlightClient client, String inputPath)
throws IOException {
// 1. Read data from JSON and upload to server.
FlightDescriptor descriptor = FlightDescriptor.path(inputPath);
VectorSchemaRoot jsonRoot;
Expand Down Expand Up @@ -121,7 +129,9 @@ public void onNext(PutResult val) {
metadata.writeBytes(rawMetadata);
// Transfers ownership of the buffer, so do not release it ourselves
stream.putNext(metadata);
jsonLoader.load(unloader.getRecordBatch());
try (final ArrowRecordBatch arb = unloader.getRecordBatch()) {
jsonLoader.load(arb);
}
root.clear();
counter++;
}
Expand All @@ -141,25 +151,29 @@ public void onNext(PutResult val) {
// 3. Download the data from the server.
List<Location> locations = endpoint.getLocations();
if (locations.size() == 0) {
locations = Collections.singletonList(defaultLocation);
locations = Collections.singletonList(server);
}
for (Location location : locations) {
System.out.println("Verifying location " + location.getUri());
FlightClient readClient = FlightClient.builder(allocator, location).build();
FlightStream stream = readClient.getStream(endpoint.getTicket());
VectorSchemaRoot downloadedRoot;
try (VectorSchemaRoot root = stream.getRoot()) {
downloadedRoot = VectorSchemaRoot.create(root.getSchema(), allocator);
try (FlightClient readClient = FlightClient.builder(allocator, location).build();
FlightStream stream = readClient.getStream(endpoint.getTicket());
VectorSchemaRoot root = stream.getRoot();
VectorSchemaRoot downloadedRoot = VectorSchemaRoot.create(root.getSchema(), allocator)) {
VectorLoader loader = new VectorLoader(downloadedRoot);
VectorUnloader unloader = new VectorUnloader(root);
while (stream.next()) {
loader.load(unloader.getRecordBatch());
try (final ArrowRecordBatch arb = unloader.getRecordBatch()) {
loader.load(arb);
}
}
}

// 4. Validate the data.
Validator.compareVectorSchemaRoot(jsonRoot, downloadedRoot);
// 4. Validate the data.
Validator.compareVectorSchemaRoot(jsonRoot, downloadedRoot);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
jsonRoot.close();
}
}

0 comments on commit a40d6b6

Please sign in to comment.