Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions cpp/src/plasma/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ class PlasmaClient::Impl : public std::enable_shared_from_this<PlasmaClient::Imp

Status Subscribe(int* fd);

Status DecodeNotification(const uint8_t* buffer, ObjectID* object_id,
int64_t* data_size, int64_t* metadata_size);

Status GetNotification(int fd, ObjectID* object_id, int64_t* data_size,
int64_t* metadata_size);

Expand Down Expand Up @@ -942,13 +945,10 @@ Status PlasmaClient::Impl::Subscribe(int* fd) {
return Status::OK();
}

Status PlasmaClient::Impl::GetNotification(int fd, ObjectID* object_id,
int64_t* data_size, int64_t* metadata_size) {
auto notification = ReadMessageAsync(fd);
if (notification == NULL) {
return Status::IOError("Failed to read object notification from Plasma socket");
}
auto object_info = flatbuffers::GetRoot<fb::ObjectInfo>(notification.get());
Status PlasmaClient::Impl::DecodeNotification(const uint8_t* buffer, ObjectID* object_id,
int64_t* data_size,
int64_t* metadata_size) {
auto object_info = flatbuffers::GetRoot<fb::ObjectInfo>(buffer);
ARROW_CHECK(object_info->object_id()->size() == sizeof(ObjectID));
memcpy(object_id, object_info->object_id()->data(), sizeof(ObjectID));
if (object_info->is_deletion()) {
Expand All @@ -961,6 +961,15 @@ Status PlasmaClient::Impl::GetNotification(int fd, ObjectID* object_id,
return Status::OK();
}

Status PlasmaClient::Impl::GetNotification(int fd, ObjectID* object_id,
int64_t* data_size, int64_t* metadata_size) {
auto notification = ReadMessageAsync(fd);
if (notification == NULL) {
return Status::IOError("Failed to read object notification from Plasma socket");
}
return DecodeNotification(notification.get(), object_id, data_size, metadata_size);
}

Status PlasmaClient::Impl::Connect(const std::string& store_socket_name,
const std::string& manager_socket_name,
int release_delay, int num_retries) {
Expand Down Expand Up @@ -1137,6 +1146,11 @@ Status PlasmaClient::GetNotification(int fd, ObjectID* object_id, int64_t* data_
return impl_->GetNotification(fd, object_id, data_size, metadata_size);
}

Status PlasmaClient::DecodeNotification(const uint8_t* buffer, ObjectID* object_id,
int64_t* data_size, int64_t* metadata_size) {
return impl_->DecodeNotification(buffer, object_id, data_size, metadata_size);
}

Status PlasmaClient::Disconnect() { return impl_->Disconnect(); }

Status PlasmaClient::Fetch(int num_object_ids, const ObjectID* object_ids) {
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/plasma/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ class ARROW_EXPORT PlasmaClient {
Status GetNotification(int fd, ObjectID* object_id, int64_t* data_size,
int64_t* metadata_size);

Status DecodeNotification(const uint8_t* buffer, ObjectID* object_id,
int64_t* data_size, int64_t* metadata_size);

/// Disconnect from the local plasma instance, including the local store and
/// manager.
///
Expand Down
38 changes: 38 additions & 0 deletions python/pyarrow/_plasma.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ from cpython.pycapsule cimport *
import collections
import pyarrow
import random
import socket

from pyarrow.lib cimport Buffer, NativeFile, check_status, pyarrow_wrap_buffer
from pyarrow.includes.libarrow cimport (CBuffer, CMutableBuffer,
CFixedSizeBufferWriter, CStatus)

from pyarrow import compat

PLASMA_WAIT_TIMEOUT = 2 ** 30

Expand Down Expand Up @@ -131,6 +133,10 @@ cdef extern from "plasma/client.h" nogil:

CStatus Subscribe(int* fd)

CStatus DecodeNotification(const uint8_t* buffer,
CUniqueID* object_id, int64_t* data_size,
int64_t* metadata_size)

CStatus GetNotification(int fd, CUniqueID* object_id,
int64_t* data_size, int64_t* metadata_size)

Expand Down Expand Up @@ -729,6 +735,38 @@ cdef class PlasmaClient:
with nogil:
check_status(self.client.get().Subscribe(&self.notification_fd))

def get_notification_socket(self):
"""
Get the notification socket.
"""
return compat.get_socket_from_fd(self.notification_fd,
family=socket.AF_UNIX,
type=socket.SOCK_STREAM)

def decode_notification(self, const uint8_t* buf):
"""
Get the notification from the buffer.

Returns
-------
ObjectID
The object ID of the object that was stored.
int
The data size of the object that was stored.
int
The metadata size of the object that was stored.
"""
cdef CUniqueID object_id
cdef int64_t data_size
cdef int64_t metadata_size
with nogil:
check_status(self.client.get()
.DecodeNotification(buf,
&object_id,
&data_size,
&metadata_size))
return ObjectID(object_id.binary()), data_size, metadata_size

def get_next_notification(self):
"""
Get the next notification from the notification socket.
Expand Down
10 changes: 10 additions & 0 deletions python/pyarrow/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import sys
import six
from six import BytesIO, StringIO, string_types as py_string
import socket


PY26 = sys.version_info[:2] == (2, 6)
Expand Down Expand Up @@ -267,4 +268,13 @@ def import_pytorch_extension():

integer_types = six.integer_types + (np.integer,)


def get_socket_from_fd(fileno, family, type):
if PY2:
socket_obj = socket.fromfd(fileno, family, type)
return socket.socket(family, type, _sock=socket_obj)
else:
return socket.socket(fileno=fileno, family=family, type=type)


__all__ = []
29 changes: 29 additions & 0 deletions python/pyarrow/tests/test_plasma.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import pytest
import random
import signal
import struct
import subprocess
import sys
import time
Expand Down Expand Up @@ -742,6 +743,34 @@ def test_subscribe(self):
assert data_sizes[j] == recv_dsize
assert metadata_sizes[j] == recv_msize

def test_subscribe_socket(self):
# Subscribe to notifications from the Plasma Store.
self.plasma_client.subscribe()
rsock = self.plasma_client.get_notification_socket()
for i in self.SUBSCRIBE_TEST_SIZES:
# Get notification from socket.
object_ids = [random_object_id() for _ in range(i)]
metadata_sizes = [np.random.randint(1000) for _ in range(i)]
data_sizes = [np.random.randint(1000) for _ in range(i)]

for j in range(i):
self.plasma_client.create(
object_ids[j], data_sizes[j],
metadata=bytearray(np.random.bytes(metadata_sizes[j])))
self.plasma_client.seal(object_ids[j])

# Check that we received notifications for all of the objects.
for j in range(i):
# Assume the plasma store will not be full,
# so we always get the data size instead of -1.
msg_len, = struct.unpack('L', rsock.recv(8))
content = rsock.recv(msg_len)
recv_objid, recv_dsize, recv_msize = (
self.plasma_client.decode_notification(content))
assert object_ids[j] == recv_objid
assert data_sizes[j] == recv_dsize
assert metadata_sizes[j] == recv_msize

def test_subscribe_deletions(self):
# Subscribe to notifications from the Plasma Store. We use
# plasma_client2 to make sure that all used objects will get evicted
Expand Down