Skip to content

Commit

Permalink
WIP checksum fields
Browse files Browse the repository at this point in the history
get checksum fields
get fields which the checksum depends on
pass msg as bytes and dict (kwargs) to checksum function
test checksum implementation with icmp

Ref. #240
  • Loading branch information
rssen committed May 15, 2020
1 parent fc93c70 commit 8f83586
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 36 deletions.
9 changes: 7 additions & 2 deletions rflx/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import itertools
from abc import ABC, abstractmethod
from copy import copy
from typing import Dict, List, Mapping, NamedTuple, Sequence, Set, Tuple
from typing import Dict, List, Mapping, NamedTuple, Optional, Sequence, Set, Tuple

from rflx.common import flat_name, generic_repr
from rflx.expression import (
Expand Down Expand Up @@ -271,12 +271,17 @@ def __repr__(self) -> str:

class AbstractMessage(Type):
def __init__(
self, identifier: StrID, structure: Sequence[Link], types: Mapping[Field, Type]
self,
identifier: StrID,
structure: Sequence[Link],
types: Mapping[Field, Type],
aspects: Mapping[str, Sequence[Mapping[str, Sequence[Expr]]]] = None,
) -> None:
super().__init__(identifier)

self.structure = structure
self.__types = types
self.aspects = aspects or {}

if structure or types:
self.__verify()
Expand Down
96 changes: 96 additions & 0 deletions rflx/pyrflx/typevalue.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
import copy
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union

from rflx import model
from rflx.common import generic_repr
from rflx.expression import (
FALSE,
TRUE,
UNDEFINED,
Add,
And,
Attribute,
Expr,
First,
Last,
Length,
Name,
Sub,
ValueRange,
Variable,
)
from rflx.identifier import ID
Expand Down Expand Up @@ -458,13 +462,27 @@ def accepted_type(self) -> type:
class MessageValue(TypeValue):

_type: Message
# ToDo remove
_checksum_function: callable

def __init__(self, model: Message, refinements: Sequence[Refinement] = None) -> None:
super().__init__(model)
self._refinements = refinements or []
self._fields: Dict[str, MessageValue.Field] = {
f.name: self.Field(TypeValue.construct(self._type.types[f])) for f in self._type.fields
}


self._aspects = self._type.aspects
if "Checksum" in self._aspects.keys():
for d in [*self._aspects["Checksum"]]:
self._checksum_fields = [k for k, v in d.items()]
else:
self._checksum_fields = []
self.checksum_expressions = self._create_checksum_dict()
self.checksum_dependant_fields = self._create_checksum_dependant_fields_list()


self.__type_literals: Mapping[Name, Expr] = {}
self._last_field: str = self._next_field(INITIAL.name)
for t in [
Expand Down Expand Up @@ -690,6 +708,7 @@ def set_refinement(fld: MessageValue.Field, fld_name: str) -> None:
)

self._preset_fields(field_name)
self._update_checksum_fields()

def _preset_fields(self, fld: str) -> None:
nxt = self._next_field(fld)
Expand All @@ -710,6 +729,78 @@ def _preset_fields(self, fld: str) -> None:
self._last_field = nxt
nxt = self._next_field(nxt)

def set_checksum_function(self, checksum_method: callable) -> None:
self._checksum_function = checksum_method

def _create_checksum_dependant_fields_list(self) -> List[str]:
# ValueRanges(F2'First .. (F3'First -1 oder F3'Last))
# Variable -> das benannte Feld
# Length -> das benannte Feld
# Annahme -> self.fields ist topologisch sortiert

checksum_dependant_fields: List[str] = []
for key, expr in self.checksum_expressions.items():

if isinstance(expr, ValueRange):
assert isinstance(expr.lower, First)
included_first_field_of_range = str(expr.lower.prefix)
if isinstance(expr.upper, Last):
excluded_last_field_of_range = str(expr.upper.prefix)
pre = self._type.fields[
model.Field(included_first_field_of_range) : model.Field(
excluded_last_field_of_range
)
]
else:
assert isinstance(expr.upper, Sub)
assert isinstance(expr.upper.left, First)
excluded_last_field_of_range = str(expr.upper.left.prefix)
pre = self.fields[
self.fields.index(included_first_field_of_range) : self.fields.index(
excluded_last_field_of_range
)
]
checksum_dependant_fields.extend(pre)
elif isinstance(expr, Variable):
if expr.name in self.fields:
checksum_dependant_fields.append(expr.name)
elif isinstance(expr, Attribute):
if str(expr.prefix) in self.fields:
checksum_dependant_fields.append(str(expr.prefix))

return checksum_dependant_fields

def _create_checksum_dict(self) -> Mapping[str, Expr]:
field_expressions = Sequence[Expr]
for i in range(len(self._checksum_fields)):
field_expressions = self._aspects["Checksum"][i][self._checksum_fields[i]]
return dict(zip([str(f) for f in field_expressions], field_expressions))

def _update_checksum_fields(self) -> None:

# try to evaluate the checksum expressions
for key, expr in self.checksum_expressions.items():
self.checksum_expressions[key] = self.__simplified(expr)

# if the expressions can be evaluated to numbers, pass them as arguments to
# the checksum function
arguments = {}
for key, expr in self.checksum_expressions.items():

if (
isinstance(expr, ValueRange)
and isinstance(expr.lower, Number)
and isinstance(expr.upper, Number)
):
arguments[key] = (expr.lower.value, expr.upper.value)
elif isinstance(expr, Variable):
if expr.name in self.fields and self._fields[expr.name].set:
arguments[key] = self._fields[expr.name].typeval.value
elif isinstance(expr, Number):
arguments[key] = expr.value

print(f"update {self._checksum_function(self.bytestring, **arguments)}")

def get(self, field_name: str) -> Union["MessageValue", Sequence[TypeValue], int, str, bytes]:
if field_name not in self.valid_fields:
raise ValueError(f"field {field_name} not valid")
Expand Down Expand Up @@ -832,6 +923,11 @@ def __simplified(self, expr: Expr) -> Expr:
}

mapping = {**field_values, **self.__type_literals}
# ToDo implement substituted method for ValueRange
if isinstance(expr, ValueRange):
expr.lower = expr.lower.substituted(mapping=mapping)
expr.upper = expr.upper.substituted(mapping=mapping)
return expr.simplified()

return expr.substituted(mapping=mapping).substituted(mapping=mapping).simplified()

Expand Down
17 changes: 17 additions & 0 deletions tests/checksum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
def checksum_icmp(msg: MessageValue) -> int:
def add_ones_complement(num1, num2) -> int:
MOD = 1 << 16
result = num1 + num2
return result if result < MOD else (result + 1) % MOD

msg.set("Checksum", 0)
message_in_sixteen_bit_chunks = [
int.from_bytes(msg.bytestring[i : i + 2], "big") for i in range(0, len(msg.bytestring), 2)
]
intermediary_result = message_in_sixteen_bit_chunks[0]
for i in range(1, len(message_in_sixteen_bit_chunks)):
intermediary_result = add_ones_complement(
intermediary_result, message_in_sixteen_bit_chunks[i]
)

return intermediary_result ^ 0xFFFF
134 changes: 100 additions & 34 deletions tests/test_pyrflx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from tempfile import TemporaryDirectory
from typing import List

from rflx.expression import UNDEFINED
from rflx import model
from rflx.expression import UNDEFINED, First, Last, Length, Sub, ValueRange, Variable
from rflx.identifier import ID
from rflx.model import (
FINAL,
Expand Down Expand Up @@ -53,46 +54,48 @@ def setUpClass(cls) -> None:
cls.specdir = "specs"
pyrflx = PyRFLX(
[
f"{cls.testdir}/tlv_with_checksum.rflx",
f"{cls.specdir}/ethernet.rflx",
f"{cls.specdir}/tls_record.rflx",
f"{cls.specdir}/tls_alert.rflx",
# f"{cls.testdir}/tlv_with_checksum.rflx",
# f"{cls.specdir}/ethernet.rflx",
# f"{cls.specdir}/tls_record.rflx",
# f"{cls.specdir}/tls_alert.rflx",
f"{cls.specdir}/icmp.rflx",
f"{cls.testdir}/test_odd_length.rflx",
f"{cls.specdir}/ipv4.rflx",
f"{cls.testdir}/array_message.rflx",
f"{cls.testdir}/array_type.rflx",
f"{cls.specdir}/udp.rflx",
f"{cls.specdir}/tlv.rflx",
f"{cls.specdir}/in_ethernet.rflx",
f"{cls.specdir}/in_ipv4.rflx",
# f"{cls.testdir}/test_odd_length.rflx",
# f"{cls.specdir}/ipv4.rflx",
# f"{cls.testdir}/array_message.rflx",
# f"{cls.testdir}/array_type.rflx",
# f"{cls.specdir}/udp.rflx",
# f"{cls.specdir}/tlv.rflx",
# f"{cls.specdir}/in_ethernet.rflx",
# f"{cls.specdir}/in_ipv4.rflx",
]
)
cls.package_tlv_checksum = pyrflx["TLV_With_Checksum"]
cls.package_ethernet = pyrflx["Ethernet"]
cls.package_tls_record = pyrflx["TLS_Record"]
cls.package_tls_alert = pyrflx["TLS_Alert"]
# cls.package_tlv_checksum = pyrflx["TLV_With_Checksum"]
# cls.package_ethernet = pyrflx["Ethernet"]
# cls.package_tls_record = pyrflx["TLS_Record"]
# cls.package_tls_alert = pyrflx["TLS_Alert"]
cls.package_icmp = pyrflx["ICMP"]
cls.package_test_odd_length = pyrflx["Test_Odd_Length"]
cls.package_ipv4 = pyrflx["IPv4"]
cls.package_array_nested_msg = pyrflx["Array_Message"]
cls.package_array_typevalue = pyrflx["Array_Type"]
cls.package_udp = pyrflx["UDP"]
cls.package_tlv = pyrflx["TLV"]

# cls.package_test_odd_length = pyrflx["Test_Odd_Length"]
# cls.package_ipv4 = pyrflx["IPv4"]
# cls.package_array_nested_msg = pyrflx["Array_Message"]
# cls.package_array_typevalue = pyrflx["Array_Type"]
# cls.package_udp = pyrflx["UDP"]
# cls.package_tlv = pyrflx["TLV"]

def setUp(self) -> None:
self.tlv_checksum = self.package_tlv_checksum["Message"]
self.tlv = self.package_tlv["Message"]
self.frame = self.package_ethernet["Frame"]
self.record = self.package_tls_record["TLS_Record"]
self.alert = self.package_tls_alert["Alert"]
# self.tlv_checksum = self.package_tlv_checksum["Message"]
# self.tlv = self.package_tlv["Message"]
# self.frame = self.package_ethernet["Frame"]
# self.record = self.package_tls_record["TLS_Record"]
# self.alert = self.package_tls_alert["Alert"]
self.icmp = self.package_icmp["Echo_Message"]
self.odd_length = self.package_test_odd_length["Test"]
self.ipv4 = self.package_ipv4["Packet"]
self.ipv4_option = self.package_ipv4["Option"]
self.array_test_nested_msg = self.package_array_nested_msg["Message"]
self.array_test_typeval = self.package_array_typevalue["Foo"]
self.udp = self.package_udp["Datagram"]

# self.odd_length = self.package_test_odd_length["Test"]
# self.ipv4 = self.package_ipv4["Packet"]
# self.ipv4_option = self.package_ipv4["Option"]
# self.array_test_nested_msg = self.package_array_nested_msg["Message"]
# self.array_test_typeval = self.package_array_typevalue["Foo"]
# self.udp = self.package_udp["Datagram"]

def test_file_not_found(self) -> None:
with self.assertRaises(FileNotFoundError):
Expand Down Expand Up @@ -1155,3 +1158,66 @@ def test_tlv_generating_tlv_error(self) -> None:
self.tlv.set("Tag", "Msg_Error")
self.assertTrue(self.tlv.valid_message)
self.assertEqual(self.tlv.bytestring, b"\xc0")

def test_aspect_checksum(self) -> None:
def checksum_icmp(message: bytes, **kwargs) -> int:
def add_ones_complement(num1, num2) -> int:
MOD = 1 << 16
result = num1 + num2
return result if result < MOD else (result + 1) % MOD

c_f = kwargs.get("Checksum'First", None)
c_l = kwargs.get("Checksum'Last", None)
if c_l and c_f:
checksum_bytes = message[: (c_f // 8)] + b"\x00\x00" + message[((c_l + 1) // 8) :]
else:
checksum_bytes = message

message_in_sixteen_bit_chunks = [
int.from_bytes(checksum_bytes[i : i + 2], "big")
for i in range(0, len(checksum_bytes), 2)
]
intermediary_result = message_in_sixteen_bit_chunks[0]
for i in range(1, len(message_in_sixteen_bit_chunks)):
intermediary_result = add_ones_complement(
intermediary_result, message_in_sixteen_bit_chunks[i]
)

return intermediary_result ^ 0xFFFF

test_data = (
b"\x47\xb4\x67\x5e\x00\x00\x00\x00"
b"\x4a\xfc\x0d\x00\x00\x00\x00\x00\x10\x11\x12\x13\x14\x15\x16\x17"
b"\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f\x20\x21\x22\x23\x24\x25\x26\x27"
b"\x28\x29\x2a\x2b\x2c\x2d\x2e\x2f\x30\x31\x32\x33\x34\x35\x36\x37"
)

icmp_type = self.icmp._type
icmp_type.aspects = {
"Checksum": [
{
"Checksum": [
ValueRange(First("Identifier"), Sub(First("Data"), Number(1))),
Variable("Data"),
First("Checksum"),
Last("Checksum"),
]
}
]
}
icmp_checksum = MessageValue(icmp_type)
icmp_checksum.set_checksum_function(checksum_icmp)

icmp_checksum.set("Tag", "Echo_Request")
icmp_checksum.set("Code", 0)
# 12824
icmp_checksum.set("Checksum", 12824)
icmp_checksum.set("Identifier", 5)
icmp_checksum.set("Sequence_Number", 1)
icmp_checksum.set(
"Data", test_data,
)

print(f"after message {checksum_icmp(icmp_checksum.bytestring)}")
# self.assertEqual(icmp_checksum.bytestring, b"\x08\x00\x32\x18\x00\x05\x00\x01" + test_data)
# self.assertTrue(icmp_checksum.valid_message)

0 comments on commit 8f83586

Please sign in to comment.