Skip to content

Commit

Permalink
Support passing in Graph.artifacts
Browse files Browse the repository at this point in the history
Signed-off-by: Jacob Hayes <jacob.r.hayes@gmail.com>
  • Loading branch information
JacobHayes committed Mar 19, 2023
1 parent d708f3b commit 78ba386
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 11 deletions.
16 changes: 9 additions & 7 deletions src/arti/graphs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from types import TracebackType
from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, Union, cast

from pydantic import Field, PrivateAttr
from pydantic import Field, PrivateAttr, validator

import arti
from arti import io
Expand Down Expand Up @@ -105,6 +105,7 @@ class Graph(Model):
"""Graph stores a web of Artifacts connected by Producers."""

name: str
artifacts: ArtifactBox = Field(default_factory=lambda: ArtifactBox(**BOX_KWARGS[SEALED]))
# The Backend *itself* should not affect the results of a Graph build, though the contents
# certainly may (eg: stored annotations), so we avoid serializing it. This also prevent
# embedding any credentials.
Expand All @@ -113,9 +114,13 @@ class Graph(Model):

# Graph starts off sealed, but is opened within a `with Graph(...)` context
_status: Optional[bool] = PrivateAttr(None)
_artifacts: ArtifactBox = PrivateAttr(default_factory=lambda: ArtifactBox(**BOX_KWARGS[SEALED]))
_artifact_to_key: frozendict[Artifact, str] = PrivateAttr(frozendict())

@validator("artifacts")
@classmethod
def _convert_artifacts(cls, artifacts: ArtifactBox) -> ArtifactBox:
return ArtifactBox(artifacts, **BOX_KWARGS[SEALED])

def __enter__(self) -> Graph:
if arti.context.graph is not None:
raise ValueError(f"Another graph is being defined: {arti.context.graph}")
Expand All @@ -135,16 +140,13 @@ def __exit__(
TopologicalSorter(self.dependencies).prepare()

def _toggle(self, status: bool) -> None:
# The Graph object is "frozen", so we must bypass the assignment checks.
object.__setattr__(self, "artifacts", ArtifactBox(self.artifacts, **BOX_KWARGS[status]))
self._status = status
self._artifacts = ArtifactBox(self.artifacts, **BOX_KWARGS[status])
self._artifact_to_key = frozendict(
{artifact: key for key, artifact in self.artifacts.walk()}
)

@property
def artifacts(self) -> ArtifactBox:
return self._artifacts

@property
def artifact_to_key(self) -> frozendict[Artifact, str]:
return self._artifact_to_key
Expand Down
42 changes: 39 additions & 3 deletions src/arti/internal/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from collections.abc import Generator, Mapping, Sequence
from copy import deepcopy
from functools import cached_property, partial
Expand All @@ -9,10 +11,12 @@
Literal,
Optional,
TypeVar,
Union,
get_args,
get_origin,
)

from box import Box
from pydantic import BaseModel, Extra, root_validator, validator
from pydantic.fields import ModelField, Undefined
from pydantic.json import pydantic_encoder as pydantic_json_encoder
Expand All @@ -21,6 +25,8 @@
from arti.internal.utils import class_name, frozendict

if TYPE_CHECKING:
from pydantic.typing import AbstractSetIntStr, MappingIntStrAny

from arti.fingerprints import Fingerprint
from arti.types import Type

Expand Down Expand Up @@ -211,7 +217,7 @@ def _fingerprint_json_encoder(obj: Any, encoder: Any = pydantic_json_encoder) ->
return encoder(obj)

@property
def fingerprint(self) -> "Fingerprint":
def fingerprint(self) -> Fingerprint:
from arti.fingerprints import Fingerprint

# `.json` cannot be used, even with a custom encoder, because it calls `.dict`, which
Expand All @@ -233,6 +239,36 @@ def fingerprint(self) -> "Fingerprint":
)
return Fingerprint.from_string(f"{self._class_key_}:{json_repr}")

@classmethod
def _get_value(
cls,
v: Any,
to_dict: bool,
by_alias: bool,
include: Optional[Union[AbstractSetIntStr, MappingIntStrAny]],
exclude: Optional[Union[AbstractSetIntStr, MappingIntStrAny]],
exclude_unset: bool,
exclude_defaults: bool,
exclude_none: bool,
) -> Any:
new = super()._get_value(
v,
to_dict=to_dict,
by_alias=by_alias,
include=include,
exclude=exclude,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
# Copying dict subclasses doesn't preserve the subclass[1]. Further, we have extra Box
# configuration (namely frozen_box=True) we need to preserve.
#
# 1: https://github.com/pydantic/pydantic/issues/5225
if isinstance(v, Box):
return v.__class__(new, **v._Box__box_config())
return new

# Filter out non-fields from ._iter (and thus .dict, .json, etc), such as `@cached_property`
# after access (which just gets cached in .__dict__).
def _iter(self, *args: Any, **kwargs: Any) -> Generator[tuple[str, Any], None, None]:
Expand All @@ -242,8 +278,8 @@ def _iter(self, *args: Any, **kwargs: Any) -> Generator[tuple[str, Any], None, N

@classmethod
def _pydantic_type_system_post_field_conversion_hook_(
cls, type_: "Type", *, name: str, required: bool
) -> "Type":
cls, type_: Type, *, name: str, required: bool
) -> Type:
return type_


Expand Down
22 changes: 21 additions & 1 deletion tests/arti/graphs/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from arti import Artifact, CompositeKey, Fingerprint, Graph, GraphSnapshot, View, producer
from arti.backends.memory import MemoryBackend
from arti.executors.local import LocalExecutor
from arti.graphs import ArtifactBox
from arti.internal.utils import frozendict
from arti.storage.literal import StringLiteral
from arti.storage.local import LocalFile, LocalFilePartition
Expand Down Expand Up @@ -43,13 +44,31 @@ def test_Graph(graph: Graph) -> None:
assert graph.artifacts.c.a.storage.includes_input_fingerprint_template
assert graph.artifacts.c.b.storage.includes_input_fingerprint_template
# NOTE: We may need to occasionally update this, but ensure graph.backend is not included.
assert graph.fingerprint == Fingerprint.from_int(4705012302096346878)
assert graph.fingerprint == Fingerprint.from_int(3139813064524317498)


def test_Graph_pickle(graph: Graph) -> None:
assert graph == pickle.loads(pickle.dumps(graph))


def test_Graph_copy(graph: Graph) -> None:
# There are a few edge cases in pydantic when copying a model with a mapping subclass field[1], so
# double check things are ok under various conditions.
#
# 1: https://github.com/pydantic/pydantic/issues/5225
for copy in [
graph.copy(),
graph.copy(exclude={"backend"}),
graph.copy(include=set(graph.__fields__)),
]:
assert graph == copy
assert isinstance(copy.artifacts, ArtifactBox)
assert graph.artifacts == copy.artifacts
assert graph.fingerprint == copy.fingerprint
assert hash(graph) == hash(copy)
assert hash(copy) == copy.fingerprint.key


def test_Graph_literals(tmp_path: Path) -> None:
n_add_runs = 0

Expand Down Expand Up @@ -147,6 +166,7 @@ def test_Graph_snapshot() -> None:
# Ensure order independence
assert s.id == Fingerprint.combine(*reversed(id_components))

assert g.backend is s.backend # Ensure the backend is not copied
# Ensure metadata is written
with g.backend.connect() as conn:
assert conn.read_graph(g.name, g.fingerprint) == g
Expand Down

0 comments on commit 78ba386

Please sign in to comment.