Skip to content

Commit

Permalink
dsf: simplfied and centralized module configuration and introduced so…
Browse files Browse the repository at this point in the history
…cket timeout
  • Loading branch information
mfs12 committed Jun 9, 2021
1 parent 6d1ffcb commit 15accc1
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 24 deletions.
10 changes: 7 additions & 3 deletions src/dsf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
SOCKET_DIRECTORY = "/run/dsf"
SOCKET_FILE = "dcs.sock"
FULL_SOCKET_PATH = SOCKET_DIRECTORY + "/" + SOCKET_FILE
# path to unix socket file
SOCKET_FILE = "/run/dsf/dsf.sock"

# allowed connection per unix server
DEFAULT_BACKLOG = 4

# DSF protocol version
PROTOCOL_VERSION = 11
40 changes: 25 additions & 15 deletions src/dsf/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import socket
from typing import Optional

from . import DEFAULT_BACKLOG, FULL_SOCKET_PATH
from . import DEFAULT_BACKLOG, SOCKET_FILE
from .commands import responses, basecommands, code, result, codechannel
from .commands.basecommands import MessageType, LogLevel
from .initmessages import serverinitmessage, clientinitmessages
Expand Down Expand Up @@ -50,20 +50,21 @@ class BaseConnection:
using a UNIX socket
"""

def __init__(self, debug: bool = False):
def __init__(self, debug: bool = False, timeout: int = 3):
self.debug = debug
self.timeout = timeout
self.socket: Optional[socket.socket] = None
self.id = None
self.input = ""

def connect(
self, init_message: clientinitmessages.ClientInitMessage, socket_path: str
self, init_message: clientinitmessages.ClientInitMessage, socket_file: str
):
"""Establishes a connection to the given UNIX socket file"""

self.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.socket.connect(socket_path)
self.socket.setblocking(True)
self.socket.connect(socket_file)
self.socket.settimeout(self.timeout)
server_init_message = serverinitmessage.ServerInitMessage.from_json(
json.loads(self.socket.recv(50).decode("utf8"))
)
Expand Down Expand Up @@ -147,12 +148,21 @@ def receive_json(self) -> str:
# Refill the buffer and check again
BUFF_SIZE = 4096 # 4 KiB
data = b""
part = b""
while True:
part = self.socket.recv(BUFF_SIZE)
data += part
try:
part = self.socket.recv(BUFF_SIZE)
data += part
except socket.timeout:
pass
except Exception as e:
raise e
# either 0 or end of data
if len(part) == 0:
raise TimeoutError
if len(part) < BUFF_SIZE:
break

json_string += data.decode("utf8")

end_index = self.get_json_object_end_index(json_string)
Expand Down Expand Up @@ -210,9 +220,9 @@ def add_http_endpoint(
endpoint_type, namespace, path, is_upload_request
)
)
socket_path = res.result
socket_file = res.result
return HttpEndpointUnixSocket(
endpoint_type, namespace, path, socket_path, backlog, self.debug
endpoint_type, namespace, path, socket_file, backlog, self.debug
)

def add_user_session(
Expand Down Expand Up @@ -410,9 +420,9 @@ def write_message(
class CommandConnection(BaseCommandConnection):
"""Connection class for sending commands to the control server"""

def connect(self, socket_path: str = FULL_SOCKET_PATH): # type: ignore
def connect(self, socket_file: str = SOCKET_FILE): # type: ignore
"""Establishes a connection to the given UNIX socket file"""
return super().connect(clientinitmessages.command_init_message(), socket_path)
return super().connect(clientinitmessages.command_init_message(), socket_file)


class InterceptConnection(BaseCommandConnection):
Expand All @@ -435,13 +445,13 @@ def __init__(
self.filters = filters
self.priority_codes = priority_codes

def connect(self, socket_path: str = FULL_SOCKET_PATH): # type: ignore
def connect(self, socket_file: str = SOCKET_FILE): # type: ignore
"""Establishes a connection to the given UNIX socket file"""
iim = clientinitmessages.intercept_init_message(
self.interception_mode, self.channels, self.filters, self.priority_codes
)

return super().connect(iim, socket_path)
return super().connect(iim, socket_file)

def receive_code(self) -> code.Code:
"""Wait for a code to be intercepted and read it"""
Expand Down Expand Up @@ -484,12 +494,12 @@ def __init__(
self.filter_str = filter_str
self.filter_list = filter_list

def connect(self, socket_path: str = FULL_SOCKET_PATH): # type: ignore
def connect(self, socket_file: str = SOCKET_FILE): # type: ignore
"""Establishes a connection to the given UNIX socket file"""
sim = clientinitmessages.subscribe_init_message(
self.subscription_mode, self.filter_str, self.filter_list
)
return super().connect(sim, socket_path)
return super().connect(sim, socket_file)

def get_machine_model(self) -> MachineModel:
"""
Expand Down
10 changes: 5 additions & 5 deletions src/dsf/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,23 +108,23 @@ def __init__(
endpoint_type: HttpEndpointType,
namespace: str,
path: str,
socket_path: str,
socket_file: str,
backlog: int = DEFAULT_BACKLOG,
debug: bool = False,
):
"""Open a new UNIX socket on the given file path"""
self.endpoint_type = endpoint_type
self.namespace = namespace
self.endpoint_path = path
self.socket_path = socket_path
self.socket_file = socket_file
self.backlog = backlog
self.handler = None
self.debug = debug
self._loop = None
self._server = None

try:
os.remove(self.socket_path)
os.remove(self.socket_file)
except FileNotFoundError:
# We don't care if the file was missing
# TODO: should we care about deletion failed?
Expand All @@ -143,7 +143,7 @@ def close(self):
self.event_loop.cancel()
self.executor.shutdown(wait=False)
try:
os.remove(self.socket_path)
os.remove(self.socket_file)
except FileNotFoundError:
pass

Expand All @@ -155,7 +155,7 @@ def start_connection_listener(self):
try:
self._loop = asyncio.new_event_loop()
self._server = asyncio.start_unix_server(
self.handle_connection, self.socket_path, backlog=self.backlog
self.handle_connection, self.socket_file, backlog=self.backlog
)
self._loop.create_task(self._server)
self._loop.run_forever()
Expand Down
4 changes: 3 additions & 1 deletion src/dsf/initmessages/serverinitmessage.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""

from .. import PROTOCOL_VERSION


class IncompatibleVersionException(Exception):
"""Exception raised when the server and client are incompatible"""
Expand All @@ -34,7 +36,7 @@ def from_json(cls, data):
"""Deserialize a dictionary coming from JSON into an instance of this class"""
return cls(**data)

PROTOCOL_VERSION = 11
PROTOCOL_VERSION = PROTOCOL_VERSION

def __init__(self, version: int, id: int):
self.version = version
Expand Down

0 comments on commit 15accc1

Please sign in to comment.