Skip to content

Commit

Permalink
ARROW-8304: [Flight][Python] Fix client example with TLS
Browse files Browse the repository at this point in the history
The `get` command wouldn't use the adequate TLS root certs when fetching a Flight from its endpoints.

Also fix style in the Python examples and configure `archery lint` to check them.

Closes #6808 from pitrou/ARROW-8304-py-flight-client-tls

Lead-authored-by: Antoine Pitrou <antoine@python.org>
Co-authored-by: Ravindra Wagh <ravindra.wagh@cambridgesemantics.com>
Signed-off-by: Antoine Pitrou <antoine@python.org>
  • Loading branch information
pitrou and ravindra-wagh committed Apr 2, 2020
1 parent 43e6172 commit 5ab4930
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 20 deletions.
5 changes: 3 additions & 2 deletions dev/archery/archery/utils/lint.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,9 @@ def python_linter(src):
return

setup_py = os.path.join(src.python, "setup.py")
yield LintResult.from_cmd(flake8(setup_py, src.pyarrow, src.dev,
check=False))
yield LintResult.from_cmd(flake8(setup_py, src.pyarrow,
os.path.join(src.python, "examples"),
src.dev, check=False))
config = os.path.join(src.python, ".flake8.cython")
yield LintResult.from_cmd(flake8("--config=" + config, src.pyarrow,
check=False))
Expand Down
23 changes: 12 additions & 11 deletions python/examples/flight/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import pyarrow.csv as csv


def list_flights(args, client):
def list_flights(args, client, connection_args={}):
print('Flights\n=======')
for flight in client.list_flights():
descriptor = flight.descriptor
Expand Down Expand Up @@ -60,7 +60,7 @@ def list_flights(args, client):
print('---')


def do_action(args, client):
def do_action(args, client, connection_args={}):
try:
buf = pyarrow.allocate_buffer(0)
action = pyarrow.flight.Action(args.action_type, buf)
Expand All @@ -71,19 +71,19 @@ def do_action(args, client):
print("Error calling action:", e)


def push_data(args, client):
def push_data(args, client, connection_args={}):
print('File Name:', args.file)
my_table = csv.read_csv(args.file)
print ('Table rows=', str(len(my_table)))
print('Table rows=', str(len(my_table)))
df = my_table.to_pandas()
print(df.head())
writer, _ = client.do_put(
pyarrow.flight.FlightDescriptor.for_path(args.file), my_table.schema)
pyarrow.flight.FlightDescriptor.for_path(args.file), my_table.schema)
writer.write_table(my_table)
writer.close()


def get_flight(args, client):
def get_flight(args, client, connection_args={}):
if args.path:
descriptor = pyarrow.flight.FlightDescriptor.for_path(*args.path)
else:
Expand All @@ -94,7 +94,8 @@ def get_flight(args, client):
print('Ticket:', endpoint.ticket)
for location in endpoint.locations:
print(location)
get_client = pyarrow.flight.FlightClient(location)
get_client = pyarrow.flight.FlightClient(location,
**connection_args)
reader = get_client.do_get(endpoint.ticket)
df = reader.read_pandas()
print(df)
Expand Down Expand Up @@ -129,8 +130,8 @@ def main():
cmd_put.set_defaults(action='put')
_add_common_arguments(cmd_put)
cmd_put.add_argument('file', type=str,
help="CSV file to upload.")
help="CSV file to upload.")

cmd_get = subcommands.add_parser('get')
cmd_get.set_defaults(action='get')
_add_common_arguments(cmd_get)
Expand Down Expand Up @@ -161,7 +162,7 @@ def main():
with open(args.tls_roots, "rb") as root_certs:
connection_args["tls_root_certs"] = root_certs.read()
client = pyarrow.flight.FlightClient(f"{scheme}://{host}:{port}",
**connection_args)
**connection_args)
while True:
try:
action = pyarrow.flight.Action("healthcheck", b"")
Expand All @@ -171,7 +172,7 @@ def main():
except pyarrow.ArrowIOError as e:
if "Deadline" in str(e):
print("Server is not ready, waiting...")
commands[args.action](args, client)
commands[args.action](args, client, connection_args)


if __name__ == '__main__':
Expand Down
19 changes: 12 additions & 7 deletions python/examples/flight/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@


class FlightServer(pyarrow.flight.FlightServerBase):
def __init__(self, host="localhost", location=None, tls_certificates=None, auth_handler=None):
super(FlightServer, self).__init__(location, auth_handler, tls_certificates)
def __init__(self, host="localhost", location=None,
tls_certificates=None, auth_handler=None):
super(FlightServer, self).__init__(
location, auth_handler, tls_certificates)
self.flights = {}
self.host = host
self.tls_certificates = tls_certificates
Expand All @@ -40,13 +42,16 @@ def descriptor_to_key(self, descriptor):

def _make_flight_info(self, key, descriptor, table):
if self.tls_certificates:
location = pyarrow.flight.Location.for_grpc_tls(self.host, self.port)
location = pyarrow.flight.Location.for_grpc_tls(
self.host, self.port)
else:
location = pyarrow.flight.Location.for_grpc_tcp(self.host, self.port)
endpoints = [pyarrow.flight.FlightEndpoint(repr(key), [location]),]
location = pyarrow.flight.Location.for_grpc_tcp(
self.host, self.port)
endpoints = [pyarrow.flight.FlightEndpoint(repr(key), [location]), ]

mock_sink = pyarrow.MockOutputStream()
stream_writer = pyarrow.RecordBatchStreamWriter(mock_sink, table.schema)
stream_writer = pyarrow.RecordBatchStreamWriter(
mock_sink, table.schema)
stream_writer.write_table(table)
stream_writer.close()
data_size = mock_sink.size()
Expand Down Expand Up @@ -139,6 +144,6 @@ def main():
print("Serving on", location)
server.serve()


if __name__ == '__main__':
main()

0 comments on commit 5ab4930

Please sign in to comment.