diff --git a/logger/writers/test_udp_writer.py b/logger/writers/test_udp_writer.py new file mode 100644 index 00000000..21941bdc --- /dev/null +++ b/logger/writers/test_udp_writer.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python3 + +import logging +import signal +import socket +import sys +import threading +import time +import unittest + +from os.path import dirname, realpath +sys.path.append(dirname(dirname(dirname(realpath(__file__))))) +from logger.writers.udp_writer import UDPWriter + +SAMPLE_DATA = ['f1 line 1', + 'f1 line 2', + 'f1 line 3'] + +BINARY_DATA = [b'\xff\xa1', + b'\xff\xa2', + b'\xff\xa3'] + + + +class ReaderTimeout(StopIteration): + """A custom exception we can raise when we hit timeout.""" + pass + +################################################################################ + + +class TestUDPWriter(unittest.TestCase): + ############################ + def _handler(self, signum, frame): + """If timeout fires, raise our custom exception""" + logging.info('Read operation timed out') + raise ReaderTimeout + + ############################ + # Actually run the UDPWriter in internal method + def write(self, host, port, eol=None, data=None, interval=0, delay=0, encoding='utf-8', mc_interface=None, mc_ttl=3): + writer = UDPWriter(destination=host, port=port, eol=eol, encoding=encoding, mc_interface=mc_interface, mc_ttl=mc_ttl) + + time.sleep(delay) + for line in data: + writer.write(line) + time.sleep(interval) + + ############################ + # + # NOTE: The only simple to really verify that it's broadcast, as apposed to + # unicast that we successfully read, is to do packet analysis in + # wireshark/tshark or similar. But this test case will catch most + # breakage from other developemnt. + # + def test_broadcast(self): + # Main method starts here + host = "" + port = 8001 + + sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, True) + sock.bind((host, port)) + + # Start the writer + threading.Thread(target=self.write, + args=(host, port), + kwargs={"data": SAMPLE_DATA, "interval": 0.1}).start() + + # Set timeout we can catch if things are taking too long + signal.signal(signal.SIGALRM, self._handler) + signal.alarm(3) + try: + # Check that we get the lines we expect from it + for line in SAMPLE_DATA: + record = sock.recv(4096) + logging.info('looking for "%s", got "%s"', line, record) + if record: + record = record.decode('utf-8') + self.assertEqual(line, record) + except ReaderTimeout: + self.assertTrue(False, 'UDPReader timed out in test - is port ' + '%s open?' % addr) + signal.alarm(0) + + ############################ + # + # NOTE: We're really testing 2 things here: unicast output and handling odd + # `eol` parameter values (e.g., escaped oddness). + # + def test_unicast_with_eol(self): + # Main method starts here + host = '' + port = 8002 + eol = '\\r\\n' # simulate escaped entry from config file + + sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM) + sock.bind((host, port)) + + # Start the writer + threading.Thread(target=self.write, + args=(host, port), + kwargs={"data": SAMPLE_DATA, "interval": 0.1, "eol": eol}).start() + + # Set timeout we can catch if things are taking too long + signal.signal(signal.SIGALRM, self._handler) + signal.alarm(3) + try: + # Check that we get the lines we expect from it + for line in SAMPLE_DATA: + record = sock.recv(4096) + if record: + record = record.decode('utf-8') + line += eol.encode().decode('unicode_escape') + logging.info('looking for "%s", got "%s"', line, record) + self.assertEqual(line, record) + except ReaderTimeout: + self.assertTrue(False, 'UDPReader timed out in test - is port ' + '%s open?' % addr) + signal.alarm(0) + + ############################ + def test_binary(self): + # Main method starts here + host = '' + port = 8003 + + sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM) + sock.bind((host, port)) + + # Start the writer + threading.Thread(target=self.write, + args=(host, port), + kwargs={"data": BINARY_DATA, "interval": 0.1, "encoding": None}).start() + + # Set timeout we can catch if things are taking too long + signal.signal(signal.SIGALRM, self._handler) + signal.alarm(3) + try: + # Check that we get the lines we expect from it + for line in BINARY_DATA: + record = sock.recv(4096) + logging.info('looking for "%s", got "%s"', line, record) + self.assertEqual(line, record) + except ReaderTimeout: + self.assertTrue(False, 'UDPReader timed out in test - is port ' + '%s open?' % addr) + signal.alarm(0) + + ############################ + # + # NOTE: This doesn't actually prove that your network can support multicast + # routing, which is complicated. It's only verifying that we sent + # packets to a multicast address and received them, all locally. So, + # it also doesn't really test the IGMP group membership stuff very + # well. Again, packet analysis is really needed to prove it's + # working. + # + def test_multicast(self): + # Main method starts here + host = "239.192.0.100" + port = 8004 + ttl = 3 + source_ip = socket.gethostbyname(socket.gethostname()) + + sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM) + + # join the multicast group + logging.info("joing %s to multicast group %s", source_ip, host) + # NOTE: Since these are both already encoded as binary by inet_aton(), + # we can just concatenate them. Alternatively, could use + # struct.pack("4s4s", ...) + sock.setsockopt(socket.IPPROTO_IP, socket.IP_ADD_MEMBERSHIP, + socket.inet_aton(host) + socket.inet_aton(source_ip)) + + # bind to the mc host, port + sock.bind((host, port)) + + # Start the writer + threading.Thread(target=self.write, + args=(host, port), + kwargs={"data": SAMPLE_DATA, "interval": 0.1, "mc_interface": source_ip, "mc_ttl": ttl}).start() + + # Set timeout we can catch if things are taking too long + signal.signal(signal.SIGALRM, self._handler) + signal.alarm(3) + try: + # Check that we get the lines we expect from it + for line in SAMPLE_DATA: + record = sock.recv(4096) + logging.info('looking for "%s", got "%s"', line, record) + if record: + record = record.decode('utf-8') + self.assertEqual(line, record) + except ReaderTimeout: + self.assertTrue(False, 'UDPReader timed out in test - is port ' + '%s open?' % addr) + signal.alarm(0) + + +################################################################################ +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('-v', '--verbosity', dest='verbosity', + default=0, action='count', + help='Increase output verbosity') + args = parser.parse_args() + + LOGGING_FORMAT = '%(asctime)-15s %(filename)s:%(lineno)d %(message)s' + logging.basicConfig(format=LOGGING_FORMAT) + + LOG_LEVELS = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} + args.verbosity = min(args.verbosity, max(LOG_LEVELS)) + logging.getLogger().setLevel(LOG_LEVELS[args.verbosity]) + + unittest.main(warnings='ignore') diff --git a/logger/writers/udp_writer.py b/logger/writers/udp_writer.py index d35d6aab..83b08d1a 100644 --- a/logger/writers/udp_writer.py +++ b/logger/writers/udp_writer.py @@ -1,7 +1,6 @@ #!/usr/bin/env python3 import json -import ipaddress import logging import socket import struct @@ -10,18 +9,18 @@ from os.path import dirname, realpath sys.path.append(dirname(dirname(dirname(realpath(__file__))))) -from logger.utils.das_record import DASRecord # noqa: E402 -from logger.writers.network_writer import NetworkWriter # noqa: E402 +from logger.utils.formats import Text +from logger.writers.writer import Writer -class UDPWriter(NetworkWriter): +class UDPWriter(Writer): """Write UDP packets to network.""" def __init__(self, port, destination='', interface='', # DEPRECATED! - ttl=3, num_retry=2, warning_limit=5, eol=''): - """ - Write text records to a network socket. + mc_interface=None, mc_ttl=3, num_retry=2, warning_limit=5, eol='', + encoding='utf-8', encoding_errors='ignore'): + """Write text records to a network socket. ``` port Port to which packets should be sent @@ -38,7 +37,10 @@ def __init__(self, port, destination='', argument and specify the broadcast address of the desired network. - ttl For multicast, how many network hops to allow + mc_interface REQUIRED for multicast, the interface to send from. Can be + specified as either IP or a resolvable hostname. + + mc_ttl For multicast, how many network hops to allow. num_retry Number of times to retry if write fails. If writer exceeds this number, it will give up on writing the message and @@ -50,17 +52,51 @@ def __init__(self, port, destination='', eol If specified, an end of line string to append to record before sending. + + encoding - 'utf-8' by default. If empty or None, do not attempt any + decoding and return raw bytes. Other possible encodings are + listed in online documentation here: + https://docs.python.org/3/library/codecs.html#standard-encodings + + encoding_errors - 'ignore' by default. Other error strategies are + 'strict', 'replace', and 'backslashreplace', described here: + https://docs.python.org/3/howto/unicode.html#encodings ``` + """ - self.ttl = ttl + super().__init__(input_format=Text, + encoding=encoding, + encoding_errors=encoding_errors) + self.num_retry = num_retry self.warning_limit = warning_limit self.num_warnings = 0 + self.good_writes = 0 # consecutive good writes, for detecting UDP errors + + # 'eol' comes in as a (probably escaped) string. We need to + # unescape it, which means converting to bytes and back. + if eol is not None and self.encoding: + eol = self._unescape_str(eol) self.eol = eol self.target_str = 'interface: %s, destination: %s, port: %d' % ( interface, destination, port) + # do name resolution once in the constructor + # + # NOTE: This means the hostname must be valid when we start, otherwise + # the config_check code will puke. That's fine. The alternative + # is we let name resolution happen while we're running, but then + # each failed lookup is going to block our write() routine for a + # few seconds - not good. + # + # NOTE: This also catches specifying impropperly formatted IP + # addresses. The only way through gethostbyname() w/out throwing + # an exception is to provide a valid hostname or IP address. + # Propperly formatted IPs just get returned. + # + if destination: + destination = socket.gethostbyname(destination) if interface: logging.warning('DEPRECATED: UDPWriter(interface=%s). Instead of the ' '"interface" parameter, UDPWriters should use the' @@ -68,10 +104,9 @@ def __init__(self, port, destination='', 'interface, specify UDPWriter(destination=) address as the destination.', interface) + interface = socket.gethostbyname(interface) if interface and destination: - ipaddress.ip_address(interface) # throw a ValueError if bad addr - ipaddress.ip_address(destination) # At the moment, we don't know how to do both interface and # multicast/unicast. If they've specified both, then complain # and ignore the interface part. @@ -95,20 +130,23 @@ def __init__(self, port, destination='', destination = '' else: # Change interface's lowest tuple to 'broadcast' value (255) - ipaddress.ip_address(interface) destination = interface[:interface.rfind('.')] + '.255' - # If we've been given a destination, make sure it's a valid IP - elif destination: - ipaddress.ip_address(destination) - # If no destination, it's a broadcast; set flag allowing broadcast and # set dest to special string - else: + elif not destination: destination = '' self.destination = destination - self.port = port + # make sure port gets stored as an int, even if passed in as a string + self.port = int(port) + + # multicast options + if mc_interface: + # resolve once in constructor + mc_interface = socket.gethostbyname(mc_interface) + self.mc_interface = mc_interface + self.mc_ttl = mc_ttl # Try opening the socket self.socket = self._open_socket() @@ -126,10 +164,18 @@ def _open_socket(self): except AttributeError: logging.warning('Unable to set socket REUSEPORT; may be unsupported') - # Set the time-to-live for messages, in case of multicast - udp_socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, - struct.pack('b', self.ttl)) - udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, True) + # set multicast/broadcast options + if self.mc_interface: + # set the time-to-live for messages + udp_socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, + struct.pack('b', self.mc_ttl)) + # set outgoing multicast interface + udp_socket.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF, + socket.inet_aton(self.mc_interface)) + else: + # maybe broadcast, but very non-trivial to detect broadcast IP, so + # we set the broadcast flag anytime we're not doing multicast + udp_socket.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, True) try: udp_socket.connect((self.destination, self.port)) @@ -152,15 +198,7 @@ def write(self, record): self.write(single_record) return - # If record is not a string, try converting to JSON. If we don't know - # how, throw a hail Mary and force it into str format - if not isinstance(record, str): - if type(record) in [int, float, bool, list, dict]: - record = json.dumps(record) - elif isinstance(record, DASRecord): - record = record.as_json() - else: - record = str(record) + # Append eol if configured if self.eol: record += self.eol @@ -175,18 +213,33 @@ def write(self, record): num_tries = bytes_sent = 0 rec_len = len(record) - while num_tries < self.num_retry and bytes_sent < rec_len: + while num_tries <= self.num_retry and bytes_sent < rec_len: try: - bytes_sent = self.socket.send(record.encode('utf-8')) + bytes_sent = self.socket.send(self._encode_str(record)) # If here, write at least partially succeeded. Reset warnings - if self.num_warnings == self.warning_limit: - logging.info('UDPWriter.write() succeeded in writing after series of ' - 'failures; resetting warnings.') - self.num_warnings = 0 # we've succeeded + # + # NOTE: If the host is unreachable, every other send will fail. + # Since UDP doesn't actually know it failed, the initial + # send() cannot fail. However, the network stack will + # see the ICMP host unreachable message and will store + # THAT as the the error message for next write, then the + # next send fails and clears the error... Then the next + # "succeeds" and the next fails, etc, etc + # + # So we look for 2 consecutive "successful" writes before + # resetting num_warnings. + # + self.good_writes += 1 + if self.good_writes >= 2: + if self.num_warnings == self.warning_limit: + logging.info('UDPWriter.write() succeeded in writing after series of ' + 'failures; resetting warnings.') + self.num_warnings = 0 # we've succeeded except (OSError, ConnectionRefusedError) as e: # If we failed, complain, unless we've already complained too much + self.good_writes = 0 if self.num_warnings < self.warning_limit: logging.error('UDPWriter error: %s: %s', self.target_str, str(e)) logging.error('UDPWriter record: %s', record) diff --git a/logger/writers/writer.py b/logger/writers/writer.py index 8a4e8740..5a9a3028 100644 --- a/logger/writers/writer.py +++ b/logger/writers/writer.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import logging import sys from os.path import dirname, realpath @@ -14,9 +15,9 @@ class Writer: """ ############################ - def __init__(self, input_format=formats.Unknown): - """ - Abstract base class for data Writers. + def __init__(self, input_format=formats.Unknown, + encoding='utf-8', encoding_errors='ignore'): + """Abstract base class for data Writers. Concrete classes should implement the write(record) method and assign an input format that the writer accepts. @@ -38,8 +39,64 @@ def __init__(self, filename): ``` The test will also return False if either the reader or writer have format specification of "Unknown". + + Two additional arguments govern how records will be encoded/decoded + from bytes, if desired by the Reader subclass when it calls + _encode_str() or _decode_bytes: + + encoding - 'utf-8' by default. If empty or None, do not attempt any + decoding and return raw bytes. Other possible encodings are + listed in online documentation here: + https://docs.python.org/3/library/codecs.html#standard-encodings + + encoding_errors - 'ignore' by default. Other error strategies are + 'strict', 'replace', and 'backslashreplace', described here: + https://docs.python.org/3/howto/unicode.html#encodings + """ self.input_format(input_format) + self.encoding = encoding + self.encoding_errors = encoding_errors + + ############################ + def _unescape_str(self, the_str): + """Unescape a string by encoding it to bytes, then unescaping when we + decode it. Ugly. + """ + if not self.encoding: + return the_str + + encoded = the_str.encode(encoding=self.encoding, errors=self.encoding_errors) + return encoded.decode('unicode_escape') + + ############################ + def _encode_str(self, the_str, unescape=True): + """Encode a string to bytes, unescaping things like \n and \r. Unescaping + requires ugly convolutions of encoding, then decoding while we escape things, + then encoding a second time. + """ + if not self.encoding: + return the_str + if unescape: + the_str = self._unescape_str(the_str) + return the_str.encode(encoding=self.encoding, errors=self.encoding_errors) + + ############################ + def _decode_bytes(self, record): + """Decode a record from bytes to str, if we have an encoding specified.""" + if not record: + return None + + if not self.encoding: + return record + + try: + return record.decode(encoding=self.encoding, + errors=self.encoding_errors) + except UnicodeDecodeError as e: + logging.warning('Error decoding string "%s" from encoding "%s": %s', + record, self.encoding, str(e)) + return None ############################ def input_format(self, new_format=None):