Skip to content

Commit

Permalink
Add caching of message verification results to parser
Browse files Browse the repository at this point in the history
Ref. #442
  • Loading branch information
treiher committed Sep 30, 2020
1 parent 9d1ff85 commit 3b2efe9
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 14 deletions.
2 changes: 1 addition & 1 deletion rflx/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def generate(args: argparse.Namespace) -> None:


def parse(files: List, skip_verification: bool = False) -> Model:
parser = Parser(skip_verification)
parser = Parser(skip_verification, cached=True)

error = RecordFluxError()
for f in files:
Expand Down
63 changes: 63 additions & 0 deletions rflx/parser/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import hashlib
import json
import pathlib
from typing import Dict

from rflx import __version__
from rflx.model.message import AbstractMessage

CACHE_DIR = pathlib.Path.home() / ".cache" / "RecordFlux"
VERIFICATION_FILE = "verification.json"


class Cache:
def __init__(self, enabled: bool = True) -> None:
self._enabled = enabled

if not enabled:
return

self._verification: Dict[str, str] = {}

self._initialize_cache()
self._load_cache()

def is_verified(self, message: AbstractMessage) -> bool:
if not self._enabled:
return False

return message.full_name in self._verification and self._verification[
message.full_name
] == self._message_hash(message)

def add_verified(self, message: AbstractMessage) -> None:
if not self._enabled:
return

message_hash = self._message_hash(message)
if (
message.full_name not in self._verification
or message_hash != self._verification[message.full_name]
):
self._verification[message.full_name] = message_hash
self._write_cache()

@staticmethod
def _initialize_cache() -> None:
if not CACHE_DIR.exists():
CACHE_DIR.mkdir()
if not (CACHE_DIR / VERIFICATION_FILE).exists():
with open(CACHE_DIR / VERIFICATION_FILE, "w") as f:
json.dump({}, f)

def _load_cache(self) -> None:
with open(CACHE_DIR / VERIFICATION_FILE) as f:
self._verification = json.load(f)

def _write_cache(self) -> None:
with open(CACHE_DIR / VERIFICATION_FILE, "w") as f:
json.dump(self._verification, f)

@staticmethod
def _message_hash(message: AbstractMessage) -> str:
return hashlib.md5(f"{__version__}|{message}".encode("utf-8")).hexdigest()
51 changes: 38 additions & 13 deletions rflx/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,19 @@
RefinementSpec,
Specification,
)
from .cache import Cache

log = logging.getLogger(__name__)


class Parser:
def __init__(self, skip_verification: bool = False) -> None:
def __init__(self, skip_verification: bool = False, cached: bool = False) -> None:
self.skip_verification = skip_verification
self.__specifications: Deque[Specification] = deque()
self.__evaluated_specifications: Set[ID] = set()
self.__types: Dict[ID, Type] = {**BUILTIN_TYPES, **INTERNAL_TYPES}
self.__sessions: Dict[ID, Session] = {}
self.__cache = Cache(cached)

def parse(self, specfile: Path) -> None:
self.__parse(specfile)
Expand Down Expand Up @@ -140,6 +142,7 @@ def create_model(self) -> Model:
result = Model(list(self.__types.values()), list(self.__sessions.values()))
except RecordFluxError as e:
error.extend(e)

error.propagate()
return result

Expand Down Expand Up @@ -184,10 +187,12 @@ def __evaluate_types(self, spec: Specification, error: RecordFluxError) -> None:
new_type = create_array(t, self.__types)

elif isinstance(t, MessageSpec):
new_type = create_message(t, self.__types, self.skip_verification)
new_type = create_message(t, self.__types, self.skip_verification, self.__cache)

elif isinstance(t, DerivationSpec):
new_type = create_derived_message(t, self.__types)
new_type = create_derived_message(
t, self.__types, self.skip_verification, self.__cache
)

elif isinstance(t, RefinementSpec):
new_type = create_refinement(t, self.__types)
Expand Down Expand Up @@ -248,7 +253,10 @@ def create_array(array: ArraySpec, types: Mapping[ID, Type]) -> Array:


def create_message(
message: MessageSpec, types: Mapping[ID, Type], skip_verification: bool = False
message: MessageSpec,
types: Mapping[ID, Type],
skip_verification: bool,
cache: Cache,
) -> Message:
components = list(message.components)

Expand Down Expand Up @@ -306,16 +314,21 @@ def create_message(
)
)

return (
return create_proven_message(
UnprovenMessage(
message.identifier, structure, field_types, message.aspects, message.location, error
)
.merged()
.proven(skip_verification)
).merged(),
skip_verification,
cache,
)


def create_derived_message(derivation: DerivationSpec, types: Mapping[ID, Type]) -> Message:
def create_derived_message(
derivation: DerivationSpec,
types: Mapping[ID, Type],
skip_verification: bool,
cache: Cache,
) -> Message:
base_name = qualified_type_name(derivation.base, derivation.package)
messages = message_types(types)
error = RecordFluxError()
Expand Down Expand Up @@ -357,12 +370,24 @@ def create_derived_message(derivation: DerivationSpec, types: Mapping[ID, Type])
)
error.propagate()

return (
UnprovenDerivedMessage(derivation.identifier, base, location=derivation.location)
.merged()
.proven()
return create_proven_message(
UnprovenDerivedMessage(derivation.identifier, base, location=derivation.location).merged(),
skip_verification,
cache,
)


def create_proven_message(
unproven_message: UnprovenMessage, skip_verification: bool, cache: Cache
) -> Message:
proven_message = unproven_message.proven(
skip_verification or cache.is_verified(unproven_message)
)

cache.add_verified(unproven_message)

return proven_message


def create_refinement(refinement: RefinementSpec, types: Mapping[ID, Type]) -> Refinement:
messages = message_types(types)
Expand Down
19 changes: 19 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Opaque,
RangeInteger,
Refinement,
UnprovenMessage,
)

NULL_MESSAGE = Message("Null::Message", [], {})
Expand Down Expand Up @@ -217,6 +218,24 @@
DERIVATION_MESSAGE = DerivedMessage("Derivation::Message", ARRAYS_MESSAGE)
DERIVATION_MODEL = Model([*ARRAYS_MODEL.types, DERIVATION_MESSAGE])

VALID_MESSAGE = UnprovenMessage(
"P::M",
[
Link(INITIAL, Field("F"), length=Number(16)),
Link(Field("F"), FINAL),
],
{Field("F"): Opaque()},
)

INVALID_MESSAGE = UnprovenMessage(
"P::M",
[
Link(INITIAL, Field("F")),
Link(Field("F"), FINAL),
],
{Field("F"): Opaque()},
)

MODULAR_INTEGER = ModularInteger("P::Modular", Number(256))
RANGE_INTEGER = RangeInteger("P::Range", Number(1), Number(100), Number(8))
ENUMERATION = Enumeration(
Expand Down
45 changes: 45 additions & 0 deletions tests/unit/parser/cache_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from pathlib import Path

from rflx.parser import cache
from tests.models import TLV_MESSAGE


def test_init(tmp_path: Path) -> None:
cache.CACHE_DIR = tmp_path / "Test"
cache.Cache()
assert (tmp_path / "Test").is_dir()
assert (tmp_path / "Test" / cache.VERIFICATION_FILE).is_file()


def test_init_existing(tmp_path: Path) -> None:
cache.CACHE_DIR = tmp_path
with open(tmp_path / cache.VERIFICATION_FILE, "x") as f:
f.write("{}")
cache.Cache()
assert tmp_path.is_dir()
assert (tmp_path / cache.VERIFICATION_FILE).is_file()


def test_init_disabled(tmp_path: Path) -> None:
cache.CACHE_DIR = tmp_path / "Test"
cache.Cache(enabled=False)
assert not (tmp_path / "Test").is_dir()
assert not (tmp_path / "Test" / cache.VERIFICATION_FILE).is_file()


def test_verified(tmp_path: Path) -> None:
cache.CACHE_DIR = tmp_path
c = cache.Cache()
assert not c.is_verified(TLV_MESSAGE)
c.add_verified(TLV_MESSAGE)
assert c.is_verified(TLV_MESSAGE)
c.add_verified(TLV_MESSAGE)
assert c.is_verified(TLV_MESSAGE)


def test_verified_disabled(tmp_path: Path) -> None:
cache.CACHE_DIR = tmp_path
c = cache.Cache(enabled=False)
assert not c.is_verified(TLV_MESSAGE)
c.add_verified(TLV_MESSAGE)
assert not c.is_verified(TLV_MESSAGE)
37 changes: 37 additions & 0 deletions tests/unit/parser/parser_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from pathlib import Path

import pytest

from rflx.error import RecordFluxError
from rflx.parser import cache, parser
from tests.models import INVALID_MESSAGE, VALID_MESSAGE

TEST_DIR = Path("specs")
SPEC_DIR = Path("specs")


def test_create_model() -> None:
p = parser.Parser()
p.parse(SPEC_DIR / "tlv.rflx")
p.create_model()


def test_create_model_cached() -> None:
p = parser.Parser(cached=True)
p.parse(SPEC_DIR / "tlv.rflx")
p.create_model()


def test_create_proven_message(tmp_path: Path) -> None:
cache.CACHE_DIR = tmp_path
c = cache.Cache()
assert parser.create_proven_message(VALID_MESSAGE, False, c)
assert c.is_verified(VALID_MESSAGE)


def test_create_proven_message_error(tmp_path: Path) -> None:
cache.CACHE_DIR = tmp_path
c = cache.Cache()
with pytest.raises(RecordFluxError):
parser.create_proven_message(INVALID_MESSAGE, False, c)
assert not c.is_verified(INVALID_MESSAGE)

0 comments on commit 3b2efe9

Please sign in to comment.