Skip to content
Permalink
Browse files
IGNITE-15479 Fix incorrect partial read from socket in sync client - F…
…ixes #50.
  • Loading branch information
ivandasch committed Sep 9, 2021
1 parent ef8687e commit 3bf1cc1ad9e56a3b74a9abbb8a586495afb40169
Showing 2 changed files with 57 additions and 16 deletions.
@@ -156,6 +156,9 @@ def _connection_listener(self):
return self.client._event_listeners


DEFAULT_INITIAL_BUF_SIZE = 1024


class Connection(BaseConnection):
"""
This is a `pyignite` class, that represents a connection to Ignite
@@ -348,39 +351,35 @@ def recv(self, flags=None, reconnect=True) -> bytearray:
if flags is not None:
kwargs['flags'] = flags

data = bytearray(1024)
data = bytearray(DEFAULT_INITIAL_BUF_SIZE)
buffer = memoryview(data)
bytes_total_received, bytes_to_receive = 0, 0
total_rcvd, packet_len = 0, 0
while True:
try:
bytes_received = self._socket.recv_into(buffer, len(buffer), **kwargs)
if bytes_received == 0:
bytes_rcvd = self._socket.recv_into(buffer, len(buffer), **kwargs)
if bytes_rcvd == 0:
raise SocketError('Connection broken.')
bytes_total_received += bytes_received
total_rcvd += bytes_rcvd
except connection_errors as e:
self.failed = True
if reconnect:
self._on_connection_lost(e)
self.reconnect()
raise e

if bytes_total_received < 4:
continue
elif bytes_to_receive == 0:
response_len = int.from_bytes(data[0:4], PROTOCOL_BYTE_ORDER)
bytes_to_receive = response_len

if response_len + 4 > len(data):
if packet_len == 0 and total_rcvd > 4:
packet_len = int.from_bytes(data[0:4], PROTOCOL_BYTE_ORDER, signed=True) + 4
if packet_len > len(data):
buffer.release()
data.extend(bytearray(response_len + 4 - len(data)))
buffer = memoryview(data)[bytes_total_received:]
data.extend(bytearray(packet_len - len(data)))
buffer = memoryview(data)[total_rcvd:]
continue

if bytes_total_received >= bytes_to_receive:
if 0 < packet_len <= total_rcvd:
buffer.release()
break

buffer = buffer[bytes_received:]
buffer = buffer[bytes_rcvd:]

return data

@@ -0,0 +1,42 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import secrets
import socket
import unittest.mock as mock

import pytest

from pyignite import Client
from tests.util import get_or_create_cache

old_recv_into = socket.socket.recv_into


def patched_recv_into_factory(buf_len):
def patched_recv_into(self, buffer, nbytes, **kwargs):
return old_recv_into(self, buffer, min(nbytes, buf_len) if buf_len else nbytes, **kwargs)
return patched_recv_into


@pytest.mark.parametrize('buf_len', [0, 1, 4, 16, 32, 64, 128, 256, 512, 1024])
def test_get_large_value(buf_len):
with mock.patch.object(socket.socket, 'recv_into', new=patched_recv_into_factory(buf_len)):
c = Client()
with c.connect("127.0.0.1", 10801):
with get_or_create_cache(c, 'test') as cache:
value = secrets.token_hex((1 << 16) + 1)
cache.put(1, value)
assert value == cache.get(1)

0 comments on commit 3bf1cc1

Please sign in to comment.