Skip to content

Commit

Permalink
ARROW-7928: [Python] Update Python flight server and client examples …
Browse files Browse the repository at this point in the history
…for latest API

Patch by @ravindra-wagh
(had to redo a new branch and GitHub PR because of git UI crappiness, sorry)

Closes #6479 from pitrou/ARROW-7928-update-py-flight-examples and squashes the following commits:

d54dd7d <Antoine Pitrou> Some nits
269590f <ravindra-wagh> Updated client to upload a csv file to the server
322dace <ravindra-wagh> Update server.py

Lead-authored-by: ravindra-wagh <57516622+ravindra-wagh@users.noreply.github.com>
Co-authored-by: Antoine Pitrou <antoine@python.org>
Signed-off-by: David Li <li.davidm96@gmail.com>
  • Loading branch information
2 people authored and lidavidm committed Feb 25, 2020
1 parent 23d74c0 commit fb8868d
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 21 deletions.
30 changes: 24 additions & 6 deletions python/examples/flight/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import pyarrow
import pyarrow.flight
import pyarrow.csv as csv


def list_flights(args, client):
Expand All @@ -48,10 +49,8 @@ def list_flights(args, client):
print("Unknown")

print("Number of endpoints:", len(flight.endpoints))

if args.list:
print(flight.schema)

print("Schema:")
print(flight.schema)
print('---')

print('\nActions\n=======')
Expand All @@ -72,6 +71,18 @@ def do_action(args, client):
print("Error calling action:", e)


def push_data(args, client):
print('File Name:', args.file)
my_table = csv.read_csv(args.file)
print ('Table rows=', str(len(my_table)))
df = my_table.to_pandas()
print(df.head())
writer, _ = client.do_put(
pyarrow.flight.FlightDescriptor.for_path(args.file), my_table.schema)
writer.write_table(my_table)
writer.close()


def get_flight(args, client):
if args.path:
descriptor = pyarrow.flight.FlightDescriptor.for_path(*args.path)
Expand All @@ -83,7 +94,7 @@ def get_flight(args, client):
print('Ticket:', endpoint.ticket)
for location in endpoint.locations:
print(location)
get_client = pyarrow.flight.FlightClient.connect(location)
get_client = pyarrow.flight.FlightClient(location)
reader = get_client.do_get(endpoint.ticket)
df = reader.read_pandas()
print(df)
Expand Down Expand Up @@ -112,6 +123,12 @@ def main():
cmd_do.add_argument('action_type', type=str,
help="The action type to run.")

cmd_put = subcommands.add_parser('put')
cmd_put.set_defaults(action='put')
_add_common_arguments(cmd_put)
cmd_put.add_argument('file', type=str,
help="CSV file to upload.")

cmd_get = subcommands.add_parser('get')
cmd_get.set_defaults(action='get')
_add_common_arguments(cmd_get)
Expand All @@ -130,6 +147,7 @@ def main():
'list': list_flights,
'do': do_action,
'get': get_flight,
'put': push_data,
}
host, port = args.host.split(':')
port = int(port)
Expand All @@ -140,7 +158,7 @@ def main():
if args.tls_roots:
with open(args.tls_roots, "rb") as root_certs:
connection_args["tls_root_certs"] = root_certs.read()
client = pyarrow.flight.FlightClient.connect(f"{scheme}://{host}:{port}",
client = pyarrow.flight.FlightClient(f"{scheme}://{host}:{port}",
**connection_args)
while True:
try:
Expand Down
39 changes: 24 additions & 15 deletions python/examples/flight/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@


class FlightServer(pyarrow.flight.FlightServerBase):
def __init__(self):
super(FlightServer, self).__init__()
def __init__(self, host="localhost", location=None, **kwargs):
super(FlightServer, self).__init__(location, **kwargs)
self.flights = {}
self.host = host

@classmethod
def descriptor_to_key(self, descriptor):
Expand All @@ -45,27 +46,36 @@ def list_flights(self, context, criteria):
descriptor = pyarrow.flight.FlightDescriptor.for_path(*key[2])

endpoints = [
pyarrow.flight.FlightEndpoint(repr(key),
[('localhost', 5005)]),
pyarrow.flight.FlightEndpoint(
repr(key),
[pyarrow.flight.Location.for_grpc_tcp(self.host, self.port)]),
]

mock_sink = pyarrow.MockOutputStream()
stream_writer = pyarrow.RecordBatchStreamWriter(mock_sink, table.schema)
stream_writer.write_table(table)
stream_writer.close()
data_size = mock_sink.size()

yield pyarrow.flight.FlightInfo(table.schema,
descriptor, endpoints,
table.num_rows, 0)
table.num_rows, data_size)

def get_flight_info(self, context, descriptor):
key = FlightServer.descriptor_to_key(descriptor)
if key in self.flights:
table = self.flights[key]
print(table.schema)
endpoints = [
pyarrow.flight.FlightEndpoint(repr(key),
[('localhost', 5005)]),
[pyarrow.flight.Location.for_grpc_tcp(self.host, self.port)]),
]
return pyarrow.flight.FlightInfo(table.schema,
descriptor, endpoints,
table.num_rows, 0)
raise KeyError('Flight not found.')

def do_put(self, context, descriptor, reader):
def do_put(self, context, descriptor, reader, writer):
key = FlightServer.descriptor_to_key(descriptor)
print(key)
self.flights[key] = reader.read_all()
Expand Down Expand Up @@ -95,7 +105,7 @@ def do_action(self, context, action):
# request
threading.Thread(target=self._shutdown).start()
else:
raise KeyError(f"Unknown action {action.type!r}")
raise KeyError("Unknown action {!r}".format(action.type))

def _shutdown(self):
"""Shut down after a delay."""
Expand All @@ -106,12 +116,11 @@ def _shutdown(self):

def main():
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=5005)
parser.add_argument("--tls", nargs=2, default=None)

args = parser.parse_args()

server = FlightServer()
kwargs = {}
scheme = "grpc+tcp"
if args.tls:
Expand All @@ -120,12 +129,12 @@ def main():
kwargs["tls_cert_chain"] = cert_file.read()
with open(args.tls[1], "rb") as key_file:
kwargs["tls_private_key"] = key_file.read()

location = "{}://0.0.0.0:{}".format(scheme, args.port)
server.init(location, **kwargs)
location = "{}://{}:{}".format(scheme, args.host, args.port)
server = FlightServer(args.host, location, **kwargs)
print("Serving on", location)
server.run()

server.serve()

if __name__ == '__main__':
main()

0 comments on commit fb8868d

Please sign in to comment.